From 4a84a5576225eb168e6242dd8d55a166bcd2fe34 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Mon, 24 Nov 2025 16:08:23 +0200 Subject: [PATCH 1/8] chore: use internal_err macro in `assert_eq_or_internal_err` for backtrace --- datafusion/common/src/error.rs | 28 ++++++++++++------------ datafusion/common/src/utils/mod.rs | 2 +- datafusion/expr-common/src/statistics.rs | 2 +- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 6a5cb31fe3fba..993dc0b322b62 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -768,10 +768,10 @@ impl DataFusionErrorBuilder { macro_rules! unwrap_or_internal_err { ($Value: ident) => { $Value.ok_or_else(|| { - DataFusionError::Internal(format!( + $crate::error::_internal_datafusion_err!( "{} should not be None", stringify!($Value) - )) + ) })? }; } @@ -789,19 +789,19 @@ macro_rules! unwrap_or_internal_err { macro_rules! assert_or_internal_err { ($cond:expr) => { if !$cond { - return Err(DataFusionError::Internal(format!( + return Err($crate::error::_internal_datafusion_err!( "Assertion failed: {}", stringify!($cond) - ))); + )); } }; ($cond:expr, $($arg:tt)+) => { if !$cond { - return Err(DataFusionError::Internal(format!( + return Err($crate::error::_internal_datafusion_err!( "Assertion failed: {}: {}", stringify!($cond), format!($($arg)+) - ))); + )); } }; } @@ -821,27 +821,27 @@ macro_rules! assert_eq_or_internal_err { let left_val = &$left; let right_val = &$right; if left_val != right_val { - return Err(DataFusionError::Internal(format!( + return Err($crate::error::_internal_datafusion_err!( "Assertion failed: {} == {} (left: {:?}, right: {:?})", stringify!($left), stringify!($right), left_val, right_val - ))); + )); } }}; ($left:expr, $right:expr, $($arg:tt)+) => {{ let left_val = &$left; let right_val = &$right; if left_val != right_val { - return Err(DataFusionError::Internal(format!( + return Err($crate::error::_internal_datafusion_err!( "Assertion failed: {} == {} (left: {:?}, right: {:?}): {}", stringify!($left), stringify!($right), left_val, right_val, format!($($arg)+) - ))); + )); } }}; } @@ -861,27 +861,27 @@ macro_rules! assert_ne_or_internal_err { let left_val = &$left; let right_val = &$right; if left_val == right_val { - return Err(DataFusionError::Internal(format!( + return Err($crate::error::_internal_datafusion_err!( "Assertion failed: {} != {} (left: {:?}, right: {:?})", stringify!($left), stringify!($right), left_val, right_val - ))); + )); } }}; ($left:expr, $right:expr, $($arg:tt)+) => {{ let left_val = &$left; let right_val = &$right; if left_val == right_val { - return Err(DataFusionError::Internal(format!( + return Err($crate::error::_internal_datafusion_err!( "Assertion failed: {} != {} (left: {:?}, right: {:?}): {}", stringify!($left), stringify!($right), left_val, right_val, format!($($arg)+) - ))); + )); } }}; } diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 6e7396a7c577e..99bdcb6f74fe6 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -24,7 +24,7 @@ pub mod string_utils; use crate::assert_or_internal_err; use crate::error::{_exec_datafusion_err, _internal_datafusion_err}; -use crate::{DataFusionError, Result, ScalarValue}; +use crate::{Result, ScalarValue}; use arrow::array::{ cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, diff --git a/datafusion/expr-common/src/statistics.rs b/datafusion/expr-common/src/statistics.rs index 7961bf0872e4d..7284673d9a8f7 100644 --- a/datafusion/expr-common/src/statistics.rs +++ b/datafusion/expr-common/src/statistics.rs @@ -26,7 +26,7 @@ use arrow::datatypes::DataType; use datafusion_common::rounding::alter_fp_rounding_mode; use datafusion_common::{ assert_eq_or_internal_err, assert_ne_or_internal_err, assert_or_internal_err, - internal_err, not_impl_err, DataFusionError, Result, ScalarValue, + internal_err, not_impl_err, Result, ScalarValue, }; /// This object defines probabilistic distributions that encode uncertain From db486aa9f1ff0cadd89c130581023a76d76d6f5a Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Mon, 24 Nov 2025 17:24:31 +0200 Subject: [PATCH 2/8] remove all imports due to clippy: error: unused import: `DataFusionError` --- datafusion/catalog-listing/src/helpers.rs | 4 +--- datafusion/datasource-parquet/src/access_plan.rs | 2 +- datafusion/datasource/src/memory.rs | 3 +-- datafusion/datasource/src/sink.rs | 2 +- datafusion/expr/src/expr_schema.rs | 4 +--- datafusion/expr/src/logical_plan/invariants.rs | 2 +- datafusion/expr/src/udf.rs | 3 +-- datafusion/functions-aggregate/src/array_agg.rs | 4 +--- datafusion/functions-aggregate/src/covariance.rs | 3 +-- datafusion/functions-aggregate/src/nth_value.rs | 2 +- datafusion/functions-aggregate/src/regr.rs | 3 +-- datafusion/functions-nested/src/position.rs | 2 +- datafusion/functions-nested/src/set_ops.rs | 3 +-- datafusion/functions/src/core/greatest.rs | 4 +--- datafusion/functions/src/core/greatest_least_utils.rs | 4 +--- datafusion/functions/src/core/least.rs | 4 +--- datafusion/functions/src/math/pi.rs | 2 +- datafusion/functions/src/math/random.rs | 2 +- datafusion/functions/src/string/uuid.rs | 2 +- datafusion/optimizer/src/decorrelate_predicate_subquery.rs | 4 +--- datafusion/optimizer/src/extract_equijoin_predicate.rs | 2 +- datafusion/optimizer/src/optimize_projections/mod.rs | 3 +-- datafusion/optimizer/src/push_down_filter.rs | 2 +- datafusion/optimizer/src/scalar_subquery_to_join.rs | 4 +--- datafusion/physical-expr-common/src/physical_expr.rs | 3 +-- datafusion/physical-expr/src/aggregate.rs | 3 +-- datafusion/physical-expr/src/analysis.rs | 2 +- datafusion/physical-expr/src/expressions/in_list.rs | 4 ++-- datafusion/physical-expr/src/expressions/like.rs | 2 +- datafusion/physical-expr/src/projection.rs | 2 +- datafusion/physical-optimizer/src/coalesce_batches.rs | 4 +--- .../enforce_sorting/replace_with_order_preserving_variants.rs | 2 +- datafusion/physical-optimizer/src/filter_pushdown.rs | 4 +--- datafusion/physical-plan/src/aggregates/mod.rs | 3 +-- datafusion/physical-plan/src/async_func.rs | 2 +- datafusion/physical-plan/src/coalesce/mod.rs | 2 +- datafusion/physical-plan/src/coalesce_partitions.rs | 4 +--- datafusion/physical-plan/src/coop.rs | 2 +- datafusion/physical-plan/src/empty.rs | 2 +- datafusion/physical-plan/src/explain.rs | 2 +- datafusion/physical-plan/src/joins/cross_join.rs | 3 +-- datafusion/physical-plan/src/joins/hash_join/exec.rs | 4 ++-- datafusion/physical-plan/src/joins/sort_merge_join/exec.rs | 4 ++-- datafusion/physical-plan/src/joins/symmetric_hash_join.rs | 4 ++-- datafusion/physical-plan/src/limit.rs | 4 +--- datafusion/physical-plan/src/memory.rs | 4 +--- datafusion/physical-plan/src/placeholder_row.rs | 2 +- datafusion/physical-plan/src/sorts/sort_preserving_merge.rs | 4 +--- datafusion/physical-plan/src/sorts/streaming_merge.rs | 2 +- datafusion/physical-plan/src/test.rs | 3 +-- datafusion/physical-plan/src/union.rs | 2 +- datafusion/physical-plan/src/windows/window_agg_exec.rs | 2 +- datafusion/physical-plan/src/work_table.rs | 4 +--- datafusion/proto/src/common.rs | 4 +--- datafusion/proto/src/logical_plan/mod.rs | 2 +- datafusion/pruning/src/pruning_predicate.rs | 2 +- datafusion/spark/src/function/math/modulus.rs | 4 +--- datafusion/spark/src/function/math/rint.rs | 1 - datafusion/sql/src/expr/identifier.rs | 3 +-- datafusion/sql/src/unparser/expr.rs | 2 +- datafusion/sql/src/utils.rs | 2 +- 61 files changed, 64 insertions(+), 111 deletions(-) diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 5e69cf1a14022..34073338fbd7e 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -21,9 +21,7 @@ use std::mem; use std::sync::Arc; use datafusion_catalog::Session; -use datafusion_common::{ - assert_or_internal_err, DataFusionError, HashMap, Result, ScalarValue, -}; +use datafusion_common::{assert_or_internal_err, HashMap, Result, ScalarValue}; use datafusion_datasource::ListingTableUrl; use datafusion_datasource::PartitionedFile; use datafusion_expr::{lit, utils, BinaryExpr, Operator}; diff --git a/datafusion/datasource-parquet/src/access_plan.rs b/datafusion/datasource-parquet/src/access_plan.rs index 295ecea9468e7..7399a2cd0856a 100644 --- a/datafusion/datasource-parquet/src/access_plan.rs +++ b/datafusion/datasource-parquet/src/access_plan.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{assert_eq_or_internal_err, DataFusionError, Result}; +use datafusion_common::{assert_eq_or_internal_err, Result}; use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; use parquet::file::metadata::RowGroupMetaData; diff --git a/datafusion/datasource/src/memory.rs b/datafusion/datasource/src/memory.rs index 036574ddd3c39..595b1bf6d4268 100644 --- a/datafusion/datasource/src/memory.rs +++ b/datafusion/datasource/src/memory.rs @@ -30,8 +30,7 @@ use crate::source::{DataSource, DataSourceExec}; use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::{ - assert_or_internal_err, plan_err, project_schema, DataFusionError, Result, - ScalarValue, + assert_or_internal_err, plan_err, project_schema, Result, ScalarValue, }; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::project_orderings; diff --git a/datafusion/datasource/src/sink.rs b/datafusion/datasource/src/sink.rs index a4ab78d07840e..f66fbc408c68b 100644 --- a/datafusion/datasource/src/sink.rs +++ b/datafusion/datasource/src/sink.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, RecordBatch, UInt64Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::{assert_eq_or_internal_err, DataFusionError, Result}; +use datafusion_common::{assert_eq_or_internal_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{Distribution, EquivalenceProperties}; use datafusion_physical_expr_common::sort_expr::{LexRequirement, OrderingRequirements}; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 94f5b0480b651..3ef61da91bd82 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -835,9 +835,7 @@ mod tests { use super::*; use crate::{and, col, lit, not, or, out_ref_col_with_metadata, when}; - use datafusion_common::{ - assert_or_internal_err, DFSchema, DataFusionError, ScalarValue, - }; + use datafusion_common::{assert_or_internal_err, DFSchema, ScalarValue}; macro_rules! test_is_expr_nullable { ($EXPR_TYPE:ident) => {{ diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index a416c7f7465c3..1c2c8a2a936f5 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -18,7 +18,7 @@ use datafusion_common::{ assert_or_internal_err, plan_err, tree_node::{TreeNode, TreeNodeRecursion}, - DFSchemaRef, DataFusionError, Result, + DFSchemaRef, Result, }; use crate::{ diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 449ddf59094a0..92caf5427d637 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -26,8 +26,7 @@ use crate::{ColumnarValue, Documentation, Expr, Signature}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::{ - assert_or_internal_err, not_impl_err, DataFusionError, ExprSchema, Result, - ScalarValue, + assert_or_internal_err, not_impl_err, ExprSchema, Result, ScalarValue, }; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index f0ee7327b90e9..4f5797c308f9b 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -32,9 +32,7 @@ use datafusion_common::cast::as_list_array; use datafusion_common::utils::{ compare_rows, get_row_at_idx, take_function_args, SingleRowListArrayBuilder, }; -use datafusion_common::{ - assert_eq_or_internal_err, exec_err, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{assert_eq_or_internal_err, exec_err, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index f74fddd603319..7e34ffbaad01b 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -24,8 +24,7 @@ use arrow::{ datatypes::{DataType, Field}, }; use datafusion_common::{ - downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result, - ScalarValue, + downcast_value, plan_err, unwrap_or_internal_err, Result, ScalarValue, }; use datafusion_expr::{ function::{AccumulatorArgs, StateFieldsArgs}, diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index adf3e47b7d5ab..05026940fec45 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -28,7 +28,7 @@ use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; use datafusion_common::{ - assert_or_internal_err, exec_err, not_impl_err, DataFusionError, Result, ScalarValue, + assert_or_internal_err, exec_err, not_impl_err, Result, ScalarValue, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; diff --git a/datafusion/functions-aggregate/src/regr.rs b/datafusion/functions-aggregate/src/regr.rs index 44ce0bd48ead6..045cb99838430 100644 --- a/datafusion/functions-aggregate/src/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -26,8 +26,7 @@ use arrow::{ datatypes::Field, }; use datafusion_common::{ - downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, HashMap, Result, - ScalarValue, + downcast_value, plan_err, unwrap_or_internal_err, HashMap, Result, ScalarValue, }; use datafusion_doc::aggregate_doc_sections::DOC_SECTION_STATISTICAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; diff --git a/datafusion/functions-nested/src/position.rs b/datafusion/functions-nested/src/position.rs index 14f2ed3313d46..2844eefaf058d 100644 --- a/datafusion/functions-nested/src/position.rs +++ b/datafusion/functions-nested/src/position.rs @@ -38,7 +38,7 @@ use datafusion_common::cast::{ as_generic_list_array, as_int64_array, as_large_list_array, as_list_array, }; use datafusion_common::{ - assert_or_internal_err, exec_err, utils::take_function_args, DataFusionError, Result, + assert_or_internal_err, exec_err, utils::take_function_args, Result, }; use itertools::Itertools; diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index e3531d1cf8eec..4350bfdc5a9bc 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -30,8 +30,7 @@ use arrow::row::{RowConverter, SortField}; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::utils::ListCoercion; use datafusion_common::{ - assert_eq_or_internal_err, exec_err, internal_err, utils::take_function_args, - DataFusionError, Result, + assert_eq_or_internal_err, exec_err, internal_err, utils::take_function_args, Result, }; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, diff --git a/datafusion/functions/src/core/greatest.rs b/datafusion/functions/src/core/greatest.rs index 02491501dca66..95fd8e64d7274 100644 --- a/datafusion/functions/src/core/greatest.rs +++ b/datafusion/functions/src/core/greatest.rs @@ -21,9 +21,7 @@ use arrow::buffer::BooleanBuffer; use arrow::compute::kernels::cmp; use arrow::compute::SortOptions; use arrow::datatypes::DataType; -use datafusion_common::{ - assert_eq_or_internal_err, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{assert_eq_or_internal_err, Result, ScalarValue}; use datafusion_doc::Documentation; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; diff --git a/datafusion/functions/src/core/greatest_least_utils.rs b/datafusion/functions/src/core/greatest_least_utils.rs index 2a8666685621e..78c6864d8c9f5 100644 --- a/datafusion/functions/src/core/greatest_least_utils.rs +++ b/datafusion/functions/src/core/greatest_least_utils.rs @@ -18,9 +18,7 @@ use arrow::array::{Array, ArrayRef, BooleanArray}; use arrow::compute::kernels::zip::zip; use arrow::datatypes::DataType; -use datafusion_common::{ - assert_or_internal_err, plan_err, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{assert_or_internal_err, plan_err, Result, ScalarValue}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::type_coercion::binary::type_union_resolution; use std::sync::Arc; diff --git a/datafusion/functions/src/core/least.rs b/datafusion/functions/src/core/least.rs index 45e0d1c4cbc9f..602cd4169a3fd 100644 --- a/datafusion/functions/src/core/least.rs +++ b/datafusion/functions/src/core/least.rs @@ -21,9 +21,7 @@ use arrow::buffer::BooleanBuffer; use arrow::compute::kernels::cmp; use arrow::compute::SortOptions; use arrow::datatypes::DataType; -use datafusion_common::{ - assert_eq_or_internal_err, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{assert_eq_or_internal_err, Result, ScalarValue}; use datafusion_doc::Documentation; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; diff --git a/datafusion/functions/src/math/pi.rs b/datafusion/functions/src/math/pi.rs index dcb2886a3babe..92a27932e1649 100644 --- a/datafusion/functions/src/math/pi.rs +++ b/datafusion/functions/src/math/pi.rs @@ -19,7 +19,7 @@ use std::any::Any; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Float64; -use datafusion_common::{assert_or_internal_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{assert_or_internal_err, Result, ScalarValue}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, diff --git a/datafusion/functions/src/math/random.rs b/datafusion/functions/src/math/random.rs index 77067d39dca8f..4270eff665728 100644 --- a/datafusion/functions/src/math/random.rs +++ b/datafusion/functions/src/math/random.rs @@ -23,7 +23,7 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::Float64; use rand::{rng, Rng}; -use datafusion_common::{assert_or_internal_err, DataFusionError, Result}; +use datafusion_common::{assert_or_internal_err, Result}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs index 6a8585a19242e..96ce9439028c6 100644 --- a/datafusion/functions/src/string/uuid.rs +++ b/datafusion/functions/src/string/uuid.rs @@ -24,7 +24,7 @@ use arrow::datatypes::DataType::Utf8; use rand::Rng; use uuid::Uuid; -use datafusion_common::{assert_or_internal_err, DataFusionError, Result}; +use datafusion_common::{assert_or_internal_err, Result}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index c9cc33708bb52..0590aba52bfab 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -27,9 +27,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{ - assert_or_internal_err, plan_err, Column, DataFusionError, Result, -}; +use datafusion_common::{assert_or_internal_err, plan_err, Column, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index b5256a338bb7e..9228e84abf931 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -19,7 +19,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::{assert_or_internal_err, DFSchema, DataFusionError}; +use datafusion_common::{assert_or_internal_err, DFSchema}; use datafusion_common::{NullEquality, Result}; use datafusion_expr::utils::split_conjunction_owned; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 60343ff2ec2ab..ee7b006a2d496 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -26,8 +26,7 @@ use std::sync::Arc; use datafusion_common::{ assert_eq_or_internal_err, get_required_group_by_exprs_indices, - internal_datafusion_err, internal_err, Column, DFSchema, DataFusionError, HashMap, - JoinType, Result, + internal_datafusion_err, internal_err, Column, DFSchema, HashMap, JoinType, Result, }; use datafusion_expr::expr::Alias; use datafusion_expr::{ diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index a38cd7a75bc1e..ea0980ad4e1c7 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -29,7 +29,7 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{ assert_eq_or_internal_err, assert_or_internal_err, internal_err, plan_err, - qualified_name, Column, DFSchema, DataFusionError, Result, + qualified_name, Column, DFSchema, Result, }; use datafusion_expr::expr::WindowFunction; use datafusion_expr::expr_rewriter::replace_col; diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index dd27ded58699b..2df1be1b7f0ba 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -30,9 +30,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; -use datafusion_common::{ - assert_or_internal_err, plan_err, Column, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{assert_or_internal_err, plan_err, Column, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::conjunction; diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index efdd6fcb6265e..f7f6326876523 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -31,8 +31,7 @@ use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{ - assert_eq_or_internal_err, exec_err, not_impl_err, DataFusionError, Result, - ScalarValue, + assert_eq_or_internal_err, exec_err, not_impl_err, Result, ScalarValue, }; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index ae5a4a8559470..f16895b44bf5e 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -42,8 +42,7 @@ use crate::expressions::Column; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef}; use datafusion_common::{ - assert_or_internal_err, internal_err, not_impl_err, DataFusionError, Result, - ScalarValue, + assert_or_internal_err, internal_err, not_impl_err, Result, ScalarValue, }; use datafusion_expr::{AggregateUDF, ReversedUDAF, SetMonotonicity}; use datafusion_expr_common::accumulator::Accumulator; diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index 981acbb779b68..166e639966f13 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -29,7 +29,7 @@ use arrow::datatypes::Schema; use datafusion_common::stats::Precision; use datafusion_common::{ assert_or_internal_err, internal_datafusion_err, internal_err, ColumnStatistics, - DataFusionError, Result, ScalarValue, + Result, ScalarValue, }; use datafusion_expr::interval_arithmetic::{cardinality_ratio, Interval}; diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 95029c1efe74c..bb033aac03ed6 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -33,8 +33,8 @@ use arrow::datatypes::*; use arrow::util::bit_iterator::BitIndexIterator; use datafusion_common::hash_utils::with_hashes; use datafusion_common::{ - assert_or_internal_err, exec_datafusion_err, exec_err, DFSchema, DataFusionError, - HashSet, Result, ScalarValue, + assert_or_internal_err, exec_datafusion_err, exec_err, DFSchema, HashSet, Result, + ScalarValue, }; use datafusion_expr::{expr_vec_fmt, ColumnarValue}; diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index 3046e8a028a8e..5502def5820f6 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -18,7 +18,7 @@ use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{assert_or_internal_err, DataFusionError, Result}; +use datafusion_common::{assert_or_internal_err, Result}; use datafusion_expr::{ColumnarValue, Operator}; use datafusion_physical_expr_common::datum::apply_cmp; use std::hash::Hash; diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index 74edaa191508c..3d6740510bec6 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -27,7 +27,7 @@ use arrow::datatypes::{Field, Schema, SchemaRef}; use datafusion_common::stats::{ColumnStatistics, Precision}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ - assert_or_internal_err, internal_datafusion_err, plan_err, DataFusionError, Result, + assert_or_internal_err, internal_datafusion_err, plan_err, Result, }; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; diff --git a/datafusion/physical-optimizer/src/coalesce_batches.rs b/datafusion/physical-optimizer/src/coalesce_batches.rs index 61e4c0e7f1801..d399fd335e463 100644 --- a/datafusion/physical-optimizer/src/coalesce_batches.rs +++ b/datafusion/physical-optimizer/src/coalesce_batches.rs @@ -23,9 +23,7 @@ use crate::PhysicalOptimizerRule; use std::sync::Arc; use datafusion_common::error::Result; -use datafusion_common::{ - assert_eq_or_internal_err, config::ConfigOptions, DataFusionError, -}; +use datafusion_common::{assert_eq_or_internal_err, config::ConfigOptions}; use datafusion_physical_expr::Partitioning; use datafusion_physical_plan::{ async_func::AsyncFuncExec, coalesce_batches::CoalesceBatchesExec, diff --git a/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs b/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs index 49c66e0ab2442..2c9303d7ea690 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/replace_with_order_preserving_variants.rs @@ -27,7 +27,7 @@ use crate::utils::{ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::Transformed; -use datafusion_common::{assert_or_internal_err, DataFusionError, Result}; +use datafusion_common::{assert_or_internal_err, Result}; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::repartition::RepartitionExec; diff --git a/datafusion/physical-optimizer/src/filter_pushdown.rs b/datafusion/physical-optimizer/src/filter_pushdown.rs index 8bed6c3aeba06..22cb03fc3e876 100644 --- a/datafusion/physical-optimizer/src/filter_pushdown.rs +++ b/datafusion/physical-optimizer/src/filter_pushdown.rs @@ -36,9 +36,7 @@ use std::sync::Arc; use crate::PhysicalOptimizerRule; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{ - assert_eq_or_internal_err, config::ConfigOptions, DataFusionError, Result, -}; +use datafusion_common::{assert_eq_or_internal_err, config::ConfigOptions, Result}; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_common::physical_expr::is_volatile; use datafusion_physical_plan::filter_pushdown::{ diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 6bf59fd3d3039..6c59195f76358 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -45,8 +45,7 @@ use arrow::record_batch::RecordBatch; use arrow_schema::FieldRef; use datafusion_common::stats::Precision; use datafusion_common::{ - assert_eq_or_internal_err, not_impl_err, Constraint, Constraints, DataFusionError, - Result, + assert_eq_or_internal_err, not_impl_err, Constraint, Constraints, Result, }; use datafusion_execution::TaskContext; use datafusion_expr::{Accumulator, Aggregate}; diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index fcb2ce54bed99..d442307e9488e 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -23,7 +23,7 @@ use crate::{ use arrow::array::RecordBatch; use arrow_schema::{Fields, Schema, SchemaRef}; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; -use datafusion_common::{assert_eq_or_internal_err, DataFusionError, Result}; +use datafusion_common::{assert_eq_or_internal_err, Result}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; use datafusion_physical_expr::equivalence::ProjectionMapping; diff --git a/datafusion/physical-plan/src/coalesce/mod.rs b/datafusion/physical-plan/src/coalesce/mod.rs index 860fc81bf6658..d0930b2c0e58a 100644 --- a/datafusion/physical-plan/src/coalesce/mod.rs +++ b/datafusion/physical-plan/src/coalesce/mod.rs @@ -18,7 +18,7 @@ use arrow::array::RecordBatch; use arrow::compute::BatchCoalescer; use arrow::datatypes::SchemaRef; -use datafusion_common::{assert_or_internal_err, DataFusionError, Result}; +use datafusion_common::{assert_or_internal_err, Result}; /// Concatenate multiple [`RecordBatch`]es and apply a limit /// diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index dfb1800a2e6aa..64e0315a523d1 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -33,9 +33,7 @@ use crate::projection::{make_with_child, ProjectionExec}; use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{ - assert_eq_or_internal_err, internal_err, DataFusionError, Result, -}; +use datafusion_common::{assert_eq_or_internal_err, internal_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::PhysicalExpr; diff --git a/datafusion/physical-plan/src/coop.rs b/datafusion/physical-plan/src/coop.rs index 29d5ba1ca84f5..aa5e7b4a8cec1 100644 --- a/datafusion/physical-plan/src/coop.rs +++ b/datafusion/physical-plan/src/coop.rs @@ -85,7 +85,7 @@ use crate::{ }; use arrow::record_batch::RecordBatch; use arrow_schema::Schema; -use datafusion_common::{assert_eq_or_internal_err, DataFusionError, Result, Statistics}; +use datafusion_common::{assert_eq_or_internal_err, Result, Statistics}; use datafusion_execution::TaskContext; use crate::execution_plan::SchedulingType; diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index e9606fe1be692..e072b55ecff44 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -29,7 +29,7 @@ use crate::{ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::{assert_or_internal_err, DataFusionError, Result}; +use datafusion_common::{assert_or_internal_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; diff --git a/datafusion/physical-plan/src/explain.rs b/datafusion/physical-plan/src/explain.rs index 91e2593ec4cd6..4b8491cf14dd8 100644 --- a/datafusion/physical-plan/src/explain.rs +++ b/datafusion/physical-plan/src/explain.rs @@ -27,7 +27,7 @@ use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; use datafusion_common::display::StringifiedPlan; -use datafusion_common::{assert_eq_or_internal_err, DataFusionError, Result}; +use datafusion_common::{assert_eq_or_internal_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 42a76c6d51278..0488cd35a8e36 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -42,8 +42,7 @@ use arrow::compute::concat_batches; use arrow::datatypes::{Fields, Schema, SchemaRef}; use datafusion_common::stats::Precision; use datafusion_common::{ - assert_eq_or_internal_err, internal_err, DataFusionError, JoinType, Result, - ScalarValue, + assert_eq_or_internal_err, internal_err, JoinType, Result, ScalarValue, }; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 03bf516eadd17..97ee8ecbdba8a 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -66,8 +66,8 @@ use arrow_schema::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ - assert_or_internal_err, plan_err, project_schema, DataFusionError, JoinSide, - JoinType, NullEquality, Result, + assert_or_internal_err, plan_err, project_schema, JoinSide, JoinType, NullEquality, + Result, }; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs index daf47603c217f..b5b4325798f9d 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs @@ -45,8 +45,8 @@ use crate::{ use arrow::compute::SortOptions; use arrow::datatypes::SchemaRef; use datafusion_common::{ - assert_eq_or_internal_err, internal_err, plan_err, DataFusionError, JoinSide, - JoinType, NullEquality, Result, + assert_eq_or_internal_err, internal_err, plan_err, JoinSide, JoinType, NullEquality, + Result, }; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 8b3677713a463..9c778ad131846 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -67,8 +67,8 @@ use arrow::record_batch::RecordBatch; use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::bisect; use datafusion_common::{ - assert_eq_or_internal_err, plan_err, DataFusionError, HashSet, JoinSide, JoinType, - NullEquality, Result, + assert_eq_or_internal_err, plan_err, HashSet, JoinSide, JoinType, NullEquality, + Result, }; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index a78de53fccff9..4646e8ebc3132 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -32,9 +32,7 @@ use crate::{DisplayFormatType, Distribution, ExecutionPlan, Partitioning}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::{ - assert_eq_or_internal_err, internal_err, DataFusionError, Result, -}; +use datafusion_common::{assert_eq_or_internal_err, internal_err, Result}; use datafusion_execution::TaskContext; use futures::stream::{Stream, StreamExt}; diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 97a748e33a59c..92e789ebc5965 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -32,9 +32,7 @@ use crate::{ use arrow::array::RecordBatch; use arrow::datatypes::SchemaRef; -use datafusion_common::{ - assert_eq_or_internal_err, assert_or_internal_err, DataFusionError, Result, -}; +use datafusion_common::{assert_eq_or_internal_err, assert_or_internal_err, Result}; use datafusion_execution::memory_pool::MemoryReservation; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index d657128ee12a4..be4c3da509e88 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -30,7 +30,7 @@ use crate::{ use arrow::array::{ArrayRef, NullArray, RecordBatch, RecordBatchOptions}; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; -use datafusion_common::{assert_or_internal_err, DataFusionError, Result}; +use datafusion_common::{assert_or_internal_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 34b6d5108f9e4..3361a7cdb7185 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -30,9 +30,7 @@ use crate::{ Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, }; -use datafusion_common::{ - assert_eq_or_internal_err, internal_err, DataFusionError, Result, -}; +use datafusion_common::{assert_eq_or_internal_err, internal_err, Result}; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs index ea0358a55fc27..047fbd8cbd81d 100644 --- a/datafusion/physical-plan/src/sorts/streaming_merge.rs +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -27,7 +27,7 @@ use crate::sorts::{ use crate::{SendableRecordBatchStream, SpillManager}; use arrow::array::*; use arrow::datatypes::{DataType, SchemaRef}; -use datafusion_common::{assert_or_internal_err, internal_err, DataFusionError, Result}; +use datafusion_common::{assert_or_internal_err, internal_err, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{ human_readable_size, MemoryConsumer, MemoryPool, MemoryReservation, diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index 4f7b843262dea..e3b22611f4deb 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -37,8 +37,7 @@ use crate::{DisplayAs, DisplayFormatType, PlanProperties}; use arrow::array::{Array, ArrayRef, Int32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{ - assert_or_internal_err, config::ConfigOptions, project_schema, DataFusionError, - Result, Statistics, + assert_or_internal_err, config::ConfigOptions, project_schema, Result, Statistics, }; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::equivalence::{ diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 1f4bc8817ed59..06c28a8081ef6 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -46,7 +46,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; use datafusion_common::stats::Precision; use datafusion_common::{ - assert_or_internal_err, exec_err, internal_datafusion_err, DataFusionError, Result, + assert_or_internal_err, exec_err, internal_datafusion_err, Result, }; use datafusion_execution::TaskContext; use datafusion_physical_expr::{calculate_union, EquivalenceProperties, PhysicalExpr}; diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index 810c97cf47451..b588608397f40 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -42,7 +42,7 @@ use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; use datafusion_common::utils::{evaluate_partition_ranges, transpose}; -use datafusion_common::{assert_eq_or_internal_err, DataFusionError, Result}; +use datafusion_common::{assert_eq_or_internal_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr_common::sort_expr::{ OrderingRequirements, PhysicalSortExpr, diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index a77e7b2cf10fc..e2c6efd508ba9 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -31,9 +31,7 @@ use crate::{ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::{ - assert_eq_or_internal_err, internal_datafusion_err, DataFusionError, Result, -}; +use datafusion_common::{assert_eq_or_internal_err, internal_datafusion_err, Result}; use datafusion_execution::memory_pool::MemoryReservation; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; diff --git a/datafusion/proto/src/common.rs b/datafusion/proto/src/common.rs index da4cd942ccb90..508e9af419c58 100644 --- a/datafusion/proto/src/common.rs +++ b/datafusion/proto/src/common.rs @@ -15,9 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{ - assert_eq_or_internal_err, internal_datafusion_err, DataFusionError, Result, -}; +use datafusion_common::{assert_eq_or_internal_err, internal_datafusion_err, Result}; pub(crate) fn str_to_byte(s: &String, description: &str) -> Result { assert_eq_or_internal_err!( diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 84d6688b789b8..7a8cbafc22bf8 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -38,7 +38,7 @@ use datafusion_catalog::cte_worktable::CteWorkTable; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ assert_or_internal_err, context, internal_datafusion_err, internal_err, not_impl_err, - plan_err, DataFusionError, Result, TableReference, ToDFSchema, + plan_err, Result, TableReference, ToDFSchema, }; use datafusion_datasource::file_format::FileFormat; use datafusion_datasource::file_format::{ diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index 527a0e094613e..4084da820c0d4 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -40,7 +40,7 @@ use datafusion_common::{assert_eq_or_internal_err, Column, DFSchema}; use datafusion_common::{ internal_datafusion_err, plan_datafusion_err, plan_err, tree_node::{Transformed, TreeNode}, - DataFusionError, ScalarValue, + ScalarValue, }; use datafusion_expr_common::operator::Operator; use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; diff --git a/datafusion/spark/src/function/math/modulus.rs b/datafusion/spark/src/function/math/modulus.rs index aa66b179e2d4c..60d45baa7f380 100644 --- a/datafusion/spark/src/function/math/modulus.rs +++ b/datafusion/spark/src/function/math/modulus.rs @@ -18,9 +18,7 @@ use arrow::compute::kernels::numeric::add; use arrow::compute::kernels::{cmp::lt, numeric::rem, zip::zip}; use arrow::datatypes::DataType; -use datafusion_common::{ - assert_eq_or_internal_err, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{assert_eq_or_internal_err, Result, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; diff --git a/datafusion/spark/src/function/math/rint.rs b/datafusion/spark/src/function/math/rint.rs index 40dac9cb31d87..3271be38f8338 100644 --- a/datafusion/spark/src/function/math/rint.rs +++ b/datafusion/spark/src/function/math/rint.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::DataFusionError; use std::any::Any; use std::sync::Arc; diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 498c58dd89842..b5db580217741 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -18,8 +18,7 @@ use arrow::datatypes::Field; use datafusion_common::{ assert_or_internal_err, exec_datafusion_err, internal_err, not_impl_err, - plan_datafusion_err, plan_err, Column, DFSchema, DataFusionError, Result, Span, - TableReference, + plan_datafusion_err, plan_err, Column, DFSchema, Result, Span, TableReference, }; use datafusion_expr::planner::PlannerResult; use datafusion_expr::{Case, Expr}; diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 10c34d5a4df7b..575cfd27ee354 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -41,7 +41,7 @@ use arrow::datatypes::{ use arrow::util::display::array_value_to_string; use datafusion_common::{ assert_eq_or_internal_err, assert_or_internal_err, internal_datafusion_err, - internal_err, not_impl_err, plan_err, Column, DataFusionError, Result, ScalarValue, + internal_err, not_impl_err, plan_err, Column, Result, ScalarValue, }; use datafusion_expr::{ expr::{Alias, Exists, InList, ScalarFunction, Sort, WindowFunction}, diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 91ab2e003c87a..c3f7cb37bdee8 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -27,7 +27,7 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{ assert_or_internal_err, exec_datafusion_err, exec_err, internal_err, plan_err, - Column, DFSchemaRef, DataFusionError, Diagnostic, HashMap, Result, ScalarValue, + Column, DFSchemaRef, Diagnostic, HashMap, Result, ScalarValue, }; use datafusion_expr::builder::get_struct_unnested_columns; use datafusion_expr::expr::{ From 359653555df48a1364529b05d0066aaaa7ced989 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Mon, 24 Nov 2025 17:38:38 +0200 Subject: [PATCH 3/8] strip backtrace from error --- datafusion/common/src/error.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 993dc0b322b62..aaa84a6b94e38 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -1121,7 +1121,7 @@ mod test { ok_result() } - let err = check().unwrap_err(); + let err = check().unwrap_err().strip_backtrace(); assert_snapshot!( err.to_string(), @r" @@ -1144,7 +1144,7 @@ mod test { ok_result() } - let err = check().unwrap_err(); + let err = check().unwrap_err().strip_backtrace(); assert_snapshot!( err.to_string(), @r" @@ -1168,7 +1168,7 @@ mod test { ok_result() } - let err = check().unwrap_err(); + let err = check().unwrap_err().strip_backtrace(); assert_snapshot!( err.to_string(), @r" @@ -1185,7 +1185,7 @@ mod test { ok_result() } - let err = check().unwrap_err(); + let err = check().unwrap_err().strip_backtrace(); assert_snapshot!( err.to_string(), @r" @@ -1202,7 +1202,7 @@ mod test { ok_result() } - let err = check().unwrap_err(); + let err = check().unwrap_err().strip_backtrace(); assert_snapshot!( err.to_string(), @r" From 27bc152fae80769990356eb4c67b46a7f84af0fd Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 10 Jun 2026 13:15:09 +0300 Subject: [PATCH 4/8] revert --- Cargo.lock | 6 + datafusion/catalog-listing/Cargo.toml | 1 + datafusion/catalog-listing/src/config.rs | 88 +- datafusion/catalog-listing/src/helpers.rs | 231 +- datafusion/catalog-listing/src/mod.rs | 4 +- datafusion/catalog-listing/src/options.rs | 33 +- datafusion/catalog-listing/src/table.rs | 594 +- datafusion/common/src/utils/mod.rs | 794 ++- datafusion/datasource-parquet/Cargo.toml | 18 + .../benches/parquet_metadata_statistics.rs | 303 + .../benches/parquet_nested_filter_pushdown.rs | 238 + .../benches/parquet_struct_filter_pushdown.rs | 353 ++ .../datasource-parquet/src/access_plan.rs | 497 +- .../datasource-parquet/src/bloom_filter.rs | 560 ++ .../src/decoder_projection.rs | 159 + .../datasource-parquet/src/file_format.rs | 1462 +---- datafusion/datasource-parquet/src/metadata.rs | 1239 +++- datafusion/datasource-parquet/src/metrics.rs | 158 +- datafusion/datasource-parquet/src/mod.rs | 24 +- .../src/opener/early_stop.rs | 107 + .../src/opener/encryption.rs | 104 + .../datasource-parquet/src/opener/mod.rs | 3397 +++++++++++ .../datasource-parquet/src/page_filter.rs | 152 +- .../datasource-parquet/src/push_decoder.rs | 221 + datafusion/datasource-parquet/src/reader.rs | 9 +- .../datasource-parquet/src/row_filter.rs | 1883 ++++++- .../src/row_group_filter.rs | 904 ++- .../datasource-parquet/src/schema_coercion.rs | 843 +++ datafusion/datasource-parquet/src/sink.rs | 754 +++ datafusion/datasource-parquet/src/sort.rs | 1108 ++++ datafusion/datasource-parquet/src/source.rs | 1131 +++- .../src/supported_predicates.rs | 143 + .../src/test_data/ndv_test.parquet | Bin 0 -> 1141 bytes .../datasource-parquet/src/test_util.rs | 71 + .../datasource-parquet/src/virtual_column.rs | 125 + datafusion/datasource/Cargo.toml | 13 +- .../benches/split_groups_by_statistics.rs | 2 +- datafusion/datasource/src/decoder.rs | 28 +- datafusion/datasource/src/display.rs | 11 +- datafusion/datasource/src/file.rs | 264 +- .../datasource/src/file_compression_type.rs | 60 +- datafusion/datasource/src/file_format.rs | 108 +- datafusion/datasource/src/file_groups.rs | 380 +- .../datasource/src/file_scan_config/mod.rs | 3192 +++++++++++ .../src/file_scan_config/sort_pushdown.rs | 622 ++ datafusion/datasource/src/file_sink_config.rs | 52 +- .../datasource/src/file_stream/builder.rs | 142 + .../datasource/src/file_stream/metrics.rs | 159 + datafusion/datasource/src/file_stream/mod.rs | 1678 ++++++ .../datasource/src/file_stream/scan_state.rs | 304 + .../datasource/src/file_stream/work_source.rs | 109 + datafusion/datasource/src/memory.rs | 124 +- datafusion/datasource/src/mod.rs | 333 +- datafusion/datasource/src/morsel/adapters.rs | 122 + datafusion/datasource/src/morsel/mocks.rs | 746 +++ datafusion/datasource/src/morsel/mod.rs | 234 + datafusion/datasource/src/projection.rs | 630 +++ datafusion/datasource/src/schema_adapter.rs | 1076 +--- datafusion/datasource/src/sink.rs | 36 +- datafusion/datasource/src/source.rs | 276 +- datafusion/datasource/src/statistics.rs | 523 +- datafusion/datasource/src/table_schema.rs | 544 +- datafusion/datasource/src/test_util.rs | 71 +- datafusion/datasource/src/url.rs | 395 +- datafusion/datasource/src/write/demux.rs | 51 +- datafusion/datasource/src/write/mod.rs | 25 +- .../datasource/src/write/orchestration.rs | 24 +- datafusion/expr-common/Cargo.toml | 4 +- datafusion/expr-common/src/accumulator.rs | 27 +- datafusion/expr-common/src/casts.rs | 105 +- datafusion/expr-common/src/columnar_value.rs | 233 +- datafusion/expr-common/src/dyn_eq.rs | 4 +- .../expr-common/src/groups_accumulator.rs | 65 +- .../expr-common/src/interval_arithmetic.rs | 497 +- datafusion/expr-common/src/lib.rs | 5 +- datafusion/expr-common/src/operator.rs | 98 +- datafusion/expr-common/src/placement.rs | 62 + datafusion/expr-common/src/signature.rs | 498 +- datafusion/expr-common/src/statistics.rs | 116 +- .../src/type_coercion/aggregates.rs | 18 +- .../expr-common/src/type_coercion/binary.rs | 751 ++- .../type_coercion/binary/tests/arithmetic.rs | 68 +- .../type_coercion/binary/tests/comparison.rs | 318 +- .../type_coercion/binary/tests/dictionary.rs | 20 +- .../binary/tests/run_end_encoded.rs | 56 +- datafusion/expr/Cargo.toml | 7 +- datafusion/expr/src/arguments.rs | 475 +- datafusion/expr/src/async_udf.rs | 16 +- .../expr/src/conditional_expressions.rs | 4 +- datafusion/expr/src/execution_props.rs | 215 +- datafusion/expr/src/expr.rs | 882 ++- datafusion/expr/src/expr_fn.rs | 41 +- .../expr/src/expr_rewriter/guarantees.rs | 13 +- datafusion/expr/src/expr_rewriter/mod.rs | 35 +- datafusion/expr/src/expr_rewriter/order_by.rs | 264 +- datafusion/expr/src/expr_schema.rs | 579 +- .../array_formatter_factory.rs | 67 + datafusion/expr/src/extension_types/mod.rs | 22 + datafusion/expr/src/function.rs | 40 +- datafusion/expr/src/higher_order_function.rs | 1684 ++++++ datafusion/expr/src/lib.rs | 37 +- datafusion/expr/src/literal.rs | 2 +- datafusion/expr/src/logical_plan/builder.rs | 295 +- datafusion/expr/src/logical_plan/ddl.rs | 189 +- datafusion/expr/src/logical_plan/display.rs | 68 +- datafusion/expr/src/logical_plan/dml.rs | 11 +- .../expr/src/logical_plan/invariants.rs | 42 +- datafusion/expr/src/logical_plan/mod.rs | 17 +- datafusion/expr/src/logical_plan/plan.rs | 1150 +++- datafusion/expr/src/logical_plan/statement.rs | 9 +- datafusion/expr/src/logical_plan/tree_node.rs | 153 +- datafusion/expr/src/partition_evaluator.rs | 12 +- datafusion/expr/src/planner.rs | 158 +- datafusion/expr/src/predicate_bounds.rs | 12 +- datafusion/expr/src/preimage.rs | 29 + datafusion/expr/src/ptr_eq.rs | 2 +- datafusion/expr/src/registry.rs | 385 +- datafusion/expr/src/select_expr.rs | 2 +- datafusion/expr/src/simplify.rs | 224 +- datafusion/expr/src/sql.rs | 174 + datafusion/expr/src/table_source.rs | 14 +- datafusion/expr/src/test/function_stub.rs | 52 +- datafusion/expr/src/tree_node.rs | 73 +- .../expr/src/type_coercion/functions.rs | 1298 ++++- datafusion/expr/src/type_coercion/mod.rs | 24 - datafusion/expr/src/type_coercion/other.rs | 66 +- datafusion/expr/src/udaf.rs | 199 +- datafusion/expr/src/udf.rs | 278 +- datafusion/expr/src/udf_eq.rs | 28 +- datafusion/expr/src/udwf.rs | 80 +- datafusion/expr/src/utils.rs | 290 +- datafusion/expr/src/window_frame.rs | 61 +- datafusion/expr/src/window_state.rs | 74 +- datafusion/functions-aggregate/Cargo.toml | 24 +- .../benches/approx_distinct.rs | 331 ++ .../functions-aggregate/benches/array_agg.rs | 29 +- .../functions-aggregate/benches/count.rs | 5 +- .../benches/count_distinct.rs | 459 ++ .../functions-aggregate/benches/first_last.rs | 359 ++ .../functions-aggregate/benches/median.rs | 122 + .../benches/min_max_bytes.rs | 4 +- .../benches/percentile_cont.rs | 129 + datafusion/functions-aggregate/benches/sum.rs | 5 +- .../src/approx_distinct.rs | 1024 +++- .../functions-aggregate/src/approx_median.rs | 90 +- .../src/approx_percentile_cont.rs | 215 +- .../src/approx_percentile_cont_with_weight.rs | 96 +- .../functions-aggregate/src/array_agg.rs | 1463 ++++- datafusion/functions-aggregate/src/average.rs | 114 +- .../functions-aggregate/src/bit_and_or_xor.rs | 65 +- .../functions-aggregate/src/bool_and_or.rs | 53 +- .../functions-aggregate/src/correlation.rs | 99 +- datafusion/functions-aggregate/src/count.rs | 227 +- .../functions-aggregate/src/covariance.rs | 132 +- .../functions-aggregate/src/first_last.rs | 740 +-- .../src/first_last/state.rs | 462 ++ .../functions-aggregate/src/grouping.rs | 34 +- .../functions-aggregate/src/hyperloglog.rs | 96 +- datafusion/functions-aggregate/src/lib.rs | 2 - datafusion/functions-aggregate/src/macros.rs | 2 - datafusion/functions-aggregate/src/median.rs | 146 +- datafusion/functions-aggregate/src/min_max.rs | 224 +- .../src/min_max/min_max_bytes.rs | 6 +- .../src/min_max/min_max_struct.rs | 7 +- .../functions-aggregate/src/nth_value.rs | 20 +- .../src/percentile_cont.rs | 632 ++- datafusion/functions-aggregate/src/planner.rs | 2 +- datafusion/functions-aggregate/src/regr.rs | 124 +- datafusion/functions-aggregate/src/stddev.rs | 75 +- .../functions-aggregate/src/string_agg.rs | 447 +- datafusion/functions-aggregate/src/sum.rs | 225 +- datafusion/functions-aggregate/src/utils.rs | 9 +- .../functions-aggregate/src/variance.rs | 216 +- datafusion/functions-nested/Cargo.toml | 61 +- .../functions-nested/benches/array_concat.rs | 94 + .../functions-nested/benches/array_has.rs | 856 ++- .../functions-nested/benches/array_min_max.rs | 74 + .../benches/array_position.rs | 344 ++ .../functions-nested/benches/array_range.rs | 208 + .../functions-nested/benches/array_remove.rs | 553 ++ .../functions-nested/benches/array_repeat.rs | 407 ++ .../functions-nested/benches/array_replace.rs | 589 ++ .../functions-nested/benches/array_resize.rs | 170 + .../functions-nested/benches/array_reverse.rs | 6 +- .../functions-nested/benches/array_set_ops.rs | 389 ++ .../functions-nested/benches/array_slice.rs | 228 + .../functions-nested/benches/array_sort.rs | 195 + .../benches/array_to_string.rs | 157 + .../functions-nested/benches/arrays_zip.rs | 117 + datafusion/functions-nested/benches/map.rs | 218 +- .../benches/string_to_array.rs | 244 + datafusion/functions-nested/src/array_add.rs | 203 + .../functions-nested/src/array_any_match.rs | 521 ++ .../functions-nested/src/array_compact.rs | 191 + .../functions-nested/src/array_filter.rs | 464 ++ datafusion/functions-nested/src/array_has.rs | 680 ++- .../functions-nested/src/array_normalize.rs | 207 + .../functions-nested/src/array_product.rs | 174 + .../functions-nested/src/array_scale.rs | 220 + .../functions-nested/src/array_subtract.rs | 130 + datafusion/functions-nested/src/array_sum.rs | 174 + .../functions-nested/src/array_transform.rs | 293 + datafusion/functions-nested/src/arrays_zip.rs | 613 ++ .../functions-nested/src/cardinality.rs | 28 +- datafusion/functions-nested/src/concat.rs | 155 +- .../functions-nested/src/cosine_distance.rs | 219 + datafusion/functions-nested/src/dimension.rs | 22 +- datafusion/functions-nested/src/distance.rs | 17 +- datafusion/functions-nested/src/empty.rs | 14 +- datafusion/functions-nested/src/except.rs | 183 +- datafusion/functions-nested/src/extract.rs | 193 +- datafusion/functions-nested/src/flatten.rs | 17 +- .../functions-nested/src/inner_product.rs | 214 + .../functions-nested/src/lambda_utils.rs | 191 + datafusion/functions-nested/src/length.rs | 13 +- datafusion/functions-nested/src/lib.rs | 82 +- datafusion/functions-nested/src/macros.rs | 6 - .../functions-nested/src/macros_lambda.rs | 107 + datafusion/functions-nested/src/make_array.rs | 62 +- datafusion/functions-nested/src/map.rs | 234 +- .../functions-nested/src/map_entries.rs | 16 +- .../functions-nested/src/map_extract.rs | 21 +- datafusion/functions-nested/src/map_keys.rs | 16 +- datafusion/functions-nested/src/map_values.rs | 16 +- datafusion/functions-nested/src/min_max.rs | 131 +- datafusion/functions-nested/src/planner.rs | 25 +- datafusion/functions-nested/src/position.rs | 571 +- datafusion/functions-nested/src/range.rs | 157 +- datafusion/functions-nested/src/remove.rs | 1000 +++- datafusion/functions-nested/src/repeat.rs | 431 +- datafusion/functions-nested/src/replace.rs | 425 +- datafusion/functions-nested/src/resize.rs | 182 +- datafusion/functions-nested/src/reverse.rs | 15 +- datafusion/functions-nested/src/set_ops.rs | 547 +- datafusion/functions-nested/src/sort.rs | 392 +- datafusion/functions-nested/src/string.rs | 951 ++-- datafusion/functions-nested/src/utils.rs | 158 +- datafusion/functions/Cargo.toml | 122 +- datafusion/functions/benches/ascii.rs | 34 +- datafusion/functions/benches/atan2.rs | 146 + .../functions/benches/character_length.rs | 4 +- datafusion/functions/benches/chr.rs | 35 +- datafusion/functions/benches/concat.rs | 104 +- datafusion/functions/benches/concat_ws.rs | 123 + datafusion/functions/benches/contains.rs | 183 + datafusion/functions/benches/cot.rs | 51 +- datafusion/functions/benches/crypto.rs | 73 + datafusion/functions/benches/date_bin.rs | 8 +- datafusion/functions/benches/date_trunc.rs | 19 +- datafusion/functions/benches/encoding.rs | 24 +- datafusion/functions/benches/ends_with.rs | 183 + datafusion/functions/benches/factorial.rs | 65 + datafusion/functions/benches/find_in_set.rs | 6 +- datafusion/functions/benches/floor_ceil.rs | 133 + datafusion/functions/benches/gcd.rs | 6 +- datafusion/functions/benches/helper.rs | 2 +- datafusion/functions/benches/initcap.rs | 167 +- datafusion/functions/benches/isnan.rs | 4 +- datafusion/functions/benches/iszero.rs | 49 +- datafusion/functions/benches/lcm.rs | 66 + datafusion/functions/benches/left_right.rs | 115 + datafusion/functions/benches/levenshtein.rs | 85 + datafusion/functions/benches/lower.rs | 43 +- datafusion/functions/benches/make_date.rs | 20 +- datafusion/functions/benches/nanvl.rs | 114 + datafusion/functions/benches/nullif.rs | 6 +- datafusion/functions/benches/overlay.rs | 200 + datafusion/functions/benches/pad.rs | 738 ++- datafusion/functions/benches/power.rs | 140 + datafusion/functions/benches/random.rs | 4 +- datafusion/functions/benches/regexp_count.rs | 116 + datafusion/functions/benches/regx.rs | 46 +- datafusion/functions/benches/repeat.rs | 45 +- datafusion/functions/benches/replace.rs | 170 + datafusion/functions/benches/reverse.rs | 3 +- datafusion/functions/benches/round.rs | 152 + datafusion/functions/benches/signum.rs | 50 +- datafusion/functions/benches/split_part.rs | 272 + datafusion/functions/benches/starts_with.rs | 183 + datafusion/functions/benches/strpos.rs | 325 +- datafusion/functions/benches/substr.rs | 218 +- datafusion/functions/benches/substr_index.rs | 263 +- datafusion/functions/benches/to_char.rs | 97 +- datafusion/functions/benches/to_hex.rs | 121 +- datafusion/functions/benches/to_local_time.rs | 90 + datafusion/functions/benches/to_time.rs | 94 + datafusion/functions/benches/to_timestamp.rs | 28 +- datafusion/functions/benches/translate.rs | 165 + datafusion/functions/benches/trim.rs | 435 ++ datafusion/functions/benches/trunc.rs | 52 +- datafusion/functions/benches/upper.rs | 4 +- datafusion/functions/benches/uuid.rs | 4 +- datafusion/functions/src/core/arrow_cast.rs | 53 +- datafusion/functions/src/core/arrow_field.rs | 162 + .../functions/src/core/arrow_metadata.rs | 155 + .../functions/src/core/arrow_try_cast.rs | 151 + datafusion/functions/src/core/arrowtypeof.rs | 6 +- datafusion/functions/src/core/cast_to_type.rs | 146 + datafusion/functions/src/core/coalesce.rs | 11 +- datafusion/functions/src/core/getfield.rs | 1229 +++- datafusion/functions/src/core/greatest.rs | 17 +- .../src/core/greatest_least_utils.rs | 20 +- datafusion/functions/src/core/least.rs | 17 +- datafusion/functions/src/core/mod.rs | 51 +- datafusion/functions/src/core/named_struct.rs | 44 +- datafusion/functions/src/core/nullif.rs | 123 +- datafusion/functions/src/core/nvl.rs | 8 +- datafusion/functions/src/core/nvl2.rs | 22 +- datafusion/functions/src/core/overlay.rs | 328 +- datafusion/functions/src/core/planner.rs | 4 +- datafusion/functions/src/core/struct.rs | 10 +- .../functions/src/core/try_cast_to_type.rs | 130 + .../functions/src/core/union_extract.rs | 24 +- datafusion/functions/src/core/union_tag.rs | 10 +- datafusion/functions/src/core/version.rs | 8 +- .../functions/src/core/with_metadata.rs | 335 ++ datafusion/functions/src/crypto/basic.rs | 300 +- datafusion/functions/src/crypto/digest.rs | 53 +- datafusion/functions/src/crypto/md5.rs | 103 +- datafusion/functions/src/crypto/mod.rs | 13 +- datafusion/functions/src/crypto/sha.rs | 170 + datafusion/functions/src/datetime/common.rs | 175 +- .../functions/src/datetime/current_date.rs | 58 +- .../functions/src/datetime/current_time.rs | 104 +- datafusion/functions/src/datetime/date_bin.rs | 668 ++- .../functions/src/datetime/date_part.rs | 249 +- .../functions/src/datetime/date_trunc.rs | 384 +- .../functions/src/datetime/from_unixtime.rs | 23 +- .../functions/src/datetime/make_date.rs | 347 +- .../functions/src/datetime/make_time.rs | 268 + datafusion/functions/src/datetime/mod.rs | 62 +- datafusion/functions/src/datetime/now.rs | 55 +- datafusion/functions/src/datetime/planner.rs | 2 +- datafusion/functions/src/datetime/to_char.rs | 283 +- datafusion/functions/src/datetime/to_date.rs | 76 +- .../functions/src/datetime/to_local_time.rs | 336 +- datafusion/functions/src/datetime/to_time.rs | 239 + .../functions/src/datetime/to_timestamp.rs | 1170 +++- .../functions/src/datetime/to_unixtime.rs | 58 +- datafusion/functions/src/encoding/inner.rs | 685 ++- datafusion/functions/src/lib.rs | 4 +- datafusion/functions/src/macros.rs | 238 +- datafusion/functions/src/math/abs.rs | 17 +- datafusion/functions/src/math/ceil.rs | 201 + datafusion/functions/src/math/cot.rs | 306 +- datafusion/functions/src/math/decimal.rs | 111 + datafusion/functions/src/math/factorial.rs | 129 +- datafusion/functions/src/math/floor.rs | 684 +++ datafusion/functions/src/math/gcd.rs | 73 +- datafusion/functions/src/math/iszero.rs | 240 +- datafusion/functions/src/math/lcm.rs | 62 +- datafusion/functions/src/math/log.rs | 255 +- datafusion/functions/src/math/mod.rs | 33 +- datafusion/functions/src/math/monotonicity.rs | 63 +- datafusion/functions/src/math/nans.rs | 169 +- datafusion/functions/src/math/nanvl.rs | 114 +- datafusion/functions/src/math/pi.rs | 8 +- datafusion/functions/src/math/power.rs | 565 +- datafusion/functions/src/math/random.rs | 15 +- datafusion/functions/src/math/round.rs | 785 ++- datafusion/functions/src/math/signum.rs | 95 +- datafusion/functions/src/math/trunc.rs | 78 +- datafusion/functions/src/planner.rs | 2 +- datafusion/functions/src/regex/mod.rs | 2 +- datafusion/functions/src/regex/regexpcount.rs | 134 +- datafusion/functions/src/regex/regexpinstr.rs | 38 +- datafusion/functions/src/regex/regexplike.rs | 306 +- datafusion/functions/src/regex/regexpmatch.rs | 45 +- .../functions/src/regex/regexpreplace.rs | 199 +- datafusion/functions/src/string/ascii.rs | 35 +- datafusion/functions/src/string/bit_length.rs | 5 - datafusion/functions/src/string/btrim.rs | 15 +- datafusion/functions/src/string/chr.rs | 182 +- datafusion/functions/src/string/common.rs | 622 +- datafusion/functions/src/string/concat.rs | 199 +- datafusion/functions/src/string/concat_ws.rs | 627 ++- datafusion/functions/src/string/contains.rs | 96 +- datafusion/functions/src/string/ends_with.rs | 297 +- .../functions/src/string/levenshtein.rs | 39 +- datafusion/functions/src/string/lower.rs | 211 +- datafusion/functions/src/string/ltrim.rs | 23 +- .../functions/src/string/octet_length.rs | 7 +- datafusion/functions/src/string/repeat.rs | 387 +- datafusion/functions/src/string/replace.rs | 165 +- datafusion/functions/src/string/rtrim.rs | 23 +- datafusion/functions/src/string/split_part.rs | 784 ++- .../functions/src/string/starts_with.rs | 289 +- datafusion/functions/src/string/to_hex.rs | 289 +- datafusion/functions/src/string/upper.rs | 211 +- datafusion/functions/src/string/uuid.rs | 17 +- datafusion/functions/src/strings.rs | 1362 ++++- .../functions/src/unicode/character_length.rs | 17 +- datafusion/functions/src/unicode/common.rs | 275 + .../functions/src/unicode/find_in_set.rs | 143 +- datafusion/functions/src/unicode/initcap.rs | 302 +- datafusion/functions/src/unicode/left.rs | 214 +- datafusion/functions/src/unicode/lpad.rs | 346 +- datafusion/functions/src/unicode/mod.rs | 1 + datafusion/functions/src/unicode/planner.rs | 2 +- datafusion/functions/src/unicode/reverse.rs | 186 +- datafusion/functions/src/unicode/right.rs | 223 +- datafusion/functions/src/unicode/rpad.rs | 562 +- datafusion/functions/src/unicode/strpos.rs | 292 +- datafusion/functions/src/unicode/substr.rs | 492 +- .../functions/src/unicode/substrindex.rs | 685 ++- datafusion/functions/src/unicode/translate.rs | 485 +- datafusion/functions/src/utils.rs | 178 +- datafusion/optimizer/Cargo.toml | 18 +- .../optimizer/benches/optimize_projections.rs | 235 + .../benches/projection_unnecessary.rs | 4 +- .../optimizer/benches/unions_to_filter.rs | 195 + .../src/analyzer/function_rewrite.rs | 2 +- datafusion/optimizer/src/analyzer/mod.rs | 2 +- .../src/analyzer/resolve_grouping_function.rs | 95 +- .../optimizer/src/analyzer/type_coercion.rs | 499 +- .../optimizer/src/common_subexpr_eliminate.rs | 119 +- datafusion/optimizer/src/decorrelate.rs | 76 +- .../optimizer/src/decorrelate_lateral_join.rs | 344 +- .../src/decorrelate_predicate_subquery.rs | 115 +- .../optimizer/src/eliminate_cross_join.rs | 182 +- .../src/eliminate_duplicated_expr.rs | 33 +- datafusion/optimizer/src/eliminate_filter.rs | 6 +- .../src/eliminate_group_by_constant.rs | 123 +- datafusion/optimizer/src/eliminate_join.rs | 4 +- datafusion/optimizer/src/eliminate_limit.rs | 11 +- .../optimizer/src/eliminate_outer_join.rs | 1460 ++++- .../src/extract_equijoin_predicate.rs | 10 +- .../optimizer/src/extract_leaf_expressions.rs | 3089 ++++++++++ .../optimizer/src/filter_null_join_keys.rs | 6 +- datafusion/optimizer/src/join_key_set.rs | 2 +- datafusion/optimizer/src/lib.rs | 7 +- .../optimizer/src/optimize_projections/mod.rs | 530 +- .../optimize_projections/required_indices.rs | 64 +- datafusion/optimizer/src/optimize_unions.rs | 71 +- datafusion/optimizer/src/optimizer.rs | 373 +- datafusion/optimizer/src/plan_signature.rs | 2 +- .../optimizer/src/propagate_empty_relation.rs | 253 +- datafusion/optimizer/src/push_down_filter.rs | 1044 ++-- datafusion/optimizer/src/push_down_limit.rs | 20 +- .../src/replace_distinct_aggregate.rs | 8 +- .../optimizer/src/rewrite_set_comparison.rs | 171 + .../optimizer/src/scalar_subquery_to_join.rs | 467 +- .../simplify_expressions/expr_simplifier.rs | 763 ++- .../simplify_expressions/inlist_simplifier.rs | 92 +- .../simplify_expressions/linear_aggregates.rs | 229 + .../optimizer/src/simplify_expressions/mod.rs | 7 +- .../src/simplify_expressions/regex.rs | 174 +- .../reorder_predicates.rs | 193 + .../simplify_expressions/simplify_exprs.rs | 251 +- .../simplify_expressions/simplify_literal.rs | 151 + .../simplify_predicates.rs | 14 +- .../src/simplify_expressions/udf_preimage.rs | 402 ++ .../src/simplify_expressions/unwrap_cast.rs | 28 +- .../src/simplify_expressions/utils.rs | 52 +- .../src/single_distinct_to_groupby.rs | 27 +- datafusion/optimizer/src/test/mod.rs | 27 +- datafusion/optimizer/src/test/udfs.rs | 98 + datafusion/optimizer/src/test/user_defined.rs | 2 +- datafusion/optimizer/src/unions_to_filter.rs | 652 +++ datafusion/optimizer/src/utils.rs | 52 +- .../optimizer/tests/optimizer_integration.rs | 576 +- datafusion/physical-expr-common/Cargo.toml | 20 +- .../benches/compare_nested.rs | 74 + .../physical-expr-common/src/binary_map.rs | 12 +- .../src/binary_view_map.rs | 241 +- datafusion/physical-expr-common/src/datum.rs | 37 +- datafusion/physical-expr-common/src/lib.rs | 3 +- .../src/metrics/baseline.rs | 376 ++ .../src/metrics/builder.rs | 341 ++ .../src/metrics/custom.rs | 114 + .../src/metrics/elapsed_compute.rs | 101 + .../src/metrics/expression.rs | 88 + .../physical-expr-common/src/metrics/mod.rs | 968 ++++ .../physical-expr-common/src/metrics/value.rs | 1566 ++++++ .../physical-expr-common/src/physical_expr.rs | 533 +- .../physical-expr-common/src/sort_expr.rs | 127 +- .../physical-expr-common/src/tree_node.rs | 4 +- datafusion/physical-expr-common/src/utils.rs | 503 +- datafusion/physical-expr/Cargo.toml | 27 +- datafusion/physical-expr/benches/binary_op.rs | 91 +- datafusion/physical-expr/benches/case_when.rs | 92 +- datafusion/physical-expr/benches/in_list.rs | 440 +- .../physical-expr/benches/in_list_strategy.rs | 1037 ++++ datafusion/physical-expr/benches/is_null.rs | 4 +- datafusion/physical-expr/benches/simplify.rs | 299 + .../physical-expr/benches/string_concat.rs | 94 + datafusion/physical-expr/src/aggregate.rs | 550 +- datafusion/physical-expr/src/analysis.rs | 196 +- .../src/async_scalar_function.rs | 25 +- .../physical-expr/src/equivalence/class.rs | 55 +- .../physical-expr/src/equivalence/mod.rs | 7 +- .../physical-expr/src/equivalence/ordering.rs | 15 +- .../src/equivalence/properties/dependency.rs | 147 +- .../src/equivalence/properties/joins.rs | 3 +- .../src/equivalence/properties/mod.rs | 96 +- .../src/equivalence/properties/union.rs | 4 +- .../physical-expr/src/expressions/binary.rs | 434 +- .../src/expressions/binary/kernels.rs | 119 +- .../physical-expr/src/expressions/case.rs | 1502 +++-- .../boolean_lookup_table.rs | 122 + .../bytes_like_lookup_table.rs | 223 + .../case/literal_lookup_table/mod.rs | 327 ++ .../primitive_lookup_table.rs | 229 + .../physical-expr/src/expressions/cast.rs | 701 ++- .../physical-expr/src/expressions/column.rs | 67 +- .../src/expressions/dynamic_filters/mod.rs | 1189 ++++ .../expressions/dynamic_filters/tracker.rs | 331 ++ .../physical-expr/src/expressions/in_list.rs | 4998 ++++++++++------- .../in_list/array_static_filter.rs | 160 + .../expressions/in_list/primitive_filter.rs | 233 + .../src/expressions/in_list/static_filter.rs | 37 + .../src/expressions/in_list/strategy.rs | 57 + .../src/expressions/is_not_null.rs | 158 +- .../physical-expr/src/expressions/is_null.rs | 155 +- .../physical-expr/src/expressions/lambda.rs | 252 + .../src/expressions/lambda_variable.rs | 146 + .../physical-expr/src/expressions/like.rs | 253 +- .../physical-expr/src/expressions/literal.rs | 147 +- .../physical-expr/src/expressions/mod.rs | 39 +- .../physical-expr/src/expressions/negative.rs | 198 +- .../physical-expr/src/expressions/no_op.rs | 8 +- .../physical-expr/src/expressions/not.rs | 199 +- .../physical-expr/src/expressions/try_cast.rs | 212 +- .../src/expressions/unknown_column.rs | 145 +- .../src/higher_order_function.rs | 718 +++ .../physical-expr/src/intervals/cp_solver.rs | 33 +- .../physical-expr/src/intervals/test_utils.rs | 6 +- .../physical-expr/src/intervals/utils.rs | 20 +- datafusion/physical-expr/src/lib.rs | 16 +- datafusion/physical-expr/src/partitioning.rs | 1226 +++- datafusion/physical-expr/src/physical_expr.rs | 17 +- datafusion/physical-expr/src/planner.rs | 320 +- datafusion/physical-expr/src/projection.rs | 1003 +++- .../physical-expr/src/proto_test_util.rs | 141 + .../physical-expr/src/scalar_function.rs | 87 +- .../physical-expr/src/scalar_subquery.rs | 240 + .../src/simplifier/const_evaluator.rs | 207 + .../physical-expr/src/simplifier/mod.rs | 636 ++- .../physical-expr/src/simplifier/not.rs | 128 + .../src/simplifier/unwrap_cast.rs | 188 +- .../physical-expr/src/statistics/mod.rs | 5 +- .../src/statistics/stats_solver.rs | 23 +- .../physical-expr/src/utils/guarantee.rs | 36 +- datafusion/physical-expr/src/utils/mod.rs | 124 +- .../physical-expr/src/window/aggregate.rs | 4 +- .../src/window/sliding_aggregate.rs | 4 +- .../physical-expr/src/window/standard.rs | 8 +- .../window/standard_window_function_expr.rs | 3 +- .../physical-expr/src/window/window_expr.rs | 9 +- datafusion/physical-optimizer/Cargo.toml | 1 + .../src/aggregate_statistics.rs | 75 +- .../src/combine_partial_final_agg.rs | 20 +- .../physical-optimizer/src/ensure_coop.rs | 336 +- .../enforce_distribution.rs | 1423 +++++ .../enforce_sorting/mod.rs | 746 +++ .../replace_with_order_preserving_variants.rs | 299 + .../enforce_sorting/sort_pushdown.rs | 966 ++++ .../src/ensure_requirements/mod.rs | 259 + .../physical-optimizer/src/filter_pushdown.rs | 6 +- .../src/hash_join_buffering.rs | 103 + .../physical-optimizer/src/join_selection.rs | 341 +- datafusion/physical-optimizer/src/lib.rs | 16 +- .../physical-optimizer/src/limit_pushdown.rs | 260 +- .../src/limit_pushdown_past_window.rs | 135 +- .../src/limited_distinct_aggregation.rs | 42 +- .../physical-optimizer/src/optimizer.rs | 140 +- .../src/output_requirements.rs | 47 +- .../src/projection_pushdown.rs | 53 +- .../physical-optimizer/src/pushdown_sort.rs | 199 + .../physical-optimizer/src/sanity_checker.rs | 55 +- .../src/topk_aggregation.rs | 74 +- .../src/topk_repartition.rs | 367 ++ .../src/update_aggr_exprs.rs | 12 +- datafusion/physical-optimizer/src/utils.rs | 66 +- .../physical-optimizer/src/window_topn.rs | 331 ++ datafusion/physical-plan/Cargo.toml | 42 +- .../benches/aggregate_vectorized.rs | 21 +- .../benches/dictionary_group_values.rs | 176 + .../benches/hash_join_semi_anti.rs | 387 ++ .../physical-plan/benches/multi_group_by.rs | 356 ++ .../physical-plan/benches/partial_ordering.rs | 2 +- .../physical-plan/benches/sort_merge_join.rs | 204 + .../benches/sort_preserving_merge.rs | 4 +- datafusion/physical-plan/benches/spill_io.rs | 7 +- .../src/aggregates/group_values/metrics.rs | 11 +- .../src/aggregates/group_values/mod.rs | 8 +- .../group_values/multi_group_by/boolean.rs | 196 +- .../group_values/multi_group_by/bytes.rs | 150 +- .../group_values/multi_group_by/bytes_view.rs | 313 +- .../group_values/multi_group_by/mod.rs | 619 +- .../group_values/multi_group_by/primitive.rs | 319 +- .../src/aggregates/group_values/row.rs | 51 +- .../group_values/single_group_by/boolean.rs | 3 +- .../group_values/single_group_by/bytes.rs | 4 +- .../single_group_by/bytes_view.rs | 4 +- .../group_values/single_group_by/primitive.rs | 103 +- .../src/aggregates/hash_aggregate.rs | 425 ++ .../src/aggregates/hash_table.rs | 623 ++ .../physical-plan/src/aggregates/mod.rs | 3719 ++++++++++-- .../src/aggregates/no_grouping.rs | 301 +- .../physical-plan/src/aggregates/order/mod.rs | 103 +- .../physical-plan/src/aggregates/row_hash.rs | 945 +++- .../src/aggregates/topk/hash_table.rs | 337 +- .../physical-plan/src/aggregates/topk/heap.rs | 329 +- .../src/aggregates/topk/priority_map.rs | 267 +- .../src/aggregates/topk_stream.rs | 133 +- datafusion/physical-plan/src/analyze.rs | 217 +- datafusion/physical-plan/src/async_func.rs | 184 +- datafusion/physical-plan/src/buffer.rs | 639 +++ datafusion/physical-plan/src/coalesce/mod.rs | 6 +- .../physical-plan/src/coalesce_batches.rs | 83 +- .../physical-plan/src/coalesce_partitions.rs | 141 +- .../physical-plan/src/column_rewriter.rs | 382 ++ datafusion/physical-plan/src/common.rs | 266 +- datafusion/physical-plan/src/coop.rs | 78 +- datafusion/physical-plan/src/display.rs | 563 +- datafusion/physical-plan/src/empty.rs | 60 +- .../physical-plan/src/execution_plan.rs | 413 +- datafusion/physical-plan/src/explain.rs | 26 +- datafusion/physical-plan/src/filter.rs | 2534 ++++++++- .../physical-plan/src/filter_pushdown.rs | 166 +- .../physical-plan/src/joins/array_map.rs | 547 ++ datafusion/physical-plan/src/joins/chain.rs | 69 + .../physical-plan/src/joins/cross_join.rs | 221 +- .../physical-plan/src/joins/hash_join/exec.rs | 3061 +++++++--- .../src/joins/hash_join/inlist_builder.rs | 158 + .../physical-plan/src/joins/hash_join/mod.rs | 4 +- .../joins/hash_join/partitioned_hash_eval.rs | 734 ++- .../src/joins/hash_join/shared_bounds.rs | 703 ++- .../src/joins/hash_join/stream.rs | 659 ++- .../physical-plan/src/joins/join_hash_map.rs | 188 +- datafusion/physical-plan/src/joins/mod.rs | 39 +- .../src/joins/nested_loop_join.rs | 1780 +++++- .../piecewise_merge_join/classic_join.rs | 174 +- .../src/joins/piecewise_merge_join/exec.rs | 115 +- .../joins/sort_merge_join/bitwise_stream.rs | 1344 +++++ .../src/joins/sort_merge_join/exec.rs | 162 +- .../src/joins/sort_merge_join/filter.rs | 388 ++ .../sort_merge_join/materializing_stream.rs | 1948 +++++++ .../src/joins/sort_merge_join/metrics.rs | 27 +- .../src/joins/sort_merge_join/mod.rs | 4 +- .../src/joins/sort_merge_join/tests.rs | 3840 ++++++++++--- .../src/joins/stream_join_utils.rs | 147 +- .../src/joins/symmetric_hash_join.rs | 100 +- .../physical-plan/src/joins/test_utils.rs | 7 +- datafusion/physical-plan/src/joins/utils.rs | 2003 ++++++- datafusion/physical-plan/src/lib.rs | 22 +- datafusion/physical-plan/src/limit.rs | 123 +- datafusion/physical-plan/src/memory.rs | 122 +- datafusion/physical-plan/src/metrics.rs | 21 + .../src/operator_statistics/mod.rs | 2297 ++++++++ .../physical-plan/src/placeholder_row.rs | 36 +- datafusion/physical-plan/src/projection.rs | 903 ++- .../physical-plan/src/recursive_query.rs | 221 +- datafusion/physical-plan/src/render_tree.rs | 3 +- .../src/repartition/distributor_channels.rs | 4 +- .../physical-plan/src/repartition/mod.rs | 990 +++- .../physical-plan/src/scalar_subquery.rs | 558 ++ datafusion/physical-plan/src/sort_pushdown.rs | 120 + datafusion/physical-plan/src/sorts/builder.rs | 283 +- datafusion/physical-plan/src/sorts/cursor.rs | 5 +- datafusion/physical-plan/src/sorts/merge.rs | 127 +- datafusion/physical-plan/src/sorts/mod.rs | 3 + .../src/sorts/multi_level_merge.rs | 54 +- .../physical-plan/src/sorts/partial_sort.rs | 204 +- .../src/sorts/partitioned_topk.rs | 515 ++ datafusion/physical-plan/src/sorts/sort.rs | 1098 +++- .../src/sorts/sort_preserving_merge.rs | 317 +- datafusion/physical-plan/src/sorts/stream.rs | 174 +- .../src/sorts/streaming_merge.rs | 6 +- .../src/spill/in_progress_spill_file.rs | 132 +- datafusion/physical-plan/src/spill/mod.rs | 994 +++- .../src/spill/replayable_spill_input.rs | 448 ++ .../physical-plan/src/spill/spill_manager.rs | 211 +- .../physical-plan/src/spill/spill_pool.rs | 188 +- datafusion/physical-plan/src/stream.rs | 265 +- datafusion/physical-plan/src/streaming.rs | 21 +- datafusion/physical-plan/src/test.rs | 42 +- datafusion/physical-plan/src/test/exec.rs | 218 +- datafusion/physical-plan/src/topk/mod.rs | 308 +- datafusion/physical-plan/src/tree_node.rs | 4 +- datafusion/physical-plan/src/union.rs | 638 ++- datafusion/physical-plan/src/unnest.rs | 152 +- .../src/windows/bounded_window_agg_exec.rs | 206 +- datafusion/physical-plan/src/windows/mod.rs | 70 +- .../src/windows/window_agg_exec.rs | 123 +- datafusion/physical-plan/src/work_table.rs | 130 +- datafusion/proto/Cargo.toml | 21 +- datafusion/proto/regen.sh | 4 +- datafusion/proto/src/bytes/mod.rs | 185 +- datafusion/proto/src/common.rs | 20 +- datafusion/proto/src/convert.rs | 44 + datafusion/proto/src/lib.rs | 41 +- .../proto/src/logical_plan/file_formats.rs | 398 +- .../proto/src/logical_plan/from_proto.rs | 294 +- datafusion/proto/src/logical_plan/mod.rs | 619 +- datafusion/proto/src/logical_plan/to_proto.rs | 168 +- .../proto/src/physical_plan/from_proto.rs | 700 ++- datafusion/proto/src/physical_plan/mod.rs | 2188 +++++--- .../proto/src/physical_plan/to_proto.rs | 478 +- datafusion/proto/tests/cases/mod.rs | 11 - .../tests/cases/roundtrip_logical_plan.rs | 1011 +++- .../tests/cases/roundtrip_physical_plan.rs | 2114 ++++++- datafusion/proto/tests/cases/serialize.rs | 12 +- datafusion/pruning/Cargo.toml | 2 +- datafusion/pruning/LICENSE.txt | 1 + datafusion/pruning/NOTICE.txt | 1 + datafusion/pruning/src/file_pruner.rs | 186 +- datafusion/pruning/src/lib.rs | 6 +- datafusion/pruning/src/pruning_predicate.rs | 998 +++- datafusion/spark/Cargo.toml | 48 +- datafusion/spark/benches/char.rs | 4 +- datafusion/spark/benches/floor.rs | 119 + datafusion/spark/benches/hex.rs | 150 + datafusion/spark/benches/sha2.rs | 105 + datafusion/spark/benches/slice.rs | 185 + datafusion/spark/benches/space.rs | 71 + datafusion/spark/benches/substring.rs | 205 + datafusion/spark/benches/unhex.rs | 146 + .../spark/src/function/aggregate/avg.rs | 188 +- .../spark/src/function/aggregate/collect.rs | 192 + .../spark/src/function/aggregate/mod.rs | 29 +- .../spark/src/function/aggregate/try_sum.rs | 655 +++ .../src/function/array/array_contains.rs | 163 + datafusion/spark/src/function/array/mod.rs | 29 +- datafusion/spark/src/function/array/repeat.rs | 121 + .../spark/src/function/array/shuffle.rs | 94 +- datafusion/spark/src/function/array/slice.rs | 249 + .../spark/src/function/array/spark_array.rs | 154 +- .../function/bitmap/bitmap_bit_position.rs | 138 + .../function/bitmap/bitmap_bucket_number.rs | 138 + .../spark/src/function/bitmap/bitmap_count.rs | 78 +- datafusion/spark/src/function/bitmap/mod.rs | 23 +- .../spark/src/function/bitwise/bit_count.rs | 65 +- .../spark/src/function/bitwise/bit_get.rs | 66 +- .../spark/src/function/bitwise/bit_shift.rs | 81 +- .../spark/src/function/bitwise/bitwise_not.rs | 91 +- .../spark/src/function/collection/mod.rs | 13 +- .../spark/src/function/collection/size.rs | 157 + .../spark/src/function/conditional/if.rs | 12 +- .../spark/src/function/conversion/cast.rs | 1007 ++++ .../spark/src/function/conversion/mod.rs | 19 +- .../spark/src/function/datetime/add_months.rs | 90 + .../spark/src/function/datetime/date_add.rs | 94 +- .../spark/src/function/datetime/date_diff.rs | 114 + .../spark/src/function/datetime/date_part.rs | 138 + .../spark/src/function/datetime/date_sub.rs | 90 +- .../spark/src/function/datetime/date_trunc.rs | 167 + .../spark/src/function/datetime/extract.rs | 254 + .../function/datetime/from_utc_timestamp.rs | 190 + .../spark/src/function/datetime/last_day.rs | 90 +- .../src/function/datetime/make_dt_interval.rs | 178 +- .../src/function/datetime/make_interval.rs | 94 +- datafusion/spark/src/function/datetime/mod.rs | 134 + .../spark/src/function/datetime/monthname.rs | 115 + .../spark/src/function/datetime/next_day.rs | 92 +- .../spark/src/function/datetime/time_trunc.rs | 117 + .../src/function/datetime/to_utc_timestamp.rs | 220 + .../spark/src/function/datetime/trunc.rs | 138 + .../spark/src/function/datetime/unix.rs | 165 + datafusion/spark/src/function/error_utils.rs | 6 +- datafusion/spark/src/function/hash/crc32.rs | 52 +- datafusion/spark/src/function/hash/mod.rs | 8 +- datafusion/spark/src/function/hash/sha1.rs | 63 +- datafusion/spark/src/function/hash/sha2.rs | 357 +- datafusion/spark/src/function/hash/utils.rs | 1005 ++++ .../spark/src/function/hash/xxhash64.rs | 445 ++ .../spark/src/function/json/json_tuple.rs | 238 + datafusion/spark/src/function/json/mod.rs | 17 +- .../spark/src/function/map/map_from_arrays.rs | 113 +- .../src/function/map/map_from_entries.rs | 124 +- datafusion/spark/src/function/map/mod.rs | 10 +- .../spark/src/function/map/str_to_map.rs | 306 + datafusion/spark/src/function/map/utils.rs | 232 +- datafusion/spark/src/function/math/abs.rs | 437 +- datafusion/spark/src/function/math/bin.rs | 106 + datafusion/spark/src/function/math/ceil.rs | 304 + datafusion/spark/src/function/math/expm1.rs | 5 - .../spark/src/function/math/factorial.rs | 9 +- datafusion/spark/src/function/math/floor.rs | 182 + datafusion/spark/src/function/math/hex.rs | 519 +- datafusion/spark/src/function/math/mod.rs | 50 +- datafusion/spark/src/function/math/modulus.rs | 190 +- .../spark/src/function/math/negative.rs | 472 ++ datafusion/spark/src/function/math/pow.rs | 152 + datafusion/spark/src/function/math/rint.rs | 9 +- datafusion/spark/src/function/math/round.rs | 654 +++ .../spark/src/function/math/trigonometry.rs | 9 - datafusion/spark/src/function/math/unhex.rs | 216 + .../spark/src/function/math/width_bucket.rs | 253 +- datafusion/spark/src/function/mod.rs | 1 + datafusion/spark/src/function/null_utils.rs | 108 + datafusion/spark/src/function/string/ascii.rs | 81 +- .../spark/src/function/string/base64.rs | 174 + datafusion/spark/src/function/string/char.rs | 69 +- .../spark/src/function/string/concat.rs | 226 +- datafusion/spark/src/function/string/elt.rs | 12 +- .../src/function/string/format_string.rs | 1113 +++- datafusion/spark/src/function/string/ilike.rs | 94 +- .../src/function/string/is_valid_utf8.rs | 120 + .../spark/src/function/string/length.rs | 58 +- datafusion/spark/src/function/string/like.rs | 96 +- .../spark/src/function/string/luhn_check.rs | 8 +- .../src/function/string/make_valid_utf8.rs | 125 + datafusion/spark/src/function/string/mod.rs | 55 + datafusion/spark/src/function/string/quote.rs | 121 + .../spark/src/function/string/soundex.rs | 150 + datafusion/spark/src/function/string/space.rs | 227 + .../spark/src/function/string/substring.rs | 404 ++ datafusion/spark/src/function/url/mod.rs | 29 +- .../spark/src/function/url/parse_url.rs | 276 +- .../spark/src/function/url/try_parse_url.rs | 8 +- .../spark/src/function/url/try_url_decode.rs | 103 + .../spark/src/function/url/url_decode.rs | 254 + .../spark/src/function/url/url_encode.rs | 126 + datafusion/spark/src/lib.rs | 50 +- datafusion/spark/src/planner.rs | 43 + datafusion/spark/src/session_state.rs | 147 + 818 files changed, 182777 insertions(+), 34613 deletions(-) create mode 100644 datafusion/datasource-parquet/benches/parquet_metadata_statistics.rs create mode 100644 datafusion/datasource-parquet/benches/parquet_nested_filter_pushdown.rs create mode 100644 datafusion/datasource-parquet/benches/parquet_struct_filter_pushdown.rs create mode 100644 datafusion/datasource-parquet/src/bloom_filter.rs create mode 100644 datafusion/datasource-parquet/src/decoder_projection.rs create mode 100644 datafusion/datasource-parquet/src/opener/early_stop.rs create mode 100644 datafusion/datasource-parquet/src/opener/encryption.rs create mode 100644 datafusion/datasource-parquet/src/opener/mod.rs create mode 100644 datafusion/datasource-parquet/src/push_decoder.rs create mode 100644 datafusion/datasource-parquet/src/schema_coercion.rs create mode 100644 datafusion/datasource-parquet/src/sink.rs create mode 100644 datafusion/datasource-parquet/src/sort.rs create mode 100644 datafusion/datasource-parquet/src/supported_predicates.rs create mode 100644 datafusion/datasource-parquet/src/test_data/ndv_test.parquet create mode 100644 datafusion/datasource-parquet/src/test_util.rs create mode 100644 datafusion/datasource-parquet/src/virtual_column.rs create mode 100644 datafusion/datasource/src/file_scan_config/mod.rs create mode 100644 datafusion/datasource/src/file_scan_config/sort_pushdown.rs create mode 100644 datafusion/datasource/src/file_stream/builder.rs create mode 100644 datafusion/datasource/src/file_stream/metrics.rs create mode 100644 datafusion/datasource/src/file_stream/mod.rs create mode 100644 datafusion/datasource/src/file_stream/scan_state.rs create mode 100644 datafusion/datasource/src/file_stream/work_source.rs create mode 100644 datafusion/datasource/src/morsel/adapters.rs create mode 100644 datafusion/datasource/src/morsel/mocks.rs create mode 100644 datafusion/datasource/src/morsel/mod.rs create mode 100644 datafusion/datasource/src/projection.rs create mode 100644 datafusion/expr-common/src/placement.rs create mode 100644 datafusion/expr/src/extension_types/array_formatter_factory.rs create mode 100644 datafusion/expr/src/extension_types/mod.rs create mode 100644 datafusion/expr/src/higher_order_function.rs create mode 100644 datafusion/expr/src/preimage.rs create mode 100644 datafusion/expr/src/sql.rs create mode 100644 datafusion/functions-aggregate/benches/approx_distinct.rs create mode 100644 datafusion/functions-aggregate/benches/count_distinct.rs create mode 100644 datafusion/functions-aggregate/benches/first_last.rs create mode 100644 datafusion/functions-aggregate/benches/median.rs create mode 100644 datafusion/functions-aggregate/benches/percentile_cont.rs create mode 100644 datafusion/functions-aggregate/src/first_last/state.rs create mode 100644 datafusion/functions-nested/benches/array_concat.rs create mode 100644 datafusion/functions-nested/benches/array_min_max.rs create mode 100644 datafusion/functions-nested/benches/array_position.rs create mode 100644 datafusion/functions-nested/benches/array_range.rs create mode 100644 datafusion/functions-nested/benches/array_remove.rs create mode 100644 datafusion/functions-nested/benches/array_repeat.rs create mode 100644 datafusion/functions-nested/benches/array_replace.rs create mode 100644 datafusion/functions-nested/benches/array_resize.rs create mode 100644 datafusion/functions-nested/benches/array_set_ops.rs create mode 100644 datafusion/functions-nested/benches/array_slice.rs create mode 100644 datafusion/functions-nested/benches/array_sort.rs create mode 100644 datafusion/functions-nested/benches/array_to_string.rs create mode 100644 datafusion/functions-nested/benches/arrays_zip.rs create mode 100644 datafusion/functions-nested/benches/string_to_array.rs create mode 100644 datafusion/functions-nested/src/array_add.rs create mode 100644 datafusion/functions-nested/src/array_any_match.rs create mode 100644 datafusion/functions-nested/src/array_compact.rs create mode 100644 datafusion/functions-nested/src/array_filter.rs create mode 100644 datafusion/functions-nested/src/array_normalize.rs create mode 100644 datafusion/functions-nested/src/array_product.rs create mode 100644 datafusion/functions-nested/src/array_scale.rs create mode 100644 datafusion/functions-nested/src/array_subtract.rs create mode 100644 datafusion/functions-nested/src/array_sum.rs create mode 100644 datafusion/functions-nested/src/array_transform.rs create mode 100644 datafusion/functions-nested/src/arrays_zip.rs create mode 100644 datafusion/functions-nested/src/cosine_distance.rs create mode 100644 datafusion/functions-nested/src/inner_product.rs create mode 100644 datafusion/functions-nested/src/lambda_utils.rs create mode 100644 datafusion/functions-nested/src/macros_lambda.rs create mode 100644 datafusion/functions/benches/atan2.rs create mode 100644 datafusion/functions/benches/concat_ws.rs create mode 100644 datafusion/functions/benches/contains.rs create mode 100644 datafusion/functions/benches/crypto.rs create mode 100644 datafusion/functions/benches/ends_with.rs create mode 100644 datafusion/functions/benches/factorial.rs create mode 100644 datafusion/functions/benches/floor_ceil.rs create mode 100644 datafusion/functions/benches/lcm.rs create mode 100644 datafusion/functions/benches/left_right.rs create mode 100644 datafusion/functions/benches/levenshtein.rs create mode 100644 datafusion/functions/benches/nanvl.rs create mode 100644 datafusion/functions/benches/overlay.rs create mode 100644 datafusion/functions/benches/power.rs create mode 100644 datafusion/functions/benches/regexp_count.rs create mode 100644 datafusion/functions/benches/replace.rs create mode 100644 datafusion/functions/benches/round.rs create mode 100644 datafusion/functions/benches/split_part.rs create mode 100644 datafusion/functions/benches/starts_with.rs create mode 100644 datafusion/functions/benches/to_local_time.rs create mode 100644 datafusion/functions/benches/to_time.rs create mode 100644 datafusion/functions/benches/translate.rs create mode 100644 datafusion/functions/benches/trim.rs create mode 100644 datafusion/functions/src/core/arrow_field.rs create mode 100644 datafusion/functions/src/core/arrow_metadata.rs create mode 100644 datafusion/functions/src/core/arrow_try_cast.rs create mode 100644 datafusion/functions/src/core/cast_to_type.rs create mode 100644 datafusion/functions/src/core/try_cast_to_type.rs create mode 100644 datafusion/functions/src/core/with_metadata.rs create mode 100644 datafusion/functions/src/crypto/sha.rs create mode 100644 datafusion/functions/src/datetime/make_time.rs create mode 100644 datafusion/functions/src/datetime/to_time.rs create mode 100644 datafusion/functions/src/math/ceil.rs create mode 100644 datafusion/functions/src/math/decimal.rs create mode 100644 datafusion/functions/src/math/floor.rs create mode 100644 datafusion/functions/src/unicode/common.rs create mode 100644 datafusion/optimizer/benches/optimize_projections.rs create mode 100644 datafusion/optimizer/benches/unions_to_filter.rs create mode 100644 datafusion/optimizer/src/extract_leaf_expressions.rs create mode 100644 datafusion/optimizer/src/rewrite_set_comparison.rs create mode 100644 datafusion/optimizer/src/simplify_expressions/linear_aggregates.rs create mode 100644 datafusion/optimizer/src/simplify_expressions/reorder_predicates.rs create mode 100644 datafusion/optimizer/src/simplify_expressions/simplify_literal.rs create mode 100644 datafusion/optimizer/src/simplify_expressions/udf_preimage.rs create mode 100644 datafusion/optimizer/src/test/udfs.rs create mode 100644 datafusion/optimizer/src/unions_to_filter.rs create mode 100644 datafusion/physical-expr-common/benches/compare_nested.rs create mode 100644 datafusion/physical-expr-common/src/metrics/baseline.rs create mode 100644 datafusion/physical-expr-common/src/metrics/builder.rs create mode 100644 datafusion/physical-expr-common/src/metrics/custom.rs create mode 100644 datafusion/physical-expr-common/src/metrics/elapsed_compute.rs create mode 100644 datafusion/physical-expr-common/src/metrics/expression.rs create mode 100644 datafusion/physical-expr-common/src/metrics/mod.rs create mode 100644 datafusion/physical-expr-common/src/metrics/value.rs create mode 100644 datafusion/physical-expr/benches/in_list_strategy.rs create mode 100644 datafusion/physical-expr/benches/simplify.rs create mode 100644 datafusion/physical-expr/benches/string_concat.rs create mode 100644 datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs create mode 100644 datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs create mode 100644 datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs create mode 100644 datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs create mode 100644 datafusion/physical-expr/src/expressions/dynamic_filters/mod.rs create mode 100644 datafusion/physical-expr/src/expressions/dynamic_filters/tracker.rs create mode 100644 datafusion/physical-expr/src/expressions/in_list/array_static_filter.rs create mode 100644 datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs create mode 100644 datafusion/physical-expr/src/expressions/in_list/static_filter.rs create mode 100644 datafusion/physical-expr/src/expressions/in_list/strategy.rs create mode 100644 datafusion/physical-expr/src/expressions/lambda.rs create mode 100644 datafusion/physical-expr/src/expressions/lambda_variable.rs create mode 100644 datafusion/physical-expr/src/higher_order_function.rs create mode 100644 datafusion/physical-expr/src/proto_test_util.rs create mode 100644 datafusion/physical-expr/src/scalar_subquery.rs create mode 100644 datafusion/physical-expr/src/simplifier/const_evaluator.rs create mode 100644 datafusion/physical-expr/src/simplifier/not.rs create mode 100644 datafusion/physical-optimizer/src/ensure_requirements/enforce_distribution.rs create mode 100644 datafusion/physical-optimizer/src/ensure_requirements/enforce_sorting/mod.rs create mode 100644 datafusion/physical-optimizer/src/ensure_requirements/enforce_sorting/replace_with_order_preserving_variants.rs create mode 100644 datafusion/physical-optimizer/src/ensure_requirements/enforce_sorting/sort_pushdown.rs create mode 100644 datafusion/physical-optimizer/src/ensure_requirements/mod.rs create mode 100644 datafusion/physical-optimizer/src/hash_join_buffering.rs create mode 100644 datafusion/physical-optimizer/src/pushdown_sort.rs create mode 100644 datafusion/physical-optimizer/src/topk_repartition.rs create mode 100644 datafusion/physical-optimizer/src/window_topn.rs create mode 100644 datafusion/physical-plan/benches/dictionary_group_values.rs create mode 100644 datafusion/physical-plan/benches/hash_join_semi_anti.rs create mode 100644 datafusion/physical-plan/benches/multi_group_by.rs create mode 100644 datafusion/physical-plan/benches/sort_merge_join.rs create mode 100644 datafusion/physical-plan/src/aggregates/hash_aggregate.rs create mode 100644 datafusion/physical-plan/src/aggregates/hash_table.rs create mode 100644 datafusion/physical-plan/src/buffer.rs create mode 100644 datafusion/physical-plan/src/column_rewriter.rs create mode 100644 datafusion/physical-plan/src/joins/array_map.rs create mode 100644 datafusion/physical-plan/src/joins/chain.rs create mode 100644 datafusion/physical-plan/src/joins/hash_join/inlist_builder.rs create mode 100644 datafusion/physical-plan/src/joins/sort_merge_join/bitwise_stream.rs create mode 100644 datafusion/physical-plan/src/joins/sort_merge_join/filter.rs create mode 100644 datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs create mode 100644 datafusion/physical-plan/src/metrics.rs create mode 100644 datafusion/physical-plan/src/operator_statistics/mod.rs create mode 100644 datafusion/physical-plan/src/scalar_subquery.rs create mode 100644 datafusion/physical-plan/src/sort_pushdown.rs create mode 100644 datafusion/physical-plan/src/sorts/partitioned_topk.rs create mode 100644 datafusion/physical-plan/src/spill/replayable_spill_input.rs create mode 100644 datafusion/proto/src/convert.rs create mode 120000 datafusion/pruning/LICENSE.txt create mode 120000 datafusion/pruning/NOTICE.txt create mode 100644 datafusion/spark/benches/floor.rs create mode 100644 datafusion/spark/benches/hex.rs create mode 100644 datafusion/spark/benches/sha2.rs create mode 100644 datafusion/spark/benches/slice.rs create mode 100644 datafusion/spark/benches/space.rs create mode 100644 datafusion/spark/benches/substring.rs create mode 100644 datafusion/spark/benches/unhex.rs create mode 100644 datafusion/spark/src/function/aggregate/collect.rs create mode 100644 datafusion/spark/src/function/aggregate/try_sum.rs create mode 100644 datafusion/spark/src/function/array/array_contains.rs create mode 100644 datafusion/spark/src/function/array/repeat.rs create mode 100644 datafusion/spark/src/function/array/slice.rs create mode 100644 datafusion/spark/src/function/bitmap/bitmap_bit_position.rs create mode 100644 datafusion/spark/src/function/bitmap/bitmap_bucket_number.rs create mode 100644 datafusion/spark/src/function/collection/size.rs create mode 100644 datafusion/spark/src/function/conversion/cast.rs create mode 100644 datafusion/spark/src/function/datetime/add_months.rs create mode 100644 datafusion/spark/src/function/datetime/date_diff.rs create mode 100644 datafusion/spark/src/function/datetime/date_part.rs create mode 100644 datafusion/spark/src/function/datetime/date_trunc.rs create mode 100644 datafusion/spark/src/function/datetime/extract.rs create mode 100644 datafusion/spark/src/function/datetime/from_utc_timestamp.rs create mode 100644 datafusion/spark/src/function/datetime/monthname.rs create mode 100644 datafusion/spark/src/function/datetime/time_trunc.rs create mode 100644 datafusion/spark/src/function/datetime/to_utc_timestamp.rs create mode 100644 datafusion/spark/src/function/datetime/trunc.rs create mode 100644 datafusion/spark/src/function/datetime/unix.rs create mode 100644 datafusion/spark/src/function/hash/utils.rs create mode 100644 datafusion/spark/src/function/hash/xxhash64.rs create mode 100644 datafusion/spark/src/function/json/json_tuple.rs create mode 100644 datafusion/spark/src/function/map/str_to_map.rs create mode 100644 datafusion/spark/src/function/math/bin.rs create mode 100644 datafusion/spark/src/function/math/ceil.rs create mode 100644 datafusion/spark/src/function/math/floor.rs create mode 100644 datafusion/spark/src/function/math/negative.rs create mode 100644 datafusion/spark/src/function/math/pow.rs create mode 100644 datafusion/spark/src/function/math/round.rs create mode 100644 datafusion/spark/src/function/math/unhex.rs create mode 100644 datafusion/spark/src/function/null_utils.rs create mode 100644 datafusion/spark/src/function/string/base64.rs create mode 100644 datafusion/spark/src/function/string/is_valid_utf8.rs create mode 100644 datafusion/spark/src/function/string/make_valid_utf8.rs create mode 100644 datafusion/spark/src/function/string/quote.rs create mode 100644 datafusion/spark/src/function/string/soundex.rs create mode 100644 datafusion/spark/src/function/string/space.rs create mode 100644 datafusion/spark/src/function/string/substring.rs create mode 100644 datafusion/spark/src/function/url/try_url_decode.rs create mode 100644 datafusion/spark/src/function/url/url_decode.rs create mode 100644 datafusion/spark/src/function/url/url_encode.rs create mode 100644 datafusion/spark/src/planner.rs create mode 100644 datafusion/spark/src/session_state.rs diff --git a/Cargo.lock b/Cargo.lock index b5620839efced..f20ae8cb55804 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1941,6 +1941,7 @@ version = "51.0.0" dependencies = [ "arrow", "async-trait", + "chrono", "datafusion-catalog", "datafusion-common", "datafusion-datasource", @@ -2147,15 +2148,19 @@ name = "datafusion-datasource-parquet" version = "51.0.0" dependencies = [ "arrow", + "arrow-schema", "async-trait", "bytes", "chrono", + "criterion", "datafusion-common", "datafusion-common-runtime", "datafusion-datasource", "datafusion-execution", "datafusion-expr", + "datafusion-functions", "datafusion-functions-aggregate-common", + "datafusion-functions-nested", "datafusion-physical-expr", "datafusion-physical-expr-adapter", "datafusion-physical-expr-common", @@ -2168,6 +2173,7 @@ dependencies = [ "object_store", "parking_lot", "parquet", + "tempfile", "tokio", ] diff --git a/datafusion/catalog-listing/Cargo.toml b/datafusion/catalog-listing/Cargo.toml index be1374b371485..61b55397137df 100644 --- a/datafusion/catalog-listing/Cargo.toml +++ b/datafusion/catalog-listing/Cargo.toml @@ -48,6 +48,7 @@ log = { workspace = true } object_store = { workspace = true } [dev-dependencies] +chrono = { workspace = true } datafusion-datasource-parquet = { workspace = true } # Note: add additional linter rules in lib.rs. diff --git a/datafusion/catalog-listing/src/config.rs b/datafusion/catalog-listing/src/config.rs index 3370d2ea75535..ca4d2abfcd737 100644 --- a/datafusion/catalog-listing/src/config.rs +++ b/datafusion/catalog-listing/src/config.rs @@ -19,9 +19,10 @@ use crate::options::ListingOptions; use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_catalog::Session; use datafusion_common::{config_err, internal_err}; +use datafusion_datasource::ListingTableUrl; use datafusion_datasource::file_compression_type::FileCompressionType; +#[expect(deprecated)] use datafusion_datasource::schema_adapter::SchemaAdapterFactory; -use datafusion_datasource::ListingTableUrl; use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; use std::str::FromStr; use std::sync::Arc; @@ -44,15 +45,12 @@ pub enum SchemaSource { /// # Schema Evolution Support /// /// This configuration supports schema evolution through the optional -/// [`SchemaAdapterFactory`]. You might want to override the default factory when you need: +/// [`PhysicalExprAdapterFactory`]. You might want to override the default factory when you need: /// /// - **Type coercion requirements**: When you need custom logic for converting between /// different Arrow data types (e.g., Int32 ↔ Int64, Utf8 ↔ LargeUtf8) /// - **Column mapping**: You need to map columns with a legacy name to a new name /// - **Custom handling of missing columns**: By default they are filled in with nulls, but you may e.g. want to fill them in with `0` or `""`. -/// -/// If not specified, a [`datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory`] -/// will be used, which handles basic schema compatibility cases. #[derive(Debug, Clone, Default)] pub struct ListingTableConfig { /// Paths on the `ObjectStore` for creating [`crate::ListingTable`]. @@ -68,8 +66,6 @@ pub struct ListingTableConfig { pub options: Option, /// Tracks the source of the schema information pub(crate) schema_source: SchemaSource, - /// Optional [`SchemaAdapterFactory`] for creating schema adapters - pub(crate) schema_adapter_factory: Option>, /// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters pub(crate) expr_adapter_factory: Option>, } @@ -218,8 +214,7 @@ impl ListingTableConfig { file_schema, options: _, schema_source, - schema_adapter_factory, - expr_adapter_factory: physical_expr_adapter_factory, + expr_adapter_factory, } = self; let (schema, new_schema_source) = match file_schema { @@ -241,8 +236,7 @@ impl ListingTableConfig { file_schema: Some(schema), options: Some(options), schema_source: new_schema_source, - schema_adapter_factory, - expr_adapter_factory: physical_expr_adapter_factory, + expr_adapter_factory, }) } None => internal_err!("No `ListingOptions` set for inferring schema"), @@ -282,7 +276,6 @@ impl ListingTableConfig { file_schema: self.file_schema, options: Some(options), schema_source: self.schema_source, - schema_adapter_factory: self.schema_adapter_factory, expr_adapter_factory: self.expr_adapter_factory, }) } @@ -290,63 +283,11 @@ impl ListingTableConfig { } } - /// Set the [`SchemaAdapterFactory`] for the [`crate::ListingTable`] - /// - /// The schema adapter factory is used to create schema adapters that can - /// handle schema evolution and type conversions when reading files with - /// different schemas than the table schema. - /// - /// If not provided, a default schema adapter factory will be used. - /// - /// # Example: Custom Schema Adapter for Type Coercion - /// ```rust - /// # use std::sync::Arc; - /// # use datafusion_catalog_listing::{ListingTableConfig, ListingOptions}; - /// # use datafusion_datasource::schema_adapter::{SchemaAdapterFactory, SchemaAdapter}; - /// # use datafusion_datasource::ListingTableUrl; - /// # use datafusion_datasource_parquet::file_format::ParquetFormat; - /// # use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; - /// # - /// # #[derive(Debug)] - /// # struct MySchemaAdapterFactory; - /// # impl SchemaAdapterFactory for MySchemaAdapterFactory { - /// # fn create(&self, _projected_table_schema: SchemaRef, _file_schema: SchemaRef) -> Box { - /// # unimplemented!() - /// # } - /// # } - /// # let table_paths = ListingTableUrl::parse("file:///path/to/data").unwrap(); - /// # let listing_options = ListingOptions::new(Arc::new(ParquetFormat::default())); - /// # let table_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); - /// let config = ListingTableConfig::new(table_paths) - /// .with_listing_options(listing_options) - /// .with_schema(table_schema) - /// .with_schema_adapter_factory(Arc::new(MySchemaAdapterFactory)); - /// ``` - pub fn with_schema_adapter_factory( - self, - schema_adapter_factory: Arc, - ) -> Self { - Self { - schema_adapter_factory: Some(schema_adapter_factory), - ..self - } - } - - /// Get the [`SchemaAdapterFactory`] for this configuration - pub fn schema_adapter_factory(&self) -> Option<&Arc> { - self.schema_adapter_factory.as_ref() - } - /// Set the [`PhysicalExprAdapterFactory`] for the [`crate::ListingTable`] /// /// The expression adapter factory is used to create physical expression adapters that can /// handle schema evolution and type conversions when evaluating expressions /// with different schemas than the table schema. - /// - /// If not provided, a default physical expression adapter factory will be used unless a custom - /// `SchemaAdapterFactory` is set, in which case only the `SchemaAdapterFactory` will be used. - /// - /// See for details on this transition. pub fn with_expr_adapter_factory( self, expr_adapter_factory: Arc, @@ -356,4 +297,23 @@ impl ListingTableConfig { ..self } } + + /// Deprecated: Set the [`SchemaAdapterFactory`] for the [`crate::ListingTable`] + /// + /// `SchemaAdapterFactory` has been removed. Use [`Self::with_expr_adapter_factory`] + /// and `PhysicalExprAdapterFactory` instead. See `upgrading.md` for more details. + /// + /// This method is a no-op and returns `self` unchanged. + #[deprecated( + since = "52.0.0", + note = "SchemaAdapterFactory has been removed. Use with_expr_adapter_factory and PhysicalExprAdapterFactory instead. See upgrading.md for more details." + )] + #[expect(deprecated)] + pub fn with_schema_adapter_factory( + self, + _schema_adapter_factory: Arc, + ) -> Self { + // No-op - just return self unchanged + self + } } diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 34073338fbd7e..4f83ec4b3730f 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -21,10 +21,12 @@ use std::mem; use std::sync::Arc; use datafusion_catalog::Session; -use datafusion_common::{assert_or_internal_err, HashMap, Result, ScalarValue}; -use datafusion_datasource::ListingTableUrl; +use datafusion_common::{ + HashMap, Result, ScalarValue, TableReference, assert_or_internal_err, +}; use datafusion_datasource::PartitionedFile; -use datafusion_expr::{lit, utils, BinaryExpr, Operator}; +use datafusion_datasource::{FileExtensions, ListingTableUrl}; +use datafusion_expr::{BinaryExpr, Operator, lit, utils}; use arrow::{ array::AsArray, @@ -33,7 +35,7 @@ use arrow::{ }; use datafusion_expr::execution_props::ExecutionProps; use futures::stream::FuturesUnordered; -use futures::{stream::BoxStream, StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt, stream::BoxStream}; use log::{debug, trace}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; @@ -51,7 +53,7 @@ use object_store::{ObjectMeta, ObjectStore}; pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { let mut is_applicable = true; expr.apply(|expr| match expr { - Expr::Column(Column { ref name, .. }) => { + Expr::Column(Column { name, .. }) => { is_applicable &= col_names.contains(&name.as_str()); if is_applicable { Ok(TreeNodeRecursion::Jump) @@ -83,13 +85,28 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::Exists(_) | Expr::InSubquery(_) | Expr::ScalarSubquery(_) + | Expr::SetComparison(_) | Expr::GroupingSet(_) - | Expr::Case(_) => Ok(TreeNodeRecursion::Continue), + | Expr::Case(_) + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match scalar_function.func.signature().volatility { Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context + // https://github.com/apache/datafusion/issues/21690 + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(TreeNodeRecursion::Stop) + } + } + } + Expr::HigherOrderFunction(hof) => { + match hof.func.signature().volatility { + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + // https://github.com/apache/datafusion/issues/21690 Volatility::Stable | Volatility::Volatile => { is_applicable = false; Ok(TreeNodeRecursion::Stop) @@ -101,6 +118,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { // - AGGREGATE and WINDOW should not end up in filter conditions, except maybe in some edge cases // - Can `Wildcard` be considered as a `Literal`? // - ScalarVariable could be `applicable`, but that would require access to the context + // https://github.com/apache/datafusion/issues/21690 // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::AggregateFunction { .. } @@ -247,23 +265,19 @@ fn populate_partition_values<'a>( partition_values: &mut HashMap<&'a str, PartitionValue>, filter: &'a Expr, ) { - if let Expr::BinaryExpr(BinaryExpr { - ref left, - op, - ref right, - }) = filter - { + if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = filter { match op { Operator::Eq => match (left.as_ref(), right.as_ref()) { - (Expr::Column(Column { ref name, .. }), Expr::Literal(val, _)) - | (Expr::Literal(val, _), Expr::Column(Column { ref name, .. })) => { + (Expr::Column(Column { name, .. }), Expr::Literal(val, _)) + | (Expr::Literal(val, _), Expr::Column(Column { name, .. })) if partition_values .insert(name, PartitionValue::Single(val.to_string())) - .is_some() - { - partition_values.insert(name, PartitionValue::Multi); - } + .is_some() => + { + partition_values.insert(name, PartitionValue::Multi); } + (Expr::Column(Column { .. }), Expr::Literal(_, _)) + | (Expr::Literal(_, _), Expr::Column(Column { .. })) => {} _ => {} }, Operator::And => { @@ -344,17 +358,25 @@ fn filter_partitions( Ok(None) } +/// Returns `Ok(None)` when the file is not inside a valid partition path +/// (e.g. a stale file in the table root directory). Such files are skipped +/// because hive-style partition values are never null and there is no valid +/// value to assign for non-partitioned files. fn try_into_partitioned_file( object_meta: ObjectMeta, partition_cols: &[(String, DataType)], table_path: &ListingTableUrl, -) -> Result { +) -> Result> { let cols = partition_cols.iter().map(|(name, _)| name.as_str()); let parsed = parse_partitions_for_path(table_path, &object_meta.location, cols); + let Some(parsed) = parsed else { + // parse_partitions_for_path already logs a debug message + return Ok(None); + }; + let partition_values = parsed .into_iter() - .flatten() .zip(partition_cols) .map(|(parsed, (_, datatype))| { ScalarValue::try_from_string(parsed.to_string(), datatype) @@ -363,8 +385,9 @@ fn try_into_partitioned_file( let mut pf: PartitionedFile = object_meta.into(); pf.partition_values = partition_values; + pf.table_reference.clone_from(table_path.get_table_ref()); - Ok(pf) + Ok(Some(pf)) } /// Discover the partitions on the given path and prune out files @@ -397,8 +420,15 @@ pub async fn pruned_partition_list<'a>( table_path ); - // if no partition col => simply list all the files - Ok(objects.map_ok(|object_meta| object_meta.into()).boxed()) + // if no partition col => list all the files + Ok(objects + .try_filter_map(|object_meta| { + futures::future::ready(object_meta_to_partitioned_file( + object_meta, + table_path.get_table_ref(), + )) + }) + .boxed()) } else { let df_schema = DFSchema::from_unqualified_fields( partition_cols @@ -409,18 +439,37 @@ pub async fn pruned_partition_list<'a>( )?; Ok(objects - .map_ok(|object_meta| { - try_into_partitioned_file(object_meta, partition_cols, table_path) + .try_filter_map(|object_meta| { + futures::future::ready(try_into_partitioned_file( + object_meta, + partition_cols, + table_path, + )) }) .try_filter_map(move |pf| { - futures::future::ready( - pf.and_then(|pf| filter_partitions(pf, filters, &df_schema)), - ) + futures::future::ready(filter_partitions(pf, filters, &df_schema)) }) .boxed()) } } +fn object_meta_to_partitioned_file( + object_meta: ObjectMeta, + table_ref: &Option, +) -> Result> { + Ok(Some(PartitionedFile { + object_meta, + arrow_schema: None, + partition_values: vec![], + range: None, + statistics: None, + ordering: None, + extensions: FileExtensions::new(), + metadata_size_hint: None, + table_reference: table_ref.clone(), + })) +} + /// Extract the partition values for the given `file_path` (in the given `table_path`) /// associated to the partitions defined by `table_partition_cols` pub fn parse_partitions_for_path<'a, I>( @@ -466,7 +515,7 @@ mod tests { use std::ops::Not; use super::*; - use datafusion_expr::{case, col, lit, Expr}; + use datafusion_expr::{case, col}; #[test] fn test_split_files() { @@ -578,6 +627,130 @@ mod tests { ); } + #[test] + fn test_try_into_partitioned_file_valid_partition() { + let table_path = ListingTableUrl::parse("file:///bucket/mytable").unwrap(); + let partition_cols = vec![("year_month".to_string(), DataType::Utf8)]; + let meta = ObjectMeta { + location: Path::from("bucket/mytable/year_month=2024-01/data.parquet"), + last_modified: chrono::Utc::now(), + size: 100, + e_tag: None, + version: None, + }; + + let result = + try_into_partitioned_file(meta, &partition_cols, &table_path).unwrap(); + assert!(result.is_some()); + let pf = result.unwrap(); + assert_eq!(pf.partition_values.len(), 1); + assert_eq!( + pf.partition_values[0], + ScalarValue::Utf8(Some("2024-01".to_string())) + ); + } + + #[test] + fn test_try_into_partitioned_file_root_file_skipped() { + // File in root directory (not inside any partition path) should be + // skipped — this is the case where a stale file exists from before + // hive partitioning was added. + let table_path = ListingTableUrl::parse("file:///bucket/mytable").unwrap(); + let partition_cols = vec![("year_month".to_string(), DataType::Utf8)]; + let meta = ObjectMeta { + location: Path::from("bucket/mytable/data.parquet"), + last_modified: chrono::Utc::now(), + size: 100, + e_tag: None, + version: None, + }; + + let result = + try_into_partitioned_file(meta, &partition_cols, &table_path).unwrap(); + assert!( + result.is_none(), + "Files outside partition structure should be skipped" + ); + } + + #[test] + fn test_try_into_partitioned_file_wrong_partition_name() { + // File in a directory that doesn't match the expected partition column + let table_path = ListingTableUrl::parse("file:///bucket/mytable").unwrap(); + let partition_cols = vec![("year_month".to_string(), DataType::Utf8)]; + let meta = ObjectMeta { + location: Path::from("bucket/mytable/wrong_col=2024-01/data.parquet"), + last_modified: chrono::Utc::now(), + size: 100, + e_tag: None, + version: None, + }; + + let result = + try_into_partitioned_file(meta, &partition_cols, &table_path).unwrap(); + assert!( + result.is_none(), + "Files with wrong partition column name should be skipped" + ); + } + + #[test] + fn test_try_into_partitioned_file_multiple_partitions() { + let table_path = ListingTableUrl::parse("file:///bucket/mytable").unwrap(); + let partition_cols = vec![ + ("year".to_string(), DataType::Utf8), + ("month".to_string(), DataType::Utf8), + ]; + let meta = ObjectMeta { + location: Path::from("bucket/mytable/year=2024/month=01/data.parquet"), + last_modified: chrono::Utc::now(), + size: 100, + e_tag: None, + version: None, + }; + + let result = + try_into_partitioned_file(meta, &partition_cols, &table_path).unwrap(); + assert!(result.is_some()); + let pf = result.unwrap(); + assert_eq!(pf.partition_values.len(), 2); + assert_eq!( + pf.partition_values[0], + ScalarValue::Utf8(Some("2024".to_string())) + ); + assert_eq!( + pf.partition_values[1], + ScalarValue::Utf8(Some("01".to_string())) + ); + } + + #[test] + fn test_try_into_partitioned_file_partial_partition_skipped() { + // File has first partition but not second — should be skipped + let table_path = ListingTableUrl::parse("file:///bucket/mytable").unwrap(); + let partition_cols = vec![ + ("year".to_string(), DataType::Utf8), + ("month".to_string(), DataType::Utf8), + ]; + let meta = ObjectMeta { + location: Path::from("bucket/mytable/year=2024/data.parquet"), + last_modified: chrono::Utc::now(), + size: 100, + e_tag: None, + version: None, + }; + + let result = + try_into_partitioned_file(meta, &partition_cols, &table_path).unwrap(); + // File has year=2024 but no month= directory — parse_partitions_for_path + // returns None because the path component "data.parquet" doesn't match + // the expected "month=..." pattern. + assert!( + result.is_none(), + "Files with incomplete partition structure should be skipped" + ); + } + #[test] fn test_expr_applicable_for_cols() { assert!(expr_applicable_for_cols( diff --git a/datafusion/catalog-listing/src/mod.rs b/datafusion/catalog-listing/src/mod.rs index 1e06483261d2e..9efb5aa96267e 100644 --- a/datafusion/catalog-listing/src/mod.rs +++ b/datafusion/catalog-listing/src/mod.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", @@ -34,4 +32,4 @@ mod table; pub use config::{ListingTableConfig, SchemaSource}; pub use options::ListingOptions; -pub use table::ListingTable; +pub use table::{ListFilesResult, ListingTable}; diff --git a/datafusion/catalog-listing/src/options.rs b/datafusion/catalog-listing/src/options.rs index 7da8005f90ec2..0ab15e05abba1 100644 --- a/datafusion/catalog-listing/src/options.rs +++ b/datafusion/catalog-listing/src/options.rs @@ -18,12 +18,12 @@ use arrow::datatypes::{DataType, SchemaRef}; use datafusion_catalog::Session; use datafusion_common::plan_err; -use datafusion_datasource::file_format::FileFormat; use datafusion_datasource::ListingTableUrl; +use datafusion_datasource::file_format::FileFormat; use datafusion_execution::config::SessionConfig; use datafusion_expr::SortExpr; use futures::StreamExt; -use futures::{future, TryStreamExt}; +use futures::TryStreamExt; use itertools::Itertools; use std::sync::Arc; @@ -263,7 +263,15 @@ impl ListingOptions { /// Infer the schema of the files at the given path on the provided object store. /// /// If the table_path contains one or more files (i.e. it is a directory / - /// prefix of files) their schema is merged by calling [`FileFormat::infer_schema`] + /// prefix of files) their schema is merged by calling [`FileFormat::infer_schema`]. + /// + /// Returns a `Plan` error if `table_path` contains no files at all (e.g. an + /// empty or non-existent directory), since an inferred schema with zero + /// columns produces confusing "column not found" errors at query time. + /// Callers that need to support empty locations must declare an explicit + /// schema instead of relying on inference. Locations that contain files + /// which all happen to be 0-byte are still accepted — the empty files are + /// filtered out before format-specific inference runs. /// /// Note: The inferred schema does not include any partitioning columns. /// @@ -275,14 +283,27 @@ impl ListingOptions { ) -> datafusion_common::Result { let store = state.runtime_env().object_store(table_path)?; - let files: Vec<_> = table_path + let all_files: Vec<_> = table_path .list_all_files(state, store.as_ref(), &self.file_extension) .await? - // Empty files cannot affect schema but may throw when trying to read for it - .try_filter(|object_meta| future::ready(object_meta.size > 0)) .try_collect() .await?; + if all_files.is_empty() { + return plan_err!( + "No files found at {}. \ + Cannot infer schema from an empty location; either add data files \ + or declare an explicit schema for the table.", + table_path + ); + } + + // Empty files cannot affect schema but may throw when trying to read for it + let files: Vec<_> = all_files + .into_iter() + .filter(|object_meta| object_meta.size > 0) + .collect(); + let schema = self.format.infer_schema(state, &store, &files).await?; Ok(schema) diff --git a/datafusion/catalog-listing/src/table.rs b/datafusion/catalog-listing/src/table.rs index 33d5c86bf88dc..dd3675bd2b39d 100644 --- a/datafusion/catalog-listing/src/table.rs +++ b/datafusion/catalog-listing/src/table.rs @@ -23,35 +23,43 @@ use async_trait::async_trait; use datafusion_catalog::{ScanArgs, ScanResult, Session, TableProvider}; use datafusion_common::stats::Precision; use datafusion_common::{ - internal_datafusion_err, plan_err, project_schema, Constraints, DataFusionError, - SchemaExt, Statistics, + Constraints, SchemaExt, Statistics, internal_datafusion_err, plan_err, project_schema, }; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; -use datafusion_datasource::file_sink_config::FileSinkConfig; -use datafusion_datasource::schema_adapter::{ - DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, -}; +use datafusion_datasource::file_sink_config::{FileOutputMode, FileSinkConfig}; +#[expect(deprecated)] +use datafusion_datasource::schema_adapter::SchemaAdapterFactory; use datafusion_datasource::{ - compute_all_files_statistics, ListingTableUrl, PartitionedFile, TableSchema, + ListingTableUrl, PartitionedFile, TableSchemaBuilder, compute_all_files_statistics, }; +use datafusion_execution::cache::TableScopedPath; use datafusion_execution::cache::cache_manager::FileStatisticsCache; -use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; use datafusion_expr::dml::InsertOp; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType}; use datafusion_physical_expr::create_lex_ordering; use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::ExecutionPlan; -use futures::{future, stream, Stream, StreamExt, TryStreamExt}; +use datafusion_physical_plan::empty::EmptyExec; +use futures::{Stream, StreamExt, TryStreamExt, future, stream}; use object_store::ObjectStore; -use std::any::Any; use std::collections::HashMap; use std::sync::Arc; +/// Result of a file listing operation from [`ListingTable::list_files_for_scan`]. +#[derive(Debug)] +pub struct ListFilesResult { + /// File groups organized by the partitioning strategy. + pub file_groups: Vec, + /// Aggregated statistics for all files. + pub statistics: Statistics, + /// Whether files are grouped by partition values (enables Hash partitioning). + pub grouped_by_partition: bool, +} + /// Built in [`TableProvider`] that reads data from one or more files as a single table. /// /// The files are read using an [`ObjectStore`] instance, for example from @@ -178,13 +186,11 @@ pub struct ListingTable { /// The SQL definition for this table, if any definition: Option, /// Cache for collected file statistics - collected_statistics: FileStatisticsCache, + collected_statistics: Option>, /// Constraints applied to this table constraints: Constraints, /// Column default expressions for columns that are not physically present in the data files column_defaults: HashMap, - /// Optional [`SchemaAdapterFactory`] for creating schema adapters - schema_adapter_factory: Option>, /// Optional [`PhysicalExprAdapterFactory`] for creating physical expression adapters expr_adapter_factory: Option>, } @@ -224,10 +230,9 @@ impl ListingTable { schema_source, options, definition: None, - collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), + collected_statistics: None, constraints: Constraints::default(), column_defaults: HashMap::new(), - schema_adapter_factory: config.schema_adapter_factory, expr_adapter_factory: config.expr_adapter_factory, }; @@ -254,10 +259,8 @@ impl ListingTable { /// Setting a statistics cache on the `SessionContext` can avoid refetching statistics /// multiple times in the same session. /// - /// If `None`, creates a new [`DefaultFileStatisticsCache`] scoped to this query. - pub fn with_cache(mut self, cache: Option) -> Self { - self.collected_statistics = - cache.unwrap_or_else(|| Arc::new(DefaultFileStatisticsCache::default())); + pub fn with_cache(mut self, cache: Option>) -> Self { + self.collected_statistics = cache; self } @@ -282,92 +285,152 @@ impl ListingTable { self.schema_source } - /// Set the [`SchemaAdapterFactory`] for this [`ListingTable`] + /// Deprecated: Set the [`SchemaAdapterFactory`] for this [`ListingTable`] /// - /// The schema adapter factory is used to create schema adapters that can - /// handle schema evolution and type conversions when reading files with - /// different schemas than the table schema. + /// `SchemaAdapterFactory` has been removed. Use [`ListingTableConfig::with_expr_adapter_factory`] + /// and `PhysicalExprAdapterFactory` instead. See `upgrading.md` for more details. /// - /// # Example: Adding Schema Evolution Support - /// ```rust - /// # use std::sync::Arc; - /// # use datafusion_catalog_listing::{ListingTable, ListingTableConfig, ListingOptions}; - /// # use datafusion_datasource::ListingTableUrl; - /// # use datafusion_datasource::schema_adapter::{DefaultSchemaAdapterFactory, SchemaAdapter}; - /// # use datafusion_datasource_parquet::file_format::ParquetFormat; - /// # use arrow::datatypes::{SchemaRef, Schema, Field, DataType}; - /// # let table_path = ListingTableUrl::parse("file:///path/to/data").unwrap(); - /// # let options = ListingOptions::new(Arc::new(ParquetFormat::default())); - /// # let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); - /// # let config = ListingTableConfig::new(table_path).with_listing_options(options).with_schema(schema); - /// # let table = ListingTable::try_new(config).unwrap(); - /// let table_with_evolution = table - /// .with_schema_adapter_factory(Arc::new(DefaultSchemaAdapterFactory)); - /// ``` - /// See [`ListingTableConfig::with_schema_adapter_factory`] for an example of custom SchemaAdapterFactory. + /// This method is a no-op and returns `self` unchanged. + #[deprecated( + since = "52.0.0", + note = "SchemaAdapterFactory has been removed. Use ListingTableConfig::with_expr_adapter_factory and PhysicalExprAdapterFactory instead. See upgrading.md for more details." + )] + #[expect(deprecated)] pub fn with_schema_adapter_factory( self, - schema_adapter_factory: Arc, + _schema_adapter_factory: Arc, ) -> Self { - Self { - schema_adapter_factory: Some(schema_adapter_factory), - ..self - } - } - - /// Get the [`SchemaAdapterFactory`] for this table - pub fn schema_adapter_factory(&self) -> Option<&Arc> { - self.schema_adapter_factory.as_ref() + // No-op - just return self unchanged + self } - /// Creates a schema adapter for mapping between file and table schemas + /// Deprecated: Returns the [`SchemaAdapterFactory`] used by this [`ListingTable`]. /// - /// Uses the configured schema adapter factory if available, otherwise falls back - /// to the default implementation. - fn create_schema_adapter(&self) -> Box { - let table_schema = self.schema(); - match &self.schema_adapter_factory { - Some(factory) => { - factory.create_with_projected_schema(Arc::clone(&table_schema)) - } - None => DefaultSchemaAdapterFactory::from_schema(Arc::clone(&table_schema)), - } + /// `SchemaAdapterFactory` has been removed. Use `PhysicalExprAdapterFactory` instead. + /// See `upgrading.md` for more details. + /// + /// Always returns `None`. + #[deprecated( + since = "52.0.0", + note = "SchemaAdapterFactory has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." + )] + #[expect(deprecated)] + pub fn schema_adapter_factory(&self) -> Option> { + None } - /// Creates a file source and applies schema adapter factory if available - fn create_file_source_with_schema_adapter( - &self, - ) -> datafusion_common::Result> { - let table_schema = TableSchema::new( - Arc::clone(&self.file_schema), - self.options - .table_partition_cols - .iter() - .map(|(col, field)| Arc::new(Field::new(col, field.clone(), false))) - .collect(), - ); + /// Creates a file source for this table + fn create_file_source(&self) -> Arc { + let table_schema = TableSchemaBuilder::from(&self.file_schema) + .with_table_partition_cols( + self.options + .table_partition_cols + .iter() + .map(|(col, field)| Arc::new(Field::new(col, field.clone(), false))) + .collect::>(), + ) + .build(); - let mut source = self.options.format.file_source(table_schema); - // Apply schema adapter to source if available - // - // The source will use this SchemaAdapter to adapt data batches as they flow up the plan. - // Note: ListingTable also creates a SchemaAdapter in `scan()` but that is only used to adapt collected statistics. - if let Some(factory) = &self.schema_adapter_factory { - source = source.with_schema_adapter_factory(Arc::clone(factory))?; - } - Ok(source) + self.options.format.file_source(table_schema) } - /// If file_sort_order is specified, creates the appropriate physical expressions + /// Creates output ordering from user-specified file_sort_order or derives + /// from file orderings when user doesn't specify. + /// + /// If user specified `file_sort_order`, that takes precedence. + /// Otherwise, attempts to derive common ordering from file orderings in + /// the provided file groups. pub fn try_create_output_ordering( &self, execution_props: &ExecutionProps, + file_groups: &[FileGroup], ) -> datafusion_common::Result> { - create_lex_ordering( - &self.table_schema, - &self.options.file_sort_order, - execution_props, - ) + // If user specified sort order, use that + if !self.options.file_sort_order.is_empty() { + return create_lex_ordering( + &self.table_schema, + &self.options.file_sort_order, + execution_props, + ); + } + if let Some(ordering) = derive_common_ordering_from_files(file_groups) { + return Ok(vec![ordering]); + } + Ok(vec![]) + } +} + +/// Derives a common ordering from file orderings across all file groups. +/// +/// Returns the common ordering if all files have compatible orderings, +/// otherwise returns None. +/// +/// The function finds the longest common prefix among all file orderings. +/// For example, if files have orderings `[a, b, c]` and `[a, b]`, the common +/// ordering is `[a, b]`. +fn derive_common_ordering_from_files(file_groups: &[FileGroup]) -> Option { + enum CurrentOrderingState { + /// Initial state before processing any files + FirstFile, + /// Some common ordering found so far + SomeOrdering(LexOrdering), + /// No files have ordering + NoOrdering, + } + let mut state = CurrentOrderingState::FirstFile; + + // Collect file orderings and track counts + for group in file_groups { + for file in group.iter() { + state = match (&state, &file.ordering) { + // If this is the first file with ordering, set it as current + (CurrentOrderingState::FirstFile, Some(ordering)) => { + CurrentOrderingState::SomeOrdering(ordering.clone()) + } + (CurrentOrderingState::FirstFile, None) => { + CurrentOrderingState::NoOrdering + } + // If we have an existing ordering, find common prefix with new ordering + (CurrentOrderingState::SomeOrdering(current), Some(ordering)) => { + // Find common prefix between current and new ordering + let prefix_len = current + .as_ref() + .iter() + .zip(ordering.as_ref().iter()) + .take_while(|(a, b)| a == b) + .count(); + if prefix_len == 0 { + log::trace!( + "Cannot derive common ordering: no common prefix between orderings {current:?} and {ordering:?}" + ); + return None; + } else { + let ordering = + LexOrdering::new(current.as_ref()[..prefix_len].to_vec()) + .expect("prefix_len > 0, so ordering must be valid"); + CurrentOrderingState::SomeOrdering(ordering) + } + } + // If one file has ordering and another doesn't, no common ordering + // Return None and log a trace message explaining why + (CurrentOrderingState::SomeOrdering(ordering), None) + | (CurrentOrderingState::NoOrdering, Some(ordering)) => { + log::trace!( + "Cannot derive common ordering: some files have ordering {ordering:?}, others don't" + ); + return None; + } + // Both have no ordering, remain in NoOrdering state + (CurrentOrderingState::NoOrdering, None) => { + CurrentOrderingState::NoOrdering + } + }; + } + } + + match state { + CurrentOrderingState::SomeOrdering(ordering) => Some(ordering), + _ => None, } } @@ -383,10 +446,6 @@ fn can_be_evaluated_for_partition_pruning( #[async_trait] impl TableProvider for ListingTable { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { Arc::clone(&self.table_schema) } @@ -446,7 +505,11 @@ impl TableProvider for ListingTable { // at the same time. This is because the limit should be applied after the filters are applied. let statistic_file_limit = if filters.is_empty() { limit } else { None }; - let (mut partitioned_file_lists, statistics) = self + let ListFilesResult { + file_groups: mut partitioned_file_lists, + statistics, + grouped_by_partition: partitioned_by_file_group, + } = self .list_files_for_scan(state, &partition_filters, statistic_file_limit) .await?; @@ -456,7 +519,10 @@ impl TableProvider for ListingTable { return Ok(ScanResult::new(Arc::new(EmptyExec::new(projected_schema)))); } - let output_ordering = self.try_create_output_ordering(state.execution_props())?; + let output_ordering = self.try_create_output_ordering( + state.execution_props(), + &partitioned_file_lists, + )?; match state .config_options() .execution @@ -478,7 +544,9 @@ impl TableProvider for ListingTable { if new_groups.len() <= self.options.target_partitions { partitioned_file_lists = new_groups; } else { - log::debug!("attempted to split file groups by statistics, but there were more file groups than target_partitions; falling back to unordered") + log::debug!( + "attempted to split file groups by statistics, but there were more file groups than target_partitions; falling back to unordered" + ) } } None => {} // no ordering required @@ -492,7 +560,7 @@ impl TableProvider for ListingTable { ))))); }; - let file_source = self.create_file_source_with_schema_adapter()?; + let file_source = self.create_file_source(); // create the execution plan let plan = self @@ -504,10 +572,11 @@ impl TableProvider for ListingTable { .with_file_groups(partitioned_file_lists) .with_constraints(self.constraints.clone()) .with_statistics(statistics) - .with_projection_indices(projection) + .with_projection_indices(projection)? .with_limit(limit) .with_output_ordering(output_ordering) .with_expr_adapter(self.expr_adapter_factory.clone()) + .with_partitioned_by_file_group(partitioned_by_file_group) .build(), ) .await?; @@ -578,6 +647,15 @@ impl TableProvider for ListingTable { let keep_partition_by_columns = state.config_options().execution.keep_partition_by_columns; + // Invalidate cache entries for this table if they exist + if let Some(lfc) = state.runtime_env().cache_manager.get_list_files_cache() { + let key = TableScopedPath { + table: table_path.get_table_ref().clone(), + path: table_path.prefix().clone(), + }; + let _ = lfc.remove(&key); + } + // Sink related option, apart from format let config = FileSinkConfig { original_url: String::default(), @@ -589,9 +667,11 @@ impl TableProvider for ListingTable { insert_op, keep_partition_by_columns, file_extension: self.options().format.get_ext(), + file_output_mode: FileOutputMode::Automatic, }; - let orderings = self.try_create_output_ordering(state.execution_props())?; + // For writes, we only use user-specified ordering (no file groups to derive from) + let orderings = self.try_create_output_ordering(state.execution_props(), &[])?; // It is sufficient to pass only one of the equivalent orderings: let order_requirements = orderings.into_iter().next().map(Into::into); @@ -615,11 +695,15 @@ impl ListingTable { ctx: &'a dyn Session, filters: &'a [Expr], limit: Option, - ) -> datafusion_common::Result<(Vec, Statistics)> { + ) -> datafusion_common::Result { let store = if let Some(url) = self.table_paths.first() { ctx.runtime_env().object_store(url)? } else { - return Ok((vec![], Statistics::new_unknown(&self.file_schema))); + return Ok(ListFilesResult { + file_groups: vec![], + statistics: Statistics::new_unknown(&self.file_schema), + grouped_by_partition: false, + }); }; // list files (with partitions) let file_list = future::try_join_all(self.table_paths.iter().map(|table_path| { @@ -636,16 +720,19 @@ impl ListingTable { let meta_fetch_concurrency = ctx.config_options().execution.meta_fetch_concurrency; let file_list = stream::iter(file_list).flatten_unordered(meta_fetch_concurrency); - // collect the statistics if required by the config + // collect the statistics and ordering if required by the config let files = file_list .map(|part_file| async { let part_file = part_file?; - let statistics = if self.options.collect_stat { - self.do_collect_statistics(ctx, &store, &part_file).await? + let (statistics, ordering) = if self.options.collect_stat { + self.do_collect_statistics_and_ordering(ctx, &store, &part_file) + .await? } else { - Arc::new(Statistics::new_unknown(&self.file_schema)) + (Arc::new(Statistics::new_unknown(&self.file_schema)), None) }; - Ok(part_file.with_statistics(statistics)) + Ok(part_file + .with_statistics(statistics) + .with_ordering(ordering)) }) .boxed() .buffer_unordered(ctx.config_options().execution.meta_fetch_concurrency); @@ -653,65 +740,103 @@ impl ListingTable { let (file_group, inexact_stats) = get_files_with_limit(files, limit, self.options.collect_stat).await?; - let file_groups = file_group.split_files(self.options.target_partitions); - let (mut file_groups, mut stats) = compute_all_files_statistics( + // Threshold: 0 = disabled, N > 0 = enabled when distinct_keys >= N + // + // When enabled, files are grouped by their Hive partition column values, allowing + // FileScanConfig to declare Hash partitioning. This enables the optimizer to skip + // hash repartitioning for aggregates and joins on partition columns. + let threshold = ctx.config_options().optimizer.preserve_file_partitions; + + let (file_groups, grouped_by_partition) = if threshold > 0 + && !self.options.table_partition_cols.is_empty() + { + let grouped = + file_group.group_by_partition_values(self.options.target_partitions); + if grouped.len() >= threshold { + (grouped, true) + } else { + let all_files: Vec<_> = + grouped.into_iter().flat_map(|g| g.into_inner()).collect(); + ( + FileGroup::new(all_files).split_files(self.options.target_partitions), + false, + ) + } + } else { + ( + file_group.split_files(self.options.target_partitions), + false, + ) + }; + + let (file_groups, stats) = compute_all_files_statistics( file_groups, self.schema(), self.options.collect_stat, inexact_stats, )?; - let schema_adapter = self.create_schema_adapter(); - let (schema_mapper, _) = schema_adapter.map_schema(self.file_schema.as_ref())?; - - stats.column_statistics = - schema_mapper.map_column_statistics(&stats.column_statistics)?; - file_groups.iter_mut().try_for_each(|file_group| { - if let Some(stat) = file_group.statistics_mut() { - stat.column_statistics = - schema_mapper.map_column_statistics(&stat.column_statistics)?; - } - Ok::<_, DataFusionError>(()) - })?; - Ok((file_groups, stats)) + // Note: Statistics already include both file columns and partition columns. + // PartitionedFile::with_statistics automatically appends exact partition column + // statistics (min=max=partition_value, null_count=0, distinct_count=1) computed + // from partition_values. + Ok(ListFilesResult { + file_groups, + statistics: stats, + grouped_by_partition, + }) } - /// Collects statistics for a given partitioned file. + /// Collects statistics and ordering for a given partitioned file. /// - /// This method first checks if the statistics for the given file are already cached. - /// If they are, it returns the cached statistics. - /// If they are not, it infers the statistics from the file and stores them in the cache. - async fn do_collect_statistics( + /// This method checks if statistics are cached. If cached, it returns the + /// cached statistics and infers ordering separately. If not cached, it infers + /// both statistics and ordering in a single metadata read for efficiency. + async fn do_collect_statistics_and_ordering( &self, ctx: &dyn Session, store: &Arc, part_file: &PartitionedFile, - ) -> datafusion_common::Result> { - match self - .collected_statistics - .get_with_extra(&part_file.object_meta.location, &part_file.object_meta) + ) -> datafusion_common::Result<(Arc, Option)> { + use datafusion_execution::cache::cache_manager::CachedFileMetadata; + + let path = TableScopedPath { + table: part_file.table_reference.clone(), + path: part_file.object_meta.location.clone(), + }; + let meta = &part_file.object_meta; + + // Check cache first - if we have valid cached statistics and ordering + if let Some(cache) = &self.collected_statistics + && let Some(cached) = cache.get(&path) + && cached.is_valid_for(meta) { - Some(statistics) => Ok(statistics), - None => { - let statistics = self - .options - .format - .infer_stats( - ctx, - store, - Arc::clone(&self.file_schema), - &part_file.object_meta, - ) - .await?; - let statistics = Arc::new(statistics); - self.collected_statistics.put_with_extra( - &part_file.object_meta.location, + // Return cached statistics and ordering + return Ok((Arc::clone(&cached.statistics), cached.ordering.clone())); + } + + // Cache miss or invalid: fetch both statistics and ordering in a single metadata read + let file_meta = self + .options + .format + .infer_stats_and_ordering(ctx, store, Arc::clone(&self.file_schema), meta) + .await?; + + let statistics = Arc::new(file_meta.statistics); + + // Store in cache + if let Some(cache) = &self.collected_statistics { + cache.put( + &path, + CachedFileMetadata::new( + meta.clone(), Arc::clone(&statistics), - &part_file.object_meta, - ); - Ok(statistics) - } + file_meta.ordering.clone(), + ), + ); } + + Ok((statistics, file_meta.ordering)) } } @@ -760,28 +885,25 @@ async fn get_files_with_limit( let file = file_result?; // Update file statistics regardless of state - if collect_stats { - if let Some(file_stats) = &file.statistics { - num_rows = if file_group.is_empty() { - // For the first file, just take its row count - file_stats.num_rows - } else { - // For subsequent files, accumulate the counts - num_rows.add(&file_stats.num_rows) - }; - } + if collect_stats && let Some(file_stats) = &file.statistics { + num_rows = if file_group.is_empty() { + // For the first file, just take its row count + file_stats.num_rows + } else { + // For subsequent files, accumulate the counts + num_rows.add(&file_stats.num_rows) + }; } // Always add the file to our group file_group.push(file); // Check if we've hit the limit (if one was specified) - if let Some(limit) = limit { - if let Precision::Exact(row_count) = num_rows { - if row_count > limit { - state = ProcessingState::ReachedLimit; - } - } + if let Some(limit) = limit + && let Precision::Exact(row_count) = num_rows + && row_count > limit + { + state = ProcessingState::ReachedLimit; } } // If we still have files in the stream, it means that the limit kicked @@ -790,3 +912,145 @@ async fn get_files_with_limit( let inexact_stats = all_files.next().await.is_some(); Ok((file_group, inexact_stats)) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::compute::SortOptions; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + + /// Helper to create a PhysicalSortExpr + fn sort_expr( + name: &str, + idx: usize, + descending: bool, + nulls_first: bool, + ) -> PhysicalSortExpr { + PhysicalSortExpr::new( + Arc::new(Column::new(name, idx)), + SortOptions { + descending, + nulls_first, + }, + ) + } + + /// Helper to create a LexOrdering (unwraps the Option) + fn lex_ordering(exprs: Vec) -> LexOrdering { + LexOrdering::new(exprs).expect("expected non-empty ordering") + } + + /// Helper to create a PartitionedFile with optional ordering + fn create_file(name: &str, ordering: Option) -> PartitionedFile { + PartitionedFile::new(name.to_string(), 1024).with_ordering(ordering) + } + + #[test] + fn test_derive_common_ordering_all_files_same_ordering() { + // All files have the same ordering -> returns that ordering + let ordering = lex_ordering(vec![ + sort_expr("a", 0, false, true), + sort_expr("b", 1, true, false), + ]); + + let file_groups = vec![ + FileGroup::new(vec![ + create_file("f1.parquet", Some(ordering.clone())), + create_file("f2.parquet", Some(ordering.clone())), + ]), + FileGroup::new(vec![create_file("f3.parquet", Some(ordering.clone()))]), + ]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, Some(ordering)); + } + + #[test] + fn test_derive_common_ordering_common_prefix() { + // Files have different orderings but share a common prefix + let ordering_abc = lex_ordering(vec![ + sort_expr("a", 0, false, true), + sort_expr("b", 1, false, true), + sort_expr("c", 2, false, true), + ]); + let ordering_ab = lex_ordering(vec![ + sort_expr("a", 0, false, true), + sort_expr("b", 1, false, true), + ]); + + let file_groups = vec![FileGroup::new(vec![ + create_file("f1.parquet", Some(ordering_abc)), + create_file("f2.parquet", Some(ordering_ab.clone())), + ])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, Some(ordering_ab)); + } + + #[test] + fn test_derive_common_ordering_no_common_prefix() { + // Files have completely different orderings -> returns None + let ordering_a = lex_ordering(vec![sort_expr("a", 0, false, true)]); + let ordering_b = lex_ordering(vec![sort_expr("b", 1, false, true)]); + + let file_groups = vec![FileGroup::new(vec![ + create_file("f1.parquet", Some(ordering_a)), + create_file("f2.parquet", Some(ordering_b)), + ])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, None); + } + + #[test] + fn test_derive_common_ordering_mixed_with_none() { + // Some files have ordering, some don't -> returns None + let ordering = lex_ordering(vec![sort_expr("a", 0, false, true)]); + + let file_groups = vec![FileGroup::new(vec![ + create_file("f1.parquet", Some(ordering)), + create_file("f2.parquet", None), + ])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, None); + } + + #[test] + fn test_derive_common_ordering_all_none() { + // No files have ordering -> returns None + let file_groups = vec![FileGroup::new(vec![ + create_file("f1.parquet", None), + create_file("f2.parquet", None), + ])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, None); + } + + #[test] + fn test_derive_common_ordering_empty_groups() { + // Empty file groups -> returns None + let file_groups: Vec = vec![]; + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, None); + } + + #[test] + fn test_derive_common_ordering_single_file() { + // Single file with ordering -> returns that ordering + let ordering = lex_ordering(vec![ + sort_expr("a", 0, false, true), + sort_expr("b", 1, true, false), + ]); + + let file_groups = vec![FileGroup::new(vec![create_file( + "f1.parquet", + Some(ordering.clone()), + )])]; + + let result = derive_common_ordering_from_files(&file_groups); + assert_eq!(result, Some(ordering)); + } +} diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 99bdcb6f74fe6..12b3f44fe796a 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -17,29 +17,40 @@ //! This module provides the bisect function, which implements binary search. +pub(crate) mod aggregate; pub mod expr; pub mod memory; pub mod proxy; pub mod string_utils; use crate::assert_or_internal_err; -use crate::error::{_exec_datafusion_err, _internal_datafusion_err}; +use crate::error::{_exec_datafusion_err, _exec_err, _internal_datafusion_err}; use crate::{Result, ScalarValue}; use arrow::array::{ - cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, - OffsetSizeTrait, + Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, + cast::AsArray, +}; +use arrow::array::{ + ArrowPrimitiveType, BooleanArray, Datum, GenericListArray, Int32Array, Int64Array, + MutableArrayData, PrimitiveArray, make_array, +}; +use arrow::array::{LargeListViewArray, ListViewArray}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::compute::kernels::cmp::eq; +use arrow::compute::kernels::length::length; +use arrow::compute::{SortColumn, SortOptions, partition}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Field, Int32Type, Int64Type, SchemaRef, }; -use arrow::buffer::OffsetBuffer; -use arrow::compute::{partition, SortColumn, SortOptions}; -use arrow::datatypes::{DataType, Field, SchemaRef}; #[cfg(feature = "sql")] use sqlparser::{ast::Ident, dialect::GenericDialect, parser::Parser}; use std::borrow::{Borrow, Cow}; -use std::cmp::{min, Ordering}; +use std::cmp::{Ordering, min}; use std::collections::HashSet; +use std::iter::repeat_n; use std::num::NonZero; use std::ops::Range; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use std::thread::available_parallelism; /// Applies an optional projection to a [`SchemaRef`], returning the @@ -70,10 +81,10 @@ use std::thread::available_parallelism; /// ``` pub fn project_schema( schema: &SchemaRef, - projection: Option<&Vec>, + projection: Option<&impl AsRef<[usize]>>, ) -> Result { let schema = match projection { - Some(columns) => Arc::new(schema.project(columns)?), + Some(columns) => Arc::new(schema.project(columns.as_ref())?), None => Arc::clone(schema), }; Ok(schema) @@ -266,10 +277,10 @@ fn needs_quotes(s: &str) -> bool { let mut chars = s.chars(); // first char can not be a number unless escaped - if let Some(first_char) = chars.next() { - if !(first_char.is_ascii_lowercase() || first_char == '_') { - return true; - } + if let Some(first_char) = chars.next() + && !(first_char.is_ascii_lowercase() || first_char == '_') + { + return true; } !chars.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_') @@ -384,6 +395,137 @@ pub fn longest_consecutive_prefix>( count } +/// Splits `vec` at index `n`, returning the first `n` elements and leaving the +/// remaining `vec.len() - n` elements in `vec`. +/// +/// Allocates for whichever side is smaller, so the new allocation is +/// `min(n, vec.len() - n)` rather than always `n` (as `vec.drain(0..n).collect()` +/// would). This matters when the split emits a prefix under memory pressure, +/// where `n` can be close to `vec.len()`. +pub fn split_vec_min_alloc(vec: &mut Vec, n: usize) -> Vec { + if n * 2 <= vec.len() { + vec.drain(0..n).collect() + } else { + let remaining = vec.split_off(n); + std::mem::replace(vec, remaining) + } +} + +#[cfg(test)] +mod split_vec_min_alloc_tests { + use super::split_vec_min_alloc; + + #[test] + fn drain_branch() { + // n * 2 <= len -> drain+collect branch (allocates n elements) + let mut v = vec![1, 2, 3, 4, 5, 6]; + let first = split_vec_min_alloc(&mut v, 2); + assert_eq!(first, vec![1, 2]); + assert_eq!(v, vec![3, 4, 5, 6]); + } + + #[test] + fn split_off_branch() { + // remaining < n -> split_off+replace branch (allocates remaining elements) + let mut v = vec![1, 2, 3, 4, 5, 6]; + let first = split_vec_min_alloc(&mut v, 4); + assert_eq!(first, vec![1, 2, 3, 4]); + assert_eq!(v, vec![5, 6]); + } + + #[test] + fn exactly_half() { + // n * 2 == len -> drain branch (boundary) + let mut v = vec![1, 2, 3, 4]; + let first = split_vec_min_alloc(&mut v, 2); + assert_eq!(first, vec![1, 2]); + assert_eq!(v, vec![3, 4]); + } + + #[test] + fn take_all() { + let mut v = vec![1, 2, 3]; + let first = split_vec_min_alloc(&mut v, 3); + assert_eq!(first, vec![1, 2, 3]); + assert!(v.is_empty()); + } + + #[test] + fn take_none() { + let mut v = vec![1, 2, 3]; + let first = split_vec_min_alloc(&mut v, 0); + assert!(first.is_empty()); + assert_eq!(v, vec![1, 2, 3]); + } + + #[test] + fn emitted_prefix_does_not_realloc_on_push() { + // Demonstrates *why* the split-off branch must NOT call `shrink_to_fit`. + // + // Downstream callers (e.g. `multi_group_by/bytes.rs`, which does + // `first_n_offsets.push(offset_n)` right after the split) push onto the + // emitted prefix immediately. The split-off branch hands the original + // backing allocation to that prefix, so the prefix already has spare + // capacity for the very next push. + // + // If we shrank the prefix to fit, that next push would have to + // reallocate, and Vec's growth strategy would land it at a *larger* + // capacity than the original allocation we started with -- the opposite + // of the memory saving `shrink_to_fit` was meant to deliver. + + // A Vec with a known, deliberately large capacity. n*2 > len, so this + // takes the split-off branch. + let mut v: Vec = Vec::with_capacity(64); + v.extend(0..10); + let original_capacity = v.capacity(); + assert!(original_capacity >= 64); + + // Emit a prefix that is most of the Vec (n = 8, remaining = 2). + let mut prefix = split_vec_min_alloc(&mut v, 8); + assert_eq!(prefix, vec![0, 1, 2, 3, 4, 5, 6, 7]); + + // The split-off branch moved the original backing store into `prefix`, + // so it keeps the original (large) capacity -- no shrink happened. + assert_eq!( + prefix.capacity(), + original_capacity, + "split-off branch must hand the original allocation to the prefix" + ); + + // The caller's very next operation: push one element onto the prefix. + prefix.push(99); + + // Because the capacity was preserved, the push reused the existing + // allocation: post-push capacity is unchanged and still <= original. + // This is the realloc that `shrink_to_fit` would have forced. + assert_eq!( + prefix.capacity(), + original_capacity, + "push must reuse the preserved allocation (no realloc)" + ); + assert!(prefix.capacity() <= original_capacity); + + // Counter-demonstration: had we shrunk the prefix to fit (capacity 8), + // the same push would have reallocated. Vec doubles on growth, so the + // post-push capacity (16) ends up LARGER than where a length-8 prefix + // started -- and we paid a realloc for it. + let mut shrunk: Vec = prefix[..8].to_vec(); + shrunk.shrink_to_fit(); + let shrunk_capacity = shrink_then_push_capacity(&mut shrunk); + assert!( + shrunk_capacity > 8, + "shrink-to-fit then push reallocates to a larger capacity" + ); + } + + /// Helper for the counter-demonstration above: push one element and report + /// the resulting capacity. + fn shrink_then_push_capacity(v: &mut Vec) -> usize { + v.push(99); + v.capacity() + } +} + /// Creates single element [`ListArray`], [`LargeListArray`] and /// [`FixedSizeListArray`] from other arrays /// @@ -479,6 +621,34 @@ impl SingleRowListArrayBuilder { ScalarValue::FixedSizeList(Arc::new(self.build_fixed_size_list_array(list_size))) } + /// Build a single element [`ListViewArray`] + pub fn build_list_view_array(self) -> ListViewArray { + let (field, arr) = self.into_field_and_arr(); + let offsets = ScalarBuffer::from(vec![0]); + let sizes = ScalarBuffer::from(vec![i32::try_from(arr.len()).expect( + "Trying to construct a ListView where element length exceeds i32::MAX", + )]); + ListViewArray::new(field, offsets, sizes, arr, None) + } + + /// Build a single element [`ListViewArray`] and wrap as [`ScalarValue::ListView`] + pub fn build_list_view_scalar(self) -> ScalarValue { + ScalarValue::ListView(Arc::new(self.build_list_view_array())) + } + + /// Build a single element [`LargeListViewArray`] + pub fn build_large_list_view_array(self) -> LargeListViewArray { + let (field, arr) = self.into_field_and_arr(); + let offsets = ScalarBuffer::from(vec![0]); + let sizes = ScalarBuffer::from(vec![arr.len() as i64]); + LargeListViewArray::new(field, offsets, sizes, arr, None) + } + + /// Build a single element [`LargeListViewArray`] and wrap as [`ScalarValue::LargeListView`] + pub fn build_large_list_view_scalar(self) -> ScalarValue { + ScalarValue::LargeListView(Arc::new(self.build_large_list_view_array())) + } + /// Helper function: convert this builder into a tuple of field and array fn into_field_and_arr(self) -> (Arc, ArrayRef) { let Self { @@ -516,6 +686,7 @@ impl SingleRowListArrayBuilder { /// ); /// /// assert_eq!(list_arr, expected); +/// ``` pub fn arrays_into_list_array( arr: impl IntoIterator, ) -> Result { @@ -563,11 +734,17 @@ pub fn base_type(data_type: &DataType) -> DataType { match data_type { DataType::List(field) | DataType::LargeList(field) + | DataType::ListView(field) + | DataType::LargeListView(field) | DataType::FixedSizeList(field, _) => base_type(field.data_type()), _ => data_type.to_owned(), } } +// TODO: Modify this to also allow specifying how listviews should be treated. +// For example if cast to List (default) or maintain as ListView (requires +// function to implement support for ListViews) +// https://github.com/apache/datafusion/issues/21777 /// Information about how to coerce lists. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum ListCoercion { @@ -587,6 +764,7 @@ pub enum ListCoercion { /// let base_type = DataType::Float64; /// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type, None); /// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true)))); +/// ``` pub fn coerced_type_with_base_type_only( data_type: &DataType, base_type: &DataType, @@ -620,6 +798,19 @@ pub fn coerced_type_with_base_type_only( *len, ) } + (DataType::ListView(field), _) => { + let field_type = coerced_type_with_base_type_only( + field.data_type(), + base_type, + array_coercion, + ); + + DataType::ListView(Arc::new(Field::new( + field.name(), + field_type, + field.is_nullable(), + ))) + } (DataType::LargeList(field), _) => { let field_type = coerced_type_with_base_type_only( field.data_type(), @@ -633,6 +824,19 @@ pub fn coerced_type_with_base_type_only( field.is_nullable(), ))) } + (DataType::LargeListView(field), _) => { + let field_type = coerced_type_with_base_type_only( + field.data_type(), + base_type, + array_coercion, + ); + + DataType::LargeListView(Arc::new(Field::new( + field.name(), + field_type, + field.is_nullable(), + ))) + } _ => base_type.clone(), } @@ -650,6 +854,15 @@ pub fn coerced_fixed_size_list_to_list(data_type: &DataType) -> DataType { field.is_nullable(), ))) } + DataType::ListView(field) => { + let field_type = coerced_fixed_size_list_to_list(field.data_type()); + + DataType::ListView(Arc::new(Field::new( + field.name(), + field_type, + field.is_nullable(), + ))) + } DataType::LargeList(field) => { let field_type = coerced_fixed_size_list_to_list(field.data_type()); @@ -659,6 +872,15 @@ pub fn coerced_fixed_size_list_to_list(data_type: &DataType) -> DataType { field.is_nullable(), ))) } + DataType::LargeListView(field) => { + let field_type = coerced_fixed_size_list_to_list(field.data_type()); + + DataType::LargeListView(Arc::new(Field::new( + field.name(), + field_type, + field.is_nullable(), + ))) + } _ => data_type.clone(), } @@ -669,6 +891,8 @@ pub fn list_ndims(data_type: &DataType) -> u64 { match data_type { DataType::List(field) | DataType::LargeList(field) + | DataType::ListView(field) + | DataType::LargeListView(field) | DataType::FixedSizeList(field, _) => 1 + list_ndims(field.data_type()), _ => 0, } @@ -693,10 +917,14 @@ pub mod datafusion_strsim { } /// Calculates the minimum number of insertions, deletions, and substitutions - /// required to change one sequence into the other. - fn generic_levenshtein<'a, 'b, Iter1, Iter2, Elem1, Elem2>( + /// required to change one sequence into the other, using a reusable cache buffer. + /// + /// This is the generic implementation that works with any iterator types. + /// The `cache` buffer will be resized as needed and reused across calls. + fn generic_levenshtein_with_buffer<'a, 'b, Iter1, Iter2, Elem1, Elem2>( a: &'a Iter1, b: &'b Iter2, + cache: &mut Vec, ) -> usize where &'a Iter1: IntoIterator, @@ -709,7 +937,9 @@ pub mod datafusion_strsim { return b_len; } - let mut cache: Vec = (1..b_len + 1).collect(); + // Resize cache to fit b_len elements + cache.clear(); + cache.extend(1..=b_len); let mut result = 0; @@ -729,6 +959,21 @@ pub mod datafusion_strsim { result } + /// Calculates the minimum number of insertions, deletions, and substitutions + /// required to change one sequence into the other. + fn generic_levenshtein<'a, 'b, Iter1, Iter2, Elem1, Elem2>( + a: &'a Iter1, + b: &'b Iter2, + ) -> usize + where + &'a Iter1: IntoIterator, + &'b Iter2: IntoIterator, + Elem1: PartialEq, + { + let mut cache = Vec::new(); + generic_levenshtein_with_buffer(a, b, &mut cache) + } + /// Calculates the minimum number of insertions, deletions, and substitutions /// required to change one string into the other. /// @@ -741,6 +986,15 @@ pub mod datafusion_strsim { generic_levenshtein(&StringWrapper(a), &StringWrapper(b)) } + /// Calculates the Levenshtein distance using a reusable cache buffer. + /// This avoids allocating a new Vec for each call, improving performance + /// when computing many distances. + /// + /// The `cache` buffer will be resized as needed and reused across calls. + pub fn levenshtein_with_buffer(a: &str, b: &str, cache: &mut Vec) -> usize { + generic_levenshtein_with_buffer(&StringWrapper(a), &StringWrapper(b), cache) + } + /// Calculates the normalized Levenshtein distance between two strings. /// The normalized distance is a value between 0.0 and 1.0, where 1.0 indicates /// that the strings are identical and 0.0 indicates no similarity. @@ -890,10 +1144,15 @@ pub fn combine_limit( /// /// This is a wrapper around `std::thread::available_parallelism`, providing a default value /// of `1` if the system's parallelism cannot be determined. +/// +/// The result is cached after the first call. pub fn get_available_parallelism() -> usize { - available_parallelism() - .unwrap_or(NonZero::new(1).expect("literal value `1` shouldn't be zero")) - .get() + static PARALLELISM: LazyLock = LazyLock::new(|| { + available_parallelism() + .unwrap_or(NonZero::new(1).expect("literal value `1` shouldn't be zero")) + .get() + }); + *PARALLELISM } /// Converts a collection of function arguments into a fixed-size array of length N @@ -938,13 +1197,296 @@ pub fn take_function_args( }) } +/// Returns the inner values of a list, or an error otherwise +/// For [`ListArray`] and [`LargeListArray`], if it's sliced, it returns a +/// sliced array too. Therefore, too reconstruct a list using it, +/// you must adjust the offsets using [`adjust_offsets_for_slice`] +pub fn list_values(array: &dyn Array) -> Result { + match array.data_type() { + DataType::List(_) => Ok(sliced_list_values(array.as_list::())), + DataType::LargeList(_) => Ok(sliced_list_values(array.as_list::())), + DataType::FixedSizeList(_, _) => { + Ok(Arc::clone(array.as_fixed_size_list().values())) + } + other => _exec_err!("expected list, got {other}"), + } +} + +fn sliced_list_values(list: &GenericListArray) -> ArrayRef { + let values = list.values(); + let offsets = list.offsets(); + + if let (Some(first), Some(last)) = (offsets.first(), offsets.last()) { + let first = first.as_usize(); + let last = last.as_usize(); + + if first != 0 || last != values.len() { + return values.slice(first, last - first); + } + } + + Arc::clone(values) +} + +/// If `list` is sliced, returns an adjusted offset buffer so that +/// it points to the sliced portion of the list values, and not the whole list values +pub fn adjust_offsets_for_slice( + list: &GenericListArray, +) -> OffsetBuffer { + let offsets = list.offsets(); + + if let (Some(first), Some(last)) = (offsets.first(), offsets.last()) + && (!first.is_zero() || last.as_usize() != list.values().len()) + { + let offsets = offsets.iter().map(|offset| *offset - *first).collect(); + + //todo: use unsafe Offset::new_unchecked? + return OffsetBuffer::new(offsets); + } + + offsets.clone() +} + +/// For lists and large lists, truncates the sublist of null values +/// Otherwise returns an error +pub fn remove_list_null_values(array: &ArrayRef) -> Result { + // todo: handle list view and map + match array.data_type() { + DataType::List(_) => Ok(Arc::new(truncate_list_nulls(array.as_list::())?)), + DataType::LargeList(_) => { + Ok(Arc::new(truncate_list_nulls(array.as_list::())?)) + } + dt => _exec_err!("expected List or LargeList, got {dt}"), + } +} + +/// Create a new list array where all the nulls point to empty lists +fn truncate_list_nulls( + list: &GenericListArray, +) -> Result> { + if let Some(nulls) = list.nulls() + && nulls.null_count() > 0 + { + let lengths = length(list)?; + let zero: &dyn Datum = if lengths.data_type() == &DataType::Int32 { + &Int32Array::new_scalar(0) + } else { + &Int64Array::new_scalar(0) + }; + + let (mut valid_or_empty, _nulls) = eq(&lengths, zero)?.into_parts(); + valid_or_empty |= nulls.inner(); + let valid_or_empty = BooleanArray::from(valid_or_empty); + + if valid_or_empty.has_false() { + let array_data = list.values().to_data(); + let offsets = list.offsets(); + let capacity = offsets[offsets.len() - 1] - offsets[0]; + let mut mutable_array_data = + MutableArrayData::new(vec![&array_data], false, capacity.as_usize()); + + let (valid_or_empty, _nulls) = valid_or_empty.into_parts(); + + for (start, end) in valid_or_empty.set_slices() { + mutable_array_data.extend( + 0, + offsets[start].as_usize(), + offsets[end].as_usize(), + ); + } + + let lengths = std::iter::zip(offsets.lengths(), nulls) + .map(|(length, is_valid)| if is_valid { length } else { 0 }); + + let offsets = OffsetBuffer::from_lengths(lengths); + let values = make_array(mutable_array_data.freeze()); + + let field = match list.data_type() { + DataType::List(field) => field, + DataType::LargeList(field) => field, + _ => unreachable!(), + }; + + return Ok(GenericListArray::try_new( + Arc::clone(field), + offsets, + values, + list.nulls().cloned(), + )?); + } + } + Ok(list.clone()) +} + +/// If `array` is a list or a map, returns a new array of the same length as it's inner values +/// where each value is the 1-based index of the sublist it's contained. Example: +/// +/// `[[1], [2, 3], [4, 5, 6]] => [1, 2, 2, 3, 3, 3]` +/// +/// Otherwise returns an error +pub fn list_values_row_number(array: &dyn Array) -> Result { + match array.data_type() { + DataType::List(_) => Ok(Arc::new(variable_size_list_values_row_number::< + Int32Type, + >(array.as_list().offsets()))), + DataType::LargeList(_) => Ok(Arc::new(variable_size_list_values_row_number::< + Int64Type, + >(array.as_list().offsets()))), + DataType::ListView(_) => Ok(Arc::new(variable_size_list_values_row_number::< + Int32Type, + >(array.as_list_view().offsets()))), + DataType::LargeListView(_) => { + Ok(Arc::new(variable_size_list_values_row_number::( + array.as_list_view().offsets(), + ))) + } + DataType::FixedSizeList(_, _) => { + let fixed_size_list = array.as_fixed_size_list(); + + Ok(Arc::new(fsl_values_row_number( + fixed_size_list.value_length(), + fixed_size_list.len(), + )?)) + } + DataType::Map(_, _) => Ok(Arc::new(variable_size_list_values_row_number::< + Int32Type, + >(array.as_map().offsets()))), + other => _exec_err!("expected list, got {other}"), + } +} + +/// [0, 2, 2, 5, 6] -> [0, 0, 2, 2, 2, 3] +fn variable_size_list_values_row_number( + offsets: &[T::Native], +) -> PrimitiveArray { + let mut rows_number = Vec::with_capacity( + offsets[offsets.len() - 1].to_usize().unwrap() - offsets[0].to_usize().unwrap(), + ); + + for (i, w) in offsets.windows(2).enumerate() { + let len = w[1].as_usize() - w[0].as_usize(); + rows_number.extend(repeat_n(T::Native::usize_as(i), len)); + } + + PrimitiveArray::new(rows_number.into(), None) +} + +/// (2, 3) -> [0, 0, 1, 1, 2, 2] +fn fsl_values_row_number(list_size: i32, array_len: usize) -> Result { + let list_size = list_size.to_usize().ok_or_else(|| { + _exec_datafusion_err!("fsl_values_index: invalid list_size {list_size}") + })?; + + let mut rows_number = Vec::with_capacity(list_size * array_len); + + for i in 0..array_len { + rows_number.extend(repeat_n(i as i32, list_size)); + } + + Ok(PrimitiveArray::new(rows_number.into(), None)) +} + +/// Replace `-0.0` with `+0.0` in any `Float16`, `Float32`, or `Float64` array. +/// For non-float arrays returns the input unchanged. NaN payloads are +/// preserved. +/// +/// Arrow's comparison kernels (`arrow::compute::kernels::cmp::eq` etc.) and +/// row-encoding (`arrow::row::RowConverter`) use IEEE 754 totalOrder +/// semantics, which treats `-0.0` and `+0.0` as distinct. SQL semantics +/// (PostgreSQL / IEEE 754 equality) require them to compare equal, so +/// callers normalize before invoking those kernels. +/// +/// The common case - no `-0.0` present - is allocation-free: a single +/// read-only scan of the underlying buffer (auto-vectorizable to an +/// OR-reduction) decides whether to fall through to the rewriting path. +/// Only arrays that actually contain `-0.0` pay for a new buffer. +pub fn normalize_float_zero(array: &ArrayRef) -> ArrayRef { + use arrow::array::{Float16Array, Float32Array, Float64Array}; + use arrow::datatypes::{Float16Type, Float32Type, Float64Type}; + // -0.0 has only the sign bit set; no other finite or NaN value shares + // this bit pattern, so a strict-equality scan reliably gates the rewrite. + const NEG_ZERO_F16_BITS: u16 = half::f16::NEG_ZERO.to_bits(); + const NEG_ZERO_F32_BITS: u32 = (-0.0_f32).to_bits(); + const NEG_ZERO_F64_BITS: u64 = (-0.0_f64).to_bits(); + match array.data_type() { + DataType::Float32 => { + let arr: &Float32Array = array.as_primitive::(); + if !arr + .values() + .iter() + .any(|v| v.to_bits() == NEG_ZERO_F32_BITS) + { + return Arc::clone(array); + } + let normalized: Float32Array = + arr.unary(|v| if v.to_bits() << 1 == 0 { 0.0_f32 } else { v }); + Arc::new(normalized) + } + DataType::Float64 => { + let arr: &Float64Array = array.as_primitive::(); + if !arr + .values() + .iter() + .any(|v| v.to_bits() == NEG_ZERO_F64_BITS) + { + return Arc::clone(array); + } + let normalized: Float64Array = + arr.unary(|v| if v.to_bits() << 1 == 0 { 0.0_f64 } else { v }); + Arc::new(normalized) + } + DataType::Float16 => { + let arr: &Float16Array = array.as_primitive::(); + if !arr + .values() + .iter() + .any(|v| v.to_bits() == NEG_ZERO_F16_BITS) + { + return Arc::clone(array); + } + let normalized: Float16Array = arr.unary(|v| { + if v.to_bits() << 1 == 0 { + half::f16::from_bits(0) + } else { + v + } + }); + Arc::new(normalized) + } + _ => Arc::clone(array), + } +} + +/// Replace `-0.0` with `+0.0` in `Float16`, `Float32`, or `Float64` scalar +/// values. Other variants are returned unchanged. See [`normalize_float_zero`] +/// for context. +pub fn normalize_float_zero_scalar(scalar: ScalarValue) -> ScalarValue { + match scalar { + ScalarValue::Float32(Some(v)) if v.to_bits() << 1 == 0 => { + ScalarValue::Float32(Some(0.0)) + } + ScalarValue::Float64(Some(v)) if v.to_bits() << 1 == 0 => { + ScalarValue::Float64(Some(0.0)) + } + ScalarValue::Float16(Some(v)) if v.to_bits() << 1 == 0 => { + ScalarValue::Float16(Some(half::f16::from_bits(0))) + } + other => other, + } +} + #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::ScalarValue::Null; - use arrow::array::Float64Array; + use arrow::{ + array::{Float64Array, Int32Array}, + buffer::NullBuffer, + datatypes::Int32Type, + }; use sqlparser::ast::Ident; - use sqlparser::tokenizer::Span; #[test] fn test_bisect_linear_left_and_right() -> Result<()> { @@ -1173,7 +1715,7 @@ mod tests { let expected_parsed = vec![Ident { value: identifier.to_string(), quote_style, - span: Span::empty(), + span: sqlparser::tokenizer::Span::empty(), }]; assert_eq!( @@ -1244,4 +1786,210 @@ mod tests { assert_eq!(expected, transposed); Ok(()) } + + #[test] + fn test_sliced_list_values() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ]; + + let list = ListArray::from_iter_primitive::(data); + + assert_eq!( + sliced_list_values(&list).as_primitive(), + &Int32Array::from(vec![ + Some(0), + Some(1), + Some(2), + Some(3), + None, + Some(5), + Some(6), + Some(7) + ]) + ); + + assert_eq!( + sliced_list_values(&list.slice(0, 1)).as_primitive(), + &Int32Array::from(vec![Some(0), Some(1), Some(2)]) + ); + + assert_eq!( + sliced_list_values(&list.slice(2, 1)).as_primitive(), + &Int32Array::from(vec![Some(3), None, Some(5)]) + ); + + assert_eq!( + sliced_list_values(&list.slice(3, 1)).as_primitive(), + &Int32Array::from(vec![Some(6), Some(7)]) + ); + + assert!(sliced_list_values(&list.slice(0, 0)).is_empty()); + assert!(sliced_list_values(&list.slice(1, 0)).is_empty()); + assert!(sliced_list_values(&list.slice(3, 0)).is_empty()); + } + + #[test] + fn test_adjust_offsets() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(6), Some(7)]), + ]; + let list = ListArray::from_iter_primitive::(data); + + assert_eq!( + adjust_offsets_for_slice(&list), + OffsetBuffer::from_lengths([3, 0, 3, 2]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(0, 1)), + OffsetBuffer::from_lengths([3]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(1, 2)), + OffsetBuffer::from_lengths([0, 3]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(1, 3)), + OffsetBuffer::from_lengths([0, 3, 2]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(0, 0)), + OffsetBuffer::from_lengths([]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(1, 0)), + OffsetBuffer::from_lengths([]) + ); + + assert_eq!( + adjust_offsets_for_slice(&list.slice(3, 0)), + OffsetBuffer::from_lengths([]) + ); + } + + fn create_i32_list( + values: impl Into, + offsets: OffsetBuffer, + nulls: Option, + ) -> ListArray { + let list_field = Arc::new(Field::new_list_field(DataType::Int32, true)); + + ListArray::new(list_field, offsets, Arc::new(values.into()), nulls) + } + + #[test] + fn test_remove_list_null_values_list() { + let list = Arc::new(create_i32_list( + vec![100, 20, 10, 0, 0, 0, 0, 1, 50], + OffsetBuffer::::from_lengths(vec![3, 4, 0, 2, 0]), + Some(NullBuffer::from(vec![true, false, false, true, false])), + )) as ArrayRef; + + let res = remove_list_null_values(&list).unwrap(); + let res = res.as_list::(); + + let expected = Arc::new(create_i32_list( + vec![100, 20, 10, 1, 50], + OffsetBuffer::::from_lengths(vec![3, 0, 0, 2, 0]), + Some(NullBuffer::from(vec![true, false, false, true, false])), + )) as ArrayRef; + let expected = expected.as_list::(); + + assert_eq!(res, expected); + // check above skips inner value of nulls + assert_eq!(res.values(), expected.values()); + assert_eq!(res.offsets(), expected.offsets()); + } + + #[test] + fn test_list_array_values_row_number() { + assert_eq!( + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([1, 3, 0, 2,]) + ), + Int32Array::from(vec![0, 1, 1, 1, 3, 3]) + ); + + assert_eq!( + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([]) + ), + Int32Array::new_null(0) + ); + + assert_eq!( + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([0]) + ), + Int32Array::new_null(0) + ); + + assert_eq!( + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([0, 0]) + ), + Int32Array::new_null(0) + ); + + assert_eq!( + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([1]) + ), + Int32Array::from(vec![0]) + ); + + assert_eq!( + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([2]) + ), + Int32Array::from(vec![0, 0]) + ); + } + + #[test] + fn test_fsl_values_row_number() { + assert_eq!( + fsl_values_row_number(2, 3).unwrap(), + Int32Array::from(vec![0, 0, 1, 1, 2, 2]) + ); + + assert_eq!( + fsl_values_row_number(1, 3).unwrap(), + Int32Array::from(vec![0, 1, 2]) + ); + + assert_eq!( + fsl_values_row_number(2, 1).unwrap(), + Int32Array::from(vec![0, 0]) + ); + + assert_eq!( + fsl_values_row_number(2, 0).unwrap(), + Int32Array::new_null(0), + ); + + assert_eq!( + fsl_values_row_number(0, 2).unwrap(), + Int32Array::new_null(0), + ); + + assert_eq!( + fsl_values_row_number(0, 0).unwrap(), + Int32Array::new_null(0), + ); + + fsl_values_row_number(-1, 2).unwrap_err(); + fsl_values_row_number(-1, 0).unwrap_err(); + } } diff --git a/datafusion/datasource-parquet/Cargo.toml b/datafusion/datasource-parquet/Cargo.toml index a5f6f56ac6f33..32424069c17a0 100644 --- a/datafusion/datasource-parquet/Cargo.toml +++ b/datafusion/datasource-parquet/Cargo.toml @@ -32,6 +32,7 @@ all-features = true [dependencies] arrow = { workspace = true } +arrow-schema = { workspace = true } async-trait = { workspace = true } bytes = { workspace = true } datafusion-common = { workspace = true, features = ["object_store", "parquet"] } @@ -39,6 +40,7 @@ datafusion-common-runtime = { workspace = true } datafusion-datasource = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-adapter = { workspace = true } @@ -56,6 +58,10 @@ tokio = { workspace = true } [dev-dependencies] chrono = { workspace = true } +criterion = { workspace = true } +datafusion-functions = { workspace = true } +datafusion-functions-nested = { workspace = true } +tempfile = { workspace = true } # Note: add additional linter rules in lib.rs. # Rust does not support workspace + new linter rules in subcrates yet @@ -73,3 +79,15 @@ parquet_encryption = [ "datafusion-common/parquet_encryption", "datafusion-execution/parquet_encryption", ] + +[[bench]] +name = "parquet_nested_filter_pushdown" +harness = false + +[[bench]] +name = "parquet_struct_filter_pushdown" +harness = false + +[[bench]] +name = "parquet_metadata_statistics" +harness = false diff --git a/datafusion/datasource-parquet/benches/parquet_metadata_statistics.rs b/datafusion/datasource-parquet/benches/parquet_metadata_statistics.rs new file mode 100644 index 0000000000000..46ebd100fde88 --- /dev/null +++ b/datafusion/datasource-parquet/benches/parquet_metadata_statistics.rs @@ -0,0 +1,303 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for deriving DataFusion table statistics from Parquet metadata. +//! +//! This mirrors the structure of Arrow's `arrow_statistics` benchmark: build +//! Parquet metadata once, then repeatedly measure statistics extraction. The +//! benchmark targets the cold planning/statistics path used by listing tables. + +use std::hint::black_box; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use criterion::{BatchSize, BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_datasource_parquet::metadata::DFParquetMetadata; +use parquet::arrow::ArrowSchemaConverter; +use parquet::data_type::ByteArray; +use parquet::file::metadata::{ + ColumnChunkMetaData, FileMetaData, ParquetMetaData, RowGroupMetaData, +}; +use parquet::file::statistics::{Statistics as ParquetStatistics, ValueStatistics}; + +const ROWS_PER_GROUP: usize = 8; + +#[derive(Debug, Copy, Clone)] +struct BenchmarkSpec { + columns: usize, + row_groups: usize, + metadata: MetadataState, +} + +#[derive(Debug, Copy, Clone)] +enum MetadataState { + Full, + Mixed, + None, +} + +impl std::fmt::Display for MetadataState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Full => write!(f, "full"), + Self::Mixed => write!(f, "mixed"), + Self::None => write!(f, "none"), + } + } +} + +struct BenchmarkCase { + schema: SchemaRef, + metadata: ParquetMetaData, +} + +fn parquet_metadata_statistics(c: &mut Criterion) { + let metadata_states = [ + MetadataState::Full, + MetadataState::Mixed, + MetadataState::None, + ]; + let column_counts = [8, 64, 256]; + let row_group_counts = [1, 32, 128]; + + let mut group = c.benchmark_group("parquet_metadata_statistics"); + + for metadata in metadata_states { + for columns in column_counts { + for row_groups in row_group_counts { + let spec = BenchmarkSpec { + columns, + row_groups, + metadata, + }; + group.bench_function( + BenchmarkId::from_parameter(format!( + "metadata_{}_col_{}_rg_{}", + spec.metadata, spec.columns, spec.row_groups, + )), + |b| { + b.iter_batched( + || BenchmarkCase::new(spec), + |case| { + let statistics = + DFParquetMetadata::statistics_from_parquet_metadata( + black_box(&case.metadata), + black_box(&case.schema), + ) + .expect("statistics extraction failed"); + black_box(statistics); + }, + BatchSize::PerIteration, + ); + }, + ); + } + } + } + + group.finish(); +} + +impl BenchmarkCase { + fn new(spec: BenchmarkSpec) -> Self { + let schema = make_schema(spec.columns); + let metadata = match spec.metadata { + MetadataState::Full => { + make_synthetic_metadata(&schema, spec, full_statistics) + } + MetadataState::Mixed => { + make_synthetic_metadata(&schema, spec, mixed_statistics) + } + MetadataState::None => make_synthetic_metadata(&schema, spec, |_, _, _| None), + }; + + Self { schema, metadata } + } +} + +fn make_synthetic_metadata( + schema: &SchemaRef, + spec: BenchmarkSpec, + statistics: fn(&DataType, usize, usize) -> Option, +) -> ParquetMetaData { + let schema_descr = Arc::new( + ArrowSchemaConverter::new() + .convert(schema.as_ref()) + .expect("failed to convert arrow schema"), + ); + let row_groups = (0..spec.row_groups) + .map(|row_group| { + let columns = schema + .fields() + .iter() + .enumerate() + .map(|(column_idx, field)| { + let mut builder = + ColumnChunkMetaData::builder(schema_descr.column(column_idx)); + if let Some(statistics) = + statistics(field.data_type(), column_idx, row_group) + { + builder = builder.set_statistics(statistics); + } + builder + .set_num_values(ROWS_PER_GROUP as i64) + .build() + .expect("failed to build column metadata") + }) + .collect::>(); + + RowGroupMetaData::builder(Arc::clone(&schema_descr)) + .set_num_rows(ROWS_PER_GROUP as i64) + .set_total_byte_size((spec.columns * ROWS_PER_GROUP * 8) as i64) + .set_column_metadata(columns) + .build() + .expect("failed to build row group metadata") + }) + .collect::>(); + + let file_metadata = FileMetaData::new( + 1, + (spec.row_groups * ROWS_PER_GROUP) as i64, + Some("datafusion parquet metadata benchmark".to_string()), + None, + schema_descr, + None, + ); + + ParquetMetaData::new(file_metadata, row_groups) +} + +fn full_statistics( + data_type: &DataType, + column_idx: usize, + row_group: usize, +) -> Option { + Some(statistics( + data_type, + column_idx, + row_group, + true, + true, + Some(null_count_for_rows()), + )) +} + +fn mixed_statistics( + data_type: &DataType, + column_idx: usize, + row_group: usize, +) -> Option { + if column_idx.is_multiple_of(16) || row_group.is_multiple_of(5) { + return None; + } + + let min_exact = !row_group.is_multiple_of(3); + let max_exact = !row_group.is_multiple_of(4); + let null_count = (!row_group.is_multiple_of(7)).then(null_count_for_rows); + + Some(statistics( + data_type, column_idx, row_group, min_exact, max_exact, null_count, + )) +} + +fn statistics( + data_type: &DataType, + column_idx: usize, + row_group: usize, + min_exact: bool, + max_exact: bool, + null_count: Option, +) -> ParquetStatistics { + let min_row = first_non_null_row(); + let max_row = last_non_null_row(); + + match data_type { + DataType::Int64 => { + let min = min_row.map(|row| value(column_idx, row_group, row)); + let max = max_row.map(|row| value(column_idx, row_group, row)); + ParquetStatistics::Int64( + ValueStatistics::new(min, max, None, null_count, false) + .with_min_is_exact(min_exact) + .with_max_is_exact(max_exact), + ) + } + DataType::Float64 => { + let min = min_row.map(|row| value(column_idx, row_group, row) as f64 * 1.5); + let max = max_row.map(|row| value(column_idx, row_group, row) as f64 * 1.5); + ParquetStatistics::Double( + ValueStatistics::new(min, max, None, null_count, false) + .with_min_is_exact(min_exact) + .with_max_is_exact(max_exact), + ) + } + DataType::Utf8 => { + let min = min_row.map(|row| { + ByteArray::from(string_value(column_idx, row_group, row).into_bytes()) + }); + let max = max_row.map(|row| { + ByteArray::from(string_value(column_idx, row_group, row).into_bytes()) + }); + ParquetStatistics::ByteArray( + ValueStatistics::new(min, max, None, null_count, false) + .with_min_is_exact(min_exact) + .with_max_is_exact(max_exact), + ) + } + other => unreachable!("unsupported benchmark data type: {other:?}"), + } +} + +fn make_schema(columns: usize) -> SchemaRef { + let fields = (0..columns) + .map(|idx| { + let data_type = match idx % 4 { + 0 => DataType::Int64, + 1 => DataType::Float64, + 2 => DataType::Utf8, + _ => DataType::Int64, + }; + Field::new(format!("c{idx:04}"), data_type, true) + }) + .collect::>(); + + Arc::new(Schema::new(fields)) +} + +fn first_non_null_row() -> Option { + (0..ROWS_PER_GROUP).find(|row| !row.is_multiple_of(7)) +} + +fn last_non_null_row() -> Option { + (0..ROWS_PER_GROUP).rev().find(|row| !row.is_multiple_of(7)) +} + +fn null_count_for_rows() -> u64 { + (0..ROWS_PER_GROUP) + .filter(|row| row.is_multiple_of(7)) + .count() as u64 +} + +fn value(column_idx: usize, row_group: usize, row: usize) -> i64 { + (column_idx as i64 * 10_000) + (row_group as i64 * 100) + row as i64 +} + +fn string_value(column_idx: usize, row_group: usize, row: usize) -> String { + format!("s{column_idx:04}_{row_group:04}_{row:04}") +} + +criterion_group!(benches, parquet_metadata_statistics); +criterion_main!(benches); diff --git a/datafusion/datasource-parquet/benches/parquet_nested_filter_pushdown.rs b/datafusion/datasource-parquet/benches/parquet_nested_filter_pushdown.rs new file mode 100644 index 0000000000000..02137b5a1d288 --- /dev/null +++ b/datafusion/datasource-parquet/benches/parquet_nested_filter_pushdown.rs @@ -0,0 +1,238 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::path::{Path, PathBuf}; +use std::sync::{Arc, LazyLock}; + +use arrow::array::{ + BinaryBuilder, BooleanArray, ListBuilder, RecordBatch, StringBuilder, +}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use criterion::{Criterion, Throughput, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_datasource_parquet::{ParquetFileMetrics, build_row_filter}; +use datafusion_expr::{Expr, col}; +use datafusion_functions_nested::expr_fn::array_has; +use datafusion_physical_expr::planner::logical2physical; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; +use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; +use parquet::arrow::{ArrowWriter, ProjectionMask}; +use parquet::file::properties::WriterProperties; +use tempfile::TempDir; + +const ROW_GROUP_ROW_COUNT: usize = 10_000; +const TOTAL_ROW_GROUPS: usize = 10; +const TOTAL_ROWS: usize = ROW_GROUP_ROW_COUNT * TOTAL_ROW_GROUPS; +const TARGET_VALUE: &str = "target_value"; +const COLUMN_NAME: &str = "list_col"; +const PAYLOAD_COLUMN_NAME: &str = "payload"; +// Large binary payload to emphasize decoding overhead when pushdown is disabled. +const PAYLOAD_BYTES: usize = 8 * 1024; + +struct BenchmarkDataset { + _tempdir: TempDir, + file_path: PathBuf, +} + +impl BenchmarkDataset { + fn path(&self) -> &Path { + &self.file_path + } +} + +static DATASET: LazyLock = LazyLock::new(|| { + create_dataset().expect("failed to prepare parquet benchmark dataset") +}); + +fn parquet_nested_filter_pushdown(c: &mut Criterion) { + let dataset_path = DATASET.path().to_owned(); + let mut group = c.benchmark_group("parquet_nested_filter_pushdown"); + group.throughput(Throughput::Elements(TOTAL_ROWS as u64)); + + group.bench_function("no_pushdown", |b| { + let file_schema = setup_reader(&dataset_path); + let predicate = logical2physical(&create_predicate(), &file_schema); + b.iter(|| { + let matched = scan_with_predicate(&dataset_path, &predicate, false) + .expect("baseline parquet scan with filter succeeded"); + assert_eq!(matched, ROW_GROUP_ROW_COUNT); + }); + }); + + group.bench_function("with_pushdown", |b| { + let file_schema = setup_reader(&dataset_path); + let predicate = logical2physical(&create_predicate(), &file_schema); + b.iter(|| { + let matched = scan_with_predicate(&dataset_path, &predicate, true) + .expect("pushdown parquet scan with filter succeeded"); + assert_eq!(matched, ROW_GROUP_ROW_COUNT); + }); + }); + + group.finish(); +} + +fn setup_reader(path: &Path) -> SchemaRef { + let file = std::fs::File::open(path).expect("failed to open file"); + let builder = + ParquetRecordBatchReaderBuilder::try_new(file).expect("failed to build reader"); + Arc::clone(builder.schema()) +} + +fn create_predicate() -> Expr { + array_has( + col(COLUMN_NAME), + Expr::Literal(ScalarValue::Utf8(Some(TARGET_VALUE.to_string())), None), + ) +} + +fn scan_with_predicate( + path: &Path, + predicate: &Arc, + pushdown: bool, +) -> datafusion_common::Result { + let file = std::fs::File::open(path)?; + let builder = ParquetRecordBatchReaderBuilder::try_new(file)?; + let metadata = builder.metadata().clone(); + let file_schema = builder.schema(); + let projection = ProjectionMask::all(); + + let metrics = ExecutionPlanMetricsSet::new(); + let file_metrics = ParquetFileMetrics::new(0, &path.display().to_string(), &metrics); + + let builder = if pushdown { + if let Some(row_filter) = + build_row_filter(predicate, file_schema, &metadata, false, &file_metrics)? + { + builder.with_row_filter(row_filter) + } else { + builder + } + } else { + builder + }; + + let reader = builder.with_projection(projection).build()?; + + let mut matched_rows = 0usize; + for batch in reader { + let batch = batch?; + matched_rows += count_matches(predicate, &batch)?; + } + + if pushdown { + let pruned_rows = file_metrics.pushdown_rows_pruned.value(); + assert_eq!( + pruned_rows, + TOTAL_ROWS - matched_rows, + "row-level pushdown should prune 90% of rows" + ); + } + + Ok(matched_rows) +} + +fn count_matches( + expr: &Arc, + batch: &RecordBatch, +) -> datafusion_common::Result { + let values = expr.evaluate(batch)?.into_array(batch.num_rows())?; + let bools = values + .as_any() + .downcast_ref::() + .expect("boolean filter result"); + + Ok(bools.iter().filter(|v| matches!(v, Some(true))).count()) +} + +fn create_dataset() -> datafusion_common::Result { + let tempdir = TempDir::new()?; + let file_path = tempdir.path().join("nested_lists.parquet"); + + let field = Arc::new(Field::new("item", DataType::Utf8, true)); + let schema = Arc::new(Schema::new(vec![ + Field::new(COLUMN_NAME, DataType::List(field), false), + Field::new(PAYLOAD_COLUMN_NAME, DataType::Binary, false), + ])); + + let writer_props = WriterProperties::builder() + .set_max_row_group_row_count(Some(ROW_GROUP_ROW_COUNT)) + .build(); + + let mut writer = ArrowWriter::try_new( + std::fs::File::create(&file_path)?, + Arc::clone(&schema), + Some(writer_props), + )?; + + // Create sorted row groups with distinct values so that min/max statistics + // allow skipping most groups when applying a selective predicate. + let sorted_values = [ + "alpha", + "bravo", + "charlie", + "delta", + "echo", + "foxtrot", + "golf", + "hotel", + "india", + TARGET_VALUE, + ]; + + for value in sorted_values { + let batch = build_list_batch(&schema, value, ROW_GROUP_ROW_COUNT)?; + writer.write(&batch)?; + } + + writer.close()?; + + // Ensure the writer respected the requested row group size + let reader = + ParquetRecordBatchReaderBuilder::try_new(std::fs::File::open(&file_path)?)?; + assert_eq!(reader.metadata().row_groups().len(), TOTAL_ROW_GROUPS); + + Ok(BenchmarkDataset { + _tempdir: tempdir, + file_path, + }) +} + +fn build_list_batch( + schema: &SchemaRef, + value: &str, + len: usize, +) -> datafusion_common::Result { + let mut builder = ListBuilder::new(StringBuilder::new()); + let mut payload_builder = BinaryBuilder::new(); + let payload = vec![1u8; PAYLOAD_BYTES]; + for _ in 0..len { + builder.values().append_value(value); + builder.append(true); + payload_builder.append_value(&payload); + } + + let array = builder.finish(); + let payload_array = payload_builder.finish(); + Ok(RecordBatch::try_new( + Arc::clone(schema), + vec![Arc::new(array), Arc::new(payload_array)], + )?) +} + +criterion_group!(benches, parquet_nested_filter_pushdown); +criterion_main!(benches); diff --git a/datafusion/datasource-parquet/benches/parquet_struct_filter_pushdown.rs b/datafusion/datasource-parquet/benches/parquet_struct_filter_pushdown.rs new file mode 100644 index 0000000000000..b52408d4222d8 --- /dev/null +++ b/datafusion/datasource-parquet/benches/parquet_struct_filter_pushdown.rs @@ -0,0 +1,353 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for struct field filter pushdown in Parquet. +//! +//! Compares scanning with vs without row-level filter pushdown for +//! predicates on struct sub-fields (e.g. `get_field(s, 'id') = 42`). +//! +//! The dataset schema (in SQL-like notation): +//! +//! ```sql +//! CREATE TABLE t ( +//! id INT, -- top-level id, useful for correctness checks +//! large_string TEXT, -- wide column so SELECT * is expensive +//! s STRUCT< +//! id: INT, -- mirrors top-level id +//! large_string: TEXT -- wide sub-field; pushdown with proper projection +//! -- should avoid reading this when filtering on s.id +//! > +//! ); +//! ``` +//! +//! Benchmark queries: +//! +//! 1. `SELECT * FROM t WHERE get_field(s, 'id') = 42` +//! - no pushdown vs. row-level filter pushdown +//! 2. `SELECT * FROM t WHERE get_field(s, 'id') = id` +//! - cross-column predicate; no pushdown vs. row-level filter pushdown +//! 3. `SELECT id FROM t WHERE get_field(s, 'id') = 42` +//! - narrow projection; pushdown should avoid reading s.large_string + +use std::path::{Path, PathBuf}; +use std::sync::{Arc, LazyLock}; + +use arrow::array::{BooleanArray, Int32Array, RecordBatch, StringBuilder, StructArray}; +use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; +use criterion::{Criterion, Throughput, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_datasource_parquet::{ParquetFileMetrics, build_row_filter}; +use datafusion_expr::{Expr, col}; +use datafusion_physical_expr::planner::logical2physical; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; +use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; +use parquet::arrow::{ArrowWriter, ProjectionMask}; +use parquet::file::properties::WriterProperties; +use tempfile::TempDir; + +const ROW_GROUP_ROW_COUNT: usize = 10_000; +const TOTAL_ROW_GROUPS: usize = 10; +const TOTAL_ROWS: usize = ROW_GROUP_ROW_COUNT * TOTAL_ROW_GROUPS; +/// Only one row group will contain the target value. +const TARGET_VALUE: i32 = 42; +const ID_COLUMN_NAME: &str = "id"; +const LARGE_STRING_COLUMN_NAME: &str = "large_string"; +const STRUCT_COLUMN_NAME: &str = "s"; +// Large string payload to emphasize decoding overhead when pushdown is disabled. +const LARGE_STRING_LEN: usize = 8 * 1024; + +struct BenchmarkDataset { + _tempdir: TempDir, + file_path: PathBuf, +} + +impl BenchmarkDataset { + fn path(&self) -> &Path { + &self.file_path + } +} + +static DATASET: LazyLock = LazyLock::new(|| { + create_dataset().expect("failed to prepare parquet benchmark dataset") +}); + +fn parquet_struct_filter_pushdown(c: &mut Criterion) { + let dataset_path = DATASET.path().to_owned(); + let mut group = c.benchmark_group("parquet_struct_filter_pushdown"); + group.throughput(Throughput::Elements(TOTAL_ROWS as u64)); + + // Scenario 1: SELECT * FROM t WHERE get_field(s, 'id') = 42 + group.bench_function("select_star/no_pushdown", |b| { + let file_schema = setup_reader(&dataset_path); + let predicate = logical2physical(&struct_id_eq_literal(), &file_schema); + b.iter(|| { + let matched = scan(&dataset_path, &predicate, false, ProjectionMask::all()) + .expect("scan succeeded"); + assert_eq!(matched, ROW_GROUP_ROW_COUNT); + }); + }); + + group.bench_function("select_star/with_pushdown", |b| { + let file_schema = setup_reader(&dataset_path); + let predicate = logical2physical(&struct_id_eq_literal(), &file_schema); + b.iter(|| { + let matched = scan(&dataset_path, &predicate, true, ProjectionMask::all()) + .expect("scan succeeded"); + assert_eq!(matched, ROW_GROUP_ROW_COUNT); + }); + }); + + // Scenario 2: SELECT * FROM t WHERE get_field(s, 'id') = id + group.bench_function("select_star_cross_col/no_pushdown", |b| { + let file_schema = setup_reader(&dataset_path); + let predicate = logical2physical(&struct_id_eq_top_id(), &file_schema); + b.iter(|| { + let matched = scan(&dataset_path, &predicate, false, ProjectionMask::all()) + .expect("scan succeeded"); + assert_eq!(matched, TOTAL_ROWS); + }); + }); + + group.bench_function("select_star_cross_col/with_pushdown", |b| { + let file_schema = setup_reader(&dataset_path); + let predicate = logical2physical(&struct_id_eq_top_id(), &file_schema); + b.iter(|| { + let matched = scan(&dataset_path, &predicate, true, ProjectionMask::all()) + .expect("scan succeeded"); + assert_eq!(matched, TOTAL_ROWS); + }); + }); + + // Scenario 3: SELECT id FROM t WHERE get_field(s, 'id') = 42 + group.bench_function("select_id/no_pushdown", |b| { + let file_schema = setup_reader(&dataset_path); + let predicate = logical2physical(&struct_id_eq_literal(), &file_schema); + b.iter(|| { + // Without pushdown we must read all columns to evaluate the predicate. + let matched = scan(&dataset_path, &predicate, false, ProjectionMask::all()) + .expect("scan succeeded"); + assert_eq!(matched, ROW_GROUP_ROW_COUNT); + }); + }); + + group.bench_function("select_id/with_pushdown", |b| { + let file_schema = setup_reader(&dataset_path); + let predicate = logical2physical(&struct_id_eq_literal(), &file_schema); + let id_only = id_projection(&dataset_path); + b.iter(|| { + // With pushdown the filter runs first, then we only project `id`. + let matched = scan(&dataset_path, &predicate, true, id_only.clone()) + .expect("scan succeeded"); + assert_eq!(matched, ROW_GROUP_ROW_COUNT); + }); + }); + + group.finish(); +} + +fn setup_reader(path: &Path) -> SchemaRef { + let file = std::fs::File::open(path).expect("failed to open file"); + let builder = + ParquetRecordBatchReaderBuilder::try_new(file).expect("failed to build reader"); + Arc::clone(builder.schema()) +} + +/// `get_field(s, 'id') = TARGET_VALUE` +fn struct_id_eq_literal() -> Expr { + let get_field_expr = datafusion_functions::core::get_field().call(vec![ + col(STRUCT_COLUMN_NAME), + Expr::Literal(ScalarValue::Utf8(Some("id".to_string())), None), + ]); + get_field_expr.eq(Expr::Literal(ScalarValue::Int32(Some(TARGET_VALUE)), None)) +} + +/// `get_field(s, 'id') = id` +fn struct_id_eq_top_id() -> Expr { + let get_field_expr = datafusion_functions::core::get_field().call(vec![ + col(STRUCT_COLUMN_NAME), + Expr::Literal(ScalarValue::Utf8(Some("id".to_string())), None), + ]); + get_field_expr.eq(col(ID_COLUMN_NAME)) +} + +/// Build a [`ProjectionMask`] that only reads the top-level `id` leaf column. +fn id_projection(path: &Path) -> ProjectionMask { + let file = std::fs::File::open(path).expect("failed to open file"); + let builder = + ParquetRecordBatchReaderBuilder::try_new(file).expect("failed to build reader"); + let parquet_schema = builder.metadata().file_metadata().schema_descr_ptr(); + // Leaf index 0 corresponds to the top-level `id` column. + ProjectionMask::leaves(&parquet_schema, [0]) +} + +fn scan( + path: &Path, + predicate: &Arc, + pushdown: bool, + projection: ProjectionMask, +) -> datafusion_common::Result { + let file = std::fs::File::open(path)?; + let builder = ParquetRecordBatchReaderBuilder::try_new(file)?; + let metadata = builder.metadata().clone(); + let file_schema = builder.schema(); + + let metrics = ExecutionPlanMetricsSet::new(); + let file_metrics = ParquetFileMetrics::new(0, &path.display().to_string(), &metrics); + + let mut filter_applied = false; + let builder = if pushdown { + if let Some(row_filter) = + build_row_filter(predicate, file_schema, &metadata, false, &file_metrics)? + { + filter_applied = true; + builder.with_row_filter(row_filter) + } else { + builder + } + } else { + builder + }; + + // Only apply a narrow projection when the filter was actually pushed down. + // Otherwise we need all columns to evaluate the predicate manually. + let output_projection = if filter_applied { + projection + } else { + ProjectionMask::all() + }; + let reader = builder.with_projection(output_projection).build()?; + + let mut matched_rows = 0usize; + for batch in reader { + let batch = batch?; + if filter_applied { + // When the row filter was applied, rows are already filtered. + matched_rows += batch.num_rows(); + } else { + matched_rows += count_matches(predicate, &batch)?; + } + } + + Ok(matched_rows) +} + +fn count_matches( + expr: &Arc, + batch: &RecordBatch, +) -> datafusion_common::Result { + let values = expr.evaluate(batch)?.into_array(batch.num_rows())?; + let bools = values + .as_any() + .downcast_ref::() + .expect("boolean filter result"); + + Ok(bools.iter().filter(|v| matches!(v, Some(true))).count()) +} + +fn schema() -> SchemaRef { + let struct_fields = Fields::from(vec![ + Field::new("id", DataType::Int32, false), + Field::new(LARGE_STRING_COLUMN_NAME, DataType::Utf8, false), + ]); + Arc::new(Schema::new(vec![ + Field::new(ID_COLUMN_NAME, DataType::Int32, false), + Field::new(LARGE_STRING_COLUMN_NAME, DataType::Utf8, false), + Field::new(STRUCT_COLUMN_NAME, DataType::Struct(struct_fields), false), + ])) +} + +fn create_dataset() -> datafusion_common::Result { + let tempdir = TempDir::new()?; + let file_path = tempdir.path().join("struct_filter.parquet"); + + let schema = schema(); + let writer_props = WriterProperties::builder() + .set_max_row_group_row_count(Some(ROW_GROUP_ROW_COUNT)) + .build(); + + let mut writer = ArrowWriter::try_new( + std::fs::File::create(&file_path)?, + Arc::clone(&schema), + Some(writer_props), + )?; + + // Each row group has a distinct `s.id` value. Only one row group + // matches the target, so pushdown should prune 90% of rows. + for rg_idx in 0..TOTAL_ROW_GROUPS { + let id_value = if rg_idx == TOTAL_ROW_GROUPS - 1 { + TARGET_VALUE + } else { + (rg_idx as i32 + 1) * 1000 + }; + let batch = build_struct_batch(&schema, id_value, ROW_GROUP_ROW_COUNT)?; + writer.write(&batch)?; + } + + writer.close()?; + + let reader = + ParquetRecordBatchReaderBuilder::try_new(std::fs::File::open(&file_path)?)?; + assert_eq!(reader.metadata().row_groups().len(), TOTAL_ROW_GROUPS); + + Ok(BenchmarkDataset { + _tempdir: tempdir, + file_path, + }) +} + +fn build_struct_batch( + schema: &SchemaRef, + id_value: i32, + len: usize, +) -> datafusion_common::Result { + let large_string: String = "x".repeat(LARGE_STRING_LEN); + + // Top-level columns + let top_id_array = Arc::new(Int32Array::from(vec![id_value; len])); + let mut top_string_builder = StringBuilder::new(); + for _ in 0..len { + top_string_builder.append_value(&large_string); + } + let top_string_array = Arc::new(top_string_builder.finish()); + + // Struct sub-fields: s.id mirrors top-level id, s.large_string is the same payload + let struct_id_array = Arc::new(Int32Array::from(vec![id_value; len])); + let mut struct_string_builder = StringBuilder::new(); + for _ in 0..len { + struct_string_builder.append_value(&large_string); + } + let struct_string_array = Arc::new(struct_string_builder.finish()); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("id", DataType::Int32, false)), + struct_id_array as Arc, + ), + ( + Arc::new(Field::new(LARGE_STRING_COLUMN_NAME, DataType::Utf8, false)), + struct_string_array as Arc, + ), + ]); + + Ok(RecordBatch::try_new( + Arc::clone(schema), + vec![top_id_array, top_string_array, Arc::new(struct_array)], + )?) +} + +criterion_group!(benches, parquet_struct_filter_pushdown); +criterion_main!(benches); diff --git a/datafusion/datasource-parquet/src/access_plan.rs b/datafusion/datasource-parquet/src/access_plan.rs index 7399a2cd0856a..edbea39948f09 100644 --- a/datafusion/datasource-parquet/src/access_plan.rs +++ b/datafusion/datasource-parquet/src/access_plan.rs @@ -15,9 +15,15 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{assert_eq_or_internal_err, Result}; +use crate::sort::reverse_row_selection; +use arrow::datatypes::Schema; +use datafusion_common::{Result, assert_eq_or_internal_err}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use log::debug; +use parquet::arrow::arrow_reader::statistics::StatisticsConverter; use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; -use parquet::file::metadata::RowGroupMetaData; +use parquet::file::metadata::{ParquetMetaData, RowGroupMetaData}; /// A selection of rows and row groups within a ParquetFile to decode. /// @@ -82,10 +88,20 @@ use parquet::file::metadata::RowGroupMetaData; /// └───────────────────┘ /// Row Group 3 /// ``` +/// +/// For more background, please also see the [Embedding User-Defined Indexes in Apache Parquet Files blog] +/// +/// [Embedding User-Defined Indexes in Apache Parquet Files blog]: https://datafusion.apache.org/blog/2025/07/14/user-defined-parquet-indexes #[derive(Debug, Clone, PartialEq)] pub struct ParquetAccessPlan { /// How to access the i-th row group row_groups: Vec, + /// Whether all rows in the i-th row group are known to match the predicate. + /// + /// This is tracked separately from [`RowGroupAccess`] because it describes + /// whether row-level filter evaluation can be skipped, not which rows should + /// be read. + fully_matched: Vec, } /// Describes how the parquet reader will access a row group @@ -99,6 +115,24 @@ pub enum RowGroupAccess { Selection(RowSelection), } +/// A consecutive set of row groups that share the same row filter requirement. +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct RowGroupRun { + /// True if this run needs row filter evaluation. + pub(crate) needs_filter: bool, + /// The access plan for this run. + pub(crate) access_plan: ParquetAccessPlan, +} + +impl RowGroupRun { + fn new(needs_filter: bool, access_plan: ParquetAccessPlan) -> Self { + Self { + needs_filter, + access_plan, + } + } +} + impl RowGroupAccess { /// Return true if this row group should be scanned pub fn should_scan(&self) -> bool { @@ -114,6 +148,7 @@ impl ParquetAccessPlan { pub fn new_all(row_group_count: usize) -> Self { Self { row_groups: vec![RowGroupAccess::Scan; row_group_count], + fully_matched: vec![false; row_group_count], } } @@ -121,17 +156,26 @@ impl ParquetAccessPlan { pub fn new_none(row_group_count: usize) -> Self { Self { row_groups: vec![RowGroupAccess::Skip; row_group_count], + fully_matched: vec![false; row_group_count], } } /// Create a new `ParquetAccessPlan` from the specified [`RowGroupAccess`]es pub fn new(row_groups: Vec) -> Self { - Self { row_groups } + let row_group_count = row_groups.len(); + Self { + row_groups, + fully_matched: vec![false; row_group_count], + } } /// Set the i-th row group to the specified [`RowGroupAccess`] pub fn set(&mut self, idx: usize, access: RowGroupAccess) { + let should_scan = access.should_scan(); self.row_groups[idx] = access; + if !should_scan { + self.fully_matched[idx] = false; + } } /// skips the i-th row group (should not be scanned) @@ -149,6 +193,32 @@ impl ParquetAccessPlan { self.row_groups[idx].should_scan() } + /// Marks the i-th row group as fully matched. + /// + /// Fully matched row groups are still read according to their + /// [`RowGroupAccess`], but row-level filter evaluation can be skipped. + pub(crate) fn mark_fully_matched(&mut self, idx: usize) { + if self.should_scan(idx) { + self.fully_matched[idx] = true; + } + } + + /// Return true if the i-th row group is fully matched and scanned. + pub(crate) fn is_fully_matched(&self, idx: usize) -> bool { + self.should_scan(idx) && self.fully_matched[idx] + } + + /// Returns the fully matched row group flags. + pub(crate) fn fully_matched(&self) -> &Vec { + &self.fully_matched + } + + /// Return true if any scanned row group is fully matched. + fn has_fully_matched(&self) -> bool { + self.row_group_index_iter() + .any(|idx| self.is_fully_matched(idx)) + } + /// Set to scan only the [`RowSelection`] in the specified row group. /// /// Behavior is different depending on the existing access @@ -302,13 +372,10 @@ impl ParquetAccessPlan { /// Return an iterator over the row group indexes that should be scanned pub fn row_group_index_iter(&self) -> impl Iterator + '_ { - self.row_groups.iter().enumerate().filter_map(|(idx, b)| { - if b.should_scan() { - Some(idx) - } else { - None - } - }) + self.row_groups + .iter() + .enumerate() + .filter_map(|(idx, b)| if b.should_scan() { Some(idx) } else { None }) } /// Return a vec of all row group indexes to scan @@ -336,6 +403,233 @@ impl ParquetAccessPlan { pub fn into_inner(self) -> Vec { self.row_groups } + + /// Split this plan into consecutive row group runs that share the same row + /// filter requirement. + pub(crate) fn split_runs(self, needs_filter: bool) -> Vec { + if !needs_filter || !self.has_fully_matched() { + return vec![RowGroupRun::new(needs_filter, self)]; + } + + let num_row_groups = self.row_groups.len(); + let row_groups = self.row_groups; + let fully_matched = self.fully_matched; + let mut runs: Vec = Vec::new(); + + for (idx, (access, fully_matched)) in + row_groups.into_iter().zip(fully_matched).enumerate() + { + if !access.should_scan() { + continue; + } + + let row_group_needs_filter = !fully_matched; + if let Some(run) = runs + .last_mut() + .filter(|run| run.needs_filter == row_group_needs_filter) + { + run.access_plan.set(idx, access); + if fully_matched { + run.access_plan.mark_fully_matched(idx); + } + } else { + let mut run_plan = ParquetAccessPlan::new_none(num_row_groups); + run_plan.set(idx, access); + if fully_matched { + run_plan.mark_fully_matched(idx); + } + runs.push(RowGroupRun::new(row_group_needs_filter, run_plan)); + } + } + + if runs.is_empty() { + vec![RowGroupRun::new( + needs_filter, + ParquetAccessPlan::new_none(num_row_groups), + )] + } else { + runs + } + } + + /// Prepare this plan and resolve to the final `PreparedAccessPlan` + pub(crate) fn prepare( + self, + row_group_meta_data: &[RowGroupMetaData], + ) -> Result { + let row_group_indexes = self.row_group_indexes(); + let row_selection = self.into_overall_row_selection(row_group_meta_data)?; + + PreparedAccessPlan::new(row_group_indexes, row_selection) + } +} + +/// Represents a prepared, fully resolved [`ParquetAccessPlan`] +/// +/// The [`RowSelection`] represents the result of applying all pruning such as +/// user provided scans, Row Group statistics, DataPage statistics, and Bloom +/// Filters. +/// +/// This plan is what is passed to the parquet reader +pub(crate) struct PreparedAccessPlan { + /// Row group indexes to read + pub(crate) row_group_indexes: Vec, + /// Optional row selection for filtering within row groups + pub(crate) row_selection: Option, +} + +impl PreparedAccessPlan { + /// Create a new prepared access plan + fn new( + row_group_indexes: Vec, + row_selection: Option, + ) -> Result { + Ok(Self { + row_group_indexes, + row_selection, + }) + } + + /// Reorder row groups by their min statistics for the given sort order. + /// + /// This helps TopK queries find optimal values first. Row groups are + /// always sorted by min values in ASC order — direction (DESC) is + /// handled separately by `reverse()` which is applied after reorder. + /// + /// Gracefully skips reordering when: + /// - There is a row_selection (too complex to remap) + /// - 0 or 1 row groups (nothing to reorder) + /// - Sort expression is not a simple column reference + /// - Statistics are unavailable + pub(crate) fn reorder_by_statistics( + mut self, + sort_order: &LexOrdering, + file_metadata: &ParquetMetaData, + arrow_schema: &Schema, + ) -> Result { + // Skip if row_selection present (too complex to remap) + if self.row_selection.is_some() { + debug!("Skipping RG reorder: row_selection present"); + return Ok(self); + } + + // Nothing to reorder + if self.row_group_indexes.len() <= 1 { + return Ok(self); + } + + let first_sort_expr = sort_order.first(); + + // Extract column name from sort expression + let column: &Column = match first_sort_expr.expr.downcast_ref::() { + Some(col) => col, + None => { + debug!("Skipping RG reorder: sort expr is not a simple column"); + return Ok(self); + } + }; + + // Expected graceful skip: the sort column lives outside the + // file schema (e.g. a partition column whose ordering came + // through `reversed_satisfies` rather than `column_in_file_schema`). + // Parquet has no per-RG stats for it. Bail out quietly — no + // `debug_assert!` because this is a normal pushdown shape. + if arrow_schema.field_with_name(column.name()).is_err() { + debug!( + "Skipping RG reorder: column `{}` not in file schema", + column.name() + ); + return Ok(self); + } + + // From here, any `StatisticsConverter` / stats read / sort + // failure is unexpected — the column exists in the file + // schema, so building the converter and pulling typed mins + // should succeed on any well-formed parquet file. Trip a + // `debug_assert!` so CI catches regressions, but stay graceful + // in release so a single odd file can't take down a scan. + let converter = match StatisticsConverter::try_new( + column.name(), + arrow_schema, + file_metadata.file_metadata().schema_descr(), + ) { + Ok(c) => c, + Err(e) => { + debug_assert!( + false, + "RG reorder: cannot create stats converter for `{}`: {e}", + column.name(), + ); + return Ok(self); + } + }; + + // Always sort ASC by min values — direction is handled by reverse + let rg_metadata: Vec<&RowGroupMetaData> = self + .row_group_indexes + .iter() + .map(|&idx| file_metadata.row_group(idx)) + .collect(); + + let stat_mins = match converter.row_group_mins(rg_metadata.iter().copied()) { + Ok(vals) => vals, + Err(e) => { + debug_assert!( + false, + "RG reorder: cannot get min values for `{}`: {e}", + column.name(), + ); + return Ok(self); + } + }; + + let sort_options = arrow::compute::SortOptions { + descending: false, + nulls_first: first_sort_expr.options.nulls_first, + }; + let sorted_indices = + match arrow::compute::sort_to_indices(&stat_mins, Some(sort_options), None) { + Ok(indices) => indices, + Err(e) => { + debug_assert!( + false, + "RG reorder: arrow sort_to_indices failed for `{}`: {e}", + column.name(), + ); + return Ok(self); + } + }; + + // Apply the reordering + let original_indexes = self.row_group_indexes.clone(); + self.row_group_indexes = sorted_indices + .values() + .iter() + .map(|&i| original_indexes[i as usize]) + .collect(); + + Ok(self) + } + + /// Reverse the access plan for reverse scanning + pub(crate) fn reverse(mut self, file_metadata: &ParquetMetaData) -> Result { + // Get the row group indexes before reversing + let row_groups_to_scan = self.row_group_indexes.clone(); + + // Reverse the row group indexes + self.row_group_indexes = self.row_group_indexes.into_iter().rev().collect(); + + // If we have a row selection, reverse it to match the new row group order + if let Some(row_selection) = self.row_selection { + self.row_selection = Some(reverse_row_selection( + &row_selection, + file_metadata, + &row_groups_to_scan, // Pass the original (non-reversed) row group indexes + )?); + } + + Ok(self) + } } #[cfg(test)] @@ -511,7 +805,10 @@ mod test { .unwrap_err() .to_string(); assert_eq!(row_group_indexes, vec![0, 1, 2, 3]); - assert_contains!(err, "Invalid ParquetAccessPlan Selection. Row group 1 has 20 rows but selection only specifies 22 rows"); + assert_contains!( + err, + "Invalid ParquetAccessPlan Selection. Row group 1 has 20 rows but selection only specifies 22 rows" + ); } /// [`RowGroupMetaData`] that returns 4 row groups with 10, 20, 30, 40 rows @@ -551,4 +848,182 @@ mod test { .unwrap(); Arc::new(SchemaDescriptor::new(Arc::new(schema))) } + + // ---------------------------------------------------------------- + // `reorder_by_statistics` tests + // ---------------------------------------------------------------- + + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{BinaryExpr, lit}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + use parquet::file::metadata::FileMetaData; + use parquet::file::statistics::Statistics as ParquetStatistics; + + /// Single-column int32 schema named "a". + fn int_schema_descr() -> SchemaDescPtr { + use parquet::basic::Type as PhysicalType; + use parquet::schema::types::Type as SchemaType; + let field = SchemaType::primitive_type_builder("a", PhysicalType::INT32) + .build() + .unwrap(); + let schema = SchemaType::group_type_builder("schema") + .with_fields(vec![Arc::new(field)]) + .build() + .unwrap(); + Arc::new(SchemaDescriptor::new(Arc::new(schema))) + } + + /// Build a `ParquetMetaData` with one row group per element of + /// `mins`. Each row group declares int32 statistics with + /// `min == max == mins[i]` so the reorder key is unambiguous. + fn parquet_metadata_with_int_mins(mins: &[i32]) -> ParquetMetaData { + let schema_descr = int_schema_descr(); + let row_groups: Vec = mins + .iter() + .map(|&m| { + let stats = + ParquetStatistics::int32(Some(m), Some(m), None, Some(0), false); + let column = ColumnChunkMetaData::builder(schema_descr.column(0)) + .set_statistics(stats) + .set_num_values(100) + .build() + .unwrap(); + RowGroupMetaData::builder(schema_descr.clone()) + .set_num_rows(100) + .set_column_metadata(vec![column]) + .build() + .unwrap() + }) + .collect(); + let file_metadata = + FileMetaData::new(0, 0, None, None, schema_descr.clone(), None); + ParquetMetaData::new(file_metadata, row_groups) + } + + fn arrow_schema_a_int() -> Schema { + Schema::new(vec![Field::new("a", DataType::Int32, true)]) + } + + fn lex_ordering_a_asc() -> LexOrdering { + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]) + .unwrap() + } + + /// Happy path: three row groups with mins 50/10/100. After + /// `reorder_by_statistics` the indexes are ordered ASC by `min`, + /// i.e. RG 1 (min=10) first, then RG 0 (min=50), then RG 2 + /// (min=100). + #[test] + fn reorder_by_statistics_sorts_row_groups_asc_by_min() { + let metadata = parquet_metadata_with_int_mins(&[50, 10, 100]); + let plan = PreparedAccessPlan::new(vec![0, 1, 2], None).unwrap(); + + let result = plan + .reorder_by_statistics( + &lex_ordering_a_asc(), + &metadata, + &arrow_schema_a_int(), + ) + .unwrap(); + + assert_eq!(result.row_group_indexes, vec![1, 0, 2]); + } + + /// A `row_selection` is "too complex to remap" through reorder, + /// so the function short-circuits and returns the input untouched. + #[test] + fn reorder_by_statistics_skips_when_row_selection_present() { + let metadata = parquet_metadata_with_int_mins(&[50, 10]); + let selection = RowSelection::from(vec![RowSelector::select(100)]); + let plan = PreparedAccessPlan::new(vec![0, 1], Some(selection)).unwrap(); + + let result = plan + .reorder_by_statistics( + &lex_ordering_a_asc(), + &metadata, + &arrow_schema_a_int(), + ) + .unwrap(); + + assert_eq!(result.row_group_indexes, vec![0, 1]); + } + + /// One row group means nothing to reorder. + #[test] + fn reorder_by_statistics_skips_when_at_most_one_row_group() { + let metadata = parquet_metadata_with_int_mins(&[50]); + let plan = PreparedAccessPlan::new(vec![0], None).unwrap(); + + let result = plan + .reorder_by_statistics( + &lex_ordering_a_asc(), + &metadata, + &arrow_schema_a_int(), + ) + .unwrap(); + + assert_eq!(result.row_group_indexes, vec![0]); + } + + /// Non-`Column` sort expressions (e.g. `a + 1`, + /// `date_trunc(...)`) can't drive a stats lookup, so reorder is + /// skipped. The opener falls back to whatever order it received. + #[test] + fn reorder_by_statistics_skips_for_non_column_sort_expr() { + let metadata = parquet_metadata_with_int_mins(&[50, 10]); + let plan = PreparedAccessPlan::new(vec![0, 1], None).unwrap(); + let arrow_schema = arrow_schema_a_int(); + let order = LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + lit(1i32), + )), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]) + .unwrap(); + + let result = plan + .reorder_by_statistics(&order, &metadata, &arrow_schema) + .unwrap(); + + assert_eq!(result.row_group_indexes, vec![0, 1]); + } + + /// When the sort column lives outside the file's arrow schema + /// (e.g. a partition column that reached this method through + /// `try_pushdown_sort`'s reversed-equivalence branch), reorder is + /// an expected graceful skip — no `debug_assert!` should fire. + #[test] + fn reorder_by_statistics_skips_when_column_not_in_arrow_schema() { + let metadata = parquet_metadata_with_int_mins(&[50, 10]); + let plan = PreparedAccessPlan::new(vec![0, 1], None).unwrap(); + // Arrow schema only has "a"; the sort references "b". + let arrow_schema = arrow_schema_a_int(); + let order = LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::new(Column::new("b", 0)), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]) + .unwrap(); + + let result = plan + .reorder_by_statistics(&order, &metadata, &arrow_schema) + .unwrap(); + + assert_eq!(result.row_group_indexes, vec![0, 1]); + } } diff --git a/datafusion/datasource-parquet/src/bloom_filter.rs b/datafusion/datasource-parquet/src/bloom_filter.rs new file mode 100644 index 0000000000000..9388aba4385f2 --- /dev/null +++ b/datafusion/datasource-parquet/src/bloom_filter.rs @@ -0,0 +1,560 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Loaded Parquet Split Block Bloom Filter (SBBF) data, with a +//! [`PruningStatistics`] adapter so the predicate-pruning machinery in +//! [`datafusion_pruning`] can consume it. + +use std::collections::{HashMap, HashSet}; + +use arrow::array::{ArrayRef, BooleanArray}; +use datafusion_common::pruning::PruningStatistics; +use datafusion_common::{Column, ScalarValue}; +use parquet::basic::Type; +use parquet::bloom_filter::Sbbf; +use parquet::data_type::Decimal; + +/// In memory Parquet Split Block Bloom Filters (SBBF). +/// +/// This structure implements [`PruningStatistics`] and is used to prune +/// Parquet row groups and data pages based on the query predicate. +#[derive(Debug, Clone, Default)] +pub(crate) struct BloomFilterStatistics { + /// Per-column Bloom filters + /// Key: predicate column name + /// Value: + /// * [`Sbbf`] (Bloom filter), + /// * Parquet physical [`Type`] needed to evaluate literals against the filter + column_sbbf: HashMap, +} + +impl BloomFilterStatistics { + /// Create an empty [`BloomFilterStatistics`] + pub(crate) fn new() -> Self { + Default::default() + } + + /// Create an empty [`BloomFilterStatistics`] with the specified capacity + pub(crate) fn with_capacity(capacity: usize) -> Self { + Self { + column_sbbf: HashMap::with_capacity(capacity), + } + } + + /// Add a Bloom filter and type for the specified column + pub(crate) fn insert(&mut self, column: impl Into, sbbf: Sbbf, ty: Type) { + self.column_sbbf.insert(column.into(), (sbbf, ty)); + } + + /// Helper function for checking if [`Sbbf`] filter contains [`ScalarValue`]. + /// + /// In case the type of scalar is not supported, returns `true`, assuming that the + /// value may be present. + fn check_scalar(sbbf: &Sbbf, value: &ScalarValue, parquet_type: &Type) -> bool { + match value { + ScalarValue::Utf8(Some(v)) + | ScalarValue::Utf8View(Some(v)) + | ScalarValue::LargeUtf8(Some(v)) => sbbf.check(&v.as_str()), + ScalarValue::Binary(Some(v)) + | ScalarValue::BinaryView(Some(v)) + | ScalarValue::LargeBinary(Some(v)) => sbbf.check(v), + ScalarValue::FixedSizeBinary(_size, Some(v)) => sbbf.check(v), + ScalarValue::Boolean(Some(v)) => sbbf.check(v), + ScalarValue::Float64(Some(v)) => sbbf.check(v), + ScalarValue::Float32(Some(v)) => sbbf.check(v), + ScalarValue::Int64(Some(v)) => sbbf.check(v), + ScalarValue::Int32(Some(v)) => sbbf.check(v), + ScalarValue::UInt64(Some(v)) => sbbf.check(v), + ScalarValue::UInt32(Some(v)) => sbbf.check(v), + ScalarValue::Decimal128(Some(v), p, s) => match parquet_type { + Type::INT32 => { + //https://github.com/apache/parquet-format/blob/eb4b31c1d64a01088d02a2f9aefc6c17c54cc6fc/Encodings.md?plain=1#L35-L42 + // All physical type are little-endian + if *p > 9 { + //DECIMAL can be used to annotate the following types: + // + // int32: for 1 <= precision <= 9 + // int64: for 1 <= precision <= 18 + return true; + } + let b = (*v as i32).to_le_bytes(); + // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 + let decimal = Decimal::Int32 { + value: b, + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + Type::INT64 => { + if *p > 18 { + return true; + } + let b = (*v as i64).to_le_bytes(); + let decimal = Decimal::Int64 { + value: b, + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + Type::FIXED_LEN_BYTE_ARRAY => { + // keep with from_bytes_to_i128 + let b = v.to_be_bytes().to_vec(); + // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 + let decimal = Decimal::Bytes { + value: b.into(), + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + _ => true, + }, + ScalarValue::Dictionary(_, inner) => { + BloomFilterStatistics::check_scalar(sbbf, inner, parquet_type) + } + _ => true, + } + } +} + +impl PruningStatistics for BloomFilterStatistics { + fn min_values(&self, _column: &Column) -> Option { + None + } + + fn max_values(&self, _column: &Column) -> Option { + None + } + + fn num_containers(&self) -> usize { + 1 + } + + fn null_counts(&self, _column: &Column) -> Option { + None + } + + fn row_counts(&self) -> Option { + None + } + + /// Use bloom filters to determine if we are sure this column can not + /// possibly contain `values` + /// + /// The `contained` API returns false if the bloom filters knows that *ALL* + /// of the values in a column are not present. + fn contained( + &self, + column: &Column, + values: &HashSet, + ) -> Option { + let (sbbf, parquet_type) = self.column_sbbf.get(column.name.as_str())?; + + // Bloom filters are probabilistic data structures that can return false + // positives (i.e. it might return true even if the value is not + // present) however, the bloom filter will return `false` if the value is + // definitely not present. + + let known_not_present = values + .iter() + .map(|value| BloomFilterStatistics::check_scalar(sbbf, value, parquet_type)) + // The row group doesn't contain any of the values if + // all the checks are false + .all(|v| !v); + + let contains = if known_not_present { + Some(false) + } else { + // Given the bloom filter is probabilistic, we can't be sure that + // the row group actually contains the values. Return `None` to + // indicate this uncertainty + None + }; + + Some(BooleanArray::from(vec![contains])) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::Arc; + + use crate::reader::ParquetFileReader; + use crate::test_util::ExpectedPruning; + use crate::{ParquetAccessPlan, ParquetFileMetrics, RowGroupAccessPlanFilter}; + + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_expr::{Expr, col, lit}; + use datafusion_physical_expr::planner::logical2physical; + use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; + use datafusion_pruning::PruningPredicate; + use object_store::ObjectStoreExt; + use parquet::arrow::ParquetRecordBatchStreamBuilder; + use parquet::arrow::async_reader::ParquetObjectReader; + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_simple_expr() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_all_pruned() + // generate pruning predicate `(String = "Hello_Not_exists")` + .run(col(r#""String""#).eq(lit("Hello_Not_Exists"))) + .await + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_multiple_expr() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_all_pruned() + // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` + .run( + lit("1").eq(lit("1")).and( + col(r#""String""#) + .eq(lit("Hello_Not_Exists")) + .or(col(r#""String""#).eq(lit("Hello_Not_Exists2"))), + ), + ) + .await + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_multiple_expr_view() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_all_pruned() + // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` + .run( + lit("1").eq(lit("1")).and( + col(r#""String""#) + .eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("Hello_Not_Exists"))), + None, + )) + .or(col(r#""String""#).eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from( + "Hello_Not_Exists2", + ))), + None, + ))), + ), + ) + .await + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_sql_in() { + // load parquet file + let testdata = datafusion_common::test_util::parquet_test_data(); + let file_name = "data_index_bloom_encoding_stats.parquet"; + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + // generate pruning predicate + let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + + let expr = col(r#""String""#).in_list( + (1..25) + .map(|i| lit(format!("Hello_Not_Exists{i}"))) + .collect::>(), + false, + ); + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + file_name, + data, + &pruning_predicate, + ) + .await + .unwrap(); + assert!( + pruned_row_groups + .access_plan() + .row_group_indexes() + .is_empty() + ); + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_value() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello")` + .run(col(r#""String""#).eq(lit("Hello"))) + .await + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_2_values() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello") OR (String = "the quick")` + .run( + col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))), + ) + .await + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_3_values() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` + .run( + col(r#""String""#) + .eq(lit("Hello")) + .or(col(r#""String""#).eq(lit("the quick"))) + .or(col(r#""String""#).eq(lit("are you"))), + ) + .await + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_3_values_view() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` + .run( + col(r#""String""#) + .eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("Hello"))), + None, + )) + .or(col(r#""String""#).eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("the quick"))), + None, + ))) + .or(col(r#""String""#).eq(Expr::Literal( + ScalarValue::Utf8View(Some(String::from("are you"))), + None, + ))), + ) + .await + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_or_not_eq() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "foo") OR (String != "bar")` + .run( + col(r#""String""#) + .not_eq(lit("foo")) + .or(col(r#""String""#).not_eq(lit("bar"))), + ) + .await + } + + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_without_bloom_filter() { + // generate pruning predicate on a column without a bloom filter + BloomFilterTest::new_all_types() + .with_expect_none_pruned() + .run(col(r#""string_col""#).eq(lit("0"))) + .await + } + + struct BloomFilterTest { + file_name: String, + schema: Schema, + // which row groups are expected to be left after pruning + post_pruning_row_groups: ExpectedPruning, + } + + impl BloomFilterTest { + /// Return a test for data_index_bloom_encoding_stats.parquet + /// Note the values in the `String` column are: + /// ```sql + /// > select * from './parquet-testing/data/data_index_bloom_encoding_stats.parquet'; + /// +-----------+ + /// | String | + /// +-----------+ + /// | Hello | + /// | This is | + /// | a | + /// | test | + /// | How | + /// | are you | + /// | doing | + /// | today | + /// | the quick | + /// | brown fox | + /// | jumps | + /// | over | + /// | the lazy | + /// | dog | + /// +-----------+ + /// ``` + fn new_data_index_bloom_encoding_stats() -> Self { + Self { + file_name: String::from("data_index_bloom_encoding_stats.parquet"), + schema: Schema::new(vec![Field::new("String", DataType::Utf8, false)]), + post_pruning_row_groups: ExpectedPruning::None, + } + } + + // Return a test for alltypes_plain.parquet + fn new_all_types() -> Self { + Self { + file_name: String::from("alltypes_plain.parquet"), + schema: Schema::new(vec![Field::new( + "string_col", + DataType::Utf8, + false, + )]), + post_pruning_row_groups: ExpectedPruning::None, + } + } + + /// Expect all row groups to be pruned + pub fn with_expect_all_pruned(mut self) -> Self { + self.post_pruning_row_groups = ExpectedPruning::All; + self + } + + /// Expect all row groups not to be pruned + pub fn with_expect_none_pruned(mut self) -> Self { + self.post_pruning_row_groups = ExpectedPruning::None; + self + } + + /// Prune this file using the specified expression and check that the expected row groups are left + async fn run(self, expr: Expr) { + let Self { + file_name, + schema, + post_pruning_row_groups, + } = self; + + let testdata = datafusion_common::test_util::parquet_test_data(); + let path = format!("{testdata}/{file_name}"); + let data = bytes::Bytes::from(std::fs::read(path).unwrap()); + + let expr = logical2physical(&expr, &schema); + let pruning_predicate = + PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + + let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( + &file_name, + data, + &pruning_predicate, + ) + .await + .unwrap(); + + post_pruning_row_groups.assert(&pruned_row_groups); + } + } + + /// Evaluates the pruning predicate on the specified row groups and returns the row groups that are left + async fn test_row_group_bloom_filter_pruning_predicate( + file_name: &str, + data: bytes::Bytes, + pruning_predicate: &PruningPredicate, + ) -> Result { + use datafusion_datasource::PartitionedFile; + use object_store::ObjectMeta; + + let object_meta = ObjectMeta { + location: object_store::path::Path::parse(file_name).expect("creating path"), + last_modified: chrono::DateTime::from(std::time::SystemTime::now()), + size: data.len() as u64, + e_tag: None, + version: None, + }; + let in_memory = object_store::memory::InMemory::new(); + in_memory + .put(&object_meta.location, data.into()) + .await + .expect("put parquet file into in memory object store"); + + let metrics = ExecutionPlanMetricsSet::new(); + let file_metrics = + ParquetFileMetrics::new(0, object_meta.location.as_ref(), &metrics); + let inner = + ParquetObjectReader::new(Arc::new(in_memory), object_meta.location.clone()) + .with_file_size(object_meta.size); + + let partitioned_file = PartitionedFile::new_from_meta(object_meta); + + let reader = ParquetFileReader { + inner, + file_metrics: file_metrics.clone(), + partitioned_file, + }; + let mut builder = ParquetRecordBatchStreamBuilder::new(reader).await.unwrap(); + + let access_plan = ParquetAccessPlan::new_all(builder.metadata().num_row_groups()); + let mut pruned_row_groups = RowGroupAccessPlanFilter::new(access_plan); + let literal_columns = pruning_predicate.literal_columns(); + let parquet_columns: Vec<_> = literal_columns + .into_iter() + .filter_map(|column_name| { + let (column_idx, _) = parquet::arrow::parquet_column( + builder.parquet_schema(), + pruning_predicate.schema(), + &column_name, + )?; + Some(( + column_name.to_string(), + column_idx, + builder.parquet_schema().column(column_idx).physical_type(), + )) + }) + .collect::>(); + let mut row_group_bloom_filters = + Vec::with_capacity(builder.metadata().num_row_groups()); + row_group_bloom_filters.resize_with( + builder.metadata().num_row_groups(), + BloomFilterStatistics::new, + ); + for idx in pruned_row_groups.row_group_indexes() { + let mut bloom_filters = + BloomFilterStatistics::with_capacity(parquet_columns.len()); + for (column_name, column_idx, physical_type) in &parquet_columns { + let bf = match builder + .get_row_group_column_bloom_filter(idx, *column_idx) + .await + { + Ok(Some(bf)) => bf, + Ok(None) => continue, + Err(e) => { + log::debug!("Ignoring error reading bloom filter: {e}"); + file_metrics.predicate_evaluation_errors.add(1); + continue; + } + }; + bloom_filters.insert(column_name.clone(), bf, *physical_type); + } + row_group_bloom_filters[idx] = bloom_filters; + } + pruned_row_groups.prune_by_bloom_filters( + pruning_predicate, + &file_metrics, + &row_group_bloom_filters, + ); + + Ok(pruned_row_groups) + } +} diff --git a/datafusion/datasource-parquet/src/decoder_projection.rs b/datafusion/datasource-parquet/src/decoder_projection.rs new file mode 100644 index 0000000000000..27a84f2f50298 --- /dev/null +++ b/datafusion/datasource-parquet/src/decoder_projection.rs @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Decoder-projection construction for the parquet scan. +//! +//! [`DecoderProjection`] owns the two halves of "project a decoded parquet +//! batch onto the scan's output schema": +//! +//! * the [`ProjectionMask`] installed on every parquet decoder run, and +//! * the per-batch transform ([`DecoderProjection::map`]) that applies the +//! projector and, when needed, rebuilds the batch with the user's +//! `output_schema` to recover metadata / nullability the file schema does +//! not carry. +//! +//! The opener constructs one [`DecoderProjection`] per file via +//! [`DecoderProjection::try_new`] and hands it to the push-decoder stream, +//! which calls [`map`](DecoderProjection::map) on every decoded batch. + +use std::sync::Arc; + +use arrow::array::{RecordBatch, RecordBatchOptions}; +use arrow::datatypes::SchemaRef; + +use datafusion_common::Result; +use datafusion_physical_expr::projection::{ProjectionExprs, Projector}; +use datafusion_physical_expr::utils::reassign_expr_columns; +use datafusion_physical_expr_adapter::replace_columns_with_literals; + +use parquet::arrow::ProjectionMask; +use parquet::schema::types::SchemaDescriptor; + +use crate::opener::{VirtualColumnsState, append_fields}; +use crate::row_filter::build_projection_read_plan; + +/// Per-file decoder projection: the [`ProjectionMask`] installed on every +/// parquet decoder run, plus the per-batch transform that maps the decoder's +/// output onto the scan's `output_schema`. +/// +/// Built once per file by the opener via [`Self::try_new`]; the +/// push-decoder stream installs [`Self::projection_mask`] on each decoder +/// and calls [`Self::map`] on every decoded batch. +pub(crate) struct DecoderProjection { + projection_mask: ProjectionMask, + projector: Projector, + output_schema: SchemaRef, + /// `true` when the projector's output schema differs from `output_schema` + /// in metadata / nullability and [`map`](Self::map) must rebuild the batch + /// with `output_schema`. + replace_schema: bool, +} + +impl DecoderProjection { + /// Build the decoder projection for a file. + /// + /// `projection` references columns in `physical_file_schema` (i.e. already + /// adapted by the per-file expr adapter); `parquet_schema` is the + /// corresponding parquet [`SchemaDescriptor`]. `output_schema` is what + /// consumers of the scan stream expect. + /// + /// `virtual_state`, when present, describes virtual columns the reader + /// will append to each decoded batch (e.g. parquet `row_number`). Virtual + /// columns are stripped from the projection fed into + /// `build_projection_read_plan` (which only understands file columns) and + /// appended to the stream schema so the projector can resolve them. + pub(crate) fn try_new( + projection: &ProjectionExprs, + physical_file_schema: &SchemaRef, + parquet_schema: &SchemaDescriptor, + output_schema: &SchemaRef, + virtual_state: Option<&VirtualColumnsState>, + ) -> Result { + // Virtual columns are produced by the reader separately from the + // projection mask, so strip them from the expressions we feed into + // `build_projection_read_plan`. We substitute each virtual column + // reference with a null literal; that leaves the remaining Column + // refs (into `physical_file_schema`) intact for + // `ProjectionMask::roots`, which only understands file columns. + let projection_for_read_plan = match virtual_state { + None => projection.clone(), + Some(state) => projection.clone().try_map_exprs(|expr| { + replace_columns_with_literals(expr, state.null_replacements()) + })?, + }; + let read_plan = build_projection_read_plan( + projection_for_read_plan.expr_iter(), + physical_file_schema, + parquet_schema, + ); + + // The reader produces projected file columns followed by any virtual + // columns (`ArrowReaderOptions::with_virtual_columns` appends them to + // each decoded batch). + let stream_schema = match virtual_state { + Some(state) => { + append_fields(&read_plan.projected_schema, state.virtual_columns()) + } + None => Arc::clone(&read_plan.projected_schema), + }; + + // Rebase the projection onto the decoder's stream schema (column + // indices change because the decoder yields only the masked columns). + let rebased_projection = projection + .clone() + .try_map_exprs(|expr| reassign_expr_columns(expr, &stream_schema))?; + let projector = rebased_projection.make_projector(&stream_schema)?; + + // Compare against the projector's *output* schema rather than the + // stream schema, so future widening of the mask (e.g. for post-scan + // filter columns) does not flip this flag. + let replace_schema = projector.output_schema() != output_schema; + + Ok(Self { + projection_mask: read_plan.projection_mask, + projector, + output_schema: Arc::clone(output_schema), + replace_schema, + }) + } + + /// The projection mask to install on every parquet decoder in the scan. + pub(crate) fn projection_mask(&self) -> &ProjectionMask { + &self.projection_mask + } + + /// Map a decoded batch onto the scan's output schema. + /// + /// Applies the [`Projector`] and, when the projector's output schema + /// differs from `output_schema` in metadata or nullability, rebuilds the + /// batch with `output_schema` (some writers emit OPTIONAL fields even when + /// the data has no nulls; some logical schemas carry field-level metadata + /// the file schema does not). + pub(crate) fn map(&self, batch: &RecordBatch) -> Result { + let projected = self.projector.project_batch(batch)?; + if !self.replace_schema { + return Ok(projected); + } + let (_stream_schema, arrays, num_rows) = projected.into_parts(); + let options = RecordBatchOptions::new().with_row_count(Some(num_rows)); + Ok(RecordBatch::try_new_with_options( + Arc::clone(&self.output_schema), + arrays, + &options, + )?) + } +} diff --git a/datafusion/datasource-parquet/src/file_format.rs b/datafusion/datasource-parquet/src/file_format.rs index 385bfb5472a53..fe81504e320d7 100644 --- a/datafusion/datasource-parquet/src/file_format.rs +++ b/datafusion/datasource-parquet/src/file_format.rs @@ -17,84 +17,59 @@ //! [`ParquetFormat`]: Parquet [`FileFormat`] abstractions -use std::any::Any; -use std::cell::RefCell; +use std::fmt; use std::fmt::Debug; use std::ops::Range; -use std::rc::Rc; use std::sync::Arc; -use std::{fmt, vec}; -use arrow::array::RecordBatch; -use arrow::datatypes::{Fields, Schema, SchemaRef, TimeUnit}; -use datafusion_datasource::file_compression_type::FileCompressionType; -use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; -use datafusion_datasource::write::{ - get_writer_schema, ObjectWriterBuilder, SharedBuffer, +// Re-export so the historical `file_format::*` paths still resolve. +#[expect(deprecated)] +pub use crate::schema_coercion::{ + Int96Coercer, apply_file_schema_type_coercions, coerce_file_schema_to_string_type, + coerce_file_schema_to_view_type, coerce_int96_to_resolution, + transform_binary_to_string, transform_schema_to_view, }; +pub use crate::sink::ParquetSink; + +use arrow::datatypes::{Fields, Schema, SchemaRef}; use datafusion_datasource::TableSchema; +use datafusion_datasource::file_compression_type::FileCompressionType; +use datafusion_datasource::file_sink_config::FileSinkConfig; use datafusion_datasource::file_format::{FileFormat, FileFormatFactory}; -use datafusion_datasource::write::demux::DemuxedStreamReceiver; -use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::Statistics; use datafusion_common::config::{ConfigField, ConfigFileType, TableParquetOptions}; use datafusion_common::encryption::FileDecryptionProperties; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - internal_datafusion_err, internal_err, not_impl_err, DataFusionError, GetExt, - HashSet, Result, DEFAULT_PARQUET_EXTENSION, + DEFAULT_PARQUET_EXTENSION, DataFusionError, GetExt, Result, internal_datafusion_err, + internal_err, not_impl_err, }; -use datafusion_common::{HashMap, Statistics}; -use datafusion_common_runtime::{JoinSet, SpawnedTask}; -use datafusion_datasource::display::FileGroupDisplay; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; -use datafusion_datasource::sink::{DataSink, DataSinkExec}; -use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_datasource::sink::DataSinkExec; use datafusion_expr::dml::InsertOp; -use datafusion_physical_expr_common::sort_expr::LexRequirement; -use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_plan::ExecutionPlan; use datafusion_session::Session; -use crate::metadata::DFParquetMetadata; +use crate::metadata::{DFParquetMetadata, lex_ordering_to_sorting_columns}; use crate::reader::CachedParquetFileReaderFactory; -use crate::source::{parse_coerce_int96_string, ParquetSource}; +use crate::source::{ + ParquetSource, parse_coerce_int96_string, parse_coerce_int96_tz_string, +}; use async_trait::async_trait; use bytes::Bytes; use datafusion_datasource::source::DataSourceExec; use datafusion_execution::cache::cache_manager::FileMetadataCache; -use datafusion_execution::runtime_env::RuntimeEnv; use futures::future::BoxFuture; use futures::{FutureExt, StreamExt, TryStreamExt}; -use object_store::buffered::BufWriter; use object_store::path::Path; -use object_store::{ObjectMeta, ObjectStore}; -use parquet::arrow::arrow_writer::{ - compute_leaves, ArrowColumnChunk, ArrowColumnWriter, ArrowLeafColumn, - ArrowRowGroupWriterFactory, ArrowWriterOptions, -}; +use object_store::{ObjectMeta, ObjectStore, ObjectStoreExt}; use parquet::arrow::async_reader::MetadataFetch; -use parquet::arrow::{ArrowWriter, AsyncArrowWriter}; -use parquet::basic::Type; -#[cfg(feature = "parquet_encryption")] -use parquet::encryption::encrypt::FileEncryptionProperties; use parquet::errors::ParquetError; use parquet::file::metadata::ParquetMetaData; -use parquet::file::properties::{WriterProperties, WriterPropertiesBuilder}; -use parquet::file::writer::SerializedFileWriter; -use parquet::schema::types::SchemaDescriptor; -use tokio::io::{AsyncWrite, AsyncWriteExt}; -use tokio::sync::mpsc::{self, Receiver, Sender}; - -/// Initial writing buffer size. Note this is just a size hint for efficiency. It -/// will grow beyond the set value if needed. -const INITIAL_BUFFER_BYTES: usize = 1048576; - -/// When writing parquet files in parallel, if the buffered Parquet data exceeds -/// this size, it is flushed to object store -const BUFFER_FLUSH_BYTES: usize = 1024000; #[derive(Default)] /// Factory struct used to create [ParquetFormat] @@ -147,10 +122,6 @@ impl FileFormatFactory for ParquetFormatFactory { fn default(&self) -> Arc { Arc::new(ParquetFormat::default()) } - - fn as_any(&self) -> &dyn Any { - self - } } impl GetExt for ParquetFormatFactory { @@ -307,7 +278,7 @@ async fn get_file_decryption_properties( file_path: &Path, ) -> Result>> { Ok(match &options.crypto.file_decryption { - Some(cfd) => Some(Arc::new(FileDecryptionProperties::from(cfd.clone()))), + Some(cfd) => Some(Arc::new(FileDecryptionProperties::try_from(cfd.clone())?)), None => match &options.crypto.factory_id { Some(factory_id) => { let factory = @@ -335,10 +306,6 @@ async fn get_file_decryption_properties( #[async_trait] impl FileFormat for ParquetFormat { - fn as_any(&self) -> &dyn Any { - self - } - fn get_ext(&self) -> String { ParquetFormatFactory::new().get_ext() } @@ -368,6 +335,13 @@ impl FileFormat for ParquetFormat { Some(time_unit) => Some(parse_coerce_int96_string(time_unit.as_str())?), None => None, }; + let coerce_int96_tz = self + .options + .global + .coerce_int96_tz + .as_ref() + .map(|tz| parse_coerce_int96_tz_string(tz)) + .transpose()?; let file_metadata_cache = state.runtime_env().cache_manager.get_file_metadata_cache(); @@ -385,13 +359,14 @@ impl FileFormat for ParquetFormat { .with_decryption_properties(file_decryption_properties) .with_file_metadata_cache(Some(Arc::clone(&file_metadata_cache))) .with_coerce_int96(coerce_int96) + .with_coerce_int96_tz(coerce_int96_tz.clone()) .fetch_schema_with_location() .await?; Ok::<_, DataFusionError>(result) }) .boxed() // Workaround https://github.com/rust-lang/rust/issues/64552 // fetch schemas concurrently, if requested - .buffered(state.config_options().execution.meta_fetch_concurrency) + .buffer_unordered(state.config_options().execution.meta_fetch_concurrency) .try_collect() .await?; @@ -401,12 +376,10 @@ impl FileFormat for ParquetFormat { // is not deterministic. Thus, to ensure deterministic schema inference // sort the files first. // https://github.com/apache/datafusion/pull/6629 - schemas.sort_by(|(location1, _), (location2, _)| location1.cmp(location2)); + schemas + .sort_unstable_by(|(location1, _), (location2, _)| location1.cmp(location2)); - let schemas = schemas - .into_iter() - .map(|(_, schema)| schema) - .collect::>(); + let schemas = schemas.into_iter().map(|(_, schema)| schema); let schema = if self.skip_metadata() { Schema::try_merge(clear_metadata(schemas)) @@ -449,6 +422,57 @@ impl FileFormat for ParquetFormat { .await } + async fn infer_ordering( + &self, + state: &dyn Session, + store: &Arc, + table_schema: SchemaRef, + object: &ObjectMeta, + ) -> Result> { + let file_decryption_properties = + get_file_decryption_properties(state, &self.options, &object.location) + .await?; + let file_metadata_cache = + state.runtime_env().cache_manager.get_file_metadata_cache(); + let metadata = DFParquetMetadata::new(store, object) + .with_metadata_size_hint(self.metadata_size_hint()) + .with_decryption_properties(file_decryption_properties) + .with_file_metadata_cache(Some(file_metadata_cache)) + .fetch_metadata() + .await?; + crate::metadata::ordering_from_parquet_metadata(&metadata, &table_schema) + } + + async fn infer_stats_and_ordering( + &self, + state: &dyn Session, + store: &Arc, + table_schema: SchemaRef, + object: &ObjectMeta, + ) -> Result { + let file_decryption_properties = + get_file_decryption_properties(state, &self.options, &object.location) + .await?; + let file_metadata_cache = + state.runtime_env().cache_manager.get_file_metadata_cache(); + let metadata = DFParquetMetadata::new(store, object) + .with_metadata_size_hint(self.metadata_size_hint()) + .with_decryption_properties(file_decryption_properties) + .with_file_metadata_cache(Some(file_metadata_cache)) + .fetch_metadata() + .await?; + let statistics = DFParquetMetadata::statistics_from_parquet_metadata( + &metadata, + &table_schema, + )?; + let ordering = + crate::metadata::ordering_from_parquet_metadata(&metadata, &table_schema)?; + Ok( + datafusion_datasource::file_format::FileMeta::new(statistics) + .with_ordering(ordering), + ) + } + async fn create_physical_plan( &self, state: &dyn Session, @@ -460,12 +484,12 @@ impl FileFormat for ParquetFormat { metadata_size_hint = Some(metadata); } - let table_schema = TableSchema::new( - Arc::clone(conf.file_schema()), - conf.table_partition_cols().clone(), - ); - let mut source = ParquetSource::new(table_schema) - .with_table_parquet_options(self.options.clone()); + let mut source = conf + .file_source() + .downcast_ref::() + .cloned() + .ok_or_else(|| internal_datafusion_err!("Expected ParquetSource"))?; + source = source.with_table_parquet_options(self.options.clone()); // Use the CachedParquetFileReaderFactory let metadata_cache = state.runtime_env().cache_manager.get_file_metadata_cache(); @@ -482,11 +506,8 @@ impl FileFormat for ParquetFormat { source = self.set_source_encryption_factory(source, state)?; - // Apply schema adapter factory before building the new config - let file_source = source.apply_schema_adapter(&conf)?; - let conf = FileScanConfigBuilder::from(conf) - .with_source(file_source) + .with_source(Arc::new(source)) .build(); Ok(DataSourceExec::from_data_source(conf)) } @@ -502,7 +523,22 @@ impl FileFormat for ParquetFormat { return not_impl_err!("Overwrites are not implemented yet for Parquet"); } - let sink = Arc::new(ParquetSink::new(conf, self.options.clone())); + // Convert ordering requirements to Parquet SortingColumns for file metadata + let sorting_columns = if let Some(ref requirements) = order_requirements { + let ordering: LexOrdering = requirements.clone().into(); + // In cases like `COPY (... ORDER BY ...) TO ...` the ORDER BY clause + // may not be compatible with Parquet sorting columns (e.g. ordering on `random()`). + // So if we cannot create a Parquet sorting column from the ordering requirement, + // we skip setting sorting columns on the Parquet sink. + lex_ordering_to_sorting_columns(&ordering).ok() + } else { + None + }; + + let sink = Arc::new( + ParquetSink::new(conf, self.options.clone()) + .with_sorting_columns(sorting_columns), + ); Ok(Arc::new(DataSinkExec::new(input, sink, order_requirements)) as _) } @@ -542,458 +578,15 @@ impl ParquetFormat { _state: &dyn Session, ) -> Result { if let Some(encryption_factory_id) = &self.options.crypto.factory_id { - Err(DataFusionError::Configuration( - format!("Parquet encryption factory id is set to '{encryption_factory_id}' but the parquet_encryption feature is disabled"))) + Err(DataFusionError::Configuration(format!( + "Parquet encryption factory id is set to '{encryption_factory_id}' but the parquet_encryption feature is disabled" + ))) } else { Ok(source) } } } -/// Apply necessary schema type coercions to make file schema match table schema. -/// -/// This function performs two main types of transformations in a single pass: -/// 1. Binary types to string types conversion - Converts binary data types to their -/// corresponding string types when the table schema expects string data -/// 2. Regular to view types conversion - Converts standard string/binary types to -/// view types when the table schema uses view types -/// -/// # Arguments -/// * `table_schema` - The table schema containing the desired types -/// * `file_schema` - The file schema to be transformed -/// -/// # Returns -/// * `Some(Schema)` - If any transformations were applied, returns the transformed schema -/// * `None` - If no transformations were needed -pub fn apply_file_schema_type_coercions( - table_schema: &Schema, - file_schema: &Schema, -) -> Option { - let mut needs_view_transform = false; - let mut needs_string_transform = false; - - // Create a mapping of table field names to their data types for fast lookup - // and simultaneously check if we need any transformations - let table_fields: HashMap<_, _> = table_schema - .fields() - .iter() - .map(|f| { - let dt = f.data_type(); - // Check if we need view type transformation - if matches!(dt, &DataType::Utf8View | &DataType::BinaryView) { - needs_view_transform = true; - } - // Check if we need string type transformation - if matches!( - dt, - &DataType::Utf8 | &DataType::LargeUtf8 | &DataType::Utf8View - ) { - needs_string_transform = true; - } - - (f.name(), dt) - }) - .collect(); - - // Early return if no transformation needed - if !needs_view_transform && !needs_string_transform { - return None; - } - - let transformed_fields: Vec> = file_schema - .fields() - .iter() - .map(|field| { - let field_name = field.name(); - let field_type = field.data_type(); - - // Look up the corresponding field type in the table schema - if let Some(table_type) = table_fields.get(field_name) { - match (table_type, field_type) { - // table schema uses string type, coerce the file schema to use string type - ( - &DataType::Utf8, - DataType::Binary | DataType::LargeBinary | DataType::BinaryView, - ) => { - return field_with_new_type(field, DataType::Utf8); - } - // table schema uses large string type, coerce the file schema to use large string type - ( - &DataType::LargeUtf8, - DataType::Binary | DataType::LargeBinary | DataType::BinaryView, - ) => { - return field_with_new_type(field, DataType::LargeUtf8); - } - // table schema uses string view type, coerce the file schema to use view type - ( - &DataType::Utf8View, - DataType::Binary | DataType::LargeBinary | DataType::BinaryView, - ) => { - return field_with_new_type(field, DataType::Utf8View); - } - // Handle view type conversions - (&DataType::Utf8View, DataType::Utf8 | DataType::LargeUtf8) => { - return field_with_new_type(field, DataType::Utf8View); - } - (&DataType::BinaryView, DataType::Binary | DataType::LargeBinary) => { - return field_with_new_type(field, DataType::BinaryView); - } - _ => {} - } - } - - // If no transformation is needed, keep the original field - Arc::clone(field) - }) - .collect(); - - Some(Schema::new_with_metadata( - transformed_fields, - file_schema.metadata.clone(), - )) -} - -/// Coerces the file schema's Timestamps to the provided TimeUnit if Parquet schema contains INT96. -pub fn coerce_int96_to_resolution( - parquet_schema: &SchemaDescriptor, - file_schema: &Schema, - time_unit: &TimeUnit, -) -> Option { - // Traverse the parquet_schema columns looking for int96 physical types. If encountered, insert - // the field's full path into a set. - let int96_fields: HashSet<_> = parquet_schema - .columns() - .iter() - .filter(|f| f.physical_type() == Type::INT96) - .map(|f| f.path().string()) - .collect(); - - if int96_fields.is_empty() { - // The schema doesn't contain any int96 fields, so skip the remaining logic. - return None; - } - - // Do a DFS into the schema using a stack, looking for timestamp(nanos) fields that originated - // as int96 to coerce to the provided time_unit. - - type NestedFields = Rc>>; - type StackContext<'a> = ( - Vec<&'a str>, // The Parquet column path (e.g., "c0.list.element.c1") for the current field. - &'a FieldRef, // The current field to be processed. - NestedFields, // The parent's fields that this field will be (possibly) type-coerced and - // inserted into. All fields have a parent, so this is not an Option type. - Option, // Nested types need to create their own vector of fields for their - // children. For primitive types this will remain None. For nested - // types it is None the first time they are processed. Then, we - // instantiate a vector for its children, push the field back onto the - // stack to be processed again, and DFS into its children. The next - // time we process the field, we know we have DFS'd into the children - // because this field is Some. - ); - - // This is our top-level fields from which we will construct our schema. We pass this into our - // initial stack context as the parent fields, and the DFS populates it. - let fields = Rc::new(RefCell::new(Vec::with_capacity(file_schema.fields.len()))); - - // TODO: It might be possible to only DFS into nested fields that we know contain an int96 if we - // use some sort of LPM data structure to check if we're currently DFS'ing nested types that are - // in a column path that contains an int96. That can be a future optimization for large schemas. - let transformed_schema = { - // Populate the stack with our top-level fields. - let mut stack: Vec = file_schema - .fields() - .iter() - .rev() - .map(|f| (vec![f.name().as_str()], f, Rc::clone(&fields), None)) - .collect(); - - // Pop fields to DFS into until we have exhausted the stack. - while let Some((parquet_path, current_field, parent_fields, child_fields)) = - stack.pop() - { - match (current_field.data_type(), child_fields) { - (DataType::Struct(unprocessed_children), None) => { - // This is the first time popping off this struct. We don't yet know the - // correct types of its children (i.e., if they need coercing) so we create - // a vector for child_fields, push the struct node back onto the stack to be - // processed again (see below) after processing all its children. - let child_fields = Rc::new(RefCell::new(Vec::with_capacity( - unprocessed_children.len(), - ))); - // Note that here we push the struct back onto the stack with its - // parent_fields in the same position, now with Some(child_fields). - stack.push(( - parquet_path.clone(), - current_field, - parent_fields, - Some(Rc::clone(&child_fields)), - )); - // Push all the children in reverse to maintain original schema order due to - // stack processing. - for child in unprocessed_children.into_iter().rev() { - let mut child_path = parquet_path.clone(); - // Build up a normalized path that we'll use as a key into the original - // int96_fields set above to test if this originated as int96. - child_path.push("."); - child_path.push(child.name()); - // Note that here we push the field onto the stack using the struct's - // new child_fields vector as the field's parent_fields. - stack.push((child_path, child, Rc::clone(&child_fields), None)); - } - } - (DataType::Struct(unprocessed_children), Some(processed_children)) => { - // This is the second time popping off this struct. The child_fields vector - // now contains each field that has been DFS'd into, and we can construct - // the resulting struct with correct child types. - let processed_children = processed_children.borrow(); - assert_eq!(processed_children.len(), unprocessed_children.len()); - let processed_struct = Field::new_struct( - current_field.name(), - processed_children.as_slice(), - current_field.is_nullable(), - ); - parent_fields.borrow_mut().push(Arc::new(processed_struct)); - } - (DataType::List(unprocessed_child), None) => { - // This is the first time popping off this list. See struct docs above. - let child_fields = Rc::new(RefCell::new(Vec::with_capacity(1))); - stack.push(( - parquet_path.clone(), - current_field, - parent_fields, - Some(Rc::clone(&child_fields)), - )); - let mut child_path = parquet_path.clone(); - // Spark uses a definition for arrays/lists that results in a group - // named "list" that is not maintained when parsing to Arrow. We just push - // this name into the path. - child_path.push(".list."); - child_path.push(unprocessed_child.name()); - stack.push(( - child_path.clone(), - unprocessed_child, - Rc::clone(&child_fields), - None, - )); - } - (DataType::List(_), Some(processed_children)) => { - // This is the second time popping off this list. See struct docs above. - let processed_children = processed_children.borrow(); - assert_eq!(processed_children.len(), 1); - let processed_list = Field::new_list( - current_field.name(), - Arc::clone(&processed_children[0]), - current_field.is_nullable(), - ); - parent_fields.borrow_mut().push(Arc::new(processed_list)); - } - (DataType::Map(unprocessed_child, _), None) => { - // This is the first time popping off this map. See struct docs above. - let child_fields = Rc::new(RefCell::new(Vec::with_capacity(1))); - stack.push(( - parquet_path.clone(), - current_field, - parent_fields, - Some(Rc::clone(&child_fields)), - )); - let mut child_path = parquet_path.clone(); - child_path.push("."); - child_path.push(unprocessed_child.name()); - stack.push(( - child_path.clone(), - unprocessed_child, - Rc::clone(&child_fields), - None, - )); - } - (DataType::Map(_, sorted), Some(processed_children)) => { - // This is the second time popping off this map. See struct docs above. - let processed_children = processed_children.borrow(); - assert_eq!(processed_children.len(), 1); - let processed_map = Field::new( - current_field.name(), - DataType::Map(Arc::clone(&processed_children[0]), *sorted), - current_field.is_nullable(), - ); - parent_fields.borrow_mut().push(Arc::new(processed_map)); - } - (DataType::Timestamp(TimeUnit::Nanosecond, None), None) - if int96_fields.contains(parquet_path.concat().as_str()) => - // We found a timestamp(nanos) and it originated as int96. Coerce it to the correct - // time_unit. - { - parent_fields.borrow_mut().push(field_with_new_type( - current_field, - DataType::Timestamp(*time_unit, None), - )); - } - // Other types can be cloned as they are. - _ => parent_fields.borrow_mut().push(Arc::clone(current_field)), - } - } - assert_eq!(fields.borrow().len(), file_schema.fields.len()); - Schema::new_with_metadata( - fields.borrow_mut().clone(), - file_schema.metadata.clone(), - ) - }; - - Some(transformed_schema) -} - -/// Coerces the file schema if the table schema uses a view type. -#[deprecated( - since = "47.0.0", - note = "Use `apply_file_schema_type_coercions` instead" -)] -pub fn coerce_file_schema_to_view_type( - table_schema: &Schema, - file_schema: &Schema, -) -> Option { - let mut transform = false; - let table_fields: HashMap<_, _> = table_schema - .fields - .iter() - .map(|f| { - let dt = f.data_type(); - if dt.equals_datatype(&DataType::Utf8View) - || dt.equals_datatype(&DataType::BinaryView) - { - transform = true; - } - (f.name(), dt) - }) - .collect(); - - if !transform { - return None; - } - - let transformed_fields: Vec> = file_schema - .fields - .iter() - .map( - |field| match (table_fields.get(field.name()), field.data_type()) { - (Some(DataType::Utf8View), DataType::Utf8 | DataType::LargeUtf8) => { - field_with_new_type(field, DataType::Utf8View) - } - ( - Some(DataType::BinaryView), - DataType::Binary | DataType::LargeBinary, - ) => field_with_new_type(field, DataType::BinaryView), - _ => Arc::clone(field), - }, - ) - .collect(); - - Some(Schema::new_with_metadata( - transformed_fields, - file_schema.metadata.clone(), - )) -} - -/// If the table schema uses a string type, coerce the file schema to use a string type. -/// -/// See [ParquetFormat::binary_as_string] for details -#[deprecated( - since = "47.0.0", - note = "Use `apply_file_schema_type_coercions` instead" -)] -pub fn coerce_file_schema_to_string_type( - table_schema: &Schema, - file_schema: &Schema, -) -> Option { - let mut transform = false; - let table_fields: HashMap<_, _> = table_schema - .fields - .iter() - .map(|f| (f.name(), f.data_type())) - .collect(); - let transformed_fields: Vec> = file_schema - .fields - .iter() - .map( - |field| match (table_fields.get(field.name()), field.data_type()) { - // table schema uses string type, coerce the file schema to use string type - ( - Some(DataType::Utf8), - DataType::Binary | DataType::LargeBinary | DataType::BinaryView, - ) => { - transform = true; - field_with_new_type(field, DataType::Utf8) - } - // table schema uses large string type, coerce the file schema to use large string type - ( - Some(DataType::LargeUtf8), - DataType::Binary | DataType::LargeBinary | DataType::BinaryView, - ) => { - transform = true; - field_with_new_type(field, DataType::LargeUtf8) - } - // table schema uses string view type, coerce the file schema to use view type - ( - Some(DataType::Utf8View), - DataType::Binary | DataType::LargeBinary | DataType::BinaryView, - ) => { - transform = true; - field_with_new_type(field, DataType::Utf8View) - } - _ => Arc::clone(field), - }, - ) - .collect(); - - if !transform { - None - } else { - Some(Schema::new_with_metadata( - transformed_fields, - file_schema.metadata.clone(), - )) - } -} - -/// Create a new field with the specified data type, copying the other -/// properties from the input field -fn field_with_new_type(field: &FieldRef, new_type: DataType) -> FieldRef { - Arc::new(field.as_ref().clone().with_data_type(new_type)) -} - -/// Transform a schema to use view types for Utf8 and Binary -/// -/// See [ParquetFormat::force_view_types] for details -pub fn transform_schema_to_view(schema: &Schema) -> Schema { - let transformed_fields: Vec> = schema - .fields - .iter() - .map(|field| match field.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => { - field_with_new_type(field, DataType::Utf8View) - } - DataType::Binary | DataType::LargeBinary => { - field_with_new_type(field, DataType::BinaryView) - } - _ => Arc::clone(field), - }) - .collect(); - Schema::new_with_metadata(transformed_fields, schema.metadata.clone()) -} - -/// Transform a schema so that any binary types are strings -pub fn transform_binary_to_string(schema: &Schema) -> Schema { - let transformed_fields: Vec> = schema - .fields - .iter() - .map(|field| match field.data_type() { - DataType::Binary => field_with_new_type(field, DataType::Utf8), - DataType::LargeBinary => field_with_new_type(field, DataType::LargeUtf8), - DataType::BinaryView => field_with_new_type(field, DataType::Utf8View), - _ => Arc::clone(field), - }) - .collect(); - Schema::new_with_metadata(transformed_fields, schema.metadata.clone()) -} - /// [`MetadataFetch`] adapter for reading bytes from an [`ObjectStore`] pub struct ObjectStoreFetch<'a> { store: &'a dyn ObjectStore, @@ -1079,832 +672,3 @@ pub fn statistics_from_parquet_meta_calc( ) -> Result { DFParquetMetadata::statistics_from_parquet_metadata(metadata, &table_schema) } - -/// Implements [`DataSink`] for writing to a parquet file. -pub struct ParquetSink { - /// Config options for writing data - config: FileSinkConfig, - /// Underlying parquet options - parquet_options: TableParquetOptions, - /// File metadata from successfully produced parquet files. The Mutex is only used - /// to allow inserting to HashMap from behind borrowed reference in DataSink::write_all. - written: Arc>>, -} - -impl Debug for ParquetSink { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ParquetSink").finish() - } -} - -impl DisplayAs for ParquetSink { - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match t { - DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "ParquetSink(file_groups=",)?; - FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?; - write!(f, ")") - } - DisplayFormatType::TreeRender => { - // TODO: collect info - write!(f, "") - } - } - } -} - -impl ParquetSink { - /// Create from config. - pub fn new(config: FileSinkConfig, parquet_options: TableParquetOptions) -> Self { - Self { - config, - parquet_options, - written: Default::default(), - } - } - - /// Retrieve the file metadata for the written files, keyed to the path - /// which may be partitioned (in the case of hive style partitioning). - pub fn written(&self) -> HashMap { - self.written.lock().clone() - } - - /// Create writer properties based upon configuration settings, - /// including partitioning and the inclusion of arrow schema metadata. - async fn create_writer_props( - &self, - runtime: &Arc, - path: &Path, - ) -> Result { - let schema = self.config.output_schema(); - - // TODO: avoid this clone in follow up PR, where the writer properties & schema - // are calculated once on `ParquetSink::new` - let mut parquet_opts = self.parquet_options.clone(); - if !self.parquet_options.global.skip_arrow_metadata { - parquet_opts.arrow_schema(schema); - } - - let mut builder = WriterPropertiesBuilder::try_from(&parquet_opts)?; - builder = set_writer_encryption_properties( - builder, - runtime, - parquet_opts, - schema, - path, - ) - .await?; - Ok(builder.build()) - } - - /// Creates an AsyncArrowWriter which serializes a parquet file to an ObjectStore - /// AsyncArrowWriters are used when individual parquet file serialization is not parallelized - async fn create_async_arrow_writer( - &self, - location: &Path, - object_store: Arc, - context: &Arc, - parquet_props: WriterProperties, - ) -> Result> { - let buf_writer = BufWriter::with_capacity( - object_store, - location.clone(), - context - .session_config() - .options() - .execution - .objectstore_writer_buffer_size, - ); - let options = ArrowWriterOptions::new() - .with_properties(parquet_props) - .with_skip_arrow_metadata(self.parquet_options.global.skip_arrow_metadata); - - let writer = AsyncArrowWriter::try_new_with_options( - buf_writer, - get_writer_schema(&self.config), - options, - )?; - Ok(writer) - } - - /// Parquet options - pub fn parquet_options(&self) -> &TableParquetOptions { - &self.parquet_options - } -} - -#[cfg(feature = "parquet_encryption")] -async fn set_writer_encryption_properties( - builder: WriterPropertiesBuilder, - runtime: &Arc, - parquet_opts: TableParquetOptions, - schema: &Arc, - path: &Path, -) -> Result { - if let Some(file_encryption_properties) = parquet_opts.crypto.file_encryption { - // Encryption properties have been specified directly - return Ok(builder.with_file_encryption_properties(Arc::new( - FileEncryptionProperties::from(file_encryption_properties), - ))); - } else if let Some(encryption_factory_id) = &parquet_opts.crypto.factory_id.as_ref() { - // Encryption properties will be generated by an encryption factory - let encryption_factory = - runtime.parquet_encryption_factory(encryption_factory_id)?; - let file_encryption_properties = encryption_factory - .get_file_encryption_properties( - &parquet_opts.crypto.factory_options, - schema, - path, - ) - .await?; - if let Some(file_encryption_properties) = file_encryption_properties { - return Ok( - builder.with_file_encryption_properties(file_encryption_properties) - ); - } - } - Ok(builder) -} - -#[cfg(not(feature = "parquet_encryption"))] -async fn set_writer_encryption_properties( - builder: WriterPropertiesBuilder, - _runtime: &Arc, - _parquet_opts: TableParquetOptions, - _schema: &Arc, - _path: &Path, -) -> Result { - Ok(builder) -} - -#[async_trait] -impl FileSink for ParquetSink { - fn config(&self) -> &FileSinkConfig { - &self.config - } - - async fn spawn_writer_tasks_and_join( - &self, - context: &Arc, - demux_task: SpawnedTask>, - mut file_stream_rx: DemuxedStreamReceiver, - object_store: Arc, - ) -> Result { - let parquet_opts = &self.parquet_options; - - let mut file_write_tasks: JoinSet< - std::result::Result<(Path, ParquetMetaData), DataFusionError>, - > = JoinSet::new(); - - let runtime = context.runtime_env(); - let parallel_options = ParallelParquetWriterOptions { - max_parallel_row_groups: parquet_opts - .global - .maximum_parallel_row_group_writers, - max_buffered_record_batches_per_stream: parquet_opts - .global - .maximum_buffered_record_batches_per_stream, - }; - - while let Some((path, mut rx)) = file_stream_rx.recv().await { - let parquet_props = self.create_writer_props(&runtime, &path).await?; - if !parquet_opts.global.allow_single_file_parallelism { - let mut writer = self - .create_async_arrow_writer( - &path, - Arc::clone(&object_store), - context, - parquet_props.clone(), - ) - .await?; - let mut reservation = MemoryConsumer::new(format!("ParquetSink[{path}]")) - .register(context.memory_pool()); - file_write_tasks.spawn(async move { - while let Some(batch) = rx.recv().await { - writer.write(&batch).await?; - reservation.try_resize(writer.memory_size())?; - } - let parquet_meta_data = writer - .close() - .await - .map_err(|e| DataFusionError::ParquetError(Box::new(e)))?; - Ok((path, parquet_meta_data)) - }); - } else { - let writer = ObjectWriterBuilder::new( - // Parquet files as a whole are never compressed, since they - // manage compressed blocks themselves. - FileCompressionType::UNCOMPRESSED, - &path, - Arc::clone(&object_store), - ) - .with_buffer_size(Some( - context - .session_config() - .options() - .execution - .objectstore_writer_buffer_size, - )) - .build()?; - let schema = get_writer_schema(&self.config); - let props = parquet_props.clone(); - let skip_arrow_metadata = self.parquet_options.global.skip_arrow_metadata; - let parallel_options_clone = parallel_options.clone(); - let pool = Arc::clone(context.memory_pool()); - file_write_tasks.spawn(async move { - let parquet_meta_data = output_single_parquet_file_parallelized( - writer, - rx, - schema, - &props, - skip_arrow_metadata, - parallel_options_clone, - pool, - ) - .await?; - Ok((path, parquet_meta_data)) - }); - } - } - - let mut row_count = 0; - while let Some(result) = file_write_tasks.join_next().await { - match result { - Ok(r) => { - let (path, parquet_meta_data) = r?; - row_count += parquet_meta_data.file_metadata().num_rows(); - let mut written_files = self.written.lock(); - written_files - .try_insert(path.clone(), parquet_meta_data) - .map_err(|e| internal_datafusion_err!("duplicate entry detected for partitioned file {path}: {e}"))?; - drop(written_files); - } - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } - } - } - } - - demux_task - .join_unwind() - .await - .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; - - Ok(row_count as u64) - } -} - -#[async_trait] -impl DataSink for ParquetSink { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> &SchemaRef { - self.config.output_schema() - } - - async fn write_all( - &self, - data: SendableRecordBatchStream, - context: &Arc, - ) -> Result { - FileSink::write_all(self, data, context).await - } -} - -/// Consumes a stream of [ArrowLeafColumn] via a channel and serializes them using an [ArrowColumnWriter] -/// Once the channel is exhausted, returns the ArrowColumnWriter. -async fn column_serializer_task( - mut rx: Receiver, - mut writer: ArrowColumnWriter, - mut reservation: MemoryReservation, -) -> Result<(ArrowColumnWriter, MemoryReservation)> { - while let Some(col) = rx.recv().await { - writer.write(&col)?; - reservation.try_resize(writer.memory_size())?; - } - Ok((writer, reservation)) -} - -type ColumnWriterTask = SpawnedTask>; -type ColSender = Sender; - -/// Spawns a parallel serialization task for each column -/// Returns join handles for each columns serialization task along with a send channel -/// to send arrow arrays to each serialization task. -fn spawn_column_parallel_row_group_writer( - col_writers: Vec, - max_buffer_size: usize, - pool: &Arc, -) -> Result<(Vec, Vec)> { - let num_columns = col_writers.len(); - - let mut col_writer_tasks = Vec::with_capacity(num_columns); - let mut col_array_channels = Vec::with_capacity(num_columns); - for writer in col_writers.into_iter() { - // Buffer size of this channel limits the number of arrays queued up for column level serialization - let (send_array, receive_array) = - mpsc::channel::(max_buffer_size); - col_array_channels.push(send_array); - - let reservation = - MemoryConsumer::new("ParquetSink(ArrowColumnWriter)").register(pool); - let task = SpawnedTask::spawn(column_serializer_task( - receive_array, - writer, - reservation, - )); - col_writer_tasks.push(task); - } - - Ok((col_writer_tasks, col_array_channels)) -} - -/// Settings related to writing parquet files in parallel -#[derive(Clone)] -struct ParallelParquetWriterOptions { - max_parallel_row_groups: usize, - max_buffered_record_batches_per_stream: usize, -} - -/// This is the return type of calling [ArrowColumnWriter].close() on each column -/// i.e. the Vec of encoded columns which can be appended to a row group -type RBStreamSerializeResult = Result<(Vec, MemoryReservation, usize)>; - -/// Sends the ArrowArrays in passed [RecordBatch] through the channels to their respective -/// parallel column serializers. -async fn send_arrays_to_col_writers( - col_array_channels: &[ColSender], - rb: &RecordBatch, - schema: Arc, -) -> Result<()> { - // Each leaf column has its own channel, increment next_channel for each leaf column sent. - let mut next_channel = 0; - for (array, field) in rb.columns().iter().zip(schema.fields()) { - for c in compute_leaves(field, array)? { - // Do not surface error from closed channel (means something - // else hit an error, and the plan is shutting down). - if col_array_channels[next_channel].send(c).await.is_err() { - return Ok(()); - } - - next_channel += 1; - } - } - - Ok(()) -} - -/// Spawns a tokio task which joins the parallel column writer tasks, -/// and finalizes the row group -fn spawn_rg_join_and_finalize_task( - column_writer_tasks: Vec, - rg_rows: usize, - pool: &Arc, -) -> SpawnedTask { - let mut rg_reservation = - MemoryConsumer::new("ParquetSink(SerializedRowGroupWriter)").register(pool); - - SpawnedTask::spawn(async move { - let num_cols = column_writer_tasks.len(); - let mut finalized_rg = Vec::with_capacity(num_cols); - for task in column_writer_tasks.into_iter() { - let (writer, _col_reservation) = task - .join_unwind() - .await - .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; - let encoded_size = writer.get_estimated_total_bytes(); - rg_reservation.grow(encoded_size); - finalized_rg.push(writer.close()?); - } - - Ok((finalized_rg, rg_reservation, rg_rows)) - }) -} - -/// This task coordinates the serialization of a parquet file in parallel. -/// As the query produces RecordBatches, these are written to a RowGroup -/// via parallel [ArrowColumnWriter] tasks. Once the desired max rows per -/// row group is reached, the parallel tasks are joined on another separate task -/// and sent to a concatenation task. This task immediately continues to work -/// on the next row group in parallel. So, parquet serialization is parallelized -/// across both columns and row_groups, with a theoretical max number of parallel tasks -/// given by n_columns * num_row_groups. -fn spawn_parquet_parallel_serialization_task( - row_group_writer_factory: ArrowRowGroupWriterFactory, - mut data: Receiver, - serialize_tx: Sender>, - schema: Arc, - writer_props: Arc, - parallel_options: Arc, - pool: Arc, -) -> SpawnedTask> { - SpawnedTask::spawn(async move { - let max_buffer_rb = parallel_options.max_buffered_record_batches_per_stream; - let max_row_group_rows = writer_props.max_row_group_size(); - let mut row_group_index = 0; - let col_writers = - row_group_writer_factory.create_column_writers(row_group_index)?; - let (mut column_writer_handles, mut col_array_channels) = - spawn_column_parallel_row_group_writer(col_writers, max_buffer_rb, &pool)?; - let mut current_rg_rows = 0; - - while let Some(mut rb) = data.recv().await { - // This loop allows the "else" block to repeatedly split the RecordBatch to handle the case - // when max_row_group_rows < execution.batch_size as an alternative to a recursive async - // function. - loop { - if current_rg_rows + rb.num_rows() < max_row_group_rows { - send_arrays_to_col_writers( - &col_array_channels, - &rb, - Arc::clone(&schema), - ) - .await?; - current_rg_rows += rb.num_rows(); - break; - } else { - let rows_left = max_row_group_rows - current_rg_rows; - let a = rb.slice(0, rows_left); - send_arrays_to_col_writers( - &col_array_channels, - &a, - Arc::clone(&schema), - ) - .await?; - - // Signal the parallel column writers that the RowGroup is done, join and finalize RowGroup - // on a separate task, so that we can immediately start on the next RG before waiting - // for the current one to finish. - drop(col_array_channels); - let finalize_rg_task = spawn_rg_join_and_finalize_task( - column_writer_handles, - max_row_group_rows, - &pool, - ); - - // Do not surface error from closed channel (means something - // else hit an error, and the plan is shutting down). - if serialize_tx.send(finalize_rg_task).await.is_err() { - return Ok(()); - } - - current_rg_rows = 0; - rb = rb.slice(rows_left, rb.num_rows() - rows_left); - - row_group_index += 1; - let col_writers = row_group_writer_factory - .create_column_writers(row_group_index)?; - (column_writer_handles, col_array_channels) = - spawn_column_parallel_row_group_writer( - col_writers, - max_buffer_rb, - &pool, - )?; - } - } - } - - drop(col_array_channels); - // Handle leftover rows as final rowgroup, which may be smaller than max_row_group_rows - if current_rg_rows > 0 { - let finalize_rg_task = spawn_rg_join_and_finalize_task( - column_writer_handles, - current_rg_rows, - &pool, - ); - - // Do not surface error from closed channel (means something - // else hit an error, and the plan is shutting down). - if serialize_tx.send(finalize_rg_task).await.is_err() { - return Ok(()); - } - } - - Ok(()) - }) -} - -/// Consume RowGroups serialized by other parallel tasks and concatenate them in -/// to the final parquet file, while flushing finalized bytes to an [ObjectStore] -async fn concatenate_parallel_row_groups( - mut parquet_writer: SerializedFileWriter, - merged_buff: SharedBuffer, - mut serialize_rx: Receiver>, - mut object_store_writer: Box, - pool: Arc, -) -> Result { - let mut file_reservation = - MemoryConsumer::new("ParquetSink(SerializedFileWriter)").register(&pool); - - while let Some(task) = serialize_rx.recv().await { - let result = task.join_unwind().await; - let (serialized_columns, mut rg_reservation, _cnt) = - result.map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; - - let mut rg_out = parquet_writer.next_row_group()?; - for chunk in serialized_columns { - chunk.append_to_row_group(&mut rg_out)?; - rg_reservation.free(); - - let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); - file_reservation.try_resize(buff_to_flush.len())?; - - if buff_to_flush.len() > BUFFER_FLUSH_BYTES { - object_store_writer - .write_all(buff_to_flush.as_slice()) - .await?; - buff_to_flush.clear(); - file_reservation.try_resize(buff_to_flush.len())?; // will set to zero - } - } - rg_out.close()?; - } - - let parquet_meta_data = parquet_writer.close()?; - let final_buff = merged_buff.buffer.try_lock().unwrap(); - - object_store_writer.write_all(final_buff.as_slice()).await?; - object_store_writer.shutdown().await?; - file_reservation.free(); - - Ok(parquet_meta_data) -} - -/// Parallelizes the serialization of a single parquet file, by first serializing N -/// independent RecordBatch streams in parallel to RowGroups in memory. Another -/// task then stitches these independent RowGroups together and streams this large -/// single parquet file to an ObjectStore in multiple parts. -async fn output_single_parquet_file_parallelized( - object_store_writer: Box, - data: Receiver, - output_schema: Arc, - parquet_props: &WriterProperties, - skip_arrow_metadata: bool, - parallel_options: ParallelParquetWriterOptions, - pool: Arc, -) -> Result { - let max_rowgroups = parallel_options.max_parallel_row_groups; - // Buffer size of this channel limits maximum number of RowGroups being worked on in parallel - let (serialize_tx, serialize_rx) = - mpsc::channel::>(max_rowgroups); - - let arc_props = Arc::new(parquet_props.clone()); - let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES); - let options = ArrowWriterOptions::new() - .with_properties(parquet_props.clone()) - .with_skip_arrow_metadata(skip_arrow_metadata); - let writer = ArrowWriter::try_new_with_options( - merged_buff.clone(), - Arc::clone(&output_schema), - options, - )?; - let (writer, row_group_writer_factory) = writer.into_serialized_writer()?; - - let launch_serialization_task = spawn_parquet_parallel_serialization_task( - row_group_writer_factory, - data, - serialize_tx, - Arc::clone(&output_schema), - Arc::clone(&arc_props), - parallel_options.into(), - Arc::clone(&pool), - ); - let parquet_meta_data = concatenate_parallel_row_groups( - writer, - merged_buff, - serialize_rx, - object_store_writer, - pool, - ) - .await?; - - launch_serialization_task - .join_unwind() - .await - .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; - Ok(parquet_meta_data) -} - -#[cfg(test)] -mod tests { - use parquet::arrow::parquet_to_arrow_schema; - use std::sync::Arc; - - use super::*; - - use arrow::datatypes::DataType; - use parquet::schema::parser::parse_message_type; - - #[test] - fn coerce_int96_to_resolution_with_mixed_timestamps() { - // Unclear if Spark (or other writer) could generate a file with mixed timestamps like this, - // but we want to test the scenario just in case since it's at least a valid schema as far - // as the Parquet spec is concerned. - let spark_schema = " - message spark_schema { - optional int96 c0; - optional int64 c1 (TIMESTAMP(NANOS,true)); - optional int64 c2 (TIMESTAMP(NANOS,false)); - optional int64 c3 (TIMESTAMP(MILLIS,true)); - optional int64 c4 (TIMESTAMP(MILLIS,false)); - optional int64 c5 (TIMESTAMP(MICROS,true)); - optional int64 c6 (TIMESTAMP(MICROS,false)); - } - "; - - let schema = parse_message_type(spark_schema).expect("should parse schema"); - let descr = SchemaDescriptor::new(Arc::new(schema)); - - let arrow_schema = parquet_to_arrow_schema(&descr, None).unwrap(); - - let result = - coerce_int96_to_resolution(&descr, &arrow_schema, &TimeUnit::Microsecond) - .unwrap(); - - // Only the first field (c0) should be converted to a microsecond timestamp because it's the - // only timestamp that originated from an INT96. - let expected_schema = Schema::new(vec![ - Field::new("c0", DataType::Timestamp(TimeUnit::Microsecond, None), true), - Field::new( - "c1", - DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), - true, - ), - Field::new("c2", DataType::Timestamp(TimeUnit::Nanosecond, None), true), - Field::new( - "c3", - DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), - true, - ), - Field::new("c4", DataType::Timestamp(TimeUnit::Millisecond, None), true), - Field::new( - "c5", - DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), - true, - ), - Field::new("c6", DataType::Timestamp(TimeUnit::Microsecond, None), true), - ]); - - assert_eq!(result, expected_schema); - } - - #[test] - fn coerce_int96_to_resolution_with_nested_types() { - // This schema is derived from Comet's CometFuzzTestSuite ParquetGenerator only using int96 - // primitive types with generateStruct, generateArray, and generateMap set to true, with one - // additional field added to c4's struct to make sure all fields in a struct get modified. - // https://github.com/apache/datafusion-comet/blob/main/spark/src/main/scala/org/apache/comet/testing/ParquetGenerator.scala - let spark_schema = " - message spark_schema { - optional int96 c0; - optional group c1 { - optional int96 c0; - } - optional group c2 { - optional group c0 (LIST) { - repeated group list { - optional int96 element; - } - } - } - optional group c3 (LIST) { - repeated group list { - optional int96 element; - } - } - optional group c4 (LIST) { - repeated group list { - optional group element { - optional int96 c0; - optional int96 c1; - } - } - } - optional group c5 (MAP) { - repeated group key_value { - required int96 key; - optional int96 value; - } - } - optional group c6 (LIST) { - repeated group list { - optional group element (MAP) { - repeated group key_value { - required int96 key; - optional int96 value; - } - } - } - } - } - "; - - let schema = parse_message_type(spark_schema).expect("should parse schema"); - let descr = SchemaDescriptor::new(Arc::new(schema)); - - let arrow_schema = parquet_to_arrow_schema(&descr, None).unwrap(); - - let result = - coerce_int96_to_resolution(&descr, &arrow_schema, &TimeUnit::Microsecond) - .unwrap(); - - let expected_schema = Schema::new(vec![ - Field::new("c0", DataType::Timestamp(TimeUnit::Microsecond, None), true), - Field::new_struct( - "c1", - vec![Field::new( - "c0", - DataType::Timestamp(TimeUnit::Microsecond, None), - true, - )], - true, - ), - Field::new_struct( - "c2", - vec![Field::new_list( - "c0", - Field::new( - "element", - DataType::Timestamp(TimeUnit::Microsecond, None), - true, - ), - true, - )], - true, - ), - Field::new_list( - "c3", - Field::new( - "element", - DataType::Timestamp(TimeUnit::Microsecond, None), - true, - ), - true, - ), - Field::new_list( - "c4", - Field::new_struct( - "element", - vec![ - Field::new( - "c0", - DataType::Timestamp(TimeUnit::Microsecond, None), - true, - ), - Field::new( - "c1", - DataType::Timestamp(TimeUnit::Microsecond, None), - true, - ), - ], - true, - ), - true, - ), - Field::new_map( - "c5", - "key_value", - Field::new( - "key", - DataType::Timestamp(TimeUnit::Microsecond, None), - false, - ), - Field::new( - "value", - DataType::Timestamp(TimeUnit::Microsecond, None), - true, - ), - false, - true, - ), - Field::new_list( - "c6", - Field::new_map( - "element", - "key_value", - Field::new( - "key", - DataType::Timestamp(TimeUnit::Microsecond, None), - false, - ), - Field::new( - "value", - DataType::Timestamp(TimeUnit::Microsecond, None), - true, - ), - false, - true, - ), - true, - ), - ]); - - assert_eq!(result, expected_schema); - } -} diff --git a/datafusion/datasource-parquet/src/metadata.rs b/datafusion/datasource-parquet/src/metadata.rs index fcd3a22dcf943..d3831766a42ab 100644 --- a/datafusion/datasource-parquet/src/metadata.rs +++ b/datafusion/datasource-parquet/src/metadata.rs @@ -18,34 +18,44 @@ //! [`DFParquetMetadata`] for fetching Parquet file metadata, statistics //! and schema information. -use crate::{ - apply_file_schema_type_coercions, coerce_int96_to_resolution, ObjectStoreFetch, -}; -use arrow::array::{ArrayRef, BooleanArray}; -use arrow::compute::and; +use crate::{Int96Coercer, apply_file_schema_type_coercions}; +use arrow::array::{Array, ArrayRef, BooleanArray}; use arrow::compute::kernels::cmp::eq; -use arrow::compute::sum; +use arrow::compute::{and, sum}; use arrow::datatypes::{DataType, Schema, SchemaRef, TimeUnit}; use datafusion_common::encryption::FileDecryptionProperties; use datafusion_common::stats::Precision; use datafusion_common::{ ColumnStatistics, DataFusionError, Result, ScalarValue, Statistics, }; -use datafusion_execution::cache::cache_manager::{FileMetadata, FileMetadataCache}; +use datafusion_execution::cache::cache_manager::{ + CachedFileMetadataEntry, FileMetadata, FileMetadataCache, +}; use datafusion_functions_aggregate_common::min_max::{MaxAccumulator, MinAccumulator}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_plan::Accumulator; use log::debug; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; +use parquet::DecodeResult; use parquet::arrow::arrow_reader::statistics::StatisticsConverter; -use parquet::arrow::parquet_to_arrow_schema; +use parquet::arrow::{parquet_column, parquet_to_arrow_schema}; use parquet::file::metadata::{ - PageIndexPolicy, ParquetMetaData, ParquetMetaDataReader, RowGroupMetaData, + PageIndexPolicy, ParquetMetaData, ParquetMetaDataPushDecoder, RowGroupMetaData, + SortingColumn, }; +use parquet::file::statistics::Statistics as ParquetStatistics; +use parquet::schema::types::SchemaDescriptor; use std::any::Any; use std::collections::HashMap; use std::sync::Arc; +/// Minimum fraction of row groups that must report NDV statistics for the +/// merged result to be `Inexact` rather than `Absent`, as the estimate +/// would be too unreliable otherwise. +const PARTIAL_NDV_THRESHOLD: f64 = 0.75; + /// Handles fetching Parquet file schema, metadata and statistics /// from object store. /// @@ -62,6 +72,8 @@ pub struct DFParquetMetadata<'a> { file_metadata_cache: Option>, /// timeunit to coerce INT96 timestamps to pub coerce_int96: Option, + /// Optional timezone applied to INT96-coerced timestamps. + pub coerce_int96_tz: Option>, } impl<'a> DFParquetMetadata<'a> { @@ -73,6 +85,7 @@ impl<'a> DFParquetMetadata<'a> { decryption_properties: None, file_metadata_cache: None, coerce_int96: None, + coerce_int96_tz: None, } } @@ -106,68 +119,96 @@ impl<'a> DFParquetMetadata<'a> { self } + /// Set the optional timezone applied to INT96-coerced timestamps. + pub fn with_coerce_int96_tz(mut self, timezone: Option>) -> Self { + self.coerce_int96_tz = timezone; + self + } + /// Fetch parquet metadata from the remote object store pub async fn fetch_metadata(&self) -> Result> { - let Self { - store, - object_meta, - metadata_size_hint, - decryption_properties, - file_metadata_cache, - coerce_int96: _, - } = self; - - let fetch = ObjectStoreFetch::new(*store, object_meta); - // implementation to fetch parquet metadata let cache_metadata = - !cfg!(feature = "parquet_encryption") || decryption_properties.is_none(); - - if cache_metadata { - if let Some(parquet_metadata) = file_metadata_cache - .as_ref() - .and_then(|file_metadata_cache| file_metadata_cache.get(object_meta)) - .and_then(|file_metadata| { - file_metadata - .as_any() - .downcast_ref::() - .map(|cached_parquet_metadata| { - Arc::clone(cached_parquet_metadata.parquet_metadata()) - }) - }) - { - return Ok(parquet_metadata); - } + !cfg!(feature = "parquet_encryption") || self.decryption_properties.is_none(); + + if cache_metadata + && let Some(file_metadata_cache) = self.file_metadata_cache.as_ref() + && let Some(cached) = file_metadata_cache.get(&self.object_meta.location) + && cached.is_valid_for(self.object_meta) + && let Some(cached_parquet) = cached + .file_metadata + .as_any() + .downcast_ref::() + { + return Ok(Arc::clone(cached_parquet.parquet_metadata())); } - let mut reader = - ParquetMetaDataReader::new().with_prefetch_hint(*metadata_size_hint); + let file_size = self.object_meta.size; + let mut decoder = ParquetMetaDataPushDecoder::try_new(file_size) + .map_err(DataFusionError::from)?; #[cfg(feature = "parquet_encryption")] - if let Some(decryption_properties) = decryption_properties { - reader = reader - .with_decryption_properties(Some(Arc::clone(decryption_properties))); + if let Some(decryption_properties) = &self.decryption_properties { + decoder = decoder + .with_file_decryption_properties(Some(Arc::clone(decryption_properties))); } - if cache_metadata && file_metadata_cache.is_some() { + if cache_metadata && self.file_metadata_cache.is_some() { // Need to retrieve the entire metadata for the caching to be effective. - reader = reader.with_page_index_policy(PageIndexPolicy::Optional); + decoder = decoder.with_page_index_policy(PageIndexPolicy::Optional); + } else { + decoder = decoder.with_page_index_policy(PageIndexPolicy::Skip); } - let metadata = Arc::new( - reader - .load_and_finish(fetch, object_meta.size) + // If we have a size hint, prefetch that many bytes from the end of the file + if let Some(hint) = self.metadata_size_hint { + let prefetch_start = file_size.saturating_sub(hint as u64); + let prefetch_range = prefetch_start..file_size; + let data = self + .store + .get_ranges( + &self.object_meta.location, + std::slice::from_ref(&prefetch_range), + ) .await - .map_err(DataFusionError::from)?, - ); + .map_err(DataFusionError::from)?; + decoder + .push_ranges(vec![prefetch_range], data) + .map_err(DataFusionError::from)?; + } - if cache_metadata { - if let Some(file_metadata_cache) = file_metadata_cache { - file_metadata_cache.put( - object_meta, - Arc::new(CachedParquetMetaData::new(Arc::clone(&metadata))), - ); + let metadata = loop { + match decoder.try_decode().map_err(DataFusionError::from)? { + DecodeResult::Data(metadata) => break metadata, + DecodeResult::NeedsData(ranges) => { + let buffers = self + .store + .get_ranges(&self.object_meta.location, &ranges) + .await + .map_err(DataFusionError::from)?; + decoder + .push_ranges(ranges, buffers) + .map_err(DataFusionError::from)?; + } + DecodeResult::Finished => { + return Err(DataFusionError::Internal( + "ParquetMetaDataPushDecoder finished without producing metadata" + .to_string(), + )); + } } + }; + + let metadata = Arc::new(metadata); + + if cache_metadata && let Some(file_metadata_cache) = &self.file_metadata_cache { + file_metadata_cache.put( + &self.object_meta.location, + CachedFileMetadataEntry::new( + self.object_meta.clone(), + Arc::new(CachedParquetMetaData::new(Arc::clone(&metadata))), + ), + ); } Ok(metadata) @@ -186,11 +227,9 @@ impl<'a> DFParquetMetadata<'a> { .coerce_int96 .as_ref() .and_then(|time_unit| { - coerce_int96_to_resolution( - file_metadata.schema_descr(), - &schema, - time_unit, - ) + Int96Coercer::new(file_metadata.schema_descr(), &schema, time_unit) + .with_timezone(self.coerce_int96_tz.clone()) + .coerce() }) .unwrap_or(schema); Ok(schema) @@ -227,30 +266,40 @@ impl<'a> DFParquetMetadata<'a> { /// - Exact row count /// - Exact byte size /// - All column statistics marked as unknown via Statistics::unknown_column(&table_schema) + /// - Column byte sizes are still calculated and recorded + /// /// # When only some columns have statistics: /// /// For columns with statistics: /// - Min/max values are properly extracted and represented as Precision::Exact /// - Null counts are calculated by summing across row groups + /// - Byte sizes are calculated and recorded /// /// For columns without statistics, /// - For min/max, there are two situations: /// 1. The column isn't in arrow schema, then min/max values are set to Precision::Absent /// 2. The column is in arrow schema, but not in parquet schema due to schema revolution, min/max values are set to Precision::Exact(null) /// - Null counts are set to Precision::Exact(num_rows) (conservatively assuming all values could be null) + /// + /// # Byte Size Calculation: + /// + /// - For primitive types with known fixed size, exact byte size is calculated as (byte width * number of rows) + /// - For other types, uncompressed Parquet size is used as an estimate for in-memory size + /// - If neither method is applicable, byte size is marked as Precision::Absent pub fn statistics_from_parquet_metadata( metadata: &ParquetMetaData, - table_schema: &SchemaRef, + logical_file_schema: &SchemaRef, ) -> Result { let row_groups_metadata = metadata.row_groups(); - let mut statistics = Statistics::new_unknown(table_schema); + // Use Statistics::default() as opposed to Statistics::new_unknown() + // because we are going to replace the column statistics below + // and we don't want to initialize them twice. + let mut statistics = Statistics::default(); let mut has_statistics = false; let mut num_rows = 0_usize; - let mut total_byte_size = 0_usize; for row_group_meta in row_groups_metadata { num_rows += row_group_meta.num_rows() as usize; - total_byte_size += row_group_meta.total_byte_size() as usize; if !has_statistics { has_statistics = row_group_meta @@ -260,33 +309,37 @@ impl<'a> DFParquetMetadata<'a> { } } statistics.num_rows = Precision::Exact(num_rows); - statistics.total_byte_size = Precision::Exact(total_byte_size); let file_metadata = metadata.file_metadata(); - let mut file_schema = parquet_to_arrow_schema( + let mut physical_file_schema = parquet_to_arrow_schema( file_metadata.schema_descr(), file_metadata.key_value_metadata(), )?; - if let Some(merged) = apply_file_schema_type_coercions(table_schema, &file_schema) + if let Some(merged) = + apply_file_schema_type_coercions(logical_file_schema, &physical_file_schema) { - file_schema = merged; - } - - statistics.column_statistics = if has_statistics { - let (mut max_accs, mut min_accs) = create_max_min_accs(table_schema); - let mut null_counts_array = - vec![Precision::Exact(0); table_schema.fields().len()]; - let mut is_max_value_exact = vec![Some(true); table_schema.fields().len()]; - let mut is_min_value_exact = vec![Some(true); table_schema.fields().len()]; - table_schema - .fields() - .iter() - .enumerate() - .for_each(|(idx, field)| { - match StatisticsConverter::try_new( + physical_file_schema = merged; + } + + statistics.column_statistics = + if has_statistics { + let (mut max_accs, mut min_accs) = + create_max_min_accs(logical_file_schema); + let mut null_counts_array = + vec![Precision::Absent; logical_file_schema.fields().len()]; + let mut column_byte_sizes = + vec![Precision::Absent; logical_file_schema.fields().len()]; + let mut is_max_value_exact = + vec![Some(true); logical_file_schema.fields().len()]; + let mut is_min_value_exact = + vec![Some(true); logical_file_schema.fields().len()]; + let mut distinct_counts_array = + vec![Precision::Absent; logical_file_schema.fields().len()]; + logical_file_schema.fields().iter().enumerate().for_each( + |(idx, field)| match StatisticsConverter::try_new( field.name(), - &file_schema, + &physical_file_schema, file_metadata.schema_descr(), ) { Ok(stats_converter) => { @@ -296,12 +349,16 @@ impl<'a> DFParquetMetadata<'a> { null_counts_array: &mut null_counts_array, is_min_value_exact: &mut is_min_value_exact, is_max_value_exact: &mut is_max_value_exact, + column_byte_sizes: &mut column_byte_sizes, + distinct_counts_array: &mut distinct_counts_array, }; - summarize_min_max_null_counts( + summarize_column_statistics( + logical_file_schema, &mut accumulators, idx, &stats_converter, row_groups_metadata, + num_rows, ) .ok(); } @@ -309,20 +366,54 @@ impl<'a> DFParquetMetadata<'a> { debug!("Failed to create statistics converter: {e}"); null_counts_array[idx] = Precision::Exact(num_rows); } - } - }); - - get_col_stats( - table_schema, - &null_counts_array, - &mut max_accs, - &mut min_accs, - &mut is_max_value_exact, - &mut is_min_value_exact, - ) - } else { - Statistics::unknown_column(table_schema) - }; + }, + ); + + let mut accumulators = StatisticsAccumulators { + min_accs: &mut min_accs, + max_accs: &mut max_accs, + null_counts_array: &mut null_counts_array, + is_min_value_exact: &mut is_min_value_exact, + is_max_value_exact: &mut is_max_value_exact, + column_byte_sizes: &mut column_byte_sizes, + distinct_counts_array: &mut distinct_counts_array, + }; + accumulators.build_column_statistics(logical_file_schema) + } else { + // Record column sizes + logical_file_schema + .fields() + .iter() + .enumerate() + .map(|(logical_file_schema_index, field)| { + let arrow_field = + logical_file_schema.field(logical_file_schema_index); + let parquet_idx = parquet_column( + file_metadata.schema_descr(), + &physical_file_schema, + arrow_field.name(), + ) + .map(|(idx, _)| idx); + let byte_size = compute_arrow_column_size( + field.data_type(), + row_groups_metadata, + parquet_idx, + num_rows, + ); + ColumnStatistics::new_unknown().with_byte_size(byte_size) + }) + .collect() + }; + + #[cfg(debug_assertions)] + { + // Check that the column statistics length matches the table schema fields length + assert_eq!( + statistics.column_statistics.len(), + logical_file_schema.fields().len(), + "Column statistics length does not match table schema fields length" + ); + } Ok(statistics) } @@ -360,51 +451,6 @@ fn create_max_min_accs( (max_values, min_values) } -fn get_col_stats( - schema: &Schema, - null_counts: &[Precision], - max_values: &mut [Option], - min_values: &mut [Option], - is_max_value_exact: &mut [Option], - is_min_value_exact: &mut [Option], -) -> Vec { - (0..schema.fields().len()) - .map(|i| { - let max_value = match ( - max_values.get_mut(i).unwrap(), - is_max_value_exact.get(i).unwrap(), - ) { - (Some(max_value), Some(true)) => { - max_value.evaluate().ok().map(Precision::Exact) - } - (Some(max_value), Some(false)) | (Some(max_value), None) => { - max_value.evaluate().ok().map(Precision::Inexact) - } - (None, _) => None, - }; - let min_value = match ( - min_values.get_mut(i).unwrap(), - is_min_value_exact.get(i).unwrap(), - ) { - (Some(min_value), Some(true)) => { - min_value.evaluate().ok().map(Precision::Exact) - } - (Some(min_value), Some(false)) | (Some(min_value), None) => { - min_value.evaluate().ok().map(Precision::Inexact) - } - (None, _) => None, - }; - ColumnStatistics { - null_count: null_counts[i], - max_value: max_value.unwrap_or(Precision::Absent), - min_value: min_value.unwrap_or(Precision::Absent), - sum_value: Precision::Absent, - distinct_count: Precision::Absent, - } - }) - .collect() -} - /// Holds the accumulator state for collecting statistics from row groups struct StatisticsAccumulators<'a> { min_accs: &'a mut [Option], @@ -412,52 +458,281 @@ struct StatisticsAccumulators<'a> { null_counts_array: &'a mut [Precision], is_min_value_exact: &'a mut [Option], is_max_value_exact: &'a mut [Option], + column_byte_sizes: &'a mut [Precision], + distinct_counts_array: &'a mut [Precision], +} + +impl StatisticsAccumulators<'_> { + /// Converts the accumulated statistics into a vector of `ColumnStatistics` + fn build_column_statistics(&mut self, schema: &Schema) -> Vec { + (0..schema.fields().len()) + .map(|i| { + let max_value = match ( + self.max_accs.get_mut(i).unwrap(), + self.is_max_value_exact.get(i).unwrap(), + ) { + (Some(max_value), Some(true)) => { + max_value.evaluate().ok().map(Precision::Exact) + } + (Some(max_value), Some(false)) | (Some(max_value), None) => { + max_value.evaluate().ok().map(Precision::Inexact) + } + (None, _) => None, + }; + let min_value = match ( + self.min_accs.get_mut(i).unwrap(), + self.is_min_value_exact.get(i).unwrap(), + ) { + (Some(min_value), Some(true)) => { + min_value.evaluate().ok().map(Precision::Exact) + } + (Some(min_value), Some(false)) | (Some(min_value), None) => { + min_value.evaluate().ok().map(Precision::Inexact) + } + (None, _) => None, + }; + ColumnStatistics { + null_count: self.null_counts_array[i], + max_value: max_value.unwrap_or(Precision::Absent), + min_value: min_value.unwrap_or(Precision::Absent), + sum_value: Precision::Absent, + distinct_count: self.distinct_counts_array[i], + byte_size: self.column_byte_sizes[i], + } + }) + .collect() + } } -fn summarize_min_max_null_counts( +fn summarize_column_statistics( + logical_file_schema: &Schema, accumulators: &mut StatisticsAccumulators, - arrow_schema_index: usize, + logical_schema_index: usize, stats_converter: &StatisticsConverter, row_groups_metadata: &[RowGroupMetaData], + num_rows: usize, ) -> Result<()> { - let max_values = stats_converter.row_group_maxes(row_groups_metadata)?; - let min_values = stats_converter.row_group_mins(row_groups_metadata)?; - let null_counts = stats_converter.row_group_null_counts(row_groups_metadata)?; - let is_max_value_exact_stat = - stats_converter.row_group_is_max_value_exact(row_groups_metadata)?; - let is_min_value_exact_stat = - stats_converter.row_group_is_min_value_exact(row_groups_metadata)?; - - if let Some(max_acc) = &mut accumulators.max_accs[arrow_schema_index] { - max_acc.update_batch(&[Arc::clone(&max_values)])?; - let mut cur_max_acc = max_acc.clone(); - accumulators.is_max_value_exact[arrow_schema_index] = has_any_exact_match( - &cur_max_acc.evaluate()?, - &max_values, - &is_max_value_exact_stat, - ); + let parquet_index = stats_converter.parquet_column_index(); + + if let Some(max_acc) = &mut accumulators.max_accs[logical_schema_index] { + accumulators.is_max_value_exact[logical_schema_index] = summarize_bound( + max_acc, + &stats_converter.row_group_maxes(row_groups_metadata)?, + parquet_index, + row_groups_metadata, + ParquetStatistics::max_is_exact, + || Ok(stats_converter.row_group_is_max_value_exact(row_groups_metadata)?), + )?; } - if let Some(min_acc) = &mut accumulators.min_accs[arrow_schema_index] { - min_acc.update_batch(&[Arc::clone(&min_values)])?; - let mut cur_min_acc = min_acc.clone(); - accumulators.is_min_value_exact[arrow_schema_index] = has_any_exact_match( - &cur_min_acc.evaluate()?, - &min_values, - &is_min_value_exact_stat, - ); + if let Some(min_acc) = &mut accumulators.min_accs[logical_schema_index] { + accumulators.is_min_value_exact[logical_schema_index] = summarize_bound( + min_acc, + &stats_converter.row_group_mins(row_groups_metadata)?, + parquet_index, + row_groups_metadata, + ParquetStatistics::min_is_exact, + || Ok(stats_converter.row_group_is_min_value_exact(row_groups_metadata)?), + )?; } - accumulators.null_counts_array[arrow_schema_index] = match sum(&null_counts) { - Some(null_count) => Precision::Exact(null_count as usize), + accumulators.null_counts_array[logical_schema_index] = + summarize_null_counts(stats_converter, row_groups_metadata)?; + + accumulators.distinct_counts_array[logical_schema_index] = + summarize_distinct_counts(parquet_index, row_groups_metadata); + + let arrow_field = logical_file_schema.field(logical_schema_index); + accumulators.column_byte_sizes[logical_schema_index] = compute_arrow_column_size( + arrow_field.data_type(), + row_groups_metadata, + parquet_index, + num_rows, + ); + + Ok(()) +} + +/// Feed a column's per-row-group min or max `values` into `acc` and decide +/// whether the resulting bound is exact across all row groups. +/// +/// `is_exact` reads the per-row-group exactness flag straight from the raw +/// parquet statistics. `row_group_exactness` rebuilds the exactness as a Boolean +/// array and is only called for the rare case where row groups disagree. +fn summarize_bound( + acc: &mut A, + values: &ArrayRef, + parquet_index: Option, + row_groups_metadata: &[RowGroupMetaData], + is_exact: impl Fn(&ParquetStatistics) -> bool, + row_group_exactness: impl FnOnce() -> Result, +) -> Result> { + acc.update_batch(&[Arc::clone(values)])?; + + Ok( + match summarize_row_group_exactness(parquet_index, row_groups_metadata, is_exact) + { + ExactnessSummary::AllExact => Some(true), + ExactnessSummary::NoneExact => Some(false), + ExactnessSummary::Mixed => { + let exactness = row_group_exactness()?; + has_any_exact_match(&acc.evaluate()?, values, &exactness) + } + }, + ) +} + +fn summarize_null_counts( + stats_converter: &StatisticsConverter, + row_groups_metadata: &[RowGroupMetaData], +) -> Result> { + if row_groups_metadata.is_empty() { + return Ok(Precision::Exact(0)); + } + + let null_counts = stats_converter.row_group_null_counts(row_groups_metadata)?; + + match sum(&null_counts) { + Some(count) => { + // If any row group has an unknown null_count, either because column + // statistics are absent or because the null_count field is omitted, + // report the aggregate as inexact. + if null_counts.null_count() > 0 { + Ok(Precision::Inexact(count as usize)) + } else { + Ok(Precision::Exact(count as usize)) + } + } None => match null_counts.len() { // If sum() returned None we either have no rows or all values are null - 0 => Precision::Exact(0), - _ => Precision::Absent, + 0 => Ok(Precision::Exact(0)), + _ => Ok(Precision::Absent), }, + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum ExactnessSummary { + AllExact, + NoneExact, + Mixed, +} + +fn summarize_row_group_exactness( + parquet_idx: Option, + row_groups_metadata: &[RowGroupMetaData], + exactness: impl Fn(&ParquetStatistics) -> bool, +) -> ExactnessSummary { + let Some(parquet_idx) = parquet_idx else { + return ExactnessSummary::NoneExact; }; - Ok(()) + summarize_exactness(row_groups_metadata.iter().map(|row_group| { + row_group + .columns() + .get(parquet_idx) + .and_then(|column| column.statistics()) + .map(&exactness) + })) +} + +fn summarize_exactness(exactness: I) -> ExactnessSummary +where + I: IntoIterator>, +{ + let mut has_true = false; + let mut has_false_or_null = false; + + for exactness in exactness { + match exactness { + Some(true) => has_true = true, + Some(false) | None => has_false_or_null = true, + } + + if has_true && has_false_or_null { + return ExactnessSummary::Mixed; + } + } + + if has_true { + ExactnessSummary::AllExact + } else { + ExactnessSummary::NoneExact + } +} + +/// Extract distinct counts from row group column statistics. +fn summarize_distinct_counts( + parquet_idx: Option, + row_groups_metadata: &[RowGroupMetaData], +) -> Precision { + let Some(parquet_idx) = parquet_idx else { + return Precision::Absent; + }; + + let num_row_groups = row_groups_metadata.len(); + if num_row_groups == 0 { + return Precision::Absent; + } + + let required_count = (num_row_groups as f64 * PARTIAL_NDV_THRESHOLD).ceil() as usize; + let mut ndv_count = 0; + let mut max_distinct_count: Option = None; + + for (row_group_idx, row_group) in row_groups_metadata.iter().enumerate() { + if let Some(distinct_count) = row_group + .columns() + .get(parquet_idx) + .and_then(|col| col.statistics()) + .and_then(|stats| stats.distinct_count_opt()) + { + ndv_count += 1; + max_distinct_count = Some(match max_distinct_count { + Some(max) => max.max(distinct_count), + None => distinct_count, + }); + } + + // Return early if there's no chance to reach the required coverage. + let remaining = num_row_groups - row_group_idx - 1; + if ndv_count + remaining < required_count { + return Precision::Absent; + } + } + + match max_distinct_count { + Some(distinct_count) if num_row_groups == 1 => { + Precision::Exact(distinct_count as usize) + } + Some(distinct_count) => Precision::Inexact(distinct_count as usize), + None => Precision::Absent, + } +} + +/// Compute the Arrow in-memory size for a single column +fn compute_arrow_column_size( + data_type: &DataType, + row_groups_metadata: &[RowGroupMetaData], + parquet_idx: Option, + num_rows: usize, +) -> Precision { + // For primitive types with known fixed size, compute exact size + if let Some(byte_width) = data_type.primitive_width() { + return Precision::Exact(byte_width * num_rows); + } + + // Use the uncompressed Parquet size as an estimate for other types + if let Some(parquet_idx) = parquet_idx { + let uncompressed_bytes: i64 = row_groups_metadata + .iter() + .filter_map(|rg| rg.columns().get(parquet_idx)) + .map(|col| col.uncompressed_size()) + .sum(); + return Precision::Inexact(uncompressed_bytes as usize); + } + + // Otherwise, we cannot determine the size + Precision::Absent } /// Checks if any occurrence of `value` in `array` corresponds to a `true` @@ -479,10 +754,19 @@ fn has_any_exact_match( array: &ArrayRef, exactness: &BooleanArray, ) -> Option { + if value.is_null() { + return Some(false); + } + + // Shortcut for single row group + if array.len() == 1 { + return Some(exactness.is_valid(0) && exactness.value(0)); + } + let scalar_array = value.to_scalar().ok()?; let eq_mask = eq(&scalar_array, &array).ok()?; let combined_mask = and(&eq_mask, exactness).ok()?; - Some(combined_mask.true_count() > 0) + Some(combined_mask.has_true()) } /// Wrapper to implement [`FileMetadata`] for [`ParquetMetaData`]. @@ -514,12 +798,114 @@ impl FileMetadata for CachedParquetMetaData { } } +/// Convert a [`PhysicalSortExpr`] to a Parquet [`SortingColumn`]. +/// +/// Returns `Err` if the expression is not a simple column reference. +pub(crate) fn sort_expr_to_sorting_column( + sort_expr: &PhysicalSortExpr, +) -> Result { + let column = sort_expr.expr.downcast_ref::().ok_or_else(|| { + DataFusionError::Plan(format!( + "Parquet sorting_columns only supports simple column references, \ + but got expression: {}", + sort_expr.expr + )) + })?; + + let column_idx: i32 = column.index().try_into().map_err(|_| { + DataFusionError::Plan(format!( + "Column index {} is too large to be represented as i32", + column.index() + )) + })?; + + Ok(SortingColumn { + column_idx, + descending: sort_expr.options.descending, + nulls_first: sort_expr.options.nulls_first, + }) +} + +/// Convert a [`LexOrdering`] to `Vec` for Parquet. +/// +/// Returns `Err` if any expression is not a simple column reference. +pub(crate) fn lex_ordering_to_sorting_columns( + ordering: &LexOrdering, +) -> Result> { + ordering.iter().map(sort_expr_to_sorting_column).collect() +} + +/// Extracts ordering information from Parquet metadata. +/// +/// This function reads the sorting_columns from the first row group's metadata +/// and converts them into a [`LexOrdering`] that can be used by the query engine. +/// +/// # Arguments +/// * `metadata` - The Parquet metadata containing sorting_columns information +/// * `schema` - The Arrow schema to use for column lookup +/// +/// # Returns +/// * `Ok(Some(ordering))` if valid ordering information was found +/// * `Ok(None)` if no sorting columns were specified or they couldn't be resolved +pub fn ordering_from_parquet_metadata( + metadata: &ParquetMetaData, + schema: &SchemaRef, +) -> Result> { + // Get the sorting columns from the first row group metadata. + // If no row groups exist or no sorting columns are specified, return None. + let sorting_columns = metadata + .row_groups() + .first() + .and_then(|rg| rg.sorting_columns()) + .filter(|cols| !cols.is_empty()); + + let Some(sorting_columns) = sorting_columns else { + return Ok(None); + }; + + let parquet_schema = metadata.file_metadata().schema_descr(); + + let sort_exprs = + sorting_columns_to_physical_exprs(sorting_columns, parquet_schema, schema); + + if sort_exprs.is_empty() { + return Ok(None); + } + + Ok(LexOrdering::new(sort_exprs)) +} + +/// Converts Parquet sorting columns to physical sort expressions. +fn sorting_columns_to_physical_exprs( + sorting_columns: &[SortingColumn], + parquet_schema: &SchemaDescriptor, + arrow_schema: &SchemaRef, +) -> Vec { + use arrow::compute::SortOptions; + + sorting_columns + .iter() + .filter_map(|sc| { + let parquet_column = parquet_schema.column(sc.column_idx as usize); + let name = parquet_column.name(); + + // Find the column in the arrow schema + let (index, _) = arrow_schema.column_with_name(name)?; + + let expr = Arc::new(Column::new(name, index)); + let options = SortOptions { + descending: sc.descending, + nulls_first: sc.nulls_first, + }; + Some(PhysicalSortExpr::new(expr, options)) + }) + .collect() +} + #[cfg(test)] mod tests { use super::*; - use arrow::array::{ArrayRef, BooleanArray, Int32Array}; - use datafusion_common::ScalarValue; - use std::sync::Arc; + use arrow::array::Int32Array; #[test] fn test_has_any_exact_match() { @@ -567,4 +953,501 @@ mod tests { assert_eq!(result, Some(false)); } } + + #[test] + fn test_summarize_exactness() { + assert_eq!( + summarize_exactness([Some(true), Some(true)]), + ExactnessSummary::AllExact + ); + assert_eq!( + summarize_exactness([Some(false), None]), + ExactnessSummary::NoneExact + ); + assert_eq!( + summarize_exactness([Some(true), Some(false)]), + ExactnessSummary::Mixed + ); + assert_eq!( + summarize_exactness([Some(true), None]), + ExactnessSummary::Mixed + ); + assert_eq!( + summarize_exactness(std::iter::empty()), + ExactnessSummary::NoneExact + ); + } + + mod ndv_tests { + use super::*; + use arrow::datatypes::Field; + use parquet::basic::Type as PhysicalType; + use parquet::file::metadata::ColumnChunkMetaData; + use parquet::file::reader::{FileReader, SerializedFileReader}; + use parquet::file::statistics::Statistics as ParquetStatistics; + use parquet::schema::types::Type as SchemaType; + use std::fs::File; + use std::path::PathBuf; + + fn create_schema_descr(num_columns: usize) -> Arc { + let fields: Vec> = (0..num_columns) + .map(|i| { + Arc::new( + SchemaType::primitive_type_builder( + &format!("col_{i}"), + PhysicalType::INT32, + ) + .build() + .unwrap(), + ) + }) + .collect(); + + let schema = SchemaType::group_type_builder("schema") + .with_fields(fields) + .build() + .unwrap(); + + Arc::new(SchemaDescriptor::new(Arc::new(schema))) + } + + fn create_arrow_schema(num_columns: usize) -> SchemaRef { + let fields: Vec = (0..num_columns) + .map(|i| Field::new(format!("col_{i}"), DataType::Int32, true)) + .collect(); + Arc::new(Schema::new(fields)) + } + + fn create_row_group_with_stats( + schema_descr: &Arc, + column_stats: Vec>, + num_rows: i64, + ) -> RowGroupMetaData { + let columns: Vec = column_stats + .into_iter() + .enumerate() + .map(|(i, stats)| { + let mut builder = + ColumnChunkMetaData::builder(schema_descr.column(i)); + if let Some(s) = stats { + builder = builder.set_statistics(s); + } + builder.set_num_values(num_rows).build().unwrap() + }) + .collect(); + + RowGroupMetaData::builder(schema_descr.clone()) + .set_num_rows(num_rows) + .set_total_byte_size(1000) + .set_column_metadata(columns) + .build() + .unwrap() + } + + fn create_parquet_metadata( + schema_descr: Arc, + row_groups: Vec, + ) -> ParquetMetaData { + use parquet::file::metadata::FileMetaData; + + let num_rows: i64 = row_groups.iter().map(|rg| rg.num_rows()).sum(); + let file_meta = FileMetaData::new( + 1, // version + num_rows, // num_rows + None, // created_by + None, // key_value_metadata + schema_descr, // schema_descr + None, // column_orders + ); + + ParquetMetaData::new(file_meta, row_groups) + } + + #[test] + fn test_summarize_null_counts() { + let schema_descr = create_schema_descr(1); + let arrow_schema = create_arrow_schema(2); + let stats_with_count = + ParquetStatistics::int32(Some(1), Some(10), None, Some(2), false); + let stats_without_count = + ParquetStatistics::int32(Some(1), Some(10), None, None, false); + + let row_groups = vec![ + create_row_group_with_stats( + &schema_descr, + vec![Some(stats_with_count)], + 10, + ), + create_row_group_with_stats( + &schema_descr, + vec![Some(stats_without_count.clone())], + 10, + ), + create_row_group_with_stats(&schema_descr, vec![None], 10), + ]; + let stats_converter = + StatisticsConverter::try_new("col_0", &arrow_schema, &schema_descr) + .unwrap(); + let missing_column_converter = + StatisticsConverter::try_new("col_1", &arrow_schema, &schema_descr) + .unwrap(); + + assert_eq!( + summarize_null_counts(&stats_converter, &row_groups).unwrap(), + Precision::Inexact(2) + ); + assert_eq!( + summarize_null_counts(&missing_column_converter, &row_groups).unwrap(), + Precision::Absent + ); + assert_eq!( + summarize_null_counts(&stats_converter, &[]).unwrap(), + Precision::Exact(0) + ); + assert_eq!( + summarize_null_counts(&missing_column_converter, &[]).unwrap(), + Precision::Exact(0) + ); + + let missing_counts_unknown_converter = + StatisticsConverter::try_new("col_0", &arrow_schema, &schema_descr) + .unwrap() + .with_missing_null_counts_as_zero(false); + assert_eq!( + summarize_null_counts(&missing_counts_unknown_converter, &row_groups) + .unwrap(), + Precision::Inexact(2) + ); + + let row_groups_without_count = vec![ + create_row_group_with_stats( + &schema_descr, + vec![Some(stats_without_count.clone())], + 10, + ), + create_row_group_with_stats( + &schema_descr, + vec![Some(stats_without_count)], + 10, + ), + ]; + assert_eq!( + summarize_null_counts(&stats_converter, &row_groups_without_count) + .unwrap(), + Precision::Exact(0) + ); + + let missing_counts_unknown_converter = + stats_converter.with_missing_null_counts_as_zero(false); + assert_eq!( + summarize_null_counts( + &missing_counts_unknown_converter, + &row_groups_without_count, + ) + .unwrap(), + Precision::Absent + ); + } + + #[test] + fn test_distinct_count_single_row_group_with_ndv() { + // Single row group with distinct count should return Exact + let schema_descr = create_schema_descr(1); + let arrow_schema = create_arrow_schema(1); + + // Create statistics with distinct_count = 42 + let stats = ParquetStatistics::int32( + Some(1), // min + Some(100), // max + Some(42), // distinct_count + Some(0), // null_count + false, // is_deprecated + ); + + let row_group = + create_row_group_with_stats(&schema_descr, vec![Some(stats)], 1000); + let metadata = create_parquet_metadata(schema_descr, vec![row_group]); + + let result = DFParquetMetadata::statistics_from_parquet_metadata( + &metadata, + &arrow_schema, + ) + .unwrap(); + + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Exact(42) + ); + } + + #[test] + fn test_distinct_count_multiple_row_groups_with_ndv() { + // Multiple row groups with distinct counts should return Inexact (sum) + let schema_descr = create_schema_descr(1); + let arrow_schema = create_arrow_schema(1); + + // Row group 1: distinct_count = 10 + let stats1 = ParquetStatistics::int32( + Some(1), + Some(50), + Some(10), // distinct_count + Some(0), + false, + ); + + // Row group 2: distinct_count = 20 + let stats2 = ParquetStatistics::int32( + Some(51), + Some(100), + Some(20), // distinct_count + Some(0), + false, + ); + + let row_group1 = + create_row_group_with_stats(&schema_descr, vec![Some(stats1)], 500); + let row_group2 = + create_row_group_with_stats(&schema_descr, vec![Some(stats2)], 500); + let metadata = + create_parquet_metadata(schema_descr, vec![row_group1, row_group2]); + + let result = DFParquetMetadata::statistics_from_parquet_metadata( + &metadata, + &arrow_schema, + ) + .unwrap(); + + // Max of distinct counts (lower bound since we can't accurately merge NDV) + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Inexact(20) + ); + } + + #[test] + fn test_distinct_count_no_ndv_available() { + // No distinct count in statistics should return Absent + let schema_descr = create_schema_descr(1); + let arrow_schema = create_arrow_schema(1); + + // Create statistics without distinct_count (None) + let stats = ParquetStatistics::int32( + Some(1), + Some(100), + None, // no distinct_count + Some(0), + false, + ); + + let row_group = + create_row_group_with_stats(&schema_descr, vec![Some(stats)], 1000); + let metadata = create_parquet_metadata(schema_descr, vec![row_group]); + + let result = DFParquetMetadata::statistics_from_parquet_metadata( + &metadata, + &arrow_schema, + ) + .unwrap(); + + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Absent + ); + } + + #[test] + fn test_distinct_count_partial_ndv_below_threshold() { + // 1 of 2 row groups has NDV (50% < 75% threshold) -> Absent + let schema_descr = create_schema_descr(1); + let arrow_schema = create_arrow_schema(1); + + let stats1 = + ParquetStatistics::int32(Some(1), Some(50), Some(15), Some(0), false); + let stats2 = + ParquetStatistics::int32(Some(51), Some(100), None, Some(0), false); + + let row_group1 = + create_row_group_with_stats(&schema_descr, vec![Some(stats1)], 500); + let row_group2 = + create_row_group_with_stats(&schema_descr, vec![Some(stats2)], 500); + let metadata = + create_parquet_metadata(schema_descr, vec![row_group1, row_group2]); + + let result = DFParquetMetadata::statistics_from_parquet_metadata( + &metadata, + &arrow_schema, + ) + .unwrap(); + + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Absent + ); + } + + #[test] + fn test_distinct_count_partial_ndv_above_threshold() { + // 3 of 4 row groups have NDV (75% >= 75% threshold) -> Inexact + let schema_descr = create_schema_descr(1); + let arrow_schema = create_arrow_schema(1); + + let stats_with = |ndv| { + ParquetStatistics::int32(Some(1), Some(100), Some(ndv), Some(0), false) + }; + let stats_without = + ParquetStatistics::int32(Some(1), Some(100), None, Some(0), false); + + let rg1 = create_row_group_with_stats( + &schema_descr, + vec![Some(stats_with(10))], + 250, + ); + let rg2 = create_row_group_with_stats( + &schema_descr, + vec![Some(stats_with(20))], + 250, + ); + let rg3 = create_row_group_with_stats( + &schema_descr, + vec![Some(stats_with(15))], + 250, + ); + let rg4 = create_row_group_with_stats( + &schema_descr, + vec![Some(stats_without)], + 250, + ); + let metadata = + create_parquet_metadata(schema_descr, vec![rg1, rg2, rg3, rg4]); + + let result = DFParquetMetadata::statistics_from_parquet_metadata( + &metadata, + &arrow_schema, + ) + .unwrap(); + + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Inexact(20) + ); + } + + #[test] + fn test_distinct_count_multiple_columns() { + // Test with multiple columns, each with different NDV + let schema_descr = create_schema_descr(3); + let arrow_schema = create_arrow_schema(3); + + // col_0: distinct_count = 5 + let stats0 = + ParquetStatistics::int32(Some(1), Some(10), Some(5), Some(0), false); + // col_1: no distinct_count + let stats1 = + ParquetStatistics::int32(Some(1), Some(100), None, Some(0), false); + // col_2: distinct_count = 100 + let stats2 = + ParquetStatistics::int32(Some(1), Some(1000), Some(100), Some(0), false); + + let row_group = create_row_group_with_stats( + &schema_descr, + vec![Some(stats0), Some(stats1), Some(stats2)], + 1000, + ); + let metadata = create_parquet_metadata(schema_descr, vec![row_group]); + + let result = DFParquetMetadata::statistics_from_parquet_metadata( + &metadata, + &arrow_schema, + ) + .unwrap(); + + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Exact(5) + ); + assert_eq!( + result.column_statistics[1].distinct_count, + Precision::Absent + ); + assert_eq!( + result.column_statistics[2].distinct_count, + Precision::Exact(100) + ); + } + + #[test] + fn test_distinct_count_no_statistics_at_all() { + // No statistics in row group should return Absent for all stats + let schema_descr = create_schema_descr(1); + let arrow_schema = create_arrow_schema(1); + + // Create row group without any statistics + let row_group = create_row_group_with_stats(&schema_descr, vec![None], 1000); + let metadata = create_parquet_metadata(schema_descr, vec![row_group]); + + let result = DFParquetMetadata::statistics_from_parquet_metadata( + &metadata, + &arrow_schema, + ) + .unwrap(); + + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Absent + ); + } + + /// Integration test that reads a real Parquet file with distinct_count statistics. + /// The test file was created with DuckDB and has known NDV values: + /// - id: NULL (high cardinality, not tracked) + /// - category: 10 distinct values + /// - name: 5 distinct values + #[test] + fn test_distinct_count_from_real_parquet_file() { + // Path to test file created by DuckDB with distinct_count statistics + let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + path.push("src/test_data/ndv_test.parquet"); + + let file = File::open(&path).expect("Failed to open test parquet file"); + let reader = + SerializedFileReader::new(file).expect("Failed to create reader"); + let parquet_metadata = reader.metadata(); + + // Derive Arrow schema from parquet file metadata + let arrow_schema = Arc::new( + parquet_to_arrow_schema( + parquet_metadata.file_metadata().schema_descr(), + None, + ) + .expect("Failed to convert schema"), + ); + + let result = DFParquetMetadata::statistics_from_parquet_metadata( + parquet_metadata, + &arrow_schema, + ) + .expect("Failed to extract statistics"); + + // id: no distinct_count (high cardinality) + assert_eq!( + result.column_statistics[0].distinct_count, + Precision::Absent, + "id column should have Absent distinct_count" + ); + + // category: 10 distinct values + assert_eq!( + result.column_statistics[1].distinct_count, + Precision::Exact(10), + "category column should have Exact(10) distinct_count" + ); + + // name: 5 distinct values + assert_eq!( + result.column_statistics[2].distinct_count, + Precision::Exact(5), + "name column should have Exact(5) distinct_count" + ); + } + } } diff --git a/datafusion/datasource-parquet/src/metrics.rs b/datafusion/datasource-parquet/src/metrics.rs index 5eaa137e9a456..4bf009afd6d63 100644 --- a/datafusion/datasource-parquet/src/metrics.rs +++ b/datafusion/datasource-parquet/src/metrics.rs @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use datafusion_physical_plan::metrics::{ - Count, ExecutionPlanMetricsSet, MetricBuilder, MetricType, PruningMetrics, - RatioMergeStrategy, RatioMetrics, Time, + Count, ExecutionPlanMetricsSet, Gauge, Label, MetricBuilder, MetricCategory, + MetricType, PruningMetrics, RatioMergeStrategy, RatioMetrics, Time, }; /// Stores metrics about the parquet execution for a particular parquet file. @@ -45,9 +47,11 @@ pub struct ParquetFileMetrics { pub files_ranges_pruned_statistics: PruningMetrics, /// Number of times the predicate could not be evaluated pub predicate_evaluation_errors: Count, - /// Number of row groups whose bloom filters were checked, tracked with matched/pruned counts + /// Number of row groups pruned by bloom filters pub row_groups_pruned_bloom_filter: PruningMetrics, - /// Number of row groups whose statistics were checked, tracked with matched/pruned counts + /// Number of row groups pruned due to limit pruning. + pub limit_pruned_row_groups: PruningMetrics, + /// Number of row groups pruned by statistics pub row_groups_pruned_statistics: PruningMetrics, /// Total number of bytes scanned pub bytes_scanned: Count, @@ -63,19 +67,32 @@ pub struct ParquetFileMetrics { pub bloom_filter_eval_time: Time, /// Total rows filtered or matched by parquet page index pub page_index_rows_pruned: PruningMetrics, + /// Total pages filtered or matched by parquet page index + pub page_index_pages_pruned: PruningMetrics, /// Total time spent evaluating parquet page index filters pub page_index_eval_time: Time, /// Total time spent reading and parsing metadata from the footer pub metadata_load_time: Time, /// Scan Efficiency Ratio, calculated as bytes_scanned / total_file_size pub scan_efficiency_ratio: RatioMetrics, - /// Predicate Cache: number of records read directly from the inner reader. - /// This is the number of rows decoded while evaluating predicates - pub predicate_cache_inner_records: Count, + /// Predicate Cache: Total number of rows physically read and decoded from the Parquet file. + /// + /// This metric tracks "cache misses" in the predicate pushdown optimization. + /// When the specialized predicate reader cannot find the requested data in its cache, + /// it must fall back to the "inner reader" to physically decode the data from the + /// Parquet. + /// + /// This is the expensive path (IO + Decompression + Decoding). + /// + /// We use a Gauge here as arrow-rs reports absolute numbers rather + /// than incremental readings, we want a `set` operation here rather + /// than `add`. Earlier it was `Count`, which led to this issue: + /// github.com/apache/datafusion/issues/19334 + pub predicate_cache_inner_records: Gauge, /// Predicate Cache: number of records read from the cache. This is the /// number of rows that were stored in the cache after evaluating predicates /// reused for the output. - pub predicate_cache_records: Count, + pub predicate_cache_records: Gauge, } impl ParquetFileMetrics { @@ -85,41 +102,52 @@ impl ParquetFileMetrics { filename: &str, metrics: &ExecutionPlanMetricsSet, ) -> Self { + // Share the filename label across all per-file metrics to avoid + // allocating the same filename string for each metric. + let filename_label = Label::new("filename", Arc::::from(filename)); + let builder = MetricBuilder::new(metrics).with_label(filename_label); + // ----------------------- // 'summary' level metrics // ----------------------- - let row_groups_pruned_bloom_filter = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) - .with_type(MetricType::SUMMARY) + let row_groups_pruned_bloom_filter = builder + .clone() + .with_type(MetricType::Summary) .pruning_metrics("row_groups_pruned_bloom_filter", partition); - let row_groups_pruned_statistics = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) - .with_type(MetricType::SUMMARY) + let limit_pruned_row_groups = builder + .clone() + .with_type(MetricType::Summary) + .pruning_metrics("limit_pruned_row_groups", partition); + + let row_groups_pruned_statistics = builder + .clone() + .with_type(MetricType::Summary) .pruning_metrics("row_groups_pruned_statistics", partition); - let page_index_rows_pruned = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) - .with_type(MetricType::SUMMARY) - .pruning_metrics("page_index_rows_pruned", partition); + let page_index_pages_pruned = builder + .clone() + .with_type(MetricType::Summary) + .pruning_metrics("page_index_pages_pruned", partition); - let bytes_scanned = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) - .with_type(MetricType::SUMMARY) + let bytes_scanned = builder + .clone() + .with_type(MetricType::Summary) + .with_category(MetricCategory::Bytes) .counter("bytes_scanned", partition); - let metadata_load_time = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) - .with_type(MetricType::SUMMARY) + let metadata_load_time = builder + .clone() + .with_type(MetricType::Summary) .subset_time("metadata_load_time", partition); let files_ranges_pruned_statistics = MetricBuilder::new(metrics) - .with_type(MetricType::SUMMARY) + .with_type(MetricType::Summary) .pruning_metrics("files_ranges_pruned_statistics", partition); - let scan_efficiency_ratio = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) - .with_type(MetricType::SUMMARY) + let scan_efficiency_ratio = builder + .clone() + .with_type(MetricType::Summary) .ratio_metrics_with_strategy( "scan_efficiency_ratio", partition, @@ -129,49 +157,59 @@ impl ParquetFileMetrics { // ----------------------- // 'dev' level metrics // ----------------------- - let predicate_evaluation_errors = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) + let predicate_evaluation_errors = builder + .clone() + .with_category(MetricCategory::Rows) .counter("predicate_evaluation_errors", partition); - let pushdown_rows_pruned = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) + let pushdown_rows_pruned = builder + .clone() + .with_category(MetricCategory::Rows) .counter("pushdown_rows_pruned", partition); - let pushdown_rows_matched = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) + let pushdown_rows_matched = builder + .clone() + .with_category(MetricCategory::Rows) .counter("pushdown_rows_matched", partition); - let row_pushdown_eval_time = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) + let row_pushdown_eval_time = builder + .clone() .subset_time("row_pushdown_eval_time", partition); - let statistics_eval_time = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) + let statistics_eval_time = builder + .clone() .subset_time("statistics_eval_time", partition); - let bloom_filter_eval_time = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) + let bloom_filter_eval_time = builder + .clone() .subset_time("bloom_filter_eval_time", partition); - let page_index_eval_time = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) + let page_index_eval_time = builder + .clone() .subset_time("page_index_eval_time", partition); - let predicate_cache_inner_records = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) - .counter("predicate_cache_inner_records", partition); + let page_index_rows_pruned = builder + .clone() + .pruning_metrics("page_index_rows_pruned", partition); - let predicate_cache_records = MetricBuilder::new(metrics) - .with_new_label("filename", filename.to_string()) - .counter("predicate_cache_records", partition); + let predicate_cache_inner_records = builder + .clone() + .with_category(MetricCategory::Rows) + .gauge("predicate_cache_inner_records", partition); + + let predicate_cache_records = builder + .with_category(MetricCategory::Rows) + .gauge("predicate_cache_records", partition); Self { files_ranges_pruned_statistics, predicate_evaluation_errors, row_groups_pruned_bloom_filter, row_groups_pruned_statistics, + limit_pruned_row_groups, bytes_scanned, pushdown_rows_pruned, pushdown_rows_matched, row_pushdown_eval_time, page_index_rows_pruned, + page_index_pages_pruned, statistics_eval_time, bloom_filter_eval_time, page_index_eval_time, @@ -181,4 +219,28 @@ impl ParquetFileMetrics { predicate_cache_records, } } + + /// Record pages whose page-index pruning was skipped because the containing + /// row group was fully matched by row-group statistics. + /// + /// The counter is only registered when there is a non-zero value. This keeps + /// [`ParquetFileMetrics::new`] from cloning the filename and metrics set for + /// files that never use this metric. + pub(crate) fn add_page_index_pages_skipped_by_fully_matched( + metrics: &ExecutionPlanMetricsSet, + partition: usize, + filename: &str, + n: usize, + ) { + if n == 0 { + return; + } + + let count = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .with_type(MetricType::Summary) + .with_category(MetricCategory::Rows) + .counter("page_index_pages_skipped_by_fully_matched", partition); + count.add(n); + } } diff --git a/datafusion/datasource-parquet/src/mod.rs b/datafusion/datasource-parquet/src/mod.rs index e0e906f3ce2a5..bec07363668e3 100644 --- a/datafusion/datasource-parquet/src/mod.rs +++ b/datafusion/datasource-parquet/src/mod.rs @@ -15,23 +15,35 @@ // specific language governing permissions and limitations // under the License. +//! DataFusion Parquet Reader: [`ParquetSource`] +//! +//! [`ParquetSource`]: source::ParquetSource + // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] pub mod access_plan; +mod bloom_filter; +mod decoder_projection; pub mod file_format; pub mod metadata; mod metrics; mod opener; mod page_filter; +mod push_decoder; mod reader; mod row_filter; mod row_group_filter; +mod schema_coercion; +mod sink; +mod sort; pub mod source; +mod supported_predicates; +#[cfg(test)] +mod test_util; +mod virtual_column; mod writer; pub use access_plan::{ParquetAccessPlan, RowGroupAccess}; @@ -42,4 +54,12 @@ pub use reader::*; // Expose so downstream crates can use it pub use row_filter::build_row_filter; pub use row_filter::can_expr_be_pushed_down_with_schemas; pub use row_group_filter::RowGroupAccessPlanFilter; +#[expect(deprecated)] +pub use schema_coercion::{ + Int96Coercer, apply_file_schema_type_coercions, coerce_file_schema_to_string_type, + coerce_file_schema_to_view_type, coerce_int96_to_resolution, + transform_binary_to_string, transform_schema_to_view, +}; +pub use sink::ParquetSink; +pub use virtual_column::ParquetVirtualColumn; pub use writer::plan_to_parquet; diff --git a/datafusion/datasource-parquet/src/opener/early_stop.rs b/datafusion/datasource-parquet/src/opener/early_stop.rs new file mode 100644 index 0000000000000..75749d284068b --- /dev/null +++ b/datafusion/datasource-parquet/src/opener/early_stop.rs @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`EarlyStoppingStream`] terminates a Parquet file scan when a dynamic +//! filter narrows after the scan has already started. + +use std::pin::Pin; +use std::task::{Context, Poll}; + +use arrow::array::RecordBatch; +use datafusion_common::Result; +use datafusion_physical_plan::metrics::PruningMetrics; +use datafusion_pruning::FilePruner; +use futures::{Stream, StreamExt, ready}; + +/// Wraps an inner RecordBatchStream and a [`FilePruner`] +/// +/// This can terminate the scan early when some dynamic filters is updated after +/// the scan starts, so we discover after the scan starts that the file can be +/// pruned (can't have matching rows). +pub(super) struct EarlyStoppingStream { + /// Has the stream finished processing? All subsequent polls will return + /// None + done: bool, + file_pruner: FilePruner, + files_ranges_pruned_statistics: PruningMetrics, + /// The inner stream + inner: S, +} + +impl EarlyStoppingStream { + pub(super) fn new( + stream: S, + file_pruner: FilePruner, + files_ranges_pruned_statistics: PruningMetrics, + ) -> Self { + Self { + done: false, + inner: stream, + file_pruner, + files_ranges_pruned_statistics, + } + } +} + +impl EarlyStoppingStream +where + S: Stream> + Unpin, +{ + fn check_prune(&mut self, input: Result) -> Result> { + let batch = input?; + + // Since dynamic filters may have been updated, see if we can stop + // reading this stream entirely. + if self.file_pruner.should_prune()? { + self.files_ranges_pruned_statistics.add_pruned(1); + // Previously this file range has been counted as matched + self.files_ranges_pruned_statistics.subtract_matched(1); + self.done = true; + Ok(None) + } else { + // Return the adapted batch + Ok(Some(batch)) + } + } +} + +impl Stream for EarlyStoppingStream +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + if self.done { + return Poll::Ready(None); + } + match ready!(self.inner.poll_next_unpin(cx)) { + None => { + // input done + self.done = true; + Poll::Ready(None) + } + Some(input_batch) => { + let output = self.check_prune(input_batch); + Poll::Ready(output.transpose()) + } + } + } +} diff --git a/datafusion/datasource-parquet/src/opener/encryption.rs b/datafusion/datasource-parquet/src/opener/encryption.rs new file mode 100644 index 0000000000000..b725198237bbf --- /dev/null +++ b/datafusion/datasource-parquet/src/opener/encryption.rs @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Encryption context used during Parquet file open. +//! +//! Isolated here so the `#[cfg(feature = "parquet_encryption")]` gating does +//! not pollute the rest of the opener module. + +#[cfg(feature = "parquet_encryption")] +use std::sync::Arc; + +use datafusion_common::Result; +#[cfg(feature = "parquet_encryption")] +use datafusion_common::config::EncryptionFactoryOptions; +#[cfg(feature = "parquet_encryption")] +use datafusion_common::encryption::FileDecryptionProperties; +#[cfg(feature = "parquet_encryption")] +use datafusion_execution::parquet_encryption::EncryptionFactory; + +use super::ParquetMorselizer; + +#[derive(Default)] +pub(super) struct EncryptionContext { + #[cfg(feature = "parquet_encryption")] + file_decryption_properties: Option>, + #[cfg(feature = "parquet_encryption")] + encryption_factory: Option<(Arc, EncryptionFactoryOptions)>, +} + +#[cfg(feature = "parquet_encryption")] +impl EncryptionContext { + fn new( + file_decryption_properties: Option>, + encryption_factory: Option<( + Arc, + EncryptionFactoryOptions, + )>, + ) -> Self { + Self { + file_decryption_properties, + encryption_factory, + } + } + + pub(super) async fn get_file_decryption_properties( + &self, + file_location: &object_store::path::Path, + ) -> Result>> { + match &self.file_decryption_properties { + Some(file_decryption_properties) => { + Ok(Some(Arc::clone(file_decryption_properties))) + } + None => match &self.encryption_factory { + Some((encryption_factory, encryption_config)) => Ok(encryption_factory + .get_file_decryption_properties(encryption_config, file_location) + .await?), + None => Ok(None), + }, + } + } +} + +#[cfg(not(feature = "parquet_encryption"))] +#[expect(dead_code)] +impl EncryptionContext { + pub(super) async fn get_file_decryption_properties( + &self, + _file_location: &object_store::path::Path, + ) -> Result< + Option>, + > { + Ok(None) + } +} + +impl ParquetMorselizer { + #[cfg(feature = "parquet_encryption")] + pub(super) fn get_encryption_context(&self) -> EncryptionContext { + EncryptionContext::new( + self.file_decryption_properties.clone(), + self.encryption_factory.clone(), + ) + } + + #[cfg(not(feature = "parquet_encryption"))] + #[expect(dead_code)] + pub(super) fn get_encryption_context(&self) -> EncryptionContext { + EncryptionContext::default() + } +} diff --git a/datafusion/datasource-parquet/src/opener/mod.rs b/datafusion/datasource-parquet/src/opener/mod.rs new file mode 100644 index 0000000000000..5b517663f9c03 --- /dev/null +++ b/datafusion/datasource-parquet/src/opener/mod.rs @@ -0,0 +1,3397 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ParquetMorselizer`] state machines for opening Parquet files + +mod early_stop; +mod encryption; + +use self::early_stop::EarlyStoppingStream; +#[cfg(feature = "parquet_encryption")] +use self::encryption::EncryptionContext; +use crate::access_plan::PreparedAccessPlan; +use crate::decoder_projection::DecoderProjection; +use crate::page_filter::PagePruningAccessPlanFilter; +use crate::push_decoder::{DecoderBuilderConfig, PushDecoderStreamState}; +use crate::row_filter::RowFilterGenerator; +use crate::row_group_filter::{BloomFilterStatistics, RowGroupAccessPlanFilter}; +use crate::{ + Int96Coercer, ParquetAccessPlan, ParquetFileMetrics, ParquetFileReaderFactory, + ParquetVirtualColumn, apply_file_schema_type_coercions, +}; +use arrow::array::RecordBatch; +use arrow::datatypes::DataType; +use datafusion_datasource::morsel::{Morsel, MorselPlan, MorselPlanner, Morselizer}; +use datafusion_physical_expr::projection::ProjectionExprs; +use datafusion_physical_expr_adapter::replace_columns_with_literals; +use std::collections::{HashMap, VecDeque}; +use std::fmt; +use std::future::Future; +use std::mem; +use std::sync::Arc; + +use arrow::datatypes::{FieldRef, Schema, SchemaRef, TimeUnit}; +#[cfg(feature = "parquet_encryption")] +use datafusion_common::encryption::FileDecryptionProperties; +use datafusion_common::stats::Precision; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::{ + ColumnStatistics, HashSet, Result, ScalarValue, Statistics, exec_err, internal_err, +}; +use datafusion_datasource::{PartitionedFile, TableSchema}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; +use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::metrics::{ + BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricCategory, +}; +use datafusion_pruning::{FilePruner, PruningPredicate, build_pruning_predicate}; + +#[cfg(feature = "parquet_encryption")] +use datafusion_common::config::EncryptionFactoryOptions; +#[cfg(feature = "parquet_encryption")] +use datafusion_execution::parquet_encryption::EncryptionFactory; +use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; +use log::debug; +use parquet::arrow::ParquetRecordBatchStreamBuilder; +use parquet::arrow::arrow_reader::metrics::ArrowReaderMetrics; +use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; +use parquet::arrow::async_reader::AsyncFileReader; +use parquet::arrow::parquet_column; +use parquet::basic::Type; +use parquet::bloom_filter::Sbbf; +use parquet::file::metadata::{PageIndexPolicy, ParquetMetaDataReader}; + +/// Morselizer-level state for virtual columns, precomputed once per scan +/// partition so each file skips the validator walks, `null_replacements` +/// rebuild, and one of the `append_fields` allocations. +/// +/// Only constructed when the scan actually requests virtual columns; +/// [`ParquetMorselizer`] and [`PreparedParquetOpen`] hold +/// `Option>` so the zero-virtual-column path (the +/// common case) pays nothing. +pub(crate) struct VirtualColumnsState { + /// Shared list of virtual column fields. Cloned as a `Vec` only at the + /// arrow-rs `with_virtual_columns` call site, which takes it by value. + virtual_columns: Arc>, + /// Null-literal substitutions keyed by virtual column name, used to strip + /// virtual-column references from the projection fed into + /// `build_projection_read_plan` (which only understands file columns). + null_replacements: HashMap, + /// `logical_file_schema` with the virtual columns appended. Fed into the + /// per-file expression rewriter so virtual-column references + /// identity-rewrite instead of being replaced with null literals. + logical_schema_with_virtual: SchemaRef, +} + +impl VirtualColumnsState { + /// Validate each field carries a supported arrow virtual extension type + /// and precompute the per-scan derived state. + fn try_new( + virtual_columns: Vec, + logical_file_schema: &SchemaRef, + ) -> Result { + // Gate which extension types we forward to arrow-rs. Adding a new + // supported virtual column means adding a `ParquetVirtualColumn` + // variant — not editing a stringly-typed allowlist here. + for field in &virtual_columns { + ParquetVirtualColumn::try_from(field)?; + } + let null_replacements = virtual_columns + .iter() + .map(|f| ScalarValue::try_from(f.data_type()).map(|v| (f.name().clone(), v))) + .collect::>>()?; + let logical_schema_with_virtual = + append_fields(logical_file_schema, &virtual_columns); + Ok(Self { + virtual_columns: Arc::new(virtual_columns), + null_replacements, + logical_schema_with_virtual, + }) + } + + /// Validated virtual column fields, in declaration order. + pub(crate) fn virtual_columns(&self) -> &[FieldRef] { + &self.virtual_columns + } + + /// Null-literal substitutions keyed by virtual column name. Used to strip + /// virtual-column references from a projection before it is fed into the + /// parquet `ProjectionMask` (which only understands file columns). + pub(crate) fn null_replacements(&self) -> &HashMap { + &self.null_replacements + } +} + +/// Build the per-scan virtual-column state. +/// +/// Two checks run here: +/// - Extension-type allowlist via [`VirtualColumnsState::try_new`]: returns +/// `Err` for unsupported virtual extension types. +/// - Predicate-reference check (when pushdown is enabled): returns `Err` if +/// the predicate references a virtual column. The contract is that callers +/// route filters through +/// [`ParquetSource::try_pushdown_filters`](crate::source::ParquetSource), +/// which classifies virtual-col filters as `PushedDown::No`. Erroring here +/// prevents silent wrong results for callers that bypass that path and set +/// the predicate directly on `ParquetSource`. +/// +/// Returns `None` when the scan has no virtual columns, so callers avoid +/// allocating the shared state on the common path. +pub(crate) fn build_virtual_columns_state( + virtual_columns: &[FieldRef], + logical_file_schema: &SchemaRef, + predicate: Option<&Arc>, + pushdown_filters: bool, +) -> Result>> { + if virtual_columns.is_empty() { + return Ok(None); + } + if pushdown_filters && let Some(predicate) = predicate { + validate_predicate_does_not_reference_virtual_columns( + predicate, + virtual_columns, + )?; + } + let state = + VirtualColumnsState::try_new(virtual_columns.to_vec(), logical_file_schema)?; + Ok(Some(Arc::new(state))) +} + +/// Return `base` unchanged when `extra` is empty; otherwise build a new schema +/// with `extra` appended to `base`'s fields. +pub(crate) fn append_fields(base: &SchemaRef, extra: &[FieldRef]) -> SchemaRef { + if extra.is_empty() { + return Arc::clone(base); + } + let fields = base + .fields() + .iter() + .cloned() + .chain(extra.iter().cloned()) + .collect::>(); + Arc::new(Schema::new(fields)) +} + +/// Reject predicates that reference a virtual column. +/// +/// arrow-rs's `RowFilter` evaluates predicates against a `ProjectionMask` that +/// addresses parquet leaves only; virtual columns (e.g. `row_number`) are +/// synthesized by the reader *after* filter evaluation and cannot be referenced +/// inside a row filter. Silently dropping such a predicate would produce wrong +/// results. +fn validate_predicate_does_not_reference_virtual_columns( + predicate: &Arc, + virtual_columns: &[FieldRef], +) -> Result<()> { + if virtual_columns.is_empty() { + return Ok(()); + } + let virtual_names: HashSet<&str> = + virtual_columns.iter().map(|f| f.name().as_str()).collect(); + let mut offender: Option = None; + predicate.apply(|node: &Arc| { + if let Some(column) = node.downcast_ref::() + && virtual_names.contains(column.name()) + { + offender = Some(column.name().to_string()); + return Ok(TreeNodeRecursion::Stop); + } + Ok(TreeNodeRecursion::Continue) + })?; + if let Some(name) = offender { + return internal_err!( + "Predicate references virtual column '{name}'; route via \ + ParquetSource::try_pushdown_filters." + ); + } + Ok(()) +} + +/// Stateless Parquet morselizer implementation. +/// +/// Reading a Parquet file is a multi-stage process, with multiple CPU-intensive +/// steps interspersed with I/O steps. The code in this module implements the steps +/// as an explicit state machine -- see [`ParquetOpenState`] for details. +#[derive(Clone)] +pub(super) struct ParquetMorselizer { + /// Execution partition index + pub(crate) partition_index: usize, + /// Projection to apply on top of the table schema (i.e. can reference partition columns). + pub projection: ProjectionExprs, + /// Target number of rows in each output RecordBatch + pub batch_size: usize, + /// Optional limit on the number of rows to read + pub(crate) limit: Option, + /// If should keep the output rows in order + pub preserve_order: bool, + /// Optional predicate to apply during the scan + pub predicate: Option>, + /// Table schema, including partition columns. + pub table_schema: TableSchema, + /// Optional hint for how large the initial request to read parquet metadata + /// should be + pub metadata_size_hint: Option, + /// Metrics for reporting + pub metrics: ExecutionPlanMetricsSet, + /// Factory for instantiating parquet reader + pub parquet_file_reader_factory: Arc, + /// Should the filters be evaluated during the parquet scan using + /// [`DatafusionArrowPredicate`](crate::row_filter::DatafusionArrowPredicate)? + pub pushdown_filters: bool, + /// Should the filters be reordered to optimize the scan? + pub reorder_filters: bool, + /// Should we force the reader to use RowSelections for filtering + pub force_filter_selections: bool, + /// Should the page index be read from parquet files, if present, to skip + /// data pages + pub enable_page_index: bool, + /// Should the bloom filter be read from parquet, if present, to skip row + /// groups + pub enable_bloom_filter: bool, + /// Should row group pruning be applied + pub enable_row_group_stats_pruning: bool, + /// Coerce INT96 timestamps to specific TimeUnit + pub coerce_int96: Option, + /// Optional timezone applied to INT96-coerced timestamps. When `Some`, the + /// coerced column type becomes `Timestamp(, Some())`. + /// No effect when `coerce_int96` is `None`. + pub coerce_int96_tz: Option>, + /// Optional parquet FileDecryptionProperties + #[cfg(feature = "parquet_encryption")] + pub file_decryption_properties: Option>, + /// Rewrite expressions in the context of the file schema + pub(crate) expr_adapter_factory: Arc, + /// Optional factory to create file decryption properties dynamically + #[cfg(feature = "parquet_encryption")] + pub encryption_factory: + Option<(Arc, EncryptionFactoryOptions)>, + /// Maximum size of the predicate cache, in bytes. If none, uses + /// the arrow-rs default. + pub max_predicate_cache_size: Option, + /// Whether to read row groups in reverse order + pub reverse_row_groups: bool, + /// Optional sort order used to reorder row groups by their min/max statistics. + pub sort_order_for_reorder: Option, + /// Per-scan virtual-column state (validation already performed). `None` + /// when no virtual columns are requested — the common path. + pub(crate) virtual_state: Option>, +} + +impl fmt::Debug for ParquetMorselizer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ParquetMorselizer") + .field("partition_index", &self.partition_index) + .field("preserve_order", &self.preserve_order) + .field("enable_page_index", &self.enable_page_index) + .field("enable_bloom_filter", &self.enable_bloom_filter) + .finish() + } +} + +impl Morselizer for ParquetMorselizer { + fn plan_file(&self, file: PartitionedFile) -> Result> { + Ok(Box::new(ParquetMorselPlanner::try_new(self, file)?)) + } +} + +/// States for [`ParquetMorselPlanner`] +/// +/// These states correspond to the steps required to read and apply various +/// filter operations. +/// +/// States whose names beginning with `Load` represent waiting on IO to resolve +/// +/// ```text +/// Start +/// | +/// v +/// [LoadEncryption]? +/// | +/// v +/// PruneFile +/// | +/// v +/// LoadMetadata +/// | +/// v +/// PrepareFilters +/// | +/// v +/// LoadPageIndex +/// | +/// v +/// PruneWithStatistics +/// | +/// v +/// LoadBloomFilters +/// | +/// v +/// PruneWithBloomFilters +/// | +/// v +/// BuildStream +/// | +/// v +/// Done +/// ``` +/// +/// Note: `LoadEncryption` is only present when the `parquet_encryption` feature is +/// enabled. All other states are always visited in the order shown above, +/// though any async state may return `Poll::Pending` and then resume later. +enum ParquetOpenState { + Start { + prepared: Box, + #[cfg(feature = "parquet_encryption")] + encryption_context: Arc, + }, + /// Loading encryption footers + #[cfg(feature = "parquet_encryption")] + LoadEncryption(BoxFuture<'static, Result>>), + /// Try to prune file using only file-level statistics and partition + /// values before loading any parquet metadata + PruneFile(Box), + /// Loading Parquet metadata (in footer) + LoadMetadata(BoxFuture<'static, Result>), + /// Specialize any filters for the actual file schema (only known after + /// metadata is loaded) + PrepareFilters(Box), + /// Loading [Parquet Page Index](https://parquet.apache.org/docs/file-format/pageindex/) + LoadPageIndex(BoxFuture<'static, Result>), + /// Pruning Row Groups + PruneWithStatistics(Box), + /// Loading bloom filters required for row-group pruning + LoadBloomFilters(BoxFuture<'static, Result>), + /// Pruning with preloaded Bloom Filters + PruneWithBloomFilters(Box), + /// Builds the final reader stream + /// + /// TODO: split state as this currently does both I/O and CPU work. + BuildStream(Box), + /// Terminal state: the final opened stream is ready to return. + Ready(BoxStream<'static, Result>), + /// Terminal state: reading complete + Done, +} + +impl fmt::Debug for ParquetOpenState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let state = match self { + ParquetOpenState::Start { .. } => "Start", + #[cfg(feature = "parquet_encryption")] + ParquetOpenState::LoadEncryption(_) => "LoadEncryption", + ParquetOpenState::PruneFile(_) => "PruneFile", + ParquetOpenState::LoadMetadata(_) => "LoadMetadata", + ParquetOpenState::PrepareFilters(_) => "PrepareFilters", + ParquetOpenState::LoadPageIndex(_) => "LoadPageIndex", + ParquetOpenState::PruneWithStatistics(_) => "PruneWithStatistics", + ParquetOpenState::LoadBloomFilters(_) => "LoadBloomFilters", + ParquetOpenState::PruneWithBloomFilters(_) => "PruneWithBloomFilters", + ParquetOpenState::BuildStream(_) => "BuildStream", + ParquetOpenState::Ready(_) => "Ready", + ParquetOpenState::Done => "Done", + }; + f.write_str(state) + } +} + +struct PreparedParquetOpen { + partition_index: usize, + partitioned_file: PartitionedFile, + file_range: Option, + extensions: datafusion_datasource::FileExtensions, + file_name: String, + file_metrics: ParquetFileMetrics, + baseline_metrics: BaselineMetrics, + file_pruner: Option, + metadata_size_hint: Option, + metrics: ExecutionPlanMetricsSet, + parquet_file_reader_factory: Arc, + async_file_reader: Box, + batch_size: usize, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + output_schema: SchemaRef, + projection: ProjectionExprs, + predicate: Option>, + /// Per-scan virtual-column state, Arc-cloned from [`ParquetMorselizer`] so + /// each file shares validated fields, precomputed null replacements, and + /// the logical-with-virtual schema. `None` when no virtual columns were + /// requested. + virtual_state: Option>, + reorder_predicates: bool, + pushdown_filters: bool, + force_filter_selections: bool, + enable_page_index: bool, + enable_bloom_filter: bool, + enable_row_group_stats_pruning: bool, + limit: Option, + coerce_int96: Option, + coerce_int96_tz: Option>, + expr_adapter_factory: Arc, + predicate_creation_errors: Count, + max_predicate_cache_size: Option, + reverse_row_groups: bool, + sort_order_for_reorder: Option, + preserve_order: bool, + #[cfg(feature = "parquet_encryption")] + file_decryption_properties: Option>, +} + +/// State of [`ParquetOpenState`] +/// +/// Result of loading parquet metadata after file-level pruning is complete. +struct MetadataLoadedParquetOpen { + prepared: PreparedParquetOpen, + reader_metadata: ArrowReaderMetadata, + options: ArrowReaderOptions, +} + +/// State of [`ParquetOpenState`] +/// +/// Pruning Predicate and DataPage pruning information +/// specialized for the files specific schema. +struct FiltersPreparedParquetOpen { + loaded: MetadataLoadedParquetOpen, + pruning_predicate: Option>, + page_pruning_predicate: Option>, +} + +/// State of [`ParquetOpenState`] +/// +/// Result of CPU-only row-group pruning before optional bloom-filter I/O. +struct RowGroupsPrunedParquetOpen { + prepared: FiltersPreparedParquetOpen, + row_groups: RowGroupAccessPlanFilter, +} + +/// State of [`ParquetOpenState`] +/// +/// Result of loading bloom filters needed for row-group pruning. +struct BloomFiltersLoadedParquetOpen { + prepared: RowGroupsPrunedParquetOpen, + /// Bloom filters loaded for each row group that remains under consideration. + /// + /// indexed by parquet row-group index + row_group_bloom_filters: Vec, +} + +impl ParquetOpenState { + /// Applies one CPU-only state transition. + /// + /// `Load*` states do not transition here and are returned unchanged so the + /// driver loop can poll their inner futures separately. + /// + /// Implements state machine described in [`ParquetOpenState`] + fn transition(self) -> Result { + match self { + ParquetOpenState::Start { + prepared, + #[cfg(feature = "parquet_encryption")] + encryption_context, + } => { + #[cfg(feature = "parquet_encryption")] + { + let mut prepared = *prepared; + let future = async move { + let file_location = + &prepared.partitioned_file.object_meta.location; + prepared.file_decryption_properties = encryption_context + .get_file_decryption_properties(file_location) + .await?; + Ok(Box::new(prepared)) + } + .boxed(); + Ok(ParquetOpenState::LoadEncryption(future)) + } + #[cfg(not(feature = "parquet_encryption"))] + { + Ok(ParquetOpenState::PruneFile(prepared)) + } + } + #[cfg(feature = "parquet_encryption")] + ParquetOpenState::LoadEncryption(future) => { + Ok(ParquetOpenState::LoadEncryption(future)) + } + ParquetOpenState::PruneFile(prepared) => { + let Some(prepared) = (*prepared).prune_file()? else { + return Ok(ParquetOpenState::Done); + }; + Ok(ParquetOpenState::LoadMetadata(prepared.load().boxed())) + } + ParquetOpenState::LoadMetadata(future) => { + Ok(ParquetOpenState::LoadMetadata(future)) + } + ParquetOpenState::PrepareFilters(loaded) => { + let prepared_filters = loaded.prepare_filters()?; + Ok(ParquetOpenState::LoadPageIndex( + prepared_filters.load_page_index().boxed(), + )) + } + ParquetOpenState::LoadPageIndex(future) => { + Ok(ParquetOpenState::LoadPageIndex(future)) + } + ParquetOpenState::PruneWithStatistics(prepared) => { + let prepared_row_groups = prepared.prune_row_groups()?; + Ok(ParquetOpenState::LoadBloomFilters( + prepared_row_groups.load_bloom_filters().boxed(), + )) + } + ParquetOpenState::LoadBloomFilters(future) => { + Ok(ParquetOpenState::LoadBloomFilters(future)) + } + ParquetOpenState::PruneWithBloomFilters(loaded) => Ok( + ParquetOpenState::BuildStream(Box::new(loaded.prune_bloom_filters())), + ), + ParquetOpenState::BuildStream(prepared) => { + Ok(ParquetOpenState::Ready(prepared.build_stream()?)) + } + ParquetOpenState::Ready(stream) => Ok(ParquetOpenState::Ready(stream)), + ParquetOpenState::Done => { + panic!("ParquetOpenFuture polled after completion"); + } + } + } +} + +/// Implements the Morsel API +struct ParquetStreamMorsel { + stream: BoxStream<'static, Result>, +} + +impl ParquetStreamMorsel { + fn new(stream: BoxStream<'static, Result>) -> Self { + Self { stream } + } +} + +impl fmt::Debug for ParquetStreamMorsel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ParquetStreamMorsel") + .finish_non_exhaustive() + } +} + +impl Morsel for ParquetStreamMorsel { + fn into_stream(self: Box) -> BoxStream<'static, Result> { + self.stream + } +} + +/// Per-file planner that owns the current [`ParquetOpenState`]. +struct ParquetMorselPlanner { + /// Ready to perform CPU-only planning work. + state: ParquetOpenState, +} + +impl fmt::Debug for ParquetMorselPlanner { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("ParquetMorselPlanner::Ready") + .field(&self.state) + .finish() + } +} + +impl ParquetMorselPlanner { + fn try_new(morselizer: &ParquetMorselizer, file: PartitionedFile) -> Result { + let prepared = morselizer.prepare_open_file(file)?; + #[cfg(feature = "parquet_encryption")] + let state = ParquetOpenState::Start { + prepared: Box::new(prepared), + encryption_context: Arc::new(morselizer.get_encryption_context()), + }; + #[cfg(not(feature = "parquet_encryption"))] + let state = ParquetOpenState::Start { + prepared: Box::new(prepared), + }; + Ok(Self { state }) + } + + /// Schedule an I/O future that resolves to the next planner to run. + /// + /// This helper + /// + /// 1. drives one I/O phase to completion + /// 2. wraps the resulting state in a new [`ParquetMorselPlanner`] + /// 3. returns a [`MorselPlan`] containing the boxed future for the caller + /// to poll + /// + fn schedule_io(future: F) -> MorselPlan + where + F: Future> + Send + 'static, + { + let io_future = async move { + let next_state = future.await?; + Ok(Box::new(ParquetMorselPlanner { state: next_state }) as _) + }; + MorselPlan::new().with_pending_planner(io_future) + } +} + +impl MorselPlanner for ParquetMorselPlanner { + fn plan(self: Box) -> Result> { + if let ParquetOpenState::Done = self.state { + return Ok(None); + } + + let state = self.state.transition()?; + + match state { + #[cfg(feature = "parquet_encryption")] + ParquetOpenState::LoadEncryption(future) => { + Ok(Some(Self::schedule_io(async move { + Ok(ParquetOpenState::PruneFile(future.await?)) + }))) + } + ParquetOpenState::LoadMetadata(future) => { + Ok(Some(Self::schedule_io(async move { + Ok(ParquetOpenState::PrepareFilters(Box::new(future.await?))) + }))) + } + ParquetOpenState::LoadPageIndex(future) => { + Ok(Some(Self::schedule_io(async move { + Ok(ParquetOpenState::PruneWithStatistics(Box::new( + future.await?, + ))) + }))) + } + ParquetOpenState::LoadBloomFilters(future) => { + Ok(Some(Self::schedule_io(async move { + Ok(ParquetOpenState::PruneWithBloomFilters(Box::new( + future.await?, + ))) + }))) + } + ParquetOpenState::Ready(stream) => { + let morsels: Vec> = + vec![Box::new(ParquetStreamMorsel::new(stream))]; + Ok(Some(MorselPlan::new().with_morsels(morsels))) + } + ParquetOpenState::Done => Ok(None), + cpu_state => Ok(Some( + MorselPlan::new() + .with_planners(vec![Box::new(Self { state: cpu_state })]), + )), + } + } +} + +impl ParquetMorselizer { + /// Perform the CPU-only setup for opening a parquet file. + fn prepare_open_file( + &self, + partitioned_file: PartitionedFile, + ) -> Result { + let file_range = partitioned_file.range.clone(); + let extensions = partitioned_file.extensions.clone(); + let file_name = partitioned_file.object_meta.location.to_string(); + let file_metrics = + ParquetFileMetrics::new(self.partition_index, &file_name, &self.metrics); + let baseline_metrics = BaselineMetrics::new(&self.metrics, self.partition_index); + + let metadata_size_hint = partitioned_file + .metadata_size_hint + .or(self.metadata_size_hint); + + let async_file_reader: Box = + self.parquet_file_reader_factory.create_reader( + self.partition_index, + partitioned_file.clone(), + metadata_size_hint, + &self.metrics, + )?; + + // Calculate the output schema from the original projection (before literal replacement) + // so we get correct field names from column references + let logical_file_schema = Arc::clone(self.table_schema.file_schema()); + let output_schema = Arc::new( + self.projection + .project_schema(self.table_schema.table_schema())?, + ); + + // Build a combined map for replacing column references with literal values. + // This includes: + // 1. Partition column values from the file path (e.g., region=us-west-2) + // 2. Constant columns detected from file statistics (where min == max) + // + // Although partition columns *are* constant columns, we don't want to rely on + // statistics for them being populated if we can use the partition values + // (which are guaranteed to be present). + // + // For example, given a partition column `region` and predicate + // `region IN ('us-east-1', 'eu-central-1')` with file path + // `/data/region=us-west-2/...`, the predicate is rewritten to + // `'us-west-2' IN ('us-east-1', 'eu-central-1')` which simplifies to FALSE. + // + // While partition column optimization is done during logical planning, + // there are cases where partition columns may appear in more complex + // predicates that cannot be simplified until we open the file (such as + // dynamic predicates). + let mut literal_columns: HashMap = self + .table_schema + .table_partition_cols() + .iter() + .zip(partitioned_file.partition_values.iter()) + .map(|(field, value)| (field.name().clone(), value.clone())) + .collect(); + // Add constant columns from file statistics. + // Note that if there are statistics for partition columns there will be overlap, + // but since we use a HashMap, we'll just overwrite the partition values with the + // constant values from statistics (which should be the same). + literal_columns.extend(constant_columns_from_stats( + partitioned_file.statistics.as_deref(), + &logical_file_schema, + )); + + let mut projection = self.projection.clone(); + let mut predicate = self.predicate.clone(); + if !literal_columns.is_empty() { + projection = projection.try_map_exprs(|expr| { + replace_columns_with_literals(Arc::clone(&expr), &literal_columns) + })?; + predicate = predicate + .map(|p| replace_columns_with_literals(p, &literal_columns)) + .transpose()?; + } + + let predicate_creation_errors = MetricBuilder::new(&self.metrics) + .with_category(MetricCategory::Rows) + .global_counter("num_predicate_creation_errors"); + + // `FilePruner::try_new` decides whether a pruner is worthwhile (it needs + // a statistics struct, and either real column statistics or a dynamic + // filter that can prune via partition-value folding) and returns `None` + // otherwise. For a static predicate the pruner's tracker reports no + // changes, so it runs once and adds no ongoing cost. + let file_pruner = predicate.as_ref().and_then(|p| { + FilePruner::try_new( + Arc::clone(p), + &logical_file_schema, + &partitioned_file, + predicate_creation_errors.clone(), + ) + }); + + Ok(PreparedParquetOpen { + partition_index: self.partition_index, + partitioned_file, + file_range, + extensions, + file_name, + file_metrics, + baseline_metrics, + file_pruner, + metadata_size_hint, + metrics: self.metrics.clone(), + parquet_file_reader_factory: Arc::clone(&self.parquet_file_reader_factory), + async_file_reader, + batch_size: self.batch_size, + logical_file_schema: Arc::clone(&logical_file_schema), + physical_file_schema: logical_file_schema, + output_schema, + projection, + predicate, + virtual_state: self.virtual_state.as_ref().map(Arc::clone), + reorder_predicates: self.reorder_filters, + pushdown_filters: self.pushdown_filters, + force_filter_selections: self.force_filter_selections, + enable_page_index: self.enable_page_index, + enable_bloom_filter: self.enable_bloom_filter, + enable_row_group_stats_pruning: self.enable_row_group_stats_pruning, + limit: self.limit, + coerce_int96: self.coerce_int96, + coerce_int96_tz: self.coerce_int96_tz.clone(), + expr_adapter_factory: Arc::clone(&self.expr_adapter_factory), + predicate_creation_errors, + max_predicate_cache_size: self.max_predicate_cache_size, + reverse_row_groups: self.reverse_row_groups, + sort_order_for_reorder: self.sort_order_for_reorder.clone(), + preserve_order: self.preserve_order, + #[cfg(feature = "parquet_encryption")] + file_decryption_properties: None, + }) + } +} + +impl PreparedParquetOpen { + /// Attempt file-level pruning before any metadata is loaded. + /// + /// Returns `None` if the file can be skipped completely. + fn prune_file(mut self) -> Result> { + // Prune this file using the file level statistics and partition values. + // Since dynamic filters may have been updated since planning it is + // possible that we are able to prune files now that we couldn't prune at + // planning time. The `FilePruner` (built when the predicate is dynamic or + // the file carries statistics) also watches any still-active dynamic + // filter, so the + // `EarlyStoppingStream` wrapping the scan can re-check after each batch + // and end the stream early once a tightened filter proves the file can + // be skipped. + // + // File-level statistics may prune the file without loading any row + // groups or metadata. Partition column predicates are already folded to + // literals (see `replace_columns_with_literals` above), so a dynamic + // filter that references only partition columns can prune here too even + // when the file has no column statistics, e.g. + // `select * from t order by partition_col limit 10`. + if let Some(file_pruner) = &mut self.file_pruner + && file_pruner.should_prune()? + { + self.file_metrics + .files_ranges_pruned_statistics + .add_pruned(1); + return Ok(None); + } + + self.file_metrics + .files_ranges_pruned_statistics + .add_matched(1); + Ok(Some(self)) + } + + /// Load parquet metadata after file-level pruning is complete. + async fn load(mut self) -> Result { + // Don't load the page index yet. Since it is not stored inline in + // the footer, loading the page index if it is not needed will do + // unnecessary I/O. We decide later if it is needed to evaluate the + // pruning predicates. Thus default to not requesting it from the + // underlying reader. + let mut options = + ArrowReaderOptions::new().with_page_index_policy(PageIndexPolicy::Skip); + if let Some(schema) = self.partitioned_file.arrow_schema.as_ref() { + options = options.with_schema(Arc::clone(schema)); + } + #[cfg(feature = "parquet_encryption")] + let mut options = options; + #[cfg(feature = "parquet_encryption")] + if let Some(fd_val) = &self.file_decryption_properties { + options = options.with_file_decryption_properties(Arc::clone(fd_val)); + } + + let mut metadata_timer = self.file_metrics.metadata_load_time.timer(); + // Begin by loading the metadata from the underlying reader (note + // the returned metadata may actually include page indexes as some + // readers may return page indexes even when not requested -- for + // example when they are cached) + let reader_metadata = + ArrowReaderMetadata::load_async(&mut self.async_file_reader, options.clone()) + .await?; + metadata_timer.stop(); + drop(metadata_timer); + + Ok(MetadataLoadedParquetOpen { + prepared: self, + reader_metadata, + options, + }) + } +} + +impl MetadataLoadedParquetOpen { + /// Prepare file-schema coercions and pruning predicates once metadata is + /// loaded. + fn prepare_filters(self) -> Result { + let MetadataLoadedParquetOpen { + mut prepared, + mut reader_metadata, + mut options, + } = self; + + // Note about schemas: we are actually dealing with **3 different schemas** here: + // - The table schema as defined by the TableProvider. + // This is what the user sees, what they get when they `SELECT * FROM table`, etc. + // - The logical file schema: this is the table schema minus any hive partition columns and projections. + // This is what the physical file schema is coerced to. + // - The physical file schema: this is the schema that the arrow-rs + // parquet reader will actually produce for the file's columns. Any + // virtual columns (see [`crate::TableSchema::virtual_columns`]) are + // produced separately by the reader and are not part of this schema. + let mut physical_file_schema = Arc::clone(reader_metadata.schema()); + + // The schema loaded from the file may not be the same as the + // desired schema (for example if we want to instruct the parquet + // reader to read strings using Utf8View instead). Update if necessary + let mut metadata_dirty = false; + if let Some(merged) = apply_file_schema_type_coercions( + &prepared.logical_file_schema, + &physical_file_schema, + ) { + physical_file_schema = Arc::new(merged); + options = options.with_schema(Arc::clone(&physical_file_schema)); + metadata_dirty = true; + } + + if let Some(ref coerce) = prepared.coerce_int96 + && let Some(merged) = Int96Coercer::new( + reader_metadata.parquet_schema(), + &physical_file_schema, + coerce, + ) + .with_timezone(prepared.coerce_int96_tz.clone()) + .coerce() + { + physical_file_schema = Arc::new(merged); + options = options.with_schema(Arc::clone(&physical_file_schema)); + metadata_dirty = true; + } + + // Arrow-rs appends virtual columns to the supplied schema internally, + // so any `with_schema` coercion above must stay limited to file columns. + if let Some(state) = prepared.virtual_state.as_ref() { + options = options.with_virtual_columns((*state.virtual_columns).clone())?; + metadata_dirty = true; + } + + if metadata_dirty { + reader_metadata = ArrowReaderMetadata::try_new( + Arc::clone(reader_metadata.metadata()), + options.clone(), + )?; + } + + // Adapt the projection & filter predicate to the physical file schema. + // This evaluates missing columns and inserts any necessary casts. + // After rewriting to the file schema, further simplifications may be possible. + // For example, if `'a' = col_that_is_missing` becomes `'a' = NULL` that can then be simplified to `FALSE` + // and we can avoid doing any more work on the file (bloom filters, loading the page index, etc.). + // Additionally, if any casts were inserted we can move casts from the column to the literal side: + // `CAST(col AS INT) = 5` can become `col = CAST(5 AS )`, which can be evaluated statically. + // + // When the schemas are identical and there is no predicate, the + // rewriter is a no-op: column indices already match (partition + // columns are appended after file columns in the table schema), + // types are the same, and there are no missing columns. Skip the + // tree walk entirely in that case. + let needs_rewrite = prepared.predicate.is_some() + || prepared.logical_file_schema != physical_file_schema; + if needs_rewrite { + // When virtual columns are requested, augment the logical and + // physical schemas passed to the rewriter/simplifier with those + // fields. The rewriter identity-rewrites references found in both + // schemas, keeping virtual-column references as `Column` rather + // than replacing them with null literals; the simplifier needs + // them present so it can resolve their data types while walking + // expression trees. We keep `physical_file_schema` itself as the + // pure file schema so downstream predicate pushdown, pruning, and + // row filter construction stay unaffected. + let (logical_for_rewrite, physical_for_rewrite) = + if let Some(state) = prepared.virtual_state.as_ref() { + ( + Arc::clone(&state.logical_schema_with_virtual), + append_fields(&physical_file_schema, &state.virtual_columns), + ) + } else { + ( + Arc::clone(&prepared.logical_file_schema), + Arc::clone(&physical_file_schema), + ) + }; + let rewriter = prepared.expr_adapter_factory.create( + Arc::clone(&logical_for_rewrite), + Arc::clone(&physical_for_rewrite), + )?; + let simplifier = PhysicalExprSimplifier::new(&physical_for_rewrite); + prepared.predicate = prepared + .predicate + .map(|p| simplifier.simplify(rewriter.rewrite(p)?)) + .transpose()?; + prepared.projection = prepared + .projection + .try_map_exprs(|p| simplifier.simplify(rewriter.rewrite(p)?))?; + } + prepared.physical_file_schema = Arc::clone(&physical_file_schema); + + // Build predicates for this specific file + let pruning_predicate = build_pruning_predicates( + prepared.predicate.as_ref(), + &physical_file_schema, + &prepared.predicate_creation_errors, + ); + + // Only build page pruning predicate if page index is enabled + let page_pruning_predicate = if prepared.enable_page_index { + prepared.predicate.as_ref().and_then(|predicate| { + let p = build_page_pruning_predicate(predicate, &physical_file_schema); + (p.filter_number() > 0).then_some(p) + }) + } else { + None + }; + + Ok(FiltersPreparedParquetOpen { + loaded: MetadataLoadedParquetOpen { + prepared, + reader_metadata, + options, + }, + pruning_predicate, + page_pruning_predicate, + }) + } +} + +impl FiltersPreparedParquetOpen { + /// Load the page index if pruning requires it and metadata did not include it. + async fn load_page_index(mut self) -> Result { + // The page index is not stored inline in the parquet footer so the + // metadata load above may not have read the page index structures yet. + // If we need them for reading and they aren't yet loaded, we need to + // load them now. + if self.page_pruning_predicate.is_some() { + self.loaded.reader_metadata = load_page_index( + self.loaded.reader_metadata, + &mut self.loaded.prepared.async_file_reader, + self.loaded + .options + .clone() + .with_page_index_policy(PageIndexPolicy::Optional), + ) + .await?; + } + + Ok(self) + } + + /// Prune row groups using file ranges and parquet metadata. + fn prune_row_groups(self) -> Result { + let loaded = &self.loaded; + let prepared = &loaded.prepared; + let file_metadata = Arc::clone(loaded.reader_metadata.metadata()); + let rg_metadata = file_metadata.row_groups(); + + // Determine which row groups to actually read. The idea is to skip + // as many row groups as possible based on the metadata and query + let mut row_groups = RowGroupAccessPlanFilter::new(create_initial_plan( + &prepared.file_name, + &prepared.extensions, + rg_metadata.len(), + )?); + + // If there is a range restricting what parts of the file to read + if let Some(range) = prepared.file_range.as_ref() { + row_groups.prune_by_range(rg_metadata, range); + } + + // If there is a predicate that can be evaluated against the metadata + if let Some(predicate) = self.pruning_predicate.as_ref().map(|p| p.as_ref()) { + if prepared.enable_row_group_stats_pruning { + row_groups.prune_by_statistics( + &prepared.physical_file_schema, + loaded.reader_metadata.parquet_schema(), + rg_metadata, + predicate, + &prepared.file_metrics, + ); + } else { + // Update metrics: statistics unavailable, so all row groups are + // matched (not pruned) + prepared + .file_metrics + .row_groups_pruned_statistics + .add_matched(row_groups.remaining_row_group_count()); + } + + if !prepared.enable_bloom_filter || row_groups.is_empty() { + // Update metrics: bloom filter unavailable, so all row groups are + // matched (not pruned) + prepared + .file_metrics + .row_groups_pruned_bloom_filter + .add_matched(row_groups.remaining_row_group_count()); + } + } else { + // Update metrics: no predicate, so all row groups are matched (not pruned) + let remaining = row_groups.remaining_row_group_count(); + prepared + .file_metrics + .row_groups_pruned_statistics + .add_matched(remaining); + prepared + .file_metrics + .row_groups_pruned_bloom_filter + .add_matched(remaining); + } + + Ok(RowGroupsPrunedParquetOpen { + prepared: self, + row_groups, + }) + } +} + +impl RowGroupsPrunedParquetOpen { + /// Load bloom filters needed for pruning when enabled and a pruning predicate exists. + async fn load_bloom_filters(mut self) -> Result { + let num_row_groups = self + .prepared + .loaded + .reader_metadata + .metadata() + .num_row_groups(); + let mut row_group_bloom_filters = + vec![BloomFilterStatistics::new(); num_row_groups]; + + if let Some(predicate) = + self.prepared.pruning_predicate.as_ref().map(|p| p.as_ref()) + && self.prepared.loaded.prepared.enable_bloom_filter + && !self.row_groups.is_empty() + { + // Use the existing reader for bloom filter I/O; + // replace with a fresh reader for decoding below. + let reader_metadata = self.prepared.loaded.reader_metadata.clone(); + let replacement_reader = { + let prepared = &self.prepared.loaded.prepared; + prepared.parquet_file_reader_factory.create_reader( + prepared.partition_index, + prepared.partitioned_file.clone(), + prepared.metadata_size_hint, + &prepared.metrics, + )? + }; + + let prepared = &mut self.prepared.loaded.prepared; + let mut builder = ParquetRecordBatchStreamBuilder::new_with_metadata( + mem::replace(&mut prepared.async_file_reader, replacement_reader), + reader_metadata, + ); + let parquet_columns: Vec<(String, usize, Type)> = predicate + .literal_columns() + .into_iter() + .filter_map(|column_name| { + let parquet_schema = builder.parquet_schema(); + let (column_idx, _) = parquet_column( + parquet_schema, + &prepared.physical_file_schema, + &column_name, + )?; + Some(( + column_name, + column_idx, + parquet_schema.column(column_idx).physical_type(), + )) + }) + .collect(); + + for idx in self.row_groups.row_group_indexes() { + let mut row_group_filters = + BloomFilterStatistics::with_capacity(parquet_columns.len()); + for (column_name, column_idx, physical_type) in &parquet_columns { + let bf: Sbbf = match builder + .get_row_group_column_bloom_filter(idx, *column_idx) + .await + { + Ok(Some(bf)) => bf, + Ok(None) => continue, + Err(e) => { + debug!("Ignoring error reading bloom filter: {e}"); + prepared.file_metrics.predicate_evaluation_errors.add(1); + continue; + } + }; + row_group_filters.insert(column_name, bf, *physical_type); + } + row_group_bloom_filters[idx] = row_group_filters; + } + } + + Ok(BloomFiltersLoadedParquetOpen { + prepared: self, + row_group_bloom_filters, + }) + } +} + +impl BloomFiltersLoadedParquetOpen { + /// Apply bloom filter pruning using already loaded bloom filters. + fn prune_bloom_filters(mut self) -> RowGroupsPrunedParquetOpen { + let bloom_filter_eval_time = self + .prepared + .prepared + .loaded + .prepared + .file_metrics + .bloom_filter_eval_time + .clone(); + let _timer_guard = bloom_filter_eval_time.timer(); + if let Some(predicate) = self + .prepared + .prepared + .pruning_predicate + .as_ref() + .map(|p| p.as_ref()) + && self.prepared.prepared.loaded.prepared.enable_bloom_filter + && !self.prepared.row_groups.is_empty() + { + self.prepared.row_groups.prune_by_bloom_filters( + predicate, + &self.prepared.prepared.loaded.prepared.file_metrics, + &self.row_group_bloom_filters, + ); + } + + self.prepared + } +} + +impl RowGroupsPrunedParquetOpen { + /// Build the final parquet stream once all pruning work is complete. + fn build_stream(self) -> Result>> { + let RowGroupsPrunedParquetOpen { + prepared, + mut row_groups, + } = self; + let FiltersPreparedParquetOpen { + loaded, + pruning_predicate: _, + page_pruning_predicate, + } = prepared; + let MetadataLoadedParquetOpen { + prepared, + reader_metadata, + options: _, + } = loaded; + + let file_metadata = Arc::clone(reader_metadata.metadata()); + let rg_metadata = file_metadata.row_groups(); + + // Prune by limit if limit is set and limit order is not sensitive + if let (Some(limit), false) = (prepared.limit, prepared.preserve_order) { + row_groups.prune_by_limit(limit, rg_metadata, &prepared.file_metrics); + } + + // Build the access plan. Fully matched row groups have all rows + // satisfying the predicate, so page pruning and row filter evaluation + // can be skipped for them. + let mut access_plan = row_groups.build(); + + // Page index pruning: if all data on individual pages can + // be ruled using page metadata, rows from other columns + // with that range can be skipped as well. + if prepared.enable_page_index + && !access_plan.is_empty() + && let Some(page_pruning_predicate) = page_pruning_predicate + { + let page_pruning_result = page_pruning_predicate + .prune_plan_with_page_index_and_metrics( + access_plan, + &prepared.physical_file_schema, + reader_metadata.parquet_schema(), + file_metadata.as_ref(), + &prepared.file_metrics, + ); + access_plan = page_pruning_result.access_plan; + ParquetFileMetrics::add_page_index_pages_skipped_by_fully_matched( + &prepared.metrics, + prepared.partition_index, + &prepared.file_name, + page_pruning_result.pages_skipped_by_fully_matched, + ); + } + + // Prepare access plans, then apply row-group ordering tweaks per + // run. Two composable steps: + // + // 1. `reorder_by_statistics`: sort row groups by `min(col)` ASC. + // Fixes out-of-order row groups (e.g. append-heavy workloads). + // Skipped gracefully when statistics aren't available or the + // sort expression isn't a plain column. + // + // 2. `reverse`: flip the iteration order for DESC requests, applied + // AFTER any reorder so the reversed order is correct whether or + // not reorder changed anything. Also handles `row_selection` + // remapping. + // + // For sorted data: reorder is a no-op, reverse gives perfect DESC. + // For unsorted data: reorder fixes the order, reverse flips for DESC. + // + // Both inputs come from the sort-pushdown channel — + // `ParquetSource::try_pushdown_sort` sets `sort_order_for_reorder` + // and/or `reverse_row_groups`. + let prepare_access_plan = + |plan: ParquetAccessPlan| -> Result { + let mut prepared_plan = plan.prepare(rg_metadata)?; + if let Some(sort_order) = prepared.sort_order_for_reorder.as_ref() { + prepared_plan = prepared_plan.reorder_by_statistics( + sort_order, + file_metadata.as_ref(), + &prepared.physical_file_schema, + )?; + } + if prepared.reverse_row_groups { + prepared_plan = prepared_plan.reverse(file_metadata.as_ref())?; + } + Ok(prepared_plan) + }; + + let arrow_reader_metrics = ArrowReaderMetrics::enabled(); + + // Build the decoder projection (mask + per-batch transform) in a + // single call. Encapsulating it behind `DecoderProjection` keeps the + // opener's orchestration body focused on filter / decoder / stream + // wiring. + let decoder_projection = DecoderProjection::try_new( + &prepared.projection, + &prepared.physical_file_schema, + reader_metadata.parquet_schema(), + &prepared.output_schema, + prepared.virtual_state.as_deref(), + )?; + + let (decoder, pending_decoders, remaining_limit) = { + let pushdown_predicate = prepared + .pushdown_filters + .then_some(prepared.predicate.as_ref()) + .flatten(); + let mut row_filter_generator = RowFilterGenerator::new( + pushdown_predicate, + &prepared.physical_file_schema, + file_metadata.as_ref(), + prepared.reorder_predicates, + &prepared.file_metrics, + ); + + // Split into consecutive runs of row groups that share the same filter + // requirement. Fully matched row groups skip the RowFilter; others need it. + // Reverse the run order for reverse scans so the combined decoder stream + // preserves the requested global row group order. + let mut runs = access_plan.split_runs(row_filter_generator.has_row_filter()); + if prepared.reverse_row_groups { + runs.reverse(); + } + let run_count = runs.len(); + let decoder_limit = prepared.limit.filter(|_| run_count == 1); + let remaining_limit = prepared.limit.filter(|_| run_count > 1); + + let decoder_config = DecoderBuilderConfig { + projection_mask: decoder_projection.projection_mask(), + batch_size: prepared.batch_size, + arrow_reader_metrics: &arrow_reader_metrics, + force_filter_selections: prepared.force_filter_selections, + decoder_limit, + }; + + // Build a decoder per run. + let mut decoders = VecDeque::with_capacity(runs.len()); + for run in runs { + let prepared_access_plan = prepare_access_plan(run.access_plan)?; + let mut builder = + decoder_config.build(prepared_access_plan, reader_metadata.clone()); + if run.needs_filter { + if let Some(row_filter) = row_filter_generator.next_filter() { + builder = builder.with_row_filter(row_filter); + } + if let Some(max_predicate_cache_size) = + prepared.max_predicate_cache_size + { + builder = builder + .with_max_predicate_cache_size(max_predicate_cache_size); + } + } + decoders.push_back(builder.build()?); + } + + let decoder = decoders + .pop_front() + .expect("at least one decoder must be created"); + (decoder, decoders, remaining_limit) + }; + + let predicate_cache_inner_records = + prepared.file_metrics.predicate_cache_inner_records.clone(); + let predicate_cache_records = + prepared.file_metrics.predicate_cache_records.clone(); + + let files_ranges_pruned_statistics = + prepared.file_metrics.files_ranges_pruned_statistics.clone(); + let stream = PushDecoderStreamState { + decoder, + pending_decoders, + remaining_limit, + reader: prepared.async_file_reader, + decoder_projection, + arrow_reader_metrics, + predicate_cache_inner_records, + predicate_cache_records, + baseline_metrics: prepared.baseline_metrics, + } + .into_stream(); + + // Wrap the stream so a dynamic filter can stop the file scan early, but + // only when the pruner is still watching a filter that can change + // mid-scan. For a static (or already-complete) predicate the up-front + // `prune_file` check already captured everything that can be pruned, so + // per-batch re-checking would only add overhead. + match prepared.file_pruner { + Some(file_pruner) if file_pruner.is_watching() => { + Ok(EarlyStoppingStream::new( + stream, + file_pruner, + files_ranges_pruned_statistics, + ) + .boxed()) + } + _ => Ok(stream), + } + } +} + +type ConstantColumns = HashMap; + +/// Extract constant column values from statistics, keyed by column name in the logical file schema. +fn constant_columns_from_stats( + statistics: Option<&Statistics>, + file_schema: &SchemaRef, +) -> ConstantColumns { + let mut constants = HashMap::new(); + let Some(statistics) = statistics else { + return constants; + }; + + let num_rows = match statistics.num_rows { + Precision::Exact(num_rows) => Some(num_rows), + _ => None, + }; + + for (idx, column_stats) in statistics + .column_statistics + .iter() + .take(file_schema.fields().len()) + .enumerate() + { + let field = file_schema.field(idx); + if let Some(value) = + constant_value_from_stats(column_stats, num_rows, field.data_type()) + { + constants.insert(field.name().clone(), value); + } + } + + constants +} + +fn constant_value_from_stats( + column_stats: &ColumnStatistics, + num_rows: Option, + data_type: &DataType, +) -> Option { + if let (Precision::Exact(min), Precision::Exact(max)) = + (&column_stats.min_value, &column_stats.max_value) + && min == max + && !min.is_null() + && matches!(column_stats.null_count, Precision::Exact(0)) + { + // Cast to the expected data type if needed (e.g., Utf8 -> Dictionary) + if min.data_type() != *data_type { + return min.cast_to(data_type).ok(); + } + return Some(min.clone()); + } + + if let (Some(num_rows), Precision::Exact(nulls)) = + (num_rows, &column_stats.null_count) + && *nulls == num_rows + { + return ScalarValue::try_new_null(data_type).ok(); + } + + None +} + +/// Return the initial [`ParquetAccessPlan`] +/// +/// If the user has supplied one as an extension, use that +/// otherwise return a plan that scans all row groups +/// +/// Returns an error if an invalid `ParquetAccessPlan` is provided +/// +/// Note: file_name is only used for error messages +fn create_initial_plan( + file_name: &str, + extensions: &datafusion_datasource::FileExtensions, + row_group_count: usize, +) -> Result { + if let Some(access_plan) = extensions.get::() { + let plan_len = access_plan.len(); + if plan_len != row_group_count { + return exec_err!( + "Invalid ParquetAccessPlan for {file_name}. Specified {plan_len} row groups, but file has {row_group_count}" + ); + } + return Ok(access_plan.clone()); + } + + // default to scanning all row groups + Ok(ParquetAccessPlan::new_all(row_group_count)) +} + +/// Build a page pruning predicate from an optional predicate expression. +/// If the predicate is None or the predicate cannot be converted to a page pruning +/// predicate, return None. +pub(crate) fn build_page_pruning_predicate( + predicate: &Arc, + file_schema: &SchemaRef, +) -> Arc { + Arc::new(PagePruningAccessPlanFilter::new( + predicate, + Arc::clone(file_schema), + )) +} + +pub(crate) fn build_pruning_predicates( + predicate: Option<&Arc>, + file_schema: &SchemaRef, + predicate_creation_errors: &Count, +) -> Option> { + let predicate = predicate.as_ref()?; + build_pruning_predicate( + Arc::clone(predicate), + file_schema, + predicate_creation_errors, + ) +} + +/// Returns a `ArrowReaderMetadata` with the page index loaded, loading +/// it from the underlying `AsyncFileReader` if necessary. +async fn load_page_index( + reader_metadata: ArrowReaderMetadata, + input: &mut T, + options: ArrowReaderOptions, +) -> Result { + let parquet_metadata = reader_metadata.metadata(); + let missing_column_index = parquet_metadata.column_index().is_none(); + let missing_offset_index = parquet_metadata.offset_index().is_none(); + // You may ask yourself: why are we even checking if the page index is already loaded here? + // Didn't we explicitly *not* load it above? + // Well it's possible that a custom implementation of `AsyncFileReader` gives you + // the page index even if you didn't ask for it (e.g. because it's cached) + // so it's important to check that here to avoid extra work. + if missing_column_index || missing_offset_index { + let m = Arc::try_unwrap(Arc::clone(parquet_metadata)) + .unwrap_or_else(|e| e.as_ref().clone()); + let mut reader = ParquetMetaDataReader::new_with_metadata(m) + .with_page_index_policy(PageIndexPolicy::Optional); + reader.load_page_index(input).await?; + let new_parquet_metadata = reader.finish()?; + let new_arrow_reader = + ArrowReaderMetadata::try_new(Arc::new(new_parquet_metadata), options)?; + Ok(new_arrow_reader) + } else { + // No need to load the page index again, just return the existing metadata + Ok(reader_metadata) + } +} + +#[cfg(test)] +mod test { + use super::*; + use super::{ConstantColumns, ParquetMorselizer, constant_columns_from_stats}; + use crate::{DefaultParquetFileReaderFactory, RowGroupAccess}; + use arrow::array::RecordBatch; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use bytes::{BufMut, BytesMut}; + use datafusion_common::{ + ColumnStatistics, ScalarValue, Statistics, internal_err, record_batch, + stats::Precision, + }; + use datafusion_datasource::morsel::{Morsel, Morselizer}; + use datafusion_datasource::{PartitionedFile, TableSchema, TableSchemaBuilder}; + use datafusion_expr::{col, lit}; + use datafusion_physical_expr::{ + PhysicalExpr, + expressions::{Column, DynamicFilterPhysicalExpr, Literal}, + planner::logical2physical, + projection::ProjectionExprs, + }; + use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapterFactory, replace_columns_with_literals, + }; + use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; + use futures::StreamExt; + use futures::stream::BoxStream; + use object_store::{ObjectStore, ObjectStoreExt, memory::InMemory, path::Path}; + use parquet::arrow::ArrowWriter; + use parquet::file::properties::WriterProperties; + use std::collections::VecDeque; + use std::sync::Arc; + + /// Builder for creating [`ParquetMorselizer`] instances with sensible defaults for tests. + /// This helps reduce code duplication and makes it clear what differs between test cases. + struct ParquetMorselizerBuilder { + store: Option>, + table_schema: Option, + partition_index: usize, + projection_indices: Option>, + projection: Option, + batch_size: usize, + limit: Option, + predicate: Option>, + metadata_size_hint: Option, + metrics: ExecutionPlanMetricsSet, + pushdown_filters: bool, + reorder_filters: bool, + force_filter_selections: bool, + enable_page_index: bool, + enable_bloom_filter: bool, + enable_row_group_stats_pruning: bool, + coerce_int96: Option, + max_predicate_cache_size: Option, + reverse_row_groups: bool, + preserve_order: bool, + } + + impl ParquetMorselizerBuilder { + /// Create a new builder with sensible defaults for tests. + fn new() -> Self { + Self { + store: None, + table_schema: None, + partition_index: 0, + projection_indices: None, + projection: None, + batch_size: 1024, + limit: None, + predicate: None, + metadata_size_hint: None, + metrics: ExecutionPlanMetricsSet::new(), + pushdown_filters: false, + reorder_filters: false, + force_filter_selections: false, + enable_page_index: false, + enable_bloom_filter: false, + enable_row_group_stats_pruning: false, + coerce_int96: None, + max_predicate_cache_size: None, + reverse_row_groups: false, + preserve_order: false, + } + } + + /// Set the object store (required for building). + fn with_store(mut self, store: Arc) -> Self { + self.store = Some(store); + self + } + + /// Create a simple table schema from a file schema (for files without partition columns). + fn with_schema(mut self, file_schema: SchemaRef) -> Self { + self.table_schema = Some(TableSchema::from(file_schema)); + self + } + + /// Set a custom table schema (for files with partition columns). + fn with_table_schema(mut self, table_schema: TableSchema) -> Self { + self.table_schema = Some(table_schema); + self + } + + /// Set projection by column indices. + /// + /// The indices are resolved against the **file schema**, not the full + /// table schema. Callers that need to project partition columns or + /// virtual columns must use [`Self::with_projection`] and construct a + /// [`ProjectionExprs`] against [`TableSchema::table_schema`]. + fn with_projection_indices(mut self, indices: &[usize]) -> Self { + self.projection_indices = Some(indices.to_vec()); + self + } + + /// Set an explicit projection. + /// + /// Prefer this over [`Self::with_projection_indices`] whenever the + /// projection must reference partition or virtual columns, since + /// `with_projection_indices` resolves its indices against the file + /// schema only. + fn with_projection(mut self, projection: ProjectionExprs) -> Self { + self.projection = Some(projection); + self + } + + /// Set the predicate. + fn with_predicate(mut self, predicate: Arc) -> Self { + self.predicate = Some(predicate); + self + } + + /// Enable pushdown filters. + fn with_pushdown_filters(mut self, enable: bool) -> Self { + self.pushdown_filters = enable; + self + } + + /// Enable filter reordering. + fn with_reorder_filters(mut self, enable: bool) -> Self { + self.reorder_filters = enable; + self + } + + /// Enable row group stats pruning. + fn with_row_group_stats_pruning(mut self, enable: bool) -> Self { + self.enable_row_group_stats_pruning = enable; + self + } + + /// Enable page index. + fn with_enable_page_index(mut self, enable: bool) -> Self { + self.enable_page_index = enable; + self + } + + /// Set a row limit. + fn with_limit(mut self, limit: usize) -> Self { + self.limit = Some(limit); + self + } + + /// Set reverse row groups flag. + fn with_reverse_row_groups(mut self, enable: bool) -> Self { + self.reverse_row_groups = enable; + self + } + + /// Build the ParquetMorselizer instance, unwrapping validation errors. + /// + /// # Panics + /// + /// Panics if required fields (store, schema/table_schema) are not set, + /// or if virtual-column validation fails. Use [`Self::try_build`] + /// when the test wants to assert on the validation error. + fn build(self) -> ParquetMorselizer { + self.try_build().expect("ParquetMorselizerBuilder::build") + } + + /// Build the ParquetMorselizer instance, returning any morselizer-level + /// validation error (e.g. unsupported virtual extension type, or a + /// predicate that references a virtual column with + /// `pushdown_filters=true`). + /// + /// # Panics + /// + /// Panics if required fields (store, schema/table_schema) are not set. + fn try_build(self) -> Result { + let store = self + .store + .expect("ParquetMorselizerBuilder: store must be set via with_store()"); + let table_schema = self.table_schema.expect( + "ParquetMorselizerBuilder: table_schema must be set via with_schema() or with_table_schema()", + ); + let file_schema = Arc::clone(table_schema.file_schema()); + + let projection = if let Some(projection) = self.projection { + projection + } else if let Some(indices) = self.projection_indices { + ProjectionExprs::from_indices(&indices, &file_schema) + } else { + // Default: project all columns + let all_indices: Vec = (0..file_schema.fields().len()).collect(); + ProjectionExprs::from_indices(&all_indices, &file_schema) + }; + + let virtual_state = build_virtual_columns_state( + table_schema.virtual_columns(), + table_schema.file_schema(), + self.predicate.as_ref(), + self.pushdown_filters, + )?; + + Ok(ParquetMorselizer { + partition_index: self.partition_index, + projection, + batch_size: self.batch_size, + limit: self.limit, + preserve_order: self.preserve_order, + predicate: self.predicate, + table_schema, + metadata_size_hint: self.metadata_size_hint, + metrics: self.metrics, + parquet_file_reader_factory: Arc::new( + DefaultParquetFileReaderFactory::new(store), + ), + pushdown_filters: self.pushdown_filters, + reorder_filters: self.reorder_filters, + force_filter_selections: self.force_filter_selections, + enable_page_index: self.enable_page_index, + enable_bloom_filter: self.enable_bloom_filter, + enable_row_group_stats_pruning: self.enable_row_group_stats_pruning, + coerce_int96: self.coerce_int96, + // End-to-end coercion behavior (including timezone) is + // covered by parquet.slt. No opener-level test currently + // needs a non-default value here. + coerce_int96_tz: None, + #[cfg(feature = "parquet_encryption")] + file_decryption_properties: None, + expr_adapter_factory: Arc::new(DefaultPhysicalExprAdapterFactory), + #[cfg(feature = "parquet_encryption")] + encryption_factory: None, + max_predicate_cache_size: self.max_predicate_cache_size, + reverse_row_groups: self.reverse_row_groups, + sort_order_for_reorder: None, + virtual_state, + }) + } + } + + /// Test helper that drives a [`ParquetMorselizer`] to completion and returns + /// the first stream morsel it produces. + /// + /// This mirrors how `FileStream` consumes the morsel APIs: it repeatedly + /// plans CPU work, awaits any discovered I/O futures, and feeds the planner + /// back into the ready queue until a stream morsel is ready. + async fn open_file( + morselizer: &ParquetMorselizer, + file: PartitionedFile, + ) -> Result>> { + let mut planners = VecDeque::from([morselizer.plan_file(file)?]); + let mut morsels: VecDeque> = VecDeque::new(); + + loop { + if let Some(morsel) = morsels.pop_front() { + return Ok(Box::pin(morsel.into_stream())); + } + + let Some(planner) = planners.pop_front() else { + return Ok(Box::pin(futures::stream::empty())); + }; + + if let Some(mut plan) = planner.plan()? { + morsels.extend(plan.take_morsels()); + planners.extend(plan.take_ready_planners()); + + if let Some(pending_planner) = plan.take_pending_planner() { + planners.push_front(pending_planner.await?); + continue; + } + + if morsels.is_empty() && planners.is_empty() { + return internal_err!("planner returned an empty morsel plan"); + } + } + } + } + + fn constant_int_stats() -> (Statistics, SchemaRef) { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let statistics = Statistics { + num_rows: Precision::Exact(3), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::from(5i32)), + min_value: Precision::Exact(ScalarValue::from(5i32)), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }, + ColumnStatistics::new_unknown(), + ], + }; + (statistics, schema) + } + + #[test] + fn extract_constant_columns_non_null() { + let (statistics, schema) = constant_int_stats(); + let constants = constant_columns_from_stats(Some(&statistics), &schema); + assert_eq!(constants.len(), 1); + assert_eq!(constants.get("a"), Some(&ScalarValue::from(5i32))); + assert!(!constants.contains_key("b")); + } + + #[test] + fn extract_constant_columns_all_null() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + let statistics = Statistics { + num_rows: Precision::Exact(2), + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(2), + max_value: Precision::Absent, + min_value: Precision::Absent, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }], + }; + + let constants = constant_columns_from_stats(Some(&statistics), &schema); + assert_eq!( + constants.get("a"), + Some(&ScalarValue::Utf8(None)), + "all-null column should be treated as constant null" + ); + } + + #[test] + fn rewrite_projection_to_literals() { + let (statistics, schema) = constant_int_stats(); + let constants = constant_columns_from_stats(Some(&statistics), &schema); + let projection = ProjectionExprs::from_indices(&[0, 1], &schema); + + let rewritten = projection + .try_map_exprs(|expr| replace_columns_with_literals(expr, &constants)) + .unwrap(); + let exprs = rewritten.as_ref(); + assert!(exprs[0].expr.downcast_ref::().is_some()); + assert!(exprs[1].expr.downcast_ref::().is_some()); + + // Only column `b` should remain in the projection mask + assert_eq!(rewritten.column_indices(), vec![1]); + } + + #[test] + fn rewrite_physical_expr_literal() { + let mut constants = ConstantColumns::new(); + constants.insert("a".to_string(), ScalarValue::from(7i32)); + let expr: Arc = Arc::new(Column::new("a", 0)); + + let rewritten = replace_columns_with_literals(expr, &constants).unwrap(); + assert!(rewritten.downcast_ref::().is_some()); + } + + async fn count_batches_and_rows( + mut stream: BoxStream<'static, Result>, + ) -> (usize, usize) { + let mut num_batches = 0; + let mut num_rows = 0; + while let Some(Ok(batch)) = stream.next().await { + num_rows += batch.num_rows(); + num_batches += 1; + } + (num_batches, num_rows) + } + + /// Helper to collect all int32 values from the first column of batches + async fn collect_int32_values( + mut stream: BoxStream<'static, Result>, + ) -> Vec { + use arrow::array::Array; + let mut values = vec![]; + while let Some(Ok(batch)) = stream.next().await { + let array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..array.len() { + if !array.is_null(i) { + values.push(array.value(i)); + } + } + } + values + } + + async fn write_parquet( + store: Arc, + filename: &str, + batch: RecordBatch, + ) -> usize { + write_parquet_batches(store, filename, vec![batch], None).await + } + + /// Write multiple batches to a parquet file with optional writer properties + async fn write_parquet_batches( + store: Arc, + filename: &str, + batches: Vec, + props: Option, + ) -> usize { + let mut out = BytesMut::new().writer(); + { + let schema = batches[0].schema(); + let mut writer = ArrowWriter::try_new(&mut out, schema, props).unwrap(); + for batch in batches { + writer.write(&batch).unwrap(); + } + writer.finish().unwrap(); + } + let data = out.into_inner().freeze(); + let data_len = data.len(); + store.put(&Path::from(filename), data.into()).await.unwrap(); + data_len + } + + fn make_dynamic_expr(expr: Arc) -> Arc { + Arc::new(DynamicFilterPhysicalExpr::new( + expr.children().into_iter().map(Arc::clone).collect(), + expr, + )) + } + + #[tokio::test] + async fn test_prune_on_statistics() { + let store = Arc::new(InMemory::new()) as Arc; + + let batch = record_batch!( + ("a", Int32, vec![Some(1), Some(2), Some(2)]), + ("b", Float32, vec![Some(1.0), Some(2.0), None]) + ) + .unwrap(); + + let data_size = + write_parquet(Arc::clone(&store), "test.parquet", batch.clone()).await; + + let schema = batch.schema(); + let file = PartitionedFile::new( + "test.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ) + .with_statistics(Arc::new( + Statistics::new_unknown(&schema) + .add_column_statistics(ColumnStatistics::new_unknown()) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Float32(Some(1.0)))) + .with_max_value(Precision::Exact(ScalarValue::Float32(Some(2.0)))) + .with_null_count(Precision::Exact(1)), + ), + )); + + let make_opener = |predicate| { + ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_schema(Arc::clone(&schema)) + .with_projection_indices(&[0, 1]) + .with_predicate(predicate) + .with_row_group_stats_pruning(true) + .build() + }; + + // A filter on "a" should not exclude any rows even if it matches the data + let expr = col("a").eq(lit(1)); + let predicate = logical2physical(&expr, &schema); + let opener = make_opener(predicate); + let stream = open_file(&opener, file.clone()).await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // A filter on `b = 5.0` should exclude all rows + let expr = col("b").eq(lit(ScalarValue::Float32(Some(5.0)))); + let predicate = logical2physical(&expr, &schema); + let opener = make_opener(predicate); + let stream = open_file(&opener, file).await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } + + #[tokio::test] + async fn test_prune_on_partition_statistics_with_dynamic_expression() { + let store = Arc::new(InMemory::new()) as Arc; + + let batch = record_batch!(("a", Int32, vec![Some(1), Some(2), Some(3)])).unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "part=1/file.parquet", batch.clone()).await; + + let file_schema = batch.schema(); + let mut file = PartitionedFile::new( + "part=1/file.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + file.partition_values = vec![ScalarValue::Int32(Some(1))]; + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("part", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + + let table_schema_for_opener = TableSchemaBuilder::from(&file_schema) + .with_table_partition_cols(vec![Arc::new(Field::new( + "part", + DataType::Int32, + false, + ))]) + .build(); + let make_opener = |predicate| { + ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_table_schema(table_schema_for_opener.clone()) + .with_projection_indices(&[0]) + .with_predicate(predicate) + .with_row_group_stats_pruning(true) + .build() + }; + + // Filter should match the partition value + let expr = col("part").eq(lit(1)); + // Mark the expression as dynamic even if it's not to force partition pruning to happen + // Otherwise we assume it already happened at the planning stage and won't re-do the work here + let predicate = make_dynamic_expr(logical2physical(&expr, &table_schema)); + let opener = make_opener(predicate); + let stream = open_file(&opener, file.clone()).await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // Filter should not match the partition value + let expr = col("part").eq(lit(2)); + // Mark the expression as dynamic even if it's not to force partition pruning to happen + // Otherwise we assume it already happened at the planning stage and won't re-do the work here + let predicate = make_dynamic_expr(logical2physical(&expr, &table_schema)); + let opener = make_opener(predicate); + let stream = open_file(&opener, file).await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } + + #[tokio::test] + async fn test_prune_on_partition_values_and_file_statistics() { + let store = Arc::new(InMemory::new()) as Arc; + + let batch = record_batch!( + ("a", Int32, vec![Some(1), Some(2), Some(3)]), + ("b", Float64, vec![Some(1.0), Some(2.0), None]) + ) + .unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "part=1/file.parquet", batch.clone()).await; + let file_schema = batch.schema(); + let mut file = PartitionedFile::new( + "part=1/file.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + file.partition_values = vec![ScalarValue::Int32(Some(1))]; + file.statistics = Some(Arc::new( + Statistics::new_unknown(&file_schema) + .add_column_statistics(ColumnStatistics::new_unknown()) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Float64(Some(1.0)))) + .with_max_value(Precision::Exact(ScalarValue::Float64(Some(2.0)))) + .with_null_count(Precision::Exact(1)), + ), + )); + let table_schema = Arc::new(Schema::new(vec![ + Field::new("part", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Float32, true), + ])); + let table_schema_for_opener = TableSchemaBuilder::from(&file_schema) + .with_table_partition_cols(vec![Arc::new(Field::new( + "part", + DataType::Int32, + false, + ))]) + .build(); + let make_opener = |predicate| { + ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_table_schema(table_schema_for_opener.clone()) + .with_projection_indices(&[0]) + .with_predicate(predicate) + .with_row_group_stats_pruning(true) + .build() + }; + + // Filter should match the partition value and file statistics + let expr = col("part").eq(lit(1)).and(col("b").eq(lit(1.0))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = open_file(&opener, file.clone()).await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // Should prune based on partition value but not file statistics + let expr = col("part").eq(lit(2)).and(col("b").eq(lit(1.0))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = open_file(&opener, file.clone()).await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + + // Should prune based on file statistics but not partition value + let expr = col("part").eq(lit(1)).and(col("b").eq(lit(7.0))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = open_file(&opener, file.clone()).await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + + // Should prune based on both partition value and file statistics + let expr = col("part").eq(lit(2)).and(col("b").eq(lit(7.0))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = open_file(&opener, file).await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } + + #[tokio::test] + async fn test_prune_on_partition_value_and_data_value() { + let store = Arc::new(InMemory::new()) as Arc; + + // Note: number 3 is missing! + let batch = record_batch!(("a", Int32, vec![Some(1), Some(2), Some(4)])).unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "part=1/file.parquet", batch.clone()).await; + + let file_schema = batch.schema(); + let mut file = PartitionedFile::new( + "part=1/file.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + file.partition_values = vec![ScalarValue::Int32(Some(1))]; + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("part", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + + let table_schema_for_opener = TableSchemaBuilder::from(&file_schema) + .with_table_partition_cols(vec![Arc::new(Field::new( + "part", + DataType::Int32, + false, + ))]) + .build(); + let make_opener = |predicate| { + ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_table_schema(table_schema_for_opener.clone()) + .with_projection_indices(&[0]) + .with_predicate(predicate) + .with_pushdown_filters(true) // note that this is true! + .with_reorder_filters(true) + .build() + }; + + // Filter should match the partition value and data value + let expr = col("part").eq(lit(1)).or(col("a").eq(lit(1))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = open_file(&opener, file.clone()).await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // Filter should match the partition value but not the data value + let expr = col("part").eq(lit(1)).or(col("a").eq(lit(3))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = open_file(&opener, file.clone()).await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // Filter should not match the partition value but match the data value + let expr = col("part").eq(lit(2)).or(col("a").eq(lit(1))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = open_file(&opener, file.clone()).await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 1); + + // Filter should not match the partition value or the data value + let expr = col("part").eq(lit(2)).or(col("a").eq(lit(3))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = open_file(&opener, file).await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } + + /// Test that if the filter is not a dynamic filter and we have no stats we don't do extra pruning work at the file level. + #[tokio::test] + async fn test_opener_pruning_skipped_on_static_filters() { + let store = Arc::new(InMemory::new()) as Arc; + + let batch = record_batch!(("a", Int32, vec![Some(1), Some(2), Some(3)])).unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "part=1/file.parquet", batch.clone()).await; + + let file_schema = batch.schema(); + let mut file = PartitionedFile::new( + "part=1/file.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + file.partition_values = vec![ScalarValue::Int32(Some(1))]; + file.statistics = Some(Arc::new( + Statistics::default().add_column_statistics( + ColumnStatistics::new_unknown() + .with_min_value(Precision::Exact(ScalarValue::Int32(Some(1)))) + .with_max_value(Precision::Exact(ScalarValue::Int32(Some(3)))) + .with_null_count(Precision::Exact(0)), + ), + )); + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("part", DataType::Int32, false), + ])); + + let table_schema_for_opener = TableSchemaBuilder::from(&file_schema) + .with_table_partition_cols(vec![Arc::new(Field::new( + "part", + DataType::Int32, + false, + ))]) + .build(); + let make_opener = |predicate| { + ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_table_schema(table_schema_for_opener.clone()) + .with_projection_indices(&[0]) + .with_predicate(predicate) + .build() + }; + + // This filter could prune based on statistics, but since it's not dynamic it's not applied for pruning + // (the assumption is this happened already at planning time) + let expr = col("a").eq(lit(42)); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = open_file(&opener, file.clone()).await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + + // If we make the filter dynamic, it should prune. + // This allows dynamic filters to prune partitions/files even if they are populated late into execution. + let predicate = make_dynamic_expr(logical2physical(&expr, &table_schema)); + let opener = make_opener(predicate); + let stream = open_file(&opener, file.clone()).await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + + // If we have a filter that touches partition columns only and is dynamic, it should prune even if there are no stats. + file.statistics = Some(Arc::new(Statistics::new_unknown(&file_schema))); + let expr = col("part").eq(lit(2)); + let predicate = make_dynamic_expr(logical2physical(&expr, &table_schema)); + let opener = make_opener(predicate); + let stream = open_file(&opener, file.clone()).await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + + // Similarly a filter that combines partition and data columns should prune even if there are no stats. + let expr = col("part").eq(lit(2)).and(col("a").eq(lit(42))); + let predicate = make_dynamic_expr(logical2physical(&expr, &table_schema)); + let opener = make_opener(predicate); + let stream = open_file(&opener, file.clone()).await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } + + #[tokio::test] + async fn test_opener_prioritizes_partitioned_file_schema() { + let store = Arc::new(InMemory::new()) as Arc; + + let batch = record_batch!( + ("a", Int32, vec![Some(1), Some(2), Some(2)]), + ("b", Float32, vec![Some(1.0), Some(2.0), None]) + ) + .unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "test.parquet", batch.clone()).await; + + let schema = batch.schema(); + let query_file = async |schema: SchemaRef| -> Result<(usize, usize)> { + let file = PartitionedFile::new( + "test.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ) + .with_arrow_schema(schema.clone()); + + let predicate = logical2physical(&col("a").eq(lit(1)), &schema); + let opener = ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_schema(Arc::clone(&schema)) + .with_predicate(predicate) + .build(); + + let stream = open_file(&opener, file.clone()).await?; + Ok(count_batches_and_rows(stream).await) + }; + + let (num_batches, num_rows) = + query_file(schema.clone()).await.expect("query_file"); + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + let mismatching_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Float64, true), + ]); + assert_eq!( + query_file(SchemaRef::new(mismatching_schema)) + .await + .unwrap_err() + .message(), + "Arrow: Incompatible supplied Arrow schema: data type mismatch for field b: requested Float64 but found Float32" + ); + } + + #[tokio::test] + async fn test_reverse_scan_row_groups() { + use parquet::file::properties::WriterProperties; + + let store = Arc::new(InMemory::new()) as Arc; + + // Create multiple batches to ensure multiple row groups + let batch1 = + record_batch!(("a", Int32, vec![Some(1), Some(2), Some(3)])).unwrap(); + let batch2 = + record_batch!(("a", Int32, vec![Some(4), Some(5), Some(6)])).unwrap(); + let batch3 = + record_batch!(("a", Int32, vec![Some(7), Some(8), Some(9)])).unwrap(); + + // Write parquet file with multiple row groups + // Force small row groups by setting max_row_group_size + let props = WriterProperties::builder() + .set_max_row_group_row_count(Some(3)) // Force each batch into its own row group + .build(); + + let data_len = write_parquet_batches( + Arc::clone(&store), + "test.parquet", + vec![batch1.clone(), batch2, batch3], + Some(props), + ) + .await; + + let schema = batch1.schema(); + let file = PartitionedFile::new( + "test.parquet".to_string(), + u64::try_from(data_len).unwrap(), + ); + + let make_opener = |reverse_scan: bool| { + ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_schema(Arc::clone(&schema)) + .with_projection_indices(&[0]) + .with_reverse_row_groups(reverse_scan) + .build() + }; + + // Test normal scan (forward) + let opener = make_opener(false); + let stream = open_file(&opener, file.clone()).await.unwrap(); + let forward_values = collect_int32_values(stream).await; + + // Test reverse scan + let opener = make_opener(true); + let stream = open_file(&opener, file.clone()).await.unwrap(); + let reverse_values = collect_int32_values(stream).await; + + // The forward scan should return data in the order written + assert_eq!(forward_values, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]); + + // With reverse scan, row groups are reversed, so we expect: + // Row group 3 (7,8,9), then row group 2 (4,5,6), then row group 1 (1,2,3) + assert_eq!(reverse_values, vec![7, 8, 9, 4, 5, 6, 1, 2, 3]); + } + + #[tokio::test] + async fn test_reverse_scan_single_row_group() { + let store = Arc::new(InMemory::new()) as Arc; + + // Create a single batch (single row group) + let batch = record_batch!(("a", Int32, vec![Some(1), Some(2), Some(3)])).unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "test.parquet", batch.clone()).await; + + let schema = batch.schema(); + let file = PartitionedFile::new( + "test.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + + let make_opener = |reverse_scan: bool| { + ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_schema(Arc::clone(&schema)) + .with_projection_indices(&[0]) + .with_reverse_row_groups(reverse_scan) + .build() + }; + + // With a single row group, forward and reverse should be the same + // (only the row group order is reversed, not the rows within) + let opener_forward = make_opener(false); + let stream_forward = open_file(&opener_forward, file.clone()).await.unwrap(); + let (batches_forward, _) = count_batches_and_rows(stream_forward).await; + + let opener_reverse = make_opener(true); + let stream_reverse = open_file(&opener_reverse, file).await.unwrap(); + let (batches_reverse, _) = count_batches_and_rows(stream_reverse).await; + + // Both should have the same number of batches since there's only one row group + assert_eq!(batches_forward, batches_reverse); + } + + #[tokio::test] + async fn test_reverse_scan_with_row_selection() { + use parquet::file::properties::WriterProperties; + + let store = Arc::new(InMemory::new()) as Arc; + + // Create 3 batches with DIFFERENT selection patterns + let batch1 = + record_batch!(("a", Int32, vec![Some(1), Some(2), Some(3), Some(4)])) + .unwrap(); // 4 rows + let batch2 = + record_batch!(("a", Int32, vec![Some(5), Some(6), Some(7), Some(8)])) + .unwrap(); // 4 rows + let batch3 = + record_batch!(("a", Int32, vec![Some(9), Some(10), Some(11), Some(12)])) + .unwrap(); // 4 rows + + let props = WriterProperties::builder() + .set_max_row_group_row_count(Some(4)) + .build(); + + let data_len = write_parquet_batches( + Arc::clone(&store), + "test.parquet", + vec![batch1.clone(), batch2, batch3], + Some(props), + ) + .await; + + let schema = batch1.schema(); + + use crate::ParquetAccessPlan; + use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; + + let mut access_plan = ParquetAccessPlan::new_all(3); + // Row group 0: skip first 2, select last 2 (should get: 3, 4) + access_plan.scan_selection( + 0, + RowSelection::from(vec![RowSelector::skip(2), RowSelector::select(2)]), + ); + // Row group 1: select all (should get: 5, 6, 7, 8) + // Row group 2: select first 2, skip last 2 (should get: 9, 10) + access_plan.scan_selection( + 2, + RowSelection::from(vec![RowSelector::select(2), RowSelector::skip(2)]), + ); + + let file = PartitionedFile::new( + "test.parquet".to_string(), + u64::try_from(data_len).unwrap(), + ) + .with_extension(access_plan); + + let make_opener = |reverse_scan: bool| { + ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_schema(Arc::clone(&schema)) + .with_projection_indices(&[0]) + .with_reverse_row_groups(reverse_scan) + .build() + }; + + // Forward scan: RG0(3,4), RG1(5,6,7,8), RG2(9,10) + let opener = make_opener(false); + let stream = open_file(&opener, file.clone()).await.unwrap(); + let forward_values = collect_int32_values(stream).await; + + // Forward scan should produce: RG0(3,4), RG1(5,6,7,8), RG2(9,10) + assert_eq!( + forward_values, + vec![3, 4, 5, 6, 7, 8, 9, 10], + "Forward scan should select correct rows based on RowSelection" + ); + + // Reverse scan + // CORRECT behavior: reverse row groups AND their corresponding selections + // - RG2 is read first, WITH RG2's selection (select 2, skip 2) -> 9, 10 + // - RG1 is read second, WITH RG1's selection (select all) -> 5, 6, 7, 8 + // - RG0 is read third, WITH RG0's selection (skip 2, select 2) -> 3, 4 + let opener = make_opener(true); + let stream = open_file(&opener, file).await.unwrap(); + let reverse_values = collect_int32_values(stream).await; + + // Correct expected result: row groups reversed but each keeps its own selection + // RG2 with its selection (9,10), RG1 with its selection (5,6,7,8), RG0 with its selection (3,4) + assert_eq!( + reverse_values, + vec![9, 10, 5, 6, 7, 8, 3, 4], + "Reverse scan should reverse row group order while maintaining correct RowSelection for each group" + ); + } + + #[tokio::test] + async fn test_reverse_scan_with_non_contiguous_row_groups() { + use parquet::file::properties::WriterProperties; + + let store = Arc::new(InMemory::new()) as Arc; + + // Create 4 batches (4 row groups) + let batch0 = record_batch!(("a", Int32, vec![Some(1), Some(2)])).unwrap(); + let batch1 = record_batch!(("a", Int32, vec![Some(3), Some(4)])).unwrap(); + let batch2 = record_batch!(("a", Int32, vec![Some(5), Some(6)])).unwrap(); + let batch3 = record_batch!(("a", Int32, vec![Some(7), Some(8)])).unwrap(); + + let props = WriterProperties::builder() + .set_max_row_group_row_count(Some(2)) + .build(); + + let data_len = write_parquet_batches( + Arc::clone(&store), + "test.parquet", + vec![batch0.clone(), batch1, batch2, batch3], + Some(props), + ) + .await; + + let schema = batch0.schema(); + + use crate::ParquetAccessPlan; + use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; + + // KEY: Skip RG1 (non-contiguous!) + // Only scan row groups: [0, 2, 3] + let mut access_plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, // RG0 + RowGroupAccess::Skip, // RG1 - SKIPPED! + RowGroupAccess::Scan, // RG2 + RowGroupAccess::Scan, // RG3 + ]); + + // Add RowSelection for each scanned row group + // RG0: select first row (1), skip second (2) + access_plan.scan_selection( + 0, + RowSelection::from(vec![RowSelector::select(1), RowSelector::skip(1)]), + ); + // RG1: skipped, no selection needed + // RG2: select first row (5), skip second (6) + access_plan.scan_selection( + 2, + RowSelection::from(vec![RowSelector::select(1), RowSelector::skip(1)]), + ); + // RG3: select first row (7), skip second (8) + access_plan.scan_selection( + 3, + RowSelection::from(vec![RowSelector::select(1), RowSelector::skip(1)]), + ); + + let file = PartitionedFile::new( + "test.parquet".to_string(), + u64::try_from(data_len).unwrap(), + ) + .with_extension(access_plan); + + let make_opener = |reverse_scan: bool| { + ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_schema(Arc::clone(&schema)) + .with_projection_indices(&[0]) + .with_reverse_row_groups(reverse_scan) + .build() + }; + + // Forward scan: RG0(1), RG2(5), RG3(7) + // Note: RG1 is completely skipped + let opener = make_opener(false); + let stream = open_file(&opener, file.clone()).await.unwrap(); + let forward_values = collect_int32_values(stream).await; + + assert_eq!( + forward_values, + vec![1, 5, 7], + "Forward scan with non-contiguous row groups" + ); + + // Reverse scan: RG3(7), RG2(5), RG0(1) + // WITHOUT the bug fix, this would return WRONG values + // because the RowSelection would be incorrectly mapped + let opener = make_opener(true); + let stream = open_file(&opener, file).await.unwrap(); + let reverse_values = collect_int32_values(stream).await; + + assert_eq!( + reverse_values, + vec![7, 5, 1], + "Reverse scan with non-contiguous row groups should correctly map RowSelection" + ); + } + + /// Test that page pruning predicates are only built and applied when `enable_page_index` is true. + /// + /// The file has a single row group with 10 pages (10 rows each, values 1..100). + /// With page index enabled, pages whose max value <= 90 are pruned, returning only + /// the last page (rows 91..100). With page index disabled, all 100 rows are returned + /// since neither pushdown nor row-group pruning is active. + #[tokio::test] + async fn test_page_pruning_predicate_respects_enable_page_index() { + use parquet::file::properties::WriterProperties; + + let store = Arc::new(InMemory::new()) as Arc; + + // 100 rows with values 1..=100, written as a single row group with 10 rows per page + let values: Vec = (1..=100).collect(); + let batch = record_batch!(( + "a", + Int32, + values.iter().map(|v| Some(*v)).collect::>() + )) + .unwrap(); + let props = WriterProperties::builder() + .set_data_page_row_count_limit(10) + .set_write_batch_size(10) + .build(); + let schema = batch.schema(); + let data_size = write_parquet_batches( + Arc::clone(&store), + "test.parquet", + vec![batch], + Some(props), + ) + .await; + + let file = PartitionedFile::new("test.parquet".to_string(), data_size as u64); + + // predicate: a > 90 — should allow page index to prune first 9 pages + let predicate = logical2physical(&col("a").gt(lit(90i32)), &schema); + + let make_morselizer = |enable_page_index| { + ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_schema(Arc::clone(&schema)) + .with_predicate(Arc::clone(&predicate)) + .with_enable_page_index(enable_page_index) + // disable pushdown and row-group pruning so the only pruning path is page index + .with_pushdown_filters(false) + .with_row_group_stats_pruning(false) + .build() + }; + let (_, rows_with_page_index) = count_batches_and_rows( + open_file(&make_morselizer(true), file.clone()) + .await + .unwrap(), + ) + .await; + let (_, rows_without_page_index) = count_batches_and_rows( + open_file(&make_morselizer(false), file).await.unwrap(), + ) + .await; + + assert_eq!( + rows_with_page_index, 10, + "page index should prune 9 of 10 pages" + ); + assert_eq!( + rows_without_page_index, 100, + "without page index all rows are returned" + ); + } + + async fn fully_matched_split_test_file( + store: Arc, + ) -> (SchemaRef, PartitionedFile) { + use parquet::file::properties::WriterProperties; + + let batch0 = + record_batch!(("a", Int32, vec![Some(1), Some(2), Some(3)])).unwrap(); + let batch1 = + record_batch!(("a", Int32, vec![Some(4), Some(5), Some(6)])).unwrap(); + let batch2 = + record_batch!(("a", Int32, vec![Some(7), Some(1), Some(2)])).unwrap(); + + let props = WriterProperties::builder() + .set_max_row_group_row_count(Some(3)) + .build(); + + let data_len = write_parquet_batches( + Arc::clone(&store), + "test.parquet", + vec![batch0.clone(), batch1, batch2], + Some(props), + ) + .await; + + let schema = batch0.schema(); + let file = PartitionedFile::new( + "test.parquet".to_string(), + u64::try_from(data_len).unwrap(), + ); + (schema, file) + } + + #[tokio::test] + async fn test_fully_matched_runs_respect_global_limit() { + let store = Arc::new(InMemory::new()) as Arc; + let (schema, file) = fully_matched_split_test_file(Arc::clone(&store)).await; + let predicate = logical2physical(&col("a").gt_eq(lit(3)), &schema); + + let opener = ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_schema(Arc::clone(&schema)) + .with_projection_indices(&[0]) + .with_predicate(predicate) + .with_pushdown_filters(true) + .with_row_group_stats_pruning(true) + .with_limit(4) + .build(); + + let values = collect_int32_values(open_file(&opener, file).await.unwrap()).await; + assert_eq!(values, vec![3, 4, 5, 6]); + } + + #[tokio::test] + async fn test_fully_matched_runs_preserve_reverse_order() { + let store = Arc::new(InMemory::new()) as Arc; + let (schema, file) = fully_matched_split_test_file(Arc::clone(&store)).await; + let predicate = logical2physical(&col("a").gt_eq(lit(3)), &schema); + + let opener = ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_schema(Arc::clone(&schema)) + .with_projection_indices(&[0]) + .with_predicate(predicate) + .with_pushdown_filters(true) + .with_row_group_stats_pruning(true) + .with_reverse_row_groups(true) + .build(); + + let values = collect_int32_values(open_file(&opener, file).await.unwrap()).await; + assert_eq!(values, vec![7, 4, 5, 6, 3]); + } + + #[test] + fn test_split_decoder_runs_no_fully_matched() { + // All row groups need filtering: single run. + let plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, + RowGroupAccess::Scan, + RowGroupAccess::Scan, + ]); + let runs = plan.split_runs(true); + assert_eq!(runs.len(), 1); + assert!(runs[0].needs_filter); + assert_eq!(runs[0].access_plan.row_group_indexes(), vec![0, 1, 2]); + } + + #[test] + fn test_split_decoder_runs_all_fully_matched() { + // All row groups are fully matched: single run, no filter. + let mut plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, + RowGroupAccess::Scan, + RowGroupAccess::Scan, + ]); + plan.mark_fully_matched(0); + plan.mark_fully_matched(1); + plan.mark_fully_matched(2); + + let runs = plan.split_runs(true); + assert_eq!(runs.len(), 1); + assert!(!runs[0].needs_filter); + assert_eq!(runs[0].access_plan.row_group_indexes(), vec![0, 1, 2]); + } + + #[test] + fn test_split_decoder_runs_mixed() { + // [F, M, M, F, M] creates 4 runs preserving order. + let mut plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, // 0: filtered + RowGroupAccess::Scan, // 1: matched + RowGroupAccess::Scan, // 2: matched + RowGroupAccess::Scan, // 3: filtered + RowGroupAccess::Scan, // 4: matched + ]); + plan.mark_fully_matched(1); + plan.mark_fully_matched(2); + plan.mark_fully_matched(4); + + let runs = plan.split_runs(true); + assert_eq!(runs.len(), 4); + + assert!(runs[0].needs_filter); + assert_eq!(runs[0].access_plan.row_group_indexes(), vec![0]); + + assert!(!runs[1].needs_filter); + assert_eq!(runs[1].access_plan.row_group_indexes(), vec![1, 2]); + + assert!(runs[2].needs_filter); + assert_eq!(runs[2].access_plan.row_group_indexes(), vec![3]); + + assert!(!runs[3].needs_filter); + assert_eq!(runs[3].access_plan.row_group_indexes(), vec![4]); + } + + #[test] + fn test_split_decoder_runs_with_skipped_groups() { + // Skipped row groups are excluded from all runs. + let mut plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, // 0: filtered + RowGroupAccess::Skip, // 1: pruned + RowGroupAccess::Scan, // 2: matched + RowGroupAccess::Scan, // 3: filtered + ]); + plan.mark_fully_matched(2); + + let runs = plan.split_runs(true); + assert_eq!(runs.len(), 3); + + assert!(runs[0].needs_filter); + assert_eq!(runs[0].access_plan.row_group_indexes(), vec![0]); + + assert!(!runs[1].needs_filter); + assert_eq!(runs[1].access_plan.row_group_indexes(), vec![2]); + + assert!(runs[2].needs_filter); + assert_eq!(runs[2].access_plan.row_group_indexes(), vec![3]); + } + + /// Helpers for tests that exercise parquet virtual columns + /// (e.g. `row_number`) plumbed through `TableSchema`/`ParquetOpener`. + mod virtual_columns { + use super::*; + use arrow::array::{Array, Int64Array}; + use arrow::datatypes::FieldRef; + use parquet::arrow::RowNumber; + + /// Build a parquet `row_number` virtual column field. Spark's + /// `_tmp_metadata_row_index` is declared nullable, so the default + /// matches that contract; tests that need `nullable=false` can + /// override via `with_nullable`. + fn row_number_field(name: &str, nullable: bool) -> FieldRef { + Arc::new( + Field::new(name, DataType::Int64, nullable) + .with_extension_type(RowNumber), + ) + } + + /// Collect every `Int64` value from the given column in every batch + /// of a stream. Used to verify the `row_number` column end to end. + async fn collect_int64_values( + mut stream: BoxStream<'static, Result>, + column: usize, + ) -> Vec { + let mut out = vec![]; + while let Some(batch) = stream.next().await { + let batch = batch.unwrap(); + let array = batch + .column(column) + .as_any() + .downcast_ref::() + .expect("expected Int64 column"); + for i in 0..array.len() { + assert!( + !array.is_null(i), + "row_number values produced by the reader must not be null" + ); + out.push(array.value(i)); + } + } + out + } + + /// Write a parquet file containing `num_row_groups` groups of + /// `rows_per_group` rows with a single `value` Int64 column. + /// Values are `0..num_row_groups*rows_per_group`. + async fn write_grouped_file( + store: &Arc, + path: &str, + num_row_groups: usize, + rows_per_group: usize, + ) -> (SchemaRef, usize) { + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int64, + false, + )])); + let mut batches = Vec::with_capacity(num_row_groups); + for g in 0..num_row_groups { + let start = (g * rows_per_group) as i64; + let values: Vec = (start..start + rows_per_group as i64).collect(); + batches.push( + RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int64Array::from(values))], + ) + .unwrap(), + ); + } + let props = WriterProperties::builder() + .set_max_row_group_row_count(Some(rows_per_group)) + .build(); + let data_size = + write_parquet_batches(Arc::clone(store), path, batches, Some(props)) + .await; + (schema, data_size) + } + + #[tokio::test] + async fn test_row_index_basic() { + let store = Arc::new(InMemory::new()) as Arc; + let (file_schema, data_size) = + write_grouped_file(&store, "basic.parquet", 1, 5).await; + + let rn_field = row_number_field("row_number", false); + let table_schema = TableSchemaBuilder::new(Arc::clone(&file_schema)) + .with_virtual_columns(vec![Arc::clone(&rn_field)]) + .build(); + // Project [value, row_number] — indices in table_schema are + // [0 file:value, 1 virtual:row_number]. + let projection = + ProjectionExprs::from_indices(&[0, 1], table_schema.table_schema()); + + let morselizer = ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_table_schema(table_schema) + .with_projection(projection) + .build(); + + let file = PartitionedFile::new( + "basic.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + let stream = open_file(&morselizer, file).await.unwrap(); + let row_numbers = collect_int64_values(stream, 1).await; + assert_eq!(row_numbers, vec![0, 1, 2, 3, 4]); + } + + #[tokio::test] + async fn test_row_index_projection_only() { + let store = Arc::new(InMemory::new()) as Arc; + let (file_schema, data_size) = + write_grouped_file(&store, "proj_only.parquet", 1, 4).await; + + let rn_field = row_number_field("row_number", false); + let table_schema = TableSchemaBuilder::new(Arc::clone(&file_schema)) + .with_virtual_columns(vec![Arc::clone(&rn_field)]) + .build(); + // Project only the virtual column (index 1). + let projection = + ProjectionExprs::from_indices(&[1], table_schema.table_schema()); + + let morselizer = ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_table_schema(table_schema) + .with_projection(projection) + .build(); + + let file = PartitionedFile::new( + "proj_only.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + let stream = open_file(&morselizer, file).await.unwrap(); + let row_numbers = collect_int64_values(stream, 0).await; + assert_eq!(row_numbers, vec![0, 1, 2, 3]); + } + + #[tokio::test] + async fn test_row_index_multi_row_group() { + let store = Arc::new(InMemory::new()) as Arc; + let (file_schema, data_size) = + write_grouped_file(&store, "multi_rg.parquet", 3, 100).await; + + let rn_field = row_number_field("row_number", false); + let table_schema = TableSchemaBuilder::new(Arc::clone(&file_schema)) + .with_virtual_columns(vec![Arc::clone(&rn_field)]) + .build(); + let projection = + ProjectionExprs::from_indices(&[0, 1], table_schema.table_schema()); + + let morselizer = ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_table_schema(table_schema) + .with_projection(projection) + .build(); + + let file = PartitionedFile::new( + "multi_rg.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + let stream = open_file(&morselizer, file).await.unwrap(); + let row_numbers = collect_int64_values(stream, 1).await; + let expected: Vec = (0..300).collect(); + assert_eq!(row_numbers, expected); + } + + #[tokio::test] + async fn test_row_index_with_row_group_skip() { + // 3 row groups of 100 rows. A predicate that excludes the middle + // row group (values 100..200) must leave absolute row numbers + // 0..100 and 200..300 intact — not 0..200. This guards against + // the arrow-rs bug fixed in apache/arrow-rs#8863. + let store = Arc::new(InMemory::new()) as Arc; + let (file_schema, data_size) = + write_grouped_file(&store, "rg_skip.parquet", 3, 100).await; + + let rn_field = row_number_field("row_number", false); + let table_schema = TableSchemaBuilder::new(Arc::clone(&file_schema)) + .with_virtual_columns(vec![Arc::clone(&rn_field)]) + .build(); + let projection = + ProjectionExprs::from_indices(&[0, 1], table_schema.table_schema()); + + // `value < 100 OR value >= 200` prunes the middle row group via + // min/max statistics. + let expr = col("value") + .lt(lit(100i64)) + .or(col("value").gt_eq(lit(200i64))); + let predicate = logical2physical(&expr, table_schema.table_schema()); + + let morselizer = ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_table_schema(table_schema) + .with_projection(projection) + .with_predicate(predicate) + .with_row_group_stats_pruning(true) + .build(); + + let file = PartitionedFile::new( + "rg_skip.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + let stream = open_file(&morselizer, file).await.unwrap(); + let row_numbers = collect_int64_values(stream, 1).await; + let expected: Vec = (0..100).chain(200..300).collect(); + assert_eq!(row_numbers, expected); + } + + #[tokio::test] + async fn test_row_index_with_partition_cols() { + let store = Arc::new(InMemory::new()) as Arc; + let (file_schema, data_size) = + write_grouped_file(&store, "part=5/data.parquet", 1, 3).await; + + let rn_field = row_number_field("row_number", false); + let partition_col = Arc::new(Field::new("part", DataType::Int32, false)); + let table_schema = TableSchemaBuilder::new(Arc::clone(&file_schema)) + .with_table_partition_cols(vec![Arc::clone(&partition_col)]) + .with_virtual_columns(vec![Arc::clone(&rn_field)]) + .build(); + // table_schema layout: [value(0), part(1), row_number(2)]. + let projection = + ProjectionExprs::from_indices(&[0, 1, 2], table_schema.table_schema()); + + let morselizer = ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_table_schema(table_schema) + .with_projection(projection) + .build(); + + let mut file = PartitionedFile::new( + "part=5/data.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + file.partition_values = vec![ScalarValue::Int32(Some(5))]; + + let stream = open_file(&morselizer, file).await.unwrap(); + let mut stream = stream; + let batch = stream.next().await.unwrap().unwrap(); + assert!(stream.next().await.is_none()); + + assert_eq!(batch.num_columns(), 3); + assert_eq!(batch.schema().field(0).name(), "value"); + assert_eq!(batch.schema().field(1).name(), "part"); + assert_eq!(batch.schema().field(2).name(), "row_number"); + + let part = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(part.iter().all(|v| v == Some(5))); + + let rn = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + let rn_values: Vec = (0..rn.len()).map(|i| rn.value(i)).collect(); + assert_eq!(rn_values, vec![0, 1, 2]); + } + + #[tokio::test] + async fn test_row_index_nullable_int64() { + // Spark declares `_tmp_metadata_row_index` nullable. Verify the + // nullability flag flows through unchanged. + let store = Arc::new(InMemory::new()) as Arc; + let (file_schema, data_size) = + write_grouped_file(&store, "nullable.parquet", 1, 3).await; + + let rn_field = row_number_field("_tmp_metadata_row_index", true); + let table_schema = TableSchemaBuilder::new(Arc::clone(&file_schema)) + .with_virtual_columns(vec![Arc::clone(&rn_field)]) + .build(); + let projection = + ProjectionExprs::from_indices(&[0, 1], table_schema.table_schema()); + + let morselizer = ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_table_schema(table_schema) + .with_projection(projection) + .build(); + + let file = PartitionedFile::new( + "nullable.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + let mut stream = open_file(&morselizer, file).await.unwrap(); + let batch = stream.next().await.unwrap().unwrap(); + + let schema_field = batch.schema().field(1).clone(); + assert_eq!(schema_field.name(), "_tmp_metadata_row_index"); + assert_eq!(schema_field.data_type(), &DataType::Int64); + assert!( + schema_field.is_nullable(), + "nullable flag should be preserved for Spark's row index field" + ); + } + + #[tokio::test] + async fn test_unsupported_virtual_extension_type_rejected() { + // Guard: opener must reject virtual columns carrying extension + // types outside the tested allowlist, rather than silently + // forwarding them to arrow-rs (where they would produce columns + // we have not validated against DataFusion's projection and + // predicate paths). + let store = Arc::new(InMemory::new()) as Arc; + let (file_schema, _data_size) = + write_grouped_file(&store, "unsupported.parquet", 1, 1).await; + + // RowGroupIndex is a real arrow-rs virtual type but is not in + // SUPPORTED_VIRTUAL_EXTENSION_TYPES until a test is added for it. + let rg_field = Arc::new( + Field::new("row_group_index", DataType::Int64, false) + .with_extension_type(parquet::arrow::RowGroupIndex), + ); + let table_schema = TableSchemaBuilder::new(Arc::clone(&file_schema)) + .with_virtual_columns(vec![rg_field]) + .build(); + let projection = + ProjectionExprs::from_indices(&[0, 1], table_schema.table_schema()); + + // Validation now happens at morselizer-build time (once per scan + // partition), not once per file inside `prepare_open_file`. + let err = ParquetMorselizerBuilder::new() + .with_store(Arc::clone(&store)) + .with_table_schema(table_schema) + .with_projection(projection) + .try_build() + .unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("parquet.virtual.row_group_index"), + "error should name the unsupported extension type, got: {msg}" + ); + } + + /// Build a morselizer + file for a 5-row single-row-group parquet at + /// `path`, with a single `row_number` virtual column and the given + /// physical predicate applied to + /// `table_schema = [value(0), row_number(1)]`. + async fn build_pushdown_morselizer( + store: &Arc, + path: &str, + predicate_expr: datafusion_expr::Expr, + pushdown_filters: bool, + ) -> Result<(ParquetMorselizer, PartitionedFile)> { + let (file_schema, data_size) = write_grouped_file(store, path, 1, 5).await; + let rn_field = row_number_field("row_number", false); + let table_schema = TableSchemaBuilder::new(Arc::clone(&file_schema)) + .with_virtual_columns(vec![Arc::clone(&rn_field)]) + .build(); + let projection = + ProjectionExprs::from_indices(&[0, 1], table_schema.table_schema()); + let predicate = + logical2physical(&predicate_expr, table_schema.table_schema()); + + let morselizer = ParquetMorselizerBuilder::new() + .with_store(Arc::clone(store)) + .with_table_schema(table_schema) + .with_projection(projection) + .with_predicate(predicate) + .with_pushdown_filters(pushdown_filters) + .try_build()?; + + let file = + PartitionedFile::new(path.to_string(), u64::try_from(data_size).unwrap()); + Ok((morselizer, file)) + } + + // The predicate-vs-virtual-column check rejects callers that bypass + // `ParquetSource::try_pushdown_filters` (which keeps virtual-col + // filters above the scan as a `FilterExec`) and set the predicate + // directly on the source with pushdown enabled. Without this guard, + // arrow-rs's `RowFilter` would silently drop the virtual-col conjunct + // and produce wrong results. + #[tokio::test] + async fn test_row_index_predicate_pushdown_mixed_or_errors() { + let store = Arc::new(InMemory::new()) as Arc; + let expr = col("row_number") + .eq(lit(2i64)) + .or(col("value").eq(lit(4i64))); + let err = + build_pushdown_morselizer(&store, "pushdown_mixed.parquet", expr, true) + .await + .unwrap_err(); + assert!( + err.to_string().contains("try_pushdown_filters"), + "error should mention try_pushdown_filters, got: {err}" + ); + } + + #[tokio::test] + async fn test_row_index_predicate_pushdown_virtual_only_errors() { + let store = Arc::new(InMemory::new()) as Arc; + let expr = col("row_number").eq(lit(2i64)); + let err = build_pushdown_morselizer( + &store, + "pushdown_virtual_only.parquet", + expr, + true, + ) + .await + .unwrap_err(); + assert!( + err.to_string().contains("try_pushdown_filters"), + "error should mention try_pushdown_filters, got: {err}" + ); + } + + #[tokio::test] + async fn test_row_index_predicate_allowed_when_pushdown_disabled() { + // Guards the `pushdown_filters=false` path: the predicate is only + // used for stats pruning (a no-op for row_number) and must not + // trip the virtual-column check. + let store = Arc::new(InMemory::new()) as Arc; + let expr = col("row_number").eq(lit(2i64)); + let (morselizer, file) = + build_pushdown_morselizer(&store, "pushdown_off.parquet", expr, false) + .await + .unwrap(); + + let stream = open_file(&morselizer, file).await.unwrap(); + let (_batches, rows) = count_batches_and_rows(stream).await; + assert_eq!(rows, 5); + } + } +} diff --git a/datafusion/datasource-parquet/src/page_filter.rs b/datafusion/datasource-parquet/src/page_filter.rs index 9f4e52c513cf5..795a63268b6a9 100644 --- a/datafusion/datasource-parquet/src/page_filter.rs +++ b/datafusion/datasource-parquet/src/page_filter.rs @@ -28,9 +28,9 @@ use arrow::{ array::ArrayRef, datatypes::{Schema, SchemaRef}, }; -use datafusion_common::pruning::PruningStatistics; use datafusion_common::ScalarValue; -use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; +use datafusion_common::pruning::PruningStatistics; +use datafusion_physical_expr::{PhysicalExpr, split_conjunction}; use datafusion_pruning::PruningPredicate; use log::{debug, trace}; @@ -115,6 +115,26 @@ pub struct PagePruningAccessPlanFilter { predicates: Vec, } +/// Result of applying page-index pruning to a [`ParquetAccessPlan`]. +pub(crate) struct PagePruningResult { + pub(crate) access_plan: ParquetAccessPlan, + /// Pages skipped because the containing row group was fully matched by + /// row-group statistics. + pub(crate) pages_skipped_by_fully_matched: usize, +} + +impl PagePruningResult { + fn new( + access_plan: ParquetAccessPlan, + pages_skipped_by_fully_matched: usize, + ) -> Self { + Self { + access_plan, + pages_skipped_by_fully_matched, + } + } +} + impl PagePruningAccessPlanFilter { /// Create a new [`PagePruningAccessPlanFilter`] from a physical /// expression. @@ -155,50 +175,101 @@ impl PagePruningAccessPlanFilter { /// parquet page index, if any pub fn prune_plan_with_page_index( &self, - mut access_plan: ParquetAccessPlan, + access_plan: ParquetAccessPlan, arrow_schema: &Schema, parquet_schema: &SchemaDescriptor, parquet_metadata: &ParquetMetaData, file_metrics: &ParquetFileMetrics, ) -> ParquetAccessPlan { + self.prune_plan_with_page_index_and_metrics( + access_plan, + arrow_schema, + parquet_schema, + parquet_metadata, + file_metrics, + ) + .access_plan + } + + /// Returns an updated [`ParquetAccessPlan`] and metrics by applying predicates + /// to the parquet page index, if any. + pub(crate) fn prune_plan_with_page_index_and_metrics( + &self, + mut access_plan: ParquetAccessPlan, + arrow_schema: &Schema, + parquet_schema: &SchemaDescriptor, + parquet_metadata: &ParquetMetaData, + file_metrics: &ParquetFileMetrics, + ) -> PagePruningResult { // scoped timer updates on drop let _timer_guard = file_metrics.page_index_eval_time.timer(); if self.predicates.is_empty() { - return access_plan; + return PagePruningResult::new(access_plan, 0); } let page_index_predicates = &self.predicates; let groups = parquet_metadata.row_groups(); if groups.is_empty() { - return access_plan; + return PagePruningResult::new(access_plan, 0); } if parquet_metadata.offset_index().is_none() || parquet_metadata.column_index().is_none() { debug!( - "Can not prune pages due to lack of indexes. Have offset: {}, column index: {}", - parquet_metadata.offset_index().is_some(), parquet_metadata.column_index().is_some() - ); - return access_plan; + "Can not prune pages due to lack of indexes. Have offset: {}, column index: {}", + parquet_metadata.offset_index().is_some(), + parquet_metadata.column_index().is_some() + ); + return PagePruningResult::new(access_plan, 0); }; // track the total number of rows that should be skipped let mut total_skip = 0; // track the total number of rows that should not be skipped let mut total_select = 0; + // track the total number of pages that should be skipped + let mut total_pages_skip = 0; + // track the total number of pages that should not be skipped + let mut total_pages_select = 0; + // track pages for which page-index pruning was skipped because the + // containing row group was already proven fully matched by statistics + let mut total_pages_skipped_by_fully_matched = 0; // for each row group specified in the access plan let row_group_indexes = access_plan.row_group_indexes(); for row_group_index in row_group_indexes { + // Skip page pruning for fully matched row groups: all rows are + // known to satisfy the predicate, so page-level pruning is wasted work. + if access_plan.is_fully_matched(row_group_index) { + let page_count = + fully_matched_page_count(row_group_index, parquet_metadata); + total_pages_skipped_by_fully_matched += page_count; + + continue; + } // The selection for this particular row group let mut overall_selection = None; + + let total_pages_in_group = + parquet_metadata.offset_index().map_or(0, |offset_index| { + offset_index[row_group_index] + .first() + .map_or(0, |column| column.page_locations.len()) + }); + // stores the indexes of the matched pages + let mut matched_pages_in_group: HashSet = + HashSet::from_iter(0..total_pages_in_group); + for predicate in page_index_predicates { - let column = predicate - .required_columns() - .single_column() - .expect("Page pruning requires single column predicates"); + let Some(column) = predicate.required_columns().single_column() else { + debug!( + "Ignoring multi-column page pruning predicate: {:?}", + predicate.predicate_expr() + ); + continue; + }; let converter = StatisticsConverter::try_new( column.name(), @@ -225,16 +296,25 @@ impl PagePruningAccessPlanFilter { file_metrics, ); - let Some(selection) = selection else { + let Some((selection, pages)) = selection else { trace!("No pages pruned in prune_pages_in_one_row_group"); continue; }; - debug!("Use filter and page index to create RowSelection {:?} from predicate: {:?}", + debug!( + "Use filter and page index to create RowSelection {:?} from predicate: {:?}", &selection, predicate.predicate_expr(), ); + let matched_pages_indexes: HashSet<_> = pages + .into_iter() + .enumerate() + .filter(|x| x.1) + .map(|x| x.0) + .collect(); + matched_pages_in_group.retain(|x| matched_pages_indexes.contains(x)); + overall_selection = update_selection(overall_selection, selection); // if the overall selection has ruled out all rows, no need to @@ -253,7 +333,9 @@ impl PagePruningAccessPlanFilter { let rows_selected = overall_selection.row_count(); if rows_selected > 0 { let rows_skipped = overall_selection.skipped_row_count(); - trace!("Overall selection from predicate skipped {rows_skipped}, selected {rows_selected}: {overall_selection:?}"); + trace!( + "Overall selection from predicate skipped {rows_skipped}, selected {rows_selected}: {overall_selection:?}" + ); total_skip += rows_skipped; total_select += rows_selected; access_plan.scan_selection(row_group_index, overall_selection) @@ -267,14 +349,27 @@ impl PagePruningAccessPlanFilter { skipping all {rows_skipped} rows in row group {row_group_index}" ); } + } else { + total_select += + parquet_metadata.row_group(row_group_index).num_rows() as usize; } + + let pages_matched = matched_pages_in_group.len(); + total_pages_select += pages_matched; + total_pages_skip += total_pages_in_group - pages_matched; } file_metrics.page_index_rows_pruned.add_pruned(total_skip); file_metrics .page_index_rows_pruned .add_matched(total_select); - access_plan + file_metrics + .page_index_pages_pruned + .add_pruned(total_pages_skip); + file_metrics + .page_index_pages_pruned + .add_matched(total_pages_select); + PagePruningResult::new(access_plan, total_pages_skipped_by_fully_matched) } /// Returns the number of filters in the [`PagePruningAccessPlanFilter`] @@ -293,7 +388,21 @@ fn update_selection( } } -/// Returns a [`RowSelection`] for the rows in this row group to scan. +/// Returns the number of pages for which page-index pruning is skipped because +/// the containing row group is fully matched by row-group statistics. +fn fully_matched_page_count( + row_group_index: usize, + parquet_metadata: &ParquetMetaData, +) -> usize { + parquet_metadata.offset_index().map_or(0, |offset_index| { + offset_index[row_group_index] + .first() + .map_or(0, |column| column.page_locations.len()) + }) +} + +/// Returns a [`RowSelection`] for the rows in this row group to scan, in addition to a vec of +/// booleans that state if each page was matched (true) or not (false). /// /// This Row Selection is formed from the page index and the predicate skips row /// ranges that can be ruled out based on the predicate. @@ -306,7 +415,7 @@ fn prune_pages_in_one_row_group( converter: StatisticsConverter<'_>, parquet_metadata: &ParquetMetaData, metrics: &ParquetFileMetrics, -) -> Option { +) -> Option<(RowSelection, Vec)> { let pruning_stats = PagesPruningStatistics::try_new(row_group_index, converter, parquet_metadata)?; @@ -358,7 +467,8 @@ fn prune_pages_in_one_row_group( RowSelector::skip(sum_row) }; vec.push(selector); - Some(RowSelection::from(vec)) + + Some((RowSelection::from(vec), values)) } /// Implement [`PruningStatistics`] for one column's PageIndex (column_index + offset_index) @@ -488,7 +598,7 @@ impl PruningStatistics for PagesPruningStatistics<'_> { } } - fn row_counts(&self, _column: &datafusion_common::Column) -> Option { + fn row_counts(&self) -> Option { match self.converter.data_page_row_counts( self.offset_index, self.row_group_metadatas, diff --git a/datafusion/datasource-parquet/src/push_decoder.rs b/datafusion/datasource-parquet/src/push_decoder.rs new file mode 100644 index 0000000000000..3156b9e35fe24 --- /dev/null +++ b/datafusion/datasource-parquet/src/push_decoder.rs @@ -0,0 +1,221 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Push-based Parquet decoder setup and stream driver. +//! +//! This module owns the push-decoder lifecycle: +//! +//! - [`DecoderBuilderConfig`] holds the shared options applied to every +//! [`ParquetPushDecoderBuilder`] in a file scan, exposing a single `build` +//! entry point per decoder run. +//! - [`PushDecoderStreamState`] is the per-file stream driver that polls one +//! or more decoders to completion, yielding projected [`RecordBatch`]es. +//! A scan can produce multiple decoders (for example, when fully matched +//! row groups split it into runs with different filter requirements); the +//! state machine drains them in order so the output is contiguous. +//! +//! The opener constructs both halves and hands the state off to +//! [`PushDecoderStreamState::into_stream`] for consumption. + +use std::collections::VecDeque; + +use arrow::array::RecordBatch; +use futures::StreamExt; +use futures::stream::BoxStream; +use parquet::DecodeResult; +use parquet::arrow::ProjectionMask; +use parquet::arrow::arrow_reader::metrics::ArrowReaderMetrics; +use parquet::arrow::arrow_reader::{ArrowReaderMetadata, RowSelectionPolicy}; +use parquet::arrow::async_reader::AsyncFileReader; +use parquet::arrow::push_decoder::{ParquetPushDecoder, ParquetPushDecoderBuilder}; + +use datafusion_common::{DataFusionError, Result}; +use datafusion_physical_plan::metrics::{BaselineMetrics, Gauge}; + +use crate::access_plan::PreparedAccessPlan; +use crate::decoder_projection::DecoderProjection; + +/// Shared options applied to every [`ParquetPushDecoderBuilder`] in a file scan. +/// +/// A single scan may produce multiple decoders (for example, when fully matched +/// row groups split the scan into consecutive runs with different filter +/// requirements). All decoders in that scan share the same projection, batch +/// size, metrics sink, and selection policy. +pub(crate) struct DecoderBuilderConfig<'a> { + /// Projection mask installed on every decoder in the scan. Sourced from + /// the file's [`DecoderProjection`]. + pub(crate) projection_mask: &'a ProjectionMask, + pub(crate) batch_size: usize, + pub(crate) arrow_reader_metrics: &'a ArrowReaderMetrics, + pub(crate) force_filter_selections: bool, + pub(crate) decoder_limit: Option, +} + +impl DecoderBuilderConfig<'_> { + /// Build a [`ParquetPushDecoderBuilder`] for a single decoder run. + /// + /// The caller is expected to attach the run-specific + /// [`RowFilter`](parquet::arrow::arrow_reader::RowFilter) and predicate + /// cache size on the returned builder. + pub(crate) fn build( + &self, + prepared_access_plan: PreparedAccessPlan, + metadata: ArrowReaderMetadata, + ) -> ParquetPushDecoderBuilder { + let mut builder = ParquetPushDecoderBuilder::new_with_metadata(metadata) + .with_projection(self.projection_mask.clone()) + .with_batch_size(self.batch_size) + .with_metrics(self.arrow_reader_metrics.clone()); + if self.force_filter_selections { + builder = builder.with_row_selection_policy(RowSelectionPolicy::Selectors); + } + if let Some(row_selection) = prepared_access_plan.row_selection { + builder = builder.with_row_selection(row_selection); + } + builder = builder.with_row_groups(prepared_access_plan.row_group_indexes); + if let Some(limit) = self.decoder_limit { + builder = builder.with_limit(limit); + } + builder + } +} + +/// State for a stream that decodes a single Parquet file using a push-based decoder. +/// +/// The [`transition`](Self::transition) method drives the decoder in a loop: it requests +/// byte ranges from the [`AsyncFileReader`], pushes the fetched data into the +/// [`ParquetPushDecoder`], and yields projected [`RecordBatch`]es until the file is +/// fully consumed. +pub(crate) struct PushDecoderStreamState { + pub(crate) decoder: ParquetPushDecoder, + /// Additional decoders to process after the current one finishes. + /// Used when fully matched row groups split the scan into consecutive + /// runs with different filter configurations, maintaining original order. + pub(crate) pending_decoders: VecDeque, + /// Global remaining row limit across all decoder runs. + /// + /// Decoder-local limits are only safe for single-run scans. When the scan + /// is split across multiple decoders, the combined stream limit is enforced + /// here instead. + pub(crate) remaining_limit: Option, + pub(crate) reader: Box, + /// Per-file projection: the mask installed on every decoder and the + /// per-batch transform applied by [`Self::project_batch`]. + pub(crate) decoder_projection: DecoderProjection, + pub(crate) arrow_reader_metrics: ArrowReaderMetrics, + pub(crate) predicate_cache_inner_records: Gauge, + pub(crate) predicate_cache_records: Gauge, + pub(crate) baseline_metrics: BaselineMetrics, +} + +impl PushDecoderStreamState { + /// Drive the state machine to completion as a [`futures::Stream`] of record batches. + /// + /// The returned stream is fused and boxed so the caller can wrap it (for + /// example, with an early-stopping adapter) without naming the unfold type. + pub(crate) fn into_stream(self) -> BoxStream<'static, Result> { + futures::stream::unfold(self, |state| async move { state.transition().await }) + .fuse() + .boxed() + } + + /// Advances the decoder state machine until the next [`RecordBatch`] is + /// produced, the file is fully consumed, or an error occurs. + /// + /// On each iteration the decoder is polled via [`ParquetPushDecoder::try_decode`]: + /// - [`NeedsData`](DecodeResult::NeedsData) – the requested byte ranges are + /// fetched from the [`AsyncFileReader`] and fed back into the decoder. + /// - [`Data`](DecodeResult::Data) – a decoded batch is projected and returned. + /// - [`Finished`](DecodeResult::Finished) – signals end-of-stream (`None`). + /// + /// Takes `self` by value (rather than `&mut self`) so the generated future + /// owns the state directly. This avoids a Stacked Borrows violation under + /// miri where `&mut self` creates a single opaque borrow that conflicts + /// with `unfold`'s ownership across yield points. + async fn transition(mut self) -> Option<(Result, Self)> { + loop { + if self.remaining_limit == Some(0) { + return None; + } + match self.decoder.try_decode() { + Ok(DecodeResult::NeedsData(ranges)) => { + let data = self + .reader + .get_byte_ranges(ranges.clone()) + .await + .map_err(DataFusionError::from); + match data { + Ok(data) => { + if let Err(e) = self.decoder.push_ranges(ranges, data) { + return Some((Err(DataFusionError::from(e)), self)); + } + } + Err(e) => return Some((Err(e), self)), + } + } + Ok(DecodeResult::Data(batch)) => { + let batch = if let Some(remaining_limit) = self.remaining_limit { + if batch.num_rows() > remaining_limit { + self.remaining_limit = Some(0); + batch.slice(0, remaining_limit) + } else { + self.remaining_limit = + Some(remaining_limit - batch.num_rows()); + batch + } + } else { + batch + }; + let mut timer = self.baseline_metrics.elapsed_compute().timer(); + self.copy_arrow_reader_metrics(); + let result = self.project_batch(&batch); + timer.stop(); + // Release the borrow on baseline_metrics before moving self + drop(timer); + return Some((result, self)); + } + Ok(DecodeResult::Finished) => { + // If there are pending decoders (e.g. for consecutive runs + // with different filter configurations), switch to the next. + if let Some(next) = self.pending_decoders.pop_front() { + self.decoder = next; + continue; + } + return None; + } + Err(e) => { + return Some((Err(DataFusionError::from(e)), self)); + } + } + } + } + + /// Copies metrics from ArrowReaderMetrics (the metrics collected by the + /// arrow-rs parquet reader) to the parquet file metrics for DataFusion + fn copy_arrow_reader_metrics(&self) { + if let Some(v) = self.arrow_reader_metrics.records_read_from_inner() { + self.predicate_cache_inner_records.set(v); + } + if let Some(v) = self.arrow_reader_metrics.records_read_from_cache() { + self.predicate_cache_records.set(v); + } + } + + fn project_batch(&self, batch: &RecordBatch) -> Result { + self.decoder_projection.map(batch) + } +} diff --git a/datafusion/datasource-parquet/src/reader.rs b/datafusion/datasource-parquet/src/reader.rs index 59a5da7b9d97c..482bf8dced4f8 100644 --- a/datafusion/datasource-parquet/src/reader.rs +++ b/datafusion/datasource-parquet/src/reader.rs @@ -18,15 +18,15 @@ //! [`ParquetFileReaderFactory`] and [`DefaultParquetFileReaderFactory`] for //! low level control of parquet file readers -use crate::metadata::DFParquetMetadata; use crate::ParquetFileMetrics; +use crate::metadata::DFParquetMetadata; use bytes::Bytes; use datafusion_datasource::PartitionedFile; use datafusion_execution::cache::cache_manager::FileMetadata; use datafusion_execution::cache::cache_manager::FileMetadataCache; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; -use futures::future::BoxFuture; use futures::FutureExt; +use futures::future::BoxFuture; use object_store::ObjectStore; use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::async_reader::{AsyncFileReader, ParquetObjectReader}; @@ -37,7 +37,7 @@ use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; -/// Interface for reading parquet files. +/// Interface for reading Apache Parquet files. /// /// The combined implementations of [`ParquetFileReaderFactory`] and /// [`AsyncFileReader`] can be used to provide custom data access operations @@ -289,7 +289,8 @@ impl AsyncFileReader for CachedParquetFileReader { fn get_metadata<'a>( &'a mut self, - #[allow(unused_variables)] options: Option<&'a ArrowReaderOptions>, + #[cfg_attr(not(feature = "parquet_encryption"), expect(unused_variables))] + options: Option<&'a ArrowReaderOptions>, ) -> BoxFuture<'a, parquet::errors::Result>> { let object_meta = self.partitioned_file.object_meta.clone(); let metadata_cache = Arc::clone(&self.metadata_cache); diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index 660b32f486120..f19dbd6c6fa63 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -50,38 +50,46 @@ //! 2. Determine whether each predicate can be evaluated as an `ArrowPredicate`. //! 3. Determine, for each predicate, the total compressed size of all //! columns required to evaluate the predicate. -//! 4. Determine, for each predicate, whether all columns required to -//! evaluate the expression are sorted. -//! 5. Re-order the predicate by total size (from step 3). -//! 6. Partition the predicates according to whether they are sorted (from step 4) -//! 7. "Compile" each predicate `Expr` to a `DatafusionArrowPredicate`. -//! 8. Build the `RowFilter` with the sorted predicates followed by -//! the unsorted predicates. Within each partition, predicates are -//! still be sorted by size. - -use std::cmp::Ordering; +//! 4. Re-order predicates by total size (from step 3). +//! 5. "Compile" each predicate `Expr` to a `DatafusionArrowPredicate`. +//! 6. Build the `RowFilter` from the ordered predicates. +//! +//! List-aware predicates (for example, `array_has`, `array_has_all`, and +//! `array_has_any`) can be evaluated directly during Parquet decoding. +//! Struct field access via `get_field` is also supported when the accessed +//! leaf is a primitive type. Filters that reference entire struct columns +//! rather than individual fields cannot be pushed down and are instead +//! evaluated after the full batches are materialized. +//! +//! For example, given a struct column `s {name: Utf8, value: Int32}`: +//! - `WHERE s['value'] > 5` — pushed down (accesses a primitive leaf) +//! - `WHERE s IS NOT NULL` — not pushed down (references the whole struct) + use std::collections::BTreeSet; use std::sync::Arc; use arrow::array::BooleanArray; -use arrow::datatypes::{DataType, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; -use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter}; +use datafusion_functions::core::getfield::GetFieldFunc; use parquet::arrow::ProjectionMask; +use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter}; use parquet::file::metadata::ParquetMetaData; +use parquet::schema::types::SchemaDescriptor; +use datafusion_common::Result; use datafusion_common::cast::as_boolean_array; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; -use datafusion_common::Result; -use datafusion_datasource::schema_adapter::{SchemaAdapterFactory, SchemaMapper}; -use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::utils::reassign_expr_columns; -use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; +use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_expr::expressions::{Column, Literal}; +use datafusion_physical_expr::utils::{collect_columns, reassign_expr_columns}; +use datafusion_physical_expr::{PhysicalExpr, split_conjunction}; use datafusion_physical_plan::metrics; use super::ParquetFileMetrics; +use super::supported_predicates::supports_list_predicates; /// A "compiled" predicate passed to `ParquetRecordBatchStream` to perform /// row-level filtering during parquet decoding. @@ -92,12 +100,17 @@ use super::ParquetFileMetrics; /// /// An expression can be evaluated as a `DatafusionArrowPredicate` if it: /// * Does not reference any projected columns -/// * Does not reference columns with non-primitive types (e.g. structs / lists) +/// * References either primitive columns or list columns used by +/// supported predicates (such as `array_has_all` or NULL checks). +/// * References struct fields via `get_field` where the accessed leaf +/// is a primitive type (e.g. `get_field(struct_col, 'field') > 5`). +/// Direct references to whole struct columns are still evaluated after +/// decoding. #[derive(Debug)] pub(crate) struct DatafusionArrowPredicate { /// the filter expression physical_expr: Arc, - /// Path to the columns in the parquet schema required to evaluate the + /// Path to the leaf columns in the parquet schema required to evaluate the /// expression projection_mask: ProjectionMask, /// how many rows were filtered out by this predicate @@ -106,32 +119,25 @@ pub(crate) struct DatafusionArrowPredicate { rows_matched: metrics::Count, /// how long was spent evaluating this predicate time: metrics::Time, - /// used to perform type coercion while filtering rows - schema_mapper: Arc, } impl DatafusionArrowPredicate { /// Create a new `DatafusionArrowPredicate` from a `FilterCandidate` pub fn try_new( candidate: FilterCandidate, - metadata: &ParquetMetaData, rows_pruned: metrics::Count, rows_matched: metrics::Count, time: metrics::Time, ) -> Result { let physical_expr = - reassign_expr_columns(candidate.expr, &candidate.filter_schema)?; + reassign_expr_columns(candidate.expr, &candidate.read_plan.projected_schema)?; Ok(Self { physical_expr, - projection_mask: ProjectionMask::roots( - metadata.file_metadata().schema_descr(), - candidate.projection, - ), + projection_mask: candidate.read_plan.projection_mask, rows_pruned, rows_matched, time, - schema_mapper: candidate.schema_mapper, }) } } @@ -142,8 +148,6 @@ impl ArrowPredicate for DatafusionArrowPredicate { } fn evaluate(&mut self, batch: RecordBatch) -> ArrowResult { - let batch = self.schema_mapper.map_batch(batch)?; - // scoped timer updates on drop let mut timer = self.time.timer(); @@ -180,75 +184,45 @@ pub(crate) struct FilterCandidate { /// the filter and to order the filters when `reorder_predicates` is true. /// This is generated by summing the compressed size of all columns that the filter references. required_bytes: usize, - /// Can this filter use an index (e.g. a page index) to prune rows? - can_use_index: bool, - /// The projection to read from the file schema to get the columns - /// required to pass through a `SchemaMapper` to the table schema - /// upon which we then evaluate the filter expression. - projection: Vec, - /// A `SchemaMapper` used to map batches read from the file schema to - /// the filter's projection of the table schema. - schema_mapper: Arc, - /// The projected table schema that this filter references - filter_schema: SchemaRef, + /// The resolved Parquet read plan (leaf indices + projected schema). + read_plan: ParquetReadPlan, +} + +/// The result of resolving which Parquet leaf columns and Arrow schema fields +/// are needed to evaluate an expression against a Parquet file +/// +/// This is the shared output of the column resolution pipeline used by both +/// the row filter to build `ArrowPredicate`s and the opener to build `ProjectionMask`s +#[derive(Debug, Clone)] +pub(crate) struct ParquetReadPlan { + /// Projection mask built from leaf column indices in the Parquet schema. + /// Using a `ProjectionMask` directly (rather than raw indices) prevents + /// bugs from accidentally mixing up root vs leaf indices. + pub projection_mask: ProjectionMask, + /// The projected Arrow schema containing only the columns/fields required + /// Struct types are pruned to include only the accessed sub-fields + pub projected_schema: SchemaRef, } /// Helper to build a `FilterCandidate`. /// -/// This will do several things +/// This will do several things: /// 1. Determine the columns required to evaluate the expression /// 2. Calculate data required to estimate the cost of evaluating the filter -/// 3. Rewrite column expressions in the predicate which reference columns not -/// in the particular file schema. -/// -/// # Schema Rewrite -/// -/// When parquet files are read in the context of "schema evolution" there are -/// potentially wo schemas: -/// -/// 1. The table schema (the columns of the table that the parquet file is part of) -/// 2. The file schema (the columns actually in the parquet file) -/// -/// There are times when the table schema contains columns that are not in the -/// file schema, such as when new columns have been added in new parquet files -/// but old files do not have the columns. /// -/// When a file is missing a column from the table schema, the value of the -/// missing column is filled in by a `SchemaAdapter` (by default as `NULL`). -/// -/// When a predicate is pushed down to the parquet reader, the predicate is -/// evaluated in the context of the file schema. -/// For each predicate we build a filter schema which is the projection of the table -/// schema that contains only the columns that this filter references. -/// If any columns from the file schema are missing from a particular file they are -/// added by the `SchemaAdapter`, by default as `NULL`. +/// Note: This does *not* handle any adaptation of the expression to the file schema. +/// The expression must already be adapted before being passed in here, generally using +/// [`PhysicalExprAdapter`](datafusion_physical_expr_adapter::PhysicalExprAdapter). struct FilterCandidateBuilder { expr: Arc, - /// The schema of this parquet file. - /// Columns may have different types from the table schema and there may be - /// columns in the file schema that are not in the table schema or columns that - /// are in the table schema that are not in the file schema. + /// The Arrow schema of this parquet file (the result of converting the + /// parquet schema to Arrow, potentially with type coercions applied). file_schema: SchemaRef, - /// The schema of the table (merged schema) -- columns may be in different - /// order than in the file and have columns that are not in the file schema - table_schema: SchemaRef, - /// A `SchemaAdapterFactory` used to map the file schema to the table schema. - schema_adapter_factory: Arc, } impl FilterCandidateBuilder { - pub fn new( - expr: Arc, - file_schema: Arc, - table_schema: Arc, - schema_adapter_factory: Arc, - ) -> Self { - Self { - expr, - file_schema, - table_schema, - schema_adapter_factory, - } + pub fn new(expr: Arc, file_schema: Arc) -> Self { + Self { expr, file_schema } } /// Attempt to build a `FilterCandidate` from the expression @@ -259,118 +233,737 @@ impl FilterCandidateBuilder { /// * `Ok(None)` if the expression cannot be used as an ArrowFilter /// * `Err(e)` if an error occurs while building the candidate pub fn build(self, metadata: &ParquetMetaData) -> Result> { - let Some(required_indices_into_table_schema) = - pushdown_columns(&self.expr, &self.table_schema)? - else { - return Ok(None); - }; - - let projected_table_schema = Arc::new( - self.table_schema - .project(&required_indices_into_table_schema)?, - ); - - let (schema_mapper, projection_into_file_schema) = self - .schema_adapter_factory - .create(Arc::clone(&projected_table_schema), self.table_schema) - .map_schema(&self.file_schema)?; - - let required_bytes = size_of_columns(&projection_into_file_schema, metadata)?; - let can_use_index = columns_sorted(&projection_into_file_schema, metadata)?; - - Ok(Some(FilterCandidate { - expr: self.expr, - required_bytes, - can_use_index, - projection: projection_into_file_schema, - schema_mapper: Arc::clone(&schema_mapper), - filter_schema: Arc::clone(&projected_table_schema), - })) + Ok( + build_parquet_read_plan(&self.expr, &self.file_schema, metadata)?.map( + |(read_plan, required_bytes)| FilterCandidate { + expr: self.expr, + required_bytes, + read_plan, + }, + ), + ) } } -// a struct that implements TreeNodeRewriter to traverse a PhysicalExpr tree structure to determine -// if any column references in the expression would prevent it from being predicate-pushed-down. -// if non_primitive_columns || projected_columns, it can't be pushed down. -// can't be reused between calls to `rewrite`; each construction must be used only once. +/// Traverses a `PhysicalExpr` tree to determine if any column references would +/// prevent the expression from being pushed down to the parquet decoder. +/// +/// An expression cannot be pushed down if it references: +/// - Unsupported nested columns (whole struct references or list fields that are +/// not covered by the supported predicate set) +/// - Columns that don't exist in the file schema +/// +/// Struct field access via `get_field` is supported when the resolved leaf type +/// is primitive (e.g. `get_field(struct_col, 'field') > 5`). struct PushdownChecker<'schema> { /// Does the expression require any non-primitive columns (like structs)? non_primitive_columns: bool, - /// Does the expression reference any columns that are in the table - /// schema but not in the file schema? - /// This includes partition columns and projected columns. + /// Does the expression reference any columns not present in the file schema? projected_columns: bool, - // Indices into the table schema of the columns required to evaluate the expression - required_columns: BTreeSet, - table_schema: &'schema Schema, + /// Indices into the file schema of columns required to evaluate the expression. + /// Does not include struct columns accessed via `get_field`. + required_columns: Vec, + /// Struct field accesses via `get_field`. + struct_field_accesses: Vec, + /// Whether nested list columns are supported by the predicate semantics. + allow_list_columns: bool, + /// The Arrow schema of the parquet file. + file_schema: &'schema Schema, } impl<'schema> PushdownChecker<'schema> { - fn new(table_schema: &'schema Schema) -> Self { + fn new(file_schema: &'schema Schema, allow_list_columns: bool) -> Self { Self { non_primitive_columns: false, projected_columns: false, - required_columns: BTreeSet::default(), - table_schema, + required_columns: Vec::new(), + struct_field_accesses: Vec::new(), + allow_list_columns, + file_schema, } } + /// Checks whether a struct's root column exists in the file schema and, if so, + /// records its index so the entire struct is decoded for filter evaluation. + /// + /// This is called when we see a `get_field` expression that resolves to a + /// primitive leaf type. We only need the *root* column index because the + /// Parquet reader decodes all leaves of a struct together. + /// + /// # Example + /// + /// Given file schema `{a: Int32, s: Struct(foo: Utf8, bar: Int64)}` and the + /// expression `get_field(s, 'foo') = 'hello'`: + /// + /// - `column_name` = `"s"` (the root struct column) + /// - `file_schema.index_of("s")` returns `1` + /// - We push `1` into `required_columns` + /// - Return `None` (no issue — traversal continues in the caller) + /// + /// If `"s"` is not in the file schema (e.g. a projected-away column), we set + /// `projected_columns = true` and return `Jump` to skip the subtree. + fn check_struct_field_column( + &mut self, + column_name: &str, + field_path: Vec, + ) -> Option { + let Ok(idx) = self.file_schema.index_of(column_name) else { + self.projected_columns = true; + return Some(TreeNodeRecursion::Jump); + }; + + self.struct_field_accesses.push(StructFieldAccess { + root_index: idx, + field_path, + }); + + None + } + fn check_single_column(&mut self, column_name: &str) -> Option { - if let Ok(idx) = self.table_schema.index_of(column_name) { - self.required_columns.insert(idx); - if DataType::is_nested(self.table_schema.field(idx).data_type()) { - self.non_primitive_columns = true; + let idx = match self.file_schema.index_of(column_name) { + Ok(idx) => idx, + Err(_) => { + // Column does not exist in the file schema, so we can't push this down. + self.projected_columns = true; return Some(TreeNodeRecursion::Jump); } + }; + + // Duplicates are handled by dedup() in into_sorted_columns() + self.required_columns.push(idx); + let data_type = self.file_schema.field(idx).data_type(); + + if DataType::is_nested(data_type) { + self.handle_nested_type(data_type) } else { - // If the column does not exist in the (un-projected) table schema then - // it must be a projected column. - self.projected_columns = true; - return Some(TreeNodeRecursion::Jump); + None } + } - None + /// Determines whether a nested data type can be pushed down to Parquet decoding. + /// + /// Returns `Some(TreeNodeRecursion::Jump)` if the nested type prevents pushdown, + /// `None` if the type is supported and pushdown can continue. + fn handle_nested_type(&mut self, data_type: &DataType) -> Option { + if self.is_nested_type_supported(data_type) { + None + } else { + // Block pushdown for unsupported nested types: + // - Structs (regardless of predicate support) + // - Lists without supported predicates + self.non_primitive_columns = true; + Some(TreeNodeRecursion::Jump) + } + } + + /// Checks if a nested data type is supported for list column pushdown. + /// + /// List columns are only supported if: + /// 1. The data type is a list variant (List, LargeList, or FixedSizeList) + /// 2. The expression contains supported list predicates (e.g., array_has_all) + fn is_nested_type_supported(&self, data_type: &DataType) -> bool { + let is_list = matches!( + data_type, + DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) + ); + self.allow_list_columns && is_list } #[inline] fn prevents_pushdown(&self) -> bool { self.non_primitive_columns || self.projected_columns } + + /// Consumes the checker and returns sorted, deduplicated column indices + /// wrapped in a `PushdownColumns` struct. + /// + /// This method sorts the column indices and removes duplicates. The sort + /// is required because downstream code relies on column indices being in + /// ascending order for correct schema projection. + fn into_sorted_columns(mut self) -> PushdownColumns { + self.required_columns.sort_unstable(); + self.required_columns.dedup(); + PushdownColumns { + required_columns: self.required_columns, + struct_field_accesses: self.struct_field_accesses, + } + } } impl TreeNodeVisitor<'_> for PushdownChecker<'_> { type Node = Arc; fn f_down(&mut self, node: &Self::Node) -> Result { - if let Some(column) = node.as_any().downcast_ref::() { - if let Some(recursion) = self.check_single_column(column.name()) { - return Ok(recursion); + // Handle struct field access like `s['foo']['bar'] > 10`. + // + // DataFusion represents nested field access as `get_field(Column("s"), "foo")` + // (or chained: `get_field(get_field(Column("s"), "foo"), "bar")`). + // + // We intercept the outermost `get_field` on the way *down* the tree so + // the visitor never reaches the raw `Column("s")` node. Without this, + // `check_single_column` would see that `s` is a Struct and reject it. + // + // The strategy: + // 1. Match `get_field` whose first arg is a `Column` (the struct root). + // 2. Check that the *resolved* return type is primitive — meaning we've + // drilled all the way to a leaf (e.g. `s['foo']` → Utf8). + // 3. Record the root column index via `check_struct_field_column` and + // return `Jump` to skip visiting the children (the Column and the + // literal field-name args), since we've already handled them. + // + // If the return type is still nested (e.g. `s['nested_struct']` → Struct), + // we fall through and let normal traversal continue, which will + // eventually reject the expression when it hits the struct Column. + if let Some(func) = + ScalarFunctionExpr::try_downcast_func::(node.as_ref()) + { + let args = func.args(); + + if let Some(column) = args.first().and_then(|a| a.downcast_ref::()) { + // for Map columns, get_field performs a runtime key lookup rather than a + // schema-level field access so the entire Map column must be read, + // we skip the struct field optimization and defer to normal Column traversal + let is_map_column = self + .file_schema + .index_of(column.name()) + .ok() + .map(|idx| { + matches!( + self.file_schema.field(idx).data_type(), + DataType::Map(_, _) + ) + }) + .unwrap_or(false); + + let return_type = func.return_type(); + + if !is_map_column + && (!DataType::is_nested(return_type) + || self.is_nested_type_supported(return_type)) + { + // try to resolve all field name arguments to strinrg literals + // if any argument is not a string literal, we can not determine the exact + // leaf path so we fall back to reading the entire struct root column + let field_path = args[1..] + .iter() + .map(|arg| { + arg.downcast_ref::().and_then(|lit| { + lit.value().try_as_str().flatten().map(|s| s.to_string()) + }) + }) + .collect(); + + match field_path { + Some(path) => { + if let Some(recursion) = + self.check_struct_field_column(column.name(), path) + { + return Ok(recursion); + } + } + None => { + // Could not resolve field path — fall back to + // reading the entire struct root column. + if let Some(recursion) = + self.check_single_column(column.name()) + { + return Ok(recursion); + } + } + } + + return Ok(TreeNodeRecursion::Jump); + } } } + if let Some(column) = node.downcast_ref::() + && let Some(recursion) = self.check_single_column(column.name()) + { + return Ok(recursion); + } + Ok(TreeNodeRecursion::Continue) } } -// Checks if a given expression can be pushed down into `DataSourceExec` as opposed to being evaluated -// post-parquet-scan in a `FilterExec`. If it can be pushed down, this returns all the -// columns in the given expression so that they can be used in the parquet scanning, along with the -// expression rewritten as defined in [`PushdownChecker::f_up`] +/// Describes the nested column behavior for filter pushdown. +/// +/// This enum makes explicit the different states a predicate can be in +/// with respect to nested column handling during Parquet decoding. +/// Result of checking which columns are required for filter pushdown. +#[derive(Debug)] +struct PushdownColumns { + /// Sorted, unique column indices into the file schema required to evaluate + /// the filter expression. Must be in ascending order for correct schema + /// projection matching. Does not include struct columns accessed via `get_field`. + required_columns: Vec, + /// Struct field accesses via `get_field`. Each entry records the root struct + /// column index and the field path being accessed. + struct_field_accesses: Vec, +} + +/// Records a struct field access via `get_field(struct_col, 'field1', 'field2', ...)`. +/// +/// This allows the row filter to project only the specific Parquet leaf columns +/// needed by the filter, rather than all leaves of the struct. +#[derive(Debug, Clone)] +struct StructFieldAccess { + /// Arrow root column index of the struct in the file schema. + root_index: usize, + /// Field names forming the path into the struct. + /// e.g., `["value"]` for `s['value']`, `["outer", "inner"]` for `s['outer']['inner']`. + field_path: Vec, +} + +/// Checks if a given expression can be pushed down to the parquet decoder. +/// +/// Returns `Some(PushdownColumns)` if the expression can be pushed down, +/// where the struct contains the indices into the file schema of all columns +/// required to evaluate the expression. +/// +/// Returns `None` if the expression cannot be pushed down (e.g., references +/// unsupported nested types or columns not in the file). fn pushdown_columns( expr: &Arc, - table_schema: &Schema, -) -> Result>> { - let mut checker = PushdownChecker::new(table_schema); + file_schema: &Schema, +) -> Result> { + let allow_list_columns = supports_list_predicates(expr); + let mut checker = PushdownChecker::new(file_schema, allow_list_columns); expr.visit(&mut checker)?; - Ok((!checker.prevents_pushdown()) - .then_some(checker.required_columns.into_iter().collect())) + Ok((!checker.prevents_pushdown()).then(|| checker.into_sorted_columns())) +} + +/// Resolves which Parquet leaf columns and Arrow schema fields are needed +/// to evaluate `expr` against a Parquet file +/// +/// Returns `Ok(Some((plan, required_bytes)))` when the expression can be +/// evaluated using only pushdown-compatible columns. `Ok(None)` when it +/// cannot (it references whole struct columns or columns missing from disk). +/// +/// The `required_bytes` is the total compressed size of all referenced columns +/// across all row groups, used to estimate filter evaluation cost. +/// +/// Note: this is a shared entry point used by both row filter construction and +/// the opener's projection logic +pub(crate) fn build_parquet_read_plan( + expr: &Arc, + file_schema: &Schema, + metadata: &ParquetMetaData, +) -> Result> { + let schema_descr = metadata.file_metadata().schema_descr(); + + let Some(required_columns) = pushdown_columns(expr, file_schema)? else { + return Ok(None); + }; + + let root_indices = &required_columns.required_columns; + + let mut leaf_indices = + leaf_indices_for_roots(root_indices.iter().copied(), schema_descr); + + let struct_leaf_indices = resolve_struct_field_leaves( + &required_columns.struct_field_accesses, + file_schema, + schema_descr, + ); + leaf_indices.extend_from_slice(&struct_leaf_indices); + leaf_indices.sort_unstable(); + leaf_indices.dedup(); + + let required_bytes = size_of_columns(&leaf_indices, metadata)?; + + let projection_mask = + ProjectionMask::leaves(schema_descr, leaf_indices.iter().copied()); + + let projected_schema = build_filter_schema( + file_schema, + root_indices, + &required_columns.struct_field_accesses, + ); + + Ok(Some(( + ParquetReadPlan { + projection_mask, + projected_schema, + }, + required_bytes, + ))) +} + +/// Builds a unified [`ParquetReadPlan`] for a set of projection expressions +/// +/// Unlike [`build_parquet_read_plan`] (which is used for filter pushdown and +/// returns `None` when an expression references unsupported nested types or +/// missing columns), this function always succeeds. It collects every column +/// that *can* be resolved in the file and produces a leaf-level projection +/// mask. Columns missing from the file are silently skipped since the projection +/// layer handles those by inserting nulls. +pub(crate) fn build_projection_read_plan( + exprs: impl IntoIterator>, + file_schema: &Schema, + schema_descr: &SchemaDescriptor, +) -> ParquetReadPlan { + // fast path: if every expression is a plain Column reference, skip all + // struct analysis and use root-level projection directly + let exprs = exprs.into_iter().collect::>(); + let all_plain_columns = exprs.iter().all(|e| e.downcast_ref::().is_some()); + + if all_plain_columns { + let mut root_indices: Vec = exprs + .iter() + .map(|e| e.downcast_ref::().unwrap().index()) + .collect(); + root_indices.sort_unstable(); + root_indices.dedup(); + + let projection_mask = + ProjectionMask::roots(schema_descr, root_indices.iter().copied()); + let projected_schema = Arc::new( + file_schema + .project(&root_indices) + .expect("valid column indices"), + ); + + return ParquetReadPlan { + projection_mask, + projected_schema, + }; + } + + // secondary fast path: if the schema has no struct columns, we can skip + // PushdownChecker traversal and use root-level projection + let has_struct_columns = file_schema + .fields() + .iter() + .any(|f| matches!(f.data_type(), DataType::Struct(_))); + + if !has_struct_columns { + let mut root_indices = exprs + .into_iter() + .flat_map(|e| collect_columns(&e).into_iter().map(|col| col.index())) + .collect::>(); + + root_indices.sort_unstable(); + root_indices.dedup(); + + let projection_mask = + ProjectionMask::roots(schema_descr, root_indices.iter().copied()); + + let projected_schema = Arc::new( + file_schema + .project(&root_indices) + .expect("valid column indices"), + ); + + return ParquetReadPlan { + projection_mask, + projected_schema, + }; + } + + let mut all_root_indices = Vec::new(); + let mut all_struct_accesses = Vec::new(); + + for expr in exprs { + let mut checker = PushdownChecker::new(file_schema, true); + let _ = expr.visit(&mut checker); + let columns = checker.into_sorted_columns(); + + all_root_indices.extend_from_slice(&columns.required_columns); + all_struct_accesses.extend(columns.struct_field_accesses); + } + + all_root_indices.sort_unstable(); + all_root_indices.dedup(); + + // when no struct field accesses were found, fall back to root-level projection + // to match the performance of the simple path + if all_struct_accesses.is_empty() { + let projection_mask = + ProjectionMask::roots(schema_descr, all_root_indices.iter().copied()); + let projected_schema = Arc::new( + file_schema + .project(&all_root_indices) + .expect("valid column indices"), + ); + + return ParquetReadPlan { + projection_mask, + projected_schema, + }; + } + + let leaf_indices = { + let mut out = + leaf_indices_for_roots(all_root_indices.iter().copied(), schema_descr); + let struct_leaf_indices = + resolve_struct_field_leaves(&all_struct_accesses, file_schema, schema_descr); + + out.extend_from_slice(&struct_leaf_indices); + out.sort_unstable(); + out.dedup(); + + out + }; + + let projection_mask = + ProjectionMask::leaves(schema_descr, leaf_indices.iter().copied()); + + let projected_schema = + build_filter_schema(file_schema, &all_root_indices, &all_struct_accesses); + + ParquetReadPlan { + projection_mask, + projected_schema, + } +} + +fn leaf_indices_for_roots( + root_indices: I, + schema_descr: &SchemaDescriptor, +) -> Vec +where + I: IntoIterator, +{ + // Always map root (Arrow) indices to Parquet leaf indices via the schema + // descriptor. Arrow root indices only equal Parquet leaf indices when the + // schema has no group columns (Struct, Map, etc.); when group columns + // exist, their children become separate leaves and shift all subsequent + // leaf indices. + // Struct columns are unsupported. + let root_set: BTreeSet<_> = root_indices.into_iter().collect(); + + (0..schema_descr.num_columns()) + .filter(|leaf_idx| { + root_set.contains(&schema_descr.get_column_root_idx(*leaf_idx)) + }) + .collect() +} + +/// Resolves struct field access to specific Parquet leaf column indices +/// +/// For every `StructFieldAccess`, finds the leaf columns in the Parquet schema +/// whose path matches the struct root name + field path. This avoids reading all +/// leaves of a struct when only specific fields are needed +fn resolve_struct_field_leaves( + accesses: &[StructFieldAccess], + file_schema: &Schema, + schema_descr: &SchemaDescriptor, +) -> Vec { + let mut leaf_indices = Vec::new(); + + for access in accesses { + let root_name = file_schema.field(access.root_index).name(); + let prefix = std::iter::once(root_name.as_str()) + .chain(access.field_path.iter().map(|p| p.as_str())) + .collect::>(); + + for leaf_idx in 0..schema_descr.num_columns() { + let col = schema_descr.column(leaf_idx); + let col_path = col.path().parts(); + + // A leaf matches if its path starts with our prefix. + // e.g., prefix=["s", "value"] matches leaf path ["s", "value"] + // prefix=["s", "outer"] matches ["s", "outer", "inner"] + + // a leaf matches if its path starts with our prefix + // for example: prefix=["s", "value"] matches leaf path ["s", "value"] + // prefix=["s", "outer"] matches ["s", "outer", "inner"] + let leaf_matches_path = col_path.len() >= prefix.len() + && col_path.iter().zip(prefix.iter()).all(|(a, b)| a == b); + + if leaf_matches_path { + leaf_indices.push(leaf_idx); + } + } + } + + leaf_indices } -/// Recurses through expr as a tree, finds all `column`s, and checks if any of them would prevent -/// this expression from being predicate pushed down. If any of them would, this returns false. -/// Otherwise, true. -/// Note that the schema passed in here is *not* the physical file schema (as it is not available at that point in time); -/// it is the schema of the table that this expression is being evaluated against minus any projected columns and partition columns. +/// Builds a filter schema that includes only the fields actually accessed by the +/// filter expression. +/// +/// For regular (non-struct) columns, the full field type is used. +/// For struct columns accessed via `get_field`, a pruned struct type is created +/// containing only the fields along the access path. Note: it must match the schema +/// that the Parquet reader produces when projecting specific struct leaves +fn build_filter_schema( + file_schema: &Schema, + regular_indices: &[usize], + struct_field_accesses: &[StructFieldAccess], +) -> SchemaRef { + let regular_set: BTreeSet = regular_indices.iter().copied().collect(); + + let all_indices = regular_indices + .iter() + .copied() + .chain( + struct_field_accesses + .iter() + .map(|&StructFieldAccess { root_index, .. }| root_index), + ) + .collect::>(); + + let fields = all_indices + .iter() + .map(|&idx| { + let field = file_schema.field(idx); + + // if this column appears as a regular (whole-column) reference, + // keep the full type + // + // Pruning is only valid when the column is accessed exclusively + // through struct field accesses + if regular_set.contains(&idx) { + return Arc::new(field.clone()); + } + + // collect all field paths that access this root struct column + let field_paths = struct_field_accesses + .iter() + .filter_map( + |&StructFieldAccess { + root_index, + ref field_path, + }| { + (root_index == idx).then_some(field_path.as_slice()) + }, + ) + .collect::>(); + + if field_paths.is_empty() { + return Arc::new(field.clone()); + } + + let pruned_data_type = prune_struct_type(field.data_type(), &field_paths); + Arc::new(Field::new( + field.name(), + pruned_data_type, + field.is_nullable(), + )) + }) + .collect::>(); + + Arc::new(Schema::new_with_metadata( + fields, + file_schema.metadata().clone(), + )) +} + +fn prune_struct_type(dt: &DataType, paths: &[&[String]]) -> DataType { + let DataType::Struct(fields) = dt else { + return dt.clone(); + }; + + let needed = paths + .iter() + .filter_map(|p| p.first().map(|s| s.as_str())) + .collect::>(); + + let pruned_fields = fields + .iter() + .filter_map(|f| { + if !needed.contains(f.name().as_str()) { + return None; + } + + let sub_paths = paths + .iter() + .filter_map(|path| { + if path.first().map(|s| s.as_str()) == Some(f.name()) { + Some(&path[1..]) + } else { + None + } + }) + .filter(|sub| !sub.is_empty()) + .collect::>(); + + let out = if sub_paths.is_empty() { + // Leaf of access path — keep the field as-is. + Arc::clone(f) + } else { + // Recurse into nested struct. + let pruned = prune_struct_type(f.data_type(), &sub_paths); + Arc::new(Field::new(f.name(), pruned, f.is_nullable())) + }; + + Some(out) + }) + .collect::>(); + + DataType::Struct(pruned_fields.into()) +} + +/// Checks if a predicate expression can be pushed down to the parquet decoder. +/// +/// Returns `true` if all columns referenced by the expression: +/// - Exist in the provided schema +/// - Are primitive types OR list columns with supported predicates +/// (e.g., `array_has`, `array_has_all`, `array_has_any`, IS NULL, IS NOT NULL) +/// - Are struct columns accessed via `get_field` where the leaf type is primitive +/// - Direct references to whole struct columns will prevent pushdown +/// +/// # Arguments +/// * `expr` - The filter expression to check +/// * `file_schema` - The Arrow schema of the parquet file (or table schema when +/// the file schema is not yet available during planning) +/// +/// # Examples +/// +/// Primitive column filters can be pushed down: +/// ```ignore +/// use datafusion_expr::{col, Expr}; +/// use datafusion_common::ScalarValue; +/// use arrow::datatypes::{DataType, Field, Schema}; +/// use std::sync::Arc; +/// +/// let schema = Arc::new(Schema::new(vec![ +/// Field::new("age", DataType::Int32, false), +/// ])); +/// +/// // Primitive filter: can be pushed down +/// let expr = col("age").gt(Expr::Literal(ScalarValue::Int32(Some(30)), None)); +/// let expr = logical2physical(&expr, &schema); +/// assert!(can_expr_be_pushed_down_with_schemas(&expr, &schema)); +/// ``` +/// +/// Struct column filters cannot be pushed down: +/// ```ignore +/// use arrow::datatypes::Fields; +/// +/// let schema = Arc::new(Schema::new(vec![ +/// Field::new("person", DataType::Struct( +/// Fields::from(vec![Field::new("name", DataType::Utf8, true)]) +/// ), true), +/// ])); +/// +/// // Struct filter: cannot be pushed down +/// let expr = col("person").is_not_null(); +/// let expr = logical2physical(&expr, &schema); +/// assert!(!can_expr_be_pushed_down_with_schemas(&expr, &schema)); +/// ``` +/// +/// List column filters with supported predicates can be pushed down: +/// ```ignore +/// use datafusion_functions_nested::expr_fn::{array_has_all, make_array}; +/// +/// let schema = Arc::new(Schema::new(vec![ +/// Field::new("tags", DataType::List( +/// Arc::new(Field::new("item", DataType::Utf8, true)) +/// ), true), +/// ])); +/// +/// // Array filter with supported predicate: can be pushed down +/// let expr = array_has_all(col("tags"), make_array(vec![ +/// Expr::Literal(ScalarValue::Utf8(Some("rust".to_string())), None) +/// ])); +/// let expr = logical2physical(&expr, &schema); +/// assert!(can_expr_be_pushed_down_with_schemas(&expr, &schema)); +/// ``` pub fn can_expr_be_pushed_down_with_schemas( expr: &Arc, file_schema: &Schema, @@ -381,7 +974,7 @@ pub fn can_expr_be_pushed_down_with_schemas( } } -/// Calculate the total compressed size of all `Column`'s required for +/// Calculate the total compressed size of all leaf columns required for /// predicate `Expr`. /// /// This value represents the total amount of IO required to evaluate the @@ -398,38 +991,33 @@ fn size_of_columns(columns: &[usize], metadata: &ParquetMetaData) -> Result Result { - // TODO How do we know this? - Ok(false) -} - -/// Build a [`RowFilter`] from the given predicate `Expr` if possible +/// # Arguments +/// * `expr` - The filter predicate, already adapted to reference columns in `file_schema` +/// * `file_schema` - The Arrow schema of the parquet file (the result of converting +/// the parquet schema to Arrow, potentially with type coercions applied) +/// * `metadata` - Parquet file metadata used for cost estimation +/// * `reorder_predicates` - If true, reorder predicates to minimize I/O +/// * `file_metrics` - Metrics for tracking filter performance /// -/// # returns -/// * `Ok(Some(row_filter))` if the expression can be used as RowFilter -/// * `Ok(None)` if the expression cannot be used as an RowFilter +/// # Returns +/// * `Ok(Some(row_filter))` if the expression can be used as a RowFilter +/// * `Ok(None)` if the expression cannot be used as a RowFilter /// * `Err(e)` if an error occurs while building the filter /// -/// Note that the returned `RowFilter` may not contains all conjuncts in the -/// original expression. This is because some conjuncts may not be able to be -/// evaluated as an `ArrowPredicate` and will be ignored. +/// Note: The returned `RowFilter` may not contain all conjuncts from the original +/// expression. Conjuncts that cannot be evaluated as an `ArrowPredicate` are ignored. /// /// For example, if the expression is `a = 1 AND b = 2 AND c = 3` and `b = 2` -/// can not be evaluated for some reason, the returned `RowFilter` will contain -/// `a = 1` and `c = 3`. +/// cannot be evaluated for some reason, the returned `RowFilter` will contain +/// only `a = 1` and `c = 3`. pub fn build_row_filter( expr: &Arc, - physical_file_schema: &SchemaRef, - predicate_file_schema: &SchemaRef, + file_schema: &SchemaRef, metadata: &ParquetMetaData, reorder_predicates: bool, file_metrics: &ParquetFileMetrics, - schema_adapter_factory: &Arc, ) -> Result> { let rows_pruned = &file_metrics.pushdown_rows_pruned; let rows_matched = &file_metrics.pushdown_rows_matched; @@ -443,13 +1031,8 @@ pub fn build_row_filter( let mut candidates: Vec = predicates .into_iter() .map(|expr| { - FilterCandidateBuilder::new( - Arc::clone(expr), - Arc::clone(physical_file_schema), - Arc::clone(predicate_file_schema), - Arc::clone(schema_adapter_factory), - ) - .build(metadata) + FilterCandidateBuilder::new(Arc::clone(expr), Arc::clone(file_schema)) + .build(metadata) }) .collect::, _>>()? .into_iter() @@ -462,22 +1045,35 @@ pub fn build_row_filter( } if reorder_predicates { - candidates.sort_unstable_by(|c1, c2| { - match c1.can_use_index.cmp(&c2.can_use_index) { - Ordering::Equal => c1.required_bytes.cmp(&c2.required_bytes), - ord => ord, - } - }); + candidates.sort_unstable_by_key(|c| c.required_bytes); } + // To avoid double-counting metrics when multiple predicates are used: + // - All predicates should count rows_pruned (cumulative pruned rows) + // - Only the last predicate should count rows_matched (final result) + // This ensures: rows_matched + rows_pruned = total rows processed + let total_candidates = candidates.len(); + candidates .into_iter() - .map(|candidate| { + .enumerate() + .map(|(idx, candidate)| { + let is_last = idx == total_candidates - 1; + + // All predicates share the pruned counter (cumulative) + let predicate_rows_pruned = rows_pruned.clone(); + + // Only the last predicate tracks matched rows (final result) + let predicate_rows_matched = if is_last { + rows_matched.clone() + } else { + metrics::Count::new() + }; + DatafusionArrowPredicate::try_new( candidate, - metadata, - rows_pruned.clone(), - rows_matched.clone(), + predicate_rows_pruned, + predicate_rows_matched, time.clone(), ) .map(|pred| Box::new(pred) as _) @@ -486,24 +1082,105 @@ pub fn build_row_filter( .map(|filters| Some(RowFilter::new(filters))) } +/// Builds row filters for decoder runs. +/// +/// A [`RowFilter`] must be owned by a decoder, so scans split across multiple +/// decoder runs need a fresh filter for each run that evaluates row predicates. +/// The first filter is built eagerly during construction so callers can cheaply +/// query [`has_row_filter`](Self::has_row_filter) before splitting the scan. +pub(crate) struct RowFilterGenerator<'a> { + predicate: Option<&'a Arc>, + physical_file_schema: &'a SchemaRef, + file_metadata: &'a ParquetMetaData, + reorder_predicates: bool, + file_metrics: &'a ParquetFileMetrics, + first_row_filter: Option, +} + +impl<'a> RowFilterGenerator<'a> { + pub(crate) fn new( + predicate: Option<&'a Arc>, + physical_file_schema: &'a SchemaRef, + file_metadata: &'a ParquetMetaData, + reorder_predicates: bool, + file_metrics: &'a ParquetFileMetrics, + ) -> Self { + let mut generator = Self { + predicate, + physical_file_schema, + file_metadata, + reorder_predicates, + file_metrics, + first_row_filter: None, + }; + generator.first_row_filter = generator.build(); + generator + } + + pub(crate) fn has_row_filter(&self) -> bool { + self.first_row_filter.is_some() + } + + pub(crate) fn next_filter(&mut self) -> Option { + self.first_row_filter.take().or_else(|| self.build()) + } + + fn build(&self) -> Option { + let predicate = self.predicate?; + match build_row_filter( + predicate, + self.physical_file_schema, + self.file_metadata, + self.reorder_predicates, + self.file_metrics, + ) { + Ok(Some(filter)) => Some(filter), + Ok(None) => None, + Err(e) => { + log::debug!( + "Ignoring error building row filter for '{predicate:?}': {e}" + ); + None + } + } + } +} + #[cfg(test)] mod test { use super::*; + use arrow::datatypes::Fields; use datafusion_common::ScalarValue; + use arrow::array::{ + Int32Array, ListBuilder, StringArray, StringBuilder, StructArray, + }; use arrow::datatypes::{Field, TimeUnit::Nanosecond}; - use datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory; - use datafusion_expr::{col, Expr}; + use datafusion_expr::{Expr, col}; + use datafusion_functions::core::get_field; + use datafusion_functions_nested::array_has::{ + array_has_all_udf, array_has_any_udf, array_has_udf, + }; + use datafusion_functions_nested::expr_fn::{ + array_has, array_has_all, array_has_any, make_array, + }; use datafusion_physical_expr::planner::logical2physical; - use datafusion_physical_plan::metrics::{Count, Time}; + use datafusion_physical_expr_adapter::{ + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapterFactory, + }; + use datafusion_physical_plan::metrics::{Count, ExecutionPlanMetricsSet, Time}; + use parquet::arrow::ArrowWriter; use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; use parquet::arrow::parquet_to_arrow_schema; use parquet::file::reader::{FileReader, SerializedFileReader}; + use tempfile::NamedTempFile; + + use datafusion_physical_expr::expressions::Column as PhysicalColumn; - // We should ignore predicate that read non-primitive columns + // List predicates used by the decoder should be accepted for pushdown #[test] - fn test_filter_candidate_builder_ignore_complex_types() { + fn test_filter_candidate_builder_supports_list_types() { let testdata = datafusion_common::test_util::parquet_test_data(); let file = std::fs::File::open(format!("{testdata}/list_columns.parquet")) .expect("opening file"); @@ -519,19 +1196,20 @@ mod test { let expr = col("int64_list").is_not_null(); let expr = logical2physical(&expr, &table_schema); - let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); let table_schema = Arc::new(table_schema.clone()); - let candidate = FilterCandidateBuilder::new( - expr, - table_schema.clone(), - table_schema, - schema_adapter_factory, - ) - .build(metadata) - .expect("building candidate"); + let list_index = table_schema + .index_of("int64_list") + .expect("list column should exist"); - assert!(candidate.is_none()); + let candidate = FilterCandidateBuilder::new(expr, table_schema) + .build(metadata) + .expect("building candidate") + .expect("list pushdown should be supported"); + + let expected_mask = + ProjectionMask::leaves(metadata.file_metadata().schema_descr(), [list_index]); + assert_eq!(candidate.read_plan.projection_mask, expected_mask); } #[test] @@ -559,21 +1237,18 @@ mod test { None, )); let expr = logical2physical(&expr, &table_schema); - let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); - let table_schema = Arc::new(table_schema.clone()); - let candidate = FilterCandidateBuilder::new( - expr, - file_schema.clone(), - table_schema.clone(), - schema_adapter_factory, - ) - .build(&metadata) - .expect("building candidate") - .expect("candidate expected"); + let expr = DefaultPhysicalExprAdapterFactory {} + .create(Arc::new(table_schema.clone()), Arc::clone(&file_schema)) + .expect("creating expr adapter") + .rewrite(expr) + .expect("rewriting expression"); + let candidate = FilterCandidateBuilder::new(expr, file_schema.clone()) + .build(&metadata) + .expect("building candidate") + .expect("candidate expected"); let mut row_filter = DatafusionArrowPredicate::try_new( candidate, - &metadata, Count::new(), Count::new(), Time::new(), @@ -600,20 +1275,19 @@ mod test { None, )); let expr = logical2physical(&expr, &table_schema); - let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); - let candidate = FilterCandidateBuilder::new( - expr, - file_schema, - table_schema, - schema_adapter_factory, - ) - .build(&metadata) - .expect("building candidate") - .expect("candidate expected"); + // Rewrite the expression to add CastExpr for type coercion + let expr = DefaultPhysicalExprAdapterFactory {} + .create(Arc::new(table_schema), Arc::clone(&file_schema)) + .expect("creating expr adapter") + .rewrite(expr) + .expect("rewriting expression"); + let candidate = FilterCandidateBuilder::new(expr, file_schema) + .build(&metadata) + .expect("building candidate") + .expect("candidate expected"); let mut row_filter = DatafusionArrowPredicate::try_new( candidate, - &metadata, Count::new(), Count::new(), Time::new(), @@ -625,14 +1299,233 @@ mod test { } #[test] - fn nested_data_structures_prevent_pushdown() { + fn struct_data_structures_prevent_pushdown() { + let table_schema = Arc::new(Schema::new(vec![Field::new( + "struct_col", + DataType::Struct( + vec![Arc::new(Field::new("a", DataType::Int32, true))].into(), + ), + true, + )])); + + let expr = col("struct_col").is_not_null(); + let expr = logical2physical(&expr, &table_schema); + + assert!(!can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); + } + + #[test] + fn mixed_primitive_and_struct_prevents_pushdown() { + // Even when a predicate contains both primitive and unsupported nested columns, + // the entire predicate should not be pushed down because the struct column + // cannot be evaluated during Parquet decoding. + let table_schema = Arc::new(Schema::new(vec![ + Field::new( + "struct_col", + DataType::Struct( + vec![Arc::new(Field::new("a", DataType::Int32, true))].into(), + ), + true, + ), + Field::new("int_col", DataType::Int32, false), + ])); + + // Expression: (struct_col IS NOT NULL) AND (int_col = 5) + // Even though int_col is primitive, the presence of struct_col in the + // conjunction should prevent pushdown of the entire expression. + let expr = col("struct_col") + .is_not_null() + .and(col("int_col").eq(Expr::Literal(ScalarValue::Int32(Some(5)), None))); + let expr = logical2physical(&expr, &table_schema); + + // The entire expression should not be pushed down + assert!(!can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); + + // However, just the int_col predicate alone should be pushable + let expr_int_only = + col("int_col").eq(Expr::Literal(ScalarValue::Int32(Some(5)), None)); + let expr_int_only = logical2physical(&expr_int_only, &table_schema); + assert!(can_expr_be_pushed_down_with_schemas( + &expr_int_only, + &table_schema + )); + } + + #[test] + fn nested_lists_allow_pushdown_checks() { let table_schema = Arc::new(get_lists_table_schema()); let expr = col("utf8_list").is_not_null(); let expr = logical2physical(&expr, &table_schema); check_expression_can_evaluate_against_schema(&expr, &table_schema); - assert!(!can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); + assert!(can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); + } + + #[test] + fn array_has_all_pushdown_filters_rows() { + // Test array_has_all: checks if array contains all of ["c"] + // Rows with "c": row 1 and row 2 + let expr = array_has_all( + col("letters"), + make_array(vec![Expr::Literal( + ScalarValue::Utf8(Some("c".to_string())), + None, + )]), + ); + test_array_predicate_pushdown("array_has_all", expr, 1, 2, true); + } + + /// Helper function to test array predicate pushdown functionality. + /// + /// Creates a Parquet file with a list column, applies the given predicate, + /// and verifies that rows are correctly filtered during decoding. + fn test_array_predicate_pushdown( + func_name: &str, + predicate_expr: Expr, + expected_pruned: usize, + expected_matched: usize, + expect_list_support: bool, + ) { + let item_field = Arc::new(Field::new("item", DataType::Utf8, true)); + let schema = Arc::new(Schema::new(vec![Field::new( + "letters", + DataType::List(item_field), + true, + )])); + + let mut builder = ListBuilder::new(StringBuilder::new()); + // Row 0: ["a", "b"] + builder.values().append_value("a"); + builder.values().append_value("b"); + builder.append(true); + + // Row 1: ["c"] + builder.values().append_value("c"); + builder.append(true); + + // Row 2: ["c", "d"] + builder.values().append_value("c"); + builder.values().append_value("d"); + builder.append(true); + + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(builder.finish())]) + .expect("record batch"); + + let file = NamedTempFile::new().expect("temp file"); + let mut writer = + ArrowWriter::try_new(file.reopen().unwrap(), schema, None).expect("writer"); + writer.write(&batch).expect("write batch"); + writer.close().expect("close writer"); + + let reader_file = file.reopen().expect("reopen file"); + let parquet_reader_builder = + ParquetRecordBatchReaderBuilder::try_new(reader_file) + .expect("reader builder"); + let metadata = parquet_reader_builder.metadata().clone(); + let file_schema = parquet_reader_builder.schema().clone(); + + let expr = logical2physical(&predicate_expr, &file_schema); + if expect_list_support { + assert!(supports_list_predicates(&expr)); + } + + let metrics = ExecutionPlanMetricsSet::new(); + let file_metrics = + ParquetFileMetrics::new(0, &format!("{func_name}.parquet"), &metrics); + + let row_filter = + build_row_filter(&expr, &file_schema, &metadata, false, &file_metrics) + .expect("building row filter") + .expect("row filter should exist"); + + let reader = parquet_reader_builder + .with_row_filter(row_filter) + .build() + .expect("build reader"); + + let mut total_rows = 0; + for batch in reader { + let batch = batch.expect("record batch"); + total_rows += batch.num_rows(); + } + + assert_eq!( + file_metrics.pushdown_rows_pruned.value(), + expected_pruned, + "{func_name}: expected {expected_pruned} pruned rows" + ); + assert_eq!( + file_metrics.pushdown_rows_matched.value(), + expected_matched, + "{func_name}: expected {expected_matched} matched rows" + ); + assert_eq!( + total_rows, expected_matched, + "{func_name}: expected {expected_matched} total rows" + ); + } + + #[test] + fn array_has_pushdown_filters_rows() { + // Test array_has: checks if "c" is in the array + // Rows with "c": row 1 and row 2 + let expr = array_has( + col("letters"), + Expr::Literal(ScalarValue::Utf8(Some("c".to_string())), None), + ); + test_array_predicate_pushdown("array_has", expr, 1, 2, true); + } + + #[test] + fn array_has_any_pushdown_filters_rows() { + // Test array_has_any: checks if array contains any of ["a", "d"] + // Row 0 has "a", row 2 has "d" - both should match + let expr = array_has_any( + col("letters"), + make_array(vec![ + Expr::Literal(ScalarValue::Utf8(Some("a".to_string())), None), + Expr::Literal(ScalarValue::Utf8(Some("d".to_string())), None), + ]), + ); + test_array_predicate_pushdown("array_has_any", expr, 1, 2, true); + } + + #[test] + fn array_has_udf_pushdown_filters_rows() { + let expr = array_has_udf().call(vec![ + col("letters"), + Expr::Literal(ScalarValue::Utf8(Some("c".to_string())), None), + ]); + + test_array_predicate_pushdown("array_has_udf", expr, 1, 2, true); + } + + #[test] + fn array_has_all_udf_pushdown_filters_rows() { + let expr = array_has_all_udf().call(vec![ + col("letters"), + make_array(vec![Expr::Literal( + ScalarValue::Utf8(Some("c".to_string())), + None, + )]), + ]); + + test_array_predicate_pushdown("array_has_all_udf", expr, 1, 2, true); + } + + #[test] + fn array_has_any_udf_pushdown_filters_rows() { + let expr = array_has_any_udf().call(vec![ + col("letters"), + make_array(vec![ + Expr::Literal(ScalarValue::Utf8(Some("a".to_string())), None), + Expr::Literal(ScalarValue::Utf8(Some("d".to_string())), None), + ]), + ]); + + test_array_predicate_pushdown("array_has_any_udf", expr, 1, 2, true); } #[test] @@ -693,6 +1586,534 @@ mod test { .expect("parsing schema") } + /// Regression test: when a schema has Struct columns, Arrow field indices diverge + /// from Parquet leaf indices (Struct children become separate leaves). The + /// `PrimitiveOnly` fast-path in `leaf_indices_for_roots` assumes they are equal, + /// so a filter on a primitive column *after* a Struct gets the wrong leaf index. + /// + /// Schema: + /// Arrow indices: col_a=0 struct_col=1 col_b=2 + /// Parquet leaves: col_a=0 struct_col.x=1 struct_col.y=2 col_b=3 + /// + /// A filter on col_b should project Parquet leaf 3, but the bug causes it to + /// project leaf 2 (struct_col.y). + #[test] + fn test_filter_pushdown_leaf_index_with_struct_in_schema() { + use arrow::array::{Int32Array, StringArray, StructArray}; + + let schema = Arc::new(Schema::new(vec![ + Field::new("col_a", DataType::Int32, false), + Field::new( + "struct_col", + DataType::Struct( + vec![ + Arc::new(Field::new("x", DataType::Int32, true)), + Arc::new(Field::new("y", DataType::Int32, true)), + ] + .into(), + ), + true, + ), + Field::new("col_b", DataType::Utf8, false), + ])); + + let col_a = Arc::new(Int32Array::from(vec![1, 2, 3])); + let struct_col = Arc::new(StructArray::from(vec![ + ( + Arc::new(Field::new("x", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![10, 20, 30])) as _, + ), + ( + Arc::new(Field::new("y", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![100, 200, 300])) as _, + ), + ])); + let col_b = Arc::new(StringArray::from(vec!["aaa", "target", "zzz"])); + + let batch = + RecordBatch::try_new(Arc::clone(&schema), vec![col_a, struct_col, col_b]) + .unwrap(); + + let file = NamedTempFile::new().expect("temp file"); + let mut writer = + ArrowWriter::try_new(file.reopen().unwrap(), Arc::clone(&schema), None) + .expect("writer"); + writer.write(&batch).expect("write batch"); + writer.close().expect("close writer"); + + let reader_file = file.reopen().expect("reopen file"); + let builder = ParquetRecordBatchReaderBuilder::try_new(reader_file) + .expect("reader builder"); + let metadata = builder.metadata().clone(); + let file_schema = builder.schema().clone(); + + // sanity check: 4 Parquet leaves, 3 Arrow fields + assert_eq!(metadata.file_metadata().schema_descr().num_columns(), 4); + assert_eq!(file_schema.fields().len(), 3); + + // build a filter candidate for `col_b = 'target'` through the public API + let expr = col("col_b").eq(Expr::Literal( + ScalarValue::Utf8(Some("target".to_string())), + None, + )); + let expr = logical2physical(&expr, &file_schema); + + let candidate = FilterCandidateBuilder::new(expr, file_schema) + .build(&metadata) + .expect("building candidate") + .expect("filter on primitive col_b should be pushable"); + + // col_b is Parquet leaf 3 (shifted by struct_col's two children). + let expected_mask = + ProjectionMask::leaves(metadata.file_metadata().schema_descr(), [3]); + assert_eq!( + candidate.read_plan.projection_mask, expected_mask, + "projection_mask should select only leaf 3 for col_b" + ); + } + + /// get_field(struct_col, 'a') on a struct with a primitive leaf should allow pushdown. + #[test] + fn get_field_on_struct_allows_pushdown() { + let table_schema = Arc::new(Schema::new(vec![Field::new( + "struct_col", + DataType::Struct( + vec![Arc::new(Field::new("a", DataType::Int32, true))].into(), + ), + true, + )])); + + // get_field(struct_col, 'a') > 5 + let get_field_expr = get_field().call(vec![ + col("struct_col"), + Expr::Literal(ScalarValue::Utf8(Some("a".to_string())), None), + ]); + let expr = get_field_expr.gt(Expr::Literal(ScalarValue::Int32(Some(5)), None)); + let expr = logical2physical(&expr, &table_schema); + + assert!(can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); + } + + /// get_field on a struct field that resolves to a nested type should still block pushdown. + #[test] + fn get_field_on_nested_leaf_prevents_pushdown() { + let inner_struct = DataType::Struct( + vec![Arc::new(Field::new("x", DataType::Int32, true))].into(), + ); + let table_schema = Arc::new(Schema::new(vec![Field::new( + "struct_col", + DataType::Struct( + vec![Arc::new(Field::new("nested", inner_struct, true))].into(), + ), + true, + )])); + + // get_field(struct_col, 'nested') IS NOT NULL — the leaf is still a struct + let get_field_expr = get_field().call(vec![ + col("struct_col"), + Expr::Literal(ScalarValue::Utf8(Some("nested".to_string())), None), + ]); + let expr = get_field_expr.is_not_null(); + let expr = logical2physical(&expr, &table_schema); + + assert!(!can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); + } + + /// get_field returning a list inside a struct should allow pushdown when + /// wrapped in a supported list predicate like `array_has_any`. + /// e.g. `array_has_any(get_field(s, 'items'), make_array('x'))` + #[test] + fn get_field_list_leaf_with_array_predicate_allows_pushdown() { + let item_field = Arc::new(Field::new("item", DataType::Utf8, true)); + let table_schema = Arc::new(Schema::new(vec![Field::new( + "s", + DataType::Struct( + vec![ + Arc::new(Field::new("id", DataType::Int32, true)), + Arc::new(Field::new("items", DataType::List(item_field), true)), + ] + .into(), + ), + true, + )])); + + // array_has_any(get_field(s, 'items'), make_array('x')) + let get_field_expr = get_field().call(vec![ + col("s"), + Expr::Literal(ScalarValue::Utf8(Some("items".to_string())), None), + ]); + let expr = array_has_any( + get_field_expr, + make_array(vec![Expr::Literal( + ScalarValue::Utf8(Some("x".to_string())), + None, + )]), + ); + let expr = logical2physical(&expr, &table_schema); + + assert!(can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); + } + + /// get_field on a struct produces correct Parquet leaf indices. + #[test] + fn get_field_filter_candidate_has_correct_leaf_indices() { + use arrow::array::{Int32Array, StringArray, StructArray}; + + // Schema: id (Int32), s (Struct{value: Int32, label: Utf8}) + // Parquet leaves: id=0, s.value=1, s.label=2 + let struct_fields: Fields = vec![ + Arc::new(Field::new("value", DataType::Int32, false)), + Arc::new(Field::new("label", DataType::Utf8, false)), + ] + .into(); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("s", DataType::Struct(struct_fields.clone()), false), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StructArray::new( + struct_fields, + vec![ + Arc::new(Int32Array::from(vec![10, 20, 30])) as _, + Arc::new(StringArray::from(vec!["a", "b", "c"])) as _, + ], + None, + )), + ], + ) + .unwrap(); + + let file = NamedTempFile::new().expect("temp file"); + let mut writer = + ArrowWriter::try_new(file.reopen().unwrap(), Arc::clone(&schema), None) + .expect("writer"); + writer.write(&batch).expect("write batch"); + writer.close().expect("close writer"); + + let reader_file = file.reopen().expect("reopen file"); + let builder = ParquetRecordBatchReaderBuilder::try_new(reader_file) + .expect("reader builder"); + let metadata = builder.metadata().clone(); + let file_schema = builder.schema().clone(); + + // get_field(s, 'value') > 5 + let get_field_expr = get_field().call(vec![ + col("s"), + Expr::Literal(ScalarValue::Utf8(Some("value".to_string())), None), + ]); + let expr = get_field_expr.gt(Expr::Literal(ScalarValue::Int32(Some(5)), None)); + let expr = logical2physical(&expr, &file_schema); + + let candidate = FilterCandidateBuilder::new(expr, file_schema) + .build(&metadata) + .expect("building candidate") + .expect("get_field filter on struct should be pushable"); + + // The filter accesses only s.value, so only Parquet leaf 1 is needed. + // Leaf 2 (s.label) is not read, reducing unnecessary I/O. + let expected_mask = + ProjectionMask::leaves(metadata.file_metadata().schema_descr(), [1]); + assert_eq!( + candidate.read_plan.projection_mask, expected_mask, + "projection_mask should select only the accessed struct field leaf" + ); + } + + /// Deeply nested get_field: get_field(struct_col, 'outer', 'inner') where the + /// leaf is primitive should allow pushdown. The logical simplifier flattens + /// nested get_field(get_field(col, 'a'), 'b') into get_field(col, 'a', 'b'). + #[test] + fn get_field_deeply_nested_allows_pushdown() { + let table_schema = Arc::new(Schema::new(vec![Field::new( + "s", + DataType::Struct( + vec![Arc::new(Field::new( + "outer", + DataType::Struct( + vec![Arc::new(Field::new("inner", DataType::Int32, true))].into(), + ), + true, + ))] + .into(), + ), + true, + )])); + + // s['outer']['inner'] > 5 + let get_field_expr = get_field().call(vec![ + col("s"), + Expr::Literal(ScalarValue::Utf8(Some("outer".to_string())), None), + Expr::Literal(ScalarValue::Utf8(Some("inner".to_string())), None), + ]); + let expr = get_field_expr.gt(Expr::Literal(ScalarValue::Int32(Some(5)), None)); + let expr = logical2physical(&expr, &table_schema); + + assert!(can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); + } + + /// End-to-end: deeply nested get_field filter produces correct leaf indices + /// and the filter actually works against a Parquet file. + #[test] + fn get_field_deeply_nested_filter_candidate() { + use arrow::array::{Int32Array, StringArray, StructArray}; + + // Schema: id (Int32), s (Struct{outer: Struct{extra: Int32, inner: Int32}, tag: Utf8}) + // Parquet leaves: id=0, s.outer.extra=1, s.outer.inner=2, s.tag=3 + let inner_fields: Fields = vec![ + Arc::new(Field::new("extra", DataType::Int32, false)), + Arc::new(Field::new("inner", DataType::Int32, false)), + ] + .into(); + let outer_fields: Fields = vec![ + Arc::new(Field::new( + "outer", + DataType::Struct(inner_fields.clone()), + false, + )), + Arc::new(Field::new("tag", DataType::Utf8, false)), + ] + .into(); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("s", DataType::Struct(outer_fields.clone()), false), + ])); + + let inner_struct = StructArray::new( + inner_fields, + vec![ + Arc::new(Int32Array::from(vec![100, 200, 300])) as _, + Arc::new(Int32Array::from(vec![10, 20, 30])) as _, + ], + None, + ); + let outer_struct = StructArray::new( + outer_fields, + vec![ + Arc::new(inner_struct) as _, + Arc::new(StringArray::from(vec!["x", "y", "z"])) as _, + ], + None, + ); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(outer_struct), + ], + ) + .unwrap(); + + let file = NamedTempFile::new().expect("temp file"); + let mut writer = + ArrowWriter::try_new(file.reopen().unwrap(), Arc::clone(&schema), None) + .expect("writer"); + writer.write(&batch).expect("write batch"); + writer.close().expect("close writer"); + + let reader_file = file.reopen().expect("reopen file"); + let builder = ParquetRecordBatchReaderBuilder::try_new(reader_file) + .expect("reader builder"); + let metadata = builder.metadata().clone(); + let file_schema = builder.schema().clone(); + + // Parquet should have 4 leaves: id=0, s.outer.extra=1, s.outer.inner=2, s.tag=3 + assert_eq!(metadata.file_metadata().schema_descr().num_columns(), 4); + + // get_field(s, 'outer', 'inner') > 15 + // Should only need leaf 2 (s.outer.inner), not leaf 1 (s.outer.extra) or leaf 3 (s.tag). + let get_field_expr = get_field().call(vec![ + col("s"), + Expr::Literal(ScalarValue::Utf8(Some("outer".to_string())), None), + Expr::Literal(ScalarValue::Utf8(Some("inner".to_string())), None), + ]); + let expr = get_field_expr.gt(Expr::Literal(ScalarValue::Int32(Some(15)), None)); + let expr = logical2physical(&expr, &file_schema); + + let candidate = FilterCandidateBuilder::new(expr, file_schema) + .build(&metadata) + .expect("building candidate") + .expect("deeply nested get_field filter should be pushable"); + + // Only s.outer.inner (leaf 2) should be projected, + let expected_mask = + ProjectionMask::leaves(metadata.file_metadata().schema_descr(), [2]); + assert_eq!( + candidate.read_plan.projection_mask, expected_mask, + "projection_mask should select only leaf 2 for s.outer.inner, skipping sibling and cousin leaves" + ); + } + + /// End-to-end: get_field filter on a struct column with multiple fields + /// reads only the needed leaf and correctly filters rows during Parquet decoding. + #[test] + fn get_field_end_to_end_filters_rows() { + // Schema: id (Int32), s (Struct{value: Int32, label: Utf8}) + // Parquet leaves: id=0, s.value=1, s.label=2 + let struct_fields: Fields = vec![ + Arc::new(Field::new("value", DataType::Int32, false)), + Arc::new(Field::new("label", DataType::Utf8, false)), + ] + .into(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("s", DataType::Struct(struct_fields.clone()), false), + ])); + + // +----+--------------------------+ + // | id | s | + // +----+--------------------------+ + // | 1 | {value: 10, label: "a"} | + // | 2 | {value: 20, label: "b"} | + // | 3 | {value: 30, label: "c"} | + // +----+--------------------------+ + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StructArray::new( + struct_fields, + vec![ + Arc::new(Int32Array::from(vec![10, 20, 30])) as _, + Arc::new(StringArray::from(vec!["a", "b", "c"])) as _, + ], + None, + )), + ], + ) + .unwrap(); + + let file = NamedTempFile::new().expect("temp file"); + let mut writer = + ArrowWriter::try_new(file.reopen().unwrap(), Arc::clone(&schema), None) + .expect("writer"); + writer.write(&batch).expect("write batch"); + writer.close().expect("close writer"); + + let reader_file = file.reopen().expect("reopen file"); + let parquet_reader_builder = + ParquetRecordBatchReaderBuilder::try_new(reader_file) + .expect("reader builder"); + let metadata = parquet_reader_builder.metadata().clone(); + let file_schema = parquet_reader_builder.schema().clone(); + + // get_field(s, 'value') > 15 — should match rows with value=20 and value=30 + let get_field_expr = get_field().call(vec![ + col("s"), + Expr::Literal(ScalarValue::Utf8(Some("value".to_string())), None), + ]); + let predicate_expr = + get_field_expr.gt(Expr::Literal(ScalarValue::Int32(Some(15)), None)); + let expr = logical2physical(&predicate_expr, &file_schema); + + let metrics = ExecutionPlanMetricsSet::new(); + let file_metrics = ParquetFileMetrics::new(0, "struct_e2e.parquet", &metrics); + + let row_filter = + build_row_filter(&expr, &file_schema, &metadata, false, &file_metrics) + .expect("building row filter") + .expect("row filter should exist"); + + let reader = parquet_reader_builder + .with_row_filter(row_filter) + .build() + .expect("build reader"); + + let mut total_rows = 0; + for batch in reader { + let batch = batch.expect("record batch"); + total_rows += batch.num_rows(); + } + + assert_eq!(total_rows, 2, "expected 2 rows matching value > 15"); + assert_eq!(file_metrics.pushdown_rows_pruned.value(), 1); + assert_eq!(file_metrics.pushdown_rows_matched.value(), 2); + } + + #[test] + fn projection_read_plan_preserves_full_struct() { + // Schema: id (Int32), s (Struct{value: Int32, label: Utf8}) + // Parquet leaves: id=0, s.value=1, s.label=2 + let struct_fields: Fields = vec![ + Arc::new(Field::new("value", DataType::Int32, false)), + Arc::new(Field::new("label", DataType::Utf8, false)), + ] + .into(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("s", DataType::Struct(struct_fields.clone()), false), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StructArray::new( + struct_fields, + vec![ + Arc::new(Int32Array::from(vec![10, 20, 30])) as _, + Arc::new(StringArray::from(vec!["a", "b", "c"])) as _, + ], + None, + )), + ], + ) + .unwrap(); + + let file = NamedTempFile::new().expect("temp file"); + let mut writer = + ArrowWriter::try_new(file.reopen().unwrap(), Arc::clone(&schema), None) + .expect("writer"); + writer.write(&batch).expect("write batch"); + writer.close().expect("close writer"); + + let reader_file = file.reopen().expect("reopen file"); + let builder = ParquetRecordBatchReaderBuilder::try_new(reader_file) + .expect("reader builder"); + let metadata = builder.metadata().clone(); + let file_schema = builder.schema().clone(); + let schema_descr = metadata.file_metadata().schema_descr(); + + // Simulate SELECT * output projection: Column("id") and Column("s") + // Plus a get_field(s, 'value') expression from the pushed-down filter + let exprs: Vec> = vec![ + Arc::new(PhysicalColumn::new("id", 0)), + Arc::new(PhysicalColumn::new("s", 1)), + logical2physical( + &get_field().call(vec![ + col("s"), + Expr::Literal(ScalarValue::Utf8(Some("value".to_string())), None), + ]), + &file_schema, + ), + ]; + + let read_plan = build_projection_read_plan(exprs, &file_schema, schema_descr); + + // The projected schema must have the FULL struct type because Column("s") + // is in the projection. It should NOT be narrowed to Struct{value: Int32}. + let s_field = read_plan.projected_schema.field_with_name("s").unwrap(); + assert_eq!( + s_field.data_type(), + &DataType::Struct( + vec![ + Arc::new(Field::new("value", DataType::Int32, false)), + Arc::new(Field::new("label", DataType::Utf8, false)), + ] + .into() + ), + ); + + // all3 Parquet leaves should be in the projection mask + let expected_mask = ProjectionMask::leaves(schema_descr, [0, 1, 2]); + assert_eq!(read_plan.projection_mask, expected_mask,); + } + /// Sanity check that the given expression could be evaluated against the given schema without any errors. /// This will fail if the expression references columns that are not in the schema or if the types of the columns are incompatible, etc. fn check_expression_can_evaluate_against_schema( diff --git a/datafusion/datasource-parquet/src/row_group_filter.rs b/datafusion/datasource-parquet/src/row_group_filter.rs index 90e4e10d5ae8f..07f4fe92cf308 100644 --- a/datafusion/datasource-parquet/src/row_group_filter.rs +++ b/datafusion/datasource-parquet/src/row_group_filter.rs @@ -15,26 +15,26 @@ // specific language governing permissions and limitations // under the License. -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::sync::Arc; use super::{ParquetAccessPlan, ParquetFileMetrics}; -use arrow::array::{ArrayRef, BooleanArray}; +// Re-exported so the existing `crate::row_group_filter::BloomFilterStatistics` +// path keeps resolving for in-crate callers (e.g. `opener`). +pub(crate) use crate::bloom_filter::BloomFilterStatistics; +use arrow::array::{ArrayRef, BooleanArray, UInt64Array}; use arrow::datatypes::Schema; use datafusion_common::pruning::PruningStatistics; use datafusion_common::{Column, Result, ScalarValue}; use datafusion_datasource::FileRange; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::{BinaryExpr, IsNullExpr, NotExpr}; +use datafusion_physical_expr::utils::collect_columns; +use datafusion_physical_expr::{PhysicalExpr, PhysicalExprSimplifier}; use datafusion_pruning::PruningPredicate; use parquet::arrow::arrow_reader::statistics::StatisticsConverter; -use parquet::arrow::parquet_column; -use parquet::basic::Type; -use parquet::data_type::Decimal; +use parquet::file::metadata::RowGroupMetaData; use parquet::schema::types::SchemaDescriptor; -use parquet::{ - arrow::{async_reader::AsyncFileReader, ParquetRecordBatchStreamBuilder}, - bloom_filter::Sbbf, - file::metadata::RowGroupMetaData, -}; /// Reduces the [`ParquetAccessPlan`] based on row group level metadata. /// @@ -65,11 +65,159 @@ impl RowGroupAccessPlanFilter { self.access_plan.row_group_index_iter().count() } - /// Returns the inner access plan + /// Return indexes of row groups that still need to be scanned. + pub fn row_group_indexes(&self) -> impl Iterator + '_ { + self.access_plan.row_group_index_iter() + } + + /// Returns the inner access plan. pub fn build(self) -> ParquetAccessPlan { self.access_plan } + /// Returns a reference to the inner access plan. + /// + /// Test-only accessor used by the shared assertion helpers in + /// [`crate::test_util`]. + #[cfg(test)] + pub(crate) fn access_plan(&self) -> &ParquetAccessPlan { + &self.access_plan + } + + /// Returns the is_fully_matched vector. + pub fn is_fully_matched(&self) -> &Vec { + self.access_plan.fully_matched() + } + + /// Prunes the access plan based on the limit and fully contained row groups. + /// + /// The pruning works by leveraging the concept of fully matched row groups. Consider a query like: + /// `WHERE species LIKE 'Alpine%' AND s >= 50 LIMIT N` + /// + /// After initial filtering, row groups can be classified into three states: + /// + /// 1. Not Matching / Pruned + /// 2. Partially Matching (Row Group/Page contains some matches) + /// 3. Fully Matching (Entire range is within predicate) + /// + /// +-----------------------------------------------------------------------+ + /// | NOT MATCHING | + /// | Row group 1 | + /// | +-----------------------------------+-----------------------------+ | + /// | | SPECIES | S | | + /// | +-----------------------------------+-----------------------------+ | + /// | | Snow Vole | 7 | | + /// | | Brown Bear | 133 ✅ | | + /// | | Gray Wolf | 82 ✅ | | + /// | +-----------------------------------+-----------------------------+ | + /// +-----------------------------------------------------------------------+ + /// + /// +---------------------------------------------------------------------------+ + /// | PARTIALLY MATCHING | + /// | | + /// | Row group 2 Row group 4 | + /// | +------------------+--------------+ +------------------+----------+ | + /// | | SPECIES | S | | SPECIES | S | | + /// | +------------------+--------------+ +------------------+----------+ | + /// | | Lynx | 71 ✅ | | Europ. Mole | 4 | | + /// | | Red Fox | 40 | | Polecat | 16 | | + /// | | Alpine Bat ✅ | 6 | | Alpine Ibex ✅ | 97 ✅ | | + /// | +------------------+--------------+ +------------------+----------+ | + /// +---------------------------------------------------------------------------+ + /// + /// +-----------------------------------------------------------------------+ + /// | FULLY MATCHING | + /// | Row group 3 | + /// | +-----------------------------------+-----------------------------+ | + /// | | SPECIES | S | | + /// | +-----------------------------------+-----------------------------+ | + /// | | Alpine Ibex ✅ | 101 ✅ | | + /// | | Alpine Goat ✅ | 76 ✅ | | + /// | | Alpine Sheep ✅ | 83 ✅ | | + /// | +-----------------------------------+-----------------------------+ | + /// +-----------------------------------------------------------------------+ + /// + /// ### Identification of Fully Matching Row Groups + /// + /// DataFusion identifies row groups where ALL rows satisfy the filter by inverting the + /// predicate and checking if statistics prove the inverted version is false for the group. + /// + /// For example, prefix matches like `species LIKE 'Alpine%'` are pruned using ranges: + /// 1. Candidate Range: `species >= 'Alpine' AND species < 'Alpinf'` + /// 2. Inverted Condition (to prove full match): `species < 'Alpine' OR species >= 'Alpinf'` + /// 3. Statistical Evaluation (check if any row *could* satisfy the inverted condition): + /// `min < 'Alpine' OR max >= 'Alpinf'` + /// + /// If this evaluation is **false**, it proves no row can fail the original filter, + /// so the row group is **FULLY MATCHING**. + /// + /// ### Impact of Statistics Truncation + /// + /// The precision of pruning depends on the metadata quality. Truncated statistics + /// may prevent the system from proving a full match. + /// + /// **Example**: `WHERE species LIKE 'Alpine%'` (Target range: `['Alpine', 'Alpinf')`) + /// + /// | Truncation Length | min / max | Inverted Evaluation | Status | + /// |-------------------|---------------------|---------------------------------------------------------------------|------------------------| + /// | **Length 6** | `Alpine` / `Alpine` | `"Alpine" < "Alpine" (F) OR "Alpine" >= "Alpinf" (F)` -> **false** | **FULLY MATCHING** | + /// | **Length 3** | `Alp` / `Alq` | `"Alp" < "Alpine" (T) OR "Alq" >= "Alpinf" (T)` -> **true** | **PARTIALLY MATCHING** | + /// + /// Even though Row Group 3 only contains matching rows, truncation to length 3 makes + /// the statistics `[Alp, Alq]` too broad to prove it (they could include "Alpha"). + /// The system must conservatively scan the group. + /// + /// Without limit pruning: Scan Partition 2 → Partition 3 → Partition 4 (until limit reached) + /// With limit pruning: If Partition 3 contains enough rows to satisfy the limit, + /// skip Partitions 2 and 4 entirely and go directly to Partition 3. + /// + /// This optimization is particularly effective when: + /// - The limit is small relative to the total dataset size + /// - There are row groups that are fully matched by the filter predicates + /// - The fully matched row groups contain sufficient rows to satisfy the limit + /// + /// For more information, see the [paper](https://arxiv.org/pdf/2504.11540)'s "Pruning for LIMIT Queries" part + pub fn prune_by_limit( + &mut self, + limit: usize, + rg_metadata: &[RowGroupMetaData], + metrics: &ParquetFileMetrics, + ) { + let mut fully_matched_row_group_indexes: Vec = Vec::new(); + let mut fully_matched_rows_count: usize = 0; + + // Iterate through the currently accessible row groups and try to + // find a set of matching row groups that can satisfy the limit + for &idx in self.access_plan.row_group_indexes().iter() { + if self.access_plan.is_fully_matched(idx) { + let row_group_row_count = rg_metadata[idx].num_rows() as usize; + fully_matched_row_group_indexes.push(idx); + fully_matched_rows_count += row_group_row_count; + if fully_matched_rows_count >= limit { + break; + } + } + } + + // If we can satisfy the limit with fully matching row groups, + // rewrite the plan to do so + if fully_matched_rows_count >= limit { + let original_num_accessible_row_groups = + self.access_plan.row_group_indexes().len(); + let new_num_accessible_row_groups = fully_matched_row_group_indexes.len(); + let pruned_count = original_num_accessible_row_groups + .saturating_sub(new_num_accessible_row_groups); + metrics.limit_pruned_row_groups.add_pruned(pruned_count); + + let mut new_access_plan = ParquetAccessPlan::new_none(rg_metadata.len()); + for &idx in &fully_matched_row_group_indexes { + new_access_plan.scan(idx); + new_access_plan.mark_fully_matched(idx); + } + self.access_plan = new_access_plan; + } + } + /// Prune remaining row groups to only those within the specified range. /// /// Updates this set to mark row groups that should not be scanned @@ -130,20 +278,36 @@ impl RowGroupAccessPlanFilter { parquet_schema, row_group_metadatas, arrow_schema, + // Preserve the existing row-group pruning behavior. This path only + // proves whether matching rows may exist, so it uses the + // StatisticsConverter default for older parquet-rs files where a + // missing null count can mean there are zero nulls. + missing_null_counts_as_zero: true, }; // try to prune the row groups in a single call match predicate.prune(&pruning_stats) { Ok(values) => { - // values[i] is false means the predicate could not be true for row group i + let mut fully_contained_candidates_original_idx: Vec = Vec::new(); for (idx, &value) in row_group_indexes.iter().zip(values.iter()) { if !value { self.access_plan.skip(*idx); metrics.row_groups_pruned_statistics.add_pruned(1); } else { metrics.row_groups_pruned_statistics.add_matched(1); + fully_contained_candidates_original_idx.push(*idx); } } + + // Check if any of the matched row groups are fully contained by the predicate + self.identify_fully_matched_row_groups( + &fully_contained_candidates_original_idx, + arrow_schema, + parquet_schema, + groups, + predicate, + metrics, + ); } // stats filter array could not be built, so we can't prune Err(e) => { @@ -153,62 +317,122 @@ impl RowGroupAccessPlanFilter { } } - /// Prune remaining row groups using available bloom filters and the + /// Identifies row groups that are fully matched by the predicate. + /// + /// This optimization checks whether all rows in a row group satisfy the predicate + /// by inverting the predicate and checking if it prunes the row group. If the + /// inverted predicate prunes a row group, it means no rows match the inverted + /// predicate, which implies all rows match the original predicate. + /// + /// Note: This optimization is relatively inexpensive for a limited number of row groups. + fn identify_fully_matched_row_groups( + &mut self, + candidate_row_group_indices: &[usize], + arrow_schema: &Schema, + parquet_schema: &SchemaDescriptor, + groups: &[RowGroupMetaData], + predicate: &PruningPredicate, + metrics: &ParquetFileMetrics, + ) { + if candidate_row_group_indices.is_empty() { + return; + } + + let mut inverted_expr: Arc = + Arc::new(NotExpr::new(Arc::clone(predicate.orig_expr()))); + + // Rows where the predicate evaluates to NULL do not pass the filter. + // Include NULL checks in the inverted expression so a row group is only + // considered fully matched when every referenced column is known non-null. + // This is conservative for null-accepting predicates, but fully matched + // row groups must not have false positives. + let mut columns = collect_columns(predicate.orig_expr()) + .into_iter() + .filter(|column| arrow_schema.field(column.index()).is_nullable()) + .collect::>(); + columns.sort_by(|a, b| { + a.index() + .cmp(&b.index()) + .then_with(|| a.name().cmp(b.name())) + }); + + for column in columns { + inverted_expr = Arc::new(BinaryExpr::new( + inverted_expr, + Operator::Or, + Arc::new(IsNullExpr::new(Arc::new(column))), + )); + } + + // Simplify the inverted expression (e.g., NOT(c1 = 0) -> c1 != 0) + // before building the pruning predicate + let simplifier = PhysicalExprSimplifier::new(arrow_schema); + let Ok(inverted_expr) = simplifier.simplify(inverted_expr) else { + return; + }; + + let Ok(inverted_predicate) = + PruningPredicate::try_new(inverted_expr, Arc::clone(predicate.schema())) + else { + return; + }; + + let inverted_pruning_stats = RowGroupPruningStatistics { + parquet_schema, + row_group_metadatas: candidate_row_group_indices + .iter() + .map(|&i| &groups[i]) + .collect::>(), + arrow_schema, + // Fully matched row groups require a stronger proof: every row + // must pass the predicate. Missing null counts are unknown here; + // treating them as zero can incorrectly mark nullable row groups as + // fully matched and make limit pruning unsound. + missing_null_counts_as_zero: false, + }; + + let Ok(inverted_values) = inverted_predicate.prune(&inverted_pruning_stats) + else { + return; + }; + + for (i, &original_row_group_idx) in candidate_row_group_indices.iter().enumerate() + { + // If the inverted predicate *also* prunes this row group (meaning inverted_values[i] is false), + // it implies that *all* rows in this group satisfy the original predicate. + if !inverted_values[i] { + self.access_plan.mark_fully_matched(original_row_group_idx); + metrics.row_groups_pruned_statistics.add_fully_matched(1); + } + } + } + + /// Prune remaining row groups using loaded bloom filters and the /// [`PruningPredicate`]. /// - /// Updates this set with row groups that should not be scanned + /// Updates this set with row groups that should not be scanned. + /// `row_group_bloom_filters[idx]` contains the bloom filters for the + /// parquet row group at index `idx`. /// /// # Panics - /// if the builder does not have the same number of row groups as this set - pub async fn prune_by_bloom_filters( + /// if `row_group_bloom_filters` does not have the same number of row groups as this set + pub(crate) fn prune_by_bloom_filters( &mut self, - arrow_schema: &Schema, - builder: &mut ParquetRecordBatchStreamBuilder, predicate: &PruningPredicate, metrics: &ParquetFileMetrics, + row_group_bloom_filters: &[BloomFilterStatistics], ) { // scoped timer updates on drop let _timer_guard = metrics.bloom_filter_eval_time.timer(); - assert_eq!(builder.metadata().num_row_groups(), self.access_plan.len()); - for idx in 0..self.access_plan.len() { + assert_eq!(row_group_bloom_filters.len(), self.access_plan.len()); + for (idx, stats) in row_group_bloom_filters.iter().enumerate() { if !self.access_plan.should_scan(idx) { continue; } - // Attempt to find bloom filters for filtering this row group - let literal_columns = predicate.literal_columns(); - let mut column_sbbf = HashMap::with_capacity(literal_columns.len()); - - for column_name in literal_columns { - let Some((column_idx, _field)) = - parquet_column(builder.parquet_schema(), arrow_schema, &column_name) - else { - continue; - }; - - let bf = match builder - .get_row_group_column_bloom_filter(idx, column_idx) - .await - { - Ok(Some(bf)) => bf, - Ok(None) => continue, // no bloom filter for this column - Err(e) => { - log::debug!("Ignoring error reading bloom filter: {e}"); - metrics.predicate_evaluation_errors.add(1); - continue; - } - }; - let physical_type = - builder.parquet_schema().column(column_idx).physical_type(); - - column_sbbf.insert(column_name.to_string(), (bf, physical_type)); - } - - let stats = BloomFilterStatistics { column_sbbf }; - // Can this group be pruned? - let prune_group = match predicate.prune(&stats) { + let prune_group = match predicate.prune(stats) { Ok(values) => !values[0], Err(e) => { log::debug!( @@ -228,149 +452,13 @@ impl RowGroupAccessPlanFilter { } } } -/// Implements [`PruningStatistics`] for Parquet Split Block Bloom Filters (SBBF) -struct BloomFilterStatistics { - /// Maps column name to the parquet bloom filter and parquet physical type - column_sbbf: HashMap, -} - -impl BloomFilterStatistics { - /// Helper function for checking if [`Sbbf`] filter contains [`ScalarValue`]. - /// - /// In case the type of scalar is not supported, returns `true`, assuming that the - /// value may be present. - fn check_scalar(sbbf: &Sbbf, value: &ScalarValue, parquet_type: &Type) -> bool { - match value { - ScalarValue::Utf8(Some(v)) - | ScalarValue::Utf8View(Some(v)) - | ScalarValue::LargeUtf8(Some(v)) => sbbf.check(&v.as_str()), - ScalarValue::Binary(Some(v)) - | ScalarValue::BinaryView(Some(v)) - | ScalarValue::LargeBinary(Some(v)) => sbbf.check(v), - ScalarValue::FixedSizeBinary(_size, Some(v)) => sbbf.check(v), - ScalarValue::Boolean(Some(v)) => sbbf.check(v), - ScalarValue::Float64(Some(v)) => sbbf.check(v), - ScalarValue::Float32(Some(v)) => sbbf.check(v), - ScalarValue::Int64(Some(v)) => sbbf.check(v), - ScalarValue::Int32(Some(v)) => sbbf.check(v), - ScalarValue::UInt64(Some(v)) => sbbf.check(v), - ScalarValue::UInt32(Some(v)) => sbbf.check(v), - ScalarValue::Decimal128(Some(v), p, s) => match parquet_type { - Type::INT32 => { - //https://github.com/apache/parquet-format/blob/eb4b31c1d64a01088d02a2f9aefc6c17c54cc6fc/Encodings.md?plain=1#L35-L42 - // All physical type are little-endian - if *p > 9 { - //DECIMAL can be used to annotate the following types: - // - // int32: for 1 <= precision <= 9 - // int64: for 1 <= precision <= 18 - return true; - } - let b = (*v as i32).to_le_bytes(); - // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 - let decimal = Decimal::Int32 { - value: b, - precision: *p as i32, - scale: *s as i32, - }; - sbbf.check(&decimal) - } - Type::INT64 => { - if *p > 18 { - return true; - } - let b = (*v as i64).to_le_bytes(); - let decimal = Decimal::Int64 { - value: b, - precision: *p as i32, - scale: *s as i32, - }; - sbbf.check(&decimal) - } - Type::FIXED_LEN_BYTE_ARRAY => { - // keep with from_bytes_to_i128 - let b = v.to_be_bytes().to_vec(); - // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 - let decimal = Decimal::Bytes { - value: b.into(), - precision: *p as i32, - scale: *s as i32, - }; - sbbf.check(&decimal) - } - _ => true, - }, - ScalarValue::Dictionary(_, inner) => { - BloomFilterStatistics::check_scalar(sbbf, inner, parquet_type) - } - _ => true, - } - } -} - -impl PruningStatistics for BloomFilterStatistics { - fn min_values(&self, _column: &Column) -> Option { - None - } - - fn max_values(&self, _column: &Column) -> Option { - None - } - - fn num_containers(&self) -> usize { - 1 - } - - fn null_counts(&self, _column: &Column) -> Option { - None - } - - fn row_counts(&self, _column: &Column) -> Option { - None - } - - /// Use bloom filters to determine if we are sure this column can not - /// possibly contain `values` - /// - /// The `contained` API returns false if the bloom filters knows that *ALL* - /// of the values in a column are not present. - fn contained( - &self, - column: &Column, - values: &HashSet, - ) -> Option { - let (sbbf, parquet_type) = self.column_sbbf.get(column.name.as_str())?; - - // Bloom filters are probabilistic data structures that can return false - // positives (i.e. it might return true even if the value is not - // present) however, the bloom filter will return `false` if the value is - // definitely not present. - - let known_not_present = values - .iter() - .map(|value| BloomFilterStatistics::check_scalar(sbbf, value, parquet_type)) - // The row group doesn't contain any of the values if - // all the checks are false - .all(|v| !v); - - let contains = if known_not_present { - Some(false) - } else { - // Given the bloom filter is probabilistic, we can't be sure that - // the row group actually contains the values. Return `None` to - // indicate this uncertainty - None - }; - - Some(BooleanArray::from(vec![contains])) - } -} /// Wraps a slice of [`RowGroupMetaData`] in a way that implements [`PruningStatistics`] struct RowGroupPruningStatistics<'a> { parquet_schema: &'a SchemaDescriptor, row_group_metadatas: Vec<&'a RowGroupMetaData>, arrow_schema: &'a Schema, + missing_null_counts_as_zero: bool, } impl<'a> RowGroupPruningStatistics<'a> { @@ -387,7 +475,8 @@ impl<'a> RowGroupPruningStatistics<'a> { &column.name, self.arrow_schema, self.parquet_schema, - )?) + )? + .with_missing_null_counts_as_zero(self.missing_null_counts_as_zero)) } } @@ -415,13 +504,13 @@ impl PruningStatistics for RowGroupPruningStatistics<'_> { .map(|counts| Arc::new(counts) as ArrayRef) } - fn row_counts(&self, column: &Column) -> Option { - // row counts are the same for all columns in a row group - self.statistics_converter(column) - .and_then(|c| Ok(c.row_group_row_counts(self.metadata_iter())?)) - .ok() - .flatten() - .map(|counts| Arc::new(counts) as ArrayRef) + fn row_counts(&self) -> Option { + // Row counts are container-level — read directly from row group metadata. + let counts: UInt64Array = self + .metadata_iter() + .map(|rg| Some(rg.num_rows() as u64)) + .collect(); + Some(Arc::new(counts) as ArrayRef) } fn contained( @@ -436,18 +525,15 @@ impl PruningStatistics for RowGroupPruningStatistics<'_> { #[cfg(test)] mod tests { use std::ops::Rem; - use std::sync::Arc; use super::*; - use crate::reader::ParquetFileReader; + use crate::test_util::ExpectedPruning; use arrow::datatypes::DataType::Decimal128; use arrow::datatypes::{DataType, Field}; - use datafusion_common::Result; - use datafusion_expr::{cast, col, lit, Expr}; + use datafusion_expr::{cast, col, lit}; use datafusion_physical_expr::planner::logical2physical; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; - use parquet::arrow::async_reader::ParquetObjectReader; use parquet::arrow::ArrowSchemaConverter; use parquet::basic::LogicalType; use parquet::data_type::{ByteArray, FixedLenByteArray}; @@ -556,6 +642,65 @@ mod tests { assert_pruned(row_groups, ExpectedPruning::Some(vec![1])) } + #[test] + fn row_group_fully_matched_requires_known_non_null_predicate_columns() { + use datafusion_expr::{col, lit}; + + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + let expr = logical2physical(&col("c1").gt(lit(15)), &schema); + let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + + let field = PrimitiveTypeField::new("c1", PhysicalType::INT32); + let schema_descr = get_test_schema_descr(vec![field]); + + // All three row groups have non-null values in the predicate range, + // so none are pruned. Only the second row group can be proven fully + // matched because it is the only one with a known zero null count. + let rg_with_null = get_row_group_meta_data( + &schema_descr, + vec![ParquetStatistics::int32( + Some(16), + Some(20), + None, + Some(1), + false, + )], + ); + let rg_without_null = get_row_group_meta_data( + &schema_descr, + vec![ParquetStatistics::int32( + Some(16), + Some(20), + None, + Some(0), + false, + )], + ); + let rg_unknown_null_count = get_row_group_meta_data( + &schema_descr, + vec![ParquetStatistics::int32( + Some(16), + Some(20), + None, + None, + false, + )], + ); + + let metrics = parquet_file_metrics(); + let mut row_groups = RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all(3)); + row_groups.prune_by_statistics( + &schema, + &schema_descr, + &[rg_with_null, rg_without_null, rg_unknown_null_count], + &pruning_predicate, + &metrics, + ); + + assert_eq!(row_groups.access_plan.row_group_indexes(), vec![0, 1, 2]); + assert_eq!(row_groups.is_fully_matched(), &vec![false, true, false]); + } + #[test] fn row_group_pruning_predicate_missing_stats() { use datafusion_expr::{col, lit}; @@ -1227,362 +1372,7 @@ mod tests { ParquetFileMetrics::new(0, "file.parquet", &metrics) } - #[tokio::test] - async fn test_row_group_bloom_filter_pruning_predicate_simple_expr() { - BloomFilterTest::new_data_index_bloom_encoding_stats() - .with_expect_all_pruned() - // generate pruning predicate `(String = "Hello_Not_exists")` - .run(col(r#""String""#).eq(lit("Hello_Not_Exists"))) - .await - } - - #[tokio::test] - async fn test_row_group_bloom_filter_pruning_predicate_multiple_expr() { - BloomFilterTest::new_data_index_bloom_encoding_stats() - .with_expect_all_pruned() - // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` - .run( - lit("1").eq(lit("1")).and( - col(r#""String""#) - .eq(lit("Hello_Not_Exists")) - .or(col(r#""String""#).eq(lit("Hello_Not_Exists2"))), - ), - ) - .await - } - - #[tokio::test] - async fn test_row_group_bloom_filter_pruning_predicate_multiple_expr_view() { - BloomFilterTest::new_data_index_bloom_encoding_stats() - .with_expect_all_pruned() - // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` - .run( - lit("1").eq(lit("1")).and( - col(r#""String""#) - .eq(Expr::Literal( - ScalarValue::Utf8View(Some(String::from("Hello_Not_Exists"))), - None, - )) - .or(col(r#""String""#).eq(Expr::Literal( - ScalarValue::Utf8View(Some(String::from( - "Hello_Not_Exists2", - ))), - None, - ))), - ), - ) - .await - } - - #[tokio::test] - async fn test_row_group_bloom_filter_pruning_predicate_sql_in() { - // load parquet file - let testdata = datafusion_common::test_util::parquet_test_data(); - let file_name = "data_index_bloom_encoding_stats.parquet"; - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - - // generate pruning predicate - let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); - - let expr = col(r#""String""#).in_list( - (1..25) - .map(|i| lit(format!("Hello_Not_Exists{i}"))) - .collect::>(), - false, - ); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - file_name, - data, - &pruning_predicate, - ) - .await - .unwrap(); - assert!(pruned_row_groups.access_plan.row_group_indexes().is_empty()); - } - - #[tokio::test] - async fn test_row_group_bloom_filter_pruning_predicate_with_exists_value() { - BloomFilterTest::new_data_index_bloom_encoding_stats() - .with_expect_none_pruned() - // generate pruning predicate `(String = "Hello")` - .run(col(r#""String""#).eq(lit("Hello"))) - .await - } - - #[tokio::test] - async fn test_row_group_bloom_filter_pruning_predicate_with_exists_2_values() { - BloomFilterTest::new_data_index_bloom_encoding_stats() - .with_expect_none_pruned() - // generate pruning predicate `(String = "Hello") OR (String = "the quick")` - .run( - col(r#""String""#) - .eq(lit("Hello")) - .or(col(r#""String""#).eq(lit("the quick"))), - ) - .await - } - - #[tokio::test] - async fn test_row_group_bloom_filter_pruning_predicate_with_exists_3_values() { - BloomFilterTest::new_data_index_bloom_encoding_stats() - .with_expect_none_pruned() - // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` - .run( - col(r#""String""#) - .eq(lit("Hello")) - .or(col(r#""String""#).eq(lit("the quick"))) - .or(col(r#""String""#).eq(lit("are you"))), - ) - .await - } - - #[tokio::test] - async fn test_row_group_bloom_filter_pruning_predicate_with_exists_3_values_view() { - BloomFilterTest::new_data_index_bloom_encoding_stats() - .with_expect_none_pruned() - // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` - .run( - col(r#""String""#) - .eq(Expr::Literal( - ScalarValue::Utf8View(Some(String::from("Hello"))), - None, - )) - .or(col(r#""String""#).eq(Expr::Literal( - ScalarValue::Utf8View(Some(String::from("the quick"))), - None, - ))) - .or(col(r#""String""#).eq(Expr::Literal( - ScalarValue::Utf8View(Some(String::from("are you"))), - None, - ))), - ) - .await - } - - #[tokio::test] - async fn test_row_group_bloom_filter_pruning_predicate_with_or_not_eq() { - BloomFilterTest::new_data_index_bloom_encoding_stats() - .with_expect_none_pruned() - // generate pruning predicate `(String = "foo") OR (String != "bar")` - .run( - col(r#""String""#) - .not_eq(lit("foo")) - .or(col(r#""String""#).not_eq(lit("bar"))), - ) - .await - } - - #[tokio::test] - async fn test_row_group_bloom_filter_pruning_predicate_without_bloom_filter() { - // generate pruning predicate on a column without a bloom filter - BloomFilterTest::new_all_types() - .with_expect_none_pruned() - .run(col(r#""string_col""#).eq(lit("0"))) - .await - } - - // What row groups are expected to be left after pruning - #[derive(Debug)] - enum ExpectedPruning { - All, - /// Only the specified row groups are expected to REMAIN (not what is pruned) - Some(Vec), - None, - } - - impl ExpectedPruning { - /// asserts that the pruned row group match this expectation - fn assert(&self, row_groups: &RowGroupAccessPlanFilter) { - let num_row_groups = row_groups.access_plan.len(); - assert!(num_row_groups > 0); - let num_pruned = (0..num_row_groups) - .filter_map(|i| { - if row_groups.access_plan.should_scan(i) { - None - } else { - Some(1) - } - }) - .sum::(); - - match self { - Self::All => { - assert_eq!( - num_row_groups, num_pruned, - "Expected all row groups to be pruned, but got {row_groups:?}" - ); - } - ExpectedPruning::None => { - assert_eq!( - num_pruned, 0, - "Expected no row groups to be pruned, but got {row_groups:?}" - ); - } - ExpectedPruning::Some(expected) => { - let actual = row_groups.access_plan.row_group_indexes(); - assert_eq!(expected, &actual, "Unexpected row groups pruned. Expected {expected:?}, got {actual:?}"); - } - } - } - } - fn assert_pruned(row_groups: RowGroupAccessPlanFilter, expected: ExpectedPruning) { expected.assert(&row_groups); } - - struct BloomFilterTest { - file_name: String, - schema: Schema, - // which row groups are expected to be left after pruning - post_pruning_row_groups: ExpectedPruning, - } - - impl BloomFilterTest { - /// Return a test for data_index_bloom_encoding_stats.parquet - /// Note the values in the `String` column are: - /// ```sql - /// > select * from './parquet-testing/data/data_index_bloom_encoding_stats.parquet'; - /// +-----------+ - /// | String | - /// +-----------+ - /// | Hello | - /// | This is | - /// | a | - /// | test | - /// | How | - /// | are you | - /// | doing | - /// | today | - /// | the quick | - /// | brown fox | - /// | jumps | - /// | over | - /// | the lazy | - /// | dog | - /// +-----------+ - /// ``` - fn new_data_index_bloom_encoding_stats() -> Self { - Self { - file_name: String::from("data_index_bloom_encoding_stats.parquet"), - schema: Schema::new(vec![Field::new("String", DataType::Utf8, false)]), - post_pruning_row_groups: ExpectedPruning::None, - } - } - - // Return a test for alltypes_plain.parquet - fn new_all_types() -> Self { - Self { - file_name: String::from("alltypes_plain.parquet"), - schema: Schema::new(vec![Field::new( - "string_col", - DataType::Utf8, - false, - )]), - post_pruning_row_groups: ExpectedPruning::None, - } - } - - /// Expect all row groups to be pruned - pub fn with_expect_all_pruned(mut self) -> Self { - self.post_pruning_row_groups = ExpectedPruning::All; - self - } - - /// Expect all row groups not to be pruned - pub fn with_expect_none_pruned(mut self) -> Self { - self.post_pruning_row_groups = ExpectedPruning::None; - self - } - - /// Prune this file using the specified expression and check that the expected row groups are left - async fn run(self, expr: Expr) { - let Self { - file_name, - schema, - post_pruning_row_groups, - } = self; - - let testdata = datafusion_common::test_util::parquet_test_data(); - let path = format!("{testdata}/{file_name}"); - let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - - let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( - &file_name, - data, - &pruning_predicate, - ) - .await - .unwrap(); - - post_pruning_row_groups.assert(&pruned_row_groups); - } - } - - /// Evaluates the pruning predicate on the specified row groups and returns the row groups that are left - async fn test_row_group_bloom_filter_pruning_predicate( - file_name: &str, - data: bytes::Bytes, - pruning_predicate: &PruningPredicate, - ) -> Result { - use datafusion_datasource::PartitionedFile; - use object_store::{ObjectMeta, ObjectStore}; - - let object_meta = ObjectMeta { - location: object_store::path::Path::parse(file_name).expect("creating path"), - last_modified: chrono::DateTime::from(std::time::SystemTime::now()), - size: data.len() as u64, - e_tag: None, - version: None, - }; - let in_memory = object_store::memory::InMemory::new(); - in_memory - .put(&object_meta.location, data.into()) - .await - .expect("put parquet file into in memory object store"); - - let metrics = ExecutionPlanMetricsSet::new(); - let file_metrics = - ParquetFileMetrics::new(0, object_meta.location.as_ref(), &metrics); - let inner = - ParquetObjectReader::new(Arc::new(in_memory), object_meta.location.clone()) - .with_file_size(object_meta.size); - - let partitioned_file = PartitionedFile { - object_meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; - - let reader = ParquetFileReader { - inner, - file_metrics: file_metrics.clone(), - partitioned_file, - }; - let mut builder = ParquetRecordBatchStreamBuilder::new(reader).await.unwrap(); - - let access_plan = ParquetAccessPlan::new_all(builder.metadata().num_row_groups()); - let mut pruned_row_groups = RowGroupAccessPlanFilter::new(access_plan); - pruned_row_groups - .prune_by_bloom_filters( - pruning_predicate.schema(), - &mut builder, - pruning_predicate, - &file_metrics, - ) - .await; - - Ok(pruned_row_groups) - } } diff --git a/datafusion/datasource-parquet/src/schema_coercion.rs b/datafusion/datasource-parquet/src/schema_coercion.rs new file mode 100644 index 0000000000000..4598bb525be32 --- /dev/null +++ b/datafusion/datasource-parquet/src/schema_coercion.rs @@ -0,0 +1,843 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Arrow-schema coercion utilities used by the Parquet reader to make a +//! file schema match the table schema (binary→string, regular→view, +//! INT96→Timestamp). +//! +//! These helpers are independent of the [`ParquetFormat`](crate::file_format::ParquetFormat) +//! type and several have been re-exported at the crate root for use by +//! callers outside the format implementation. + +use std::cell::RefCell; +use std::collections::{HashMap, HashSet}; +use std::rc::Rc; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef, Schema, TimeUnit}; +use parquet::basic::Type; +use parquet::schema::types::SchemaDescriptor; + +/// Apply necessary schema type coercions to make file schema match table schema. +/// +/// This function performs two main types of transformations in a single pass: +/// 1. Binary types to string types conversion - Converts binary data types to their +/// corresponding string types when the table schema expects string data +/// 2. Regular to view types conversion - Converts standard string/binary types to +/// view types when the table schema uses view types +/// +/// # Arguments +/// * `table_schema` - The table schema containing the desired types +/// * `file_schema` - The file schema to be transformed +/// +/// # Returns +/// * `Some(Schema)` - If any transformations were applied, returns the transformed schema +/// * `None` - If no transformations were needed +pub fn apply_file_schema_type_coercions( + table_schema: &Schema, + file_schema: &Schema, +) -> Option { + let mut needs_view_transform = false; + let mut needs_string_transform = false; + + // Create a mapping of table field names to their data types for fast lookup + // and simultaneously check if we need any transformations + let table_fields: HashMap<_, _> = table_schema + .fields() + .iter() + .map(|f| { + let dt = f.data_type(); + // Check if we need view type transformation + if matches!(dt, &DataType::Utf8View | &DataType::BinaryView) { + needs_view_transform = true; + } + // Check if we need string type transformation + if matches!( + dt, + &DataType::Utf8 | &DataType::LargeUtf8 | &DataType::Utf8View + ) { + needs_string_transform = true; + } + + (f.name(), dt) + }) + .collect(); + + // Early return if no transformation needed + if !needs_view_transform && !needs_string_transform { + return None; + } + + let transformed_fields: Vec> = file_schema + .fields() + .iter() + .map(|field| { + let field_name = field.name(); + let field_type = field.data_type(); + + // Look up the corresponding field type in the table schema + if let Some(table_type) = table_fields.get(field_name) { + match (table_type, field_type) { + // table schema uses string type, coerce the file schema to use string type + ( + &DataType::Utf8, + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + return field_with_new_type(field, DataType::Utf8); + } + // table schema uses large string type, coerce the file schema to use large string type + ( + &DataType::LargeUtf8, + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + return field_with_new_type(field, DataType::LargeUtf8); + } + // table schema uses string view type, coerce the file schema to use view type + ( + &DataType::Utf8View, + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + return field_with_new_type(field, DataType::Utf8View); + } + // Handle view type conversions + (&DataType::Utf8View, DataType::Utf8 | DataType::LargeUtf8) => { + return field_with_new_type(field, DataType::Utf8View); + } + (&DataType::BinaryView, DataType::Binary | DataType::LargeBinary) => { + return field_with_new_type(field, DataType::BinaryView); + } + _ => {} + } + } + + // If no transformation is needed, keep the original field + Arc::clone(field) + }) + .collect(); + + Some(Schema::new_with_metadata( + transformed_fields, + file_schema.metadata.clone(), + )) +} + +/// Coerces the file schema's Timestamps to the provided TimeUnit if the +/// Parquet schema contains INT96. +/// +/// Deprecated wrapper around [`Int96Coercer`]; use the builder directly +/// instead — it also supports attaching a timezone via +/// [`Int96Coercer::with_timezone`]. +#[deprecated(since = "53.2.0", note = "use `Int96Coercer` instead")] +pub fn coerce_int96_to_resolution( + parquet_schema: &SchemaDescriptor, + file_schema: &Schema, + time_unit: &TimeUnit, +) -> Option { + Int96Coercer::new(parquet_schema, file_schema, time_unit).coerce() +} + +/// Builder for coercing INT96-originated Timestamp columns in `file_schema` +/// to a specific [`TimeUnit`], optionally attaching a timezone. +/// +/// INT96 is the legacy Parquet representation that systems like Spark use for +/// timestamps. Arrow surfaces it as `Timestamp(Nanosecond, None)`, but the +/// underlying values are written as UTC-adjusted instants. Use this builder +/// to: +/// +/// - Coerce INT96-derived columns to a smaller [`TimeUnit`] (e.g. microseconds) +/// to extend the representable date range. +/// - Optionally attach a timezone so the resulting Arrow type carries the +/// timezone-aware semantic (`Timestamp(unit, Some(tz))`). Without a +/// timezone, INT96-derived columns become `Timestamp(unit, None)` — the +/// historical default. +/// +/// Returns `None` if `file_schema` contains no INT96-derived columns. +/// +/// # Example +/// +/// ```ignore +/// use std::sync::Arc; +/// use arrow::datatypes::TimeUnit; +/// use datafusion_datasource_parquet::Int96Coercer; +/// +/// let coerced = Int96Coercer::new(parquet_schema, file_schema, &TimeUnit::Microsecond) +/// .with_timezone(Some(Arc::from("UTC"))) +/// .coerce(); +/// ``` +pub struct Int96Coercer<'a> { + parquet_schema: &'a SchemaDescriptor, + file_schema: &'a Schema, + time_unit: &'a TimeUnit, + timezone: Option>, +} + +impl<'a> Int96Coercer<'a> { + /// Create a new builder. INT96-derived columns will coerce to + /// `Timestamp(time_unit, None)` unless [`Self::with_timezone`] is set. + pub fn new( + parquet_schema: &'a SchemaDescriptor, + file_schema: &'a Schema, + time_unit: &'a TimeUnit, + ) -> Self { + Self { + parquet_schema, + file_schema, + time_unit, + timezone: None, + } + } + + /// Attach a timezone to INT96-derived columns. When `Some`, INT96-derived + /// columns coerce to `Timestamp(time_unit, Some(timezone))` instead of + /// the default `Timestamp(time_unit, None)`. Spark and other systems + /// write INT96 as UTC-adjusted instants, so callers that need the + /// resulting Arrow type to be timezone-aware should pass + /// `Some(Arc::from("UTC"))`. + pub fn with_timezone(mut self, timezone: Option>) -> Self { + self.timezone = timezone; + self + } + + /// Run the coercion, returning the rewritten schema or `None` if + /// `file_schema` contains no INT96-derived columns. + pub fn coerce(self) -> Option { + let Self { + parquet_schema, + file_schema, + time_unit, + timezone, + } = self; + coerce_int96_to_resolution_impl( + parquet_schema, + file_schema, + time_unit, + timezone.as_ref(), + ) + } +} + +fn coerce_int96_to_resolution_impl( + parquet_schema: &SchemaDescriptor, + file_schema: &Schema, + time_unit: &TimeUnit, + timezone: Option<&Arc>, +) -> Option { + // Traverse the parquet_schema columns looking for int96 physical types. If encountered, insert + // the field's full path into a set. + let int96_fields: HashSet<_> = parquet_schema + .columns() + .iter() + .filter(|f| f.physical_type() == Type::INT96) + .map(|f| f.path().string()) + .collect(); + + if int96_fields.is_empty() { + // The schema doesn't contain any int96 fields, so skip the remaining logic. + return None; + } + + // Do a DFS into the schema using a stack, looking for timestamp(nanos) fields that originated + // as int96 to coerce to the provided time_unit. + + type NestedFields = Rc>>; + type StackContext<'a> = ( + Vec<&'a str>, // The Parquet column path (e.g., "c0.list.element.c1") for the current field. + &'a FieldRef, // The current field to be processed. + NestedFields, // The parent's fields that this field will be (possibly) type-coerced and + // inserted into. All fields have a parent, so this is not an Option type. + Option, // Nested types need to create their own vector of fields for their + // children. For primitive types this will remain None. For nested + // types it is None the first time they are processed. Then, we + // instantiate a vector for its children, push the field back onto the + // stack to be processed again, and DFS into its children. The next + // time we process the field, we know we have DFS'd into the children + // because this field is Some. + ); + + // This is our top-level fields from which we will construct our schema. We pass this into our + // initial stack context as the parent fields, and the DFS populates it. + let fields = Rc::new(RefCell::new(Vec::with_capacity(file_schema.fields.len()))); + + // TODO: It might be possible to only DFS into nested fields that we know contain an int96 if we + // use some sort of LPM data structure to check if we're currently DFS'ing nested types that are + // in a column path that contains an int96. That can be a future optimization for large schemas. + let transformed_schema = { + // Populate the stack with our top-level fields. + let mut stack: Vec = file_schema + .fields() + .iter() + .rev() + .map(|f| (vec![f.name().as_str()], f, Rc::clone(&fields), None)) + .collect(); + + // Pop fields to DFS into until we have exhausted the stack. + while let Some((parquet_path, current_field, parent_fields, child_fields)) = + stack.pop() + { + match (current_field.data_type(), child_fields) { + (DataType::Struct(unprocessed_children), None) => { + // This is the first time popping off this struct. We don't yet know the + // correct types of its children (i.e., if they need coercing) so we create + // a vector for child_fields, push the struct node back onto the stack to be + // processed again (see below) after processing all its children. + let child_fields = Rc::new(RefCell::new(Vec::with_capacity( + unprocessed_children.len(), + ))); + // Note that here we push the struct back onto the stack with its + // parent_fields in the same position, now with Some(child_fields). + stack.push(( + parquet_path.clone(), + current_field, + parent_fields, + Some(Rc::clone(&child_fields)), + )); + // Push all the children in reverse to maintain original schema order due to + // stack processing. + for child in unprocessed_children.into_iter().rev() { + let mut child_path = parquet_path.clone(); + // Build up a normalized path that we'll use as a key into the original + // int96_fields set above to test if this originated as int96. + child_path.push("."); + child_path.push(child.name()); + // Note that here we push the field onto the stack using the struct's + // new child_fields vector as the field's parent_fields. + stack.push((child_path, child, Rc::clone(&child_fields), None)); + } + } + (DataType::Struct(unprocessed_children), Some(processed_children)) => { + // This is the second time popping off this struct. The child_fields vector + // now contains each field that has been DFS'd into, and we can construct + // the resulting struct with correct child types. + let processed_children = processed_children.borrow(); + assert_eq!(processed_children.len(), unprocessed_children.len()); + let processed_struct = Field::new_struct( + current_field.name(), + processed_children.as_slice(), + current_field.is_nullable(), + ); + parent_fields.borrow_mut().push(Arc::new(processed_struct)); + } + (DataType::List(unprocessed_child), None) => { + // This is the first time popping off this list. See struct docs above. + let child_fields = Rc::new(RefCell::new(Vec::with_capacity(1))); + stack.push(( + parquet_path.clone(), + current_field, + parent_fields, + Some(Rc::clone(&child_fields)), + )); + let mut child_path = parquet_path.clone(); + // Spark uses a definition for arrays/lists that results in a group + // named "list" that is not maintained when parsing to Arrow. We just push + // this name into the path. + child_path.push(".list."); + child_path.push(unprocessed_child.name()); + stack.push(( + child_path.clone(), + unprocessed_child, + Rc::clone(&child_fields), + None, + )); + } + (DataType::List(_), Some(processed_children)) => { + // This is the second time popping off this list. See struct docs above. + let processed_children = processed_children.borrow(); + assert_eq!(processed_children.len(), 1); + let processed_list = Field::new_list( + current_field.name(), + Arc::clone(&processed_children[0]), + current_field.is_nullable(), + ); + parent_fields.borrow_mut().push(Arc::new(processed_list)); + } + (DataType::Map(unprocessed_child, _), None) => { + // This is the first time popping off this map. See struct docs above. + let child_fields = Rc::new(RefCell::new(Vec::with_capacity(1))); + stack.push(( + parquet_path.clone(), + current_field, + parent_fields, + Some(Rc::clone(&child_fields)), + )); + let mut child_path = parquet_path.clone(); + child_path.push("."); + child_path.push(unprocessed_child.name()); + stack.push(( + child_path.clone(), + unprocessed_child, + Rc::clone(&child_fields), + None, + )); + } + (DataType::Map(_, sorted), Some(processed_children)) => { + // This is the second time popping off this map. See struct docs above. + let processed_children = processed_children.borrow(); + assert_eq!(processed_children.len(), 1); + let processed_map = Field::new( + current_field.name(), + DataType::Map(Arc::clone(&processed_children[0]), *sorted), + current_field.is_nullable(), + ); + parent_fields.borrow_mut().push(Arc::new(processed_map)); + } + (DataType::Timestamp(TimeUnit::Nanosecond, None), None) + if int96_fields.contains(parquet_path.concat().as_str()) => + // We found a timestamp(nanos) and it originated as int96. Coerce it to the correct + // time_unit, optionally attaching the requested timezone. + { + parent_fields.borrow_mut().push(field_with_new_type( + current_field, + DataType::Timestamp(*time_unit, timezone.cloned()), + )); + } + // Other types can be cloned as they are. + _ => parent_fields.borrow_mut().push(Arc::clone(current_field)), + } + } + assert_eq!(fields.borrow().len(), file_schema.fields.len()); + Schema::new_with_metadata( + fields.borrow_mut().clone(), + file_schema.metadata.clone(), + ) + }; + + Some(transformed_schema) +} + +/// Coerces the file schema if the table schema uses a view type. +#[deprecated( + since = "47.0.0", + note = "Use `apply_file_schema_type_coercions` instead" +)] +pub fn coerce_file_schema_to_view_type( + table_schema: &Schema, + file_schema: &Schema, +) -> Option { + let mut transform = false; + let table_fields: HashMap<_, _> = table_schema + .fields + .iter() + .map(|f| { + let dt = f.data_type(); + if dt.equals_datatype(&DataType::Utf8View) + || dt.equals_datatype(&DataType::BinaryView) + { + transform = true; + } + (f.name(), dt) + }) + .collect(); + + if !transform { + return None; + } + + let transformed_fields: Vec> = file_schema + .fields + .iter() + .map( + |field| match (table_fields.get(field.name()), field.data_type()) { + (Some(DataType::Utf8View), DataType::Utf8 | DataType::LargeUtf8) => { + field_with_new_type(field, DataType::Utf8View) + } + ( + Some(DataType::BinaryView), + DataType::Binary | DataType::LargeBinary, + ) => field_with_new_type(field, DataType::BinaryView), + _ => Arc::clone(field), + }, + ) + .collect(); + + Some(Schema::new_with_metadata( + transformed_fields, + file_schema.metadata.clone(), + )) +} + +/// If the table schema uses a string type, coerce the file schema to use a string type. +/// +/// See [`ParquetFormat::binary_as_string`](crate::file_format::ParquetFormat::binary_as_string) for details +#[deprecated( + since = "47.0.0", + note = "Use `apply_file_schema_type_coercions` instead" +)] +pub fn coerce_file_schema_to_string_type( + table_schema: &Schema, + file_schema: &Schema, +) -> Option { + let mut transform = false; + let table_fields: HashMap<_, _> = table_schema + .fields + .iter() + .map(|f| (f.name(), f.data_type())) + .collect(); + let transformed_fields: Vec> = file_schema + .fields + .iter() + .map( + |field| match (table_fields.get(field.name()), field.data_type()) { + // table schema uses string type, coerce the file schema to use string type + ( + Some(DataType::Utf8), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + transform = true; + field_with_new_type(field, DataType::Utf8) + } + // table schema uses large string type, coerce the file schema to use large string type + ( + Some(DataType::LargeUtf8), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + transform = true; + field_with_new_type(field, DataType::LargeUtf8) + } + // table schema uses string view type, coerce the file schema to use view type + ( + Some(DataType::Utf8View), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + transform = true; + field_with_new_type(field, DataType::Utf8View) + } + _ => Arc::clone(field), + }, + ) + .collect(); + + if !transform { + None + } else { + Some(Schema::new_with_metadata( + transformed_fields, + file_schema.metadata.clone(), + )) + } +} + +/// Create a new field with the specified data type, copying the other +/// properties from the input field +fn field_with_new_type(field: &FieldRef, new_type: DataType) -> FieldRef { + Arc::new(field.as_ref().clone().with_data_type(new_type)) +} + +/// Transform a schema to use view types for Utf8 and Binary +/// +/// See [`ParquetFormat::force_view_types`](crate::file_format::ParquetFormat::force_view_types) for details +pub fn transform_schema_to_view(schema: &Schema) -> Schema { + let transformed_fields: Vec> = schema + .fields + .iter() + .map(|field| match field.data_type() { + DataType::Utf8 | DataType::LargeUtf8 => { + field_with_new_type(field, DataType::Utf8View) + } + DataType::Binary | DataType::LargeBinary => { + field_with_new_type(field, DataType::BinaryView) + } + _ => Arc::clone(field), + }) + .collect(); + Schema::new_with_metadata(transformed_fields, schema.metadata.clone()) +} + +/// Transform a schema so that any binary types are strings +pub fn transform_binary_to_string(schema: &Schema) -> Schema { + let transformed_fields: Vec> = schema + .fields + .iter() + .map(|field| match field.data_type() { + DataType::Binary => field_with_new_type(field, DataType::Utf8), + DataType::LargeBinary => field_with_new_type(field, DataType::LargeUtf8), + DataType::BinaryView => field_with_new_type(field, DataType::Utf8View), + _ => Arc::clone(field), + }) + .collect(); + Schema::new_with_metadata(transformed_fields, schema.metadata.clone()) +} +#[cfg(test)] +mod tests { + use parquet::arrow::parquet_to_arrow_schema; + + use super::*; + + use parquet::schema::parser::parse_message_type; + + #[test] + fn coerce_int96_to_resolution_with_mixed_timestamps() { + // Unclear if Spark (or other writer) could generate a file with mixed timestamps like this, + // but we want to test the scenario just in case since it's at least a valid schema as far + // as the Parquet spec is concerned. + let spark_schema = " + message spark_schema { + optional int96 c0; + optional int64 c1 (TIMESTAMP(NANOS,true)); + optional int64 c2 (TIMESTAMP(NANOS,false)); + optional int64 c3 (TIMESTAMP(MILLIS,true)); + optional int64 c4 (TIMESTAMP(MILLIS,false)); + optional int64 c5 (TIMESTAMP(MICROS,true)); + optional int64 c6 (TIMESTAMP(MICROS,false)); + } + "; + + let schema = parse_message_type(spark_schema).expect("should parse schema"); + let descr = SchemaDescriptor::new(Arc::new(schema)); + + let arrow_schema = parquet_to_arrow_schema(&descr, None).unwrap(); + + let result = Int96Coercer::new(&descr, &arrow_schema, &TimeUnit::Microsecond) + .coerce() + .unwrap(); + + // Only the first field (c0) should be converted to a microsecond timestamp because it's the + // only timestamp that originated from an INT96. + let expected_schema = Schema::new(vec![ + Field::new("c0", DataType::Timestamp(TimeUnit::Microsecond, None), true), + Field::new( + "c1", + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), + true, + ), + Field::new("c2", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + Field::new( + "c3", + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + true, + ), + Field::new("c4", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new( + "c5", + DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), + true, + ), + Field::new("c6", DataType::Timestamp(TimeUnit::Microsecond, None), true), + ]); + + assert_eq!(result, expected_schema); + } + + #[test] + fn coerce_int96_to_resolution_with_tz_applies_timezone() { + // Same input schema as `coerce_int96_to_resolution_with_mixed_timestamps`, but with a + // non-empty `timezone` argument. Only c0 (the INT96 column) should pick up the timezone; + // the other timestamp columns must keep whatever timezone they were declared with. + let spark_schema = " + message spark_schema { + optional int96 c0; + optional int64 c1 (TIMESTAMP(NANOS,true)); + optional int64 c2 (TIMESTAMP(NANOS,false)); + optional int64 c3 (TIMESTAMP(MILLIS,true)); + optional int64 c4 (TIMESTAMP(MILLIS,false)); + optional int64 c5 (TIMESTAMP(MICROS,true)); + optional int64 c6 (TIMESTAMP(MICROS,false)); + } + "; + + let schema = parse_message_type(spark_schema).expect("should parse schema"); + let descr = SchemaDescriptor::new(Arc::new(schema)); + + let arrow_schema = parquet_to_arrow_schema(&descr, None).unwrap(); + + let result = Int96Coercer::new(&descr, &arrow_schema, &TimeUnit::Microsecond) + .with_timezone(Some(Arc::from("UTC"))) + .coerce() + .unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new( + "c0", + DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), + true, + ), + Field::new( + "c1", + DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())), + true, + ), + Field::new("c2", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + Field::new( + "c3", + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + true, + ), + Field::new("c4", DataType::Timestamp(TimeUnit::Millisecond, None), true), + Field::new( + "c5", + DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), + true, + ), + Field::new("c6", DataType::Timestamp(TimeUnit::Microsecond, None), true), + ]); + + assert_eq!(result, expected_schema); + } + + #[test] + fn coerce_int96_to_resolution_with_nested_types() { + // This schema is derived from Comet's CometFuzzTestSuite ParquetGenerator only using int96 + // primitive types with generateStruct, generateArray, and generateMap set to true, with one + // additional field added to c4's struct to make sure all fields in a struct get modified. + // https://github.com/apache/datafusion-comet/blob/main/spark/src/main/scala/org/apache/comet/testing/ParquetGenerator.scala + let spark_schema = " + message spark_schema { + optional int96 c0; + optional group c1 { + optional int96 c0; + } + optional group c2 { + optional group c0 (LIST) { + repeated group list { + optional int96 element; + } + } + } + optional group c3 (LIST) { + repeated group list { + optional int96 element; + } + } + optional group c4 (LIST) { + repeated group list { + optional group element { + optional int96 c0; + optional int96 c1; + } + } + } + optional group c5 (MAP) { + repeated group key_value { + required int96 key; + optional int96 value; + } + } + optional group c6 (LIST) { + repeated group list { + optional group element (MAP) { + repeated group key_value { + required int96 key; + optional int96 value; + } + } + } + } + } + "; + + let schema = parse_message_type(spark_schema).expect("should parse schema"); + let descr = SchemaDescriptor::new(Arc::new(schema)); + + let arrow_schema = parquet_to_arrow_schema(&descr, None).unwrap(); + + let result = Int96Coercer::new(&descr, &arrow_schema, &TimeUnit::Microsecond) + .coerce() + .unwrap(); + + let expected_schema = Schema::new(vec![ + Field::new("c0", DataType::Timestamp(TimeUnit::Microsecond, None), true), + Field::new_struct( + "c1", + vec![Field::new( + "c0", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )], + true, + ), + Field::new_struct( + "c2", + vec![Field::new_list( + "c0", + Field::new( + "element", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + true, + )], + true, + ), + Field::new_list( + "c3", + Field::new( + "element", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + true, + ), + Field::new_list( + "c4", + Field::new_struct( + "element", + vec![ + Field::new( + "c0", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "c1", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + ], + true, + ), + true, + ), + Field::new_map( + "c5", + "key_value", + Field::new( + "key", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "value", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + false, + true, + ), + Field::new_list( + "c6", + Field::new_map( + "element", + "key_value", + Field::new( + "key", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "value", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + false, + true, + ), + true, + ), + ]); + + assert_eq!(result, expected_schema); + } +} diff --git a/datafusion/datasource-parquet/src/sink.rs b/datafusion/datasource-parquet/src/sink.rs new file mode 100644 index 0000000000000..f15f67aab0a87 --- /dev/null +++ b/datafusion/datasource-parquet/src/sink.rs @@ -0,0 +1,754 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ParquetSink`] — DataFusion `DataSink` implementation that writes one +//! or more Parquet files to an [`ObjectStore`], optionally with parallel +//! per-column and per-row-group serialization. + +use std::fmt; +use std::fmt::Debug; +use std::sync::Arc; + +use arrow::array::RecordBatch; +use arrow::datatypes::{Schema, SchemaRef}; +use async_trait::async_trait; +use datafusion_common::config::TableParquetOptions; +use datafusion_common::{DataFusionError, HashMap, Result, internal_datafusion_err}; +use datafusion_common_runtime::{JoinSet, SpawnedTask}; +use datafusion_datasource::display::FileGroupDisplay; +use datafusion_datasource::file_compression_type::FileCompressionType; +use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; +use datafusion_datasource::sink::DataSink; +use datafusion_datasource::write::demux::DemuxedStreamReceiver; +use datafusion_datasource::write::{ + ObjectWriterBuilder, SharedBuffer, get_writer_schema, +}; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_plan::metrics::{ + ElapsedComputeFutureExt, ExecutionPlanMetricsSet, MetricBuilder, MetricCategory, + MetricsSet, Time, +}; +use datafusion_physical_plan::{DisplayAs, DisplayFormatType}; +use object_store::ObjectStore; +use object_store::buffered::BufWriter; +use object_store::path::Path; +use parquet::arrow::arrow_writer::{ + ArrowColumnChunk, ArrowColumnWriter, ArrowLeafColumn, ArrowRowGroupWriterFactory, + ArrowWriterOptions, compute_leaves, +}; +use parquet::arrow::{ArrowWriter, AsyncArrowWriter}; +#[cfg(feature = "parquet_encryption")] +use parquet::encryption::encrypt::FileEncryptionProperties; +use parquet::file::metadata::{ParquetMetaData, SortingColumn}; +use parquet::file::properties::{ + DEFAULT_MAX_ROW_GROUP_ROW_COUNT, WriterProperties, WriterPropertiesBuilder, +}; +use parquet::file::writer::SerializedFileWriter; +use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio::sync::mpsc::{self, Receiver, Sender}; + +/// Initial writing buffer size. Note this is just a size hint for efficiency. It +/// will grow beyond the set value if needed. +const INITIAL_BUFFER_BYTES: usize = 1048576; + +/// When writing parquet files in parallel, if the buffered Parquet data exceeds +/// this size, it is flushed to object store +const BUFFER_FLUSH_BYTES: usize = 1024000; + +/// Implements [`DataSink`] for writing to a parquet file. +pub struct ParquetSink { + /// Config options for writing data + config: FileSinkConfig, + /// Underlying parquet options + parquet_options: TableParquetOptions, + /// File metadata from successfully produced parquet files. The Mutex is only used + /// to allow inserting to HashMap from behind borrowed reference in DataSink::write_all. + written: Arc>>, + /// Optional sorting columns to write to Parquet metadata + sorting_columns: Option>, + /// Metrics for tracking write operations + metrics: ExecutionPlanMetricsSet, +} + +impl Debug for ParquetSink { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ParquetSink").finish() + } +} + +impl DisplayAs for ParquetSink { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "ParquetSink(file_groups=",)?; + FileGroupDisplay(&self.config.file_group).fmt_as(t, f)?; + write!(f, ")") + } + DisplayFormatType::TreeRender => { + // TODO: collect info + write!(f, "") + } + } + } +} + +impl ParquetSink { + /// Create from config. + pub fn new(config: FileSinkConfig, parquet_options: TableParquetOptions) -> Self { + Self { + config, + parquet_options, + written: Default::default(), + sorting_columns: None, + metrics: ExecutionPlanMetricsSet::new(), + } + } + + /// Set sorting columns for the Parquet file metadata. + pub fn with_sorting_columns( + mut self, + sorting_columns: Option>, + ) -> Self { + self.sorting_columns = sorting_columns; + self + } + + /// Retrieve the file metadata for the written files, keyed to the path + /// which may be partitioned (in the case of hive style partitioning). + pub fn written(&self) -> HashMap { + self.written.lock().clone() + } + + /// Create writer properties based upon configuration settings, + /// including partitioning and the inclusion of arrow schema metadata. + async fn create_writer_props( + &self, + runtime: &Arc, + path: &Path, + ) -> Result { + let schema = self.config.output_schema(); + + // TODO: avoid this clone in follow up PR, where the writer properties & schema + // are calculated once on `ParquetSink::new` + let mut parquet_opts = self.parquet_options.clone(); + if !self.parquet_options.global.skip_arrow_metadata { + parquet_opts.arrow_schema(schema); + } + + let mut builder = WriterPropertiesBuilder::try_from(&parquet_opts)?; + + // Set sorting columns if configured + if let Some(ref sorting_columns) = self.sorting_columns { + builder = builder.set_sorting_columns(Some(sorting_columns.clone())); + } + + builder = set_writer_encryption_properties( + builder, + runtime, + parquet_opts, + schema, + path, + ) + .await?; + Ok(builder.build()) + } + + /// Creates an AsyncArrowWriter which serializes a parquet file to an ObjectStore + /// AsyncArrowWriters are used when individual parquet file serialization is not parallelized + async fn create_async_arrow_writer( + &self, + location: &Path, + object_store: Arc, + context: &Arc, + parquet_props: WriterProperties, + ) -> Result> { + let buf_writer = BufWriter::with_capacity( + object_store, + location.clone(), + context + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + ); + let options = ArrowWriterOptions::new() + .with_properties(parquet_props) + .with_skip_arrow_metadata(self.parquet_options.global.skip_arrow_metadata); + + let writer = AsyncArrowWriter::try_new_with_options( + buf_writer, + get_writer_schema(&self.config), + options, + )?; + Ok(writer) + } + + /// Parquet options + pub fn parquet_options(&self) -> &TableParquetOptions { + &self.parquet_options + } +} + +#[cfg(feature = "parquet_encryption")] +async fn set_writer_encryption_properties( + builder: WriterPropertiesBuilder, + runtime: &Arc, + parquet_opts: TableParquetOptions, + schema: &Arc, + path: &Path, +) -> Result { + if let Some(file_encryption_properties) = parquet_opts.crypto.file_encryption { + // Encryption properties have been specified directly + return Ok(builder.with_file_encryption_properties(Arc::new( + FileEncryptionProperties::try_from(file_encryption_properties)?, + ))); + } else if let Some(encryption_factory_id) = &parquet_opts.crypto.factory_id.as_ref() { + // Encryption properties will be generated by an encryption factory + let encryption_factory = + runtime.parquet_encryption_factory(encryption_factory_id)?; + let file_encryption_properties = encryption_factory + .get_file_encryption_properties( + &parquet_opts.crypto.factory_options, + schema, + path, + ) + .await?; + if let Some(file_encryption_properties) = file_encryption_properties { + return Ok( + builder.with_file_encryption_properties(file_encryption_properties) + ); + } + } + Ok(builder) +} + +#[cfg(not(feature = "parquet_encryption"))] +async fn set_writer_encryption_properties( + builder: WriterPropertiesBuilder, + _runtime: &Arc, + _parquet_opts: TableParquetOptions, + _schema: &Arc, + _path: &Path, +) -> Result { + Ok(builder) +} + +#[async_trait] +impl FileSink for ParquetSink { + fn config(&self) -> &FileSinkConfig { + &self.config + } + + async fn spawn_writer_tasks_and_join( + &self, + context: &Arc, + demux_task: SpawnedTask>, + mut file_stream_rx: DemuxedStreamReceiver, + object_store: Arc, + ) -> Result { + let rows_written_counter = MetricBuilder::new(&self.metrics) + .with_category(MetricCategory::Rows) + .global_counter("rows_written"); + // Note: bytes_written is the sum of compressed row group sizes, which + // may differ slightly from the actual on-disk file size (excludes footer, + // page indexes, and other Parquet metadata overhead). + let bytes_written_counter = MetricBuilder::new(&self.metrics) + .with_category(MetricCategory::Bytes) + .global_counter("bytes_written"); + let elapsed_compute = MetricBuilder::new(&self.metrics).elapsed_compute(0); + + let parquet_opts = &self.parquet_options; + + let mut file_write_tasks: JoinSet< + std::result::Result<(Path, ParquetMetaData), DataFusionError>, + > = JoinSet::new(); + + let runtime = context.runtime_env(); + let parallel_options = ParallelParquetWriterOptions { + max_parallel_row_groups: parquet_opts + .global + .maximum_parallel_row_group_writers, + max_buffered_record_batches_per_stream: parquet_opts + .global + .maximum_buffered_record_batches_per_stream, + }; + + while let Some((path, mut rx)) = file_stream_rx.recv().await { + let parquet_props = self.create_writer_props(&runtime, &path).await?; + // CDC requires the sequential writer: the chunker state lives in ArrowWriter + // and persists across row groups. The parallel path bypasses ArrowWriter entirely. + if !parquet_opts.global.allow_single_file_parallelism + || parquet_opts.global.content_defined_chunking.enabled + { + let mut writer = self + .create_async_arrow_writer( + &path, + Arc::clone(&object_store), + context, + parquet_props.clone(), + ) + .await?; + let reservation = MemoryConsumer::new(format!("ParquetSink[{path}]")) + .register(context.memory_pool()); + file_write_tasks.spawn( + async move { + while let Some(batch) = rx.recv().await { + writer.write(&batch).await?; + reservation.try_resize(writer.memory_size())?; + } + let parquet_meta_data = writer + .close() + .await + .map_err(|e| DataFusionError::ParquetError(Box::new(e)))?; + Ok((path, parquet_meta_data)) + } + .with_elapsed_compute(elapsed_compute.clone()), + ); + } else { + let writer = ObjectWriterBuilder::new( + // Parquet files as a whole are never compressed, since they + // manage compressed blocks themselves. + FileCompressionType::UNCOMPRESSED, + &path, + Arc::clone(&object_store), + ) + .with_buffer_size(Some( + context + .session_config() + .options() + .execution + .objectstore_writer_buffer_size, + )) + .build()?; + let ctx = ParquetFileWriteContext { + schema: get_writer_schema(&self.config), + props: Arc::new(parquet_props), + skip_arrow_metadata: self.parquet_options.global.skip_arrow_metadata, + parallel_options: Arc::new(parallel_options.clone()), + pool: Arc::clone(context.memory_pool()), + }; + let encoding_time = elapsed_compute.clone(); + file_write_tasks.spawn(async move { + let parquet_meta_data = output_single_parquet_file_parallelized( + writer, + rx, + ctx, + encoding_time, + ) + .await?; + Ok((path, parquet_meta_data)) + }); + } + } + + while let Some(result) = file_write_tasks.join_next().await { + match result { + Ok(r) => { + let (path, parquet_meta_data) = r?; + let file_rows = parquet_meta_data.file_metadata().num_rows() as usize; + let file_bytes: usize = parquet_meta_data + .row_groups() + .iter() + .map(|rg| rg.compressed_size() as usize) + .sum(); + rows_written_counter.add(file_rows); + bytes_written_counter.add(file_bytes); + let mut written_files = self.written.lock(); + written_files + .try_insert(path.clone(), parquet_meta_data) + .map_err(|e| internal_datafusion_err!("duplicate entry detected for partitioned file {path}: {e}"))?; + drop(written_files); + } + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + + demux_task + .join_unwind() + .await + .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; + + Ok(rows_written_counter.value() as u64) + } +} + +#[async_trait] +impl DataSink for ParquetSink { + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn schema(&self) -> &SchemaRef { + self.config.output_schema() + } + + async fn write_all( + &self, + data: SendableRecordBatchStream, + context: &Arc, + ) -> Result { + FileSink::write_all(self, data, context).await + } +} + +/// Consumes a stream of [ArrowLeafColumn] via a channel and serializes them using an [ArrowColumnWriter] +/// Once the channel is exhausted, returns the ArrowColumnWriter. +async fn column_serializer_task( + mut rx: Receiver, + mut writer: ArrowColumnWriter, + reservation: MemoryReservation, + encoding_time: Time, +) -> Result<(ArrowColumnWriter, MemoryReservation)> { + while let Some(col) = rx.recv().await { + let _timer = encoding_time.timer(); + writer.write(&col)?; + reservation.try_resize(writer.memory_size())?; + } + Ok((writer, reservation)) +} + +type ColumnWriterTask = SpawnedTask>; +type ColSender = Sender; + +/// Spawns a parallel serialization task for each column +/// Returns join handles for each columns serialization task along with a send channel +/// to send arrow arrays to each serialization task. +fn spawn_column_parallel_row_group_writer( + col_writers: Vec, + max_buffer_size: usize, + pool: &Arc, + encoding_time: &Time, +) -> Result<(Vec, Vec)> { + let num_columns = col_writers.len(); + + let mut col_writer_tasks = Vec::with_capacity(num_columns); + let mut col_array_channels = Vec::with_capacity(num_columns); + for writer in col_writers.into_iter() { + // Buffer size of this channel limits the number of arrays queued up for column level serialization + let (send_array, receive_array) = + mpsc::channel::(max_buffer_size); + col_array_channels.push(send_array); + + let reservation = + MemoryConsumer::new("ParquetSink(ArrowColumnWriter)").register(pool); + let task = SpawnedTask::spawn(column_serializer_task( + receive_array, + writer, + reservation, + encoding_time.clone(), + )); + col_writer_tasks.push(task); + } + + Ok((col_writer_tasks, col_array_channels)) +} + +/// Settings related to writing parquet files in parallel +#[derive(Clone)] +struct ParallelParquetWriterOptions { + max_parallel_row_groups: usize, + max_buffered_record_batches_per_stream: usize, +} + +/// Write configuration inputs shared across all parallel tasks that encode a +/// single Parquet file. These values are invariant for the duration of one file +/// write and do not change per row-group or per column. +/// +/// Separating these from per-call parameters (`object_store_writer`, `data`, +/// `encoding_time`) keeps the deep parallel call chain below the argument-count +/// limit without mixing configuration with runtime state. +#[derive(Clone)] +struct ParquetFileWriteContext { + schema: Arc, + props: Arc, + skip_arrow_metadata: bool, + parallel_options: Arc, + pool: Arc, +} + +/// This is the return type of calling [ArrowColumnWriter].close() on each column +/// i.e. the Vec of encoded columns which can be appended to a row group +type RBStreamSerializeResult = Result<(Vec, MemoryReservation, usize)>; + +/// Sends the ArrowArrays in passed [RecordBatch] through the channels to their respective +/// parallel column serializers. +async fn send_arrays_to_col_writers( + col_array_channels: &[ColSender], + rb: &RecordBatch, + schema: Arc, +) -> Result<()> { + // Each leaf column has its own channel, increment next_channel for each leaf column sent. + let mut next_channel = 0; + for (array, field) in rb.columns().iter().zip(schema.fields()) { + for c in compute_leaves(field, array)? { + // Do not surface error from closed channel (means something + // else hit an error, and the plan is shutting down). + if col_array_channels[next_channel].send(c).await.is_err() { + return Ok(()); + } + + next_channel += 1; + } + } + + Ok(()) +} + +/// Spawns a tokio task which joins the parallel column writer tasks, +/// and finalizes the row group +fn spawn_rg_join_and_finalize_task( + column_writer_tasks: Vec, + rg_rows: usize, + pool: &Arc, + encoding_time: Time, +) -> SpawnedTask { + let rg_reservation = + MemoryConsumer::new("ParquetSink(SerializedRowGroupWriter)").register(pool); + + SpawnedTask::spawn(async move { + let num_cols = column_writer_tasks.len(); + let mut finalized_rg = Vec::with_capacity(num_cols); + for task in column_writer_tasks.into_iter() { + let (writer, _col_reservation) = task + .join_unwind() + .await + .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; + let encoded_size = writer.get_estimated_total_bytes(); + rg_reservation.grow(encoded_size); + let _timer = encoding_time.timer(); + finalized_rg.push(writer.close()?); + } + + Ok((finalized_rg, rg_reservation, rg_rows)) + }) +} + +/// This task coordinates the serialization of a parquet file in parallel. +/// As the query produces RecordBatches, these are written to a RowGroup +/// via parallel [ArrowColumnWriter] tasks. Once the desired max rows per +/// row group is reached, the parallel tasks are joined on another separate task +/// and sent to a concatenation task. This task immediately continues to work +/// on the next row group in parallel. So, parquet serialization is parallelized +/// across both columns and row_groups, with a theoretical max number of parallel tasks +/// given by n_columns * num_row_groups. +fn spawn_parquet_parallel_serialization_task( + row_group_writer_factory: ArrowRowGroupWriterFactory, + mut data: Receiver, + serialize_tx: Sender>, + ctx: ParquetFileWriteContext, + encoding_time: Time, +) -> SpawnedTask> { + SpawnedTask::spawn(async move { + let max_buffer_rb = ctx.parallel_options.max_buffered_record_batches_per_stream; + let max_row_group_rows = ctx + .props + .max_row_group_row_count() + .unwrap_or(DEFAULT_MAX_ROW_GROUP_ROW_COUNT); + let mut row_group_index = 0; + let col_writers = + row_group_writer_factory.create_column_writers(row_group_index)?; + let (mut column_writer_handles, mut col_array_channels) = + spawn_column_parallel_row_group_writer( + col_writers, + max_buffer_rb, + &ctx.pool, + &encoding_time, + )?; + let mut current_rg_rows = 0; + + while let Some(mut rb) = data.recv().await { + // This loop allows the "else" block to repeatedly split the RecordBatch to handle the case + // when max_row_group_rows < execution.batch_size as an alternative to a recursive async + // function. + loop { + if current_rg_rows + rb.num_rows() < max_row_group_rows { + send_arrays_to_col_writers( + &col_array_channels, + &rb, + Arc::clone(&ctx.schema), + ) + .await?; + current_rg_rows += rb.num_rows(); + break; + } else { + let rows_left = max_row_group_rows - current_rg_rows; + let a = rb.slice(0, rows_left); + send_arrays_to_col_writers( + &col_array_channels, + &a, + Arc::clone(&ctx.schema), + ) + .await?; + + // Signal the parallel column writers that the RowGroup is done, join and finalize RowGroup + // on a separate task, so that we can immediately start on the next RG before waiting + // for the current one to finish. + drop(col_array_channels); + let finalize_rg_task = spawn_rg_join_and_finalize_task( + column_writer_handles, + max_row_group_rows, + &ctx.pool, + encoding_time.clone(), + ); + + // Do not surface error from closed channel (means something + // else hit an error, and the plan is shutting down). + if serialize_tx.send(finalize_rg_task).await.is_err() { + return Ok(()); + } + + current_rg_rows = 0; + rb = rb.slice(rows_left, rb.num_rows() - rows_left); + + row_group_index += 1; + let col_writers = row_group_writer_factory + .create_column_writers(row_group_index)?; + (column_writer_handles, col_array_channels) = + spawn_column_parallel_row_group_writer( + col_writers, + max_buffer_rb, + &ctx.pool, + &encoding_time, + )?; + } + } + } + + drop(col_array_channels); + // Handle leftover rows as final rowgroup, which may be smaller than max_row_group_rows + if current_rg_rows > 0 { + let finalize_rg_task = spawn_rg_join_and_finalize_task( + column_writer_handles, + current_rg_rows, + &ctx.pool, + encoding_time.clone(), + ); + + // Do not surface error from closed channel (means something + // else hit an error, and the plan is shutting down). + if serialize_tx.send(finalize_rg_task).await.is_err() { + return Ok(()); + } + } + + Ok(()) + }) +} + +/// Consume RowGroups serialized by other parallel tasks and concatenate them in +/// to the final parquet file, while flushing finalized bytes to an [ObjectStore] +async fn concatenate_parallel_row_groups( + mut parquet_writer: SerializedFileWriter, + merged_buff: SharedBuffer, + mut serialize_rx: Receiver>, + mut object_store_writer: Box, + pool: Arc, +) -> Result { + let file_reservation = + MemoryConsumer::new("ParquetSink(SerializedFileWriter)").register(&pool); + + while let Some(task) = serialize_rx.recv().await { + let result = task.join_unwind().await; + let (serialized_columns, rg_reservation, _cnt) = + result.map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; + + let mut rg_out = parquet_writer.next_row_group()?; + for chunk in serialized_columns { + chunk.append_to_row_group(&mut rg_out)?; + rg_reservation.free(); + + let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); + file_reservation.try_resize(buff_to_flush.len())?; + + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { + object_store_writer + .write_all(buff_to_flush.as_slice()) + .await?; + buff_to_flush.clear(); + file_reservation.try_resize(buff_to_flush.len())?; // will set to zero + } + } + rg_out.close()?; + } + + let parquet_meta_data = parquet_writer.close()?; + let final_buff = merged_buff.buffer.try_lock().unwrap(); + + object_store_writer.write_all(final_buff.as_slice()).await?; + object_store_writer.shutdown().await?; + file_reservation.free(); + + Ok(parquet_meta_data) +} + +/// Parallelizes the serialization of a single parquet file, by first serializing N +/// independent RecordBatch streams in parallel to RowGroups in memory. Another +/// task then stitches these independent RowGroups together and streams this large +/// single parquet file to an ObjectStore in multiple parts. +async fn output_single_parquet_file_parallelized( + object_store_writer: Box, + data: Receiver, + ctx: ParquetFileWriteContext, + encoding_time: Time, +) -> Result { + let max_rowgroups = ctx.parallel_options.max_parallel_row_groups; + // Buffer size of this channel limits maximum number of RowGroups being worked on in parallel + let (serialize_tx, serialize_rx) = + mpsc::channel::>(max_rowgroups); + + let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES); + let options = ArrowWriterOptions::new() + .with_properties((*ctx.props).clone()) + .with_skip_arrow_metadata(ctx.skip_arrow_metadata); + let writer = ArrowWriter::try_new_with_options( + merged_buff.clone(), + Arc::clone(&ctx.schema), + options, + )?; + let (writer, row_group_writer_factory) = writer.into_serialized_writer()?; + + let pool = Arc::clone(&ctx.pool); + let launch_serialization_task = spawn_parquet_parallel_serialization_task( + row_group_writer_factory, + data, + serialize_tx, + ctx, + encoding_time, + ); + let parquet_meta_data = concatenate_parallel_row_groups( + writer, + merged_buff, + serialize_rx, + object_store_writer, + pool, + ) + .await?; + + launch_serialization_task + .join_unwind() + .await + .map_err(|e| DataFusionError::ExecutionJoin(Box::new(e)))??; + Ok(parquet_meta_data) +} diff --git a/datafusion/datasource-parquet/src/sort.rs b/datafusion/datasource-parquet/src/sort.rs new file mode 100644 index 0000000000000..c1cf4e8b7824e --- /dev/null +++ b/datafusion/datasource-parquet/src/sort.rs @@ -0,0 +1,1108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sort-related utilities for Parquet scanning + +use arrow::datatypes::Schema; +use datafusion_common::{Result, ScalarValue}; +use datafusion_datasource::PartitionedFile; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; +use parquet::file::metadata::ParquetMetaData; +use std::collections::HashMap; + +/// Reverse a row selection to match reversed row group order. +/// +/// When scanning row groups in reverse order, we need to adjust the row selection +/// to account for the new ordering. This function: +/// 1. Maps each selection to its corresponding row group +/// 2. Reverses the order of row groups +/// 3. Reconstructs the row selection for the new order +/// +/// # Arguments +/// * `row_selection` - Original row selection (only covers row groups that are scanned) +/// * `parquet_metadata` - Metadata containing row group information +/// * `row_groups_to_scan` - Indexes of row groups that will be scanned (in original order) +/// +/// # Returns +/// A new `RowSelection` adjusted for reversed row group order +/// +/// # Important Notes +/// The input `row_selection` only covers the row groups specified in `row_groups_to_scan`. +/// Row groups that are skipped (not in `row_groups_to_scan`) are not represented in the +/// `row_selection` at all. This function needs `row_groups_to_scan` to correctly map +/// the selection back to the original row groups. +pub fn reverse_row_selection( + row_selection: &RowSelection, + parquet_metadata: &ParquetMetaData, + row_groups_to_scan: &[usize], +) -> Result { + let rg_metadata = parquet_metadata.row_groups(); + + // Build a mapping of row group index to its row range, but ONLY for + // the row groups that are actually being scanned. + // + // IMPORTANT: The row numbers in this mapping are RELATIVE to the scanned row groups, + // not absolute positions in the file. + // + // Example: If row_groups_to_scan = [0, 2, 3] and each has 100 rows: + // RG0: rows 0-99 (relative to scanned data) + // RG2: rows 100-199 (relative to scanned data, NOT 200-299 in file!) + // RG3: rows 200-299 (relative to scanned data, NOT 300-399 in file!) + let mut rg_row_ranges: Vec<(usize, usize, usize)> = + Vec::with_capacity(row_groups_to_scan.len()); + let mut current_row = 0; + for &rg_idx in row_groups_to_scan { + let rg = &rg_metadata[rg_idx]; + let num_rows = rg.num_rows() as usize; + rg_row_ranges.push((rg_idx, current_row, current_row + num_rows)); + current_row += num_rows; // This is relative row number, NOT absolute file position + } + + // Map selections to row groups + let mut rg_selections: HashMap> = HashMap::new(); + + let mut current_file_row = 0; + for selector in row_selection.iter() { + let selector_end = current_file_row + selector.row_count; + + // Find which row groups this selector spans + for (rg_idx, rg_start, rg_end) in rg_row_ranges.iter() { + if current_file_row < *rg_end && selector_end > *rg_start { + // This selector overlaps with this row group + let overlap_start = current_file_row.max(*rg_start); + let overlap_end = selector_end.min(*rg_end); + let overlap_count = overlap_end - overlap_start; + + if overlap_count > 0 { + let entry = rg_selections.entry(*rg_idx).or_default(); + if selector.skip { + entry.push(RowSelector::skip(overlap_count)); + } else { + entry.push(RowSelector::select(overlap_count)); + } + } + } + } + + current_file_row = selector_end; + } + + // Build new selection for reversed row group order + // Only iterate over the row groups that are being scanned, in reverse order + let mut reversed_selectors = Vec::new(); + for &rg_idx in row_groups_to_scan.iter().rev() { + if let Some(selectors) = rg_selections.get(&rg_idx) { + reversed_selectors.extend(selectors.iter().cloned()); + } else { + // No specific selection for this row group means select all rows in it + if let Some((_, start, end)) = + rg_row_ranges.iter().find(|(idx, _, _)| *idx == rg_idx) + { + reversed_selectors.push(RowSelector::select(end - start)); + } + } + } + + Ok(RowSelection::from(reversed_selectors)) +} + +/// Reorder a file list so the most "promising" files are read first, +/// matching `PreparedAccessPlan::reorder_by_statistics` at the +/// row-group level: key off the file's `min(col)`, and let the sort +/// direction follow the request (ASC by `min` for ASC requests, DESC +/// by `min` for DESC requests). +/// +/// Keeping both layers consistent matters because they share the same +/// convergence story for TopK's dynamic filter: file `i`'s `min` is a +/// lower bound on every row group inside it, so the order chosen here +/// is a natural prefix of the order `reorder_by_statistics` will +/// produce within each file. +/// +/// No-op when: +/// * `sort_order` is `None` (sort pushdown didn't fire); +/// * the leading sort expression is not a plain `Column`; or +/// * the column is not in `table_schema`. +/// +/// Files missing statistics sort to the end so present-stats files +/// run first. +pub(crate) fn reorder_files_by_min_statistics( + mut files: Vec, + sort_order: Option<&LexOrdering>, + reverse_row_groups: bool, + table_schema: &Schema, +) -> Vec { + let Some((col_name, descending)) = + extract_topk_sort_info(sort_order, reverse_row_groups) + else { + return files; + }; + + let Ok(col_idx) = table_schema.index_of(&col_name) else { + return files; + }; + + files.sort_by(|a, b| { + let key_a = file_min_value(a, col_idx); + let key_b = file_min_value(b, col_idx); + match (key_a, key_b) { + (Some(va), Some(vb)) => { + let cmp = va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal); + if descending { cmp.reverse() } else { cmp } + } + (Some(_), None) => std::cmp::Ordering::Less, + (None, Some(_)) => std::cmp::Ordering::Greater, + (None, None) => std::cmp::Ordering::Equal, + } + }); + + log::debug!( + "Reordered {} files by min({}) {} for TopK optimization", + files.len(), + col_name, + if descending { "DESC" } else { "ASC" } + ); + + files +} + +/// Extract the `(column name, descending)` tuple used by file-level +/// reordering. Returns `None` when the sort order isn't set or the +/// leading sort expression isn't a plain `Column`. +fn extract_topk_sort_info( + sort_order: Option<&LexOrdering>, + reverse_row_groups: bool, +) -> Option<(String, bool)> { + let sort_order = sort_order?; + let first = sort_order.first(); + let col = first.expr.downcast_ref::()?; + Some((col.name().to_string(), reverse_row_groups)) +} + +/// File's per-column `min` for the reorder key. +fn file_min_value(file: &PartitionedFile, col_idx: usize) -> Option { + let stats = file.statistics.as_ref()?; + stats + .column_statistics + .get(col_idx)? + .min_value + .get_value() + .cloned() +} + +#[cfg(test)] +mod tests { + use crate::ParquetAccessPlan; + use crate::RowGroupAccess; + use arrow::datatypes::{DataType, Field, Schema}; + use bytes::Bytes; + use parquet::arrow::ArrowWriter; + use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; + use parquet::file::reader::FileReader; + use parquet::file::serialized_reader::SerializedFileReader; + use std::sync::Arc; + + /// Helper function to create a ParquetMetaData with specified row group sizes + /// by actually writing a parquet file in memory + fn create_test_metadata( + row_group_sizes: Vec, + ) -> parquet::file::metadata::ParquetMetaData { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let mut buffer = Vec::new(); + { + let props = parquet::file::properties::WriterProperties::builder().build(); + let mut writer = + ArrowWriter::try_new(&mut buffer, schema.clone(), Some(props)).unwrap(); + + for &size in &row_group_sizes { + let array = arrow::array::Int32Array::from(vec![1; size as usize]); + let batch = arrow::record_batch::RecordBatch::try_new( + schema.clone(), + vec![Arc::new(array)], + ) + .unwrap(); + writer.write(&batch).unwrap(); + writer.flush().unwrap(); + } + writer.close().unwrap(); + } + + let bytes = Bytes::from(buffer); + let reader = SerializedFileReader::new(bytes).unwrap(); + reader.metadata().clone() + } + + #[test] + fn test_prepared_access_plan_reverse_simple() { + // Test: all row groups are scanned, no row selection + let metadata = create_test_metadata(vec![100, 100, 100]); + + let access_plan = ParquetAccessPlan::new_all(3); + let rg_metadata = metadata.row_groups(); + + let prepared_plan = access_plan + .prepare(rg_metadata) + .expect("Failed to create PreparedAccessPlan"); + + // Verify original plan + assert_eq!(prepared_plan.row_group_indexes, vec![0, 1, 2]); + + // No row selection originally due to scanning all rows + assert_eq!(prepared_plan.row_selection, None); + + let reversed_plan = prepared_plan + .reverse(&metadata) + .expect("Failed to reverse PreparedAccessPlan"); + + // Verify row groups are reversed + assert_eq!(reversed_plan.row_group_indexes, vec![2, 1, 0]); + + // If no selection originally, after reversal should still select all rows, + // and the selection should be None + assert_eq!(reversed_plan.row_selection, None); + } + + #[test] + fn test_prepared_access_plan_reverse_with_selection() { + // Test: simple row selection that spans multiple row groups + let metadata = create_test_metadata(vec![100, 100, 100]); + + let mut access_plan = ParquetAccessPlan::new_all(3); + + // Select first 50 rows from first row group, skip rest + access_plan.scan_selection( + 0, + RowSelection::from(vec![RowSelector::select(50), RowSelector::skip(50)]), + ); + + let rg_metadata = metadata.row_groups(); + let prepared_plan = access_plan + .prepare(rg_metadata) + .expect("Failed to create PreparedAccessPlan"); + + let original_selected: usize = prepared_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + let reversed_plan = prepared_plan + .reverse(&metadata) + .expect("Failed to reverse PreparedAccessPlan"); + + let reversed_selected: usize = reversed_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + assert_eq!( + original_selected, reversed_selected, + "Total selected rows should remain the same" + ); + } + + #[test] + fn test_prepared_access_plan_reverse_multi_row_group_selection() { + // Test: row selection spanning multiple row groups + let metadata = create_test_metadata(vec![100, 100, 100]); + + let mut access_plan = ParquetAccessPlan::new_all(3); + + // Create selection that spans RG0 and RG1 + access_plan.scan_selection( + 0, + RowSelection::from(vec![RowSelector::skip(50), RowSelector::select(50)]), + ); + access_plan.scan_selection( + 1, + RowSelection::from(vec![RowSelector::select(50), RowSelector::skip(50)]), + ); + + let rg_metadata = metadata.row_groups(); + let prepared_plan = access_plan + .prepare(rg_metadata) + .expect("Failed to create PreparedAccessPlan"); + + let original_selected: usize = prepared_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + let reversed_plan = prepared_plan + .reverse(&metadata) + .expect("Failed to reverse PreparedAccessPlan"); + + let reversed_selected: usize = reversed_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + assert_eq!(original_selected, reversed_selected); + } + + #[test] + fn test_prepared_access_plan_reverse_empty_selection() { + // Test: all rows are skipped + let metadata = create_test_metadata(vec![100, 100, 100]); + + let mut access_plan = ParquetAccessPlan::new_all(3); + + // Skip all rows in all row groups + for i in 0..3 { + access_plan + .scan_selection(i, RowSelection::from(vec![RowSelector::skip(100)])); + } + + let rg_metadata = metadata.row_groups(); + let prepared_plan = access_plan + .prepare(rg_metadata) + .expect("Failed to create PreparedAccessPlan"); + + let reversed_plan = prepared_plan + .reverse(&metadata) + .expect("Failed to reverse PreparedAccessPlan"); + + // Should still skip all rows + let total_selected: usize = reversed_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + assert_eq!(total_selected, 0); + } + + #[test] + fn test_prepared_access_plan_reverse_different_row_group_sizes() { + // Test: row groups with different sizes + let metadata = create_test_metadata(vec![50, 150, 100]); + + let mut access_plan = ParquetAccessPlan::new_all(3); + + // Create complex selection pattern + access_plan.scan_selection( + 0, + RowSelection::from(vec![RowSelector::skip(25), RowSelector::select(25)]), + ); + access_plan.scan_selection(1, RowSelection::from(vec![RowSelector::select(150)])); + access_plan.scan_selection( + 2, + RowSelection::from(vec![RowSelector::select(50), RowSelector::skip(50)]), + ); + + let rg_metadata = metadata.row_groups(); + let prepared_plan = access_plan + .prepare(rg_metadata) + .expect("Failed to create PreparedAccessPlan"); + + let original_selected: usize = prepared_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + let reversed_plan = prepared_plan + .reverse(&metadata) + .expect("Failed to reverse PreparedAccessPlan"); + + let reversed_selected: usize = reversed_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + assert_eq!(original_selected, reversed_selected); + } + + #[test] + fn test_prepared_access_plan_reverse_single_row_group() { + // Test: single row group case + let metadata = create_test_metadata(vec![100]); + + let mut access_plan = ParquetAccessPlan::new_all(1); + access_plan.scan_selection( + 0, + RowSelection::from(vec![RowSelector::select(50), RowSelector::skip(50)]), + ); + + let rg_metadata = metadata.row_groups(); + let prepared_plan = access_plan + .prepare(rg_metadata) + .expect("Failed to create PreparedAccessPlan"); + + let original_selected: usize = prepared_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + let reversed_plan = prepared_plan + .reverse(&metadata) + .expect("Failed to reverse PreparedAccessPlan"); + + // With single row group, row_group_indexes should remain [0] + assert_eq!(reversed_plan.row_group_indexes, vec![0]); + + let reversed_selected: usize = reversed_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + assert_eq!(original_selected, reversed_selected); + assert_eq!(original_selected, 50); + } + + #[test] + fn test_prepared_access_plan_reverse_complex_pattern() { + // Test: complex pattern with multiple select/skip segments + let metadata = create_test_metadata(vec![100, 100, 100]); + + let mut access_plan = ParquetAccessPlan::new_all(3); + + // Complex pattern: select some, skip some, select some more + access_plan.scan_selection( + 0, + RowSelection::from(vec![ + RowSelector::select(30), + RowSelector::skip(40), + RowSelector::select(30), + ]), + ); + access_plan.scan_selection( + 1, + RowSelection::from(vec![RowSelector::skip(50), RowSelector::select(50)]), + ); + access_plan.scan_selection(2, RowSelection::from(vec![RowSelector::select(100)])); + + let rg_metadata = metadata.row_groups(); + let prepared_plan = access_plan + .prepare(rg_metadata) + .expect("Failed to create PreparedAccessPlan"); + + let original_selected: usize = prepared_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + let reversed_plan = prepared_plan + .reverse(&metadata) + .expect("Failed to reverse PreparedAccessPlan"); + + let reversed_selected: usize = reversed_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + assert_eq!(original_selected, reversed_selected); + assert_eq!(original_selected, 210); // 30 + 30 + 50 + 100 + } + + #[test] + fn test_prepared_access_plan_reverse_with_skipped_row_groups() { + // This is the KEY test case for the bug fix! + // Test scenario where some row groups are completely skipped (not in scan plan) + let metadata = create_test_metadata(vec![100, 100, 100, 100]); + + // Scenario: RG0 (scan all), RG1 (completely skipped), RG2 (partial), RG3 (scan all) + // Only row groups [0, 2, 3] are in the scan plan + let mut access_plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, // RG0 + RowGroupAccess::Skip, // RG1 - NOT in scan plan! + RowGroupAccess::Scan, // RG2 + RowGroupAccess::Scan, // RG3 + ]); + + // Add row selections for the scanned row groups + // Note: The RowSelection only covers row groups [0, 2, 3] (300 rows total) + access_plan.scan_selection( + 0, + RowSelection::from(vec![RowSelector::select(100)]), // RG0: all 100 rows + ); + // RG1 is skipped, no selection needed + access_plan.scan_selection( + 2, + RowSelection::from(vec![ + RowSelector::select(25), // RG2: first 25 rows + RowSelector::skip(75), // RG2: skip last 75 rows + ]), + ); + access_plan.scan_selection( + 3, + RowSelection::from(vec![RowSelector::select(100)]), // RG3: all 100 rows + ); + + let rg_metadata = metadata.row_groups(); + + // Step 1: Create PreparedAccessPlan + let prepared_plan = access_plan + .prepare(rg_metadata) + .expect("Failed to create PreparedAccessPlan"); + + // Verify original plan + assert_eq!(prepared_plan.row_group_indexes, vec![0, 2, 3]); + let original_selected: usize = prepared_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + assert_eq!(original_selected, 225); // 100 + 25 + 100 + + // Step 2: Reverse the plan (this is the production code path) + let reversed_plan = prepared_plan + .reverse(&metadata) + .expect("Failed to reverse PreparedAccessPlan"); + + // Verify reversed results + // Row group order should be reversed: [3, 2, 0] + assert_eq!( + reversed_plan.row_group_indexes, + vec![3, 2, 0], + "Row groups should be reversed" + ); + + // Verify row selection is also correctly reversed + let reversed_selected: usize = reversed_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + assert_eq!( + reversed_selected, 225, + "Total selected rows should remain the same" + ); + + // Verify the reversed selection structure + // After reversal, the order becomes: RG3, RG2, RG0 + // - RG3: select(100) + // - RG2: select(25), skip(75) (note: internal order preserved, not reversed) + // - RG0: select(100) + // + // After RowSelection::from() merges adjacent selectors of the same type: + // - RG3's select(100) + RG2's select(25) = select(125) + // - RG2's skip(75) remains as skip(75) + // - RG0's select(100) remains as select(100) + let selectors: Vec<_> = reversed_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .collect(); + assert_eq!(selectors.len(), 3); + + // RG3 (100) + RG2 first part (25) merged into select(125) + assert!(!selectors[0].skip); + assert_eq!(selectors[0].row_count, 125); + + // RG2: skip last 75 rows + assert!(selectors[1].skip); + assert_eq!(selectors[1].row_count, 75); + + // RG0: select all 100 rows + assert!(!selectors[2].skip); + assert_eq!(selectors[2].row_count, 100); + } + + #[test] + fn test_prepared_access_plan_reverse_alternating_row_groups() { + // Test with alternating scan/skip pattern + let metadata = create_test_metadata(vec![100, 100, 100, 100]); + + // Scan RG0 and RG2, skip RG1 and RG3 + let mut access_plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, // RG0 + RowGroupAccess::Skip, // RG1 + RowGroupAccess::Scan, // RG2 + RowGroupAccess::Skip, // RG3 + ]); + + access_plan.scan_selection(0, RowSelection::from(vec![RowSelector::select(100)])); + access_plan.scan_selection(2, RowSelection::from(vec![RowSelector::select(100)])); + + let rg_metadata = metadata.row_groups(); + let prepared_plan = access_plan + .prepare(rg_metadata) + .expect("Failed to create PreparedAccessPlan"); + + let original_selected: usize = prepared_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + // Original: [0, 2] + assert_eq!(prepared_plan.row_group_indexes, vec![0, 2]); + + let reversed_plan = prepared_plan + .reverse(&metadata) + .expect("Failed to reverse PreparedAccessPlan"); + + // After reverse: [2, 0] + assert_eq!(reversed_plan.row_group_indexes, vec![2, 0]); + + let reversed_selected: usize = reversed_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + assert_eq!(original_selected, reversed_selected); + assert_eq!(original_selected, 200); + } + + #[test] + fn test_prepared_access_plan_reverse_middle_row_group_only() { + // Test selecting only the middle row group + let metadata = create_test_metadata(vec![100, 100, 100]); + + let mut access_plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Skip, // RG0 + RowGroupAccess::Scan, // RG1 + RowGroupAccess::Skip, // RG2 + ]); + + access_plan.scan_selection( + 1, + RowSelection::from(vec![RowSelector::select(100)]), // Select all of RG1 + ); + + let rg_metadata = metadata.row_groups(); + let prepared_plan = access_plan + .prepare(rg_metadata) + .expect("Failed to create PreparedAccessPlan"); + + let original_selected: usize = prepared_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + // Original: [1] + assert_eq!(prepared_plan.row_group_indexes, vec![1]); + + let reversed_plan = prepared_plan + .reverse(&metadata) + .expect("Failed to reverse PreparedAccessPlan"); + + // After reverse: still [1] (only one row group) + assert_eq!(reversed_plan.row_group_indexes, vec![1]); + + let reversed_selected: usize = reversed_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + assert_eq!(original_selected, reversed_selected); + assert_eq!(original_selected, 100); + } + + #[test] + fn test_prepared_access_plan_reverse_with_skipped_row_groups_detailed() { + // This is the KEY test case for the bug fix! + // Test scenario where some row groups are completely skipped (not in scan plan) + // This version includes DETAILED verification of the selector distribution + let metadata = create_test_metadata(vec![100, 100, 100, 100]); + + // Scenario: RG0 (scan all), RG1 (completely skipped), RG2 (partial), RG3 (scan all) + // Only row groups [0, 2, 3] are in the scan plan + let mut access_plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, // RG0 + RowGroupAccess::Skip, // RG1 - NOT in scan plan! + RowGroupAccess::Scan, // RG2 + RowGroupAccess::Scan, // RG3 + ]); + + // Add row selections for the scanned row groups + access_plan.scan_selection( + 0, + RowSelection::from(vec![RowSelector::select(100)]), // RG0: all 100 rows + ); + // RG1 is skipped, no selection needed + access_plan.scan_selection( + 2, + RowSelection::from(vec![ + RowSelector::select(25), // RG2: first 25 rows + RowSelector::skip(75), // RG2: skip last 75 rows + ]), + ); + access_plan.scan_selection( + 3, + RowSelection::from(vec![RowSelector::select(100)]), // RG3: all 100 rows + ); + + let rg_metadata = metadata.row_groups(); + + // Step 1: Create PreparedAccessPlan + let prepared_plan = access_plan + .prepare(rg_metadata) + .expect("Failed to create PreparedAccessPlan"); + + // Verify original plan in detail + assert_eq!(prepared_plan.row_group_indexes, vec![0, 2, 3]); + + // Detailed verification of original selection + let orig_selectors: Vec<_> = prepared_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .collect(); + + // Original structure should be: + // RG0: select(100) + // RG2: select(25), skip(75) + // RG3: select(100) + // After merging by RowSelection::from(): select(125), skip(75), select(100) + assert_eq!( + orig_selectors.len(), + 3, + "Original should have 3 selectors after merging" + ); + assert!( + !orig_selectors[0].skip && orig_selectors[0].row_count == 125, + "Original: First selector should be select(125) from RG0(100) + RG2(25)" + ); + assert!( + orig_selectors[1].skip && orig_selectors[1].row_count == 75, + "Original: Second selector should be skip(75) from RG2" + ); + assert!( + !orig_selectors[2].skip && orig_selectors[2].row_count == 100, + "Original: Third selector should be select(100) from RG3" + ); + + let original_selected: usize = orig_selectors + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + assert_eq!(original_selected, 225); // 100 + 25 + 100 + + // Step 2: Reverse the plan (this is the production code path) + let reversed_plan = prepared_plan + .reverse(&metadata) + .expect("Failed to reverse PreparedAccessPlan"); + + // Verify reversed results + // Row group order should be reversed: [3, 2, 0] + assert_eq!( + reversed_plan.row_group_indexes, + vec![3, 2, 0], + "Row groups should be reversed" + ); + + // Detailed verification of reversed selection + let rev_selectors: Vec<_> = reversed_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .collect(); + + // After reversal, the order becomes: RG3, RG2, RG0 + // - RG3: select(100) + // - RG2: select(25), skip(75) (note: internal order preserved, not reversed) + // - RG0: select(100) + // + // After RowSelection::from() merges adjacent selectors of the same type: + // - RG3's select(100) + RG2's select(25) = select(125) + // - RG2's skip(75) remains as skip(75) + // - RG0's select(100) remains as select(100) + + assert_eq!( + rev_selectors.len(), + 3, + "Reversed should have 3 selectors after merging" + ); + + // First selector: RG3 (100) + RG2 first part (25) merged into select(125) + assert!( + !rev_selectors[0].skip && rev_selectors[0].row_count == 125, + "Reversed: First selector should be select(125) from RG3(100) + RG2(25), got skip={} count={}", + rev_selectors[0].skip, + rev_selectors[0].row_count + ); + + // Second selector: RG2 skip last 75 rows + assert!( + rev_selectors[1].skip && rev_selectors[1].row_count == 75, + "Reversed: Second selector should be skip(75) from RG2, got skip={} count={}", + rev_selectors[1].skip, + rev_selectors[1].row_count + ); + + // Third selector: RG0 select all 100 rows + assert!( + !rev_selectors[2].skip && rev_selectors[2].row_count == 100, + "Reversed: Third selector should be select(100) from RG0, got skip={} count={}", + rev_selectors[2].skip, + rev_selectors[2].row_count + ); + + // Verify row selection is also correctly reversed (total count) + let reversed_selected: usize = rev_selectors + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + assert_eq!( + reversed_selected, 225, + "Total selected rows should remain the same" + ); + } + + #[test] + fn test_prepared_access_plan_reverse_complex_pattern_detailed() { + // Test: complex pattern with detailed verification + let metadata = create_test_metadata(vec![100, 100, 100]); + + let mut access_plan = ParquetAccessPlan::new_all(3); + + // Complex pattern: select some, skip some, select some more + access_plan.scan_selection( + 0, + RowSelection::from(vec![ + RowSelector::select(30), + RowSelector::skip(40), + RowSelector::select(30), + ]), + ); + access_plan.scan_selection( + 1, + RowSelection::from(vec![RowSelector::skip(50), RowSelector::select(50)]), + ); + access_plan.scan_selection(2, RowSelection::from(vec![RowSelector::select(100)])); + + let rg_metadata = metadata.row_groups(); + let prepared_plan = access_plan + .prepare(rg_metadata) + .expect("Failed to create PreparedAccessPlan"); + + // Verify original selection structure in detail + let orig_selectors: Vec<_> = prepared_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .collect(); + + // RG0: select(30), skip(40), select(30) + // RG1: skip(50), select(50) + // RG2: select(100) + // Sequential: sel(30), skip(40), sel(30), skip(50), sel(50), sel(100) + // After merge: sel(30), skip(40), sel(30), skip(50), sel(150) + + let original_selected: usize = orig_selectors + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + assert_eq!(original_selected, 210); // 30 + 30 + 50 + 100 + + let reversed_plan = prepared_plan + .reverse(&metadata) + .expect("Failed to reverse PreparedAccessPlan"); + + // Verify reversed selection structure + let rev_selectors: Vec<_> = reversed_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .collect(); + + // After reversal: RG2, RG1, RG0 + // RG2: select(100) + // RG1: skip(50), select(50) + // RG0: select(30), skip(40), select(30) + // Sequential: sel(100), skip(50), sel(50), sel(30), skip(40), sel(30) + // After merge: sel(100), skip(50), sel(80), skip(40), sel(30) + + let reversed_selected: usize = rev_selectors + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + assert_eq!( + reversed_selected, 210, + "Total selected rows should remain the same (30 + 30 + 50 + 100)" + ); + + // Verify row group order + assert_eq!(reversed_plan.row_group_indexes, vec![2, 1, 0]); + } + + #[test] + fn test_prepared_access_plan_reverse_alternating_detailed() { + // Test with alternating scan/skip pattern with detailed verification + let metadata = create_test_metadata(vec![100, 100, 100, 100]); + + // Scan RG0 and RG2, skip RG1 and RG3 + let mut access_plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, // RG0 + RowGroupAccess::Skip, // RG1 + RowGroupAccess::Scan, // RG2 + RowGroupAccess::Skip, // RG3 + ]); + + access_plan.scan_selection( + 0, + RowSelection::from(vec![RowSelector::select(30), RowSelector::skip(70)]), + ); + access_plan.scan_selection( + 2, + RowSelection::from(vec![RowSelector::skip(20), RowSelector::select(80)]), + ); + + let rg_metadata = metadata.row_groups(); + let prepared_plan = access_plan + .prepare(rg_metadata) + .expect("Failed to create PreparedAccessPlan"); + + // Original: [0, 2] + assert_eq!(prepared_plan.row_group_indexes, vec![0, 2]); + + // Verify original selection + let orig_selectors: Vec<_> = prepared_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .collect(); + + // Original: + // RG0: select(30), skip(70) + // RG2: skip(20), select(80) + // Sequential: sel(30), skip(90), sel(80) + // (RG0's skip(70) + RG2's skip(20) = skip(90)) + + let original_selected: usize = orig_selectors + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + assert_eq!(original_selected, 110); // 30 + 80 + + let reversed_plan = prepared_plan + .reverse(&metadata) + .expect("Failed to reverse PreparedAccessPlan"); + + // After reverse: [2, 0] + assert_eq!(reversed_plan.row_group_indexes, vec![2, 0]); + + // Verify reversed selection + let rev_selectors: Vec<_> = reversed_plan + .row_selection + .as_ref() + .unwrap() + .iter() + .collect(); + + // After reversal: RG2, RG0 + // RG2: skip(20), select(80) + // RG0: select(30), skip(70) + // Sequential: skip(20), sel(110), skip(70) + // (RG2's select(80) + RG0's select(30) = select(110)) + + let reversed_selected: usize = rev_selectors + .iter() + .filter(|s| !s.skip) + .map(|s| s.row_count) + .sum(); + + assert_eq!(reversed_selected, 110); // Should still be 30 + 80 + + // Detailed verification of structure + assert_eq!(rev_selectors.len(), 3, "Reversed should have 3 selectors"); + + assert!( + rev_selectors[0].skip && rev_selectors[0].row_count == 20, + "First selector should be skip(20) from RG2" + ); + + assert!( + !rev_selectors[1].skip && rev_selectors[1].row_count == 110, + "Second selector should be select(110) from RG2(80) + RG0(30)" + ); + + assert!( + rev_selectors[2].skip && rev_selectors[2].row_count == 70, + "Third selector should be skip(70) from RG0" + ); + } +} diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index da7bc125d2f6a..8228cd273eae6 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -16,45 +16,48 @@ // under the License. //! ParquetSource implementation for reading parquet files -use std::any::Any; use std::fmt::Debug; use std::fmt::Formatter; use std::sync::Arc; -use crate::opener::build_pruning_predicates; -use crate::opener::ParquetOpener; -use crate::row_filter::can_expr_be_pushed_down_with_schemas; use crate::DefaultParquetFileReaderFactory; use crate::ParquetFileReaderFactory; +use crate::opener::ParquetMorselizer; +use crate::opener::build_pruning_predicates; +use crate::opener::build_virtual_columns_state; +use crate::row_filter::can_expr_be_pushed_down_with_schemas; use datafusion_common::config::ConfigOptions; #[cfg(feature = "parquet_encryption")] use datafusion_common::config::EncryptionFactoryOptions; use datafusion_datasource::as_file_source; use datafusion_datasource::file_stream::FileOpener; -use datafusion_datasource::schema_adapter::{ - DefaultSchemaAdapterFactory, SchemaAdapterFactory, -}; +use datafusion_datasource::morsel::Morselizer; +use arrow::array::timezone::Tz; use arrow::datatypes::TimeUnit; -use datafusion_common::config::TableParquetOptions; use datafusion_common::DataFusionError; +use datafusion_common::config::TableParquetOptions; +use datafusion_datasource::TableSchema; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::TableSchema; -use datafusion_physical_expr::conjunction; +use datafusion_physical_expr::projection::ProjectionExprs; +use datafusion_physical_expr::{EquivalenceProperties, conjunction}; use datafusion_physical_expr_adapter::DefaultPhysicalExprAdapterFactory; -use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::physical_expr::fmt_sql; +use datafusion_physical_plan::DisplayFormatType; +use datafusion_physical_plan::SortOrderPushdownResult; use datafusion_physical_plan::filter_pushdown::PushedDown; use datafusion_physical_plan::filter_pushdown::{ FilterPushdownPropagation, PushedDownPredicate, }; use datafusion_physical_plan::metrics::Count; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; -use datafusion_physical_plan::DisplayFormatType; +use log::warn; #[cfg(feature = "parquet_encryption")] use datafusion_execution::parquet_encryption::EncryptionFactory; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use itertools::Itertools; use object_store::ObjectStore; #[cfg(feature = "parquet_encryption")] @@ -133,7 +136,7 @@ use parquet::encryption::decrypt::FileDecryptionProperties; /// details. /// /// * Schema evolution: read parquet files with different schemas into a unified -/// table schema. See [`SchemaAdapterFactory`] for more details. +/// table schema. See [`DefaultPhysicalExprAdapterFactory`] for more details. /// /// * metadata_size_hint: controls the number of bytes read from the end of the /// file in the initial I/O when the default [`ParquetFileReaderFactory`]. If a @@ -182,7 +185,7 @@ use parquet::encryption::decrypt::FileDecryptionProperties; /// // Split a single DataSourceExec into multiple DataSourceExecs, one for each file /// let exec = parquet_exec(); /// let data_source = exec.data_source(); -/// let base_config = data_source.as_any().downcast_ref::().unwrap(); +/// let base_config = data_source.downcast_ref::().unwrap(); /// let existing_file_groups = &base_config.file_groups; /// let new_execs = existing_file_groups /// .iter() @@ -229,7 +232,7 @@ use parquet::encryption::decrypt::FileDecryptionProperties; /// access_plan.skip(4); /// // provide the plan as extension to the FileScanConfig /// let partitioned_file = PartitionedFile::new("my_file.parquet", 1234) -/// .with_extensions(Arc::new(access_plan)); +/// .with_extension(access_plan); /// // create a FileScanConfig to scan this file /// let config = FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), Arc::new(ParquetSource::new(schema()))) /// .with_file(partitioned_file).build(); @@ -240,17 +243,17 @@ use parquet::encryption::decrypt::FileDecryptionProperties; /// /// For a complete example, see the [`advanced_parquet_index` example]). /// -/// [`parquet_index_advanced` example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_parquet_index.rs +/// [`parquet_index_advanced` example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/data_io/parquet_advanced_index.rs /// /// # Execution Overview /// /// * Step 1: `DataSourceExec::execute` is called, returning a `FileStream` -/// configured to open parquet files with a `ParquetOpener`. +/// configured to morselize parquet files with a `ParquetMorselizer`. /// -/// * Step 2: When the stream is polled, the `ParquetOpener` is called to open -/// the file. +/// * Step 2: When the stream is polled, the `ParquetMorselizer` is called to +/// plan the file. /// -/// * Step 3: The `ParquetOpener` gets the [`ParquetMetaData`] (file metadata) +/// * Step 3: The `ParquetMorselizer` gets the [`ParquetMetaData`] (file metadata) /// via [`ParquetFileReaderFactory`], creating a `ParquetAccessPlan` by /// applying predicates to metadata. The plan and projections are used to /// determine what pages must be read. @@ -260,12 +263,12 @@ use parquet::encryption::decrypt::FileDecryptionProperties; /// [`Self::with_pushdown_filters`]). /// /// * Step 5: As each [`RecordBatch`] is read, it may be adapted by a -/// [`SchemaAdapter`] to match the table schema. By default missing columns are -/// filled with nulls, but this can be customized via [`SchemaAdapterFactory`]. +/// [`DefaultPhysicalExprAdapterFactory`] to match the table schema. By default missing columns are +/// filled with nulls, but this can be customized via [`PhysicalExprAdapterFactory`]. /// /// [`RecordBatch`]: arrow::record_batch::RecordBatch -/// [`SchemaAdapter`]: datafusion_datasource::schema_adapter::SchemaAdapter /// [`ParquetMetadata`]: parquet::file::metadata::ParquetMetaData +/// [`PhysicalExprAdapterFactory`]: datafusion_physical_expr_adapter::PhysicalExprAdapterFactory #[derive(Clone, Debug)] pub struct ParquetSource { /// Options for reading Parquet files @@ -280,14 +283,21 @@ pub struct ParquetSource { pub(crate) predicate: Option>, /// Optional user defined parquet file reader factory pub(crate) parquet_file_reader_factory: Option>, - /// Optional user defined schema adapter - pub(crate) schema_adapter_factory: Option>, /// Batch size configuration pub(crate) batch_size: Option, /// Optional hint for the size of the parquet metadata pub(crate) metadata_size_hint: Option, + /// Projection to apply to the output. + pub(crate) projection: ProjectionExprs, #[cfg(feature = "parquet_encryption")] pub(crate) encryption_factory: Option>, + /// If true, the opener flips row-group iteration order. Within- + /// row-group order is on-disk order, so the scan is `Inexact` and + /// a `SortExec` is kept in the plan. + reverse_row_groups: bool, + /// Sort order driving `PreparedAccessPlan::reorder_by_statistics` + /// in the opener. + sort_order_for_reorder: Option, } impl ParquetSource { @@ -297,17 +307,23 @@ impl ParquetSource { /// Uses default `TableParquetOptions`. /// To set custom options, use [ParquetSource::with_table_parquet_options`]. pub fn new(table_schema: impl Into) -> Self { + let table_schema = table_schema.into(); + // Projection over the full table schema (file columns + partition columns) + let full_schema = table_schema.table_schema(); + let indices: Vec = (0..full_schema.fields().len()).collect(); Self { - table_schema: table_schema.into(), + projection: ProjectionExprs::from_indices(&indices, full_schema), + table_schema, table_parquet_options: TableParquetOptions::default(), metrics: ExecutionPlanMetricsSet::new(), predicate: None, parquet_file_reader_factory: None, - schema_adapter_factory: None, batch_size: None, metadata_size_hint: None, #[cfg(feature = "parquet_encryption")] encryption_factory: None, + reverse_row_groups: false, + sort_order_for_reorder: None, } } @@ -331,7 +347,11 @@ impl ParquetSource { self } - /// Set predicate information + /// Set predicate information. + /// + /// Predicates referencing virtual columns must go through + /// [`Self::try_pushdown_filters`]. Passing them here with pushdown + /// enabled trips a debug assert in the opener. #[expect(clippy::needless_pass_by_value)] pub fn with_predicate(&self, predicate: Arc) -> Self { let mut conf = self.clone(); @@ -404,6 +424,11 @@ impl ParquetSource { self.table_parquet_options.global.reorder_filters } + /// Return the value of [`datafusion_common::config::ParquetOptions::force_filter_selections`] + fn force_filter_selections(&self) -> bool { + self.table_parquet_options.global.force_filter_selections + } + /// If enabled, the reader will read the page index /// This is used to optimize filter pushdown /// via `RowSelector` and `RowFilter` by @@ -445,28 +470,6 @@ impl ParquetSource { self.table_parquet_options.global.max_predicate_cache_size } - /// Applies schema adapter factory from the FileScanConfig if present. - /// - /// # Arguments - /// * `conf` - FileScanConfig that may contain a schema adapter factory - /// # Returns - /// The converted FileSource with schema adapter factory applied if provided - pub fn apply_schema_adapter( - self, - conf: &FileScanConfig, - ) -> datafusion_common::Result> { - let file_source: Arc = self.into(); - - // If the FileScanConfig.file_source() has a schema adapter factory, apply it - if let Some(factory) = conf.file_source().schema_adapter_factory() { - file_source.with_schema_adapter_factory( - Arc::::clone(&factory), - ) - } else { - Ok(file_source) - } - } - #[cfg(feature = "parquet_encryption")] fn get_encryption_factory_with_config( &self, @@ -479,6 +482,16 @@ impl ParquetSource { )), } } + + #[cfg(test)] + pub(crate) fn with_reverse_row_groups(mut self, reverse_row_groups: bool) -> Self { + self.reverse_row_groups = reverse_row_groups; + self + } + #[cfg(test)] + pub(crate) fn reverse_row_groups(&self) -> bool { + self.reverse_row_groups + } } /// Parses datafusion.common.config.ParquetOptions.coerce_int96 String to a arrow_schema.datatype.TimeUnit @@ -499,6 +512,19 @@ pub(crate) fn parse_coerce_int96_string( } } +/// Validates that `tz` is a parseable IANA timezone and returns it as an +/// `Arc` for use in `Timestamp(_, Some(tz))` types. +pub(crate) fn parse_coerce_int96_tz_string( + tz: &str, +) -> datafusion_common::Result> { + tz.parse::().map_err(|e| { + DataFusionError::Configuration(format!( + "Invalid parquet coerce_int96_tz {tz:?}: {e}" + )) + })?; + Ok(Arc::::from(tz)) +} + /// Allows easy conversion from ParquetSource to Arc<dyn FileSource> impl From for Arc { fn from(source: ParquetSource) -> Self { @@ -508,52 +534,26 @@ impl From for Arc { impl FileSource for ParquetSource { fn create_file_opener( + &self, + _object_store: Arc, + _base_config: &FileScanConfig, + _partition: usize, + ) -> datafusion_common::Result> { + datafusion_common::internal_err!( + "ParquetSource::create_file_opener called but it supports the Morsel API, please use that instead" + ) + } + + fn create_morselizer( &self, object_store: Arc, base_config: &FileScanConfig, partition: usize, - ) -> Arc { - let projection = base_config - .file_column_projection_indices() - .unwrap_or_else(|| (0..base_config.file_schema().fields().len()).collect()); - - let (expr_adapter_factory, schema_adapter_factory) = match ( - base_config.expr_adapter_factory.as_ref(), - self.schema_adapter_factory.as_ref(), - ) { - (Some(expr_adapter_factory), Some(schema_adapter_factory)) => { - // Use both the schema adapter factory and the expr adapter factory. - // This results in the SchemaAdapter being used for projections (e.g. a column was selected that is a UInt32 in the file and a UInt64 in the table schema) - // but the PhysicalExprAdapterFactory being used for predicate pushdown and stats pruning. - ( - Some(Arc::clone(expr_adapter_factory)), - Arc::clone(schema_adapter_factory), - ) - } - (Some(expr_adapter_factory), None) => { - // If no custom schema adapter factory is provided but an expr adapter factory is provided use the expr adapter factory alongside the default schema adapter factory. - // This means that the PhysicalExprAdapterFactory will be used for predicate pushdown and stats pruning, while the default schema adapter factory will be used for projections. - ( - Some(Arc::clone(expr_adapter_factory)), - Arc::new(DefaultSchemaAdapterFactory) as _, - ) - } - (None, Some(schema_adapter_factory)) => { - // If a custom schema adapter factory is provided but no expr adapter factory is provided use the custom SchemaAdapter for both projections and predicate pushdown. - // This maximizes compatibility with existing code that uses the SchemaAdapter API and did not explicitly opt into the PhysicalExprAdapterFactory API. - (None, Arc::clone(schema_adapter_factory) as _) - } - (None, None) => { - // If no custom schema adapter factory or expr adapter factory is provided, use the default schema adapter factory and the default physical expr adapter factory. - // This means that the default SchemaAdapter will be used for projections (e.g. a column was selected that is a UInt32 in the file and a UInt64 in the table schema) - // and the default PhysicalExprAdapterFactory will be used for predicate pushdown and stats pruning. - // This is the default behavior with not customization and means that most users of DataFusion will be cut over to the new PhysicalExprAdapterFactory API. - ( - Some(Arc::new(DefaultPhysicalExprAdapterFactory) as _), - Arc::new(DefaultSchemaAdapterFactory) as _, - ) - } - }; + ) -> datafusion_common::Result> { + let expr_adapter_factory = base_config + .expr_adapter_factory + .clone() + .unwrap_or_else(|| Arc::new(DefaultPhysicalExprAdapterFactory) as _); let parquet_file_reader_factory = self.parquet_file_reader_factory.clone().unwrap_or_else(|| { @@ -566,7 +566,8 @@ impl FileSource for ParquetSource { .crypto .file_decryption .clone() - .map(FileDecryptionProperties::from) + .map(FileDecryptionProperties::try_from) + .transpose()? .map(Arc::new); let coerce_int96 = self @@ -575,38 +576,78 @@ impl FileSource for ParquetSource { .coerce_int96 .as_ref() .map(|time_unit| parse_coerce_int96_string(time_unit.as_str()).unwrap()); + let coerce_int96_tz = self + .table_parquet_options + .global + .coerce_int96_tz + .as_ref() + .map(|tz| parse_coerce_int96_tz_string(tz)) + .transpose()?; + if coerce_int96_tz.is_some() && coerce_int96.is_none() { + warn!( + "coerce_int96_tz is set but coerce_int96 is not; the timezone will be ignored" + ); + } + + // Validate virtual columns (extension-type allowlist) and, when + // pushdown is enabled, reject predicates that reference them. Both + // checks depend only on morselizer-level state, so we pay their cost + // once per scan partition rather than per file. + // + // Gating predicate validation on `pushdown_filters` is deliberate: + // when pushdown is off the predicate stays above the scan as a + // `FilterExec` and resolves virtual columns there; the row-filter + // ban only applies to the pushdown path. + let virtual_state = build_virtual_columns_state( + self.table_schema.virtual_columns(), + self.table_schema.file_schema(), + self.predicate.as_ref(), + self.pushdown_filters(), + )?; - Arc::new(ParquetOpener { + Ok(Box::new(ParquetMorselizer { partition_index: partition, - projection: Arc::from(projection), + projection: self.projection.clone(), batch_size: self .batch_size - .expect("Batch size must set before creating ParquetOpener"), + .expect("Batch size must set before creating ParquetMorselizer"), limit: base_config.limit, + preserve_order: base_config.preserve_order, predicate: self.predicate.clone(), - logical_file_schema: Arc::clone(base_config.file_schema()), - partition_fields: base_config.table_partition_cols().clone(), + table_schema: self.table_schema.clone(), metadata_size_hint: self.metadata_size_hint, metrics: self.metrics().clone(), parquet_file_reader_factory, pushdown_filters: self.pushdown_filters(), reorder_filters: self.reorder_filters(), + force_filter_selections: self.force_filter_selections(), enable_page_index: self.enable_page_index(), enable_bloom_filter: self.bloom_filter_on_read(), enable_row_group_stats_pruning: self.table_parquet_options.global.pruning, - schema_adapter_factory, coerce_int96, + coerce_int96_tz, #[cfg(feature = "parquet_encryption")] file_decryption_properties, expr_adapter_factory, #[cfg(feature = "parquet_encryption")] encryption_factory: self.get_encryption_factory_with_config(), max_predicate_cache_size: self.max_predicate_cache_size(), - }) + reverse_row_groups: self.reverse_row_groups, + sort_order_for_reorder: self.sort_order_for_reorder.clone(), + virtual_state, + })) } - fn as_any(&self) -> &dyn Any { - self + fn reorder_files( + &self, + files: Vec, + ) -> Vec { + crate::sort::reorder_files_by_min_statistics( + files, + self.sort_order_for_reorder.as_ref(), + self.reverse_row_groups, + self.table_schema.table_schema(), + ) } fn table_schema(&self) -> &TableSchema { @@ -623,8 +664,17 @@ impl FileSource for ParquetSource { Arc::new(conf) } - fn with_projection(&self, _config: &FileScanConfig) -> Arc { - Arc::new(Self { ..self.clone() }) + fn try_pushdown_projection( + &self, + projection: &ProjectionExprs, + ) -> datafusion_common::Result>> { + let mut source = self.clone(); + source.projection = self.projection.try_merge(projection)?; + Ok(Some(Arc::new(source))) + } + + fn projection(&self) -> Option<&ProjectionExprs> { + Some(&self.projection) } fn metrics(&self) -> &ExecutionPlanMetricsSet { @@ -645,17 +695,26 @@ impl FileSource for ParquetSource { write!(f, "{predicate_string}")?; - // Try to build a the pruning predicates. + // Inexact sort-pushdown markers: surface both flags so + // readers can see the optimization fired. + if let Some(sort_order) = &self.sort_order_for_reorder { + write!(f, ", sort_order_for_reorder=[{sort_order}]")?; + } + if self.reverse_row_groups { + write!(f, ", reverse_row_groups=true")?; + } + + // Try to build the pruning predicates. // These are only generated here because it's useful to have *some* // idea of what pushdown is happening when viewing plans. - // However it is important to note that these predicates are *not* + // However, it is important to note that these predicates are *not* // necessarily the predicates that are actually evaluated: // the actual predicates are built in reference to the physical schema of // each file, which we do not have at this point and hence cannot use. - // Instead we use the logical schema of the file (the table schema without partition columns). + // Instead, we use the logical schema of the file (the table schema without partition columns). if let Some(predicate) = &self.predicate { let predicate_creation_errors = Count::new(); - if let (Some(pruning_predicate), _) = build_pruning_predicates( + if let Some(pruning_predicate) = build_pruning_predicates( Some(predicate), self.table_schema.table_schema(), &predicate_creation_errors, @@ -690,7 +749,12 @@ impl FileSource for ParquetSource { filters: Vec>, config: &ConfigOptions, ) -> datafusion_common::Result>> { - let table_schema = self.table_schema.table_schema(); + // Use the schema excluding virtual columns: virtual columns (e.g. + // Parquet `row_number`) are produced by the reader itself and cannot + // be referenced inside a RowFilter, so predicates that reference them + // must not be marked as pushed down — otherwise the scan would + // silently drop them and produce wrong results. + let pushable_schema = self.table_schema.schema_without_virtual_columns(); // Determine if based on configs we should push filters down. // If either the table / scan itself or the config has pushdown enabled, // we will push down the filters. @@ -706,7 +770,7 @@ impl FileSource for ParquetSource { let filters: Vec = filters .into_iter() .map(|filter| { - if can_expr_be_pushed_down_with_schemas(&filter, table_schema) { + if can_expr_be_pushed_down_with_schemas(&filter, pushable_schema) { PushedDownPredicate::supported(filter) } else { PushedDownPredicate::unsupported(filter) @@ -753,18 +817,175 @@ impl FileSource for ParquetSource { .with_updated_node(source)) } - fn with_schema_adapter_factory( + /// Try to optimize the scan to produce data in the requested sort order. + /// + /// Inputs: + /// 1. The query's required ordering (`order` parameter) + /// 2. The source's equivalence properties (`eq_properties`) + /// + /// # Returns + /// - `Exact`: the source's natural ordering already satisfies the + /// request. The surrounding `SortExec` can be eliminated provided + /// files within each group are non-overlapping (verified by + /// `FileScanConfig`). + /// - `Inexact`: the source can approximate the request via two + /// composable runtime steps — stats-based row-group reorder + /// (skipped when the leading sort key isn't a plain `Column` + /// in the file schema) and row-group iteration reverse. A + /// `SortExec` is still required for full correctness, but limit + /// pushdown and `TopK` benefit immediately. + /// - `Unsupported`: no approximation is available. + /// + /// # How the Inexact result is communicated + /// + /// The result is carried through two fields on `ParquetSource`: + /// + /// - `sort_order_for_reorder`: set to the request's `LexOrdering` + /// whenever the pushdown fires, regardless of whether the + /// leading expression is a plain `Column`. The opener invokes + /// `PreparedAccessPlan::reorder_by_statistics`, which skips + /// when the expression can't be looked up in parquet metadata. + /// Exposing the field unconditionally keeps `EXPLAIN` honest + /// about what the source was asked to approximate. + /// - `reverse_row_groups`: drives the opener's iteration flip. + /// When stats reorder applies (column-in-schema), this is just + /// the request's direction — the reorder produces ASC-by-min, + /// so reverse iff the query asks for DESC. When stats reorder + /// doesn't apply but the reversed source ordering satisfies + /// the request (function-wrapped case), this is always `true` + /// because we're flipping the file's natural order. + fn try_pushdown_sort( &self, - schema_adapter_factory: Arc, - ) -> datafusion_common::Result> { - Ok(Arc::new(Self { - schema_adapter_factory: Some(schema_adapter_factory), - ..self.clone() - })) - } + order: &[PhysicalSortExpr], + eq_properties: &EquivalenceProperties, + ) -> datafusion_common::Result>> { + if order.is_empty() { + return Ok(SortOrderPushdownResult::Unsupported); + } + + // Check if the natural (non-reversed) ordering already satisfies the request. + // Parquet metadata guarantees within-file ordering, so if the ordering matches + // we can return Exact. FileScanConfig will verify that files within each group + // are non-overlapping before declaring the entire scan as Exact. + if eq_properties.ordering_satisfy(order.iter().cloned())? { + return Ok(SortOrderPushdownResult::Exact { + inner: Arc::new(self.clone()) as Arc, + }); + } + + // If the source's declared ordering is a non-empty *proper* prefix + // of the request (e.g. source `[a DESC, b ASC]`, request + // `[a DESC, b ASC, c DESC]`), decline pushdown so the outer + // `SortExec`'s `sort_prefix` optimisation — prefix-aware early + // termination in `TopK` — can still fire. Firing the Inexact + // pipeline below would invalidate the source's `output_ordering` + // (the runtime row-group reorder is approximate, so we can't + // honour the declared ordering anymore), which is exactly what + // `EnforceSorting` needs to derive `sort_prefix`. On data that + // is already in prefix order the stats-based reorder is mostly + // a no-op anyway, so the trade-off is plainly bad. + for prefix_len in 1..order.len() { + let prefix = order[..prefix_len].to_vec(); + if eq_properties.ordering_satisfy(prefix.iter().cloned())? { + return Ok(SortOrderPushdownResult::Unsupported); + } + } + + // Inexact pushdown. Two independent signals; either is enough + // to produce an approximate ordering, and they compose when + // both apply: + // + // 1. `column_in_file_schema`: the request's leading sort key is + // a plain `Column` present in the file schema. The opener + // can sort row groups by that column's `min` via parquet + // statistics. Drives `sort_order_for_reorder`'s actual use. + // + // 2. `reversed_satisfies`: the source's declared ordering, + // when reversed, satisfies the request. This is strictly + // more powerful than the column-in-schema check because it + // runs the request through `EquivalenceProperties`'s full + // reasoning machinery: + // + // - Function monotonicity: e.g. file declares + // `[extract_year_month(ws) DESC, ws DESC]`, request is + // `[ws ASC]`; the reversed ordering still satisfies the + // request via `extract_year_month`'s monotonicity even + // though parquet has no stats keyed by the function + // expression itself. + // - Constant columns from filters: equivalence classes can + // mark columns as constant under a predicate, allowing + // requested orderings on those columns to be trivially + // satisfied. + // - Other equivalence relationships (e.g. `a = b` transfers + // ordering between `a` and `b`). + // + // `reorder_by_statistics` can't substitute for any of the + // above because it can only look up min/max for a plain + // physical column. + // + // `sort_order_for_reorder` is set in both cases so EXPLAIN + // shows what the source was asked to approximate; the opener + // skips stats-based reorder when the leading expression isn't + // a plain `Column`. + // + // The reversal flips each `PhysicalSortExpr` (both descending + // and nulls_first) and rebuilds an `EquivalenceProperties` so + // the request can be tested against the reversed orderings + // via the same `ordering_satisfy` API. + let reversed_eq_properties = { + let mut new = eq_properties.clone(); + new.clear_orderings(); + let reversed_orderings = eq_properties + .oeq_class() + .iter() + .map(|ordering| { + ordering + .iter() + .map(|expr| expr.reverse()) + .collect::>() + }) + .collect::>(); + new.add_orderings(reversed_orderings); + new + }; + let reversed_satisfies = + reversed_eq_properties.ordering_satisfy(order.iter().cloned())?; + let sort_order = LexOrdering::new(order.iter().cloned()); + let column_in_file_schema = sort_order.as_ref().is_some_and(|s| { + s.first() + .expr + .downcast_ref::() + .is_some_and(|col| { + self.table_schema + .file_schema() + .field_with_name(col.name()) + .is_ok() + }) + }); - fn schema_adapter_factory(&self) -> Option> { - self.schema_adapter_factory.clone() + if !column_in_file_schema && !reversed_satisfies { + return Ok(SortOrderPushdownResult::Unsupported); + } + + // `reverse_row_groups` has different starting points in the + // two cases: + // - With stats reorder (column-in-schema): the reorder produces + // ASC-by-min, so reverse iff the request is DESC. + // - Without stats reorder (reversed-eq fallback): we flip the + // file's natural order, so always reverse. + let is_descending = sort_order + .as_ref() + .is_some_and(|s| s.first().options.descending); + let mut new_source = self.clone(); + new_source.sort_order_for_reorder = sort_order; + new_source.reverse_row_groups = if column_in_file_schema { + is_descending + } else { + true + }; + Ok(SortOrderPushdownResult::Inexact { + inner: Arc::new(new_source) as Arc, + }) } } @@ -775,7 +996,7 @@ mod tests { use datafusion_physical_expr::expressions::lit; #[test] - #[allow(deprecated)] + #[expect(deprecated)] fn test_parquet_source_predicate_same_as_filter() { let predicate = lit(true); @@ -784,4 +1005,672 @@ mod tests { // same value. but filter() call Arc::clone internally assert_eq!(parquet_source.predicate(), parquet_source.filter().as_ref()); } + + #[test] + fn test_reverse_scan_default_value() { + use arrow::datatypes::Schema; + + let schema = Arc::new(Schema::empty()); + let source = ParquetSource::new(schema); + + assert!(!source.reverse_row_groups()); + } + + #[test] + fn test_reverse_scan_with_setter() { + use arrow::datatypes::Schema; + + let schema = Arc::new(Schema::empty()); + + let source = ParquetSource::new(schema.clone()).with_reverse_row_groups(true); + assert!(source.reverse_row_groups()); + + let source = source.with_reverse_row_groups(false); + assert!(!source.reverse_row_groups()); + } + + #[test] + fn test_reverse_scan_clone_preserves_value() { + use arrow::datatypes::Schema; + + let schema = Arc::new(Schema::empty()); + + let source = ParquetSource::new(schema).with_reverse_row_groups(true); + let cloned = source.clone(); + + assert!(cloned.reverse_row_groups()); + assert_eq!(source.reverse_row_groups(), cloned.reverse_row_groups()); + } + + #[test] + fn test_reverse_scan_with_other_options() { + use arrow::datatypes::Schema; + + let schema = Arc::new(Schema::empty()); + let options = TableParquetOptions::default(); + + let source = ParquetSource::new(schema) + .with_table_parquet_options(options) + .with_metadata_size_hint(8192) + .with_reverse_row_groups(true); + + assert!(source.reverse_row_groups()); + assert_eq!(source.metadata_size_hint, Some(8192)); + } + + #[test] + fn test_reverse_scan_builder_pattern() { + use arrow::datatypes::Schema; + + let schema = Arc::new(Schema::empty()); + + let source = ParquetSource::new(schema) + .with_reverse_row_groups(true) + .with_reverse_row_groups(false) + .with_reverse_row_groups(true); + + assert!(source.reverse_row_groups()); + } + + #[test] + fn test_reverse_scan_independent_of_predicate() { + use arrow::datatypes::Schema; + use datafusion_physical_expr::expressions::lit; + + let schema = Arc::new(Schema::empty()); + let predicate = lit(true); + + let source = ParquetSource::new(schema) + .with_predicate(predicate) + .with_reverse_row_groups(true); + + assert!(source.reverse_row_groups()); + assert!(source.filter().is_some()); + } + + /// Helpers for the `try_pushdown_sort` regression tests below. + mod pushdown_sort_helpers { + use super::*; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + + pub(super) fn schema_with_a_int() -> Arc { + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])) + } + + pub(super) fn sort_expr_on( + schema: &Arc, + name: &str, + descending: bool, + ) -> PhysicalSortExpr { + let idx = schema.index_of(name).unwrap(); + PhysicalSortExpr { + expr: Arc::new(Column::new(name, idx)), + options: SortOptions { + descending, + nulls_first: true, + }, + } + } + } + + /// When neither natural nor reversed ordering matches the request, + /// but the sort column is a plain `Column` present in the file + /// schema, `try_pushdown_sort` returns `Inexact` with + /// `sort_order_for_reorder` set so the opener can reorder row + /// groups by min/max statistics at runtime. + #[test] + fn try_pushdown_sort_returns_inexact_when_column_in_schema_asc() { + use datafusion_physical_expr::EquivalenceProperties; + use pushdown_sort_helpers::*; + + let schema = schema_with_a_int(); + let source = ParquetSource::new(Arc::clone(&schema)); + let order = vec![sort_expr_on(&schema, "a", false)]; + // No declared natural ordering on the source. + let eq = EquivalenceProperties::new(Arc::clone(&schema)); + + let result = source.try_pushdown_sort(&order, &eq).unwrap(); + + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("expected Inexact, got a different variant"); + }; + // Downcast back to `ParquetSource` to inspect the fields the + // opener reads to drive `reorder_by_statistics` / `reverse`. + let inner_parquet = inner + .downcast_ref::() + .expect("inner is ParquetSource"); + let sort_order = inner_parquet + .sort_order_for_reorder + .as_ref() + .expect("sort_order_for_reorder must be set so the opener can reorder"); + assert_eq!(sort_order.first().expr.to_string(), "a@0"); + // ASC request must not flip RG iteration order. + assert!( + !inner_parquet.reverse_row_groups(), + "ASC request must not set reverse_row_groups", + ); + } + + /// Same as above but for DESC. `reverse_row_groups` must also be + /// `true` so the opener flips iteration order. + #[test] + fn try_pushdown_sort_returns_inexact_when_column_in_schema_desc() { + use datafusion_physical_expr::EquivalenceProperties; + use pushdown_sort_helpers::*; + + let schema = schema_with_a_int(); + let source = ParquetSource::new(Arc::clone(&schema)); + let order = vec![sort_expr_on(&schema, "a", true)]; + let eq = EquivalenceProperties::new(Arc::clone(&schema)); + + let result = source.try_pushdown_sort(&order, &eq).unwrap(); + + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("expected Inexact, got a different variant"); + }; + let inner_parquet = inner + .downcast_ref::() + .expect("inner is ParquetSource"); + assert!(inner_parquet.sort_order_for_reorder.is_some()); + assert!( + inner_parquet.reverse_row_groups(), + "DESC request must set reverse_row_groups", + ); + } + + /// A non-`Column` leading sort expression (e.g. `a + 1`, + /// `date_trunc('hour', ts)`) with no declared source ordering + /// yields `Unsupported` — parquet stats need a column name to + /// look up min/max, and there's no source ordering to reverse. + #[test] + fn try_pushdown_sort_returns_unsupported_for_non_column_sort_expr() { + use arrow::compute::SortOptions; + use datafusion_physical_expr::EquivalenceProperties; + use datafusion_physical_expr::expressions::{BinaryExpr, Column, lit}; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + use pushdown_sort_helpers::*; + + let schema = schema_with_a_int(); + let source = ParquetSource::new(Arc::clone(&schema)); + + // `a + 1` — not a plain Column. + let order = vec![PhysicalSortExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + datafusion_expr::Operator::Plus, + lit(1i32), + )), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let eq = EquivalenceProperties::new(Arc::clone(&schema)); + + let result = source.try_pushdown_sort(&order, &eq).unwrap(); + assert!( + matches!(result, SortOrderPushdownResult::Unsupported), + "non-Column sort expression must yield Unsupported", + ); + } + + /// A sort column missing from the file schema with no declared + /// source ordering yields `Unsupported` — there are no parquet + /// stats for that column and no source ordering to reverse. + #[test] + fn try_pushdown_sort_returns_unsupported_when_column_not_in_file_schema() { + use arrow::compute::SortOptions; + use datafusion_physical_expr::EquivalenceProperties; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + use pushdown_sort_helpers::*; + + let schema = schema_with_a_int(); + let source = ParquetSource::new(Arc::clone(&schema)); + + // Reference a column ("b") that does not exist in the file + // schema (which only has "a"). The Column expression itself is + // well-formed; only its name is unknown to the file. + let order = vec![PhysicalSortExpr { + expr: Arc::new(Column::new("b", 0)), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let eq = EquivalenceProperties::new(Arc::clone(&schema)); + + let result = source.try_pushdown_sort(&order, &eq).unwrap(); + assert!( + matches!(result, SortOrderPushdownResult::Unsupported), + "column not in file schema must yield Unsupported", + ); + } + + /// Regression: when the source declares `[a DESC]` and the request is + /// `[a ASC]`, both `column_in_file_schema` and `reversed_satisfies` + /// are true. `reverse_row_groups` must follow the *request's* + /// direction (false for ASC) — not the source's, and not the OR of + /// both signals. The opener's stats-based reorder produces + /// ASC-by-min, so an ASC request needs no further flip; flipping + /// would incorrectly emit DESC. + #[test] + fn try_pushdown_sort_source_desc_request_asc_does_not_reverse() { + use datafusion_physical_expr::EquivalenceProperties; + use pushdown_sort_helpers::*; + + let schema = schema_with_a_int(); + let source = ParquetSource::new(Arc::clone(&schema)); + // Source declares `[a DESC]`. + let mut eq = EquivalenceProperties::new(Arc::clone(&schema)); + eq.add_ordering(vec![sort_expr_on(&schema, "a", true)]); + // Request `[a ASC]`. + let order = vec![sort_expr_on(&schema, "a", false)]; + + let result = source.try_pushdown_sort(&order, &eq).unwrap(); + + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("expected Inexact, got a different variant"); + }; + let inner_parquet = inner + .downcast_ref::() + .expect("inner is ParquetSource"); + assert!( + inner_parquet.sort_order_for_reorder.is_some(), + "sort_order_for_reorder must be set", + ); + assert!( + !inner_parquet.reverse_row_groups(), + "ASC request on source-DESC must not set reverse_row_groups; \ + a stale `reversed_satisfies || is_descending` formula would \ + incorrectly flip iteration to DESC after the stats reorder", + ); + } + + /// A sort column that is *not* in the file schema (here: a partition + /// column `b`) but the source's *reversed* declared ordering does + /// satisfy the request. Pushdown fires via the reversed-equivalence + /// path; `sort_order_for_reorder` is still set (so EXPLAIN reflects + /// what the source was asked to approximate, even though the opener + /// will skip its stats reorder because `b` has no per-RG min/max in + /// the parquet file), and `reverse_row_groups` is `true` because we + /// flip the file's natural order rather than re-sort by stats. + #[test] + fn try_pushdown_sort_returns_inexact_via_reversed_eq_when_column_not_in_file_schema() + { + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_datasource::TableSchema; + use datafusion_physical_expr::EquivalenceProperties; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + + // File schema is just `[a]`; `b` lives as a partition column on + // top, so it appears in the table schema but not the file schema + // — the same shape `column_in_file_schema` discards. + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let partition_b = Arc::new(Field::new("b", DataType::Int32, true)); + let table_schema = TableSchema::builder(file_schema) + .with_table_partition_cols(vec![partition_b]) + .build(); + let source = ParquetSource::new(table_schema); + + // EquivalenceProperties is built on the *full* table schema so + // it can carry an ordering on `b`. + let full_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + // Construct the request first, then derive the declared + // ordering as its reverse, so `ordering_satisfy` on the + // reversed-eq holds exactly. `PhysicalSortExpr::reverse` flips + // both `descending` and `nulls_first`, so spelling the + // declared ordering directly is error-prone. + let request_expr = PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions { + descending: true, + nulls_first: true, + }, + }; + let declared = request_expr.reverse(); + let mut eq = EquivalenceProperties::new(Arc::clone(&full_schema)); + eq.add_ordering(vec![declared]); + let order = vec![request_expr]; + + let result = source.try_pushdown_sort(&order, &eq).unwrap(); + + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("expected Inexact, got a different variant"); + }; + let inner_parquet = inner + .downcast_ref::() + .expect("inner is ParquetSource"); + assert!( + inner_parquet.sort_order_for_reorder.is_some(), + "sort_order_for_reorder must be set so EXPLAIN reflects the request", + ); + assert!( + inner_parquet.reverse_row_groups(), + "request reached via reversed_satisfies (column-not-in-file-schema) \ + must set reverse_row_groups to flip the file's natural order", + ); + } + + /// Regression: when the source's declared ordering is a non-empty + /// *proper* prefix of the request, `try_pushdown_sort` must return + /// `Unsupported` so that the outer `SortExec` can keep its + /// `sort_prefix` annotation and `TopK` can early-terminate within + /// each prefix block. Letting the Phase 3 Inexact pipeline fire + /// would drop the source's `output_ordering`, destroying the + /// information `EnforceSorting` needs to compute `sort_prefix`. + #[test] + fn try_pushdown_sort_preserves_sort_prefix_when_source_declares_prefix_ordering() { + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_physical_expr::EquivalenceProperties; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ])); + let source = ParquetSource::new(Arc::clone(&schema)); + + // Source declares `[a DESC, b ASC NULLS LAST]` — the same prefix + // the SortExec input will see. + let mut eq = EquivalenceProperties::new(Arc::clone(&schema)); + eq.add_ordering(vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions { + descending: true, + nulls_first: true, + }, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + ]); + + // Request `[a DESC, b ASC NULLS LAST, c DESC]` — three columns; + // source's two-column declaration is a strict prefix. + let order = vec![ + PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions { + descending: true, + nulls_first: true, + }, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("b", 1)), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + PhysicalSortExpr { + expr: Arc::new(Column::new("c", 2)), + options: SortOptions { + descending: true, + nulls_first: true, + }, + }, + ]; + + let result = source.try_pushdown_sort(&order, &eq).unwrap(); + assert!( + matches!(result, SortOrderPushdownResult::Unsupported), + "source ordering [a DESC, b ASC NULLS LAST] is a proper prefix \ + of the request — `try_pushdown_sort` must return Unsupported so \ + the SortExec sort_prefix optimisation can fire", + ); + } + + /// Helpers for the `reorder_files` regression tests below. + mod reorder_files_helpers { + use super::*; + use datafusion_common::stats::Precision; + use datafusion_common::{ColumnStatistics, ScalarValue, Statistics}; + use datafusion_datasource::PartitionedFile; + + pub(super) fn file_with_min(name: &str, min: Option) -> PartitionedFile { + let mut pf = PartitionedFile::new(name.to_string(), 0); + let min_value = min + .map(|v| Precision::Exact(ScalarValue::Int32(Some(v)))) + .unwrap_or(Precision::Absent); + pf.statistics = Some(Arc::new(Statistics { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics { + null_count: Precision::Absent, + max_value: Precision::Absent, + min_value, + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Absent, + }], + })); + pf + } + + pub(super) fn names(files: &[PartitionedFile]) -> Vec<&str> { + files + .iter() + .map(|f| f.object_meta.location.as_ref()) + .collect() + } + } + + /// ASC TopK: `reorder_files` keys off file `min` and sorts ASC, + /// so the file with the smallest `min` is read first. This + /// matches `PreparedAccessPlan::reorder_by_statistics` at the + /// row-group level (also `min ASC`). + #[test] + fn reorder_files_sorts_asc_by_min_for_asc_request() { + use pushdown_sort_helpers::*; + use reorder_files_helpers::*; + + let schema = schema_with_a_int(); + let mut source = ParquetSource::new(Arc::clone(&schema)); + source.sort_order_for_reorder = + Some(LexOrdering::new(vec![sort_expr_on(&schema, "a", false)]).unwrap()); + // ASC request → `reverse_row_groups` left at its default `false`. + + let reordered = source.reorder_files(vec![ + file_with_min("middle", Some(50)), + file_with_min("small", Some(10)), + file_with_min("large", Some(100)), + ]); + + assert_eq!(names(&reordered), vec!["small", "middle", "large"]); + } + + /// DESC TopK: same `min` key, but sorted DESC — file with the + /// largest `min` first. Mirrors the row-group strategy of + /// "ASC-by-min then `reverse`". + #[test] + fn reorder_files_sorts_desc_by_min_for_desc_request() { + use pushdown_sort_helpers::*; + use reorder_files_helpers::*; + + let schema = schema_with_a_int(); + let mut source = + ParquetSource::new(Arc::clone(&schema)).with_reverse_row_groups(true); + source.sort_order_for_reorder = + Some(LexOrdering::new(vec![sort_expr_on(&schema, "a", true)]).unwrap()); + + let reordered = source.reorder_files(vec![ + file_with_min("middle", Some(50)), + file_with_min("small", Some(10)), + file_with_min("large", Some(100)), + ]); + + assert_eq!(names(&reordered), vec!["large", "middle", "small"]); + } + + /// Files without statistics sort to the *end* so present-stats + /// files run first regardless of direction. Verified for both + /// ASC and DESC. + #[test] + fn reorder_files_pushes_missing_stats_to_the_end() { + use pushdown_sort_helpers::*; + use reorder_files_helpers::*; + + let schema = schema_with_a_int(); + let mut source = ParquetSource::new(Arc::clone(&schema)); + source.sort_order_for_reorder = + Some(LexOrdering::new(vec![sort_expr_on(&schema, "a", false)]).unwrap()); + + let reordered = source.reorder_files(vec![ + file_with_min("no_stats", None), + file_with_min("has_min", Some(10)), + ]); + assert_eq!(names(&reordered), vec!["has_min", "no_stats"]); + + // Same for DESC. + let mut source = + ParquetSource::new(Arc::clone(&schema)).with_reverse_row_groups(true); + source.sort_order_for_reorder = + Some(LexOrdering::new(vec![sort_expr_on(&schema, "a", true)]).unwrap()); + let reordered = source.reorder_files(vec![ + file_with_min("no_stats", None), + file_with_min("has_min", Some(10)), + ]); + assert_eq!(names(&reordered), vec!["has_min", "no_stats"]); + } + + /// When no sort pushdown has fired (`sort_order_for_reorder` is + /// `None`), `reorder_files` is a no-op and preserves input order. + #[test] + fn reorder_files_is_a_no_op_without_pushdown() { + use pushdown_sort_helpers::*; + use reorder_files_helpers::*; + + let schema = schema_with_a_int(); + let source = ParquetSource::new(schema); + // No `sort_order_for_reorder` set on the source. + + let input = vec![ + file_with_min("c", Some(30)), + file_with_min("a", Some(10)), + file_with_min("b", Some(20)), + ]; + let reordered = source.reorder_files(input.clone()); + assert_eq!(names(&reordered), names(&input)); + } + + /// `sort_order_for_reorder` is surfaced in both `EXPLAIN` (Default) + /// and `EXPLAIN VERBOSE` / `EXPLAIN ANALYZE` (Verbose) so readers + /// and snapshot tests can see the inexact sort-pushdown fired. + #[test] + fn sort_order_for_reorder_shown_in_explain() { + use pushdown_sort_helpers::*; + + // `std::fmt::Formatter` can't be constructed outside core fmt + // machinery, so we drive `fmt_extra` through a Display adapter + // and read the rendered string back with `format!`. + struct DisplayHelper<'a> { + source: &'a ParquetSource, + mode: DisplayFormatType, + } + impl std::fmt::Display for DisplayHelper<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.source.fmt_extra(self.mode, f) + } + } + + let schema = schema_with_a_int(); + let mut source = ParquetSource::new(Arc::clone(&schema)); + let order = LexOrdering::new(vec![sort_expr_on(&schema, "a", false)]).unwrap(); + source.sort_order_for_reorder = Some(order); + + for mode in [DisplayFormatType::Default, DisplayFormatType::Verbose] { + let out = format!( + "{}", + DisplayHelper { + source: &source, + mode, + }, + ); + assert!( + out.contains("sort_order_for_reorder=[a@0 ASC]"), + "{mode:?} display must surface sort_order_for_reorder, got: {out}", + ); + } + } + + #[test] + fn test_try_pushdown_filters_rejects_virtual_column_refs() { + // Virtual columns are produced by the reader and cannot be referenced + // inside a RowFilter. `try_pushdown_filters` must report such filters + // as `PushedDown::No` so the FilterExec above the scan stays in + // place — otherwise the scan would silently drop the predicate and + // produce wrong results. + use arrow::datatypes::{DataType, Field, FieldRef, Schema}; + use datafusion_common::config::ConfigOptions; + use datafusion_datasource::TableSchema; + use datafusion_expr::{col, lit as logical_lit}; + use datafusion_physical_expr::planner::logical2physical; + use datafusion_physical_plan::filter_pushdown::PushedDown; + use parquet::arrow::RowNumber; + + let file_schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int64, + false, + )])); + let row_number_field: FieldRef = Arc::new( + Field::new("row_number", DataType::Int64, false) + .with_extension_type(RowNumber), + ); + let table_schema = TableSchema::builder(file_schema) + .with_virtual_columns(vec![row_number_field]) + .build(); + + let source = ParquetSource::new(table_schema).with_pushdown_filters(true); + + let full_schema = source.table_schema.table_schema(); + + let pushable = logical2physical(&col("value").eq(logical_lit(1i64)), full_schema); + let virtual_only = + logical2physical(&col("row_number").eq(logical_lit(2i64)), full_schema); + let mixed = logical2physical( + &col("row_number") + .eq(logical_lit(2i64)) + .or(col("value").eq(logical_lit(4i64))), + full_schema, + ); + + let config = ConfigOptions::default(); + let prop = source + .try_pushdown_filters(vec![pushable, virtual_only, mixed], &config) + .expect("try_pushdown_filters must not error"); + + assert_eq!(prop.filters.len(), 3); + assert!( + matches!(prop.filters[0], PushedDown::Yes), + "file-column filter should be pushable" + ); + assert!( + matches!(prop.filters[1], PushedDown::No), + "filter referencing only a virtual column must not be pushed down" + ); + assert!( + matches!(prop.filters[2], PushedDown::No), + "filter mixing a virtual column with a file column must not be \ + pushed down (row filter would silently drop it)" + ); + } } diff --git a/datafusion/datasource-parquet/src/supported_predicates.rs b/datafusion/datasource-parquet/src/supported_predicates.rs new file mode 100644 index 0000000000000..5c6b5f3ec9a2d --- /dev/null +++ b/datafusion/datasource-parquet/src/supported_predicates.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Registry of physical expressions that support nested list column pushdown +//! to the Parquet decoder. +//! +//! This module provides a trait-based approach for determining which predicates +//! can be safely evaluated on nested list columns during Parquet decoding. + +use std::sync::Arc; + +use datafusion_physical_expr::expressions::{IsNotNullExpr, IsNullExpr}; +use datafusion_physical_expr::{PhysicalExpr, ScalarFunctionExpr}; + +/// Trait for physical expressions that support list column pushdown during +/// Parquet decoding. +/// +/// This trait provides a type-safe mechanism for identifying expressions that +/// can be safely pushed down to the Parquet decoder for evaluation on nested +/// list columns. +/// +/// # Implementation Notes +/// +/// Expression types in external crates cannot directly implement this trait +/// due to Rust's orphan rules. Instead, we use a blanket implementation that +/// delegates to a registration mechanism. +/// +/// # Examples +/// +/// ```ignore +/// use datafusion_physical_expr::PhysicalExpr; +/// use datafusion_datasource_parquet::SupportsListPushdown; +/// +/// let expr: Arc = ...; +/// if expr.supports_list_pushdown() { +/// // Can safely push down to Parquet decoder +/// } +/// ``` +pub trait SupportsListPushdown { + /// Returns `true` if this expression supports list column pushdown. + fn supports_list_pushdown(&self) -> bool; +} + +/// Blanket implementation for all physical expressions. +/// +/// This delegates to specialized predicates that check whether the concrete +/// expression type is registered as supporting list pushdown. This design +/// allows the trait to work with expression types defined in external crates. +impl SupportsListPushdown for dyn PhysicalExpr { + fn supports_list_pushdown(&self) -> bool { + is_null_check(self) || is_supported_scalar_function(self) + } +} + +/// Checks if an expression is a NULL or NOT NULL check. +/// +/// These checks are universally supported for all column types. +fn is_null_check(expr: &dyn PhysicalExpr) -> bool { + expr.downcast_ref::().is_some() + || expr.downcast_ref::().is_some() +} + +/// Checks if an expression is a scalar function registered for list pushdown. +/// +/// Returns `true` if the expression is a `ScalarFunctionExpr` whose function +/// is in the registry of supported operations. +fn is_supported_scalar_function(expr: &dyn PhysicalExpr) -> bool { + scalar_function_name(expr).is_some_and(|name| { + // Registry of verified array functions + matches!(name, "array_has" | "array_has_all" | "array_has_any") + }) +} + +fn scalar_function_name(expr: &dyn PhysicalExpr) -> Option<&str> { + expr.downcast_ref::() + .map(ScalarFunctionExpr::name) +} + +/// Checks whether the given physical expression contains a supported nested +/// predicate (for example, `array_has_all`). +/// +/// This function recursively traverses the expression tree to determine if +/// any node contains predicates that support list column pushdown to the +/// Parquet decoder. +/// +/// # Supported predicates +/// +/// - `IS NULL` and `IS NOT NULL` checks on any column type +/// - Array functions: `array_has`, `array_has_all`, `array_has_any` +/// +/// # Returns +/// +/// `true` if the expression or any of its children contain supported predicates. +pub fn supports_list_predicates(expr: &Arc) -> bool { + expr.supports_list_pushdown() + || expr + .children() + .iter() + .any(|child| supports_list_predicates(child)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_null_check_detection() { + use datafusion_physical_expr::expressions::Column; + + let col_expr: Arc = Arc::new(Column::new("test", 0)); + assert!(!is_null_check(col_expr.as_ref())); + + // IsNullExpr and IsNotNullExpr detection requires actual instances + // which need schema setup - tested in integration tests + } + + #[test] + fn test_supported_scalar_functions() { + use datafusion_physical_expr::expressions::Column; + + let col_expr: Arc = Arc::new(Column::new("test", 0)); + + // Non-function expressions should return false + assert!(!is_supported_scalar_function(col_expr.as_ref())); + + // Testing with actual ScalarFunctionExpr requires function setup + // and is better suited for integration tests + } +} diff --git a/datafusion/datasource-parquet/src/test_data/ndv_test.parquet b/datafusion/datasource-parquet/src/test_data/ndv_test.parquet new file mode 100644 index 0000000000000000000000000000000000000000..3ecbe320f506efd450c6c2ebd31fd626571db80f GIT binary patch literal 1141 zcmZwHOHUI~6bJBgr_J)2zRYXL5Ay_(98U+#%5_W`z zDGL*HVOSbBZjCFK#uYKfL_dHbegId-rSU&ZYY1_YU+(nGz30A8I+Pnua^hbCd{gS? z8w~Ffmx$g-hKGnY7^5Hy`$-6fSUY}*%o{-5>o$_x_?S0@piF*7e!ov7B`8x91Rw~z zpcz6ChTYHtt-wKrJ)nUOZLk*%*a!Qe9U^c5I-nDx5CanqLL9o_5FCbXI08Lz6cUhx z6!bzG`rsHGhYa)wW$}cwlatO)vWT35({KhX$iV=dg*=>t0-T2na1k!SWhlZG7=)`Z z1lM2~uEPl2fKj*!x8OG1fx9pU_h1|*;69Xsa&{8u?L1|-AujS)jN3)t$`KhN+!`_T z6~F@`ctHXm#>@3_{>FIXF9e?1Dgn_gfmN2Smw=cHrA8Ll-+9sb52F-wrn7mz$Q5U{ zR1^h6Go|UuM1m`ngcMiez5k+VRMj`e3){0-$Lh&FxmqFL`8xcyHkD6zw1uCIHj^mR z4@_bi22lfPE1LXN&PzgNS+N9jaE@6f`|xqyDo`}DWbN$k6_;6rmg5wVoXT7wS{2E9 zF4P1@5jkOTK`-{9;_QX;BYSdUzC2Z#E}`_f7!x$1YR97Pt6VNUsXUyWTXF&cd=s6W z#)#AnrW<ni=Pc$6H3Dz@-I>y!Kbo2f4 zsNb2nCYx(MLz5di&7ZQNNn7`quD1!5 zRApz(N%ykJNvCeMV4tR}zEvn5&*KF11Chnb+), + /// No row groups are expected to be pruned. + None, +} + +impl ExpectedPruning { + /// Asserts that the pruned row groups match this expectation. + pub(crate) fn assert(&self, row_groups: &RowGroupAccessPlanFilter) { + let access_plan = row_groups.access_plan(); + let num_row_groups = access_plan.len(); + assert!(num_row_groups > 0); + let num_pruned = (0..num_row_groups) + .filter_map(|i| { + if access_plan.should_scan(i) { + None + } else { + Some(1) + } + }) + .sum::(); + + match self { + Self::All => { + assert_eq!( + num_row_groups, num_pruned, + "Expected all row groups to be pruned, but got {row_groups:?}" + ); + } + ExpectedPruning::None => { + assert_eq!( + num_pruned, 0, + "Expected no row groups to be pruned, but got {row_groups:?}" + ); + } + ExpectedPruning::Some(expected) => { + let actual = access_plan.row_group_indexes(); + assert_eq!( + expected, &actual, + "Unexpected row groups pruned. Expected {expected:?}, got {actual:?}" + ); + } + } + } +} diff --git a/datafusion/datasource-parquet/src/virtual_column.rs b/datafusion/datasource-parquet/src/virtual_column.rs new file mode 100644 index 0000000000000..2290ad2aeab9d --- /dev/null +++ b/datafusion/datasource-parquet/src/virtual_column.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Typed wrapper for parquet virtual columns. +//! +//! arrow-rs identifies virtual columns via arrow extension types carried on +//! the `FieldRef`. [`ParquetVirtualColumn`] lifts that contract into the type +//! system so callers validate at the boundary (via `TryFrom<&FieldRef>`) +//! rather than string-comparing extension-type names deep inside the reader. + +use arrow::datatypes::FieldRef; +use arrow_schema::extension::ExtensionType; +use datafusion_common::{DataFusionError, Result, not_impl_err}; +use parquet::arrow::RowNumber; +use std::sync::Arc; + +/// A parquet virtual column validated to have a supported arrow extension +/// type. +/// +/// Construct via [`TryFrom<&FieldRef>`]; add a new variant (and update the +/// `TryFrom` impl) when DataFusion gains support for another arrow-rs virtual +/// extension type. +#[derive(Debug, Clone)] +pub enum ParquetVirtualColumn { + /// Absolute row number within the parquet file. Backed by arrow-rs's + /// [`RowNumber`] extension type. + RowNumber(FieldRef), +} + +impl ParquetVirtualColumn { + pub fn field(&self) -> &FieldRef { + match self { + Self::RowNumber(field) => field, + } + } +} + +impl From for FieldRef { + fn from(col: ParquetVirtualColumn) -> Self { + match col { + ParquetVirtualColumn::RowNumber(field) => field, + } + } +} + +impl TryFrom<&FieldRef> for ParquetVirtualColumn { + type Error = DataFusionError; + + fn try_from(field: &FieldRef) -> Result { + let Some(name) = field.extension_type_name() else { + return not_impl_err!( + "Virtual column '{}' is missing an Arrow extension type; \ + supported extension types: [{}]", + field.name(), + RowNumber::NAME + ); + }; + match name { + n if n == RowNumber::NAME => Ok(Self::RowNumber(Arc::clone(field))), + other => not_impl_err!( + "Virtual column '{}' uses unsupported Arrow extension type '{}'; \ + supported types: [{}]. Add a ParquetVirtualColumn variant and \ + a test for this type before wiring it through.", + field.name(), + other, + RowNumber::NAME + ), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field}; + + #[test] + fn row_number_field_converts() { + let field: FieldRef = Arc::new( + Field::new("row_number", DataType::Int64, false) + .with_extension_type(RowNumber), + ); + let col = ParquetVirtualColumn::try_from(&field).expect("valid row_number"); + assert!(matches!(col, ParquetVirtualColumn::RowNumber(_))); + assert_eq!(col.field().name(), "row_number"); + } + + #[test] + fn missing_extension_type_rejected() { + let field: FieldRef = Arc::new(Field::new("plain", DataType::Int64, false)); + let err = ParquetVirtualColumn::try_from(&field).unwrap_err(); + assert!( + err.to_string().contains("missing an Arrow extension type"), + "got: {err}" + ); + } + + #[test] + fn unsupported_extension_type_rejected() { + // RowGroupIndex is a real arrow-rs virtual type not yet in our enum. + let field: FieldRef = Arc::new( + Field::new("row_group_index", DataType::Int64, false) + .with_extension_type(parquet::arrow::RowGroupIndex), + ); + let err = ParquetVirtualColumn::try_from(&field).unwrap_err(); + assert!( + err.to_string().contains("parquet.virtual.row_group_index"), + "error should name the offending extension type, got: {err}" + ); + } +} diff --git a/datafusion/datasource/Cargo.toml b/datafusion/datasource/Cargo.toml index 96e91b46eeac3..40e2271f45205 100644 --- a/datafusion/datasource/Cargo.toml +++ b/datafusion/datasource/Cargo.toml @@ -31,12 +31,13 @@ version.workspace = true all-features = true [features] +backtrace = ["datafusion-common/backtrace"] compression = ["async-compression", "liblzma", "bzip2", "flate2", "zstd", "tokio-util"] default = ["compression"] [dependencies] arrow = { workspace = true } -async-compression = { version = "0.4.30", features = [ +async-compression = { version = "0.4.40", features = [ "bzip2", "gzip", "xz", @@ -45,7 +46,7 @@ async-compression = { version = "0.4.30", features = [ ], optional = true } async-trait = { workspace = true } bytes = { workspace = true } -bzip2 = { version = "0.6.1", optional = true } +bzip2 = { workspace = true, optional = true } chrono = { workspace = true } datafusion-common = { workspace = true, features = ["object_store"] } datafusion-common-runtime = { workspace = true } @@ -56,22 +57,24 @@ datafusion-physical-expr-adapter = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-session = { workspace = true } -flate2 = { version = "1.1.4", optional = true } +flate2 = { workspace = true, optional = true } futures = { workspace = true } -glob = "0.3.0" +glob = { workspace = true } itertools = { workspace = true } liblzma = { workspace = true, optional = true } log = { workspace = true } object_store = { workspace = true } +parking_lot = { workspace = true } rand = { workspace = true } tempfile = { workspace = true, optional = true } tokio = { workspace = true } tokio-util = { version = "0.7.17", features = ["io"], optional = true } url = { workspace = true } -zstd = { version = "0.13", optional = true, default-features = false } +zstd = { workspace = true, optional = true } [dev-dependencies] criterion = { workspace = true } +insta = { workspace = true } tempfile = { workspace = true } # Note: add additional linter rules in lib.rs. diff --git a/datafusion/datasource/benches/split_groups_by_statistics.rs b/datafusion/datasource/benches/split_groups_by_statistics.rs index d51fdfc0a6e90..e2ae4a9753df8 100644 --- a/datafusion/datasource/benches/split_groups_by_statistics.rs +++ b/datafusion/datasource/benches/split_groups_by_statistics.rs @@ -24,7 +24,7 @@ use datafusion_datasource::{generate_test_files, verify_sort_integrity}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; pub fn compare_split_groups_by_statistics_algorithms(c: &mut Criterion) { let file_schema = Arc::new(Schema::new(vec![Field::new( diff --git a/datafusion/datasource/src/decoder.rs b/datafusion/datasource/src/decoder.rs index 654569f741138..9f9fc0d94bb1c 100644 --- a/datafusion/datasource/src/decoder.rs +++ b/datafusion/datasource/src/decoder.rs @@ -24,9 +24,9 @@ use arrow::error::ArrowError; use bytes::Buf; use bytes::Bytes; use datafusion_common::Result; -use futures::stream::BoxStream; use futures::StreamExt as _; -use futures::{ready, Stream}; +use futures::stream::BoxStream; +use futures::{Stream, ready}; use std::collections::VecDeque; use std::fmt; use std::task::Poll; @@ -175,17 +175,19 @@ pub fn deserialize_stream<'a>( mut input: impl Stream> + Unpin + Send + 'a, mut deserializer: impl BatchDeserializer + 'a, ) -> BoxStream<'a, Result> { - futures::stream::poll_fn(move |cx| loop { - match ready!(input.poll_next_unpin(cx)).transpose()? { - Some(b) => _ = deserializer.digest(b), - None => deserializer.finish(), - }; - - return match deserializer.next()? { - DeserializerOutput::RecordBatch(rb) => Poll::Ready(Some(Ok(rb))), - DeserializerOutput::InputExhausted => Poll::Ready(None), - DeserializerOutput::RequiresMoreData => continue, - }; + futures::stream::poll_fn(move |cx| { + loop { + match ready!(input.poll_next_unpin(cx)).transpose()? { + Some(b) => _ = deserializer.digest(b), + None => deserializer.finish(), + }; + + return match deserializer.next()? { + DeserializerOutput::RecordBatch(rb) => Poll::Ready(Some(Ok(rb))), + DeserializerOutput::InputExhausted => Poll::Ready(None), + DeserializerOutput::RequiresMoreData => continue, + }; + } }) .boxed() } diff --git a/datafusion/datasource/src/display.rs b/datafusion/datasource/src/display.rs index c9e979535963c..0f59e33ff9eac 100644 --- a/datafusion/datasource/src/display.rs +++ b/datafusion/datasource/src/display.rs @@ -135,7 +135,7 @@ mod tests { use super::*; use datafusion_physical_plan::{DefaultDisplay, VerboseDisplay}; - use object_store::{path::Path, ObjectMeta}; + use object_store::{ObjectMeta, path::Path}; use crate::PartitionedFile; use chrono::Utc; @@ -287,13 +287,6 @@ mod tests { version: None, }; - PartitionedFile { - object_meta, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - } + PartitionedFile::new_from_meta(object_meta) } } diff --git a/datafusion/datasource/src/file.rs b/datafusion/datasource/src/file.rs index 9ec34b5dda0cd..07460b23694b7 100644 --- a/datafusion/datasource/src/file.rs +++ b/datafusion/datasource/src/file.rs @@ -25,22 +25,34 @@ use std::sync::Arc; use crate::file_groups::FileGroupPartitioner; use crate::file_scan_config::FileScanConfig; use crate::file_stream::FileOpener; +use crate::morsel::{FileOpenerMorselizer, Morselizer}; +#[expect(deprecated)] use crate::schema_adapter::SchemaAdapterFactory; use datafusion_common::config::ConfigOptions; -use datafusion_common::{not_impl_err, Result}; -use datafusion_physical_expr::{LexOrdering, PhysicalExpr}; +use datafusion_common::{Result, not_impl_err}; +use datafusion_physical_expr::projection::ProjectionExprs; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalExpr}; +use datafusion_physical_plan::DisplayFormatType; +use datafusion_physical_plan::SortOrderPushdownResult; use datafusion_physical_plan::filter_pushdown::{FilterPushdownPropagation, PushedDown}; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; -use datafusion_physical_plan::DisplayFormatType; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use object_store::ObjectStore; -/// Helper function to convert any type implementing FileSource to Arc<dyn FileSource> +/// Helper function to convert any type implementing [`FileSource`] to `Arc` pub fn as_file_source(source: T) -> Arc { Arc::new(source) } -/// file format specific behaviors for elements in [`DataSource`] +/// File format specific behaviors for [`DataSource`] +/// +/// # Schema information +/// There are two important schemas for a [`FileSource`]: +/// 1. [`Self::table_schema`] -- the schema for the overall table +/// (file data plus partition columns) +/// 2. The logical output schema, comprised of [`Self::table_schema`] with +/// [`Self::projection`] applied /// /// See more details on specific implementations: /// * [`ArrowSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.ArrowSource.html) @@ -50,37 +62,92 @@ pub fn as_file_source(source: T) -> Arc /// * [`ParquetSource`](https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.ParquetSource.html) /// /// [`DataSource`]: crate::source::DataSource -pub trait FileSource: Send + Sync { - /// Creates a `dyn FileOpener` based on given parameters +pub trait FileSource: Any + Send + Sync { + /// Creates a `dyn FileOpener` based on given parameters. + /// + /// Note: File sources with a native morsel implementation should return an + /// error from this method and implementing [`Self::create_morselizer`] instead. fn create_file_opener( &self, object_store: Arc, base_config: &FileScanConfig, partition: usize, - ) -> Arc; - /// Any - fn as_any(&self) -> &dyn Any; - /// Returns the table schema for this file source. + ) -> Result>; + + /// Creates a `dyn Morselizer` based on given parameters. + /// + /// The default implementation preserves existing behavior by adapting the + /// legacy [`FileOpener`] API into a [`Morselizer`]. + /// + /// It is preferred to implement the [`Morselizer`] API directly by + /// implementing this method. + fn create_morselizer( + &self, + object_store: Arc, + base_config: &FileScanConfig, + partition: usize, + ) -> Result> { + let opener = self.create_file_opener(object_store, base_config, partition)?; + Ok(Box::new(FileOpenerMorselizer::new(opener))) + } + + /// Returns the table schema for the overall table (including partition columns, if any) + /// + /// This method returns the unprojected schema: the full schema of the data + /// without [`Self::projection`] applied. /// - /// This always returns the unprojected schema (the full schema of the data). + /// The output schema of this `FileSource` is this TableSchema + /// with [`Self::projection`] applied. + /// + /// Use [`ProjectionExprs::project_schema`] to get the projected schema + /// after applying the projection. fn table_schema(&self) -> &crate::table_schema::TableSchema; + /// Initialize new type with batch size configuration fn with_batch_size(&self, batch_size: usize) -> Arc; - /// Initialize new instance with projection information - fn with_projection(&self, config: &FileScanConfig) -> Arc; - /// Returns the filter expression that will be applied during the file scan. + + /// Returns the filter expression that will be applied *during* the file scan. + /// + /// These expressions are in terms of the unprojected [`Self::table_schema`]. fn filter(&self) -> Option> { None } + + /// Return the projection that will be applied to the output stream on top + /// of [`Self::table_schema`]. + /// + /// Note you can use [`ProjectionExprs::project_schema`] on the table + /// schema to get the effective output schema of this source. + fn projection(&self) -> Option<&ProjectionExprs> { + None + } + /// Return execution plan metrics fn metrics(&self) -> &ExecutionPlanMetricsSet; + /// String representation of file source such as "csv", "json", "parquet" fn file_type(&self) -> &str; + /// Format FileType specific information fn fmt_extra(&self, _t: DisplayFormatType, _f: &mut Formatter) -> fmt::Result { Ok(()) } + /// Returns whether this file source supports repartitioning files by byte ranges. + /// + /// When this returns `true`, files can be split into multiple partitions + /// based on byte offsets for parallel reading. + /// + /// When this returns `false`, files cannot be repartitioned (e.g., CSV files + /// with `newlines_in_values` enabled cannot be split because record boundaries + /// cannot be determined by byte offset alone). + /// + /// The default implementation returns `true`. File sources that cannot support + /// repartitioning should override this method. + fn supports_repartitioning(&self) -> bool { + true + } + /// If supported by the [`FileSource`], redistribute files across partitions /// according to their size. Allows custom file formats to implement their /// own repartitioning logic. @@ -94,7 +161,8 @@ pub trait FileSource: Send + Sync { output_ordering: Option, config: &FileScanConfig, ) -> Result> { - if config.file_compression_type.is_compressed() || config.new_lines_in_values { + if config.file_compression_type.is_compressed() || !self.supports_repartitioning() + { return Ok(None); } @@ -113,6 +181,19 @@ pub trait FileSource: Send + Sync { } /// Try to push down filters into this FileSource. + /// + /// `filters` must be in terms of the unprojected table schema (file schema + /// plus partition columns), before any projection is applied. + /// + /// Any filters that this FileSource chooses to evaluate itself should be + /// returned as `PushedDown::Yes` in the result, along with a FileSource + /// instance that incorporates those filters. Such filters are logically + /// applied "during" the file scan, meaning they may refer to columns not + /// included in the final output projection. + /// + /// Filters that cannot be pushed down should be marked as `PushedDown::No`, + /// and will be evaluated by an execution plan after the file source. + /// /// See [`ExecutionPlan::handle_child_pushdown_result`] for more details. /// /// [`ExecutionPlan::handle_child_pushdown_result`]: datafusion_physical_plan::ExecutionPlan::handle_child_pushdown_result @@ -126,29 +207,160 @@ pub trait FileSource: Send + Sync { )) } - /// Set optional schema adapter factory. + /// Try to create a new FileSource that can produce data in the specified sort order. + /// + /// This method attempts to optimize data retrieval to match the requested ordering. + /// It receives both the requested ordering and equivalence properties that describe + /// the output data from this file source. + /// + /// # Parameters + /// * `order` - The requested sort ordering from the query + /// * `eq_properties` - Equivalence properties of the data that will be produced by this + /// file source. These properties describe the ordering, constant columns, and other + /// relationships in the output data, allowing the implementation to determine if + /// optimizations like reversed scanning can help satisfy the requested ordering. + /// This includes information about: + /// - The file's natural ordering (from output_ordering in FileScanConfig) + /// - Constant columns (e.g., from filters like `ticker = 'AAPL'`) + /// - Monotonic functions (e.g., `extract_year_month(timestamp)`) + /// - Other equivalence relationships + /// + /// # Examples + /// + /// ## Example 1: Simple reverse + /// ```text + /// File ordering: [a ASC, b DESC] + /// Requested: [a DESC] + /// Reversed file: [a DESC, b ASC] + /// Result: Satisfies request (prefix match) → Inexact + /// ``` + /// + /// ## Example 2: Monotonic function + /// ```text + /// File ordering: [extract_year_month(ts) ASC, ts ASC] + /// Requested: [ts DESC] + /// Reversed file: [extract_year_month(ts) DESC, ts DESC] + /// Result: Through monotonicity, satisfies [ts DESC] → Inexact + /// ``` + /// + /// # Returns + /// * `Exact` - Created a source that guarantees perfect ordering + /// * `Inexact` - Created a source optimized for ordering (e.g., reversed row groups) but not perfectly sorted + /// * `Unsupported` - Cannot optimize for this ordering /// - /// [`SchemaAdapterFactory`] allows user to specify how fields from the - /// file get mapped to that of the table schema. If you implement this - /// method, you should also implement [`schema_adapter_factory`]. + /// # Deprecation / migration notes + /// - [`Self::try_reverse_output`] was renamed to this method and deprecated since `53.0.0`. + /// Per DataFusion's deprecation guidelines, it will be removed in `59.0.0` or later + /// (6 major versions or 6 months, whichever is longer). + /// - New implementations should override [`Self::try_pushdown_sort`] directly. + /// - For backwards compatibility, the default implementation of + /// [`Self::try_pushdown_sort`] delegates to the deprecated + /// [`Self::try_reverse_output`] until it is removed. After that point, the + /// default implementation will return [`SortOrderPushdownResult::Unsupported`]. + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + eq_properties: &EquivalenceProperties, + ) -> Result>> { + #[expect(deprecated)] + self.try_reverse_output(order, eq_properties) + } + + /// Deprecated: Renamed to [`Self::try_pushdown_sort`]. + #[deprecated( + since = "53.0.0", + note = "Renamed to try_pushdown_sort. This method was never limited to reversing output. It will be removed in 59.0.0 or later." + )] + fn try_reverse_output( + &self, + _order: &[PhysicalSortExpr], + _eq_properties: &EquivalenceProperties, + ) -> Result>> { + Ok(SortOrderPushdownResult::Unsupported) + } + + /// Reorder files in the shared work queue to optimize query performance. /// - /// The default implementation returns a not implemented error. + /// For example, TopK queries benefit from reading files with the best + /// statistics first, so the dynamic filter threshold tightens quickly. /// - /// [`schema_adapter_factory`]: Self::schema_adapter_factory + /// The default implementation returns files unchanged (no reordering). + fn reorder_files( + &self, + files: Vec, + ) -> Vec { + files + } + + /// Try to push down a projection into this FileSource. + /// + /// `FileSource` implementations that support projection pushdown should + /// override this method and return a new `FileSource` instance with the + /// projection incorporated. + /// + /// If a `FileSource` does accept a projection it is expected to handle + /// the projection in it's entirety, including partition columns. + /// For example, the `FileSource` may translate that projection into a + /// file format specific projection (e.g. Parquet can push down struct field access, + /// some other file formats like Vortex can push down computed expressions into un-decoded data) + /// and also need to handle partition column projection (generally done by replacing partition column + /// references with literal values derived from each files partition values). + /// + /// Not all FileSource's can handle complex expression pushdowns. For example, + /// a CSV file source may only support simple column selections. In such cases, + /// the `FileSource` can use [`SplitProjection`] and [`ProjectionOpener`] + /// to split the projection into a pushdownable part and a non-pushdownable part. + /// These helpers also handle partition column projection. + /// + /// [`SplitProjection`]: crate::projection::SplitProjection + /// [`ProjectionOpener`]: crate::projection::ProjectionOpener + fn try_pushdown_projection( + &self, + _projection: &ProjectionExprs, + ) -> Result>> { + Ok(None) + } + + /// Deprecated: Set optional schema adapter factory. + /// + /// `SchemaAdapterFactory` has been removed. Use `PhysicalExprAdapterFactory` instead. + /// See `upgrading.md` for more details. + #[deprecated( + since = "53.0.0", + note = "SchemaAdapterFactory has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." + )] + #[expect(deprecated)] fn with_schema_adapter_factory( &self, _factory: Arc, ) -> Result> { not_impl_err!( - "FileSource {} does not support schema adapter factory", - self.file_type() + "SchemaAdapterFactory has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." ) } - /// Returns the current schema adapter factory if set + /// Deprecated: Returns the current schema adapter factory if set. /// - /// Default implementation returns `None`. + /// `SchemaAdapterFactory` has been removed. Use `PhysicalExprAdapterFactory` instead. + /// See `upgrading.md` for more details. + #[deprecated( + since = "53.0.0", + note = "SchemaAdapterFactory has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." + )] + #[expect(deprecated)] fn schema_adapter_factory(&self) -> Option> { None } } + +impl dyn FileSource { + /// Returns `true` if this source is of type `T`. + pub fn is(&self) -> bool { + (self as &dyn Any).is::() + } + + /// Attempts to downcast this source to a concrete type `T`. + pub fn downcast_ref(&self) -> Option<&T> { + (self as &dyn Any).downcast_ref() + } +} diff --git a/datafusion/datasource/src/file_compression_type.rs b/datafusion/datasource/src/file_compression_type.rs index 9ca5d8763b74a..89efb580652b1 100644 --- a/datafusion/datasource/src/file_compression_type.rs +++ b/datafusion/datasource/src/file_compression_type.rs @@ -21,8 +21,8 @@ use std::str::FromStr; use datafusion_common::error::{DataFusionError, Result}; -use datafusion_common::parsers::CompressionTypeVariant::{self, *}; use datafusion_common::GetExt; +use datafusion_common::parsers::CompressionTypeVariant::{self, *}; #[cfg(feature = "compression")] use async_compression::tokio::bufread::{ @@ -39,10 +39,10 @@ use bytes::Bytes; use bzip2::read::MultiBzDecoder; #[cfg(feature = "compression")] use flate2::read::MultiGzDecoder; -use futures::stream::BoxStream; use futures::StreamExt; #[cfg(feature = "compression")] use futures::TryStreamExt; +use futures::stream::BoxStream; #[cfg(feature = "compression")] use liblzma::read::XzDecoder; use object_store::buffered::BufWriter; @@ -148,32 +148,70 @@ impl FileCompressionType { GZIP | BZIP2 | XZ | ZSTD => { return Err(DataFusionError::NotImplemented( "Compression feature is not enabled".to_owned(), - )) + )); } UNCOMPRESSED => s.boxed(), }) } /// Wrap the given `BufWriter` so that it performs compressed writes - /// according to this `FileCompressionType`. + /// according to this `FileCompressionType` using the default compression level. pub fn convert_async_writer( &self, w: BufWriter, ) -> Result> { + self.convert_async_writer_with_level(w, None) + } + + /// Wrap the given `BufWriter` so that it performs compressed writes + /// according to this `FileCompressionType`. + /// + /// If `compression_level` is `Some`, the encoder will use the specified + /// compression level. If `None`, the default level for each algorithm is used. + pub fn convert_async_writer_with_level( + &self, + w: BufWriter, + compression_level: Option, + ) -> Result> { + #[cfg(feature = "compression")] + use async_compression::Level; + Ok(match self.variant { #[cfg(feature = "compression")] - GZIP => Box::new(GzipEncoder::new(w)), + GZIP => match compression_level { + Some(level) => { + Box::new(GzipEncoder::with_quality(w, Level::Precise(level as i32))) + } + None => Box::new(GzipEncoder::new(w)), + }, #[cfg(feature = "compression")] - BZIP2 => Box::new(BzEncoder::new(w)), + BZIP2 => match compression_level { + Some(level) => { + Box::new(BzEncoder::with_quality(w, Level::Precise(level as i32))) + } + None => Box::new(BzEncoder::new(w)), + }, #[cfg(feature = "compression")] - XZ => Box::new(XzEncoder::new(w)), + XZ => match compression_level { + Some(level) => { + Box::new(XzEncoder::with_quality(w, Level::Precise(level as i32))) + } + None => Box::new(XzEncoder::new(w)), + }, #[cfg(feature = "compression")] - ZSTD => Box::new(ZstdEncoder::new(w)), + ZSTD => match compression_level { + Some(level) => { + Box::new(ZstdEncoder::with_quality(w, Level::Precise(level as i32))) + } + None => Box::new(ZstdEncoder::new(w)), + }, #[cfg(not(feature = "compression"))] GZIP | BZIP2 | XZ | ZSTD => { + // compression_level is not used when compression feature is disabled + let _ = compression_level; return Err(DataFusionError::NotImplemented( "Compression feature is not enabled".to_owned(), - )) + )); } UNCOMPRESSED => Box::new(w), }) @@ -210,7 +248,7 @@ impl FileCompressionType { GZIP | BZIP2 | XZ | ZSTD => { return Err(DataFusionError::NotImplemented( "Compression feature is not enabled".to_owned(), - )) + )); } UNCOMPRESSED => s.boxed(), }) @@ -237,7 +275,7 @@ impl FileCompressionType { GZIP | BZIP2 | XZ | ZSTD => { return Err(DataFusionError::NotImplemented( "Compression feature is not enabled".to_owned(), - )) + )); } UNCOMPRESSED => Box::new(r), }) diff --git a/datafusion/datasource/src/file_format.rs b/datafusion/datasource/src/file_format.rs index bb4ffded8086a..dd30881610f36 100644 --- a/datafusion/datasource/src/file_format.rs +++ b/datafusion/datasource/src/file_format.rs @@ -30,8 +30,9 @@ use crate::file_sink_config::FileSinkConfig; use arrow::datatypes::SchemaRef; use datafusion_common::file_options::file_type::FileType; -use datafusion_common::{internal_err, not_impl_err, GetExt, Result, Statistics}; +use datafusion_common::{GetExt, Result, Statistics, internal_err, not_impl_err}; use datafusion_physical_expr::LexRequirement; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::ExecutionPlan; use datafusion_session::Session; @@ -41,17 +42,42 @@ use object_store::{ObjectMeta, ObjectStore}; /// Default max records to scan to infer the schema pub const DEFAULT_SCHEMA_INFER_MAX_RECORD: usize = 1000; +/// Metadata fetched from a file, including statistics and ordering. +/// +/// This struct is returned by [`FileFormat::infer_stats_and_ordering`] to +/// provide all metadata in a single read, avoiding duplicate I/O operations. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct FileMeta { + /// Statistics for the file (row counts, byte sizes, column statistics). + pub statistics: Statistics, + /// The ordering (sort order) of the file, if known. + pub ordering: Option, +} + +impl FileMeta { + /// Creates a new `FileMeta` with the given statistics and no ordering. + pub fn new(statistics: Statistics) -> Self { + Self { + statistics, + ordering: None, + } + } + + /// Sets the ordering for this file metadata. + pub fn with_ordering(mut self, ordering: Option) -> Self { + self.ordering = ordering; + self + } +} + /// This trait abstracts all the file format specific implementations /// from the [`TableProvider`]. This helps code re-utilization across /// providers that support the same file formats. /// /// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProvider.html #[async_trait] -pub trait FileFormat: Send + Sync + fmt::Debug { - /// Returns the table provider as [`Any`] so that it can be - /// downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; - +pub trait FileFormat: Any + Send + Sync + fmt::Debug { /// Returns the extension for this FileFormat, e.g. "file.csv" -> csv fn get_ext(&self) -> String; @@ -90,6 +116,52 @@ pub trait FileFormat: Send + Sync + fmt::Debug { object: &ObjectMeta, ) -> Result; + /// Infer the ordering (sort order) for the provided object from file metadata. + /// + /// Returns `Ok(None)` if the file format does not support ordering inference + /// or if the file does not have ordering information. + /// + /// `table_schema` is the (combined) schema of the overall table + /// and may be a superset of the schema contained in this file. + /// + /// The default implementation returns `Ok(None)`. + async fn infer_ordering( + &self, + _state: &dyn Session, + _store: &Arc, + _table_schema: SchemaRef, + _object: &ObjectMeta, + ) -> Result> { + Ok(None) + } + + /// Infer both statistics and ordering from a single metadata read. + /// + /// This is more efficient than calling [`Self::infer_stats`] and + /// [`Self::infer_ordering`] separately when both are needed, as it avoids + /// reading file metadata twice. + /// + /// The default implementation calls both methods separately. File formats + /// that can extract both from a single read should override this method. + async fn infer_stats_and_ordering( + &self, + state: &dyn Session, + store: &Arc, + table_schema: SchemaRef, + object: &ObjectMeta, + ) -> Result { + let statistics = self + .infer_stats(state, store, Arc::clone(&table_schema), object) + .await?; + let ordering = self + .infer_ordering(state, store, table_schema, object) + .await?; + Ok(FileMeta { + statistics, + ordering, + }) + } + /// Take a list of files and convert it to the appropriate executor /// according to this file format. async fn create_physical_plan( @@ -117,10 +189,20 @@ pub trait FileFormat: Send + Sync + fmt::Debug { fn file_source(&self, table_schema: crate::TableSchema) -> Arc; } +impl dyn FileFormat { + pub fn is(&self) -> bool { + (self as &dyn Any).is::() + } + + pub fn downcast_ref(&self) -> Option<&T> { + (self as &dyn Any).downcast_ref() + } +} + /// Factory for creating [`FileFormat`] instances based on session and command level options /// /// Users can provide their own `FileFormatFactory` to support arbitrary file formats -pub trait FileFormatFactory: Sync + Send + GetExt + fmt::Debug { +pub trait FileFormatFactory: Any + Sync + Send + GetExt + fmt::Debug { /// Initialize a [FileFormat] and configure based on session and command level options fn create( &self, @@ -130,10 +212,16 @@ pub trait FileFormatFactory: Sync + Send + GetExt + fmt::Debug { /// Initialize a [FileFormat] with all options set to default values fn default(&self) -> Arc; +} + +impl dyn FileFormatFactory { + pub fn is(&self) -> bool { + (self as &dyn Any).is::() + } - /// Returns the table source as [`Any`] so that it can be - /// downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; + pub fn downcast_ref(&self) -> Option<&T> { + (self as &dyn Any).downcast_ref() + } } /// A container of [FileFormatFactory] which also implements [FileType]. diff --git a/datafusion/datasource/src/file_groups.rs b/datafusion/datasource/src/file_groups.rs index 998d09285cf1d..84594be54b504 100644 --- a/datafusion/datasource/src/file_groups.rs +++ b/datafusion/datasource/src/file_groups.rs @@ -18,10 +18,12 @@ //! Logic for managing groups of [`PartitionedFile`]s in DataFusion use crate::{FileRange, PartitionedFile}; +use arrow::compute::SortOptions; use datafusion_common::Statistics; +use datafusion_common::utils::compare_rows; use itertools::Itertools; -use std::cmp::{min, Ordering}; -use std::collections::BinaryHeap; +use std::cmp::{Ordering, min}; +use std::collections::{BinaryHeap, HashMap}; use std::iter::repeat_with; use std::mem; use std::ops::{Deref, DerefMut, Index, IndexMut}; @@ -189,15 +191,6 @@ impl FileGroupPartitioner { return None; } - // Perform redistribution only in case all files should be read from beginning to end - let has_ranges = file_groups - .iter() - .flat_map(FileGroup::iter) - .any(|f| f.range.is_some()); - if has_ranges { - return None; - } - // special case when order must be preserved if self.preserve_order_within_groups { self.repartition_preserving_order(file_groups) @@ -218,14 +211,13 @@ impl FileGroupPartitioner { let total_size = flattened_files .iter() - .map(|f| f.object_meta.size as i64) - .sum::(); - if total_size < (repartition_file_min_size as i64) || total_size == 0 { + .map(|f| f.effective_size()) + .sum::(); + if total_size < (repartition_file_min_size as u64) || total_size == 0 { return None; } - let target_partition_size = - (total_size as u64).div_ceil(target_partitions as u64); + let target_partition_size = total_size.div_ceil(target_partitions as u64); let current_partition_index: usize = 0; let current_partition_size: u64 = 0; @@ -235,13 +227,14 @@ impl FileGroupPartitioner { .into_iter() .scan( (current_partition_index, current_partition_size), - |state, source_file| { + |(current_partition_index, current_partition_size), source_file| { let mut produced_files = vec![]; - let mut range_start = 0; - while range_start < source_file.object_meta.size { + let (mut range_start, file_end) = source_file.range(); + while range_start < file_end { let range_end = min( - range_start + (target_partition_size - state.1), - source_file.object_meta.size, + range_start + + (target_partition_size - *current_partition_size), + file_end, ); let mut produced_file = source_file.clone(); @@ -249,13 +242,15 @@ impl FileGroupPartitioner { start: range_start as i64, end: range_end as i64, }); - produced_files.push((state.0, produced_file)); + produced_files.push((*current_partition_index, produced_file)); - if state.1 + (range_end - range_start) >= target_partition_size { - state.0 += 1; - state.1 = 0; + if *current_partition_size + (range_end - range_start) + >= target_partition_size + { + *current_partition_index += 1; + *current_partition_size = 0; } else { - state.1 += range_end - range_start; + *current_partition_size += range_end - range_start; } range_start = range_end; } @@ -297,7 +292,7 @@ impl FileGroupPartitioner { if group.len() == 1 { Some(ToRepartition { source_index: group_index, - file_size: group[0].object_meta.size, + file_size: group[0].effective_size(), new_groups: vec![group_index], }) } else { @@ -333,28 +328,31 @@ impl FileGroupPartitioner { // Distribute files to their newly assigned groups while let Some(to_repartition) = heap.pop() { - let range_size = to_repartition.range_size() as i64; + let range_size = to_repartition.range_size(); let ToRepartition { source_index, - file_size, + file_size: _, new_groups, } = to_repartition.into_inner(); assert_eq!(file_groups[source_index].len(), 1); let original_file = file_groups[source_index].pop().unwrap(); let last_group = new_groups.len() - 1; - let mut range_start: i64 = 0; - let mut range_end: i64 = range_size; + let (mut range_start, file_end) = original_file.range(); + let mut range_end = range_start + range_size; for (i, group_index) in new_groups.into_iter().enumerate() { let target_group = &mut file_groups[group_index]; assert!(target_group.is_empty()); // adjust last range to include the entire file if i == last_group { - range_end = file_size as i64; + range_end = file_end; } - target_group - .push(original_file.clone().with_range(range_start, range_end)); + target_group.push( + original_file + .clone() + .with_range(range_start as i64, range_end as i64), + ); range_start = range_end; range_end += range_size; } @@ -366,11 +364,27 @@ impl FileGroupPartitioner { /// Represents a group of partitioned files that'll be processed by a single thread. /// Maintains optional statistics across all files in the group. +/// +/// # Statistics +/// +/// The group-level [`FileGroup::file_statistics`] field contains merged statistics from all files +/// in the group for the **full table schema** (file columns + partition columns). +/// +/// Partition column statistics are derived from the individual file partition values: +/// - `min` = minimum partition value across all files in the group +/// - `max` = maximum partition value across all files in the group +/// - `null_count` = 0 (partition values are never null) +/// +/// This allows query optimizers to prune entire file groups based on partition bounds. #[derive(Debug, Clone)] pub struct FileGroup { /// The files in this group files: Vec, - /// Optional statistics for the data across all files in the group + /// Optional statistics for the data across all files in the group. + /// + /// These statistics cover the full table schema: file columns plus partition columns. + /// Partition column statistics are merged from individual [`PartitionedFile::statistics`], + /// which compute exact values from [`PartitionedFile::partition_values`]. statistics: Option>, } @@ -468,6 +482,65 @@ impl FileGroup { chunks } + + /// Groups files by their partition values, ensuring all files with same + /// partition values are in the same group. + /// + /// Note: May return fewer groups than `max_target_partitions` when the + /// number of unique partition values is less than the target. + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key + pub fn group_by_partition_values( + self, + max_target_partitions: usize, + ) -> Vec { + if self.is_empty() || max_target_partitions == 0 { + return vec![]; + } + + let mut partition_groups: HashMap< + Vec, + Vec, + > = HashMap::new(); + + for file in self.files { + partition_groups + .entry(file.partition_values.clone()) + .or_default() + .push(file); + } + + let num_unique_partitions = partition_groups.len(); + + // Sort for deterministic bucket assignment across query executions. + let mut sorted_partitions: Vec<_> = partition_groups.into_iter().collect(); + let sort_options = + vec![ + SortOptions::default(); + sorted_partitions.first().map(|(k, _)| k.len()).unwrap_or(0) + ]; + sorted_partitions.sort_by(|a, b| { + compare_rows(&a.0, &b.0, &sort_options).unwrap_or(Ordering::Equal) + }); + + if num_unique_partitions <= max_target_partitions { + sorted_partitions + .into_iter() + .map(|(_, files)| FileGroup::new(files)) + .collect() + } else { + // Merge into max_target_partitions buckets using round-robin. + // This maintains grouping by partition value as we are merging groups which already + // contain all values for a partition key. + let mut target_groups = vec![vec![]; max_target_partitions]; + + for (idx, (_, files)) in sorted_partitions.into_iter().enumerate() { + let bucket = idx % max_target_partitions; + target_groups[bucket].extend(files); + } + + target_groups.into_iter().map(FileGroup::new).collect() + } + } } impl Index for FileGroup { @@ -559,6 +632,7 @@ impl DerefMut for CompareByRangeSize { #[cfg(test)] mod test { use super::*; + use datafusion_common::ScalarValue; /// Empty file won't get partitioned #[test] @@ -645,6 +719,68 @@ mod test { assert_partitioned_files(expected, actual); } + #[test] + fn repartition_single_file_with_range() { + // Single file, single partition into multiple partitions + let single_partition = + vec![FileGroup::new(vec![pfile("a", 123).with_range(0, 123)])]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&single_partition); + + let expected = Some(vec![ + FileGroup::new(vec![pfile("a", 123).with_range(0, 31)]), + FileGroup::new(vec![pfile("a", 123).with_range(31, 62)]), + FileGroup::new(vec![pfile("a", 123).with_range(62, 93)]), + FileGroup::new(vec![pfile("a", 123).with_range(93, 123)]), + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_single_file_with_incomplete_range() { + // Single file, single partition into multiple partitions + let single_partition = + vec![FileGroup::new(vec![pfile("a", 123).with_range(10, 100)])]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&single_partition); + + let expected = Some(vec![ + FileGroup::new(vec![pfile("a", 123).with_range(10, 33)]), + FileGroup::new(vec![pfile("a", 123).with_range(33, 56)]), + FileGroup::new(vec![pfile("a", 123).with_range(56, 79)]), + FileGroup::new(vec![pfile("a", 123).with_range(79, 100)]), + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_single_file_duplicated_with_range() { + // Single file, two partitions into multiple partitions + let single_partition = vec![FileGroup::new(vec![ + pfile("a", 100).with_range(0, 50), + pfile("a", 100).with_range(50, 100), + ])]; + + let actual = FileGroupPartitioner::new() + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&single_partition); + + let expected = Some(vec![ + FileGroup::new(vec![pfile("a", 100).with_range(0, 25)]), + FileGroup::new(vec![pfile("a", 100).with_range(25, 50)]), + FileGroup::new(vec![pfile("a", 100).with_range(50, 75)]), + FileGroup::new(vec![pfile("a", 100).with_range(75, 100)]), + ]); + assert_partitioned_files(expected, actual); + } + #[test] fn repartition_too_much_partitions() { // Single file, single partition into 96 partitions @@ -717,22 +853,6 @@ mod test { assert_partitioned_files(expected, actual); } - #[test] - fn repartition_no_action_ranges() { - // No action due to Some(range) in second file - let source_partitions = vec![ - FileGroup::new(vec![pfile("a", 123)]), - FileGroup::new(vec![pfile("b", 144).with_range(1, 50)]), - ]; - - let actual = FileGroupPartitioner::new() - .with_target_partitions(65) - .with_repartition_file_min_size(10) - .repartition_file_groups(&source_partitions); - - assert_partitioned_files(None, actual) - } - #[test] fn repartition_no_action_min_size() { // No action due to target_partition_size @@ -809,6 +929,26 @@ mod test { assert_partitioned_files(expected, actual); } + #[test] + fn repartition_ordered_one_large_file_with_range() { + // "Rebalance" the single large file across partitions + let source_partitions = + vec![FileGroup::new(vec![pfile("a", 100).with_range(0, 100)])]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(3) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + FileGroup::new(vec![pfile("a", 100).with_range(0, 34)]), + FileGroup::new(vec![pfile("a", 100).with_range(34, 68)]), + FileGroup::new(vec![pfile("a", 100).with_range(68, 100)]), + ]); + assert_partitioned_files(expected, actual); + } + #[test] fn repartition_ordered_one_large_one_small_file() { // "Rebalance" the single large file across empty partitions, but can't split @@ -837,6 +977,91 @@ mod test { assert_partitioned_files(expected, actual); } + #[test] + fn repartition_ordered_one_large_one_small_file_with_full_range() { + // "Rebalance" the single large file across empty partitions, but can't split + // small file + let source_partitions = vec![ + FileGroup::new(vec![pfile("a", 100).with_range(0, 100)]), + FileGroup::new(vec![pfile("b", 30)]), + ]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first third of "a" + FileGroup::new(vec![pfile("a", 100).with_range(0, 33)]), + // only b in this group (can't do this) + FileGroup::new(vec![pfile("b", 30).with_range(0, 30)]), + // second third of "a" + FileGroup::new(vec![pfile("a", 100).with_range(33, 66)]), + // final third of "a" + FileGroup::new(vec![pfile("a", 100).with_range(66, 100)]), + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_one_large_one_small_file_with_split_range() { + // "Rebalance" the single large file across empty partitions, but can't split + // small file + let source_partitions = vec![ + FileGroup::new(vec![pfile("a", 100).with_range(0, 50)]), + FileGroup::new(vec![pfile("a", 100).with_range(50, 100)]), + FileGroup::new(vec![pfile("b", 30)]), + ]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first half of first "a" + FileGroup::new(vec![pfile("a", 100).with_range(0, 25)]), + // second "a" fully (not split) + FileGroup::new(vec![pfile("a", 100).with_range(50, 100)]), + // only b in this group (can't do this) + FileGroup::new(vec![pfile("b", 30).with_range(0, 30)]), + // second half of first "a" + FileGroup::new(vec![pfile("a", 100).with_range(25, 50)]), + ]); + assert_partitioned_files(expected, actual); + } + + #[test] + fn repartition_ordered_one_large_one_small_file_with_non_full_range() { + // "Rebalance" the single large file across empty partitions, but can't split + // small file + let source_partitions = vec![ + FileGroup::new(vec![pfile("a", 100).with_range(20, 80)]), + FileGroup::new(vec![pfile("b", 30).with_range(5, 25)]), + ]; + + let actual = FileGroupPartitioner::new() + .with_preserve_order_within_groups(true) + .with_target_partitions(4) + .with_repartition_file_min_size(10) + .repartition_file_groups(&source_partitions); + + let expected = Some(vec![ + // scan first third of "a" + FileGroup::new(vec![pfile("a", 100).with_range(20, 40)]), + // only b in this group (can't split this) + FileGroup::new(vec![pfile("b", 30).with_range(5, 25)]), + // second third of "a" + FileGroup::new(vec![pfile("a", 100).with_range(40, 60)]), + // final third of "a" + FileGroup::new(vec![pfile("a", 100).with_range(60, 80)]), + ]); + assert_partitioned_files(expected, actual); + } + #[test] fn repartition_ordered_two_large_files() { // "Rebalance" two large files across empty partitions, but can't mix them @@ -998,6 +1223,13 @@ mod test { PartitionedFile::new(path, file_size) } + /// Creates a file with partition value with a static size of 10. + fn pfile_with_pv(path: &str, pv: &str) -> PartitionedFile { + let mut file = pfile(path, 10); + file.partition_values = vec![ScalarValue::from(pv)]; + file + } + /// repartition the file groups both with and without preserving order /// asserting they return the same value and returns that value fn repartition_test( @@ -1013,4 +1245,50 @@ mod test { assert_partitioned_files(repartitioned.clone(), repartitioned_preserving_sort); repartitioned } + + #[test] + fn test_group_by_partition_values_edge_cases() { + // Edge cases: empty and zero target + assert!(FileGroup::default().group_by_partition_values(4).is_empty()); + assert!( + FileGroup::new(vec![pfile("a", 100)]) + .group_by_partition_values(0) + .is_empty() + ); + } + + #[test] + fn test_group_by_partition_values_less_groups_than_target() { + // File a and b have partition value p1. + // File c has partition value p2. + // Grouping by partition value should not redistribute any files since the number of partition + // values <= max_target_partitions. + let fg = FileGroup::new(vec![ + pfile_with_pv("a", "p1"), + pfile_with_pv("b", "p1"), + pfile_with_pv("c", "p2"), + ]); + let groups = fg.group_by_partition_values(4); + assert_eq!(groups.len(), 2); + assert_eq!(groups[0].len(), 2); + assert_eq!(groups[1].len(), 1); + } + + #[test] + fn test_group_by_partition_values_more_groups_than_target() { + // Each file has a single partition value. The number of partition values > max_target_partitions, so + // they should be round-robin distributed into groups. + let fg = FileGroup::new(vec![ + pfile_with_pv("a", "p1"), + pfile_with_pv("b", "p2"), + pfile_with_pv("c", "p3"), + pfile_with_pv("d", "p4"), + pfile_with_pv("e", "p5"), + ]); + let groups = fg.group_by_partition_values(3); + assert_eq!(groups.len(), 3); + assert_eq!(groups[0].len(), 2); + assert_eq!(groups[1].len(), 2); + assert_eq!(groups[2].len(), 1); + } } diff --git a/datafusion/datasource/src/file_scan_config/mod.rs b/datafusion/datasource/src/file_scan_config/mod.rs new file mode 100644 index 0000000000000..3ebd588a0770f --- /dev/null +++ b/datafusion/datasource/src/file_scan_config/mod.rs @@ -0,0 +1,3192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`FileScanConfig`] to configure scanning of possibly partitioned +//! file sources. + +pub(crate) mod sort_pushdown; + +use crate::file_groups::FileGroup; +use crate::{ + PartitionedFile, display::FileGroupsDisplay, file::FileSource, + file_compression_type::FileCompressionType, file_stream::FileStreamBuilder, + file_stream::work_source::SharedWorkSource, source::DataSource, + statistics::MinMaxStatistics, +}; +use arrow::datatypes::Fields; +use arrow::datatypes::{DataType, Schema, SchemaRef}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{ + Constraints, Result, ScalarValue, Statistics, internal_datafusion_err, internal_err, +}; +use datafusion_execution::{ + SendableRecordBatchStream, TaskContext, object_store::ObjectStoreUrl, +}; +use datafusion_expr::Operator; + +use crate::source::OpenArgs; +use datafusion_physical_expr::expressions::{BinaryExpr, Column}; +use datafusion_physical_expr::projection::ProjectionExprs; +use datafusion_physical_expr::utils::reassign_expr_columns; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning, split_conjunction}; +use datafusion_physical_expr_adapter::PhysicalExprAdapterFactory; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::SortOrderPushdownResult; +use datafusion_physical_plan::coop::cooperative; +use datafusion_physical_plan::execution_plan::SchedulingType; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, + display::{ProjectSchemaDisplay, display_orderings}, + filter_pushdown::FilterPushdownPropagation, + metrics::ExecutionPlanMetricsSet, +}; +use log::{debug, warn}; +use std::any::Any; +use std::{fmt::Debug, fmt::Formatter, fmt::Result as FmtResult, sync::Arc}; + +/// [`FileScanConfig`] represents scanning data from a group of files +/// +/// `FileScanConfig` is used to create a [`DataSourceExec`], the physical plan +/// for scanning files with a particular file format. +/// +/// The [`FileSource`] (e.g. `ParquetSource`, `CsvSource`, etc.) is responsible +/// for creating the actual execution plan to read the files based on a +/// `FileScanConfig`. Fields in a `FileScanConfig` such as Statistics represent +/// information about the files **before** any projection or filtering is +/// applied in the file source. +/// +/// Use [`FileScanConfigBuilder`] to construct a `FileScanConfig`. +/// +/// Use [`DataSourceExec::from_data_source`] to create a [`DataSourceExec`] from +/// a `FileScanConfig`. +/// +/// # Example +/// ``` +/// # use std::sync::Arc; +/// # use arrow::datatypes::{Field, Fields, DataType, Schema, SchemaRef}; +/// # use object_store::ObjectStore; +/// # use datafusion_common::Result; +/// # use datafusion_datasource::file::FileSource; +/// # use datafusion_datasource::file_groups::FileGroup; +/// # use datafusion_datasource::PartitionedFile; +/// # use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; +/// # use datafusion_datasource::file_stream::FileOpener; +/// # use datafusion_datasource::source::DataSourceExec; +/// # use datafusion_datasource::table_schema::TableSchema; +/// # use datafusion_execution::object_store::ObjectStoreUrl; +/// # use datafusion_physical_expr::projection::ProjectionExprs; +/// # use datafusion_physical_plan::ExecutionPlan; +/// # use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; +/// # let file_schema = Arc::new(Schema::new(vec![ +/// # Field::new("c1", DataType::Int32, false), +/// # Field::new("c2", DataType::Int32, false), +/// # Field::new("c3", DataType::Int32, false), +/// # Field::new("c4", DataType::Int32, false), +/// # ])); +/// # // Note: crate mock ParquetSource, as ParquetSource is not in the datasource crate +/// #[derive(Clone)] +/// # struct ParquetSource { +/// # table_schema: TableSchema, +/// # }; +/// # impl FileSource for ParquetSource { +/// # fn create_file_opener(&self, _: Arc, _: &FileScanConfig, _: usize) -> Result> { unimplemented!() } +/// # fn table_schema(&self) -> &TableSchema { &self.table_schema } +/// # fn with_batch_size(&self, _: usize) -> Arc { unimplemented!() } +/// # fn metrics(&self) -> &ExecutionPlanMetricsSet { unimplemented!() } +/// # fn file_type(&self) -> &str { "parquet" } +/// # // Note that this implementation drops the projection on the floor, it is not complete! +/// # fn try_pushdown_projection(&self, projection: &ProjectionExprs) -> Result>> { Ok(Some(Arc::new(self.clone()) as Arc)) } +/// # } +/// # impl ParquetSource { +/// # fn new(table_schema: impl Into) -> Self { Self {table_schema: table_schema.into()} } +/// # } +/// // create FileScan config for reading parquet files from file:// +/// let object_store_url = ObjectStoreUrl::local_filesystem(); +/// let file_source = Arc::new(ParquetSource::new(file_schema.clone())); +/// let config = FileScanConfigBuilder::new(object_store_url, file_source) +/// .with_limit(Some(1000)) // read only the first 1000 records +/// .with_projection_indices(Some(vec![2, 3])) // project columns 2 and 3 +/// .expect("Failed to push down projection") +/// // Read /tmp/file1.parquet with known size of 1234 bytes in a single group +/// .with_file(PartitionedFile::new("file1.parquet", 1234)) +/// // Read /tmp/file2.parquet 56 bytes and /tmp/file3.parquet 78 bytes +/// // in a single row group +/// .with_file_group(FileGroup::new(vec![ +/// PartitionedFile::new("file2.parquet", 56), +/// PartitionedFile::new("file3.parquet", 78), +/// ])).build(); +/// // create an execution plan from the config +/// let plan: Arc = DataSourceExec::from_data_source(config); +/// ``` +/// +/// [`DataSourceExec`]: crate::source::DataSourceExec +/// [`DataSourceExec::from_data_source`]: crate::source::DataSourceExec::from_data_source +#[derive(Clone)] +pub struct FileScanConfig { + /// Object store URL, used to get an [`ObjectStore`] instance from + /// [`RuntimeEnv::object_store`] + /// + /// This `ObjectStoreUrl` should be the prefix of the absolute url for files + /// as `file://` or `s3://my_bucket`. It should not include the path to the + /// file itself. The relevant URL prefix must be registered via + /// [`RuntimeEnv::register_object_store`] + /// + /// [`ObjectStore`]: object_store::ObjectStore + /// [`RuntimeEnv::register_object_store`]: datafusion_execution::runtime_env::RuntimeEnv::register_object_store + /// [`RuntimeEnv::object_store`]: datafusion_execution::runtime_env::RuntimeEnv::object_store + pub object_store_url: ObjectStoreUrl, + /// List of files to be processed, grouped into partitions + /// + /// Each file must have a schema of `file_schema` or a subset. If + /// a particular file has a subset, the missing columns are + /// padded with NULLs. + /// + /// DataFusion may attempt to read each partition of files + /// concurrently, however files *within* a partition will be read + /// sequentially, one after the next. + pub file_groups: Vec, + /// Table constraints + pub constraints: Constraints, + /// The maximum number of records to read from this plan. If `None`, + /// all records after filtering are returned. + pub limit: Option, + /// Whether the scan's limit is order sensitive + /// When `true`, files must be read in the exact order specified to produce + /// correct results (e.g., for `ORDER BY ... LIMIT` queries). When `false`, + /// DataFusion may reorder file processing for optimization without affecting correctness. + pub preserve_order: bool, + /// All equivalent lexicographical output orderings of this file scan, in terms of + /// [`FileSource::table_schema`]. See [`FileScanConfigBuilder::with_output_ordering`] for more + /// details. + /// + /// [`Self::eq_properties`] uses this information along with projection + /// and filtering information to compute the effective + /// [`EquivalenceProperties`] + pub output_ordering: Vec, + /// File compression type + pub file_compression_type: FileCompressionType, + /// File source such as `ParquetSource`, `CsvSource`, `JsonSource`, etc. + pub file_source: Arc, + /// Batch size while creating new batches + /// Defaults to [`datafusion_common::config::ExecutionOptions`] batch_size. + pub batch_size: Option, + /// Expression adapter used to adapt filters and projections that are pushed down into the scan + /// from the logical schema to the physical schema of the file. + pub expr_adapter_factory: Option>, + /// Statistics for the entire table (file schema + partition columns). + /// See [`FileScanConfigBuilder::with_statistics`] for more details. + /// + /// The effective statistics are computed on-demand via + /// [`ProjectionExprs::project_statistics`]. + /// + /// Note that this field is pub(crate) because accessing it directly from outside + /// would be incorrect if there are filters being applied, thus this should be accessed + /// via [`FileScanConfig::statistics`]. + pub(crate) statistics: Statistics, + /// When true, file_groups are organized by partition column values + /// and output_partitioning will return Hash partitioning on partition columns. + /// This allows the optimizer to skip hash repartitioning for aggregates and joins + /// on partition columns. + /// + /// If the number of file partitions > target_partitions, the file partitions will be grouped + /// in a round-robin fashion such that number of file partitions = target_partitions. + pub partitioned_by_file_group: bool, +} + +/// A builder for [`FileScanConfig`]'s. +/// +/// Example: +/// +/// ```rust +/// # use std::sync::Arc; +/// # use arrow::datatypes::{DataType, Field, Schema}; +/// # use datafusion_datasource::file_scan_config::{FileScanConfigBuilder, FileScanConfig}; +/// # use datafusion_datasource::file_compression_type::FileCompressionType; +/// # use datafusion_datasource::file_groups::FileGroup; +/// # use datafusion_datasource::PartitionedFile; +/// # use datafusion_datasource::table_schema::TableSchema; +/// # use datafusion_execution::object_store::ObjectStoreUrl; +/// # use datafusion_common::Statistics; +/// # use datafusion_datasource::file::FileSource; +/// +/// # fn main() { +/// # fn with_source(file_source: Arc) { +/// // Create a schema for our Parquet files +/// let file_schema = Arc::new(Schema::new(vec![ +/// Field::new("id", DataType::Int32, false), +/// Field::new("value", DataType::Utf8, false), +/// ])); +/// +/// // Create partition columns +/// let partition_cols = vec![ +/// Arc::new(Field::new("date", DataType::Utf8, false)), +/// ]; +/// +/// // Create table schema with file schema and partition columns +/// let table_schema = TableSchema::builder(file_schema) +/// .with_table_partition_cols(partition_cols) +/// .build(); +/// +/// // Create a builder for scanning Parquet files from a local filesystem +/// let config = FileScanConfigBuilder::new( +/// ObjectStoreUrl::local_filesystem(), +/// file_source, +/// ) +/// // Set a limit of 1000 rows +/// .with_limit(Some(1000)) +/// // Project only the first column +/// .with_projection_indices(Some(vec![0])) +/// .expect("Failed to push down projection") +/// // Add a file group with two files +/// .with_file_group(FileGroup::new(vec![ +/// PartitionedFile::new("data/date=2024-01-01/file1.parquet", 1024), +/// PartitionedFile::new("data/date=2024-01-01/file2.parquet", 2048), +/// ])) +/// // Set compression type +/// .with_file_compression_type(FileCompressionType::UNCOMPRESSED) +/// // Build the final config +/// .build(); +/// # } +/// # } +/// ``` +#[derive(Clone)] +pub struct FileScanConfigBuilder { + object_store_url: ObjectStoreUrl, + file_source: Arc, + limit: Option, + preserve_order: bool, + constraints: Option, + file_groups: Vec, + statistics: Option, + output_ordering: Vec, + file_compression_type: Option, + batch_size: Option, + expr_adapter_factory: Option>, + partitioned_by_file_group: bool, +} + +impl FileScanConfigBuilder { + /// Create a new [`FileScanConfigBuilder`] with default settings for scanning files. + /// + /// # Parameters: + /// * `object_store_url`: See [`FileScanConfig::object_store_url`] + /// * `file_source`: See [`FileScanConfig::file_source`]. The file source must have + /// a schema set via its constructor. + pub fn new( + object_store_url: ObjectStoreUrl, + file_source: Arc, + ) -> Self { + Self { + object_store_url, + file_source, + file_groups: vec![], + statistics: None, + output_ordering: vec![], + file_compression_type: None, + limit: None, + preserve_order: false, + constraints: None, + batch_size: None, + expr_adapter_factory: None, + partitioned_by_file_group: false, + } + } + + /// Set the maximum number of records to read from this plan. + /// + /// If `None`, all records after filtering are returned. + pub fn with_limit(mut self, limit: Option) -> Self { + self.limit = limit; + self + } + + /// Set whether the limit should be order-sensitive. + /// + /// When `true`, files must be read in the exact order specified to produce + /// correct results (e.g., for `ORDER BY ... LIMIT` queries). When `false`, + /// DataFusion may reorder file processing for optimization without + /// affecting correctness. + pub fn with_preserve_order(mut self, order_sensitive: bool) -> Self { + self.preserve_order = order_sensitive; + self + } + + /// Set the file source for scanning files. + /// + /// This method allows you to change the file source implementation (e.g. + /// ParquetSource, CsvSource, etc.) after the builder has been created. + pub fn with_source(mut self, file_source: Arc) -> Self { + self.file_source = file_source; + self + } + + /// Return the table schema + pub fn table_schema(&self) -> &SchemaRef { + self.file_source.table_schema().table_schema() + } + + /// Set the columns on which to project the data. Indexes that are higher than the + /// number of columns of `file_schema` refer to `table_partition_cols`. + /// + /// # Deprecated + /// Use [`Self::with_projection_indices`] instead. This method will be removed in a future release. + #[deprecated(since = "51.0.0", note = "Use with_projection_indices instead")] + pub fn with_projection(self, indices: Option>) -> Self { + match self.clone().with_projection_indices(indices) { + Ok(builder) => builder, + Err(e) => { + warn!( + "Failed to push down projection in FileScanConfigBuilder::with_projection: {e}" + ); + self + } + } + } + + /// Set the columns on which to project the data using column indices. + /// + /// This method attempts to push down the projection to the underlying file + /// source if supported. If the file source does not support projection + /// pushdown, an error is returned. + /// + /// Indexes that are higher than the number of columns of `file_schema` + /// refer to `table_partition_cols`. + pub fn with_projection_indices( + mut self, + indices: Option>, + ) -> Result { + let projection_exprs = indices.map(|indices| { + ProjectionExprs::from_indices( + &indices, + self.file_source.table_schema().table_schema(), + ) + }); + let Some(projection_exprs) = projection_exprs else { + return Ok(self); + }; + let new_source = self + .file_source + .try_pushdown_projection(&projection_exprs) + .map_err(|e| { + internal_datafusion_err!( + "Failed to push down projection in FileScanConfigBuilder::build: {e}" + ) + })?; + if let Some(new_source) = new_source { + self.file_source = new_source; + } else { + internal_err!( + "FileSource {} does not support projection pushdown", + self.file_source.file_type() + )?; + } + Ok(self) + } + + /// Set the table constraints + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = Some(constraints); + self + } + + /// Set the statistics of the files, including partition + /// columns. Defaults to [`Statistics::new_unknown`]. + /// + /// These statistics are for the entire table (file schema + partition + /// columns) before any projection or filtering is applied. Projections are + /// applied when statistics are retrieved, and if a filter is present, + /// [`FileScanConfig::statistics`] will mark the statistics as inexact + /// (counts are not adjusted). + /// + /// Projections and filters may be applied by the file source, either by + /// [`Self::with_projection_indices`] or a preexisting + /// [`FileSource::projection`] or [`FileSource::filter`]. + pub fn with_statistics(mut self, statistics: Statistics) -> Self { + self.statistics = Some(statistics); + self + } + + /// Set the list of files to be processed, grouped into partitions. + /// + /// Each file must have a schema of `file_schema` or a subset. If + /// a particular file has a subset, the missing columns are + /// padded with NULLs. + /// + /// DataFusion may attempt to read each partition of files + /// concurrently, however files *within* a partition will be read + /// sequentially, one after the next. + pub fn with_file_groups(mut self, file_groups: Vec) -> Self { + self.file_groups = file_groups; + self + } + + /// Add a new file group + /// + /// See [`Self::with_file_groups`] for more information + pub fn with_file_group(mut self, file_group: FileGroup) -> Self { + self.file_groups.push(file_group); + self + } + + /// Add a file as a single group + /// + /// See [`Self::with_file_groups`] for more information. + pub fn with_file(self, partitioned_file: PartitionedFile) -> Self { + self.with_file_group(FileGroup::new(vec![partitioned_file])) + } + + /// Set the output ordering of the files + /// + /// The expressions are in terms of the entire table schema (file schema + + /// partition columns), before any projection or filtering from the file + /// scan is applied. + /// + /// This is used for optimization purposes, e.g. to determine if a file scan + /// can satisfy an `ORDER BY` without an additional sort. + pub fn with_output_ordering(mut self, output_ordering: Vec) -> Self { + self.output_ordering = output_ordering; + self + } + + /// Set the file compression type + pub fn with_file_compression_type( + mut self, + file_compression_type: FileCompressionType, + ) -> Self { + self.file_compression_type = Some(file_compression_type); + self + } + + /// Set the batch_size property + pub fn with_batch_size(mut self, batch_size: Option) -> Self { + self.batch_size = batch_size; + self + } + + /// Register an expression adapter used to adapt filters and projections that are pushed down into the scan + /// from the logical schema to the physical schema of the file. + /// This can include things like: + /// - Column ordering changes + /// - Handling of missing columns + /// - Rewriting expression to use pre-computed values or file format specific optimizations + pub fn with_expr_adapter( + mut self, + expr_adapter: Option>, + ) -> Self { + self.expr_adapter_factory = expr_adapter; + self + } + + /// Set whether file groups are organized by partition column values. + /// + /// When set to true, the output partitioning will be declared as Hash partitioning + /// on the partition columns. + pub fn with_partitioned_by_file_group( + mut self, + partitioned_by_file_group: bool, + ) -> Self { + self.partitioned_by_file_group = partitioned_by_file_group; + self + } + + /// Build the final [`FileScanConfig`] with all the configured settings. + /// + /// This method takes ownership of the builder and returns the constructed `FileScanConfig`. + /// Any unset optional fields will use their default values. + /// + /// # Errors + /// Returns an error if projection pushdown fails or if schema operations fail. + pub fn build(self) -> FileScanConfig { + let Self { + object_store_url, + file_source, + limit, + preserve_order, + constraints, + file_groups, + statistics, + output_ordering, + file_compression_type, + batch_size, + expr_adapter_factory: expr_adapter, + partitioned_by_file_group, + } = self; + + let constraints = constraints.unwrap_or_default(); + let statistics = statistics.unwrap_or_else(|| { + Statistics::new_unknown(file_source.table_schema().table_schema()) + }); + let file_compression_type = + file_compression_type.unwrap_or(FileCompressionType::UNCOMPRESSED); + + // If there is an output ordering, we should preserve it. + let preserve_order = preserve_order || !output_ordering.is_empty(); + + FileScanConfig { + object_store_url, + file_source, + limit, + preserve_order, + constraints, + file_groups, + output_ordering, + file_compression_type, + batch_size, + expr_adapter_factory: expr_adapter, + statistics, + partitioned_by_file_group, + } + } +} + +impl From for FileScanConfigBuilder { + fn from(config: FileScanConfig) -> Self { + Self { + object_store_url: config.object_store_url, + file_source: Arc::::clone(&config.file_source), + file_groups: config.file_groups, + statistics: Some(config.statistics), + output_ordering: config.output_ordering, + file_compression_type: Some(config.file_compression_type), + limit: config.limit, + preserve_order: config.preserve_order, + constraints: Some(config.constraints), + batch_size: config.batch_size, + expr_adapter_factory: config.expr_adapter_factory, + partitioned_by_file_group: config.partitioned_by_file_group, + } + } +} + +impl DataSource for FileScanConfig { + fn open( + &self, + partition: usize, + context: Arc, + ) -> Result { + self.open_with_args(OpenArgs::new(partition, context)) + } + + fn open_with_args(&self, args: OpenArgs) -> Result { + let OpenArgs { + partition, + context, + sibling_state, + } = args; + let object_store = context.runtime_env().object_store(&self.object_store_url)?; + let batch_size = self + .batch_size + .unwrap_or_else(|| context.session_config().batch_size()); + + let source = self.file_source.with_batch_size(batch_size); + + let morselizer = source.create_morselizer(object_store, self, partition)?; + + // Extract the shared work source from the sibling state if it exists. + // This allows multiple sibling streams to steal work from a single + // shared queue of unopened files. + let shared_work_source = sibling_state + .as_ref() + .and_then(|state| state.downcast_ref::()) + .cloned(); + + let stream = FileStreamBuilder::new(self) + .with_partition(partition) + .with_shared_work_source(shared_work_source) + .with_morselizer(morselizer) + .with_metrics(source.metrics()) + .build()?; + Ok(Box::pin(cooperative(stream))) + } + + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> FmtResult { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let schema = self.projected_schema().map_err(|_| std::fmt::Error {})?; + let orderings = + sort_pushdown::get_projected_output_ordering(self, &schema); + + write!(f, "file_groups=")?; + FileGroupsDisplay(&self.file_groups).fmt_as(t, f)?; + + if !schema.fields().is_empty() { + if let Some(projection) = self.file_source.projection() { + // This matches what ProjectionExec does. + // TODO: can we put this into ProjectionExprs so that it's shared code? + let expr: Vec = projection + .as_ref() + .iter() + .map(|proj_expr| { + if let Some(column) = + proj_expr.expr.downcast_ref::() + { + if column.name() == proj_expr.alias { + column.name().to_string() + } else { + format!( + "{} as {}", + proj_expr.expr, proj_expr.alias + ) + } + } else { + format!("{} as {}", proj_expr.expr, proj_expr.alias) + } + }) + .collect(); + write!(f, ", projection=[{}]", expr.join(", "))?; + } else { + write!(f, ", projection={}", ProjectSchemaDisplay(&schema))?; + } + } + + if let Some(limit) = self.limit { + write!(f, ", limit={limit}")?; + } + + display_orderings(f, &orderings)?; + + if !self.constraints.is_empty() { + write!(f, ", {}", self.constraints)?; + } + + self.fmt_file_source(t, f) + } + DisplayFormatType::TreeRender => { + writeln!(f, "format={}", self.file_source.file_type())?; + self.file_source.fmt_extra(t, f)?; + let num_files = self.file_groups.iter().map(|fg| fg.len()).sum::(); + writeln!(f, "files={num_files}")?; + Ok(()) + } + } + } + + /// If supported by the underlying [`FileSource`], redistribute files across partitions according to their size. + fn repartitioned( + &self, + target_partitions: usize, + repartition_file_min_size: usize, + output_ordering: Option, + ) -> Result>> { + // When files are grouped by partition values, we cannot allow byte-range + // splitting. It would mix rows from different partition values across + // file groups, breaking the Hash partitioning. + if self.partitioned_by_file_group { + return Ok(None); + } + + let source = self.file_source.repartitioned( + target_partitions, + repartition_file_min_size, + output_ordering, + self, + )?; + + Ok(source.map(|s| Arc::new(s) as _)) + } + + /// Returns the output partitioning for this file scan. + /// + /// When `partitioned_by_file_group` is true, this returns `Partitioning::Hash` on + /// the Hive partition columns, allowing the optimizer to skip hash repartitioning + /// for aggregates and joins on those columns. + /// + /// Tradeoffs + /// - Benefit: Eliminates `RepartitionExec` and `SortExec` for queries with + /// `GROUP BY` or `ORDER BY` on partition columns. + /// - Cost: Files are grouped by partition values rather than split by byte + /// ranges, which may reduce I/O parallelism when partition sizes are uneven. + /// For simple aggregations without `ORDER BY`, this cost may outweigh the benefit. + /// + /// Follow-up Work + /// - Idea: Could allow byte-range splitting within partition-aware groups, + /// preserving I/O parallelism while maintaining partition semantics. + fn output_partitioning(&self) -> Partitioning { + if self.partitioned_by_file_group { + let partition_cols = self.table_partition_cols(); + if !partition_cols.is_empty() { + let projected_schema = match self.projected_schema() { + Ok(schema) => schema, + Err(_) => { + debug!( + "Could not get projected schema, falling back to UnknownPartitioning." + ); + return Partitioning::UnknownPartitioning(self.file_groups.len()); + } + }; + + // Build Column expressions for partition columns based on their + // position in the projected schema + let mut exprs: Vec> = Vec::new(); + for partition_col in partition_cols { + if let Some((idx, _)) = projected_schema + .fields() + .iter() + .enumerate() + .find(|(_, f)| f.name() == partition_col.name()) + { + exprs.push(Arc::new(Column::new(partition_col.name(), idx))); + } + } + + if exprs.len() == partition_cols.len() { + return Partitioning::Hash(exprs, self.file_groups.len()); + } + } + } + Partitioning::UnknownPartitioning(self.file_groups.len()) + } + + /// Computes the effective equivalence properties of this file scan, taking + /// into account the file schema, any projections or filters applied by the + /// file source, and the output ordering. + fn eq_properties(&self) -> EquivalenceProperties { + let schema = self.file_source.table_schema().table_schema(); + let mut eq_properties = EquivalenceProperties::new_with_orderings( + Arc::clone(schema), + self.validated_output_ordering(), + ) + .with_constraints(self.constraints.clone()); + + if let Some(filter) = self.file_source.filter() { + // We need to remap column indexes to match the projected schema since that's what the equivalence properties deal with. + // Note that this will *ignore* any non-projected columns: these don't factor into ordering / equivalence. + match Self::add_filter_equivalence_info(&filter, &mut eq_properties, schema) { + Ok(()) => {} + Err(e) => { + warn!("Failed to add filter equivalence info: {e}"); + #[cfg(debug_assertions)] + panic!("Failed to add filter equivalence info: {e}"); + } + } + } + + if let Some(projection) = self.file_source.projection() { + match ( + projection.project_schema(schema), + projection.projection_mapping(schema), + ) { + (Ok(output_schema), Ok(mapping)) => { + eq_properties = + eq_properties.project(&mapping, Arc::new(output_schema)); + } + (Err(e), _) | (_, Err(e)) => { + warn!("Failed to project equivalence properties: {e}"); + #[cfg(debug_assertions)] + panic!("Failed to project equivalence properties: {e}"); + } + } + } + + eq_properties + } + + fn scheduling_type(&self) -> SchedulingType { + SchedulingType::Cooperative + } + + fn partition_statistics(&self, partition: Option) -> Result> { + if let Some(partition) = partition { + // Get statistics for a specific partition + // Note: FileGroup statistics include partition columns (computed from partition_values) + if let Some(file_group) = self.file_groups.get(partition) + && let Some(stat) = file_group.file_statistics(None) + { + // Project the statistics based on the projection + let output_schema = self.projected_schema()?; + return if let Some(projection) = self.file_source.projection() { + Ok(Arc::new( + projection.project_statistics(stat.clone(), &output_schema)?, + )) + } else { + Ok(Arc::new(stat.clone())) + }; + } + // If no statistics available for this partition, return unknown + Ok(Arc::new(Statistics::new_unknown( + self.projected_schema()?.as_ref(), + ))) + } else { + // Return aggregate statistics across all partitions + let statistics = self.statistics(); + let projection = self.file_source.projection(); + let output_schema = self.projected_schema()?; + if let Some(projection) = &projection { + Ok(Arc::new( + projection.project_statistics(statistics.clone(), &output_schema)?, + )) + } else { + Ok(Arc::new(statistics)) + } + } + } + + fn with_fetch(&self, limit: Option) -> Option> { + let source = FileScanConfigBuilder::from(self.clone()) + .with_limit(limit) + .build(); + Some(Arc::new(source)) + } + + fn fetch(&self) -> Option { + self.limit + } + + fn metrics(&self) -> ExecutionPlanMetricsSet { + self.file_source.metrics().clone() + } + + fn try_swapping_with_projection( + &self, + projection: &ProjectionExprs, + ) -> Result>> { + match self.file_source.try_pushdown_projection(projection)? { + Some(new_source) => { + let mut new_file_scan_config = self.clone(); + new_file_scan_config.file_source = new_source; + Ok(Some(Arc::new(new_file_scan_config) as Arc)) + } + None => Ok(None), + } + } + + fn try_pushdown_filters( + &self, + filters: Vec>, + config: &ConfigOptions, + ) -> Result>> { + // Remap filter Column indices to match the table schema (file + partition columns). + // This is necessary because filters refer to the output schema of this `DataSource` + // (e.g., after projection pushdown has been applied) and need to be remapped to the table schema + // before being passed to the file source + // + // For example, consider a filter `c1_c2 > 5` being pushed down. If the + // `DataSource` has a projection `c1 + c2 as c1_c2`, the filter must be rewritten + // to refer to the table schema `c1 + c2 > 5` + let table_schema = self.file_source.table_schema().table_schema(); + let filters_to_remap = if let Some(projection) = self.file_source.projection() { + filters + .into_iter() + .map(|filter| projection.unproject_expr(&filter)) + .collect::>>()? + } else { + filters + }; + // Now remap column indices to match the table schema. + let remapped_filters = filters_to_remap + .into_iter() + .map(|filter| reassign_expr_columns(filter, table_schema)) + .collect::>>()?; + + let result = self + .file_source + .try_pushdown_filters(remapped_filters, config)?; + match result.updated_node { + Some(new_file_source) => { + let mut new_file_scan_config = self.clone(); + new_file_scan_config.file_source = new_file_source; + Ok(FilterPushdownPropagation { + filters: result.filters, + updated_node: Some(Arc::new(new_file_scan_config) as _), + }) + } + None => { + // If the file source does not support filter pushdown, return the original config + Ok(FilterPushdownPropagation { + filters: result.filters, + updated_node: None, + }) + } + } + } + + /// Push sort requirements into file-based data sources. + /// + /// # Sort Pushdown Architecture + /// + /// When a partition (file group) contains multiple files in wrong order, + /// `validated_output_ordering()` strips the ordering and `EnforceSorting` + /// inserts a `SortExec`. This optimizer fixes the file order by sorting + /// files within each group by min/max statistics, enabling sort elimination. + /// + /// This applies to both single-partition and multi-partition plans — any + /// file group with multiple files in wrong order benefits. + /// + /// ```text + /// PushdownSort optimizer finds SortExec + /// │ + /// ▼ + /// FileScanConfig::try_pushdown_sort() + /// │ + /// ├─► FileSource returns Exact + /// │ (natural ordering satisfies request) + /// │ → rebuild_with_source: sort files by stats, verify non-overlapping + /// │ → SortExec removed, fetch (LIMIT) pushed to DataSourceExec + /// │ + /// ├─► FileSource returns Inexact + /// │ (e.g. column_in_file_schema: opener will reorder RGs at runtime) + /// │ → rebuild_with_source: sort files by stats; if the post-sort + /// │ file groups are non-overlapping AND the request now validates + /// │ AND no NULLs sit in the sort columns of non-last files, + /// │ upgrade back to Exact (SortExec removed). Otherwise stays + /// │ Inexact and SortExec is kept while the scan is still + /// │ optimised via `sort_order_for_reorder` / `reverse_row_groups`. + /// │ + /// └─► FileSource returns Unsupported + /// (e.g. expression sort key or partition column) + /// → try_sort_file_groups_by_statistics(): + /// 1. Sort files within each group by min/max statistics + /// 2. Re-check: non-overlapping + ordering valid + no NULLs? + /// YES → Exact → SortExec removed + /// NO → Inexact (files reordered, Sort stays) + /// ``` + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + ) -> Result>> { + let pushdown_result = self + .file_source + .try_pushdown_sort(order, &self.eq_properties())?; + + match pushdown_result { + SortOrderPushdownResult::Exact { inner } => { + let config = self.rebuild_with_source(inner, true, order)?; + // rebuild_with_source keeps output_ordering only when all groups + // are non-overlapping. If output_ordering was cleared, files + // overlap despite within-file ordering → downgrade to Inexact. + if config.output_ordering.is_empty() { + Ok(SortOrderPushdownResult::Inexact { + inner: Arc::new(config), + }) + } else { + Ok(SortOrderPushdownResult::Exact { + inner: Arc::new(config), + }) + } + } + SortOrderPushdownResult::Inexact { inner } => { + let mut config = self.rebuild_with_source(inner, false, order)?; + // `rebuild_with_source` reorders files by stats; if the + // post-sort files are non-overlapping AND the request now + // validates against the new file groups, `output_ordering` + // is preserved and we can upgrade back to Exact. This + // restores the sort-elimination behaviour that lived in + // the `Unsupported` → `try_sort_file_groups_by_statistics` + // path before #21956 routed `column_in_file_schema` cases + // here. + if config.output_ordering.is_empty() { + return Ok(SortOrderPushdownResult::Inexact { + inner: Arc::new(config), + }); + } + // Upgrading to Exact: the post-sort file groups are + // non-overlapping and each file's declared ordering + // re-validates, so reading the files in their natural + // (declared-sorted) order already yields the requested + // ordering — exactly like the `Unsupported` → Exact path, + // which reads files in natural order too. + // + // Drop the runtime row-group reorder hints the Inexact + // source carried (`sort_order_for_reorder` / + // `reverse_row_groups`) by restoring the original, + // hint-free source. With the `SortExec` removed those + // hints are not just redundant but unsafe: for a DESC + // request the opener sorts row groups ASC-by-min and then + // reverses them, which reorders two row groups within a + // single file that share the same `min` incorrectly + // (e.g. a file `[10,8,8,8]` whose row groups are + // `[10,8]` and `[8,8]` would stream as `8,8,10,8`). + // The `SortExec` used to mask this; once it is gone the + // reordered stream is the final, wrong answer. + config.file_source = Arc::clone(&self.file_source); + Ok(SortOrderPushdownResult::Exact { + inner: Arc::new(config), + }) + } + SortOrderPushdownResult::Unsupported => { + self.try_sort_file_groups_by_statistics(order) + } + } + } + + fn with_preserve_order(&self, preserve_order: bool) -> Option> { + if self.preserve_order == preserve_order { + return Some(Arc::new(self.clone())); + } + + let new_config = FileScanConfig { + preserve_order, + ..self.clone() + }; + Some(Arc::new(new_config)) + } + + /// Create any shared state that should be passed between sibling streams + /// during one execution. + /// + /// This returns `None` when sibling streams must not share work, such as + /// when file order must be preserved or the file groups define the output + /// partitioning needed for the rest of the plan + fn create_sibling_state(&self) -> Option> { + if self.preserve_order || self.partitioned_by_file_group { + return None; + } + + Some(Arc::new(SharedWorkSource::from_config(self)) as Arc) + } +} + +impl FileScanConfig { + /// Returns only the output orderings that are validated against actual + /// file group statistics. + /// + /// For example, individual files may be ordered by `col1 ASC`, + /// but if we have files with these min/max statistics in a single partition / file group: + /// + /// - file1: min(col1) = 10, max(col1) = 20 + /// - file2: min(col1) = 5, max(col1) = 15 + /// + /// Because reading file1 followed by file2 would produce out-of-order output (there is overlap + /// in the ranges), we cannot retain `col1 ASC` as a valid output ordering. + /// + /// Similarly this would not be a valid order (non-overlapping ranges but not ordered): + /// + /// - file1: min(col1) = 20, max(col1) = 30 + /// - file2: min(col1) = 10, max(col1) = 15 + /// + /// On the other hand if we had: + /// + /// - file1: min(col1) = 5, max(col1) = 15 + /// - file2: min(col1) = 16, max(col1) = 25 + /// + /// Then we know that reading file1 followed by file2 will produce ordered output, + /// so `col1 ASC` would be retained. + /// + /// Note that we are checking for ordering *within* *each* file group / partition, + /// files in different partitions are read independently and do not affect each other's ordering. + /// Merging of the multiple partition streams into a single ordered stream is handled + /// upstream e.g. by `SortPreservingMergeExec`. + fn validated_output_ordering(&self) -> Vec { + let schema = self.file_source.table_schema().table_schema(); + sort_pushdown::validate_orderings( + &self.output_ordering, + schema, + &self.file_groups, + None, + ) + } + + /// Get the file schema (schema of the files without partition columns) + pub fn file_schema(&self) -> &SchemaRef { + self.file_source.table_schema().file_schema() + } + + /// Get the table partition columns + pub fn table_partition_cols(&self) -> &Fields { + self.file_source.table_schema().table_partition_cols() + } + + /// Returns the unprojected table statistics, marking them as inexact if filters are present. + /// + /// When filters are pushed down (including pruning predicates and bloom filters), + /// we can't guarantee the statistics are exact because we don't know how many + /// rows will be filtered out. + pub fn statistics(&self) -> Statistics { + if self.file_source.filter().is_some() { + self.statistics.clone().to_inexact() + } else { + self.statistics.clone() + } + } + + pub fn projected_schema(&self) -> Result> { + let schema = self.file_source.table_schema().table_schema(); + match self.file_source.projection() { + Some(proj) => Ok(Arc::new(proj.project_schema(schema)?)), + None => Ok(Arc::clone(schema)), + } + } + + fn add_filter_equivalence_info( + filter: &Arc, + eq_properties: &mut EquivalenceProperties, + schema: &Schema, + ) -> Result<()> { + // Gather valid equality pairs from the filter expression + let equal_pairs = split_conjunction(filter).into_iter().filter_map(|expr| { + // Ignore any binary expressions that reference non-existent columns in the current schema + // (e.g. due to unnecessary projections being removed) + reassign_expr_columns(Arc::clone(expr), schema) + .ok() + .and_then(|expr| match expr.downcast_ref::() { + Some(expr) if expr.op() == &Operator::Eq => { + Some((Arc::clone(expr.left()), Arc::clone(expr.right()))) + } + _ => None, + }) + }); + + for (lhs, rhs) in equal_pairs { + eq_properties.add_equal_conditions(lhs, rhs)? + } + + Ok(()) + } + + /// Returns whether newlines in values are supported. + /// + /// This method always returns `false`. The actual newlines_in_values setting + /// has been moved to [`CsvSource`] and should be accessed via + /// [`CsvSource::csv_options()`] instead. + /// + /// [`CsvSource`]: https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.CsvSource.html + /// [`CsvSource::csv_options()`]: https://docs.rs/datafusion/latest/datafusion/datasource/physical_plan/struct.CsvSource.html#method.csv_options + #[deprecated( + since = "52.0.0", + note = "newlines_in_values has moved to CsvSource. Access it via CsvSource::csv_options().newlines_in_values instead. It will be removed in 58.0.0 or 6 months after 52.0.0 is released, whichever comes first." + )] + pub fn newlines_in_values(&self) -> bool { + false + } + + #[deprecated( + since = "52.0.0", + note = "This method is no longer used, use eq_properties instead. It will be removed in 58.0.0 or 6 months after 52.0.0 is released, whichever comes first." + )] + pub fn projected_constraints(&self) -> Constraints { + let props = self.eq_properties(); + props.constraints().clone() + } + + #[deprecated( + since = "52.0.0", + note = "This method is no longer used, use eq_properties instead. It will be removed in 58.0.0 or 6 months after 52.0.0 is released, whichever comes first." + )] + pub fn file_column_projection_indices(&self) -> Option> { + #[expect(deprecated)] + self.file_source.projection().as_ref().map(|p| { + p.ordered_column_indices() + .into_iter() + .filter(|&i| i < self.file_schema().fields().len()) + .collect::>() + }) + } + + /// Splits file groups into new groups based on statistics to enable efficient parallel processing. + /// + /// The method distributes files across a target number of partitions while ensuring + /// files within each partition maintain sort order based on their min/max statistics. + /// + /// The algorithm works by: + /// 1. Takes files sorted by minimum values + /// 2. For each file: + /// - Finds eligible groups (empty or where file's min > group's last max) + /// - Selects the smallest eligible group + /// - Creates a new group if needed + /// + /// # Parameters + /// * `table_schema`: Schema containing information about the columns + /// * `file_groups`: The original file groups to split + /// * `sort_order`: The lexicographical ordering to maintain within each group + /// * `target_partitions`: The desired number of output partitions + /// + /// # Returns + /// A new set of file groups, where files within each group are non-overlapping with respect to + /// their min/max statistics and maintain the specified sort order. + pub fn split_groups_by_statistics_with_target_partitions( + table_schema: &SchemaRef, + file_groups: &[FileGroup], + sort_order: &LexOrdering, + target_partitions: usize, + ) -> Result> { + if target_partitions == 0 { + return Err(internal_datafusion_err!( + "target_partitions must be greater than 0" + )); + } + + let flattened_files = file_groups + .iter() + .flat_map(FileGroup::iter) + .collect::>(); + + if flattened_files.is_empty() { + return Ok(vec![]); + } + + let statistics = MinMaxStatistics::new_from_files( + sort_order, + table_schema, + None, + flattened_files.iter().copied(), + )?; + + let indices_sorted_by_min = statistics.min_values_sorted(); + + // Initialize with target_partitions empty groups + let mut file_groups_indices: Vec> = vec![vec![]; target_partitions]; + + for (idx, min) in indices_sorted_by_min { + if let Some((_, group)) = file_groups_indices + .iter_mut() + .enumerate() + .filter(|(_, group)| { + group.is_empty() + || min + > statistics + .max(*group.last().expect("groups should not be empty")) + }) + .min_by_key(|(_, group)| group.len()) + { + group.push(idx); + } else { + // Create a new group if no existing group fits + file_groups_indices.push(vec![idx]); + } + } + + // Remove any empty groups + file_groups_indices.retain(|group| !group.is_empty()); + + // Assemble indices back into groups of PartitionedFiles + Ok(file_groups_indices + .into_iter() + .map(|file_group_indices| { + FileGroup::new( + file_group_indices + .into_iter() + .map(|idx| flattened_files[idx].clone()) + .collect(), + ) + }) + .collect()) + } + + /// Attempts to do a bin-packing on files into file groups, such that any two files + /// in a file group are ordered and non-overlapping with respect to their statistics. + /// It will produce the smallest number of file groups possible. + pub fn split_groups_by_statistics( + table_schema: &SchemaRef, + file_groups: &[FileGroup], + sort_order: &LexOrdering, + ) -> Result> { + let flattened_files = file_groups + .iter() + .flat_map(FileGroup::iter) + .collect::>(); + // First Fit: + // * Choose the first file group that a file can be placed into. + // * If it fits into no existing file groups, create a new one. + // + // By sorting files by min values and then applying first-fit bin packing, + // we can produce the smallest number of file groups such that + // files within a group are in order and non-overlapping. + // + // Source: Applied Combinatorics (Keller and Trotter), Chapter 6.8 + // https://www.appliedcombinatorics.org/book/s_posets_dilworth-intord.html + + if flattened_files.is_empty() { + return Ok(vec![]); + } + + let statistics = MinMaxStatistics::new_from_files( + sort_order, + table_schema, + None, + flattened_files.iter().copied(), + ) + .map_err(|e| { + e.context("construct min/max statistics for split_groups_by_statistics") + })?; + + let indices_sorted_by_min = statistics.min_values_sorted(); + let mut file_groups_indices: Vec> = vec![]; + + for (idx, min) in indices_sorted_by_min { + let file_group_to_insert = file_groups_indices.iter_mut().find(|group| { + // If our file is non-overlapping and comes _after_ the last file, + // it fits in this file group. + min > statistics.max( + *group + .last() + .expect("groups should be nonempty at construction"), + ) + }); + match file_group_to_insert { + Some(group) => group.push(idx), + None => file_groups_indices.push(vec![idx]), + } + } + + // Assemble indices back into groups of PartitionedFiles + Ok(file_groups_indices + .into_iter() + .map(|file_group_indices| { + file_group_indices + .into_iter() + .map(|idx| flattened_files[idx].clone()) + .collect() + }) + .collect()) + } + + /// Write the data_type based on file_source + fn fmt_file_source(&self, t: DisplayFormatType, f: &mut Formatter) -> FmtResult { + write!(f, ", file_type={}", self.file_source.file_type())?; + self.file_source.fmt_extra(t, f) + } + + /// Returns the file_source + pub fn file_source(&self) -> &Arc { + &self.file_source + } + + // Sort pushdown methods (rebuild_with_source, try_sort_file_groups_by_statistics, + // sort_files_within_groups_by_statistics, any_file_has_nulls_in_sort_columns) + // are in crate::sort_pushdown module. +} + +impl Debug for FileScanConfig { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "FileScanConfig {{")?; + write!(f, "object_store_url={:?}, ", self.object_store_url)?; + + write!(f, "statistics={:?}, ", self.statistics())?; + + DisplayAs::fmt_as(self, DisplayFormatType::Verbose, f)?; + write!(f, "}}") + } +} + +impl DisplayAs for FileScanConfig { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> FmtResult { + let schema = self.projected_schema().map_err(|_| std::fmt::Error {})?; + let orderings = sort_pushdown::get_projected_output_ordering(self, &schema); + + write!(f, "file_groups=")?; + FileGroupsDisplay(&self.file_groups).fmt_as(t, f)?; + + if !schema.fields().is_empty() { + write!(f, ", projection={}", ProjectSchemaDisplay(&schema))?; + } + + if let Some(limit) = self.limit { + write!(f, ", limit={limit}")?; + } + + display_orderings(f, &orderings)?; + + if !self.constraints.is_empty() { + write!(f, ", {}", self.constraints)?; + } + + Ok(()) + } +} + +/// Convert type to a type suitable for use as a `ListingTable` +/// partition column. Returns `Dictionary(UInt16, val_type)`, which is +/// a reasonable trade off between a reasonable number of partition +/// values and space efficiency. +/// +/// This use this to specify types for partition columns. However +/// you MAY also choose not to dictionary-encode the data or to use a +/// different dictionary type. +/// +/// Use [`wrap_partition_value_in_dict`] to wrap a [`ScalarValue`] in the same say. +pub fn wrap_partition_type_in_dict(val_type: DataType) -> DataType { + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(val_type)) +} + +/// Convert a [`ScalarValue`] of partition columns to a type, as +/// described in the documentation of [`wrap_partition_type_in_dict`], +/// which can wrap the types. +pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue { + ScalarValue::Dictionary(Box::new(DataType::UInt16), Box::new(val)) +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use super::*; + use crate::source::DataSourceExec; + use crate::test_util::col; + use crate::{TableSchema, TableSchemaBuilder}; + use crate::{ + generate_test_files, test_util::MockSource, tests::aggr_test_schema, + verify_sort_integrity, + }; + + use arrow::array::{Int32Array, RecordBatch}; + use arrow::datatypes::Field; + use datafusion_common::ColumnStatistics; + use datafusion_common::stats::Precision; + use datafusion_common::{Result, assert_batches_eq, internal_err}; + use datafusion_execution::TaskContext; + use datafusion_expr::SortExpr; + use datafusion_physical_expr::create_physical_sort_expr; + use datafusion_physical_expr::expressions::Literal; + use datafusion_physical_expr::projection::ProjectionExpr; + use datafusion_physical_expr::projection::ProjectionExprs; + use datafusion_physical_plan::ExecutionPlan; + use datafusion_physical_plan::execution_plan::collect; + use futures::FutureExt as _; + use futures::StreamExt as _; + use futures::stream; + use object_store::ObjectStore; + use std::fmt::Debug; + + #[derive(Clone)] + struct InexactSortPushdownSource { + metrics: ExecutionPlanMetricsSet, + table_schema: TableSchema, + } + + impl InexactSortPushdownSource { + fn new(table_schema: TableSchema) -> Self { + Self { + metrics: ExecutionPlanMetricsSet::new(), + table_schema, + } + } + } + + impl FileSource for InexactSortPushdownSource { + fn create_file_opener( + &self, + _object_store: Arc, + _base_config: &FileScanConfig, + _partition: usize, + ) -> Result> { + unimplemented!() + } + + fn table_schema(&self) -> &TableSchema { + &self.table_schema + } + + fn with_batch_size(&self, _batch_size: usize) -> Arc { + Arc::new(self.clone()) + } + + fn metrics(&self) -> &ExecutionPlanMetricsSet { + &self.metrics + } + + fn file_type(&self) -> &str { + "mock" + } + + fn try_pushdown_sort( + &self, + _order: &[PhysicalSortExpr], + _eq_properties: &EquivalenceProperties, + ) -> Result>> { + Ok(SortOrderPushdownResult::Inexact { + inner: Arc::new(self.clone()) as Arc, + }) + } + } + + #[test] + fn physical_plan_config_no_projection_tab_cols_as_field() { + let file_schema = aggr_test_schema(); + + // make a table_partition_col as a field + let table_partition_col = + Field::new("date", wrap_partition_type_in_dict(DataType::Utf8), true) + .with_metadata(HashMap::from_iter(vec![( + "key_whatever".to_owned(), + "value_whatever".to_owned(), + )])); + + let conf = config_for_projection( + Arc::clone(&file_schema), + None, + Statistics::new_unknown(&file_schema), + vec![table_partition_col.clone()], + ); + + // verify the proj_schema includes the last column and exactly the same the field it is defined + let proj_schema = conf.projected_schema().unwrap(); + assert_eq!(proj_schema.fields().len(), file_schema.fields().len() + 1); + assert_eq!( + *proj_schema.field(file_schema.fields().len()), + table_partition_col, + "partition columns are the last columns and ust have all values defined in created field" + ); + } + + #[test] + fn test_split_groups_by_statistics() -> Result<()> { + use chrono::TimeZone; + use datafusion_common::DFSchema; + use datafusion_expr::execution_props::ExecutionProps; + use object_store::{ObjectMeta, path::Path}; + + struct File { + name: &'static str, + date: &'static str, + statistics: Vec, Option)>>, + } + impl File { + fn new( + name: &'static str, + date: &'static str, + statistics: Vec>, + ) -> Self { + Self::new_nullable( + name, + date, + statistics + .into_iter() + .map(|opt| opt.map(|(min, max)| (Some(min), Some(max)))) + .collect(), + ) + } + + fn new_nullable( + name: &'static str, + date: &'static str, + statistics: Vec, Option)>>, + ) -> Self { + Self { + name, + date, + statistics, + } + } + } + + struct TestCase { + name: &'static str, + file_schema: Schema, + files: Vec, + sort: Vec, + expected_result: Result>, &'static str>, + } + + use datafusion_expr::col; + let cases = vec![ + TestCase { + name: "test sort", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + false, + )]), + files: vec![ + File::new("0", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("1", "2023-01-01", vec![Some((0.50, 1.00))]), + File::new("2", "2023-01-02", vec![Some((0.00, 1.00))]), + ], + sort: vec![col("value").sort(true, false)], + expected_result: Ok(vec![vec!["0", "1"], vec!["2"]]), + }, + // same input but file '2' is in the middle + // test that we still order correctly + TestCase { + name: "test sort with files ordered differently", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + false, + )]), + files: vec![ + File::new("0", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("2", "2023-01-02", vec![Some((0.00, 1.00))]), + File::new("1", "2023-01-01", vec![Some((0.50, 1.00))]), + ], + sort: vec![col("value").sort(true, false)], + expected_result: Ok(vec![vec!["0", "1"], vec!["2"]]), + }, + TestCase { + name: "reverse sort", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + false, + )]), + files: vec![ + File::new("0", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("1", "2023-01-01", vec![Some((0.50, 1.00))]), + File::new("2", "2023-01-02", vec![Some((0.00, 1.00))]), + ], + sort: vec![col("value").sort(false, true)], + expected_result: Ok(vec![vec!["1", "0"], vec!["2"]]), + }, + TestCase { + name: "nullable sort columns, nulls last", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + true, + )]), + files: vec![ + File::new_nullable( + "0", + "2023-01-01", + vec![Some((Some(0.00), Some(0.49)))], + ), + File::new_nullable("1", "2023-01-01", vec![Some((Some(0.50), None))]), + File::new_nullable("2", "2023-01-02", vec![Some((Some(0.00), None))]), + ], + sort: vec![col("value").sort(true, false)], + expected_result: Ok(vec![vec!["0", "1"], vec!["2"]]), + }, + TestCase { + name: "nullable sort columns, nulls first", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + true, + )]), + files: vec![ + File::new_nullable("0", "2023-01-01", vec![Some((None, Some(0.49)))]), + File::new_nullable( + "1", + "2023-01-01", + vec![Some((Some(0.50), Some(1.00)))], + ), + File::new_nullable("2", "2023-01-02", vec![Some((None, Some(1.00)))]), + ], + sort: vec![col("value").sort(true, true)], + expected_result: Ok(vec![vec!["0", "1"], vec!["2"]]), + }, + TestCase { + name: "all three non-overlapping", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + false, + )]), + files: vec![ + File::new("0", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("1", "2023-01-01", vec![Some((0.50, 0.99))]), + File::new("2", "2023-01-02", vec![Some((1.00, 1.49))]), + ], + sort: vec![col("value").sort(true, false)], + expected_result: Ok(vec![vec!["0", "1", "2"]]), + }, + TestCase { + name: "all three overlapping", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + false, + )]), + files: vec![ + File::new("0", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("1", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("2", "2023-01-02", vec![Some((0.00, 0.49))]), + ], + sort: vec![col("value").sort(true, false)], + expected_result: Ok(vec![vec!["0"], vec!["1"], vec!["2"]]), + }, + TestCase { + name: "empty input", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + false, + )]), + files: vec![], + sort: vec![col("value").sort(true, false)], + expected_result: Ok(vec![]), + }, + TestCase { + name: "one file missing statistics", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + false, + )]), + files: vec![ + File::new("0", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("1", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("2", "2023-01-02", vec![None]), + ], + sort: vec![col("value").sort(true, false)], + expected_result: Err( + "construct min/max statistics for split_groups_by_statistics\ncaused by\ncollect min/max values\ncaused by\nget min/max for column: 'value'\ncaused by\nError during planning: statistics not found", + ), + }, + ]; + + for case in cases { + let table_schema = Arc::new(Schema::new( + case.file_schema + .fields() + .clone() + .into_iter() + .cloned() + .chain(Some(Arc::new(Field::new( + "date".to_string(), + DataType::Utf8, + false, + )))) + .collect::>(), + )); + let Some(sort_order) = LexOrdering::new( + case.sort + .into_iter() + .map(|expr| { + create_physical_sort_expr( + &expr, + &DFSchema::try_from(Arc::clone(&table_schema))?, + &ExecutionProps::default(), + ) + }) + .collect::>>()?, + ) else { + return internal_err!("This test should always use an ordering"); + }; + + let partitioned_files = FileGroup::new( + case.files.into_iter().map(From::from).collect::>(), + ); + let result = FileScanConfig::split_groups_by_statistics( + &table_schema, + std::slice::from_ref(&partitioned_files), + &sort_order, + ); + let results_by_name = result + .as_ref() + .map(|file_groups| { + file_groups + .iter() + .map(|file_group| { + file_group + .iter() + .map(|file| { + partitioned_files + .iter() + .find_map(|f| { + if f.object_meta == file.object_meta { + Some( + f.object_meta + .location + .as_ref() + .rsplit('/') + .next() + .unwrap() + .trim_end_matches(".parquet"), + ) + } else { + None + } + }) + .unwrap() + }) + .collect::>() + }) + .collect::>() + }) + .map_err(|e| e.strip_backtrace().leak() as &'static str); + + assert_eq!(results_by_name, case.expected_result, "{}", case.name); + } + + return Ok(()); + + impl From for PartitionedFile { + fn from(file: File) -> Self { + let object_meta = ObjectMeta { + location: Path::from(format!( + "data/date={}/{}.parquet", + file.date, file.name + )), + last_modified: chrono::Utc.timestamp_nanos(0), + size: 0, + e_tag: None, + version: None, + }; + let statistics = Arc::new(Statistics { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: file + .statistics + .into_iter() + .map(|stats| { + stats + .map(|(min, max)| ColumnStatistics { + min_value: Precision::Exact(ScalarValue::Float64( + min, + )), + max_value: Precision::Exact(ScalarValue::Float64( + max, + )), + ..Default::default() + }) + .unwrap_or_default() + }) + .collect::>(), + }); + PartitionedFile::new_from_meta(object_meta) + .with_partition_values(vec![ScalarValue::from(file.date)]) + .with_statistics(statistics) + } + } + } + + // sets default for configs that play no role in projections + fn config_for_projection( + file_schema: SchemaRef, + projection: Option>, + statistics: Statistics, + table_partition_cols: Vec, + ) -> FileScanConfig { + let table_schema = TableSchema::builder(file_schema) + .with_table_partition_cols( + table_partition_cols + .into_iter() + .map(Arc::new) + .collect::(), + ) + .build(); + FileScanConfigBuilder::new( + ObjectStoreUrl::parse("test:///").unwrap(), + Arc::new(MockSource::new(table_schema.clone())), + ) + .with_projection_indices(projection) + .unwrap() + .with_statistics(statistics) + .build() + } + + #[test] + fn test_file_scan_config_builder() { + let file_schema = aggr_test_schema(); + let object_store_url = ObjectStoreUrl::parse("test:///").unwrap(); + + let table_schema = TableSchemaBuilder::from(&file_schema) + .with_table_partition_cols(vec![Arc::new(Field::new( + "date", + wrap_partition_type_in_dict(DataType::Utf8), + false, + ))]) + .build(); + + let file_source: Arc = + Arc::new(MockSource::new(table_schema.clone())); + + // Create a builder with required parameters + let builder = FileScanConfigBuilder::new( + object_store_url.clone(), + Arc::clone(&file_source), + ); + + // Build with various configurations + let config = builder + .with_limit(Some(1000)) + .with_projection_indices(Some(vec![0, 1])) + .unwrap() + .with_statistics(Statistics::new_unknown(&file_schema)) + .with_file_groups(vec![FileGroup::new(vec![PartitionedFile::new( + "test.parquet".to_string(), + 1024, + )])]) + .with_output_ordering(vec![ + [PhysicalSortExpr::new_default(Arc::new(Column::new( + "date", 0, + )))] + .into(), + ]) + .with_file_compression_type(FileCompressionType::UNCOMPRESSED) + .build(); + + // Verify the built config has all the expected values + assert_eq!(config.object_store_url, object_store_url); + assert_eq!(*config.file_schema(), file_schema); + assert_eq!(config.limit, Some(1000)); + assert_eq!( + config + .file_source + .projection() + .as_ref() + .map(|p| p.column_indices()), + Some(vec![0, 1]) + ); + assert_eq!(config.table_partition_cols().len(), 1); + assert_eq!(config.table_partition_cols()[0].name(), "date"); + assert_eq!(config.file_groups.len(), 1); + assert_eq!(config.file_groups[0].len(), 1); + assert_eq!( + config.file_groups[0][0].object_meta.location.as_ref(), + "test.parquet" + ); + assert_eq!( + config.file_compression_type, + FileCompressionType::UNCOMPRESSED + ); + assert_eq!(config.output_ordering.len(), 1); + } + + #[test] + fn equivalence_properties_after_schema_change() { + let file_schema = aggr_test_schema(); + let object_store_url = ObjectStoreUrl::parse("test:///").unwrap(); + + let table_schema = TableSchema::from(&file_schema); + + // Create a file source with a filter + let file_source: Arc = Arc::new( + MockSource::new(table_schema.clone()).with_filter(Arc::new(BinaryExpr::new( + col("c2", &file_schema).unwrap(), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + ))), + ); + + let config = FileScanConfigBuilder::new( + object_store_url.clone(), + Arc::clone(&file_source), + ) + .with_projection_indices(Some(vec![0, 1, 2])) + .unwrap() + .build(); + + // Simulate projection being updated. Since the filter has already been pushed down, + // the new projection won't include the filtered column. + let exprs = ProjectionExprs::new(vec![ProjectionExpr::new( + col("c1", &file_schema).unwrap(), + "c1", + )]); + let data_source = config + .try_swapping_with_projection(&exprs) + .unwrap() + .unwrap(); + + // Gather the equivalence properties from the new data source. There should + // be no equivalence class for column c2 since it was removed by the projection. + let eq_properties = data_source.eq_properties(); + let eq_group = eq_properties.eq_group(); + + for class in eq_group.iter() { + for expr in class.iter() { + if let Some(col) = expr.downcast_ref::() { + assert_ne!( + col.name(), + "c2", + "c2 should not be present in any equivalence class" + ); + } + } + } + } + + #[test] + fn test_file_scan_config_builder_defaults() { + let file_schema = aggr_test_schema(); + let object_store_url = ObjectStoreUrl::parse("test:///").unwrap(); + + let table_schema = TableSchema::from(&file_schema); + + let file_source: Arc = + Arc::new(MockSource::new(table_schema.clone())); + + // Create a builder with only required parameters and build without any additional configurations + let config = FileScanConfigBuilder::new( + object_store_url.clone(), + Arc::clone(&file_source), + ) + .build(); + + // Verify default values + assert_eq!(config.object_store_url, object_store_url); + assert_eq!(*config.file_schema(), file_schema); + assert_eq!(config.limit, None); + // When no projection is specified, the file source should have an unprojected projection + // (i.e., all columns) + let expected_projection: Vec = (0..file_schema.fields().len()).collect(); + assert_eq!( + config + .file_source + .projection() + .as_ref() + .map(|p| p.column_indices()), + Some(expected_projection) + ); + assert!(config.table_partition_cols().is_empty()); + assert!(config.file_groups.is_empty()); + assert_eq!( + config.file_compression_type, + FileCompressionType::UNCOMPRESSED + ); + assert!(config.output_ordering.is_empty()); + assert!(config.constraints.is_empty()); + + // Verify statistics are set to unknown + assert_eq!(config.statistics().num_rows, Precision::Absent); + assert_eq!(config.statistics().total_byte_size, Precision::Absent); + assert_eq!( + config.statistics().column_statistics.len(), + file_schema.fields().len() + ); + for stat in config.statistics().column_statistics { + assert_eq!(stat.distinct_count, Precision::Absent); + assert_eq!(stat.min_value, Precision::Absent); + assert_eq!(stat.max_value, Precision::Absent); + assert_eq!(stat.null_count, Precision::Absent); + } + } + + #[test] + fn test_file_scan_config_builder_new_from() { + let schema = aggr_test_schema(); + let object_store_url = ObjectStoreUrl::parse("test:///").unwrap(); + let partition_cols = vec![Field::new( + "date", + wrap_partition_type_in_dict(DataType::Utf8), + false, + )]; + let file = PartitionedFile::new("test_file.parquet", 100); + + let table_schema = TableSchemaBuilder::from(&schema) + .with_table_partition_cols( + partition_cols + .iter() + .map(|f| Arc::new(f.clone())) + .collect::(), + ) + .build(); + + let file_source: Arc = + Arc::new(MockSource::new(table_schema.clone())); + + // Create a config with non-default values + let original_config = FileScanConfigBuilder::new( + object_store_url.clone(), + Arc::clone(&file_source), + ) + .with_projection_indices(Some(vec![0, 2])) + .unwrap() + .with_limit(Some(10)) + .with_file(file.clone()) + .with_constraints(Constraints::default()) + .build(); + + // Create a new builder from the config + let new_builder = FileScanConfigBuilder::from(original_config); + + // Build a new config from this builder + let new_config = new_builder.build(); + + // Verify properties match + let partition_cols = partition_cols.into_iter().map(Arc::new).collect::>(); + assert_eq!(new_config.object_store_url, object_store_url); + assert_eq!(*new_config.file_schema(), schema); + assert_eq!( + new_config + .file_source + .projection() + .as_ref() + .map(|p| p.column_indices()), + Some(vec![0, 2]) + ); + assert_eq!(new_config.limit, Some(10)); + assert_eq!( + *new_config.table_partition_cols(), + Fields::from(partition_cols) + ); + assert_eq!(new_config.file_groups.len(), 1); + assert_eq!(new_config.file_groups[0].len(), 1); + assert_eq!( + new_config.file_groups[0][0].object_meta.location.as_ref(), + "test_file.parquet" + ); + assert_eq!(new_config.constraints, Constraints::default()); + } + + #[test] + fn test_split_groups_by_statistics_with_target_partitions() -> Result<()> { + use datafusion_common::DFSchema; + use datafusion_expr::{col, execution_props::ExecutionProps}; + + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Float64, + false, + )])); + + // Setup sort expression + let exec_props = ExecutionProps::new(); + let df_schema = DFSchema::try_from_qualified_schema("test", schema.as_ref())?; + let sort_expr = [col("value").sort(true, false)]; + let sort_ordering = sort_expr + .map(|expr| { + create_physical_sort_expr(&expr, &df_schema, &exec_props).unwrap() + }) + .into(); + + // Test case parameters + struct TestCase { + name: String, + file_count: usize, + overlap_factor: f64, + target_partitions: usize, + expected_partition_count: usize, + } + + let test_cases = vec![ + // Basic cases + TestCase { + name: "no_overlap_10_files_4_partitions".to_string(), + file_count: 10, + overlap_factor: 0.0, + target_partitions: 4, + expected_partition_count: 4, + }, + TestCase { + name: "medium_overlap_20_files_5_partitions".to_string(), + file_count: 20, + overlap_factor: 0.5, + target_partitions: 5, + expected_partition_count: 5, + }, + TestCase { + name: "high_overlap_30_files_3_partitions".to_string(), + file_count: 30, + overlap_factor: 0.8, + target_partitions: 3, + expected_partition_count: 7, + }, + // Edge cases + TestCase { + name: "fewer_files_than_partitions".to_string(), + file_count: 3, + overlap_factor: 0.0, + target_partitions: 10, + expected_partition_count: 3, // Should only create as many partitions as files + }, + TestCase { + name: "single_file".to_string(), + file_count: 1, + overlap_factor: 0.0, + target_partitions: 5, + expected_partition_count: 1, // Should create only one partition + }, + TestCase { + name: "empty_files".to_string(), + file_count: 0, + overlap_factor: 0.0, + target_partitions: 3, + expected_partition_count: 0, // Empty result for empty input + }, + ]; + + for case in test_cases { + println!("Running test case: {}", case.name); + + // Generate files using bench utility function + let file_groups = generate_test_files(case.file_count, case.overlap_factor); + + // Call the function under test + let result = + FileScanConfig::split_groups_by_statistics_with_target_partitions( + &schema, + &file_groups, + &sort_ordering, + case.target_partitions, + )?; + + // Verify results + println!( + "Created {} partitions (target was {})", + result.len(), + case.target_partitions + ); + + // Check partition count + assert_eq!( + result.len(), + case.expected_partition_count, + "Case '{}': Unexpected partition count", + case.name + ); + + // Verify sort integrity + assert!( + verify_sort_integrity(&result), + "Case '{}': Files within partitions are not properly ordered", + case.name + ); + + // Distribution check for partitions + if case.file_count > 1 && case.expected_partition_count > 1 { + let group_sizes: Vec = result.iter().map(FileGroup::len).collect(); + let max_size = *group_sizes.iter().max().unwrap(); + let min_size = *group_sizes.iter().min().unwrap(); + + // Check partition balancing - difference shouldn't be extreme + let avg_files_per_partition = + case.file_count as f64 / case.expected_partition_count as f64; + assert!( + (max_size as f64) < 2.0 * avg_files_per_partition, + "Case '{}': Unbalanced distribution. Max partition size {} exceeds twice the average {}", + case.name, + max_size, + avg_files_per_partition + ); + + println!("Distribution - min files: {min_size}, max files: {max_size}"); + } + } + + // Test error case: zero target partitions + let empty_groups: Vec = vec![]; + let err = FileScanConfig::split_groups_by_statistics_with_target_partitions( + &schema, + &empty_groups, + &sort_ordering, + 0, + ) + .unwrap_err(); + + assert!( + err.to_string() + .contains("target_partitions must be greater than 0"), + "Expected error for zero target partitions" + ); + + Ok(()) + } + + #[test] + fn test_partition_statistics_projection() { + // This test verifies that partition_statistics applies projection correctly. + // The old implementation had a bug where it returned file group statistics + // without applying the projection, returning all column statistics instead + // of just the projected ones. + + use crate::source::DataSourceExec; + use datafusion_physical_plan::ExecutionPlan; + + // Create a schema with 4 columns + let schema = Arc::new(Schema::new(vec![ + Field::new("col0", DataType::Int32, false), + Field::new("col1", DataType::Int32, false), + Field::new("col2", DataType::Int32, false), + Field::new("col3", DataType::Int32, false), + ])); + + // Create statistics for all 4 columns + let file_group_stats = Statistics { + num_rows: Precision::Exact(100), + total_byte_size: Precision::Exact(1024), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(0), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics { + null_count: Precision::Exact(5), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics { + null_count: Precision::Exact(10), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics { + null_count: Precision::Exact(15), + ..ColumnStatistics::new_unknown() + }, + ], + }; + + // Create a file group with statistics + let file_group = FileGroup::new(vec![PartitionedFile::new("test.parquet", 1024)]) + .with_statistics(Arc::new(file_group_stats)); + + let table_schema = TableSchema::from(&schema); + + // Create a FileScanConfig with projection: only keep columns 0 and 2 + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("test:///").unwrap(), + Arc::new(MockSource::new(table_schema.clone())), + ) + .with_projection_indices(Some(vec![0, 2])) + .unwrap() // Only project columns 0 and 2 + .with_file_groups(vec![file_group]) + .build(); + + // Create a DataSourceExec from the config + let exec = DataSourceExec::from_data_source(config); + + // Get statistics for partition 0 + let partition_stats = exec.partition_statistics(Some(0)).unwrap(); + + // Verify that only 2 columns are in the statistics (the projected ones) + assert_eq!( + partition_stats.column_statistics.len(), + 2, + "Expected 2 column statistics (projected), but got {}", + partition_stats.column_statistics.len() + ); + + // Verify the column statistics are for columns 0 and 2 + assert_eq!( + partition_stats.column_statistics[0].null_count, + Precision::Exact(0), + "First projected column should be col0 with 0 nulls" + ); + assert_eq!( + partition_stats.column_statistics[1].null_count, + Precision::Exact(10), + "Second projected column should be col2 with 10 nulls" + ); + + // Verify row count and byte size + assert_eq!(partition_stats.num_rows, Precision::Exact(100)); + assert_eq!(partition_stats.total_byte_size, Precision::Exact(800)); + } + + /// Regression test for reusing a `DataSourceExec` after its execution-local + /// shared work queue has been drained. + /// + /// This test uses a single file group with two files so the scan creates a + /// shared unopened-file queue. Executing after `reset_state` must recreate + /// the shared queue and return the same rows again. + #[tokio::test] + async fn reset_state_recreates_shared_work_source() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int32, + false, + )])); + let file_source = Arc::new( + MockSource::new(Arc::clone(&schema)) + .with_file_opener(Arc::new(ResetStateTestFileOpener { schema })), + ); + + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_group(FileGroup::new(vec![ + PartitionedFile::new("file1.parquet", 100), + PartitionedFile::new("file2.parquet", 100), + ])) + .build(); + + let exec: Arc = DataSourceExec::from_data_source(config); + let task_ctx = Arc::new(TaskContext::default()); + + // Running the same scan after resetting the state, should + // produce the same answer. + let first_run = collect(Arc::clone(&exec), Arc::clone(&task_ctx)).await?; + let reset_exec = exec.reset_state()?; + let second_run = collect(reset_exec, task_ctx).await?; + + let expected = [ + "+-------+", + "| value |", + "+-------+", + "| 1 |", + "| 2 |", + "+-------+", + ]; + assert_batches_eq!(expected, &first_run); + assert_batches_eq!(expected, &second_run); + + Ok(()) + } + + /// Test-only `FileOpener` that turns file names like `file1.parquet` into a + /// single-batch stream containing that numeric value + #[derive(Debug)] + struct ResetStateTestFileOpener { + schema: SchemaRef, + } + + impl crate::file_stream::FileOpener for ResetStateTestFileOpener { + fn open( + &self, + file: PartitionedFile, + ) -> Result { + let value = file + .object_meta + .location + .as_ref() + .trim_start_matches("file") + .trim_end_matches(".parquet") + .parse::() + .expect("invalid test file name"); + let schema = Arc::clone(&self.schema); + Ok(async move { + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from(vec![value]))], + ) + .expect("test batch should be valid"); + Ok(stream::iter(vec![Ok(batch)]).boxed()) + } + .boxed()) + } + } + + #[test] + fn test_output_partitioning_not_partitioned_by_file_group() { + let file_schema = aggr_test_schema(); + let partition_col = + Field::new("date", wrap_partition_type_in_dict(DataType::Utf8), false); + + let config = config_for_projection( + Arc::clone(&file_schema), + None, + Statistics::new_unknown(&file_schema), + vec![partition_col], + ); + + // partitioned_by_file_group defaults to false + let partitioning = config.output_partitioning(); + assert!(matches!(partitioning, Partitioning::UnknownPartitioning(_))); + } + + #[test] + fn test_output_partitioning_no_partition_columns() { + let file_schema = aggr_test_schema(); + let mut config = config_for_projection( + Arc::clone(&file_schema), + None, + Statistics::new_unknown(&file_schema), + vec![], // No partition columns + ); + config.partitioned_by_file_group = true; + + let partitioning = config.output_partitioning(); + assert!(matches!(partitioning, Partitioning::UnknownPartitioning(_))); + } + + #[test] + fn test_output_partitioning_with_partition_columns() { + let file_schema = aggr_test_schema(); + + // Test single partition column + let single_partition_col = vec![Field::new( + "date", + wrap_partition_type_in_dict(DataType::Utf8), + false, + )]; + + let mut config = config_for_projection( + Arc::clone(&file_schema), + None, + Statistics::new_unknown(&file_schema), + single_partition_col, + ); + config.partitioned_by_file_group = true; + config.file_groups = vec![ + FileGroup::new(vec![PartitionedFile::new("f1.parquet".to_string(), 1024)]), + FileGroup::new(vec![PartitionedFile::new("f2.parquet".to_string(), 1024)]), + FileGroup::new(vec![PartitionedFile::new("f3.parquet".to_string(), 1024)]), + ]; + + let partitioning = config.output_partitioning(); + match partitioning { + Partitioning::Hash(exprs, num_partitions) => { + assert_eq!(num_partitions, 3); + assert_eq!(exprs.len(), 1); + assert_eq!(exprs[0].downcast_ref::().unwrap().name(), "date"); + } + _ => panic!("Expected Hash partitioning"), + } + + // Test multiple partition columns + let multiple_partition_cols = vec![ + Field::new("year", wrap_partition_type_in_dict(DataType::Utf8), false), + Field::new("month", wrap_partition_type_in_dict(DataType::Utf8), false), + ]; + + config = config_for_projection( + Arc::clone(&file_schema), + None, + Statistics::new_unknown(&file_schema), + multiple_partition_cols, + ); + config.partitioned_by_file_group = true; + config.file_groups = vec![ + FileGroup::new(vec![PartitionedFile::new("f1.parquet".to_string(), 1024)]), + FileGroup::new(vec![PartitionedFile::new("f2.parquet".to_string(), 1024)]), + ]; + + let partitioning = config.output_partitioning(); + match partitioning { + Partitioning::Hash(exprs, num_partitions) => { + assert_eq!(num_partitions, 2); + assert_eq!(exprs.len(), 2); + let col_names: Vec<_> = exprs + .iter() + .map(|e| e.downcast_ref::().unwrap().name()) + .collect(); + assert_eq!(col_names, vec!["year", "month"]); + } + _ => panic!("Expected Hash partitioning"), + } + } + + #[test] + fn try_pushdown_sort_reverses_file_groups_only_when_requested_is_reverse() + -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + + let table_schema = TableSchema::from(&file_schema); + let file_source = Arc::new(InexactSortPushdownSource::new(table_schema)); + + let file_groups = vec![FileGroup::new(vec![ + PartitionedFile::new("file1", 1), + PartitionedFile::new("file2", 1), + ])]; + + let sort_expr_asc = PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .with_output_ordering(vec![ + LexOrdering::new(vec![sort_expr_asc.clone()]).unwrap(), + ]) + .build(); + + let requested_asc = vec![sort_expr_asc.clone()]; + let result = config.try_pushdown_sort(&requested_asc)?; + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("Expected Inexact result"); + }; + let pushed_config = inner + .downcast_ref::() + .expect("Expected FileScanConfig"); + let pushed_files = pushed_config.file_groups[0].files(); + assert_eq!(pushed_files[0].object_meta.location.as_ref(), "file1"); + assert_eq!(pushed_files[1].object_meta.location.as_ref(), "file2"); + + let requested_desc = vec![sort_expr_asc.reverse()]; + let result = config.try_pushdown_sort(&requested_desc)?; + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("Expected Inexact result"); + }; + let pushed_config = inner + .downcast_ref::() + .expect("Expected FileScanConfig"); + let pushed_files = pushed_config.file_groups[0].files(); + assert_eq!(pushed_files[0].object_meta.location.as_ref(), "file2"); + assert_eq!(pushed_files[1].object_meta.location.as_ref(), "file1"); + + Ok(()) + } + + fn make_file_with_stats(name: &str, min: f64, max: f64) -> PartitionedFile { + PartitionedFile::new(name.to_string(), 1024).with_statistics(Arc::new( + Statistics { + num_rows: Precision::Exact(100), + total_byte_size: Precision::Exact(1024), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + min_value: Precision::Exact(ScalarValue::Float64(Some(min))), + max_value: Precision::Exact(ScalarValue::Float64(Some(max))), + ..Default::default() + }], + }, + )) + } + + #[derive(Clone)] + struct ExactSortPushdownSource { + metrics: ExecutionPlanMetricsSet, + table_schema: TableSchema, + } + + impl ExactSortPushdownSource { + fn new(table_schema: TableSchema) -> Self { + Self { + metrics: ExecutionPlanMetricsSet::new(), + table_schema, + } + } + } + + impl FileSource for ExactSortPushdownSource { + fn create_file_opener( + &self, + _object_store: Arc, + _base_config: &FileScanConfig, + _partition: usize, + ) -> Result> { + unimplemented!() + } + + fn table_schema(&self) -> &TableSchema { + &self.table_schema + } + + fn with_batch_size(&self, _batch_size: usize) -> Arc { + Arc::new(self.clone()) + } + + fn metrics(&self) -> &ExecutionPlanMetricsSet { + &self.metrics + } + + fn file_type(&self) -> &str { + "mock_exact" + } + + fn try_pushdown_sort( + &self, + _order: &[PhysicalSortExpr], + _eq_properties: &EquivalenceProperties, + ) -> Result>> { + Ok(SortOrderPushdownResult::Exact { + inner: Arc::new(self.clone()) as Arc, + }) + } + } + + #[test] + fn sort_pushdown_unsupported_source_files_get_sorted() -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, false)])); + let table_schema = TableSchema::from(&file_schema); + let file_source = Arc::new(MockSource::new(table_schema)); + + let file_groups = vec![FileGroup::new(vec![ + make_file_with_stats("file3", 20.0, 30.0), + make_file_with_stats("file1", 0.0, 9.0), + make_file_with_stats("file2", 10.0, 19.0), + ])]; + + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .build(); + + let result = config.try_pushdown_sort(&[sort_expr])?; + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("Expected Inexact result, got {result:?}"); + }; + let pushed_config = inner + .downcast_ref::() + .expect("Expected FileScanConfig"); + let files = pushed_config.file_groups[0].files(); + assert_eq!(files[0].object_meta.location.as_ref(), "file1"); + assert_eq!(files[1].object_meta.location.as_ref(), "file2"); + assert_eq!(files[2].object_meta.location.as_ref(), "file3"); + assert!(pushed_config.output_ordering.is_empty()); + Ok(()) + } + + #[test] + fn sort_pushdown_unsupported_source_already_sorted() -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, false)])); + let table_schema = TableSchema::from(&file_schema); + let file_source = Arc::new(MockSource::new(table_schema)); + + let file_groups = vec![FileGroup::new(vec![ + make_file_with_stats("file1", 0.0, 9.0), + make_file_with_stats("file2", 10.0, 19.0), + make_file_with_stats("file3", 20.0, 30.0), + ])]; + + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .build(); + + let result = config.try_pushdown_sort(&[sort_expr])?; + assert!(matches!(result, SortOrderPushdownResult::Unsupported)); + Ok(()) + } + + #[test] + fn sort_pushdown_unsupported_source_descending_sort() -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, false)])); + let table_schema = TableSchema::from(&file_schema); + let file_source = Arc::new(MockSource::new(table_schema)); + + let file_groups = vec![FileGroup::new(vec![ + make_file_with_stats("file1", 0.0, 9.0), + make_file_with_stats("file3", 20.0, 30.0), + make_file_with_stats("file2", 10.0, 19.0), + ])]; + + let sort_expr = PhysicalSortExpr::new( + Arc::new(Column::new("a", 0)), + arrow::compute::SortOptions { + descending: true, + nulls_first: true, + }, + ); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .build(); + + let result = config.try_pushdown_sort(&[sort_expr])?; + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("Expected Inexact result"); + }; + let pushed_config = inner + .downcast_ref::() + .expect("Expected FileScanConfig"); + let files = pushed_config.file_groups[0].files(); + assert_eq!(files[0].object_meta.location.as_ref(), "file3"); + assert_eq!(files[1].object_meta.location.as_ref(), "file2"); + assert_eq!(files[2].object_meta.location.as_ref(), "file1"); + Ok(()) + } + + #[test] + fn sort_pushdown_exact_source_non_overlapping_returns_exact() -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, false)])); + let table_schema = TableSchema::from(&file_schema); + let file_source = Arc::new(ExactSortPushdownSource::new(table_schema)); + + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))); + + let file_groups = vec![FileGroup::new(vec![ + make_file_with_stats("file1", 0.0, 9.0), + make_file_with_stats("file2", 10.0, 19.0), + make_file_with_stats("file3", 20.0, 30.0), + ])]; + + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .with_output_ordering(vec![ + LexOrdering::new(vec![sort_expr.clone()]).unwrap(), + ]) + .build(); + + let result = config.try_pushdown_sort(&[sort_expr])?; + let SortOrderPushdownResult::Exact { inner } = result else { + panic!("Expected Exact result, got {result:?}"); + }; + let pushed_config = inner + .downcast_ref::() + .expect("Expected FileScanConfig"); + assert!(!pushed_config.output_ordering.is_empty()); + Ok(()) + } + + #[test] + fn sort_pushdown_exact_source_overlapping_downgraded_to_inexact() -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, false)])); + let table_schema = TableSchema::from(&file_schema); + let file_source = Arc::new(ExactSortPushdownSource::new(table_schema)); + + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))); + + let file_groups = vec![FileGroup::new(vec![ + make_file_with_stats("file1", 0.0, 15.0), + make_file_with_stats("file2", 10.0, 25.0), + make_file_with_stats("file3", 20.0, 30.0), + ])]; + + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .with_output_ordering(vec![ + LexOrdering::new(vec![sort_expr.clone()]).unwrap(), + ]) + .build(); + + let result = config.try_pushdown_sort(&[sort_expr])?; + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("Expected Inexact (downgraded), got {result:?}"); + }; + let pushed_config = inner + .downcast_ref::() + .expect("Expected FileScanConfig"); + assert!(pushed_config.output_ordering.is_empty()); + Ok(()) + } + + #[test] + fn sort_pushdown_exact_source_out_of_order_returns_exact() -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, false)])); + let table_schema = TableSchema::from(&file_schema); + let file_source = Arc::new(ExactSortPushdownSource::new(table_schema)); + + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))); + + let file_groups = vec![FileGroup::new(vec![ + make_file_with_stats("file3", 20.0, 30.0), + make_file_with_stats("file1", 0.0, 9.0), + make_file_with_stats("file2", 10.0, 19.0), + ])]; + + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .with_output_ordering(vec![ + LexOrdering::new(vec![sort_expr.clone()]).unwrap(), + ]) + .build(); + + let result = config.try_pushdown_sort(&[sort_expr])?; + let SortOrderPushdownResult::Exact { inner } = result else { + panic!("Expected Exact result, got {result:?}"); + }; + let pushed_config = inner + .downcast_ref::() + .expect("Expected FileScanConfig"); + let files = pushed_config.file_groups[0].files(); + assert_eq!(files[0].object_meta.location.as_ref(), "file1"); + assert_eq!(files[1].object_meta.location.as_ref(), "file2"); + assert_eq!(files[2].object_meta.location.as_ref(), "file3"); + assert!(!pushed_config.output_ordering.is_empty()); + Ok(()) + } + + #[test] + fn sort_pushdown_unsupported_source_single_file_groups() -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, false)])); + let table_schema = TableSchema::from(&file_schema); + let file_source = Arc::new(MockSource::new(table_schema)); + + let file_groups = vec![ + FileGroup::new(vec![make_file_with_stats("file1", 0.0, 9.0)]), + FileGroup::new(vec![make_file_with_stats("file2", 10.0, 19.0)]), + ]; + + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .build(); + + let result = config.try_pushdown_sort(&[sort_expr])?; + assert!( + matches!(result, SortOrderPushdownResult::Unsupported), + "Expected Unsupported for single-file groups" + ); + Ok(()) + } + + #[test] + fn sort_pushdown_unsupported_source_multiple_groups() -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, false)])); + let table_schema = TableSchema::from(&file_schema); + let file_source = Arc::new(MockSource::new(table_schema)); + + let file_groups = vec![ + FileGroup::new(vec![ + make_file_with_stats("file_b", 10.0, 19.0), + make_file_with_stats("file_a", 0.0, 9.0), + ]), + FileGroup::new(vec![ + make_file_with_stats("file_d", 30.0, 39.0), + make_file_with_stats("file_c", 20.0, 29.0), + ]), + ]; + + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .build(); + + let result = config.try_pushdown_sort(&[sort_expr])?; + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("Expected Inexact result"); + }; + let pushed_config = inner + .downcast_ref::() + .expect("Expected FileScanConfig"); + let files0 = pushed_config.file_groups[0].files(); + assert_eq!(files0[0].object_meta.location.as_ref(), "file_a"); + assert_eq!(files0[1].object_meta.location.as_ref(), "file_b"); + let files1 = pushed_config.file_groups[1].files(); + assert_eq!(files1[0].object_meta.location.as_ref(), "file_c"); + assert_eq!(files1[1].object_meta.location.as_ref(), "file_d"); + Ok(()) + } + + #[test] + fn sort_pushdown_unsupported_source_partial_statistics() -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, false)])); + let table_schema = TableSchema::from(&file_schema); + let file_source = Arc::new(MockSource::new(table_schema)); + + let file_groups = vec![ + FileGroup::new(vec![ + make_file_with_stats("file_b", 10.0, 19.0), + make_file_with_stats("file_a", 0.0, 9.0), + ]), + FileGroup::new(vec![ + PartitionedFile::new("file_d".to_string(), 1024), + PartitionedFile::new("file_c".to_string(), 1024), + ]), + ]; + + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .build(); + + let result = config.try_pushdown_sort(&[sort_expr])?; + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("Expected Inexact result"); + }; + let pushed_config = inner + .downcast_ref::() + .expect("Expected FileScanConfig"); + let files0 = pushed_config.file_groups[0].files(); + assert_eq!(files0[0].object_meta.location.as_ref(), "file_a"); + assert_eq!(files0[1].object_meta.location.as_ref(), "file_b"); + let files1 = pushed_config.file_groups[1].files(); + assert_eq!(files1[0].object_meta.location.as_ref(), "file_d"); + assert_eq!(files1[1].object_meta.location.as_ref(), "file_c"); + Ok(()) + } + + #[test] + fn sort_pushdown_inexact_source_with_statistics_sorting() -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, false)])); + let table_schema = TableSchema::from(&file_schema); + let file_source = Arc::new(InexactSortPushdownSource::new(table_schema)); + + let file_groups = vec![FileGroup::new(vec![ + make_file_with_stats("file2", 10.0, 19.0), + make_file_with_stats("file1", 0.0, 9.0), + ])]; + + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))); + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .build(); + + let result = config.try_pushdown_sort(&[sort_expr])?; + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("Expected Inexact result"); + }; + let pushed_config = inner + .downcast_ref::() + .expect("Expected FileScanConfig"); + let files = pushed_config.file_groups[0].files(); + assert_eq!(files[0].object_meta.location.as_ref(), "file1"); + assert_eq!(files[1].object_meta.location.as_ref(), "file2"); + assert!(pushed_config.output_ordering.is_empty()); + Ok(()) + } + + #[test] + fn sort_pushdown_exact_multi_group_preserves_parallelism() -> Result<()> { + // ExactSortPushdownSource + 4 non-overlapping files in 2 interleaved groups. + // Groups should NOT be redistributed — interleaved groups allow SPM to + // pull from both partitions concurrently, keeping parallel I/O active. + // Redistributing consecutively would make SPM read one partition at a + // time (all values in group 0 < group 1), degrading to single-threaded I/O. + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, false)])); + let table_schema = TableSchema::from(&file_schema); + let file_source = Arc::new(ExactSortPushdownSource::new(table_schema)); + + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))); + + // 2 groups with interleaved ranges (simulating bin-packing result): + // Group 0: [file_01(0-9), file_03(20-29)] + // Group 1: [file_02(10-19), file_04(30-39)] + let file_groups = vec![ + FileGroup::new(vec![ + make_file_with_stats("file_01", 0.0, 9.0), + make_file_with_stats("file_03", 20.0, 29.0), + ]), + FileGroup::new(vec![ + make_file_with_stats("file_02", 10.0, 19.0), + make_file_with_stats("file_04", 30.0, 39.0), + ]), + ]; + + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .with_output_ordering(vec![ + LexOrdering::new(vec![sort_expr.clone()]).unwrap(), + ]) + .build(); + + let result = config.try_pushdown_sort(&[sort_expr])?; + let SortOrderPushdownResult::Exact { inner } = result else { + panic!("Expected Exact result, got {result:?}"); + }; + let pushed_config = inner + .downcast_ref::() + .expect("Expected FileScanConfig"); + + // 2 groups preserved (parallelism maintained) + assert_eq!(pushed_config.file_groups.len(), 2); + + // Files within each group are sorted by stats, but groups are NOT + // redistributed — interleaved assignment from bin-packing is kept + let files0 = pushed_config.file_groups[0].files(); + assert_eq!(files0[0].object_meta.location.as_ref(), "file_01"); + assert_eq!(files0[1].object_meta.location.as_ref(), "file_03"); + let files1 = pushed_config.file_groups[1].files(); + assert_eq!(files1[0].object_meta.location.as_ref(), "file_02"); + assert_eq!(files1[1].object_meta.location.as_ref(), "file_04"); + + // output_ordering preserved (Exact, each group internally non-overlapping) + assert!(!pushed_config.output_ordering.is_empty()); + Ok(()) + } + + #[test] + fn sort_pushdown_reverse_preserves_file_order_with_stats() -> Result<()> { + // Reverse scan should reverse file order but NOT apply statistics-based + // sorting (which would undo the reversal). The result is Inexact. + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, false)])); + let table_schema = TableSchema::from(&file_schema); + let file_source = Arc::new(InexactSortPushdownSource::new(table_schema)); + + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))); + + // Files with stats, in ASC order. Output ordering is [a ASC]. + let file_groups = vec![FileGroup::new(vec![ + make_file_with_stats("file1", 0.0, 9.0), + make_file_with_stats("file2", 10.0, 19.0), + make_file_with_stats("file3", 20.0, 30.0), + ])]; + + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .with_output_ordering(vec![ + LexOrdering::new(vec![sort_expr.clone()]).unwrap(), + ]) + .build(); + + // Request DESC → reverse path + let result = config.try_pushdown_sort(&[sort_expr.reverse()])?; + let SortOrderPushdownResult::Inexact { inner } = result else { + panic!("Expected Inexact for reverse scan, got {result:?}"); + }; + let pushed_config = inner + .downcast_ref::() + .expect("Expected FileScanConfig"); + + // Files should be reversed (not re-sorted by stats) + let files = pushed_config.file_groups[0].files(); + assert_eq!(files[0].object_meta.location.as_ref(), "file3"); + assert_eq!(files[1].object_meta.location.as_ref(), "file2"); + assert_eq!(files[2].object_meta.location.as_ref(), "file1"); + + // output_ordering cleared (Inexact) + assert!(pushed_config.output_ordering.is_empty()); + Ok(()) + } + + /// Helper: create a PartitionedFile with stats including null count + fn make_file_with_null_stats( + name: &str, + min: f64, + max: f64, + null_count: usize, + ) -> PartitionedFile { + PartitionedFile::new(name.to_string(), 1024).with_statistics(Arc::new( + Statistics { + num_rows: Precision::Exact(100), + total_byte_size: Precision::Exact(1024), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(null_count), + min_value: Precision::Exact(ScalarValue::Float64(Some(min))), + max_value: Precision::Exact(ScalarValue::Float64(Some(max))), + ..Default::default() + }], + }, + )) + } + + #[test] + fn sort_pushdown_unsupported_with_nulls_does_not_upgrade_to_exact() -> Result<()> { + // Files are non-overlapping but one has NULLs. + // Should NOT upgrade to Exact — NULLs would appear in wrong position. + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)])); + let table_schema = TableSchema::from(&file_schema); + let file_source = Arc::new(MockSource::new(table_schema)); + + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))); + + // Files in wrong order (high min first) to trigger reordering + let file_groups = vec![FileGroup::new(vec![ + make_file_with_null_stats("b_no_nulls", 10.0, 19.0, 0), + make_file_with_null_stats("a_with_nulls", 0.0, 9.0, 5), // has NULLs + ])]; + + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .with_output_ordering(vec![ + LexOrdering::new(vec![sort_expr.clone()]).unwrap(), + ]) + .build(); + + let result = config.try_pushdown_sort(&[sort_expr])?; + // Should be Inexact (not Exact) because of NULLs + assert!( + matches!(result, SortOrderPushdownResult::Inexact { .. }), + "Expected Inexact due to NULLs, got {result:?}" + ); + Ok(()) + } + + #[test] + fn sort_pushdown_unsupported_no_nulls_upgrades_to_exact() -> Result<()> { + // Files are non-overlapping, no NULLs → should upgrade to Exact + let file_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)])); + let table_schema = TableSchema::from(&file_schema); + let file_source = Arc::new(MockSource::new(table_schema)); + + let sort_expr = PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0))); + + let file_groups = vec![FileGroup::new(vec![ + make_file_with_null_stats("b_high", 10.0, 19.0, 0), + make_file_with_null_stats("a_low", 0.0, 9.0, 0), + ])]; + + let config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(file_groups) + .with_output_ordering(vec![ + LexOrdering::new(vec![sort_expr.clone()]).unwrap(), + ]) + .build(); + + let result = config.try_pushdown_sort(&[sort_expr])?; + assert!( + matches!(result, SortOrderPushdownResult::Exact { .. }), + "Expected Exact (no NULLs), got {result:?}" + ); + Ok(()) + } +} diff --git a/datafusion/datasource/src/file_scan_config/sort_pushdown.rs b/datafusion/datasource/src/file_scan_config/sort_pushdown.rs new file mode 100644 index 0000000000000..3f5beed20fa8d --- /dev/null +++ b/datafusion/datasource/src/file_scan_config/sort_pushdown.rs @@ -0,0 +1,622 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sort pushdown helpers for [`FileScanConfig`]. +//! +//! This module contains the statistics-based file sorting, non-overlapping +//! validation, and NULL handling logic used by +//! [`FileScanConfig::try_pushdown_sort`](super::FileScanConfig::try_pushdown_sort). +//! +//! Extracted from `file_scan_config.rs` to keep that module focused on +//! core configuration and data-source plumbing. + +use super::FileScanConfig; +use crate::file::FileSource; +use crate::file_groups::FileGroup; +use crate::source::DataSource; +use crate::statistics::MinMaxStatistics; + +use arrow::datatypes::SchemaRef; +use datafusion_common::Result; +use datafusion_common::stats::Precision; +use datafusion_physical_expr::equivalence::project_orderings; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::projection::ProjectionExprs; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::SortOrderPushdownResult; +use log::debug; +use std::sync::Arc; + +/// Result of sorting files within groups by their min/max statistics. +pub(crate) struct SortedFileGroups { + file_groups: Vec, + any_reordered: bool, + all_non_overlapping: bool, +} + +impl FileScanConfig { + /// + /// This is the core of sort pushdown for file-based sources. It performs + /// three optimizations depending on the pushdown result: + /// + /// ```text + /// ┌─────────────────────────────────────────────────────────────┐ + /// │ rebuild_with_source │ + /// │ │ + /// │ 1. Reverse file groups (if DESC matches reversed ordering) │ + /// │ 2. Sort files within groups by min/max statistics │ + /// │ 3. If Exact + non-overlapping: │ + /// │ Keep output_ordering → SortExec eliminated │ + /// │ Otherwise: clear output_ordering → SortExec stays │ + /// └─────────────────────────────────────────────────────────────┘ + /// ``` + /// + /// # Why sort files by statistics? + /// + /// Files within a partition (file group) are read sequentially. By sorting + /// them so that file_i.max <= file_{i+1}.min, the combined output stream + /// is already in order — no SortExec needed for that partition. + /// + /// Even when files overlap (Inexact), statistics-based ordering helps + /// TopK/LIMIT queries: reading low-value files first lets dynamic filters + /// prune high-value files earlier. + pub(crate) fn rebuild_with_source( + &self, + new_file_source: Arc, + is_exact: bool, + order: &[PhysicalSortExpr], + ) -> Result { + let mut new_config = self.clone(); + + // Reverse file order (within each group) if the caller is requesting a reversal of this + // scan's declared output ordering. + let reverse_file_groups = if self.output_ordering.is_empty() { + false + } else if let Some(requested) = LexOrdering::new(order.iter().cloned()) { + let projected_schema = self.projected_schema()?; + let orderings = project_orderings(&self.output_ordering, &projected_schema); + orderings + .iter() + .any(|ordering| ordering.is_reverse(&requested)) + } else { + false + }; + + if reverse_file_groups { + new_config.file_groups = new_config + .file_groups + .into_iter() + .map(|group| { + let mut files = group.into_inner(); + files.reverse(); + files.into() + }) + .collect(); + } + + new_config.file_source = new_file_source; + + // Sort files within groups by statistics when not reversing + let all_non_overlapping = if !reverse_file_groups { + if let Some(sort_order) = LexOrdering::new(order.iter().cloned()) { + let projected_schema = new_config.projected_schema()?; + let projection_indices = new_config + .file_source + .projection() + .as_ref() + .and_then(|p| ordered_column_indices_from_projection(p)); + let result = sort_files_within_groups_by_statistics( + &new_config.file_groups, + &sort_order, + &projected_schema, + projection_indices.as_deref(), + ); + new_config.file_groups = result.file_groups; + result.all_non_overlapping + } else { + false + } + } else { + // When reversing, files are already reversed above. We skip + // statistics-based sorting here because it would undo the reversal. + // Note: reverse path is always Inexact, so all_non_overlapping + // is not used (is_exact is false). + false + }; + + // Decide whether to keep `output_ordering` (i.e. let the outer + // pushdown report `Exact` and drop `SortExec`). + // + // Two paths can produce a keep: + // + // 1. `is_exact && all_non_overlapping`: the source already had + // validated ordering and the post-sort files still don't + // overlap — Exact carries through unchanged. + // + // 2. `!is_exact && all_non_overlapping`: source returned + // `Inexact` because pre-sort `validated_output_ordering()` + // stripped the declaration (files were listed out of order + // on disk). After our stats-based sort the files are now + // non-overlapping — re-validate against the new file + // groups and, if it passes, upgrade back to Exact so the + // outer wrapper drops the `SortExec`. Without this, the + // `Inexact` branch stayed Inexact even when reorder + // restored a perfectly valid ordering, leaving an + // unnecessary `SortExec` above the source (regression + // after #21956's `column_in_file_schema` signal pushed + // this scenario into the Inexact branch instead of the + // `try_sort_file_groups_by_statistics` fallback). + // + // We intentionally do NOT redistribute files across groups here. + // The planning-phase bin-packing may interleave file ranges across groups: + // + // Group 0: [f1(1-10), f3(21-30)] ← interleaved with group 1 + // Group 1: [f2(11-20), f4(31-40)] + // + // This interleaving is actually beneficial because SPM pulls from both + // partitions concurrently, keeping parallel I/O active. + let keep_ordering = match (all_non_overlapping, is_exact) { + // Files still overlap after the stats sort — the combined + // stream isn't ordered, so `output_ordering` must be dropped. + (false, _) => false, + // Source already had validated ordering and the post-sort + // files still don't overlap — Exact carries through. + (true, true) => true, + // Source returned `Inexact`; re-validate against the + // reordered file groups to decide whether to upgrade. + // + // Same NULL guard as `try_sort_file_groups_by_statistics`: + // we cannot claim Exact if any non-last file contains + // NULLs in the sort columns. With NULLS LAST those + // NULLs sit after all non-null rows in the file, so + // when the next file's non-nulls are smaller than the + // previous file's max, they'd appear *after* the NULLs + // in the concatenated stream — breaking the ordering. + (true, false) => { + let projected_schema = new_config.projected_schema()?; + let projection_indices = new_config + .file_source + .projection() + .as_ref() + .and_then(|p| ordered_column_indices_from_projection(p)); + if any_file_has_nulls_in_sort_columns( + &new_config.file_groups, + order, + &projected_schema, + projection_indices.as_deref(), + ) { + false + } else { + let new_eq_props = new_config.eq_properties(); + new_eq_props.ordering_satisfy(order.iter().cloned())? + } + } + }; + + if !keep_ordering { + new_config.output_ordering = vec![]; + } + + Ok(new_config) + } + + /// Last-resort optimization when FileSource returns `Unsupported`. + /// + /// FileSource may return `Unsupported` because `eq_properties` had no + /// ordering — which happens when `validated_output_ordering()` stripped + /// the ordering because files were in the wrong order. After sorting + /// files by statistics, the ordering may become valid again. + /// + /// This method: + /// 1. Sorts files within groups by min/max statistics + /// 2. Re-checks if the sorted file order makes `output_ordering` valid + /// 3. If valid AND non-overlapping → `Exact` (SortExec eliminated!) + /// 4. If files were reordered but ordering not valid → `Inexact` + /// 5. If no files were reordered → `Unsupported` + /// + /// This handles the key case where files have correct within-file ordering + /// (e.g., Parquet sorting_columns metadata) but were listed in wrong order + /// (e.g., alphabetical order doesn't match sort key order). + pub(crate) fn try_sort_file_groups_by_statistics( + &self, + order: &[PhysicalSortExpr], + ) -> Result>> { + let Some(sort_order) = LexOrdering::new(order.iter().cloned()) else { + return Ok(SortOrderPushdownResult::Unsupported); + }; + + let projected_schema = self.projected_schema()?; + let projection_indices = self + .file_source + .projection() + .as_ref() + .and_then(|p| ordered_column_indices_from_projection(p)); + + let result = sort_files_within_groups_by_statistics( + &self.file_groups, + &sort_order, + &projected_schema, + projection_indices.as_deref(), + ); + + if !result.any_reordered { + return Ok(SortOrderPushdownResult::Unsupported); + } + + let mut new_config = self.clone(); + new_config.file_groups = result.file_groups; + + // Re-check: now that files are sorted, does output_ordering become valid? + // This handles the case where validated_output_ordering() previously + // stripped the ordering because files were in the wrong order. + // + // IMPORTANT: We cannot claim Exact if any file in a non-last position + // contains NULLs in the sort columns. With NULLS LAST, NULLs within + // a file are placed after all non-null values. If the next file has + // non-null values smaller than the previous file's max, those values + // would incorrectly appear after the NULLs. Similarly for NULLS FIRST. + // + // Conservative approach: if any file has nulls in the sort columns, + // do not claim Exact. The SortExec will handle NULL ordering correctly. + if result.all_non_overlapping + && !self.output_ordering.is_empty() + && !any_file_has_nulls_in_sort_columns( + &new_config.file_groups, + order, + &projected_schema, + projection_indices.as_deref(), + ) + { + // Files are now non-overlapping, no NULLs in sort columns. + // Re-ask the FileSource if this ordering satisfies the request, + // using eq_properties computed from the NEW (sorted) file groups. + let new_eq_props = new_config.eq_properties(); + if new_eq_props.ordering_satisfy(order.iter().cloned())? { + // The sorted file order makes the ordering valid → Exact! + return Ok(SortOrderPushdownResult::Exact { + inner: Arc::new(new_config), + }); + } + } + + new_config.output_ordering = vec![]; + Ok(SortOrderPushdownResult::Inexact { + inner: Arc::new(new_config), + }) + } +} + +/// Sort files within each file group by their min/max statistics. +/// +/// No files are moved between groups — parallelism and group composition +/// are unchanged. Groups where statistics are unavailable are kept as-is. +/// +/// ```text +/// Before: Group [file_c(20-30), file_a(0-9), file_b(10-19)] +/// After: Group [file_a(0-9), file_b(10-19), file_c(20-30)] +/// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +/// sorted by min value, non-overlapping → Exact +/// ``` +pub(crate) fn sort_files_within_groups_by_statistics( + file_groups: &[FileGroup], + sort_order: &LexOrdering, + projected_schema: &SchemaRef, + projection_indices: Option<&[usize]>, +) -> SortedFileGroups { + let mut any_reordered = false; + let mut confirmed_non_overlapping: usize = 0; + let mut new_groups = Vec::with_capacity(file_groups.len()); + + for group in file_groups { + if group.len() <= 1 { + new_groups.push(group.clone()); + confirmed_non_overlapping += 1; + continue; + } + + let files: Vec<_> = group.iter().collect(); + + let statistics = match MinMaxStatistics::new_from_files( + sort_order, + projected_schema, + projection_indices, + files.iter().copied(), + ) { + Ok(stats) => stats, + Err(e) => { + log::trace!( + "Cannot sort file group by statistics: {e}. Keeping original order." + ); + new_groups.push(group.clone()); + continue; + } + }; + + let sorted_indices = statistics.min_values_sorted(); + + let already_sorted = sorted_indices + .iter() + .enumerate() + .all(|(pos, (idx, _))| pos == *idx); + + let sorted_group: FileGroup = if already_sorted { + group.clone() + } else { + any_reordered = true; + sorted_indices + .iter() + .map(|(idx, _)| files[*idx].clone()) + .collect() + }; + + let sorted_files: Vec<_> = sorted_group.iter().collect(); + let is_non_overlapping = match MinMaxStatistics::new_from_files( + sort_order, + projected_schema, + projection_indices, + sorted_files.iter().copied(), + ) { + Ok(stats) => stats.is_sorted(), + Err(_) => false, + }; + + if is_non_overlapping { + confirmed_non_overlapping += 1; + } + + new_groups.push(sorted_group); + } + + SortedFileGroups { + file_groups: new_groups, + any_reordered, + all_non_overlapping: confirmed_non_overlapping == file_groups.len(), + } +} + +/// Check if any file in any group has nulls in the sort columns. +pub(crate) fn any_file_has_nulls_in_sort_columns( + file_groups: &[FileGroup], + order: &[PhysicalSortExpr], + projected_schema: &SchemaRef, + projection_indices: Option<&[usize]>, +) -> bool { + let Some(sort_columns) = + sort_columns_from_physical_sort_exprs_nullable(order, projected_schema) + else { + return true; // Can't determine, assume nulls exist + }; + + for group in file_groups { + for file in group.iter() { + let Some(stats) = file.statistics.as_ref() else { + return true; // No stats, assume nulls exist + }; + for col in &sort_columns { + let stat_idx = projection_indices + .map(|p| p[col.index()]) + .unwrap_or_else(|| col.index()); + if stat_idx >= stats.column_statistics.len() { + return true; + } + let col_stats = &stats.column_statistics[stat_idx]; + match &col_stats.null_count { + Precision::Exact(0) => {} // No nulls, safe + Precision::Exact(_) => return true, // Has nulls + _ => return true, // Unknown null count, assume nulls + } + } + } + } + false +} + +/// Get the indices of columns in a projection if the projection is a simple +/// list of columns. +/// If there are any expressions other than columns, returns None. +pub(crate) fn ordered_column_indices_from_projection( + projection: &ProjectionExprs, +) -> Option> { + projection + .expr_iter() + .map(|e| { + let index = e.downcast_ref::()?.index(); + Some(index) + }) + .collect::>>() +} + +/// Extract Column references from sort expressions for null checking. +fn sort_columns_from_physical_sort_exprs_nullable( + order: &[PhysicalSortExpr], + _schema: &SchemaRef, +) -> Option> { + order + .iter() + .map(|expr| expr.expr.downcast_ref::().cloned()) + .collect() +} + +/// Check whether a given ordering is valid for all file groups by verifying +/// that files within each group are sorted according to their min/max statistics. +/// +/// For single-file (or empty) groups, the ordering is trivially valid. +/// For multi-file groups, we check that the min/max statistics for the sort +/// columns are in order and non-overlapping (or touching at boundaries). +/// +/// `projection` maps projected column indices back to table-schema indices +/// when validating after projection; pass `None` when validating at +/// table-schema level. +pub(crate) fn is_ordering_valid_for_file_groups( + file_groups: &[FileGroup], + ordering: &LexOrdering, + schema: &SchemaRef, + projection: Option<&[usize]>, +) -> bool { + file_groups.iter().all(|group| { + if group.len() <= 1 { + return true; // single-file groups are trivially sorted + } + match MinMaxStatistics::new_from_files(ordering, schema, projection, group.iter()) + { + Ok(stats) => stats.is_sorted(), + Err(_) => false, // can't prove sorted → reject + } + }) +} + +/// Filters orderings to retain only those valid for all file groups, +/// verified via min/max statistics. +pub(crate) fn validate_orderings( + orderings: &[LexOrdering], + schema: &SchemaRef, + file_groups: &[FileGroup], + projection: Option<&[usize]>, +) -> Vec { + orderings + .iter() + .filter(|ordering| { + is_ordering_valid_for_file_groups(file_groups, ordering, schema, projection) + }) + .cloned() + .collect() +} + +/// The various listing tables do not attempt to read all files +/// concurrently, instead they will read files in sequence within a +/// partition. This is an important property as it allows plans to +/// run against 1000s of files and not try to open them all +/// concurrently. +/// +/// However, it means if we assign more than one file to a partition +/// the output sort order will not be preserved as illustrated in the +/// following diagrams: +/// +/// When only 1 file is assigned to each partition, each partition is +/// correctly sorted on `(A, B, C)` +/// +/// ```text +/// ┏ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ┓ +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// ┃ ┌───────────────┐ ┌──────────────┐ │ ┌──────────────┐ │ ┌─────────────┐ ┃ +/// │ │ 1.parquet │ │ │ │ 2.parquet │ │ │ 3.parquet │ │ │ 4.parquet │ │ +/// ┃ │ Sort: A, B, C │ │Sort: A, B, C │ │ │Sort: A, B, C │ │ │Sort: A, B, C│ ┃ +/// │ └───────────────┘ │ │ └──────────────┘ │ └──────────────┘ │ └─────────────┘ │ +/// ┃ │ │ ┃ +/// │ │ │ │ │ │ +/// ┃ │ │ ┃ +/// │ │ │ │ │ │ +/// ┃ │ │ ┃ +/// │ │ │ │ │ │ +/// ┃ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┃ +/// DataFusion DataFusion DataFusion DataFusion +/// ┃ Partition 1 Partition 2 Partition 3 Partition 4 ┃ +/// ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ +/// +/// DataSourceExec +/// ``` +/// +/// However, when more than 1 file is assigned to each partition, each +/// partition is NOT correctly sorted on `(A, B, C)`. Once the second +/// file is scanned, the same values for A, B and C can be repeated in +/// the same sorted stream +/// +/// ```text +/// ┏ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┃ +/// ┃ ┌───────────────┐ ┌──────────────┐ │ +/// │ │ 1.parquet │ │ │ │ 2.parquet │ ┃ +/// ┃ │ Sort: A, B, C │ │Sort: A, B, C │ │ +/// │ └───────────────┘ │ │ └──────────────┘ ┃ +/// ┃ ┌───────────────┐ ┌──────────────┐ │ +/// │ │ 3.parquet │ │ │ │ 4.parquet │ ┃ +/// ┃ │ Sort: A, B, C │ │Sort: A, B, C │ │ +/// │ └───────────────┘ │ │ └──────────────┘ ┃ +/// ┃ │ +/// │ │ │ ┃ +/// ┃ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// DataFusion DataFusion ┃ +/// ┃ Partition 1 Partition 2 +/// ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ━ ┛ +/// +/// DataSourceExec +/// ``` +/// +/// **Exception**: When files within a partition are **non-overlapping** (verified +/// via min/max statistics) and each file is internally sorted, the combined +/// output is still correctly sorted. Sort pushdown +/// ([`FileScanConfig::try_pushdown_sort`]) detects this case and preserves +/// `output_ordering`, allowing `SortExec` to be eliminated entirely. +/// +/// ```text +/// Partition 1 (files sorted by stats, non-overlapping): +/// ┌──────────────────┐ ┌──────────────────┐ ┌──────────────────┐ +/// │ 1.parquet │ │ 2.parquet │ │ 3.parquet │ +/// │ A: [1..100] │ │ A: [101..200] │ │ A: [201..300] │ +/// │ Sort: A, B, C │ │ Sort: A, B, C │ │ Sort: A, B, C │ +/// └──────────────────┘ └──────────────────┘ └──────────────────┘ +/// max(1) <= min(2) ✓ max(2) <= min(3) ✓ → output_ordering preserved +/// ``` +pub(crate) fn get_projected_output_ordering( + base_config: &FileScanConfig, + projected_schema: &SchemaRef, +) -> Vec { + let projected_orderings = + project_orderings(&base_config.output_ordering, projected_schema); + + let indices = base_config + .file_source + .projection() + .as_ref() + .map(|p| ordered_column_indices_from_projection(p)); + + match indices { + Some(Some(indices)) => { + // Simple column projection — validate with statistics + validate_orderings( + &projected_orderings, + projected_schema, + &base_config.file_groups, + Some(indices.as_slice()), + ) + } + None => { + // No projection — validate with statistics (no remapping needed) + validate_orderings( + &projected_orderings, + projected_schema, + &base_config.file_groups, + None, + ) + } + Some(None) => { + // Complex projection (expressions, not simple columns) — can't + // determine column indices for statistics. Still valid if all + // file groups have at most one file. + if base_config.file_groups.iter().all(|g| g.len() <= 1) { + projected_orderings + } else { + debug!( + "Skipping specified output orderings. \ + Some file groups couldn't be determined to be sorted: {:?}", + base_config.file_groups + ); + vec![] + } + } + } +} diff --git a/datafusion/datasource/src/file_sink_config.rs b/datafusion/datasource/src/file_sink_config.rs index 2968bd1ee0449..1abce86a3565f 100644 --- a/datafusion/datasource/src/file_sink_config.rs +++ b/datafusion/datasource/src/file_sink_config.rs @@ -17,10 +17,10 @@ use std::sync::Arc; +use crate::ListingTableUrl; use crate::file_groups::FileGroup; use crate::sink::DataSink; -use crate::write::demux::{start_demuxer_task, DemuxedStreamReceiver}; -use crate::ListingTableUrl; +use crate::write::demux::{DemuxedStreamReceiver, start_demuxer_task}; use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::Result; @@ -32,6 +32,52 @@ use datafusion_expr::dml::InsertOp; use async_trait::async_trait; use object_store::ObjectStore; +/// Determines how `FileSink` output paths are interpreted. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum FileOutputMode { + /// Infer output mode from the output URL (for example, by extension / trailing `/`). + #[default] + Automatic, + /// Write to a single output file at the exact output path. + SingleFile, + /// Write to a directory under the output path with generated filenames. + Directory, +} + +impl FileOutputMode { + /// Resolve this mode into a `single_file_output` boolean for the demuxer. + pub fn single_file_output(self, base_output_path: &ListingTableUrl) -> bool { + match self { + Self::Automatic => { + !base_output_path.is_collection() + && base_output_path.file_extension().is_some() + } + Self::SingleFile => true, + Self::Directory => false, + } + } +} + +impl From> for FileOutputMode { + fn from(value: Option) -> Self { + match value { + None => Self::Automatic, + Some(true) => Self::SingleFile, + Some(false) => Self::Directory, + } + } +} + +impl From for Option { + fn from(value: FileOutputMode) -> Self { + match value { + FileOutputMode::Automatic => None, + FileOutputMode::SingleFile => Some(true), + FileOutputMode::Directory => Some(false), + } + } +} + /// General behaviors for files that do `DataSink` operations #[async_trait] pub trait FileSink: DataSink { @@ -112,6 +158,8 @@ pub struct FileSinkConfig { pub keep_partition_by_columns: bool, /// File extension without a dot(.) pub file_extension: String, + /// Determines how the output path is interpreted. + pub file_output_mode: FileOutputMode, } impl FileSinkConfig { diff --git a/datafusion/datasource/src/file_stream/builder.rs b/datafusion/datasource/src/file_stream/builder.rs new file mode 100644 index 0000000000000..7034e902550a9 --- /dev/null +++ b/datafusion/datasource/src/file_stream/builder.rs @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::file_scan_config::FileScanConfig; +use crate::file_stream::scan_state::ScanState; +use crate::file_stream::work_source::{SharedWorkSource, WorkSource}; +use crate::morsel::{FileOpenerMorselizer, Morselizer}; +use datafusion_common::{Result, internal_err}; +use datafusion_physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; + +use super::metrics::FileStreamMetrics; +use super::{FileOpener, FileStream, FileStreamState, OnError}; + +/// Builder for constructing a [`FileStream`]. +pub struct FileStreamBuilder<'a> { + config: &'a FileScanConfig, + partition: Option, + morselizer: Option>, + metrics: Option<&'a ExecutionPlanMetricsSet>, + on_error: OnError, + shared_work_source: Option, +} + +impl<'a> FileStreamBuilder<'a> { + /// Create a new builder for [`FileStream`]. + pub fn new(config: &'a FileScanConfig) -> Self { + Self { + config, + partition: None, + morselizer: None, + metrics: None, + on_error: OnError::Fail, + shared_work_source: None, + } + } + + /// Configure the partition to scan. + pub fn with_partition(mut self, partition: usize) -> Self { + self.partition = Some(partition); + self + } + + /// Configure the [`FileOpener`] used to open files. + /// + /// This will overwrite any setting from [`Self::with_morselizer`] + pub fn with_file_opener(mut self, file_opener: Arc) -> Self { + self.morselizer = Some(Box::new(FileOpenerMorselizer::new(file_opener))); + self + } + + /// Configure the [`Morselizer`] used to open files. + /// + /// This will overwrite any setting from [`Self::with_file_opener`] + pub fn with_morselizer(mut self, morselizer: Box) -> Self { + self.morselizer = Some(morselizer); + self + } + + /// Configure the metrics set used by the stream. + pub fn with_metrics(mut self, metrics: &'a ExecutionPlanMetricsSet) -> Self { + self.metrics = Some(metrics); + self + } + + /// Configure the behavior when opening or scanning a file fails. + pub fn with_on_error(mut self, on_error: OnError) -> Self { + self.on_error = on_error; + self + } + + /// Configure the [`SharedWorkSource`] for sibling work stealing. + pub(crate) fn with_shared_work_source( + mut self, + shared_work_source: Option, + ) -> Self { + self.shared_work_source = shared_work_source; + self + } + + /// Build the configured [`FileStream`]. + pub fn build(self) -> Result { + let Self { + config, + partition, + morselizer, + metrics, + on_error, + shared_work_source, + } = self; + + let Some(partition) = partition else { + return internal_err!("FileStreamBuilder missing required partition"); + }; + let Some(morselizer) = morselizer else { + return internal_err!("FileStreamBuilder missing required morselizer"); + }; + let Some(metrics) = metrics else { + return internal_err!("FileStreamBuilder missing required metrics"); + }; + let projected_schema = config.projected_schema()?; + let Some(file_group) = config.file_groups.get(partition).cloned() else { + return internal_err!( + "FileStreamBuilder invalid partition index: {partition}" + ); + }; + let work_source = match shared_work_source { + Some(shared) => WorkSource::Shared(shared), + None => WorkSource::Local(file_group.into_inner().into()), + }; + + let file_stream_metrics = FileStreamMetrics::new(metrics, partition); + let scan_state = Box::new(ScanState::new( + work_source, + config.limit, + morselizer, + on_error, + file_stream_metrics, + )); + + Ok(FileStream { + projected_schema, + state: FileStreamState::Scan { scan_state }, + baseline_metrics: BaselineMetrics::new(metrics, partition), + }) + } +} diff --git a/datafusion/datasource/src/file_stream/metrics.rs b/datafusion/datasource/src/file_stream/metrics.rs new file mode 100644 index 0000000000000..5f3894404f408 --- /dev/null +++ b/datafusion/datasource/src/file_stream/metrics.rs @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::instant::Instant; +use datafusion_physical_plan::metrics::{ + Count, ExecutionPlanMetricsSet, MetricBuilder, MetricCategory, Time, +}; + +/// A timer that can be started and stopped. +pub struct StartableTime { + pub metrics: Time, + // use for record each part cost time, will eventually add into 'metrics'. + pub start: Option, +} + +impl StartableTime { + pub fn start(&mut self) { + assert!(self.start.is_none()); + self.start = Some(Instant::now()); + } + + pub fn stop(&mut self) { + if let Some(start) = self.start.take() { + self.metrics.add_elapsed(start); + } + } +} + +/// Metrics for [`FileStream`] +/// +/// Note that all of these metrics are in terms of wall clock time +/// (not cpu time) so they include time spent waiting on I/O as well +/// as other operators. +/// +/// [`FileStream`]: +pub struct FileStreamMetrics { + /// Wall clock time elapsed for file opening. + /// + /// Time between when [`FileOpener::open`] is called and when the + /// [`FileStream`] receives a stream for reading. + /// + /// [`FileStream`]: crate::file_stream::FileStream + /// [`FileOpener::open`]: crate::file_stream::FileOpener::open + pub time_opening: StartableTime, + /// Wall clock time elapsed for file scanning + first record batch of decompression + decoding + /// + /// Time between when the [`FileStream`] requests data from the + /// stream and when the first [`RecordBatch`] is produced. + /// + /// [`FileStream`]: crate::file_stream::FileStream + /// [`RecordBatch`]: arrow::record_batch::RecordBatch + pub time_scanning_until_data: StartableTime, + /// Total elapsed wall clock time for scanning + record batch decompression / decoding + /// + /// Sum of time between when the [`FileStream`] requests data from + /// the stream and when a [`RecordBatch`] is produced for all + /// record batches in the stream. Note that this metric also + /// includes the time of the parent operator's execution. + /// + /// [`FileStream`]: crate::file_stream::FileStream + /// [`RecordBatch`]: arrow::record_batch::RecordBatch + pub time_scanning_total: StartableTime, + /// Wall clock time elapsed for data decompression + decoding + /// + /// Time spent waiting for the FileStream's input. + pub time_processing: Time, + /// Count of errors opening file. + /// + /// If using `OnError::Skip` this will provide a count of the number of files + /// which were skipped and will not be included in the scan results. + pub file_open_errors: Count, + /// Count of errors scanning file + /// + /// If using `OnError::Skip` this will provide a count of the number of files + /// which were skipped and will not be included in the scan results. + pub file_scan_errors: Count, + /// Count of files successfully opened or evaluated for processing. + /// At t=end (completion of a query) this is equal to `files_opened`, and both values are equal + /// to the total number of files in the query; unless the query itself fails. + /// This value will always be greater than or equal to `files_open`. + /// Note that this value does *not* mean the file was actually scanned. + /// We increment this value for any processing of a file, even if that processing is + /// discarding it because we hit a `LIMIT` (in this case `files_opened` and `files_processed` are both incremented at the same time). + pub files_opened: Count, + /// Count of files completely processed / closed (opened, pruned, or skipped due to limit). + /// At t=0 (the beginning of a query) this is 0. + /// At t=end (completion of a query) this is equal to `files_opened`, and both values are equal + /// to the total number of files in the query; unless the query itself fails. + /// This value will always be less than or equal to `files_open`. + /// We increment this value for any processing of a file, even if that processing is + /// discarding it because we hit a `LIMIT` (in this case `files_opened` and `files_processed` are both incremented at the same time). + pub files_processed: Count, +} + +impl FileStreamMetrics { + pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + let time_opening = StartableTime { + metrics: MetricBuilder::new(metrics) + .subset_time("time_elapsed_opening", partition), + start: None, + }; + + let time_scanning_until_data = StartableTime { + metrics: MetricBuilder::new(metrics) + .subset_time("time_elapsed_scanning_until_data", partition), + start: None, + }; + + let time_scanning_total = StartableTime { + metrics: MetricBuilder::new(metrics) + .subset_time("time_elapsed_scanning_total", partition), + start: None, + }; + + let time_processing = + MetricBuilder::new(metrics).subset_time("time_elapsed_processing", partition); + + let file_open_errors = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("file_open_errors", partition); + + let file_scan_errors = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("file_scan_errors", partition); + + let files_opened = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("files_opened", partition); + + let files_processed = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("files_processed", partition); + + Self { + time_opening, + time_scanning_until_data, + time_scanning_total, + time_processing, + file_open_errors, + file_scan_errors, + files_opened, + files_processed, + } + } +} diff --git a/datafusion/datasource/src/file_stream/mod.rs b/datafusion/datasource/src/file_stream/mod.rs new file mode 100644 index 0000000000000..d976bf955dbb2 --- /dev/null +++ b/datafusion/datasource/src/file_stream/mod.rs @@ -0,0 +1,1678 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A generic stream over file format readers that can be used by +//! any file format that read its files from start to end. +//! +//! Note: Most traits here need to be marked `Sync + Send` to be +//! compliant with the `SendableRecordBatchStream` trait. + +mod builder; +mod metrics; +mod scan_state; +pub(crate) mod work_source; + +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::PartitionedFile; +use crate::file_scan_config::FileScanConfig; +use arrow::datatypes::SchemaRef; +use datafusion_common::Result; +use datafusion_execution::RecordBatchStream; +use datafusion_physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; + +use arrow::record_batch::RecordBatch; + +use futures::Stream; +use futures::future::BoxFuture; +use futures::stream::BoxStream; + +use self::scan_state::{ScanAndReturn, ScanState}; + +pub use builder::FileStreamBuilder; +pub use metrics::{FileStreamMetrics, StartableTime}; + +/// A stream that iterates record batch by record batch, file over file. +pub struct FileStream { + /// The stream schema (file schema including partition columns and after + /// projection). + projected_schema: SchemaRef, + /// The stream state + state: FileStreamState, + /// runtime baseline metrics + baseline_metrics: BaselineMetrics, +} + +impl FileStream { + /// Create a new `FileStream` using the give `FileOpener` to scan underlying files + #[deprecated(since = "54.0.0", note = "Use FileStreamBuilder instead")] + pub fn new( + config: &FileScanConfig, + partition: usize, + file_opener: Arc, + metrics: &ExecutionPlanMetricsSet, + ) -> Result { + FileStreamBuilder::new(config) + .with_partition(partition) + .with_file_opener(file_opener) + .with_metrics(metrics) + .build() + } + + /// Specify the behavior when an error occurs opening or scanning a file + /// + /// If `OnError::Skip` the stream will skip files which encounter an error and continue + /// If `OnError:Fail` (default) the stream will fail and stop processing when an error occurs + pub fn with_on_error(mut self, on_error: OnError) -> Self { + match &mut self.state { + FileStreamState::Scan { scan_state } => scan_state.set_on_error(on_error), + FileStreamState::Error | FileStreamState::Done => { + // no effect as there are no more files to process + } + }; + self + } + + fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll>> { + loop { + match &mut self.state { + FileStreamState::Scan { scan_state: queue } => { + let action = queue.poll_scan(cx); + match action { + ScanAndReturn::Continue => continue, + ScanAndReturn::Done(result) => { + self.state = FileStreamState::Done; + return Poll::Ready(result); + } + ScanAndReturn::Error(err) => { + self.state = FileStreamState::Error; + return Poll::Ready(Some(Err(err))); + } + ScanAndReturn::Return(result) => return result, + } + } + FileStreamState::Error | FileStreamState::Done => { + return Poll::Ready(None); + } + } + } + } +} + +impl Stream for FileStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let result = self.poll_inner(cx); + self.baseline_metrics.record_poll(result) + } +} + +impl RecordBatchStream for FileStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.projected_schema) + } +} + +/// A fallible future that resolves to a stream of [`RecordBatch`] +pub type FileOpenFuture = + BoxFuture<'static, Result>>>; + +/// Describes the behavior of the `FileStream` if file opening or scanning fails +#[derive(Default)] +pub enum OnError { + /// Fail the entire stream and return the underlying error + #[default] + Fail, + /// Continue scanning, ignoring the failed file + Skip, +} + +/// Generic API for opening a file using an [`ObjectStore`] and resolving to a +/// stream of [`RecordBatch`] +/// +/// [`ObjectStore`]: object_store::ObjectStore +pub trait FileOpener: Unpin + Send + Sync { + /// Asynchronously open the specified file and return a stream + /// of [`RecordBatch`] + fn open(&self, partitioned_file: PartitionedFile) -> Result; +} + +enum FileStreamState { + /// Actively processing readers, ready morsels, and planner work. + Scan { + /// The ready queues and active reader for the current file. + scan_state: Box, + }, + /// Encountered an error + Error, + /// Finished scanning all requested data, possibly because a limit was reached + Done, +} + +#[cfg(test)] +mod tests { + use crate::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; + use crate::morsel::mocks::{ + IoFutureId, MockMorselizer, MockPlanBuilder, MockPlanner, MorselId, + PendingPlannerBuilder, PollsToResolve, + }; + use crate::source::DataSource; + use crate::tests::make_partition; + use crate::{PartitionedFile, TableSchema}; + use arrow::array::{AsArray, RecordBatch}; + use arrow::datatypes::{DataType, Field, Int32Type, Schema}; + use datafusion_common::DataFusionError; + use datafusion_common::error::Result; + use datafusion_execution::object_store::ObjectStoreUrl; + use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; + use futures::{FutureExt as _, StreamExt as _}; + use std::collections::{BTreeMap, VecDeque}; + use std::sync::Arc; + use std::sync::atomic::{AtomicUsize, Ordering}; + + use crate::file_stream::{ + FileOpenFuture, FileOpener, FileStream, FileStreamBuilder, OnError, + work_source::SharedWorkSource, + }; + use crate::test_util::MockSource; + + use datafusion_common::{assert_batches_eq, exec_err, internal_err}; + + /// Test identifier for one `FileStream` partition. + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] + struct PartitionId(usize); + + /// Test `FileOpener` which will simulate errors during file opening or scanning + #[derive(Default)] + struct TestOpener { + /// Index in stream of files which should throw an error while opening + error_opening_idx: Vec, + /// Index in stream of files which should throw an error while scanning + error_scanning_idx: Vec, + /// Index of last file in stream + current_idx: AtomicUsize, + /// `RecordBatch` to return + records: Vec, + } + + impl FileOpener for TestOpener { + fn open(&self, _partitioned_file: PartitionedFile) -> Result { + let idx = self.current_idx.fetch_add(1, Ordering::SeqCst); + + if self.error_opening_idx.contains(&idx) { + Ok(futures::future::ready(internal_err!("error opening")).boxed()) + } else if self.error_scanning_idx.contains(&idx) { + let error = futures::future::ready(exec_err!("error scanning")); + let stream = futures::stream::once(error).boxed(); + Ok(futures::future::ready(Ok(stream)).boxed()) + } else { + let iterator = self.records.clone().into_iter().map(Ok); + let stream = futures::stream::iter(iterator).boxed(); + Ok(futures::future::ready(Ok(stream)).boxed()) + } + } + } + + #[derive(Default)] + struct FileStreamTest { + /// Number of files in the stream + num_files: usize, + /// Global limit of records emitted by the stream + limit: Option, + /// Error-handling behavior of the stream + on_error: OnError, + /// Mock `FileOpener` + opener: TestOpener, + } + + impl FileStreamTest { + pub fn new() -> Self { + Self::default() + } + + /// Specify the number of files in the stream + pub fn with_num_files(mut self, num_files: usize) -> Self { + self.num_files = num_files; + self + } + + /// Specify the limit + pub fn with_limit(mut self, limit: Option) -> Self { + self.limit = limit; + self + } + + /// Specify the index of files in the stream which should + /// throw an error when opening + pub fn with_open_errors(mut self, idx: Vec) -> Self { + self.opener.error_opening_idx = idx; + self + } + + /// Specify the index of files in the stream which should + /// throw an error when scanning + pub fn with_scan_errors(mut self, idx: Vec) -> Self { + self.opener.error_scanning_idx = idx; + self + } + + /// Specify the behavior of the stream when an error occurs + pub fn with_on_error(mut self, on_error: OnError) -> Self { + self.on_error = on_error; + self + } + + /// Specify the record batches that should be returned from each + /// file that is successfully scanned + pub fn with_records(mut self, records: Vec) -> Self { + self.opener.records = records; + self + } + + /// Collect the results of the `FileStream` + pub async fn result(self) -> Result> { + let file_schema = self + .opener + .records + .first() + .map(|batch| batch.schema()) + .unwrap_or_else(|| Arc::new(Schema::empty())); + + // let ctx = SessionContext::new(); + let mock_files: Vec<(String, u64)> = (0..self.num_files) + .map(|idx| (format!("mock_file{idx}"), 10_u64)) + .collect(); + + // let mock_files_ref: Vec<(&str, u64)> = mock_files + // .iter() + // .map(|(name, size)| (name.as_str(), *size)) + // .collect(); + + let file_group = mock_files + .into_iter() + .map(|(name, size)| PartitionedFile::new(name, size)) + .collect(); + + let on_error = self.on_error; + + let table_schema = TableSchema::from(file_schema); + let config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("test:///").unwrap(), + Arc::new(MockSource::new(table_schema)), + ) + .with_file_group(file_group) + .with_limit(self.limit) + .build(); + let metrics_set = ExecutionPlanMetricsSet::new(); + let file_stream = FileStreamBuilder::new(&config) + .with_partition(0) + .with_file_opener(Arc::new(self.opener)) + .with_metrics(&metrics_set) + .with_on_error(on_error) + .build()?; + + file_stream + .collect::>() + .await + .into_iter() + .collect::>>() + } + } + + /// helper that creates a stream of 2 files with the same pair of batches in each ([0,1,2] and [0,1]) + async fn create_and_collect(limit: Option) -> Vec { + FileStreamTest::new() + .with_records(vec![make_partition(3), make_partition(2)]) + .with_num_files(2) + .with_limit(limit) + .result() + .await + .expect("error executing stream") + } + + /// Create the smallest valid file scan config for builder validation tests. + fn builder_test_config() -> FileScanConfig { + let table_schema = TableSchema::from(Arc::new(Schema::empty())); + FileScanConfigBuilder::new( + ObjectStoreUrl::parse("test:///").unwrap(), + Arc::new(MockSource::new(table_schema)), + ) + .with_file(PartitionedFile::new("mock_file", 10)) + .build() + } + + /// Convenience helper to keep builder error assertions focused on the + /// specific missing or invalid input under test. + fn builder_error(builder: FileStreamBuilder<'_>) -> String { + builder.build().err().unwrap().to_string() + } + + #[tokio::test] + async fn on_error_opening() -> Result<()> { + let batches = FileStreamTest::new() + .with_records(vec![make_partition(3), make_partition(2)]) + .with_num_files(2) + .with_on_error(OnError::Skip) + .with_open_errors(vec![0]) + .result() + .await?; + + #[rustfmt::skip] + assert_batches_eq!(&[ + "+---+", + "| i |", + "+---+", + "| 0 |", + "| 1 |", + "| 2 |", + "| 0 |", + "| 1 |", + "+---+", + ], &batches); + + let batches = FileStreamTest::new() + .with_records(vec![make_partition(3), make_partition(2)]) + .with_num_files(2) + .with_on_error(OnError::Skip) + .with_open_errors(vec![1]) + .result() + .await?; + + #[rustfmt::skip] + assert_batches_eq!(&[ + "+---+", + "| i |", + "+---+", + "| 0 |", + "| 1 |", + "| 2 |", + "| 0 |", + "| 1 |", + "+---+", + ], &batches); + + let batches = FileStreamTest::new() + .with_records(vec![make_partition(3), make_partition(2)]) + .with_num_files(2) + .with_on_error(OnError::Skip) + .with_open_errors(vec![0, 1]) + .result() + .await?; + + #[rustfmt::skip] + assert_batches_eq!(&[ + "++", + "++", + ], &batches); + + Ok(()) + } + + #[tokio::test] + async fn on_error_scanning_fail() -> Result<()> { + let result = FileStreamTest::new() + .with_records(vec![make_partition(3), make_partition(2)]) + .with_num_files(2) + .with_on_error(OnError::Fail) + .with_scan_errors(vec![1]) + .result() + .await; + + assert!(result.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn on_error_opening_fail() -> Result<()> { + let result = FileStreamTest::new() + .with_records(vec![make_partition(3), make_partition(2)]) + .with_num_files(2) + .with_on_error(OnError::Fail) + .with_open_errors(vec![1]) + .result() + .await; + + assert!(result.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn on_error_scanning() -> Result<()> { + let batches = FileStreamTest::new() + .with_records(vec![make_partition(3), make_partition(2)]) + .with_num_files(2) + .with_on_error(OnError::Skip) + .with_scan_errors(vec![0]) + .result() + .await?; + + #[rustfmt::skip] + assert_batches_eq!(&[ + "+---+", + "| i |", + "+---+", + "| 0 |", + "| 1 |", + "| 2 |", + "| 0 |", + "| 1 |", + "+---+", + ], &batches); + + let batches = FileStreamTest::new() + .with_records(vec![make_partition(3), make_partition(2)]) + .with_num_files(2) + .with_on_error(OnError::Skip) + .with_scan_errors(vec![1]) + .result() + .await?; + + #[rustfmt::skip] + assert_batches_eq!(&[ + "+---+", + "| i |", + "+---+", + "| 0 |", + "| 1 |", + "| 2 |", + "| 0 |", + "| 1 |", + "+---+", + ], &batches); + + let batches = FileStreamTest::new() + .with_records(vec![make_partition(3), make_partition(2)]) + .with_num_files(2) + .with_on_error(OnError::Skip) + .with_scan_errors(vec![0, 1]) + .result() + .await?; + + #[rustfmt::skip] + assert_batches_eq!(&[ + "++", + "++", + ], &batches); + + Ok(()) + } + + #[tokio::test] + async fn on_error_mixed() -> Result<()> { + let batches = FileStreamTest::new() + .with_records(vec![make_partition(3), make_partition(2)]) + .with_num_files(3) + .with_on_error(OnError::Skip) + .with_open_errors(vec![1]) + .with_scan_errors(vec![0]) + .result() + .await?; + + #[rustfmt::skip] + assert_batches_eq!(&[ + "+---+", + "| i |", + "+---+", + "| 0 |", + "| 1 |", + "| 2 |", + "| 0 |", + "| 1 |", + "+---+", + ], &batches); + + let batches = FileStreamTest::new() + .with_records(vec![make_partition(3), make_partition(2)]) + .with_num_files(3) + .with_on_error(OnError::Skip) + .with_open_errors(vec![0]) + .with_scan_errors(vec![1]) + .result() + .await?; + + #[rustfmt::skip] + assert_batches_eq!(&[ + "+---+", + "| i |", + "+---+", + "| 0 |", + "| 1 |", + "| 2 |", + "| 0 |", + "| 1 |", + "+---+", + ], &batches); + + let batches = FileStreamTest::new() + .with_records(vec![make_partition(3), make_partition(2)]) + .with_num_files(3) + .with_on_error(OnError::Skip) + .with_open_errors(vec![2]) + .with_scan_errors(vec![0, 1]) + .result() + .await?; + + #[rustfmt::skip] + assert_batches_eq!(&[ + "++", + "++", + ], &batches); + + let batches = FileStreamTest::new() + .with_records(vec![make_partition(3), make_partition(2)]) + .with_num_files(3) + .with_on_error(OnError::Skip) + .with_open_errors(vec![0, 2]) + .with_scan_errors(vec![1]) + .result() + .await?; + + #[rustfmt::skip] + assert_batches_eq!(&[ + "++", + "++", + ], &batches); + + Ok(()) + } + + #[tokio::test] + async fn without_limit() -> Result<()> { + let batches = create_and_collect(None).await; + + #[rustfmt::skip] + assert_batches_eq!(&[ + "+---+", + "| i |", + "+---+", + "| 0 |", + "| 1 |", + "| 2 |", + "| 0 |", + "| 1 |", + "| 0 |", + "| 1 |", + "| 2 |", + "| 0 |", + "| 1 |", + "+---+", + ], &batches); + + Ok(()) + } + + #[tokio::test] + async fn with_limit_between_files() -> Result<()> { + let batches = create_and_collect(Some(5)).await; + #[rustfmt::skip] + assert_batches_eq!(&[ + "+---+", + "| i |", + "+---+", + "| 0 |", + "| 1 |", + "| 2 |", + "| 0 |", + "| 1 |", + "+---+", + ], &batches); + + Ok(()) + } + + #[tokio::test] + async fn with_limit_at_middle_of_batch() -> Result<()> { + let batches = create_and_collect(Some(6)).await; + #[rustfmt::skip] + assert_batches_eq!(&[ + "+---+", + "| i |", + "+---+", + "| 0 |", + "| 1 |", + "| 2 |", + "| 0 |", + "| 1 |", + "| 0 |", + "+---+", + ], &batches); + + Ok(()) + } + + #[test] + fn builder_requires_partition_file_opener_and_metrics() { + let config = builder_test_config(); + + let err = builder_error(FileStreamBuilder::new(&config)); + assert!(err.contains("FileStreamBuilder missing required partition")); + + let err = builder_error(FileStreamBuilder::new(&config).with_partition(0)); + assert!(err.contains("FileStreamBuilder missing required morselizer")); + + let err = builder_error( + FileStreamBuilder::new(&config) + .with_partition(0) + .with_file_opener(Arc::new(TestOpener::default())), + ); + assert!(err.contains("FileStreamBuilder missing required metrics")); + } + + #[test] + fn builder_errors_on_invalid_partition() { + let config = builder_test_config(); + let metrics = ExecutionPlanMetricsSet::new(); + + let err = builder_error( + FileStreamBuilder::new(&config) + .with_partition(1) + .with_file_opener(Arc::new(TestOpener::default())) + .with_metrics(&metrics), + ); + assert!(err.contains("FileStreamBuilder invalid partition index: 1")); + } + + /// Verifies the simplest morsel-driven flow: one planner produces one + /// morsel immediately, and that morsel is then scanned to completion. + #[tokio::test] + async fn morsel_no_io() -> Result<()> { + let test = FileStreamMorselTest::new().with_file( + MockPlanner::builder("file1.parquet") + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(10), 42)) + .return_none(), + ); + + insta::assert_snapshot!(test.run().await.unwrap(), @r" + ----- Output Stream ----- + Batch: 42 + Done + ----- File Stream Events ----- + morselize_file: file1.parquet + planner_created: file1.parquet + planner_called: file1.parquet + morsel_produced: file1.parquet, MorselId(10) + morsel_stream_started: MorselId(10) + morsel_stream_batch_produced: MorselId(10), BatchId(42) + morsel_stream_finished: MorselId(10) + "); + + Ok(()) + } + + /// Verifies that a planner can block on one I/O phase and then produce a + /// morsel containing two batches. + #[tokio::test] + async fn morsel_single_io_two_batches() -> Result<()> { + let test = FileStreamMorselTest::new().with_file( + MockPlanner::builder("file1.parquet") + .add_plan( + PendingPlannerBuilder::new(IoFutureId(1)) + .with_polls_to_resolve(PollsToResolve(1)), + ) + .add_plan( + MockPlanBuilder::new() + .with_morsel_batches(MorselId(10), vec![42, 43]), + ) + .return_none(), + ); + + insta::assert_snapshot!(test.run().await.unwrap(), @r" + ----- Output Stream ----- + Batch: 42 + Batch: 43 + Done + ----- File Stream Events ----- + morselize_file: file1.parquet + planner_created: file1.parquet + planner_called: file1.parquet + io_future_created: file1.parquet, IoFutureId(1) + io_future_polled: file1.parquet, IoFutureId(1) + io_future_polled: file1.parquet, IoFutureId(1) + io_future_resolved: file1.parquet, IoFutureId(1) + planner_called: file1.parquet + morsel_produced: file1.parquet, MorselId(10) + morsel_stream_started: MorselId(10) + morsel_stream_batch_produced: MorselId(10), BatchId(42) + morsel_stream_batch_produced: MorselId(10), BatchId(43) + morsel_stream_finished: MorselId(10) + "); + + Ok(()) + } + + /// Verifies that a planner can traverse two sequential I/O phases before + /// producing one batch, similar to Parquet. + #[tokio::test] + async fn morsel_two_ios_one_batch() -> Result<()> { + let test = FileStreamMorselTest::new().with_file( + MockPlanner::builder("file1.parquet") + .add_plan(PendingPlannerBuilder::new(IoFutureId(1))) + .add_plan(PendingPlannerBuilder::new(IoFutureId(2))) + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(10), 42)) + .return_none(), + ); + + insta::assert_snapshot!(test.run().await.unwrap(), @r" + ----- Output Stream ----- + Batch: 42 + Done + ----- File Stream Events ----- + morselize_file: file1.parquet + planner_created: file1.parquet + planner_called: file1.parquet + io_future_created: file1.parquet, IoFutureId(1) + io_future_polled: file1.parquet, IoFutureId(1) + io_future_resolved: file1.parquet, IoFutureId(1) + planner_called: file1.parquet + io_future_created: file1.parquet, IoFutureId(2) + io_future_polled: file1.parquet, IoFutureId(2) + io_future_resolved: file1.parquet, IoFutureId(2) + planner_called: file1.parquet + morsel_produced: file1.parquet, MorselId(10) + morsel_stream_started: MorselId(10) + morsel_stream_batch_produced: MorselId(10), BatchId(42) + morsel_stream_finished: MorselId(10) + "); + + Ok(()) + } + + /// Verifies that a planner I/O future can fail and terminate the stream. + #[tokio::test] + async fn morsel_io_error() -> Result<()> { + let test = FileStreamMorselTest::new().with_file( + MockPlanner::builder("file1.parquet").add_plan( + PendingPlannerBuilder::new(IoFutureId(1)) + .with_error("io failed while opening file"), + ), + ); + + insta::assert_snapshot!(test.run().await.unwrap(), @r" + ----- Output Stream ----- + Error: io failed while opening file + Done + ----- File Stream Events ----- + morselize_file: file1.parquet + planner_created: file1.parquet + planner_called: file1.parquet + io_future_created: file1.parquet, IoFutureId(1) + io_future_polled: file1.parquet, IoFutureId(1) + io_future_errored: file1.parquet, IoFutureId(1), io failed while opening file + "); + + Ok(()) + } + + /// Verifies that pending planner I/O does not block draining the current + /// morsel stream. + #[tokio::test] + async fn morsel_pending_planner_does_not_block_active_reader() -> Result<()> { + let test = FileStreamMorselTest::new().with_file( + MockPlanner::builder("file1.parquet") + .add_plan( + MockPlanBuilder::new() + .with_morsel_batches(MorselId(10), vec![41, 42]) + .with_pending_planner(IoFutureId(1), PollsToResolve(3), Ok(())), + ) + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(11), 43)) + .return_none(), + ); + + // The key events are: + // 1. the first `planner_called` produces `MorselId(10)` and creates `IoFutureId(1)` + // 2. `MorselId(10)` continues yielding both batches while that I/O is pending + // 3. after the I/O resolves, planning resumes and yields `MorselId(11)` + insta::assert_snapshot!(test.run().await.unwrap(), @r" + ----- Output Stream ----- + Batch: 41 + Batch: 42 + Batch: 43 + Done + ----- File Stream Events ----- + morselize_file: file1.parquet + planner_created: file1.parquet + planner_called: file1.parquet + morsel_produced: file1.parquet, MorselId(10) + io_future_created: file1.parquet, IoFutureId(1) + io_future_polled: file1.parquet, IoFutureId(1) + morsel_stream_started: MorselId(10) + io_future_polled: file1.parquet, IoFutureId(1) + morsel_stream_batch_produced: MorselId(10), BatchId(41) + io_future_polled: file1.parquet, IoFutureId(1) + morsel_stream_batch_produced: MorselId(10), BatchId(42) + io_future_polled: file1.parquet, IoFutureId(1) + io_future_resolved: file1.parquet, IoFutureId(1) + morsel_stream_finished: MorselId(10) + planner_called: file1.parquet + morsel_produced: file1.parquet, MorselId(11) + morsel_stream_started: MorselId(11) + morsel_stream_batch_produced: MorselId(11), BatchId(43) + morsel_stream_finished: MorselId(11) + "); + + Ok(()) + } + + /// Verifies that one `plan()` call can return a ready child planner, which + /// is then called to produce the morsel. + #[tokio::test] + async fn morsel_ready_child_planner() -> Result<()> { + let child_planner = MockPlanner::builder("child planner") + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(10), 42)) + .return_none(); + + let test = FileStreamMorselTest::new().with_file( + MockPlanner::builder("file1.parquet") + .add_plan(MockPlanBuilder::new().with_ready_planner(child_planner)) + .return_none(), + ); + + insta::assert_snapshot!(test.run().await.unwrap(), @r" + ----- Output Stream ----- + Batch: 42 + Done + ----- File Stream Events ----- + morselize_file: file1.parquet + planner_created: file1.parquet + planner_called: file1.parquet + planner_created: child planner + planner_called: child planner + morsel_produced: child planner, MorselId(10) + morsel_stream_started: MorselId(10) + morsel_stream_batch_produced: MorselId(10), BatchId(42) + morsel_stream_finished: MorselId(10) + "); + + Ok(()) + } + + /// Verifies that planning can fail after a successful I/O phase. + #[tokio::test] + async fn morsel_plan_error_after_io() -> Result<()> { + let test = FileStreamMorselTest::new().with_file( + MockPlanner::builder("file1.parquet") + .add_plan(PendingPlannerBuilder::new(IoFutureId(1))) + .return_error("planner failed after io"), + ); + + insta::assert_snapshot!(test.run().await.unwrap(), @r" + ----- Output Stream ----- + Error: planner failed after io + Done + ----- File Stream Events ----- + morselize_file: file1.parquet + planner_created: file1.parquet + planner_called: file1.parquet + io_future_created: file1.parquet, IoFutureId(1) + io_future_polled: file1.parquet, IoFutureId(1) + io_future_resolved: file1.parquet, IoFutureId(1) + planner_called: file1.parquet + "); + + Ok(()) + } + + /// Verifies that `FileStream` scans multiple files in order. + #[tokio::test] + async fn morsel_multiple_files() -> Result<()> { + let test = FileStreamMorselTest::new() + .with_file( + MockPlanner::builder("file1.parquet") + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(10), 41)) + .return_none(), + ) + .with_file( + MockPlanner::builder("file2.parquet") + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(11), 42)) + .return_none(), + ); + + insta::assert_snapshot!(test.run().await.unwrap(), @r" + ----- Output Stream ----- + Batch: 41 + Batch: 42 + Done + ----- File Stream Events ----- + morselize_file: file1.parquet + planner_created: file1.parquet + planner_called: file1.parquet + morsel_produced: file1.parquet, MorselId(10) + morsel_stream_started: MorselId(10) + morsel_stream_batch_produced: MorselId(10), BatchId(41) + morsel_stream_finished: MorselId(10) + morselize_file: file2.parquet + planner_created: file2.parquet + planner_called: file2.parquet + morsel_produced: file2.parquet, MorselId(11) + morsel_stream_started: MorselId(11) + morsel_stream_batch_produced: MorselId(11), BatchId(42) + morsel_stream_finished: MorselId(11) + "); + + Ok(()) + } + + /// Verifies that a global limit can stop the stream before a second file is opened. + #[tokio::test] + async fn morsel_limit_prevents_second_file() -> Result<()> { + let test = FileStreamMorselTest::new() + .with_file( + MockPlanner::builder("file1.parquet") + .add_plan( + MockPlanBuilder::new() + .with_morsel_batches(MorselId(10), vec![41, 42]), + ) + .return_none(), + ) + .with_file( + MockPlanner::builder("file2.parquet") + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(11), 43)) + .return_none(), + ) + .with_limit(1); + + // Note the snapshot should not ever see planner id2 + insta::assert_snapshot!(test.run().await.unwrap(), @r" + ----- Output Stream ----- + Batch: 41 + Done + ----- File Stream Events ----- + morselize_file: file1.parquet + planner_created: file1.parquet + planner_called: file1.parquet + morsel_produced: file1.parquet, MorselId(10) + morsel_stream_started: MorselId(10) + morsel_stream_batch_produced: MorselId(10), BatchId(41) + "); + + Ok(()) + } + + /// Return a morsel test with two partitions: + /// Partition 0: file1, file2, file3 + /// Partition 1: file4 + /// + /// Partition 1 has only 1 file but it polled first 4 times + fn two_partition_morsel_test() -> FileStreamMorselTest { + FileStreamMorselTest::new() + // Partition 0 has three files + .with_file_in_partition( + PartitionId(0), + MockPlanner::builder("file1.parquet") + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(10), 101)) + .return_none(), + ) + .with_file_in_partition( + PartitionId(0), + MockPlanner::builder("file2.parquet") + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(11), 102)) + .return_none(), + ) + .with_file_in_partition( + PartitionId(0), + MockPlanner::builder("file3.parquet") + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(12), 103)) + .return_none(), + ) + // Partition 1 has only one file, but is polled first + .with_file_in_partition( + PartitionId(1), + MockPlanner::builder("file4.parquet") + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(13), 201)) + .return_none(), + ) + .with_reads(vec![ + PartitionId(1), + PartitionId(1), + PartitionId(1), + PartitionId(1), + PartitionId(1), + ]) + } + + /// Verifies that an idle sibling stream can steal shared files from + /// another stream once it exhausts its own local work. + #[tokio::test] + async fn morsel_shared_files_can_be_stolen() -> Result<()> { + let test = two_partition_morsel_test().with_file_stream_events(false); + + // Partition 0 starts with 3 files, but Partition 1 is polled first. + // Since Partition 1 is polled first, it will run all the files even those + // that were assigned to Partition 0. + insta::assert_snapshot!(test.run().await.unwrap(), @r" + ----- Partition 0 ----- + Done + ----- Partition 1 ----- + Batch: 101 + Batch: 102 + Batch: 103 + Batch: 201 + Done + ----- File Stream Events ----- + (omitted due to with_file_stream_events(false)) + "); + + Ok(()) + } + + /// Verifies that a stream that must preserve order keeps its files local + /// and therefore cannot steal from a sibling shared queue. + #[tokio::test] + async fn morsel_preserve_order_keeps_files_local() -> Result<()> { + // same fixture as `morsel_shared_files_can_be_stolen` but marked as + // preserve-order + let test = two_partition_morsel_test() + .with_preserve_order(true) + .with_file_stream_events(false); + + // Even though that Partition 1 is polled first, it can not steal files + // from partition 0. The three files originally assigned to Partition 0 + // must be evaluated by Partition 0. + insta::assert_snapshot!(test.run().await.unwrap(), @r" + ----- Partition 0 ----- + Batch: 101 + Batch: 102 + Batch: 103 + Done + ----- Partition 1 ----- + Batch: 201 + Done + ----- File Stream Events ----- + (omitted due to with_file_stream_events(false)) + "); + + Ok(()) + } + + /// Verifies that `partitioned_by_file_group` disables shared work stealing. + #[tokio::test] + async fn morsel_partitioned_by_file_group_keeps_files_local() -> Result<()> { + // same fixture as `morsel_shared_files_can_be_stolen` but marked as + // preserve-partitioned + let test = two_partition_morsel_test() + .with_partitioned_by_file_group(true) + .with_file_stream_events(false); + + insta::assert_snapshot!(test.run().await.unwrap(), @r" + ----- Partition 0 ----- + Batch: 101 + Batch: 102 + Batch: 103 + Done + ----- Partition 1 ----- + Batch: 201 + Done + ----- File Stream Events ----- + (omitted due to with_file_stream_events(false)) + "); + + Ok(()) + } + + /// Verifies that an empty sibling can immediately steal shared files when + /// it is polled before the stream that originally owned them. + #[tokio::test] + async fn morsel_empty_sibling_can_steal() -> Result<()> { + let test = FileStreamMorselTest::new() + .with_file_in_partition( + PartitionId(0), + MockPlanner::builder("file1.parquet") + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(10), 101)) + .return_none(), + ) + .with_file_in_partition( + PartitionId(0), + MockPlanner::builder("file2.parquet") + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(11), 102)) + .return_none(), + ) + // Poll the empty sibling first so it steals both files. + .with_reads(vec![PartitionId(1), PartitionId(1), PartitionId(1)]) + .with_file_stream_events(false); + + insta::assert_snapshot!(test.run().await.unwrap(), @r" + ----- Partition 0 ----- + Done + ----- Partition 1 ----- + Batch: 101 + Batch: 102 + Done + ----- File Stream Events ----- + (omitted due to with_file_stream_events(false)) + "); + + Ok(()) + } + + /// Ensures that if a sibling is built and polled + /// before another sibling has been built and contributed its files to the + /// shared queue, the first sibling does not finish prematurely. + #[tokio::test] + async fn morsel_empty_sibling_can_finish_before_shared_work_exists() -> Result<()> { + let test = FileStreamMorselTest::new() + .with_file_in_partition( + PartitionId(0), + MockPlanner::builder("file1.parquet") + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(10), 101)) + .return_none(), + ) + .with_file_in_partition( + PartitionId(0), + MockPlanner::builder("file2.parquet") + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(11), 102)) + .return_none(), + ) + // Build streams lazily so partition 1 can poll the shared queue + // before partition 0 has contributed its files. Once partition 0 + // is built, a later poll of partition 1 can still steal one of + // them from the shared queue. + .with_build_streams_on_first_read(true) + .with_reads(vec![PartitionId(1), PartitionId(0), PartitionId(1)]) + .with_file_stream_events(false); + + // Partition 1 polls too early once, then later steals one file after + // partition 0 has populated the shared queue. + insta::assert_snapshot!(test.run().await.unwrap(), @r" + ----- Partition 0 ----- + Batch: 102 + Done + ----- Partition 1 ----- + Batch: 101 + Done + ----- File Stream Events ----- + (omitted due to with_file_stream_events(false)) + "); + + Ok(()) + } + + /// Verifies that a sibling hitting its limit does not count shared files + /// left in the queue as already processed by that stream. + #[tokio::test] + async fn morsel_shared_limit_does_not_double_count_files_processed() -> Result<()> { + let test = two_partition_morsel_test(); + let unlimited_config = test.test_config(); + let limited_config = test.clone().with_limit(1).test_config(); + let shared_work_source = limited_config + .create_sibling_state() + .and_then(|state| state.as_ref().downcast_ref::().cloned()) + .expect("shared work source"); + let limited_metrics = ExecutionPlanMetricsSet::new(); + let unlimited_metrics = ExecutionPlanMetricsSet::new(); + + let limited_stream = FileStreamBuilder::new(&limited_config) + .with_partition(1) + .with_shared_work_source(Some(shared_work_source.clone())) + .with_morselizer(Box::new(test.morselizer.clone())) + .with_metrics(&limited_metrics) + .build()?; + + let unlimited_stream = FileStreamBuilder::new(&unlimited_config) + .with_partition(0) + .with_shared_work_source(Some(shared_work_source)) + .with_morselizer(Box::new(test.morselizer)) + .with_metrics(&unlimited_metrics) + .build()?; + + let limited_output = drain_stream_output(limited_stream).await?; + let unlimited_output = drain_stream_output(unlimited_stream).await?; + + insta::assert_snapshot!(format!( + "----- Limited Stream -----\n{limited_output}\n----- Unlimited Stream -----\n{unlimited_output}" + ), @r" + ----- Limited Stream ----- + Batch: 101 + ----- Unlimited Stream ----- + Batch: 102 + Batch: 103 + Batch: 201 + "); + + assert_eq!( + metric_count(&limited_metrics, "files_opened"), + 1, + "the limited stream should only open the file that produced its output" + ); + assert_eq!( + metric_count(&limited_metrics, "files_processed"), + 1, + "the limited stream should only mark its own file as processed" + ); + assert_eq!( + metric_count(&unlimited_metrics, "files_opened"), + 3, + "the draining stream should open the remaining shared files" + ); + assert_eq!( + metric_count(&unlimited_metrics, "files_processed"), + 3, + "the draining stream should process exactly the files it opened" + ); + + Ok(()) + } + + /// Verifies that one fast sibling can drain shared files that originated + /// in more than one other partition. + #[tokio::test] + async fn morsel_one_sibling_can_drain_multiple_siblings() -> Result<()> { + let test = FileStreamMorselTest::new() + .with_file_in_partition( + PartitionId(0), + MockPlanner::builder("file1.parquet") + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(10), 101)) + .return_none(), + ) + // Partition 1 has two files + .with_file_in_partition( + PartitionId(1), + MockPlanner::builder("file2.parquet") + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(11), 102)) + .return_none(), + ) + .with_file_in_partition( + PartitionId(1), + MockPlanner::builder("file3.parquet") + .add_plan(MockPlanBuilder::new().with_morsel(MorselId(12), 103)) + .return_none(), + ) + // Partition 2 starts empty but is polled first, so it should drain + // the shared queue across both sibling partitions. + .with_reads(vec![ + PartitionId(2), + PartitionId(2), + PartitionId(1), + PartitionId(2), + ]) + .with_file_stream_events(false); + + insta::assert_snapshot!(test.run().await.unwrap(), @r" + ----- Partition 0 ----- + Done + ----- Partition 1 ----- + Batch: 103 + Done + ----- Partition 2 ----- + Batch: 101 + Batch: 102 + Done + ----- File Stream Events ----- + (omitted due to with_file_stream_events(false)) + "); + + Ok(()) + } + + /// Tests how one or more `FileStream`s consume morselized file work. + #[derive(Clone)] + struct FileStreamMorselTest { + morselizer: MockMorselizer, + partition_files: BTreeMap>, + preserve_order: bool, + partitioned_by_file_group: bool, + file_stream_events: bool, + build_streams_on_first_read: bool, + reads: Vec, + limit: Option, + } + + impl FileStreamMorselTest { + /// Creates an empty test harness. + fn new() -> Self { + Self { + morselizer: MockMorselizer::new(), + partition_files: BTreeMap::new(), + preserve_order: false, + partitioned_by_file_group: false, + file_stream_events: true, + build_streams_on_first_read: false, + reads: vec![], + limit: None, + } + } + + /// Adds one file and its root planner to partition 0. + fn with_file(self, planner: impl Into) -> Self { + self.with_file_in_partition(PartitionId(0), planner) + } + + /// Adds one file and its root planner to the specified input partition. + fn with_file_in_partition( + mut self, + partition: PartitionId, + planner: impl Into, + ) -> Self { + let planner = planner.into(); + let file_path = planner.file_path().to_string(); + self.morselizer = self.morselizer.with_planner(planner); + self.partition_files + .entry(partition) + .or_default() + .push(file_path); + self + } + + /// Marks the stream (and all partitions) to preserve the specified file + /// order. + fn with_preserve_order(mut self, preserve_order: bool) -> Self { + self.preserve_order = preserve_order; + self + } + + /// Marks the test scan as pre-partitioned by file group, which should + /// force each stream to keep its own files local. + fn with_partitioned_by_file_group( + mut self, + partitioned_by_file_group: bool, + ) -> Self { + self.partitioned_by_file_group = partitioned_by_file_group; + self + } + + /// Controls whether scheduler events are included in the snapshot. + /// + /// When disabled, `run()` still includes the event section header but + /// replaces the trace with a fixed placeholder so tests can focus only + /// on the output batches. + fn with_file_stream_events(mut self, file_stream_events: bool) -> Self { + self.file_stream_events = file_stream_events; + self + } + + /// Controls whether streams are all built up front or lazily on their + /// first read. + /// + /// The default builds all streams before polling begins, which matches + /// normal execution. Tests may enable lazy creation to model races + /// where one sibling polls before another has contributed its files to + /// the shared queue. + fn with_build_streams_on_first_read( + mut self, + build_streams_on_first_read: bool, + ) -> Self { + self.build_streams_on_first_read = build_streams_on_first_read; + self + } + + /// Sets the partition polling order. + /// + /// `run()` polls these partitions in the listed order first. After + /// those explicit reads are exhausted, it completes to round + /// robin across all configured partitions, skipping any streams that + /// have already finished. + /// + /// This allows testing early scheduling decisions explicit in a test + /// while avoiding a fully scripted poll trace for the remainder. + fn with_reads(mut self, reads: Vec) -> Self { + self.reads = reads; + self + } + + /// Sets a global output limit for all streams created by this test. + fn with_limit(mut self, limit: usize) -> Self { + self.limit = Some(limit); + self + } + + /// Runs the test and returns combined stream output and scheduler + /// trace text. + async fn run(self) -> Result { + let observer = self.morselizer.observer().clone(); + observer.clear(); + + let metrics_set = ExecutionPlanMetricsSet::new(); + let partition_count = self.num_partitions(); + + let mut partitions = (0..partition_count) + .map(|_| PartitionState::new()) + .collect::>(); + + let mut build_order = Vec::new(); + for partition in self.reads.iter().map(|partition| partition.0) { + if !build_order.contains(&partition) { + build_order.push(partition); + } + } + for partition in 0..partition_count { + if !build_order.contains(&partition) { + build_order.push(partition); + } + } + + let config = self.test_config(); + // `DataSourceExec::execute` creates one execution-local shared + // state object via `create_sibling_state()` and then passes it + // to `open_with_sibling_state(...)`. These tests build + // `FileStream`s directly, bypassing `DataSourceExec`, so they must + // perform the same setup explicitly when exercising sibling-stream + // work stealing. + let shared_work_source = config.create_sibling_state().and_then(|state| { + state.as_ref().downcast_ref::().cloned() + }); + if !self.build_streams_on_first_read { + for partition in build_order { + let stream = FileStreamBuilder::new(&config) + .with_partition(partition) + .with_shared_work_source(shared_work_source.clone()) + .with_morselizer(Box::new(self.morselizer.clone())) + .with_metrics(&metrics_set) + .build()?; + partitions[partition].set_stream(stream); + } + } + + let mut initial_reads: VecDeque<_> = self.reads.into(); + let mut next_round_robin = 0; + + while !initial_reads.is_empty() + || partitions.iter().any(PartitionState::is_active) + { + let partition = if let Some(partition) = initial_reads.pop_front() { + partition.0 + } else { + let partition = next_round_robin; + next_round_robin = (next_round_robin + 1) % partition_count.max(1); + partition + }; + + let partition_state = &mut partitions[partition]; + + if self.build_streams_on_first_read && !partition_state.built { + let stream = FileStreamBuilder::new(&config) + .with_partition(partition) + .with_shared_work_source(shared_work_source.clone()) + .with_morselizer(Box::new(self.morselizer.clone())) + .with_metrics(&metrics_set) + .build()?; + partition_state.set_stream(stream); + } + + let Some(stream) = partition_state.stream.as_mut() else { + continue; + }; + + match stream.next().await { + Some(result) => partition_state.push_output(format_result(result)), + None => partition_state.finish(), + } + } + + let output_text = if partition_count == 1 { + format!( + "----- Output Stream -----\n{}", + partitions[0].output.join("\n") + ) + } else { + partitions + .into_iter() + .enumerate() + .map(|(partition, state)| { + format!( + "----- Partition {} -----\n{}", + partition, + state.output.join("\n") + ) + }) + .collect::>() + .join("\n") + }; + + let file_stream_events = if self.file_stream_events { + observer.format_events() + } else { + "(omitted due to with_file_stream_events(false))".to_string() + }; + + Ok(format!( + "{output_text}\n----- File Stream Events -----\n{file_stream_events}", + )) + } + + /// Returns the number of configured partitions, including empty ones + /// that appear only in the explicit read schedule. + fn num_partitions(&self) -> usize { + self.partition_files + .keys() + .map(|partition| partition.0 + 1) + .chain(self.reads.iter().map(|partition| partition.0 + 1)) + .max() + .unwrap_or(1) + } + + /// Builds a `FileScanConfig` covering every configured partition. + fn test_config(&self) -> FileScanConfig { + let file_groups = (0..self.num_partitions()) + .map(|partition| { + self.partition_files + .get(&PartitionId(partition)) + .into_iter() + .flat_map(|files| files.iter()) + .map(|name| PartitionedFile::new(name, 10)) + .collect::>() + .into() + }) + .collect::>(); + + let table_schema = + TableSchema::from(Arc::new(Schema::new(vec![Field::new( + "i", + DataType::Int32, + false, + )]))); + FileScanConfigBuilder::new( + ObjectStoreUrl::parse("test:///").unwrap(), + Arc::new(MockSource::new(table_schema)), + ) + .with_file_groups(file_groups) + .with_limit(self.limit) + .with_preserve_order(self.preserve_order) + .with_partitioned_by_file_group(self.partitioned_by_file_group) + .build() + } + } + + /// Formats one stream poll result into a stable snapshot line. + fn format_result(result: Result) -> String { + match result { + Ok(batch) => { + let col = batch.column(0).as_primitive::(); + let batch_id = col.value(0); + format!("Batch: {batch_id}") + } + Err(e) => { + // Pull the actual message for external errors rather than + // relying on DataFusionError formatting, which changes if + // backtraces are enabled, etc. + let message = if let DataFusionError::External(generic) = e { + generic.to_string() + } else { + e.to_string() + }; + format!("Error: {message}") + } + } + } + + async fn drain_stream_output(stream: FileStream) -> Result { + let output = stream + .collect::>() + .await + .into_iter() + .map(|result| result.map(|batch| format_result(Ok(batch)))) + .collect::>>()?; + Ok(output.join("\n")) + } + + fn metric_count(metrics: &ExecutionPlanMetricsSet, name: &str) -> usize { + metrics + .clone_inner() + .sum_by_name(name) + .unwrap_or_else(|| panic!("missing metric: {name}")) + .as_usize() + } + + /// Test-only state for one stream partition in [`FileStreamMorselTest`]. + struct PartitionState { + /// Whether the `FileStream` for this partition has been built yet. + built: bool, + /// The live stream, if this partition has not finished yet. + stream: Option, + /// Snapshot lines produced by this partition. + output: Vec, + } + + impl PartitionState { + /// Create an unbuilt partition with no output yet. + fn new() -> Self { + Self { + built: false, + stream: None, + output: vec![], + } + } + + /// Returns true if this partition might still produce output. + fn is_active(&self) -> bool { + !self.built || self.stream.is_some() + } + + /// Records that this partition's stream has been built. + fn set_stream(&mut self, stream: FileStream) { + self.stream = Some(stream); + self.built = true; + } + + /// Records one formatted output line for this partition. + fn push_output(&mut self, line: String) { + self.output.push(line); + } + + /// Marks this partition as finished. + fn finish(&mut self) { + self.push_output("Done".to_string()); + self.stream = None; + } + } +} diff --git a/datafusion/datasource/src/file_stream/scan_state.rs b/datafusion/datasource/src/file_stream/scan_state.rs new file mode 100644 index 0000000000000..21125cd08896c --- /dev/null +++ b/datafusion/datasource/src/file_stream/scan_state.rs @@ -0,0 +1,304 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::internal_datafusion_err; +use std::collections::VecDeque; +use std::task::{Context, Poll}; + +use crate::morsel::{Morsel, MorselPlanner, Morselizer, PendingMorselPlanner}; +use arrow::record_batch::RecordBatch; +use datafusion_common::{DataFusionError, Result}; +use datafusion_physical_plan::metrics::ScopedTimerGuard; +use futures::stream::BoxStream; +use futures::{FutureExt as _, StreamExt as _}; + +use super::work_source::WorkSource; +use super::{FileStreamMetrics, OnError}; + +/// State [`FileStreamState::Scan`]. +/// +/// There is one `ScanState` per `FileStream`, and thus per output partition. +/// +/// It groups together the lifecycle of scanning that partition's files: +/// unopened files, CPU-ready planners, pending planner I/O, ready morsels, +/// the active reader, and the metrics associated with processing that work. +/// +/// # I/O +/// +/// To avoid challenges controlling buffering, the ScanState only ever has a +/// single I/O outstanding at any time. +/// +/// # State Transitions +/// +/// ```text +/// work_source +/// | +/// v +/// morselizer.plan_file(file) +/// | +/// v +/// ready_planners ---> plan() ---> ready_morsels ---> into_stream() ---> reader ---> RecordBatches +/// ^ | +/// | v +/// | pending_planner +/// | | +/// | v +/// +-------- poll until ready +/// ``` +/// +/// [`FileStreamState::Scan`]: super::FileStreamState::Scan +pub(super) struct ScanState { + /// Unopened files that still need to be planned for this stream. + work_source: WorkSource, + /// Remaining row limit, if any. + remain: Option, + /// The morselizer used to plan files. + morselizer: Box, + /// Behavior if opening or scanning a file fails. + on_error: OnError, + /// CPU-ready planners for the current file. + ready_planners: VecDeque>, + /// Ready morsels for the current file. + ready_morsels: VecDeque>, + /// The active reader, if any. + reader: Option>>, + /// The single planner currently blocked on I/O, if any. + /// + /// Once the I/O completes, yields the next planner and is pushed back + /// onto `ready_planners`. + pending_planner: Option, + /// Metrics for the active scan queues. + metrics: FileStreamMetrics, +} + +impl ScanState { + pub(super) fn new( + work_source: WorkSource, + remain: Option, + morselizer: Box, + on_error: OnError, + metrics: FileStreamMetrics, + ) -> Self { + Self { + work_source, + remain, + morselizer, + on_error, + ready_planners: Default::default(), + ready_morsels: Default::default(), + reader: None, + pending_planner: None, + metrics, + } + } + + /// Updates how scan errors are handled while the stream is still active. + pub(super) fn set_on_error(&mut self, on_error: OnError) { + self.on_error = on_error; + } + + /// Drives one iteration of the active scan state. + /// + /// Work is attempted in this order: + /// 1. resolve any pending planner I/O + /// 2. poll the active reader + /// 3. turn a ready morsel into the active reader + /// 4. run CPU planning on a ready planner + /// 5. morselize the next unopened file + /// + /// The return [`ScanAndReturn`] tells `poll_inner` how to update the + /// outer `FileStreamState`. + pub(super) fn poll_scan(&mut self, cx: &mut Context<'_>) -> ScanAndReturn { + let _processing_timer: ScopedTimerGuard<'_> = + self.metrics.time_processing.timer(); + + // Try and resolve outstanding IO first. If it is still pending, check + // the current reader or ready morsels before yielding. New planning + // work must still wait for this I/O to resolve. + if let Some(mut pending_planner) = self.pending_planner.take() { + match pending_planner.poll_unpin(cx) { + // IO is still pending + Poll::Pending => { + self.pending_planner = Some(pending_planner); + } + // IO resolved, and the planner is ready for CPU work + Poll::Ready(Ok(planner)) => { + self.ready_planners.push_back(planner); + } + // IO Error + Poll::Ready(Err(err)) => { + self.metrics.file_open_errors.add(1); + self.metrics.time_opening.stop(); + return match self.on_error { + OnError::Skip => { + self.metrics.files_processed.add(1); + ScanAndReturn::Continue + } + OnError::Fail => ScanAndReturn::Error(err), + }; + } + } + } + + // Next try and get the next batch from the active reader, if any. + if let Some(reader) = self.reader.as_mut() { + match reader.poll_next_unpin(cx) { + // Morsels should ideally only expose ready-to-decode streams, + // but tolerate pending readers here. + Poll::Pending => return ScanAndReturn::Return(Poll::Pending), + Poll::Ready(Some(Ok(batch))) => { + self.metrics.time_scanning_until_data.stop(); + self.metrics.time_scanning_total.stop(); + // Apply any remaining row limit. + let (batch, finished) = match &mut self.remain { + Some(remain) => { + if *remain > batch.num_rows() { + *remain -= batch.num_rows(); + self.metrics.time_scanning_total.start(); + (batch, false) + } else { + let batch = batch.slice(0, *remain); + let done = 1 + self.work_source.skipped_on_limit(); + self.metrics.files_processed.add(done); + *remain = 0; + (batch, true) + } + } + None => { + self.metrics.time_scanning_total.start(); + (batch, false) + } + }; + return if finished { + ScanAndReturn::Done(Some(Ok(batch))) + } else { + ScanAndReturn::Return(Poll::Ready(Some(Ok(batch)))) + }; + } + Poll::Ready(Some(Err(err))) => { + self.reader = None; + self.metrics.file_scan_errors.add(1); + self.metrics.time_scanning_until_data.stop(); + self.metrics.time_scanning_total.stop(); + return match self.on_error { + OnError::Skip => { + self.metrics.files_processed.add(1); + ScanAndReturn::Continue + } + OnError::Fail => ScanAndReturn::Error(err), + }; + } + Poll::Ready(None) => { + self.reader = None; + self.metrics.files_processed.add(1); + self.metrics.time_scanning_until_data.stop(); + self.metrics.time_scanning_total.stop(); + return ScanAndReturn::Continue; + } + } + } + + // No active reader, but a morsel is ready to become the reader. + if let Some(morsel) = self.ready_morsels.pop_front() { + self.metrics.time_opening.stop(); + self.metrics.time_scanning_until_data.start(); + self.metrics.time_scanning_total.start(); + self.reader = Some(morsel.into_stream()); + return ScanAndReturn::Continue; + } + + // Do not start CPU planning or open another file while planner I/O is + // still outstanding because they may need additional IO and ScanState + // currently only permits a single outstanding IO + if self.pending_planner.is_some() { + return ScanAndReturn::Return(Poll::Pending); + } + + // No reader or morsel, so try to produce more work via CPU planning. + if let Some(planner) = self.ready_planners.pop_front() { + return match planner.plan() { + Ok(Some(mut plan)) => { + // Queue any newly-ready morsels, planners, or planner I/O. + self.ready_morsels.extend(plan.take_morsels()); + self.ready_planners.extend(plan.take_ready_planners()); + if let Some(pending_planner) = plan.take_pending_planner() { + // should not have planned if we have outstanding I/O + if self.pending_planner.is_some() { + return ScanAndReturn::Error(internal_datafusion_err!( + "Conflicting pending planner state in FileStream ScanState" + )); + } + self.pending_planner = Some(pending_planner); + } + ScanAndReturn::Continue + } + Ok(None) => { + self.metrics.files_processed.add(1); + self.metrics.time_opening.stop(); + ScanAndReturn::Continue + } + Err(err) => { + self.metrics.file_open_errors.add(1); + self.metrics.time_opening.stop(); + match self.on_error { + OnError::Skip => { + self.metrics.files_processed.add(1); + ScanAndReturn::Continue + } + OnError::Fail => ScanAndReturn::Error(err), + } + } + }; + } + + // No outstanding work remains, so begin planning the next unopened file. + let part_file = match self.work_source.pop_front() { + Some(part_file) => part_file, + None => return ScanAndReturn::Done(None), + }; + + self.metrics.time_opening.start(); + match self.morselizer.plan_file(part_file) { + Ok(planner) => { + self.metrics.files_opened.add(1); + self.ready_planners.push_back(planner); + ScanAndReturn::Continue + } + Err(err) => match self.on_error { + OnError::Skip => { + self.metrics.file_open_errors.add(1); + self.metrics.time_opening.stop(); + self.metrics.files_processed.add(1); + ScanAndReturn::Continue + } + OnError::Fail => ScanAndReturn::Error(err), + }, + } + } +} + +/// What should be done on the next iteration of [`ScanState::poll_scan`]? +pub(super) enum ScanAndReturn { + /// Poll again. + Continue, + /// Return the provided result without changing the outer state. + Return(Poll>>), + /// Update the outer `FileStreamState` to `Done` and return the provided result. + Done(Option>), + /// Update the outer `FileStreamState` to `Error` and return the provided error. + Error(DataFusionError), +} diff --git a/datafusion/datasource/src/file_stream/work_source.rs b/datafusion/datasource/src/file_stream/work_source.rs new file mode 100644 index 0000000000000..c00048453b304 --- /dev/null +++ b/datafusion/datasource/src/file_stream/work_source.rs @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::VecDeque; +use std::sync::Arc; + +use crate::PartitionedFile; +use crate::file_groups::FileGroup; +use crate::file_scan_config::FileScanConfig; +use parking_lot::Mutex; + +/// Source of work for `ScanState`. +/// +/// Streams that may share work across siblings use [`WorkSource::Shared`], +/// while streams that can not share work (e.g. because they must preserve file +/// order) use [`WorkSource::Local`]. +#[derive(Debug, Clone)] +pub(super) enum WorkSource { + /// Files this stream will plan locally without sharing them. + Local(VecDeque), + /// Files shared with sibling streams. + Shared(SharedWorkSource), +} + +impl WorkSource { + /// Pop the next file to plan from this work source. + pub(super) fn pop_front(&mut self) -> Option { + match self { + Self::Local(files) => files.pop_front(), + Self::Shared(shared) => shared.pop_front(), + } + } + + /// Return how many queued files should be counted as already processed + /// when this stream stops early after hitting a global limit. + pub(super) fn skipped_on_limit(&self) -> usize { + match self { + Self::Local(files) => files.len(), + Self::Shared(_) => 0, + } + } +} + +/// Shared source of work for sibling `FileStream`s +/// +/// The queue is created once per execution and shared by all reorderable +/// sibling streams for that execution. Whichever stream becomes idle first may +/// take the next unopened file from the front of the queue. +/// +/// It uses a [`Mutex`] internally to provide thread-safe access +/// to the shared file queue. +#[derive(Debug, Clone)] +pub(crate) struct SharedWorkSource { + inner: Arc, +} + +#[derive(Debug, Default)] +pub(super) struct SharedWorkSourceInner { + files: Mutex>, +} + +impl SharedWorkSource { + /// Create a shared work source containing the provided unopened files. + pub(crate) fn new(files: impl IntoIterator) -> Self { + let files = files.into_iter().collect(); + Self { + inner: Arc::new(SharedWorkSourceInner { + files: Mutex::new(files), + }), + } + } + + /// Create a shared work source for the unopened files in `config`. + /// + /// Files are reordered by the file source (e.g. by statistics for TopK) + /// before being placed in the shared queue, so the most promising files + /// are processed first across all partitions. + pub(crate) fn from_config(config: &FileScanConfig) -> Self { + let files: Vec<_> = config + .file_groups + .iter() + .flat_map(FileGroup::iter) + .cloned() + .collect(); + let files = config.file_source.reorder_files(files); + Self::new(files) + } + + /// Pop the next file from the shared work queue. + /// + /// Returns `None` if the queue is empty + fn pop_front(&self) -> Option { + self.inner.files.lock().pop_front() + } +} diff --git a/datafusion/datasource/src/memory.rs b/datafusion/datasource/src/memory.rs index 595b1bf6d4268..f073b09c5463e 100644 --- a/datafusion/datasource/src/memory.rs +++ b/datafusion/datasource/src/memory.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::cmp::Ordering; use std::collections::BinaryHeap; use std::fmt; @@ -30,19 +29,20 @@ use crate::source::{DataSource, DataSourceExec}; use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::{ - assert_or_internal_err, plan_err, project_schema, Result, ScalarValue, + Result, ScalarValue, assert_or_internal_err, plan_err, project_schema, }; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::project_orderings; +use datafusion_physical_expr::projection::ProjectionExprs; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use datafusion_physical_plan::memory::MemoryStream; use datafusion_physical_plan::projection::{ - all_alias_free_columns, new_projections_for_columns, ProjectionExpr, + all_alias_free_columns, new_projections_for_columns, }; use datafusion_physical_plan::{ - common, ColumnarValue, DisplayAs, DisplayFormatType, Partitioning, PhysicalExpr, - SendableRecordBatchStream, Statistics, + ColumnarValue, DisplayAs, DisplayFormatType, Partitioning, PhysicalExpr, + SendableRecordBatchStream, Statistics, common, }; use async_trait::async_trait; @@ -90,10 +90,6 @@ impl DataSource for MemorySourceConfig { ))) } - fn as_any(&self) -> &dyn Any { - self - } - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { @@ -119,10 +115,10 @@ impl DataSource for MemorySourceConfig { .map_or(String::new(), |limit| format!(", fetch={limit}")); if self.show_sizes { write!( - f, - "partitions={}, partition_sizes={partition_sizes:?}{limit}{output_ordering}{constraints}", - partition_sizes.len(), - ) + f, + "partitions={}, partition_sizes={partition_sizes:?}{limit}{output_ordering}{constraints}", + partition_sizes.len(), + ) } else { write!( f, @@ -195,26 +191,26 @@ impl DataSource for MemorySourceConfig { SchedulingType::Cooperative } - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { if let Some(partition) = partition { // Compute statistics for a specific partition if let Some(batches) = self.partitions.get(partition) { - Ok(common::compute_record_batch_statistics( + Ok(Arc::new(common::compute_record_batch_statistics( from_ref(batches), &self.schema, self.projection.clone(), - )) + ))) } else { // Invalid partition index - Ok(Statistics::new_unknown(&self.projected_schema)) + Ok(Arc::new(Statistics::new_unknown(&self.projected_schema))) } } else { // Compute statistics across all partitions - Ok(common::compute_record_batch_statistics( + Ok(Arc::new(common::compute_record_batch_statistics( &self.partitions, &self.schema, self.projection.clone(), - )) + ))) } } @@ -229,24 +225,34 @@ impl DataSource for MemorySourceConfig { fn try_swapping_with_projection( &self, - projection: &[ProjectionExpr], + projection: &ProjectionExprs, ) -> Result>> { // If there is any non-column or alias-carrier expression, Projection should not be removed. // This process can be moved into MemoryExec, but it would be an overlap of their responsibility. - all_alias_free_columns(projection) + let exprs = projection.iter().cloned().collect_vec(); + all_alias_free_columns(exprs.as_slice()) .then(|| { let all_projections = (0..self.schema.fields().len()).collect(); let new_projections = new_projections_for_columns( - projection, + &exprs, self.projection().as_ref().unwrap_or(&all_projections), ); - - MemorySourceConfig::try_new( - self.partitions(), - self.original_schema(), - Some(new_projections), - ) - .map(|s| Arc::new(s) as Arc) + let projected_schema = + project_schema(&self.schema, Some(&new_projections)); + + projected_schema.map(|projected_schema| { + // Clone self to preserve all metadata (fetch, sort_information, + // show_sizes, etc.) then update only the projection-related fields. + let mut new_source = self.clone(); + new_source.projection = Some(new_projections); + new_source.projected_schema = projected_schema; + // Project sort information to match the new projection + new_source.sort_information = project_orderings( + &new_source.sort_information, + &new_source.projected_schema, + ); + Arc::new(new_source) as Arc + }) }) .transpose() } @@ -746,10 +752,6 @@ impl MemSink { #[async_trait] impl DataSink for MemSink { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> &SchemaRef { &self.schema } @@ -773,7 +775,7 @@ impl DataSink for MemSink { } // write the outputs into the batches - for (target, mut batches) in self.batches.iter().zip(new_batches.into_iter()) { + for (target, mut batches) in self.batches.iter().zip(new_batches) { // Append all the new batches in one go to minimize locking overhead target.write().await.append(&mut batches); } @@ -853,7 +855,6 @@ mod tests { use datafusion_physical_plan::expressions::lit; use datafusion_physical_plan::ExecutionPlan; - use futures::StreamExt; #[tokio::test] async fn exec_with_limit() -> Result<()> { @@ -881,6 +882,39 @@ mod tests { Ok(()) } + /// Test that `try_swapping_with_projection` preserves the `fetch` limit. + /// Regression test for + #[test] + fn try_swapping_with_projection_preserves_fetch() { + use datafusion_physical_expr::projection::ProjectionExprs; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Int64, false), + ])); + let partitions: Vec> = vec![vec![batch(10)]]; + let source = MemorySourceConfig::try_new(&partitions, schema.clone(), None) + .unwrap() + .with_limit(Some(5)); + + assert_eq!(source.fetch, Some(5)); + + // Create a projection that reorders columns: [c, a] (indices 2, 0) + let projection = ProjectionExprs::from_indices(&[2, 0], &schema); + let swapped = source + .try_swapping_with_projection(&projection) + .unwrap() + .unwrap(); + let new_source = swapped.downcast_ref::().unwrap(); + + assert_eq!( + new_source.fetch, + Some(5), + "fetch limit must be preserved after projection pushdown" + ); + } + #[tokio::test] async fn values_empty_case() -> Result<()> { let schema = aggr_test_schema(); @@ -951,7 +985,7 @@ mod tests { let values = MemorySourceConfig::try_new_as_values(schema, data)?; assert_eq!( - values.partition_statistics(None)?, + *values.partition_statistics(None)?, Statistics { num_rows: Precision::Exact(rows), total_byte_size: Precision::Exact(8), // not important @@ -961,6 +995,7 @@ mod tests { max_value: Precision::Absent, min_value: Precision::Absent, sum_value: Precision::Absent, + byte_size: Precision::Absent, },], } ); @@ -1081,8 +1116,7 @@ mod tests { let actual = partitioned_datasrc .map(|datasrc| datasrc.output_partitioning().partition_count()); assert_eq!( - actual, - partition_cnt, + actual, partition_cnt, "partitioned datasrc does not match expected, we expected {should_exist}, instead found {actual:?}" ); } @@ -1200,9 +1234,8 @@ mod tests { // Starting = batch(100_000), batch(10_000), batch(100), batch(1). // It should have split as p1=batch(100_000), p2=[batch(10_000), batch(100), batch(1)] let partitioned_datasrc = partitioned_datasrc.unwrap(); - let Some(mem_src_config) = partitioned_datasrc - .as_any() - .downcast_ref::() + let Some(mem_src_config) = + partitioned_datasrc.downcast_ref::() else { unreachable!() }; @@ -1268,8 +1301,8 @@ mod tests { } #[test] - fn test_repartition_no_sort_information_no_output_ordering_lopsized_batches( - ) -> Result<()> { + fn test_repartition_no_sort_information_no_output_ordering_lopsized_batches() + -> Result<()> { let no_sort = vec![]; let no_output_ordering = None; @@ -1399,9 +1432,8 @@ mod tests { // Starting = batch(100_000), batch(1), batch(100), batch(10_000). // It should have split as p1=batch(100_000), p2=[batch(1), batch(100), batch(10_000)] let partitioned_datasrc = partitioned_datasrc.unwrap(); - let Some(mem_src_config) = partitioned_datasrc - .as_any() - .downcast_ref::() + let Some(mem_src_config) = + partitioned_datasrc.downcast_ref::() else { unreachable!() }; diff --git a/datafusion/datasource/src/mod.rs b/datafusion/datasource/src/mod.rs index 2c7d40d2fb3b9..82030e545a42e 100644 --- a/datafusion/datasource/src/mod.rs +++ b/datafusion/datasource/src/mod.rs @@ -23,9 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![cfg_attr(not(test), deny(clippy::clone_on_ref_ptr))] -// Enforce lint rule to prevent needless pass by value -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! A table that uses the `ObjectStore` listing capability @@ -41,6 +38,8 @@ pub mod file_scan_config; pub mod file_sink_config; pub mod file_stream; pub mod memory; +pub mod morsel; +pub mod projection; pub mod schema_adapter; pub mod sink; pub mod source; @@ -57,21 +56,35 @@ pub use self::url::ListingTableUrl; use crate::file_groups::FileGroup; use chrono::TimeZone; use datafusion_common::stats::Precision; -use datafusion_common::{exec_datafusion_err, ColumnStatistics, Result}; +use datafusion_common::{ColumnStatistics, Result, TableReference, exec_datafusion_err}; use datafusion_common::{ScalarValue, Statistics}; +use datafusion_physical_expr::LexOrdering; use futures::{Stream, StreamExt}; -use object_store::{path::Path, ObjectMeta}; use object_store::{GetOptions, GetRange, ObjectStore}; -pub use table_schema::TableSchema; +use object_store::{ObjectMeta, path::Path}; +pub use table_schema::{TableSchema, TableSchemaBuilder}; // Remove when add_row_stats is remove -#[allow(deprecated)] +use arrow::datatypes::SchemaRef; +#[expect(deprecated)] pub use statistics::add_row_stats; pub use statistics::compute_all_files_statistics; +use std::any::Any; use std::ops::Range; use std::pin::Pin; use std::sync::Arc; +/// User-defined per-file extension data, keyed by concrete Rust type. +/// +/// Re-exported from [`datafusion_common::extensions::Extensions`]; the same +/// type backs `SessionConfig::extensions`, `ExtendedStatistics::extensions`, +/// and other extension fields throughout DataFusion. +pub type FileExtensions = datafusion_common::extensions::Extensions; + /// Stream of files get listed from object store +#[deprecated( + since = "54.0.0", + note = "This type is unused and will be removed in a future release" +)] pub type PartitionedFileStream = Pin> + Send + Sync + 'static>>; @@ -96,6 +109,19 @@ impl FileRange { #[derive(Debug, Clone)] /// A single file or part of a file that should be read, along with its schema, statistics /// and partition column values that need to be appended to each row. +/// +/// # Statistics +/// +/// The [`Self::statistics`] field contains statistics for the **full table schema**, +/// which includes both file columns and partition columns. When statistics are set via +/// [`Self::with_statistics`], exact statistics for partition columns are automatically +/// computed from [`Self::partition_values`]: +/// +/// - `min = max = partition_value` (all rows in a file share the same partition value) +/// - `null_count = 0` (partition values extracted from paths are never null) +/// - `distinct_count = 1` (single distinct value per file for each partition column) +/// +/// This enables query optimizers to use partition column bounds for pruning and planning. pub struct PartitionedFile { /// Path for the file (e.g. URL, filesystem path, etc) pub object_meta: ObjectMeta, @@ -116,17 +142,45 @@ pub struct PartitionedFile { /// /// DataFusion relies on these statistics for planning (in particular to sort file groups), /// so if they are incorrect, incorrect answers may result. + /// + /// These statistics cover the full table schema: file columns plus partition columns. + /// When set via [`Self::with_statistics`], partition column statistics are automatically + /// computed from [`Self::partition_values`] with exact min/max/null_count/distinct_count. pub statistics: Option>, - /// An optional field for user defined per object metadata - pub extensions: Option>, + /// The known lexicographical ordering of the rows in this file, if any. + /// + /// This describes how the data within the file is sorted with respect to one or more + /// columns, and is used by the optimizer for planning operations that depend on input + /// ordering (e.g. merges, sorts, and certain aggregations). + /// + /// When available, this is typically inferred from file-level metadata exposed by the + /// underlying format (for example, Parquet `sorting_columns`), but it may also be set + /// explicitly via [`Self::with_ordering`]. + pub ordering: Option, + /// User-defined per-file metadata, keyed by Rust type. Multiple + /// independent components can each attach their own data here without + /// conflict — see [`FileExtensions`]. + pub extensions: FileExtensions, /// The estimated size of the parquet metadata, in bytes pub metadata_size_hint: Option, + pub table_reference: Option, + /// A user-provided physical Arrow schema for this file. + /// + /// This schema describes only the columns stored in the file. It must not + /// include partition columns; those are represented separately by + /// [`Self::partition_values`] and the scan's table partition columns. + /// + /// When provided, this field will be used by the Parquet reader to avoid + /// parsing the Arrow schema from the `ARROW:schema` metadata key. Other + /// built-in file sources ignore it for now. + pub arrow_schema: Option, } impl PartitionedFile { /// Create a simple file without metadata or partition pub fn new(path: impl Into, size: u64) -> Self { Self { + arrow_schema: None, object_meta: ObjectMeta { location: Path::from(path.into()), last_modified: chrono::Utc.timestamp_nanos(0), @@ -137,14 +191,32 @@ impl PartitionedFile { partition_values: vec![], range: None, statistics: None, - extensions: None, + ordering: None, + extensions: FileExtensions::new(), + metadata_size_hint: None, + table_reference: None, + } + } + + /// Create a file from a known ObjectMeta without partition + pub fn new_from_meta(object_meta: ObjectMeta) -> Self { + Self { + arrow_schema: None, + object_meta, + partition_values: vec![], + range: None, + statistics: None, + ordering: None, + extensions: FileExtensions::new(), metadata_size_hint: None, + table_reference: None, } } /// Create a file range without metadata or partition pub fn new_with_range(path: String, size: u64, start: i64, end: i64) -> Self { Self { + arrow_schema: None, object_meta: ObjectMeta { location: Path::from(path), last_modified: chrono::Utc.timestamp_nanos(0), @@ -155,12 +227,56 @@ impl PartitionedFile { partition_values: vec![], range: Some(FileRange { start, end }), statistics: None, - extensions: None, + ordering: None, + extensions: FileExtensions::new(), metadata_size_hint: None, + table_reference: None, } .with_range(start, end) } + /// Provide a physical Arrow schema for this file. + /// + /// The schema must describe only columns stored in the file and must not + /// include partition columns. See [`Self::arrow_schema`] for details. + pub fn with_arrow_schema(mut self, schema: SchemaRef) -> Self { + self.arrow_schema = Some(schema); + self + } + + /// Attach partition values to this file. + /// This replaces any existing partition values. + pub fn with_partition_values(mut self, partition_values: Vec) -> Self { + self.partition_values = partition_values; + self + } + + pub fn with_table_reference( + mut self, + table_reference: Option, + ) -> Self { + self.table_reference = table_reference; + self + } + + /// Size of the file to be scanned (taking into account the range, if present). + pub fn effective_size(&self) -> u64 { + if let Some(range) = &self.range { + (range.end - range.start) as u64 + } else { + self.object_meta.size + } + } + + /// Effective range of the file to be scanned. + pub fn range(&self) -> (u64, u64) { + if let Some(range) = &self.range { + (range.start as u64, range.end as u64) + } else { + (0, self.object_meta.size) + } + } + /// Provide a hint to the size of the file metadata. If a hint is provided /// the reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. /// Without an appropriate hint, two read may be required to fetch the metadata. @@ -186,20 +302,69 @@ impl PartitionedFile { self } - /// Update the user defined extensions for this file. + /// Attach a typed user-defined extension to this file. Multiple + /// independent extensions can be attached, each keyed by its concrete + /// Rust type. Inserting a value of a type that already has an extension + /// replaces the previous one. /// - /// This can be used to pass reader specific information. - pub fn with_extensions( - mut self, - extensions: Arc, - ) -> Self { - self.extensions = Some(extensions); + /// This can be used to pass reader-specific information (e.g. a + /// `ParquetAccessPlan`, or a custom index entry). + pub fn with_extension(mut self, value: T) -> Self { + self.extensions.insert(value); + self + } + + /// Borrow the extension of type `T`, if one is attached. + pub fn extension(&self) -> Option<&T> { + self.extensions.get::() + } + + /// Attach a type-erased extension to this file. + /// + /// Kept as a backwards-compatible shim; prefer [`Self::with_extension`] + /// which keys the extension by its concrete Rust type at the call site. + #[deprecated( + since = "54.0.0", + note = "use `with_extension`; the extension is keyed by its concrete type" + )] + pub fn with_extensions(mut self, extensions: Arc) -> Self { + #[expect(deprecated)] + self.extensions.insert_dyn(extensions); self } - // Update the statistics for this file. - pub fn with_statistics(mut self, statistics: Arc) -> Self { - self.statistics = Some(statistics); + /// Update the statistics for this file. + /// + /// The provided `statistics` should cover only the file schema columns. + /// This method will automatically append exact statistics for partition columns + /// based on `partition_values`: + /// - `min = max = partition_value` (all rows have the same value) + /// - `null_count = 0` (partition values from paths are never null) + /// - `distinct_count = 1` (all rows have the same partition value) + pub fn with_statistics(mut self, file_statistics: Arc) -> Self { + if self.partition_values.is_empty() { + // No partition columns, use stats as-is + self.statistics = Some(file_statistics); + } else { + // Extend stats with exact partition column statistics + let mut stats = Arc::unwrap_or_clone(file_statistics); + for partition_value in &self.partition_values { + let col_stats = ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(partition_value.clone()), + min_value: Precision::Exact(partition_value.clone()), + distinct_count: Precision::Exact(1), + sum_value: Precision::Absent, + byte_size: partition_value + .data_type() + .primitive_width() + .map(|w| stats.num_rows.multiply(&Precision::Exact(w))) + .unwrap_or_else(|| Precision::Absent), + }; + stats.column_statistics.push(col_stats); + } + self.statistics = Some(Arc::new(stats)); + } self } @@ -219,17 +384,29 @@ impl PartitionedFile { false } } + + /// Set the known ordering of data in this file. + /// + /// The ordering represents the lexicographical sort order of the data, + /// typically inferred from file metadata (e.g., Parquet sorting_columns). + pub fn with_ordering(mut self, ordering: Option) -> Self { + self.ordering = ordering; + self + } } impl From for PartitionedFile { fn from(object_meta: ObjectMeta) -> Self { PartitionedFile { object_meta, + arrow_schema: None, partition_values: vec![], range: None, statistics: None, - extensions: None, + ordering: None, + extensions: FileExtensions::new(), metadata_size_hint: None, + table_reference: None, } } } @@ -287,6 +464,10 @@ pub async fn calculate_range( 0 }; + if start + start_delta > end { + return Ok(RangeCalculation::TerminateEarly); + } + let end_delta = if end != file_size { find_first_newline(store, location, end - 1, file_size, newline).await? } else { @@ -295,7 +476,7 @@ pub async fn calculate_range( let range = start + start_delta..end + end_delta; - if range.start == range.end { + if range.start >= range.end { return Ok(RangeCalculation::TerminateEarly); } @@ -399,6 +580,7 @@ pub fn generate_test_files(num_files: usize, overlap_factor: f64) -> Vec Vec + #[tokio::test] + async fn test_calculate_range_single_line_file() { + use super::{PartitionedFile, RangeCalculation, calculate_range}; + use object_store::ObjectStore; + use object_store::memory::InMemory; + + let content = r#"{"id":1,"data":"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}"#; + let file_size = content.len() as u64; + + let store: Arc = Arc::new(InMemory::new()); + let path = Path::from("test.json"); + store.put(&path, content.into()).await.unwrap(); + + let mid = file_size / 2; + let partitioned_file = PartitionedFile::new_with_range( + path.to_string(), + file_size, + mid as i64, + file_size as i64, + ); + + let result = calculate_range(&partitioned_file, &store, None).await; + + assert!(matches!(result, Ok(RangeCalculation::TerminateEarly))); + } } diff --git a/datafusion/datasource/src/morsel/adapters.rs b/datafusion/datasource/src/morsel/adapters.rs new file mode 100644 index 0000000000000..6fa6d4916771d --- /dev/null +++ b/datafusion/datasource/src/morsel/adapters.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::PartitionedFile; +use crate::file_stream::FileOpener; +use crate::morsel::{Morsel, MorselPlan, MorselPlanner, Morselizer}; +use arrow::array::RecordBatch; +use datafusion_common::Result; +use futures::FutureExt; +use futures::stream::BoxStream; +use std::fmt::Debug; +use std::sync::Arc; + +/// Adapt a legacy [`FileOpener`] to the morsel API. +/// +/// This preserves backwards compatibility for file formats that have not yet +/// implemented a native [`Morselizer`]. +pub struct FileOpenerMorselizer { + file_opener: Arc, +} + +impl Debug for FileOpenerMorselizer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FileOpenerMorselizer") + .field("file_opener", &"...") + .finish() + } +} + +impl FileOpenerMorselizer { + pub fn new(file_opener: Arc) -> Self { + Self { file_opener } + } +} + +impl Morselizer for FileOpenerMorselizer { + fn plan_file(&self, file: PartitionedFile) -> Result> { + Ok(Box::new(FileOpenFutureMorselPlanner::new( + Arc::clone(&self.file_opener), + file, + ))) + } +} + +enum FileOpenFutureMorselPlanner { + Unopened { + file_opener: Arc, + file: Box, + }, + ReadyStream(BoxStream<'static, Result>), +} + +impl Debug for FileOpenFutureMorselPlanner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Unopened { .. } => f + .debug_tuple("FileOpenFutureMorselPlanner::Unopened") + .finish(), + Self::ReadyStream(_) => f + .debug_tuple("FileOpenFutureMorselPlanner::ReadyStream") + .finish(), + } + } +} + +impl FileOpenFutureMorselPlanner { + fn new(file_opener: Arc, file: PartitionedFile) -> Self { + Self::Unopened { + file_opener, + file: Box::new(file), + } + } +} + +impl MorselPlanner for FileOpenFutureMorselPlanner { + fn plan(self: Box) -> Result> { + match *self { + Self::Unopened { file_opener, file } => { + let io_future = async move { + let stream = file_opener.open(*file)?.await?; + Ok(Box::new(Self::ReadyStream(stream)) as Box) + } + .boxed(); + Ok(Some(MorselPlan::new().with_pending_planner(io_future))) + } + Self::ReadyStream(stream) => Ok(Some( + MorselPlan::new() + .with_morsels(vec![Box::new(FileStreamMorsel { stream })]), + )), + } + } +} + +struct FileStreamMorsel { + stream: BoxStream<'static, Result>, +} + +impl Debug for FileStreamMorsel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FileStreamMorsel").finish_non_exhaustive() + } +} + +impl Morsel for FileStreamMorsel { + fn into_stream(self: Box) -> BoxStream<'static, Result> { + self.stream + } +} diff --git a/datafusion/datasource/src/morsel/mocks.rs b/datafusion/datasource/src/morsel/mocks.rs new file mode 100644 index 0000000000000..ceb0e720691a7 --- /dev/null +++ b/datafusion/datasource/src/morsel/mocks.rs @@ -0,0 +1,746 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Test-only mocks for exercising the morsel-driven `FileStream` scheduler. + +use std::collections::{HashMap, VecDeque}; +use std::fmt::{Display, Formatter}; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; + +use crate::PartitionedFile; +use crate::morsel::{Morsel, MorselPlan, MorselPlanner, Morselizer}; +use arrow::array::{Int32Array, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::{DataFusionError, Result, internal_datafusion_err}; +use futures::stream::BoxStream; +use futures::{Future, FutureExt}; + +// Use thin wrappers around usize so the test setups are more explicit + +/// Identifier for a mock morsel in scheduler snapshots. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub(crate) struct MorselId(pub usize); + +/// Identifier for a produced batch in scheduler snapshots. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub(crate) struct BatchId(pub usize); + +/// Identifier for a mock I/O future in scheduler snapshots. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub(crate) struct IoFutureId(pub usize); + +/// Number of pending polls before a mock I/O future resolves. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub(crate) struct PollsToResolve(pub usize); + +/// Error message returned by a mock planner or I/O future. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct MockError(pub String); + +impl Display for MockError { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl std::error::Error for MockError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None + } +} + +/// Scheduler-visible event captured by the mock morsel test harness. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum MorselEvent { + MorselizeFile { + path: String, + }, + PlannerCreated { + planner_name: String, + }, + PlannerCalled { + planner_name: String, + }, + IoFutureCreated { + planner_name: String, + io_future_id: IoFutureId, + }, + IoFuturePolled { + planner_name: String, + io_future_id: IoFutureId, + }, + IoFutureResolved { + planner_name: String, + io_future_id: IoFutureId, + }, + IoFutureErrored { + planner_name: String, + io_future_id: IoFutureId, + message: String, + }, + MorselProduced { + planner_name: String, + morsel_id: MorselId, + }, + MorselStreamStarted { + morsel_id: MorselId, + }, + MorselStreamBatchProduced { + morsel_id: MorselId, + batch_id: BatchId, + }, + MorselStreamFinished { + morsel_id: MorselId, + }, +} + +impl Display for MorselEvent { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + MorselEvent::MorselizeFile { path } => { + write!(f, "morselize_file: {path}") + } + MorselEvent::PlannerCreated { planner_name } => { + write!(f, "planner_created: {planner_name}") + } + MorselEvent::PlannerCalled { planner_name } => { + write!(f, "planner_called: {planner_name}") + } + MorselEvent::IoFutureCreated { + planner_name, + io_future_id, + } => write!(f, "io_future_created: {planner_name}, {io_future_id:?}"), + MorselEvent::IoFuturePolled { + planner_name, + io_future_id, + } => write!(f, "io_future_polled: {planner_name}, {io_future_id:?}"), + MorselEvent::IoFutureResolved { + planner_name, + io_future_id, + } => write!(f, "io_future_resolved: {planner_name}, {io_future_id:?}"), + MorselEvent::IoFutureErrored { + planner_name, + io_future_id, + message, + } => write!( + f, + "io_future_errored: {planner_name}, {io_future_id:?}, {message}" + ), + MorselEvent::MorselProduced { + planner_name, + morsel_id, + } => write!(f, "morsel_produced: {planner_name}, {morsel_id:?}"), + MorselEvent::MorselStreamStarted { morsel_id } => { + write!(f, "morsel_stream_started: {morsel_id:?}") + } + MorselEvent::MorselStreamBatchProduced { + morsel_id, + batch_id, + } => write!( + f, + "morsel_stream_batch_produced: {morsel_id:?}, {batch_id:?}" + ), + MorselEvent::MorselStreamFinished { morsel_id } => { + write!(f, "morsel_stream_finished: {morsel_id:?}") + } + } + } +} + +/// Shared observer that records scheduler events for snapshot tests. +#[derive(Debug, Default, Clone)] +pub(crate) struct MorselObserver { + events: Arc>>, +} + +impl MorselObserver { + /// Clears any previously recorded events. + pub(crate) fn clear(&self) { + self.events.lock().unwrap().clear(); + } + + /// Records one new scheduler event. + pub(crate) fn push(&self, event: MorselEvent) { + self.events.lock().unwrap().push(event); + } + + /// Formats all recorded events into a stable, snapshot-friendly trace. + pub(crate) fn format_events(&self) -> String { + self.events + .lock() + .unwrap() + .iter() + .map(ToString::to_string) + .collect::>() + .join("\n") + } +} + +/// Declarative planner spec used by the mock morselizer. +#[derive(Debug, Clone)] +pub(crate) struct MockPlanner { + file_path: String, + steps: VecDeque, +} + +impl MockPlanner { + /// Creates a fluent builder for one mock planner. + pub(crate) fn builder(file_path: impl Into) -> MockPlannerBuilder { + MockPlannerBuilder { + file_path: file_path.into(), + ..Default::default() + } + } + + /// Returns the file path associated with this planner. + pub(crate) fn file_path(&self) -> &str { + &self.file_path + } +} + +/// One scheduler-visible result from calling `MorselPlanner::plan`. +#[derive(Debug, Clone)] +enum PlannerStep { + Plan { + morsels: Vec, + ready_planners: Vec, + pending_planner: Option, + }, + Error { + error: MockError, + }, + None, +} + +/// One mock morsel returned from a planning step. +#[derive(Debug, Clone)] +struct MockMorselSpec { + morsel_id: MorselId, + batch_ids: Vec, +} + +/// One pending planner I/O future returned from a planning step. +#[derive(Debug, Clone)] +struct MockPendingPlanner { + io_future_id: IoFutureId, + polls_to_resolve: PollsToResolve, + result: std::result::Result<(), MockError>, +} + +/// Builder for one mock `PlannerStep::Plan`. +#[derive(Debug, Default)] +pub(crate) struct MockPlanBuilder { + morsels: Vec, + ready_planners: Vec, + pending_planner: Option, +} + +impl MockPlanBuilder { + /// Create an empty mock plan. + pub(crate) fn new() -> Self { + Self::default() + } + + /// Add one ready morsel with a single batch. + pub(crate) fn with_morsel(mut self, morsel_id: MorselId, batch_id: i32) -> Self { + self.morsels.push(MockMorselSpec { + morsel_id, + batch_ids: vec![batch_id], + }); + self + } + + /// Add one ready morsel with multiple batches. + pub(crate) fn with_morsel_batches( + mut self, + morsel_id: MorselId, + batch_ids: Vec, + ) -> Self { + self.morsels.push(MockMorselSpec { + morsel_id, + batch_ids, + }); + self + } + + /// Add a pending planner I/O future produced by this planning step. + pub(crate) fn with_pending_planner( + mut self, + io_future_id: IoFutureId, + polls_to_resolve: PollsToResolve, + result: std::result::Result<(), MockError>, + ) -> Self { + self.pending_planner = Some(MockPendingPlanner { + io_future_id, + polls_to_resolve, + result, + }); + self + } + + /// Add a ready child planner + pub(crate) fn with_ready_planner( + self, + ready_planner: impl Into, + ) -> Self { + self.with_ready_planners(vec![ready_planner.into()]) + } + + /// Add ready child planners produced by this planning step. + pub(crate) fn with_ready_planners( + mut self, + ready_planners: Vec, + ) -> Self { + self.ready_planners.extend(ready_planners); + self + } + + /// Build the planner step. + fn build(self) -> PlannerStep { + PlannerStep::Plan { + morsels: self.morsels, + ready_planners: self.ready_planners, + pending_planner: self.pending_planner, + } + } +} + +/// Builder for a planning step that only returns a pending planner. +#[derive(Debug, Clone)] +pub(crate) struct PendingPlannerBuilder { + io_future_id: IoFutureId, + polls_to_resolve: PollsToResolve, + result: std::result::Result<(), MockError>, +} + +impl From for MockPlanBuilder { + fn from(builder: PendingPlannerBuilder) -> Self { + builder.build() + } +} + +impl PendingPlannerBuilder { + /// Create a pending-planner step with a successful I/O future. + pub(crate) fn new(io_future_id: IoFutureId) -> Self { + Self { + io_future_id, + polls_to_resolve: PollsToResolve(0), + result: Ok(()), + } + } + + /// Configure how many pending polls occur before the I/O future resolves. + pub(crate) fn with_polls_to_resolve( + mut self, + polls_to_resolve: PollsToResolve, + ) -> Self { + self.polls_to_resolve = polls_to_resolve; + self + } + + /// Configure a failing I/O future for this pending planner. + pub(crate) fn with_error(mut self, message: impl Into) -> Self { + self.result = Err(MockError(message.into())); + self + } + + /// Build a `MockPlanBuilder` containing only this pending planner. + pub(crate) fn build(self) -> MockPlanBuilder { + MockPlanBuilder::new().with_pending_planner( + self.io_future_id, + self.polls_to_resolve, + self.result, + ) + } +} + +/// Fluent builder for [`MockPlanner`] test specs. +#[derive(Debug, Default)] +pub(crate) struct MockPlannerBuilder { + file_path: String, + steps: Vec, +} + +impl From for MockPlanner { + fn from(value: MockPlannerBuilder) -> Self { + value.build() + } +} + +impl MockPlannerBuilder { + pub(crate) fn add_plan(mut self, builder: impl Into) -> Self { + let builder = builder.into(); + self.steps.push(builder.build()); + self + } + + /// Adds one planning step that reports the planner is exhausted. + pub(crate) fn return_none(mut self) -> Self { + self.steps.push(PlannerStep::None); + self + } + + /// Adds one planning step that fails during CPU planning. + pub(crate) fn return_error(mut self, message: impl Into) -> Self { + self.steps.push(PlannerStep::Error { + error: MockError(message.into()), + }); + self + } + + /// Finalizes the configured mock planner. + pub(crate) fn build(self) -> MockPlanner { + let Self { file_path, steps } = self; + + MockPlanner { + file_path, + steps: VecDeque::from(steps), + } + } +} + +/// Mock [`Morselizer`] that maps file paths to fixed planner specs. +#[derive(Debug, Clone, Default)] +pub(crate) struct MockMorselizer { + observer: MorselObserver, + files: HashMap, +} + +impl MockMorselizer { + /// Creates an empty mock morselizer. + pub(crate) fn new() -> Self { + Self::default() + } + + /// Returns the shared event observer for this test harness. + pub(crate) fn observer(&self) -> &MorselObserver { + &self.observer + } + + /// Specify the return planner for the specified file_path + pub(crate) fn with_planner(mut self, planner: impl Into) -> Self { + let planner = planner.into(); + self.files.insert(planner.file_path.clone(), planner); + self + } +} + +impl Morselizer for MockMorselizer { + fn plan_file(&self, file: PartitionedFile) -> Result> { + let path = file.object_meta.location.to_string(); + self.observer + .push(MorselEvent::MorselizeFile { path: path.clone() }); + + let planner = self.files.get(&path).cloned().ok_or_else(|| { + internal_datafusion_err!("No mock planner configured for file: {path}") + })?; + + self.observer.push(MorselEvent::PlannerCreated { + planner_name: planner.file_path.clone(), + }); + + Ok(Box::new(MockMorselPlanner::new( + self.observer.clone(), + planner, + ))) + } +} + +/// Concrete mock planner that executes one predefined step per `plan()` call. +#[derive(Debug)] +struct MockMorselPlanner { + observer: MorselObserver, + planner_name: String, + steps: VecDeque, +} + +impl MockMorselPlanner { + /// Creates a concrete planner from its declarative test spec. + fn new(observer: MorselObserver, planner: MockPlanner) -> Self { + Self { + observer, + planner_name: planner.file_path, + steps: planner.steps, + } + } +} + +/// Rebuilds the mock planner continuation after one step completes. +fn current_planner_continuation( + observer: MorselObserver, + planner_name: String, + steps: VecDeque, +) -> Vec> { + let only_none_remaining = + matches!(steps.front(), Some(PlannerStep::None)) && steps.len() == 1; + + if steps.is_empty() || only_none_remaining { + Vec::new() + } else { + vec![Box::new(MockMorselPlanner { + observer, + planner_name, + steps, + }) as Box] + } +} + +/// Create any child planners produced by this planning step. +fn child_planners( + observer: MorselObserver, + ready_planners: Vec, +) -> Vec> { + ready_planners + .into_iter() + .map(|planner| { + observer.push(MorselEvent::PlannerCreated { + planner_name: planner.file_path.clone(), + }); + Box::new(MockMorselPlanner::new(observer.clone(), planner)) + as Box + }) + .collect() +} + +impl MorselPlanner for MockMorselPlanner { + fn plan(self: Box) -> Result> { + let Self { + observer, + planner_name, + mut steps, + } = *self; + + observer.push(MorselEvent::PlannerCalled { + planner_name: planner_name.clone(), + }); + + let Some(step) = steps.pop_front() else { + return Ok(None); + }; + + match step { + PlannerStep::Plan { + morsels, + ready_planners, + pending_planner, + } => { + let mut ready_morsels = Vec::new(); + for MockMorselSpec { + morsel_id, + batch_ids, + } in morsels + { + observer.push(MorselEvent::MorselProduced { + planner_name: planner_name.clone(), + morsel_id, + }); + ready_morsels.push(Box::new(MockMorsel::new( + observer.clone(), + morsel_id, + batch_ids, + )) as Box); + } + + let mut planners = child_planners(observer.clone(), ready_planners); + if pending_planner.is_none() { + planners.extend(current_planner_continuation( + observer.clone(), + planner_name.clone(), + steps.clone(), + )); + } + + let mut plan = MorselPlan::new() + .with_morsels(ready_morsels) + .with_planners(planners); + + if let Some(MockPendingPlanner { + io_future_id, + polls_to_resolve, + result, + }) = pending_planner + { + observer.push(MorselEvent::IoFutureCreated { + planner_name: planner_name.clone(), + io_future_id, + }); + let io_future = MockIoFuture::new( + observer.clone(), + planner_name.clone(), + io_future_id, + polls_to_resolve, + result, + ) + .map(move |result| { + result?; + Ok(Box::new(MockMorselPlanner { + observer, + planner_name, + steps, + }) as Box) + }) + .boxed(); + plan = plan.with_pending_planner(io_future); + } + + Ok(Some(plan)) + } + PlannerStep::Error { error } => { + Err(DataFusionError::External(Box::new(error))) + } + PlannerStep::None => Ok(None), + } + } +} + +/// Concrete morsel used by the mock scheduler tests. +#[derive(Debug)] +pub(crate) struct MockMorsel { + observer: MorselObserver, + morsel_id: MorselId, + batch_ids: Vec, +} + +impl MockMorsel { + /// Creates a mock morsel with a deterministic sequence of batches. + fn new(observer: MorselObserver, morsel_id: MorselId, batch_ids: Vec) -> Self { + Self { + observer, + morsel_id, + batch_ids, + } + } +} + +impl Morsel for MockMorsel { + fn into_stream(self: Box) -> BoxStream<'static, Result> { + self.observer.push(MorselEvent::MorselStreamStarted { + morsel_id: self.morsel_id, + }); + Box::pin(MockMorselStream { + observer: self.observer.clone(), + morsel_id: self.morsel_id, + batch_ids: self.batch_ids.into(), + finished: false, + }) + } +} + +/// Stream returned by [`MockMorsel::into_stream`]. +struct MockMorselStream { + observer: MorselObserver, + morsel_id: MorselId, + batch_ids: VecDeque, + finished: bool, +} + +impl futures::Stream for MockMorselStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + if let Some(batch_id) = self.batch_ids.pop_front() { + self.observer.push(MorselEvent::MorselStreamBatchProduced { + morsel_id: self.morsel_id, + batch_id: BatchId(batch_id as usize), + }); + return Poll::Ready(Some(Ok(single_value_batch(batch_id)))); + } + + if !self.finished { + self.finished = true; + self.observer.push(MorselEvent::MorselStreamFinished { + morsel_id: self.morsel_id, + }); + } + + Poll::Ready(None) + } +} + +/// Deterministic future used to simulate planner I/O in tests. +struct MockIoFuture { + observer: MorselObserver, + planner_name: String, + io_future_id: IoFutureId, + pending_polls_remaining: usize, + result: std::result::Result<(), MockError>, +} + +impl MockIoFuture { + /// Creates a future that resolves after `io_polls` pending polls. + fn new( + observer: MorselObserver, + planner_name: String, + io_future_id: IoFutureId, + polls_to_resolve: PollsToResolve, + result: std::result::Result<(), MockError>, + ) -> Self { + Self { + observer, + planner_name, + io_future_id, + pending_polls_remaining: polls_to_resolve.0, + result, + } + } +} + +impl Future for MockIoFuture { + type Output = Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.observer.push(MorselEvent::IoFuturePolled { + planner_name: self.planner_name.clone(), + io_future_id: self.io_future_id, + }); + + if self.pending_polls_remaining > 0 { + self.pending_polls_remaining -= 1; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + match &self.result { + Ok(()) => { + self.observer.push(MorselEvent::IoFutureResolved { + planner_name: self.planner_name.clone(), + io_future_id: self.io_future_id, + }); + Poll::Ready(Ok(())) + } + Err(e) => { + self.observer.push(MorselEvent::IoFutureErrored { + planner_name: self.planner_name.clone(), + io_future_id: self.io_future_id, + message: e.0.clone(), + }); + Poll::Ready(Err(DataFusionError::External(Box::new(e.clone())))) + } + } + } +} + +/// Creates a one-row batch so snapshot output stays compact and readable. +fn single_value_batch(value: i32) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); + RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![value]))]).unwrap() +} diff --git a/datafusion/datasource/src/morsel/mod.rs b/datafusion/datasource/src/morsel/mod.rs new file mode 100644 index 0000000000000..7b5066ca07a26 --- /dev/null +++ b/datafusion/datasource/src/morsel/mod.rs @@ -0,0 +1,234 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Structures for Morsel Driven IO. +//! +//! NOTE: As of DataFusion 54.0.0, these are experimental APIs that may change +//! substantially. +//! +//! Morsel Driven IO is a technique for parallelizing the reading of large files +//! by dividing them into smaller "morsels" that are processed independently. +//! +//! It is inspired by the paper [Morsel-Driven Parallelism: A NUMA-Aware Query +//! Evaluation Framework for the Many-Core Age](https://db.in.tum.de/~leis/papers/morsels.pdf). + +mod adapters; +#[cfg(test)] +pub(crate) mod mocks; + +use crate::PartitionedFile; +pub(crate) use adapters::FileOpenerMorselizer; +use arrow::array::RecordBatch; +use datafusion_common::Result; +use futures::FutureExt; +use futures::future::BoxFuture; +use futures::stream::BoxStream; +use std::fmt::Debug; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// A Morsel of work ready to resolve to a stream of [`RecordBatch`]es. +/// +/// This represents a single morsel of work that is ready to be processed. It +/// has all data necessary (does not need any I/O) and is ready to be turned +/// into a stream of [`RecordBatch`]es for processing by the execution engine. +pub trait Morsel: Send + Debug { + /// Consume this morsel and produce a stream of [`RecordBatch`]es for processing. + /// + /// Note: This may do CPU work to decode already-loaded data, but should not + /// do any I/O work such as reading from the file. + fn into_stream(self: Box) -> BoxStream<'static, Result>; +} + +/// A Morselizer takes a single [`PartitionedFile`] and creates the initial planner +/// for that file. +/// +/// This is the entry point for morsel driven I/O. +pub trait Morselizer: Send + Sync + Debug { + /// Return the initial [`MorselPlanner`] for this file. + /// + /// Morselizing a file may involve CPU work, such as parsing parquet + /// metadata and evaluating pruning predicates. It should NOT do any I/O + /// work, such as reading from the file. Any needed I/O should be done using + /// [`MorselPlan::with_pending_planner`]. + fn plan_file(&self, file: PartitionedFile) -> Result>; +} + +/// A Morsel Planner is responsible for creating morsels for a given scan. +/// +/// The [`MorselPlanner`] is the unit of I/O. There is only ever a single I/O +/// outstanding for a specific planner. DataFusion may run +/// multiple planners in parallel, which corresponds to multiple parallel +/// I/O requests. +/// +/// It is not a Rust `Stream` so that it can explicitly separate CPU bound +/// work from I/O work. +/// +/// The design is similar to `ParquetPushDecoder`: when `plan` is called, it +/// should do CPU work to produce the next morsels or discover the next I/O +/// phase. +/// +/// Best practice is to spawn I/O in a Tokio task on a separate runtime to +/// ensure that CPU work doesn't block or slow down I/O work, but this is not +/// strictly required by the API. +pub trait MorselPlanner: Send + Debug { + /// Attempt to plan morsels. This may involve CPU work, such as parsing + /// parquet metadata and evaluating pruning predicates. + /// + /// It should NOT do any I/O work, such as reading from the file. If I/O is + /// required, the returned [`MorselPlan`] should contain a pending planner + /// future that the caller polls to drive the I/O work to completion. Once + /// that future resolves, it yields a planner ready for work. + /// + /// Note this function is **not async** to make it explicitly clear that if + /// I/O is required, it should be done in the returned `io_future`. + /// + /// Returns `None` if the planner has no more work to do. + /// + /// # Empty Morsel Plans + /// + /// It may return `None`, which means no batches will be read from the file + /// (e.g. due to late-pruning based on statistics). + /// + /// # Output Ordering + /// + /// See the comments on [`MorselPlan`] for the logical output order. + fn plan(self: Box) -> Result>; +} + +/// Return result of [`MorselPlanner::plan`]. +/// +/// # Logical Ordering +/// +/// For plans where the output order of rows is maintained, the output order of +/// a [`MorselPlanner`] is logically defined as follows: +/// 1. All morsels that are directly produced +/// 2. Recursively, all morsels produced by the returned `planners` +#[derive(Default)] +pub struct MorselPlan { + /// Morsels ready for CPU work + morsels: Vec>, + /// Planners that are ready for CPU work. + ready_planners: Vec>, + /// A future with planner I/O that resolves to a CPU ready planner. + /// + /// DataFusion will poll this future occasionally to drive the I/O work to + /// completion. Once it resolves, planning continues with the returned + /// planner. + pending_planner: Option, +} + +impl MorselPlan { + /// Create an empty morsel plan. + pub fn new() -> Self { + Self::default() + } + + /// Set the ready morsels. + pub fn with_morsels(mut self, morsels: Vec>) -> Self { + self.morsels = morsels; + self + } + + /// Set the ready child planners. + pub fn with_planners(mut self, planners: Vec>) -> Self { + self.ready_planners = planners; + self + } + + /// Set the pending planner for an I/O phase. + pub fn with_pending_planner(mut self, io_future: F) -> Self + where + F: Future>> + Send + 'static, + { + self.pending_planner = Some(PendingMorselPlanner::new(io_future)); + self + } + + /// Set the pending planner for an I/O phase. + pub fn set_pending_planner(&mut self, io_future: F) + where + F: Future>> + Send + 'static, + { + self.pending_planner = Some(PendingMorselPlanner::new(io_future)); + } + + /// Take the ready morsels. + pub fn take_morsels(&mut self) -> Vec> { + std::mem::take(&mut self.morsels) + } + + /// Take the ready child planners. + pub fn take_ready_planners(&mut self) -> Vec> { + std::mem::take(&mut self.ready_planners) + } + + /// Take the pending I/O future, if any. + pub fn take_pending_planner(&mut self) -> Option { + self.pending_planner.take() + } + + /// Returns `true` if this plan contains an I/O future. + pub fn has_io_future(&self) -> bool { + self.pending_planner.is_some() + } +} + +/// Wrapper for I/O that must complete before planning can continue. +pub struct PendingMorselPlanner { + future: BoxFuture<'static, Result>>, +} + +impl PendingMorselPlanner { + /// Create a new pending planner future. + /// + /// Example + /// ``` + /// # use datafusion_common::DataFusionError; + /// # use datafusion_datasource::morsel::{MorselPlanner, PendingMorselPlanner}; + /// let work = async move { + /// let planner: Box = { + /// // Do I/O work here, then return the next planner to run. + /// # unimplemented!(); + /// }; + /// Ok(planner) as Result<_, DataFusionError>; + /// }; + /// let pending_io = PendingMorselPlanner::new(work); + /// ``` + pub fn new(future: F) -> Self + where + F: Future>> + Send + 'static, + { + Self { + future: future.boxed(), + } + } + + /// Consume this wrapper and return the underlying future. + pub fn into_future(self) -> BoxFuture<'static, Result>> { + self.future + } +} + +/// Forwards polling to the underlying future. +impl Future for PendingMorselPlanner { + type Output = Result>; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // forward request to inner + self.future.as_mut().poll(cx) + } +} diff --git a/datafusion/datasource/src/projection.rs b/datafusion/datasource/src/projection.rs new file mode 100644 index 0000000000000..ac33a96ca8321 --- /dev/null +++ b/datafusion/datasource/src/projection.rs @@ -0,0 +1,630 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::{Schema, SchemaRef}; +use datafusion_common::{ + Result, ScalarValue, + tree_node::{Transformed, TransformedResult, TreeNode}, +}; +use datafusion_physical_expr::{ + expressions::{Column, Literal}, + projection::{ProjectionExpr, ProjectionExprs}, +}; +use futures::{FutureExt, StreamExt}; +use itertools::Itertools; + +use crate::{ + PartitionedFile, TableSchema, + file_stream::{FileOpenFuture, FileOpener}, +}; + +/// A file opener that handles applying a projection on top of an inner opener. +/// +/// This includes handling partition columns. +/// +/// Any projection pushed down will be split up into: +/// - Simple column indices / column selection +/// - A remainder projection that this opener applies on top of it +/// +/// This is meant to simplify projection pushdown for sources like CSV +/// that can only handle "simple" column selection. +pub struct ProjectionOpener { + inner: Arc, + projection: ProjectionExprs, + input_schema: SchemaRef, + partition_columns: Vec, +} + +impl ProjectionOpener { + pub fn try_new( + projection: SplitProjection, + inner: Arc, + file_schema: &Schema, + ) -> Result> { + Ok(Arc::new(ProjectionOpener { + inner, + projection: projection.remapped_projection, + input_schema: Arc::new(file_schema.project(&projection.file_indices)?), + partition_columns: projection.partition_columns, + })) + } +} + +impl FileOpener for ProjectionOpener { + fn open(&self, partitioned_file: PartitionedFile) -> Result { + let partition_values = partitioned_file.partition_values.clone(); + // Modify any references to partition columns in the projection expressions + // and substitute them with literal values from PartitionedFile.partition_values + let projection = if self.partition_columns.is_empty() { + self.projection.clone() + } else { + inject_partition_columns_into_projection( + &self.projection, + &self.partition_columns, + partition_values, + ) + }; + let projector = projection.make_projector(&self.input_schema)?; + + let inner = self.inner.open(partitioned_file)?; + + Ok(async move { + let stream = inner.await?; + let stream = stream.map(move |batch| { + let batch = batch?; + let batch = projector.project_batch(&batch)?; + Ok(batch) + }); + Ok(stream.boxed()) + } + .boxed()) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct PartitionColumnIndex { + /// The index of this partition column in the remainder projection (>= num_file_columns) + pub in_remainder_projection: usize, + /// The index of this partition column in the partition_values array + pub in_partition_values: usize, +} + +fn inject_partition_columns_into_projection( + projection: &ProjectionExprs, + partition_columns: &[PartitionColumnIndex], + partition_values: Vec, +) -> ProjectionExprs { + // Pre-create all literals for partition columns to avoid cloning ScalarValues multiple times. + let partition_literals: Vec> = partition_values + .into_iter() + .map(|value| Arc::new(Literal::new(value))) + .collect(); + + let projections = projection + .iter() + .map(|projection| { + let expr = Arc::clone(&projection.expr) + .transform(|expr| { + let original_expr = Arc::clone(&expr); + if let Some(column) = expr.downcast_ref::() { + // Check if this column index corresponds to a partition column + if let Some(pci) = partition_columns + .iter() + .find(|pci| pci.in_remainder_projection == column.index()) + { + let literal = + Arc::clone(&partition_literals[pci.in_partition_values]); + return Ok(Transformed::yes(literal)); + } + } + Ok(Transformed::no(original_expr)) + }) + .data() + .expect("infallible transform"); + ProjectionExpr::new(expr, projection.alias.clone()) + }) + .collect_vec(); + ProjectionExprs::new(projections) +} + +/// At a high level the goal of SplitProjection is to take a ProjectionExprs meant to be applied to the table schema +/// and split that into: +/// - The projection indices into the file schema (file_indices) +/// - The projection indices into the partition values (partition_value_indices), which pre-compute both the index into the table schema +/// and the index into the partition values array +/// - A remapped projection that can be applied after the file projection is applied +/// This remapped projection has the following properties: +/// - Column indices referring to file columns are remapped to [0..file_indices.len()) +/// - Column indices referring to partition columns are remapped to [file_indices.len()..) +/// +/// This allows the ProjectionOpener to easily identify which columns in the remapped projection +/// refer to partition columns and substitute them with literals from the partition values. +#[derive(Debug, Clone)] +pub struct SplitProjection { + /// The original projection this [`SplitProjection`] was derived from + pub source: ProjectionExprs, + /// Column indices to read from file (public for file sources) + pub file_indices: Vec, + /// Pre-computed partition column mappings (internal, used by ProjectionOpener) + pub(crate) partition_columns: Vec, + /// The remapped projection (internal, used by ProjectionOpener) + pub(crate) remapped_projection: ProjectionExprs, +} + +impl SplitProjection { + pub fn unprojected(table_schema: &TableSchema) -> Self { + let projection = ProjectionExprs::from_indices( + &(0..table_schema.table_schema().fields().len()).collect_vec(), + table_schema.table_schema(), + ); + Self::new(table_schema.file_schema(), &projection) + } + + /// Creates a new [`SplitProjection`] by splitting a projection into + /// simple file column indices and a remainder projection that is applied after reading the file. + /// + /// In other words: we get a `Vec` projection that is meant to be applied on top of `file_schema` + /// and a remainder projection that is applied to the result of that first projection. + /// + /// Here `file_schema` is expected to be the *logical* schema of the file, that is the + /// table schema minus any partition columns. + /// Partition columns are always expected to be at the end of the table schema. + /// Note that `file_schema` is *not* the physical schema of the file. + pub fn new(logical_file_schema: &Schema, projection: &ProjectionExprs) -> Self { + let num_file_schema_columns = logical_file_schema.fields().len(); + + // Collect all unique columns and classify as file or partition + let mut file_columns = Vec::new(); + let mut partition_columns = Vec::new(); + let mut all_columns = std::collections::HashMap::new(); + + // Extract all unique column references (index -> name) + for proj_expr in projection { + proj_expr + .expr + .apply(|expr| { + if let Some(column) = expr.downcast_ref::() { + all_columns + .entry(column.index()) + .or_insert_with(|| column.name().to_string()); + } + Ok(datafusion_common::tree_node::TreeNodeRecursion::Continue) + }) + .expect("infallible apply"); + } + + // Sort by index and classify into file vs partition columns + let mut sorted_columns: Vec<_> = all_columns + .into_iter() + .map(|(idx, name)| (name, idx)) + .collect(); + sorted_columns.sort_by_key(|(_, idx)| *idx); + + // Separate file and partition columns, assigning final indices + // Pre-create all remapped columns to avoid duplicate Arc'd expressions + let mut column_mapping = std::collections::HashMap::new(); + let mut file_idx = 0; + let mut partition_idx = 0; + + for (name, original_index) in sorted_columns { + let new_index = if original_index < num_file_schema_columns { + // File column: gets index [0..num_file_columns) + file_columns.push(original_index); + let idx = file_idx; + file_idx += 1; + idx + } else { + // Partition column: gets index [num_file_columns..) + partition_columns.push(original_index); + let idx = file_idx + partition_idx; + partition_idx += 1; + idx + }; + + // Pre-create the remapped column so all references can share the same Arc + let new_column: Arc = + Arc::new(Column::new(&name, new_index)); + column_mapping.insert(original_index, new_column); + } + + // Single tree transformation: remap all column references using pre-created columns + let remapped_projection = projection + .iter() + .map(|proj_expr| { + let expr = Arc::clone(&proj_expr.expr) + .transform(|expr| { + let original_expr = Arc::clone(&expr); + if let Some(column) = expr.downcast_ref::() + && let Some(new_column) = column_mapping.get(&column.index()) + { + return Ok(Transformed::yes(Arc::clone(new_column))); + } + Ok(Transformed::no(original_expr)) + }) + .data() + .expect("infallible transform"); + ProjectionExpr::new(expr, proj_expr.alias.clone()) + }) + .collect_vec(); + + // Pre-compute partition column mappings for ProjectionOpener + let num_file_columns = file_columns.len(); + let partition_column_mappings = partition_columns + .iter() + .enumerate() + .map(|(partition_idx, &table_index)| PartitionColumnIndex { + in_remainder_projection: num_file_columns + partition_idx, + in_partition_values: table_index - num_file_schema_columns, + }) + .collect_vec(); + + Self { + source: projection.clone(), + file_indices: file_columns, + partition_columns: partition_column_mappings, + remapped_projection: ProjectionExprs::from(remapped_projection), + } + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::AsArray; + use arrow::datatypes::{DataType, SchemaRef}; + use datafusion_common::{DFSchema, ScalarValue, record_batch}; + use datafusion_expr::{Expr, col, execution_props::ExecutionProps}; + use datafusion_physical_expr::{create_physical_exprs, projection::ProjectionExpr}; + use itertools::Itertools; + + use super::*; + + fn create_projection_exprs<'a>( + exprs: impl IntoIterator, + schema: &SchemaRef, + ) -> ProjectionExprs { + let df_schema = DFSchema::try_from(Arc::clone(schema)).unwrap(); + let physical_exprs = + create_physical_exprs(exprs, &df_schema, &ExecutionProps::default()).unwrap(); + let projection_exprs = physical_exprs + .into_iter() + .enumerate() + .map(|(i, e)| ProjectionExpr::new(Arc::clone(&e), format!("col{i}"))) + .collect_vec(); + ProjectionExprs::from(projection_exprs) + } + + #[test] + fn test_split_projection_with_partition_columns() { + use arrow::array::AsArray; + use arrow::datatypes::Field; + // Simulate the avro_exec_with_partition test scenario: + // file_schema has 3 fields + let file_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("bool_col", DataType::Boolean, false), + Field::new("tinyint_col", DataType::Int8, false), + ])); + + // table_schema has 4 fields (3 file + 1 partition) + let table_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("bool_col", DataType::Boolean, false), + Field::new("tinyint_col", DataType::Int8, false), + Field::new("date", DataType::Utf8, false), // partition column at index 3 + ])); + + // projection indices: [0, 1, 3, 2] + // This should select: id (0), bool_col (1), date (3-partition), tinyint_col (2) + let projection_indices = vec![0, 1, 3, 2]; + + // Create projection expressions from indices using the table schema + let projection = + ProjectionExprs::from_indices(&projection_indices, &table_schema); + + // Call SplitProjection to separate file and partition columns + let split = SplitProjection::new(&file_schema, &projection); + + // The file_indices should be [0, 1, 2] (all file columns needed) + assert_eq!(split.file_indices, vec![0, 1, 2]); + + // Should have 1 partition column at in_partition_values index 0 + assert_eq!(split.partition_columns.len(), 1); + assert_eq!(split.partition_columns[0].in_partition_values, 0); + + // Now create a batch with only the file columns + let file_batch = record_batch!( + ("id", Int32, vec![4]), + ("bool_col", Boolean, vec![true]), + ("tinyint_col", Int8, vec![0]) + ) + .unwrap(); + + // After the fix, the remainder projection should have remapped indices: + // - File columns: [0, 1, 2] (unchanged since they're already in order) + // - Partition column: [3] (stays at index 3, which is >= num_file_columns) + // So the remainder expects input columns [0, 1, 2] and references column [3] for partition + + // Verify that we can inject partition columns and apply the projection + let partition_values = vec![ScalarValue::from("2021-10-26")]; + + // Create partition column mapping + let partition_columns = vec![PartitionColumnIndex { + in_remainder_projection: 3, // partition column is at index 3 in remainder + in_partition_values: 0, // first partition value + }]; + + // Inject partition columns (replaces Column(3) with Literal) + let injected_projection = inject_partition_columns_into_projection( + &split.remapped_projection, + &partition_columns, + partition_values, + ); + + // Now the projection should work on the file batch + let projector = injected_projection + .make_projector(&file_batch.schema()) + .unwrap(); + let result = projector.project_batch(&file_batch).unwrap(); + + // Verify the output has the correct column order: id, bool_col, date, tinyint_col + assert_eq!(result.num_columns(), 4); + assert_eq!( + result + .column(0) + .as_primitive::() + .value(0), + 4 + ); + assert!(result.column(1).as_boolean().value(0)); + assert_eq!(result.column(2).as_string::().value(0), "2021-10-26"); + assert_eq!( + result + .column(3) + .as_primitive::() + .value(0), + 0 + ); + } + + // ======================================================================== + // Comprehensive Test Suite for SplitProjection + // ======================================================================== + + // Helper to create test schemas with file and partition columns + fn create_test_schemas( + file_cols: usize, + partition_cols: usize, + ) -> (SchemaRef, SchemaRef) { + use arrow::datatypes::Field; + + let file_fields: Vec<_> = (0..file_cols) + .map(|i| Field::new(format!("col_{i}"), DataType::Int32, false)) + .collect(); + + let mut table_fields = file_fields.clone(); + table_fields.extend( + (0..partition_cols) + .map(|i| Field::new(format!("part_{i}"), DataType::Utf8, false)), + ); + + ( + Arc::new(Schema::new(file_fields)), + Arc::new(Schema::new(table_fields)), + ) + } + + // ======================================================================== + // Partition Column Handling Tests + // ======================================================================== + + #[test] + fn test_split_projection_only_file_columns() { + let (file_schema, table_schema) = create_test_schemas(3, 2); + // Select only file columns [0, 1, 2] + let projection = ProjectionExprs::from_indices(&[0, 1, 2], &table_schema); + + let split = SplitProjection::new(&file_schema, &projection); + + assert_eq!(split.file_indices, vec![0, 1, 2]); + assert_eq!(split.partition_columns.len(), 0); + } + + #[test] + fn test_split_projection_only_partition_columns() { + let (file_schema, table_schema) = create_test_schemas(3, 2); + // Select only partition columns [3, 4] + let projection = ProjectionExprs::from_indices(&[3, 4], &table_schema); + + let split = SplitProjection::new(&file_schema, &projection); + + assert_eq!(split.file_indices, Vec::::new()); + assert_eq!(split.partition_columns.len(), 2); + assert_eq!(split.partition_columns[0].in_partition_values, 0); + assert_eq!(split.partition_columns[1].in_partition_values, 1); + } + + #[test] + fn test_split_projection_multiple_partition_columns() { + let (file_schema, table_schema) = create_test_schemas(2, 3); + // File cols: 0, 1; Partition cols: 2, 3, 4 + // Select: [0, 2, 4, 1, 3] (mixed file and partition) + let projection = ProjectionExprs::from_indices(&[0, 2, 4, 1, 3], &table_schema); + + let split = SplitProjection::new(&file_schema, &projection); + + assert_eq!(split.file_indices, vec![0, 1]); + assert_eq!(split.partition_columns.len(), 3); + assert_eq!(split.partition_columns[0].in_partition_values, 0); + assert_eq!(split.partition_columns[1].in_partition_values, 1); + assert_eq!(split.partition_columns[2].in_partition_values, 2); + + // Verify remapped projection has correct indices + // File columns should be at [0, 1], partition columns at [2, 3, 4] + assert_eq!(split.remapped_projection.iter().count(), 5); + } + + #[test] + fn test_split_projection_partition_columns_reverse_order() { + let (file_schema, table_schema) = create_test_schemas(2, 2); + // File cols: 0, 1; Partition cols: 2, 3 + // Select: [3, 2] (partitions in reverse) + let projection = ProjectionExprs::from_indices(&[3, 2], &table_schema); + + let split = SplitProjection::new(&file_schema, &projection); + + assert_eq!(split.file_indices, Vec::::new()); + assert_eq!(split.partition_columns.len(), 2); + assert_eq!(split.partition_columns[0].in_partition_values, 0); + assert_eq!(split.partition_columns[1].in_partition_values, 1); + } + + #[test] + fn test_split_projection_interleaved_file_and_partition() { + let (file_schema, table_schema) = create_test_schemas(3, 3); + // File cols: 0, 1, 2; Partition cols: 3, 4, 5 + // Select: [0, 3, 1, 4, 2, 5] (alternating) + let projection = + ProjectionExprs::from_indices(&[0, 3, 1, 4, 2, 5], &table_schema); + + let split = SplitProjection::new(&file_schema, &projection); + + assert_eq!(split.file_indices, vec![0, 1, 2]); + assert_eq!(split.partition_columns.len(), 3); + assert_eq!(split.partition_columns[0].in_partition_values, 0); + assert_eq!(split.partition_columns[1].in_partition_values, 1); + assert_eq!(split.partition_columns[2].in_partition_values, 2); + } + + #[test] + fn test_split_projection_expression_with_file_and_partition_columns() { + use arrow::datatypes::Field; + + // Create schemas: 2 file columns, 1 partition column + let file_schema = Arc::new(Schema::new(vec![ + Field::new("file_a", DataType::Int32, false), + Field::new("file_b", DataType::Int32, false), + ])); + let table_schema = Arc::new(Schema::new(vec![ + Field::new("file_a", DataType::Int32, false), + Field::new("file_b", DataType::Int32, false), + Field::new("part_c", DataType::Int32, false), + ])); + + // Create expression: file_a + part_c + let exprs = [col("file_a") + col("part_c")]; + let projection = create_projection_exprs(exprs.iter(), &table_schema); + + let split = SplitProjection::new(&file_schema, &projection); + + // Should extract both columns + assert_eq!(split.file_indices, vec![0]); + assert_eq!(split.partition_columns.len(), 1); + assert_eq!(split.partition_columns[0].in_partition_values, 0); + } + + // ======================================================================== + // Category 4: Boundary Conditions + // ======================================================================== + + #[test] + fn test_split_projection_boundary_last_file_column() { + let (file_schema, table_schema) = create_test_schemas(3, 2); + // Last file column is index 2 + let projection = ProjectionExprs::from_indices(&[2], &table_schema); + + let split = SplitProjection::new(&file_schema, &projection); + + assert_eq!(split.file_indices, vec![2]); + assert_eq!(split.partition_columns.len(), 0); + } + + #[test] + fn test_split_projection_boundary_first_partition_column() { + let (file_schema, table_schema) = create_test_schemas(3, 2); + // First partition column is index 3 + let projection = ProjectionExprs::from_indices(&[3], &table_schema); + + let split = SplitProjection::new(&file_schema, &projection); + + assert_eq!(split.file_indices, Vec::::new()); + assert_eq!(split.partition_columns.len(), 1); + assert_eq!(split.partition_columns[0].in_partition_values, 0); + } + + // ======================================================================== + // Category 6: Integration Tests + // ======================================================================== + + #[test] + fn test_inject_partition_columns_multiple_partitions() { + let data = + record_batch!(("col_0", Int32, vec![1]), ("col_1", Int32, vec![2])).unwrap(); + + // Create projection that references file columns and partition columns + let (file_schema, table_schema) = create_test_schemas(2, 2); + // Projection: [0, 2, 1, 3] = [file_0, part_0, file_1, part_1] + let projection = ProjectionExprs::from_indices(&[0, 2, 1, 3], &table_schema); + let split = SplitProjection::new(&file_schema, &projection); + + // Create partition column mappings + let partition_columns = vec![ + PartitionColumnIndex { + in_remainder_projection: 2, // First partition column at index 2 + in_partition_values: 0, + }, + PartitionColumnIndex { + in_remainder_projection: 3, // Second partition column at index 3 + in_partition_values: 1, + }, + ]; + + let partition_values = + vec![ScalarValue::from("part_a"), ScalarValue::from("part_b")]; + + let injected = inject_partition_columns_into_projection( + &split.remapped_projection, + &partition_columns, + partition_values, + ); + + // Apply projection + let projector = injected.make_projector(&data.schema()).unwrap(); + let result = projector.project_batch(&data).unwrap(); + + assert_eq!(result.num_columns(), 4); + assert_eq!( + result + .column(0) + .as_primitive::() + .value(0), + 1 + ); + assert_eq!(result.column(1).as_string::().value(0), "part_a"); + assert_eq!( + result + .column(2) + .as_primitive::() + .value(0), + 2 + ); + assert_eq!(result.column(3).as_string::().value(0), "part_b"); + } +} diff --git a/datafusion/datasource/src/schema_adapter.rs b/datafusion/datasource/src/schema_adapter.rs index 4c7b37113d58d..c995fa58d6c89 100644 --- a/datafusion/datasource/src/schema_adapter.rs +++ b/datafusion/datasource/src/schema_adapter.rs @@ -15,49 +15,47 @@ // specific language governing permissions and limitations // under the License. -//! [`SchemaAdapter`] and [`SchemaAdapterFactory`] to adapt file-level record batches to a table schema. +//! Deprecated: [`SchemaAdapter`] and [`SchemaAdapterFactory`] have been removed. //! -//! Adapter provides a method of translating the RecordBatches that come out of the -//! physical format into how they should be used by DataFusion. For instance, a schema -//! can be stored external to a parquet file that maps parquet logical types to arrow types. -use arrow::{ - array::{new_null_array, ArrayRef, RecordBatch, RecordBatchOptions}, - compute::can_cast_types, - datatypes::{DataType, Field, Schema, SchemaRef}, -}; -use datafusion_common::{ - format::DEFAULT_CAST_OPTIONS, - nested_struct::{cast_column, validate_struct_compatibility}, - plan_err, ColumnStatistics, -}; -use std::{fmt::Debug, sync::Arc}; -/// Function used by [`SchemaMapping`] to adapt a column from the file schema to -/// the table schema. -pub type CastColumnFn = dyn Fn( - &ArrayRef, - &Field, - &arrow::compute::CastOptions, - ) -> datafusion_common::Result +//! Use [`PhysicalExprAdapterFactory`] instead. See `upgrading.md` for more details. +//! +//! [`PhysicalExprAdapterFactory`]: datafusion_physical_expr_adapter::PhysicalExprAdapterFactory + +#![allow(deprecated)] + +use arrow::array::{ArrayRef, RecordBatch}; +use arrow::datatypes::{Field, Schema, SchemaRef}; +use datafusion_common::{ColumnStatistics, Result, not_impl_err}; +use log::warn; +use std::fmt::Debug; +use std::sync::Arc; + +/// Deprecated: Function type for casting columns. +/// +/// This type has been removed. Use [`PhysicalExprAdapterFactory`] instead. +/// See `upgrading.md` for more details. +/// +/// [`PhysicalExprAdapterFactory`]: datafusion_physical_expr_adapter::PhysicalExprAdapterFactory +#[deprecated( + since = "52.0.0", + note = "SchemaAdapter has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." +)] +pub type CastColumnFn = dyn Fn(&ArrayRef, &Field, &arrow::compute::CastOptions) -> Result + Send + Sync; -/// Factory for creating [`SchemaAdapter`] +/// Deprecated: Factory for creating [`SchemaAdapter`]. /// -/// This interface provides a way to implement custom schema adaptation logic -/// for DataSourceExec (for example, to fill missing columns with default value -/// other than null). +/// This trait has been removed. Use [`PhysicalExprAdapterFactory`] instead. +/// See `upgrading.md` for more details. /// -/// Most users should use [`DefaultSchemaAdapterFactory`]. See that struct for -/// more details and examples. +/// [`PhysicalExprAdapterFactory`]: datafusion_physical_expr_adapter::PhysicalExprAdapterFactory +#[deprecated( + since = "52.0.0", + note = "SchemaAdapter has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." +)] pub trait SchemaAdapterFactory: Debug + Send + Sync + 'static { /// Create a [`SchemaAdapter`] - /// - /// Arguments: - /// - /// * `projected_table_schema`: The schema for the table, projected to - /// include only the fields being output (projected) by the this mapping. - /// - /// * `table_schema`: The entire table schema for the table fn create( &self, projected_table_schema: SchemaRef, @@ -65,9 +63,6 @@ pub trait SchemaAdapterFactory: Debug + Send + Sync + 'static { ) -> Box; /// Create a [`SchemaAdapter`] using only the projected table schema. - /// - /// This is a convenience method for cases where the table schema and the - /// projected table schema are the same. fn create_with_projected_schema( &self, projected_table_schema: SchemaRef, @@ -76,971 +71,162 @@ pub trait SchemaAdapterFactory: Debug + Send + Sync + 'static { } } -/// Creates [`SchemaMapper`]s to map file-level [`RecordBatch`]es to a table -/// schema, which may have a schema obtained from merging multiple file-level -/// schemas. +/// Deprecated: Creates [`SchemaMapper`]s to map file-level [`RecordBatch`]es to a table schema. /// -/// This is useful for implementing schema evolution in partitioned datasets. +/// This trait has been removed. Use [`PhysicalExprAdapterFactory`] instead. +/// See `upgrading.md` for more details. /// -/// See [`DefaultSchemaAdapterFactory`] for more details and examples. +/// [`PhysicalExprAdapterFactory`]: datafusion_physical_expr_adapter::PhysicalExprAdapterFactory +#[deprecated( + since = "52.0.0", + note = "SchemaAdapter has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." +)] pub trait SchemaAdapter: Send + Sync { - /// Map a column index in the table schema to a column index in a particular - /// file schema - /// - /// This is used while reading a file to push down projections by mapping - /// projected column indexes from the table schema to the file schema - /// - /// Panics if index is not in range for the table schema + /// Map a column index in the table schema to a column index in a particular file schema. fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option; - /// Creates a mapping for casting columns from the file schema to the table - /// schema. - /// - /// This is used after reading a record batch. The returned [`SchemaMapper`]: - /// - /// 1. Maps columns to the expected columns indexes - /// 2. Handles missing values (e.g. fills nulls or a default value) for - /// columns in the in the table schema not in the file schema - /// 2. Handles different types: if the column in the file schema has a - /// different type than `table_schema`, the mapper will resolve this - /// difference (e.g. by casting to the appropriate type) - /// - /// Returns: - /// * a [`SchemaMapper`] - /// * an ordered list of columns to project from the file + /// Creates a mapping for casting columns from the file schema to the table schema. fn map_schema( &self, file_schema: &Schema, - ) -> datafusion_common::Result<(Arc, Vec)>; + ) -> Result<(Arc, Vec)>; } -/// Maps, columns from a specific file schema to the table schema. +/// Deprecated: Maps columns from a specific file schema to the table schema. +/// +/// This trait has been removed. Use [`PhysicalExprAdapterFactory`] instead. +/// See `upgrading.md` for more details. /// -/// See [`DefaultSchemaAdapterFactory`] for more details and examples. +/// [`PhysicalExprAdapterFactory`]: datafusion_physical_expr_adapter::PhysicalExprAdapterFactory +#[deprecated( + since = "52.0.0", + note = "SchemaMapper has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." +)] pub trait SchemaMapper: Debug + Send + Sync { - /// Adapts a `RecordBatch` to match the `table_schema` - fn map_batch(&self, batch: RecordBatch) -> datafusion_common::Result; + /// Adapts a `RecordBatch` to match the `table_schema`. + fn map_batch(&self, batch: RecordBatch) -> Result; - /// Adapts file-level column `Statistics` to match the `table_schema` + /// Adapts file-level column `Statistics` to match the `table_schema`. fn map_column_statistics( &self, file_col_statistics: &[ColumnStatistics], - ) -> datafusion_common::Result>; + ) -> Result>; } -/// Default [`SchemaAdapterFactory`] for mapping schemas. -/// -/// This can be used to adapt file-level record batches to a table schema and -/// implement schema evolution. -/// -/// Given an input file schema and a table schema, this factory returns -/// [`SchemaAdapter`] that return [`SchemaMapper`]s that: -/// -/// 1. Reorder columns -/// 2. Cast columns to the correct type -/// 3. Fill missing columns with nulls -/// -/// # Errors: -/// -/// * If a column in the table schema is non-nullable but is not present in the -/// file schema (i.e. it is missing), the returned mapper tries to fill it with -/// nulls resulting in a schema error. -/// -/// # Illustration of Schema Mapping -/// -/// ```text -/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ -/// ┌───────┐ ┌───────┐ │ ┌───────┐ ┌───────┐ ┌───────┐ │ -/// ││ 1.0 │ │ "foo" │ ││ NULL │ │ "foo" │ │ "1.0" │ -/// ├───────┤ ├───────┤ │ Schema mapping ├───────┤ ├───────┤ ├───────┤ │ -/// ││ 2.0 │ │ "bar" │ ││ NULL │ │ "bar" │ │ "2.0" │ -/// └───────┘ └───────┘ │────────────────▶ └───────┘ └───────┘ └───────┘ │ -/// │ │ -/// column "c" column "b"│ column "a" column "b" column "c"│ -/// │ Float64 Utf8 │ Int32 Utf8 Utf8 -/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ -/// Input Record Batch Output Record Batch -/// -/// Schema { Schema { -/// "c": Float64, "a": Int32, -/// "b": Utf8, "b": Utf8, -/// } "c": Utf8, -/// } -/// ``` -/// -/// # Example of using the `DefaultSchemaAdapterFactory` to map [`RecordBatch`]s -/// -/// Note `SchemaMapping` also supports mapping partial batches, which is used as -/// part of predicate pushdown. -/// -/// ``` -/// # use std::sync::Arc; -/// # use arrow::datatypes::{DataType, Field, Schema}; -/// # use datafusion_datasource::schema_adapter::{DefaultSchemaAdapterFactory, SchemaAdapterFactory}; -/// # use datafusion_common::record_batch; -/// // Table has fields "a", "b" and "c" -/// let table_schema = Schema::new(vec![ -/// Field::new("a", DataType::Int32, true), -/// Field::new("b", DataType::Utf8, true), -/// Field::new("c", DataType::Utf8, true), -/// ]); -/// -/// // create an adapter to map the table schema to the file schema -/// let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); +/// Deprecated: Default [`SchemaAdapterFactory`] for mapping schemas. /// -/// // The file schema has fields "c" and "b" but "b" is stored as an 'Float64' -/// // instead of 'Utf8' -/// let file_schema = Schema::new(vec![ -/// Field::new("c", DataType::Utf8, true), -/// Field::new("b", DataType::Float64, true), -/// ]); +/// This struct has been removed. /// -/// // Get a mapping from the file schema to the table schema -/// let (mapper, _indices) = adapter.map_schema(&file_schema).unwrap(); +/// Use [`PhysicalExprAdapterFactory`] instead to customize scans via +/// [`FileScanConfigBuilder`], i.e. if you had implemented a custom [`SchemaAdapter`] +/// and passed that into [`FileScanConfigBuilder`] / [`ParquetSource`]. +/// Use [`BatchAdapter`] if you want to map a stream of [`RecordBatch`]es +/// between one schema and another, i.e. if you were calling [`SchemaMapper::map_batch`] manually. /// -/// let file_batch = record_batch!( -/// ("c", Utf8, vec!["foo", "bar"]), -/// ("b", Float64, vec![1.0, 2.0]) -/// ).unwrap(); +/// See `upgrading.md` for more details. /// -/// let mapped_batch = mapper.map_batch(file_batch).unwrap(); -/// -/// // the mapped batch has the correct schema and the "b" column has been cast to Utf8 -/// let expected_batch = record_batch!( -/// ("a", Int32, vec![None, None]), // missing column filled with nulls -/// ("b", Utf8, vec!["1.0", "2.0"]), // b was cast to string and order was changed -/// ("c", Utf8, vec!["foo", "bar"]) -/// ).unwrap(); -/// assert_eq!(mapped_batch, expected_batch); -/// ``` +/// [`PhysicalExprAdapterFactory`]: datafusion_physical_expr_adapter::PhysicalExprAdapterFactory +/// [`FileScanConfigBuilder`]: crate::file_scan_config::FileScanConfigBuilder +/// [`ParquetSource`]: https://docs.rs/datafusion-datasource-parquet/latest/datafusion_datasource_parquet/source/struct.ParquetSource.html +/// [`BatchAdapter`]: datafusion_physical_expr_adapter::BatchAdapter +#[deprecated( + since = "52.0.0", + note = "DefaultSchemaAdapterFactory has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." +)] #[derive(Clone, Debug, Default)] pub struct DefaultSchemaAdapterFactory; -impl DefaultSchemaAdapterFactory { - /// Create a new factory for mapping batches from a file schema to a table - /// schema. - /// - /// This is a convenience for [`DefaultSchemaAdapterFactory::create`] with - /// the same schema for both the projected table schema and the table - /// schema. - pub fn from_schema(table_schema: SchemaRef) -> Box { - Self.create(Arc::clone(&table_schema), table_schema) - } -} - impl SchemaAdapterFactory for DefaultSchemaAdapterFactory { fn create( &self, projected_table_schema: SchemaRef, _table_schema: SchemaRef, ) -> Box { - Box::new(DefaultSchemaAdapter { - projected_table_schema, + Box::new(DeprecatedSchemaAdapter { + _projected_table_schema: projected_table_schema, }) } } -/// This SchemaAdapter requires both the table schema and the projected table -/// schema. See [`SchemaMapping`] for more details -#[derive(Clone, Debug)] -pub(crate) struct DefaultSchemaAdapter { - /// The schema for the table, projected to include only the fields being output (projected) by the - /// associated ParquetSource - projected_table_schema: SchemaRef, +impl DefaultSchemaAdapterFactory { + /// Deprecated: Create a new factory for mapping batches from a file schema to a table schema. + #[deprecated( + since = "52.0.0", + note = "DefaultSchemaAdapterFactory has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." + )] + pub fn from_schema(table_schema: SchemaRef) -> Box { + // Note: this method did not return an error thus the errors are raised from the returned adapter + warn!( + "DefaultSchemaAdapterFactory::from_schema is deprecated. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." + ); + Box::new(DeprecatedSchemaAdapter { + _projected_table_schema: table_schema, + }) + } } -/// Checks if a file field can be cast to a table field -/// -/// Returns Ok(true) if casting is possible, or an error explaining why casting is not possible -pub(crate) fn can_cast_field( - file_field: &Field, - table_field: &Field, -) -> datafusion_common::Result { - match (file_field.data_type(), table_field.data_type()) { - (DataType::Struct(source_fields), DataType::Struct(target_fields)) => { - // validate_struct_compatibility returns Result<()>; on success we can cast structs - validate_struct_compatibility(source_fields, target_fields)?; - Ok(true) - } - _ => { - if can_cast_types(file_field.data_type(), table_field.data_type()) { - Ok(true) - } else { - plan_err!( - "Cannot cast file schema field {} of type {} to table schema field of type {}", - file_field.name(), - file_field.data_type(), - table_field.data_type() - ) - } - } - } +/// Internal deprecated adapter that returns errors when methods are called. +struct DeprecatedSchemaAdapter { + _projected_table_schema: SchemaRef, } -impl SchemaAdapter for DefaultSchemaAdapter { - /// Map a column index in the table schema to a column index in a particular - /// file schema - /// - /// Panics if index is not in range for the table schema - fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { - let field = self.projected_table_schema.field(index); - Some(file_schema.fields.find(field.name())?.0) +impl SchemaAdapter for DeprecatedSchemaAdapter { + fn map_column_index(&self, _index: usize, _file_schema: &Schema) -> Option { + None // Safe no-op } - /// Creates a `SchemaMapping` for casting or mapping the columns from the - /// file schema to the table schema. - /// - /// If the provided `file_schema` contains columns of a different type to - /// the expected `table_schema`, the method will attempt to cast the array - /// data from the file schema to the table schema where possible. - /// - /// Returns a [`SchemaMapping`] that can be applied to the output batch - /// along with an ordered list of columns to project from the file fn map_schema( &self, - file_schema: &Schema, - ) -> datafusion_common::Result<(Arc, Vec)> { - let (field_mappings, projection) = create_field_mapping( - file_schema, - &self.projected_table_schema, - can_cast_field, - )?; - - Ok(( - Arc::new(SchemaMapping::new( - Arc::clone(&self.projected_table_schema), - field_mappings, - Arc::new( - |array: &ArrayRef, - field: &Field, - opts: &arrow::compute::CastOptions| { - cast_column(array, field, opts) - }, - ), - )), - projection, - )) + _file_schema: &Schema, + ) -> Result<(Arc, Vec)> { + not_impl_err!( + "SchemaAdapter has been removed. Use PhysicalExprAdapterFactory instead. \ + See upgrading.md for more details." + ) } } -/// Helper function that creates field mappings between file schema and table schema +/// Deprecated: The SchemaMapping struct held a mapping from the file schema to the table schema. /// -/// Maps columns from the file schema to their corresponding positions in the table schema, -/// applying type compatibility checking via the provided predicate function. +/// This struct has been removed. /// -/// Returns field mappings (for column reordering) and a projection (for field selection). -pub(crate) fn create_field_mapping( - file_schema: &Schema, - projected_table_schema: &SchemaRef, - can_map_field: F, -) -> datafusion_common::Result<(Vec>, Vec)> -where - F: Fn(&Field, &Field) -> datafusion_common::Result, -{ - let mut projection = Vec::with_capacity(file_schema.fields().len()); - let mut field_mappings = vec![None; projected_table_schema.fields().len()]; - - for (file_idx, file_field) in file_schema.fields.iter().enumerate() { - if let Some((table_idx, table_field)) = - projected_table_schema.fields().find(file_field.name()) - { - if can_map_field(file_field, table_field)? { - field_mappings[table_idx] = Some(projection.len()); - projection.push(file_idx); - } - } - } - - Ok((field_mappings, projection)) -} - -/// The SchemaMapping struct holds a mapping from the file schema to the table -/// schema and any necessary type conversions. +/// Use [`PhysicalExprAdapterFactory`] instead to customize scans via +/// [`FileScanConfigBuilder`], i.e. if you had implemented a custom [`SchemaAdapter`] +/// and passed that into [`FileScanConfigBuilder`] / [`ParquetSource`]. +/// Use [`BatchAdapter`] if you want to map a stream of [`RecordBatch`]es +/// between one schema and another, i.e. if you were calling [`SchemaMapper::map_batch`] manually. /// -/// [`map_batch`] is used by the ParquetOpener to produce a RecordBatch which -/// has the projected schema, since that's the schema which is supposed to come -/// out of the execution of this query. Thus `map_batch` uses -/// `projected_table_schema` as it can only operate on the projected fields. +/// See `upgrading.md` for more details. /// -/// [`map_batch`]: Self::map_batch +/// [`PhysicalExprAdapterFactory`]: datafusion_physical_expr_adapter::PhysicalExprAdapterFactory +/// [`FileScanConfigBuilder`]: crate::file_scan_config::FileScanConfigBuilder +/// [`ParquetSource`]: https://docs.rs/datafusion-datasource-parquet/latest/datafusion_datasource_parquet/source/struct.ParquetSource.html +/// [`BatchAdapter`]: datafusion_physical_expr_adapter::BatchAdapter +#[deprecated( + since = "52.0.0", + note = "SchemaMapping has been removed. Use PhysicalExprAdapterFactory instead. See upgrading.md for more details." +)] +#[derive(Debug)] pub struct SchemaMapping { - /// The schema of the table. This is the expected schema after conversion - /// and it should match the schema of the query result. - projected_table_schema: SchemaRef, - /// Mapping from field index in `projected_table_schema` to index in - /// projected file_schema. - /// - /// They are Options instead of just plain `usize`s because the table could - /// have fields that don't exist in the file. - field_mappings: Vec>, - /// Function used to adapt a column from the file schema to the table schema - /// when it exists in both schemas - cast_column: Arc, -} - -impl Debug for SchemaMapping { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SchemaMapping") - .field("projected_table_schema", &self.projected_table_schema) - .field("field_mappings", &self.field_mappings) - .field("cast_column", &"") - .finish() - } -} - -impl SchemaMapping { - /// Creates a new SchemaMapping instance - /// - /// Initializes the field mappings needed to transform file data to the projected table schema - pub fn new( - projected_table_schema: SchemaRef, - field_mappings: Vec>, - cast_column: Arc, - ) -> Self { - Self { - projected_table_schema, - field_mappings, - cast_column, - } - } + // Private fields removed - this is a skeleton for deprecation purposes only + _private: (), } impl SchemaMapper for SchemaMapping { - /// Adapts a `RecordBatch` to match the `projected_table_schema` using the stored mapping and - /// conversions. - /// The produced RecordBatch has a schema that contains only the projected columns. - fn map_batch(&self, batch: RecordBatch) -> datafusion_common::Result { - let (_old_schema, batch_cols, batch_rows) = batch.into_parts(); - - let cols = self - .projected_table_schema - // go through each field in the projected schema - .fields() - .iter() - // and zip it with the index that maps fields from the projected table schema to the - // projected file schema in `batch` - .zip(&self.field_mappings) - // and for each one... - .map(|(field, file_idx)| { - file_idx.map_or_else( - // If this field only exists in the table, and not in the file, then we know - // that it's null, so just return that. - || Ok(new_null_array(field.data_type(), batch_rows)), - // However, if it does exist in both, use the cast_column function - // to perform any necessary conversions - |batch_idx| { - (self.cast_column)( - &batch_cols[batch_idx], - field, - &DEFAULT_CAST_OPTIONS, - ) - }, - ) - }) - .collect::, _>>()?; - - // Necessary to handle empty batches - let options = RecordBatchOptions::new().with_row_count(Some(batch_rows)); - - let schema = Arc::clone(&self.projected_table_schema); - let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?; - Ok(record_batch) + fn map_batch(&self, _batch: RecordBatch) -> Result { + not_impl_err!( + "SchemaMapping has been removed. Use PhysicalExprAdapterFactory instead. \ + See upgrading.md for more details." + ) } - /// Adapts file-level column `Statistics` to match the `table_schema` fn map_column_statistics( &self, - file_col_statistics: &[ColumnStatistics], - ) -> datafusion_common::Result> { - let mut table_col_statistics = vec![]; - - // Map the statistics for each field in the file schema to the corresponding field in the - // table schema, if a field is not present in the file schema, we need to fill it with `ColumnStatistics::new_unknown` - for (_, file_col_idx) in self - .projected_table_schema - .fields() - .iter() - .zip(&self.field_mappings) - { - if let Some(file_col_idx) = file_col_idx { - table_col_statistics.push( - file_col_statistics - .get(*file_col_idx) - .cloned() - .unwrap_or_default(), - ); - } else { - table_col_statistics.push(ColumnStatistics::new_unknown()); - } - } - - Ok(table_col_statistics) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::{ - array::{Array, ArrayRef, StringBuilder, StructArray, TimestampMillisecondArray}, - compute::cast, - datatypes::{DataType, Field, TimeUnit}, - record_batch::RecordBatch, - }; - use datafusion_common::{stats::Precision, Result, ScalarValue, Statistics}; - - #[test] - fn test_schema_mapping_map_statistics_basic() { - // Create table schema (a, b, c) - let table_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), - Field::new("c", DataType::Float64, true), - ])); - - // Create file schema (b, a) - different order, missing c - let file_schema = Schema::new(vec![ - Field::new("b", DataType::Utf8, true), - Field::new("a", DataType::Int32, true), - ]); - - // Create SchemaAdapter - let adapter = DefaultSchemaAdapter { - projected_table_schema: Arc::clone(&table_schema), - }; - - // Get mapper and projection - let (mapper, projection) = adapter.map_schema(&file_schema).unwrap(); - - // Should project columns 0,1 from file - assert_eq!(projection, vec![0, 1]); - - // Create file statistics - let mut file_stats = Statistics::default(); - - // Statistics for column b (index 0 in file) - let b_stats = ColumnStatistics { - null_count: Precision::Exact(5), - ..Default::default() - }; - - // Statistics for column a (index 1 in file) - let a_stats = ColumnStatistics { - null_count: Precision::Exact(10), - ..Default::default() - }; - - file_stats.column_statistics = vec![b_stats, a_stats]; - - // Map statistics - let table_col_stats = mapper - .map_column_statistics(&file_stats.column_statistics) - .unwrap(); - - // Verify stats - assert_eq!(table_col_stats.len(), 3); - assert_eq!(table_col_stats[0].null_count, Precision::Exact(10)); // a from file idx 1 - assert_eq!(table_col_stats[1].null_count, Precision::Exact(5)); // b from file idx 0 - assert_eq!(table_col_stats[2].null_count, Precision::Absent); // c (unknown) - } - - #[test] - fn test_schema_mapping_map_statistics_empty() { - // Create schemas - let table_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), - ])); - let file_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), - ]); - - let adapter = DefaultSchemaAdapter { - projected_table_schema: Arc::clone(&table_schema), - }; - let (mapper, _) = adapter.map_schema(&file_schema).unwrap(); - - // Empty file statistics - let file_stats = Statistics::default(); - let table_col_stats = mapper - .map_column_statistics(&file_stats.column_statistics) - .unwrap(); - - // All stats should be unknown - assert_eq!(table_col_stats.len(), 2); - assert_eq!(table_col_stats[0], ColumnStatistics::new_unknown(),); - assert_eq!(table_col_stats[1], ColumnStatistics::new_unknown(),); - } - - #[test] - fn test_can_cast_field() { - // Same type should work - let from_field = Field::new("col", DataType::Int32, true); - let to_field = Field::new("col", DataType::Int32, true); - assert!(can_cast_field(&from_field, &to_field).unwrap()); - - // Casting Int32 to Float64 is allowed - let from_field = Field::new("col", DataType::Int32, true); - let to_field = Field::new("col", DataType::Float64, true); - assert!(can_cast_field(&from_field, &to_field).unwrap()); - - // Casting Float64 to Utf8 should work (converts to string) - let from_field = Field::new("col", DataType::Float64, true); - let to_field = Field::new("col", DataType::Utf8, true); - assert!(can_cast_field(&from_field, &to_field).unwrap()); - - // Binary to Utf8 is not supported - this is an example of a cast that should fail - // Note: We use Binary instead of Utf8->Int32 because Arrow actually supports that cast - let from_field = Field::new("col", DataType::Binary, true); - let to_field = Field::new("col", DataType::Decimal128(10, 2), true); - let result = can_cast_field(&from_field, &to_field); - assert!(result.is_err()); - let error_msg = result.unwrap_err().to_string(); - assert!(error_msg.contains("Cannot cast file schema field col")); - } - - #[test] - fn test_create_field_mapping() { - // Define the table schema - let table_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), - Field::new("c", DataType::Float64, true), - ])); - - // Define file schema: different order, missing column c, and b has different type - let file_schema = Schema::new(vec![ - Field::new("b", DataType::Float64, true), // Different type but castable to Utf8 - Field::new("a", DataType::Int32, true), // Same type - Field::new("d", DataType::Boolean, true), // Not in table schema - ]); - - // Custom can_map_field function that allows all mappings for testing - let allow_all = |_: &Field, _: &Field| Ok(true); - - // Test field mapping - let (field_mappings, projection) = - create_field_mapping(&file_schema, &table_schema, allow_all).unwrap(); - - // Expected: - // - field_mappings[0] (a) maps to projection[1] - // - field_mappings[1] (b) maps to projection[0] - // - field_mappings[2] (c) is None (not in file) - assert_eq!(field_mappings, vec![Some(1), Some(0), None]); - assert_eq!(projection, vec![0, 1]); // Projecting file columns b, a - - // Test with a failing mapper - let fails_all = |_: &Field, _: &Field| Ok(false); - let (field_mappings, projection) = - create_field_mapping(&file_schema, &table_schema, fails_all).unwrap(); - - // Should have no mappings or projections if all cast checks fail - assert_eq!(field_mappings, vec![None, None, None]); - assert_eq!(projection, Vec::::new()); - - // Test with error-producing mapper - let error_mapper = |_: &Field, _: &Field| plan_err!("Test error"); - let result = create_field_mapping(&file_schema, &table_schema, error_mapper); - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("Test error")); - } - - #[test] - fn test_schema_mapping_new() { - // Define the projected table schema - let projected_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), - ])); - - // Define field mappings from table to file - let field_mappings = vec![Some(1), Some(0)]; - - // Create SchemaMapping manually - let mapping = SchemaMapping::new( - Arc::clone(&projected_schema), - field_mappings.clone(), - Arc::new( - |array: &ArrayRef, field: &Field, opts: &arrow::compute::CastOptions| { - cast_column(array, field, opts) - }, - ), - ); - - // Check that fields were set correctly - assert_eq!(*mapping.projected_table_schema, *projected_schema); - assert_eq!(mapping.field_mappings, field_mappings); - - // Test with a batch to ensure it works properly - let batch = RecordBatch::try_new( - Arc::new(Schema::new(vec![ - Field::new("b_file", DataType::Utf8, true), - Field::new("a_file", DataType::Int32, true), - ])), - vec![ - Arc::new(arrow::array::StringArray::from(vec!["hello", "world"])), - Arc::new(arrow::array::Int32Array::from(vec![1, 2])), - ], - ) - .unwrap(); - - // Test that map_batch works with our manually created mapping - let mapped_batch = mapping.map_batch(batch).unwrap(); - - // Verify the mapped batch has the correct schema and data - assert_eq!(*mapped_batch.schema(), *projected_schema); - assert_eq!(mapped_batch.num_columns(), 2); - assert_eq!(mapped_batch.column(0).len(), 2); // a column - assert_eq!(mapped_batch.column(1).len(), 2); // b column - } - - #[test] - fn test_map_schema_error_path() { - // Define the table schema - let table_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), - Field::new("c", DataType::Decimal128(10, 2), true), // Use Decimal which has stricter cast rules - ])); - - // Define file schema with incompatible type for column c - let file_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Float64, true), // Different but castable - Field::new("c", DataType::Binary, true), // Not castable to Decimal128 - ]); - - // Create DefaultSchemaAdapter - let adapter = DefaultSchemaAdapter { - projected_table_schema: Arc::clone(&table_schema), - }; - - // map_schema should error due to incompatible types - let result = adapter.map_schema(&file_schema); - assert!(result.is_err()); - let error_msg = result.unwrap_err().to_string(); - assert!(error_msg.contains("Cannot cast file schema field c")); - } - - #[test] - fn test_map_schema_happy_path() { - // Define the table schema - let table_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Utf8, true), - Field::new("c", DataType::Decimal128(10, 2), true), - ])); - - // Create DefaultSchemaAdapter - let adapter = DefaultSchemaAdapter { - projected_table_schema: Arc::clone(&table_schema), - }; - - // Define compatible file schema (missing column c) - let compatible_file_schema = Schema::new(vec![ - Field::new("a", DataType::Int64, true), // Can be cast to Int32 - Field::new("b", DataType::Float64, true), // Can be cast to Utf8 - ]); - - // Test successful schema mapping - let (mapper, projection) = adapter.map_schema(&compatible_file_schema).unwrap(); - - // Verify field_mappings and projection created correctly - assert_eq!(projection, vec![0, 1]); // Projecting a and b - - // Verify the SchemaMapping works with actual data - let file_batch = RecordBatch::try_new( - Arc::new(compatible_file_schema.clone()), - vec![ - Arc::new(arrow::array::Int64Array::from(vec![100, 200])), - Arc::new(arrow::array::Float64Array::from(vec![1.5, 2.5])), - ], + _file_col_statistics: &[ColumnStatistics], + ) -> Result> { + not_impl_err!( + "SchemaMapping has been removed. Use PhysicalExprAdapterFactory instead. \ + See upgrading.md for more details." ) - .unwrap(); - - let mapped_batch = mapper.map_batch(file_batch).unwrap(); - - // Verify correct schema mapping - assert_eq!(*mapped_batch.schema(), *table_schema); - assert_eq!(mapped_batch.num_columns(), 3); // a, b, c - - // Column c should be null since it wasn't in the file schema - let c_array = mapped_batch.column(2); - assert_eq!(c_array.len(), 2); - assert_eq!(c_array.null_count(), 2); - } - - #[test] - fn test_adapt_struct_with_added_nested_fields() -> Result<()> { - let (file_schema, table_schema) = create_test_schemas_with_nested_fields(); - let batch = create_test_batch_with_struct_data(&file_schema)?; - - let adapter = DefaultSchemaAdapter { - projected_table_schema: Arc::clone(&table_schema), - }; - let (mapper, _) = adapter.map_schema(file_schema.as_ref())?; - let mapped_batch = mapper.map_batch(batch)?; - - verify_adapted_batch_with_nested_fields(&mapped_batch, &table_schema)?; - Ok(()) - } - - #[test] - fn test_map_column_statistics_struct() -> Result<()> { - let (file_schema, table_schema) = create_test_schemas_with_nested_fields(); - - let adapter = DefaultSchemaAdapter { - projected_table_schema: Arc::clone(&table_schema), - }; - let (mapper, _) = adapter.map_schema(file_schema.as_ref())?; - - let file_stats = vec![ - create_test_column_statistics( - 0, - 100, - Some(ScalarValue::Int32(Some(1))), - Some(ScalarValue::Int32(Some(100))), - Some(ScalarValue::Int32(Some(5100))), - ), - create_test_column_statistics(10, 50, None, None, None), - ]; - - let table_stats = mapper.map_column_statistics(&file_stats)?; - assert_eq!(table_stats.len(), 1); - verify_column_statistics( - &table_stats[0], - Some(0), - Some(100), - Some(ScalarValue::Int32(Some(1))), - Some(ScalarValue::Int32(Some(100))), - Some(ScalarValue::Int32(Some(5100))), - ); - let missing_stats = mapper.map_column_statistics(&[])?; - assert_eq!(missing_stats.len(), 1); - assert_eq!(missing_stats[0], ColumnStatistics::new_unknown()); - Ok(()) - } - - fn create_test_schemas_with_nested_fields() -> (SchemaRef, SchemaRef) { - let file_schema = Arc::new(Schema::new(vec![Field::new( - "info", - DataType::Struct( - vec![ - Field::new("location", DataType::Utf8, true), - Field::new( - "timestamp_utc", - DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), - true, - ), - ] - .into(), - ), - true, - )])); - - let table_schema = Arc::new(Schema::new(vec![Field::new( - "info", - DataType::Struct( - vec![ - Field::new("location", DataType::Utf8, true), - Field::new( - "timestamp_utc", - DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), - true, - ), - Field::new( - "reason", - DataType::Struct( - vec![ - Field::new("_level", DataType::Float64, true), - Field::new( - "details", - DataType::Struct( - vec![ - Field::new("rurl", DataType::Utf8, true), - Field::new("s", DataType::Float64, true), - Field::new("t", DataType::Utf8, true), - ] - .into(), - ), - true, - ), - ] - .into(), - ), - true, - ), - ] - .into(), - ), - true, - )])); - - (file_schema, table_schema) - } - - fn create_test_batch_with_struct_data( - file_schema: &SchemaRef, - ) -> Result { - let mut location_builder = StringBuilder::new(); - location_builder.append_value("San Francisco"); - location_builder.append_value("New York"); - - let timestamp_array = TimestampMillisecondArray::from(vec![ - Some(1640995200000), - Some(1641081600000), - ]); - - let timestamp_type = - DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())); - let timestamp_array = cast(×tamp_array, ×tamp_type)?; - - let info_struct = StructArray::from(vec![ - ( - Arc::new(Field::new("location", DataType::Utf8, true)), - Arc::new(location_builder.finish()) as ArrayRef, - ), - ( - Arc::new(Field::new("timestamp_utc", timestamp_type, true)), - timestamp_array, - ), - ]); - - Ok(RecordBatch::try_new( - Arc::clone(file_schema), - vec![Arc::new(info_struct)], - )?) - } - - fn verify_adapted_batch_with_nested_fields( - mapped_batch: &RecordBatch, - table_schema: &SchemaRef, - ) -> Result<()> { - assert_eq!(mapped_batch.schema(), *table_schema); - assert_eq!(mapped_batch.num_rows(), 2); - - let info_col = mapped_batch.column(0); - let info_array = info_col - .as_any() - .downcast_ref::() - .expect("Expected info column to be a StructArray"); - - verify_preserved_fields(info_array)?; - verify_reason_field_structure(info_array)?; - Ok(()) - } - - fn verify_preserved_fields(info_array: &StructArray) -> Result<()> { - let location_col = info_array - .column_by_name("location") - .expect("Expected location field in struct"); - let location_array = location_col - .as_any() - .downcast_ref::() - .expect("Expected location to be a StringArray"); - assert_eq!(location_array.value(0), "San Francisco"); - assert_eq!(location_array.value(1), "New York"); - - let timestamp_col = info_array - .column_by_name("timestamp_utc") - .expect("Expected timestamp_utc field in struct"); - let timestamp_array = timestamp_col - .as_any() - .downcast_ref::() - .expect("Expected timestamp_utc to be a TimestampMillisecondArray"); - assert_eq!(timestamp_array.value(0), 1640995200000); - assert_eq!(timestamp_array.value(1), 1641081600000); - Ok(()) - } - - fn verify_reason_field_structure(info_array: &StructArray) -> Result<()> { - let reason_col = info_array - .column_by_name("reason") - .expect("Expected reason field in struct"); - let reason_array = reason_col - .as_any() - .downcast_ref::() - .expect("Expected reason to be a StructArray"); - assert_eq!(reason_array.fields().len(), 2); - assert!(reason_array.column_by_name("_level").is_some()); - assert!(reason_array.column_by_name("details").is_some()); - - let details_col = reason_array - .column_by_name("details") - .expect("Expected details field in reason struct"); - let details_array = details_col - .as_any() - .downcast_ref::() - .expect("Expected details to be a StructArray"); - assert_eq!(details_array.fields().len(), 3); - assert!(details_array.column_by_name("rurl").is_some()); - assert!(details_array.column_by_name("s").is_some()); - assert!(details_array.column_by_name("t").is_some()); - for i in 0..2 { - assert!(reason_array.is_null(i), "reason field should be null"); - } - Ok(()) - } - - fn verify_column_statistics( - stats: &ColumnStatistics, - expected_null_count: Option, - expected_distinct_count: Option, - expected_min: Option, - expected_max: Option, - expected_sum: Option, - ) { - if let Some(count) = expected_null_count { - assert_eq!( - stats.null_count, - Precision::Exact(count), - "Null count should match expected value" - ); - } - if let Some(count) = expected_distinct_count { - assert_eq!( - stats.distinct_count, - Precision::Exact(count), - "Distinct count should match expected value" - ); - } - if let Some(min) = expected_min { - assert_eq!( - stats.min_value, - Precision::Exact(min), - "Min value should match expected value" - ); - } - if let Some(max) = expected_max { - assert_eq!( - stats.max_value, - Precision::Exact(max), - "Max value should match expected value" - ); - } - if let Some(sum) = expected_sum { - assert_eq!( - stats.sum_value, - Precision::Exact(sum), - "Sum value should match expected value" - ); - } - } - - fn create_test_column_statistics( - null_count: usize, - distinct_count: usize, - min_value: Option, - max_value: Option, - sum_value: Option, - ) -> ColumnStatistics { - ColumnStatistics { - null_count: Precision::Exact(null_count), - distinct_count: Precision::Exact(distinct_count), - min_value: min_value.map_or_else(|| Precision::Absent, Precision::Exact), - max_value: max_value.map_or_else(|| Precision::Absent, Precision::Exact), - sum_value: sum_value.map_or_else(|| Precision::Absent, Precision::Exact), - } } } diff --git a/datafusion/datasource/src/sink.rs b/datafusion/datasource/src/sink.rs index f66fbc408c68b..e3df1ad6381f4 100644 --- a/datafusion/datasource/src/sink.rs +++ b/datafusion/datasource/src/sink.rs @@ -24,15 +24,15 @@ use std::sync::Arc; use arrow::array::{ArrayRef, RecordBatch, UInt64Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::{assert_eq_or_internal_err, Result}; +use datafusion_common::{Result, assert_eq_or_internal_err}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{Distribution, EquivalenceProperties}; use datafusion_physical_expr_common::sort_expr::{LexRequirement, OrderingRequirements}; use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{ - execute_input_stream, DisplayAs, DisplayFormatType, ExecutionPlan, - ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream, + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning, + PlanProperties, SendableRecordBatchStream, execute_input_stream, }; use async_trait::async_trait; @@ -45,11 +45,7 @@ use futures::StreamExt; /// The `Display` impl is used to format the sink for explain plan /// output. #[async_trait] -pub trait DataSink: DisplayAs + Debug + Send + Sync { - /// Returns the data sink as [`Any`] so that it can be - /// downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; - +pub trait DataSink: Any + DisplayAs + Debug + Send + Sync { /// Return a snapshot of the [MetricsSet] for this /// [DataSink]. /// @@ -76,6 +72,18 @@ pub trait DataSink: DisplayAs + Debug + Send + Sync { ) -> Result; } +impl dyn DataSink { + /// Returns true if the inner type is `T`. + pub fn is(&self) -> bool { + (self as &dyn Any).is::() + } + + /// Returns a reference to the inner value as the type `T` if it is of that type. + pub fn downcast_ref(&self) -> Option<&T> { + (self as &dyn Any).downcast_ref() + } +} + /// Execution plan for writing record batches to a [`DataSink`] /// /// Returns a single row with the number of values written @@ -89,12 +97,12 @@ pub struct DataSinkExec { count_schema: SchemaRef, /// Optional required sort order for output data. sort_order: Option, - cache: PlanProperties, + cache: Arc, } impl Debug for DataSinkExec { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "DataSinkExec schema: {:?}", self.count_schema) + write!(f, "DataSinkExec schema: {}", self.count_schema) } } @@ -117,7 +125,7 @@ impl DataSinkExec { sink, count_schema: make_count_schema(), sort_order, - cache, + cache: Arc::new(cache), } } @@ -170,11 +178,7 @@ impl ExecutionPlan for DataSinkExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } diff --git a/datafusion/datasource/src/source.rs b/datafusion/datasource/src/source.rs index de79512a41017..af4bc09504937 100644 --- a/datafusion/datasource/src/source.rs +++ b/datafusion/datasource/src/source.rs @@ -20,26 +20,31 @@ use std::any::Any; use std::fmt; use std::fmt::{Debug, Formatter}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; +use datafusion_physical_expr::projection::ProjectionExprs; use datafusion_physical_plan::execution_plan::{ Boundedness, EmissionType, SchedulingType, }; use datafusion_physical_plan::metrics::SplitMetrics; -use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; +use datafusion_physical_plan::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, +}; +use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::stream::BatchSplitStream; use datafusion_physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, }; use itertools::Itertools; +use crate::file::FileSource; use crate::file_scan_config::FileScanConfig; use datafusion_common::config::ConfigOptions; use datafusion_common::{Constraints, Result, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalExpr}; -use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_plan::SortOrderPushdownResult; use datafusion_physical_plan::filter_pushdown::{ ChildPushdownResult, FilterPushdownPhase, FilterPushdownPropagation, PushedDown, }; @@ -72,8 +77,8 @@ use datafusion_physical_plan::filter_pushdown::{ /// ```text /// ┌─────────────────────┐ -----► execute path /// │ │ ┄┄┄┄┄► init path -/// │ DataSourceExec │ -/// │ │ +/// │ DataSourceExec │ +/// │ │ /// └───────▲─────────────┘ /// ┊ │ /// ┊ │ @@ -117,13 +122,22 @@ use datafusion_physical_plan::filter_pushdown::{ /// │ │ /// └─────────────────────┘ /// ``` -pub trait DataSource: Send + Sync + Debug { +pub trait DataSource: Any + Send + Sync + Debug { + /// Open the specified output partition and return its stream of + /// [`RecordBatch`]es. + /// + /// This should be used by data sources that do not need any sibling + /// coordination. Data sources that want to use per-execution shared state + /// (for example, to reorder work across partitions at runtime) should + /// implement [`Self::open_with_args`] instead. + /// + /// [`RecordBatch`]: arrow::record_batch::RecordBatch fn open( &self, partition: usize, context: Arc, ) -> Result; - fn as_any(&self) -> &dyn Any; + /// Format this source for display in explain plans fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result; @@ -154,17 +168,7 @@ pub trait DataSource: Send + Sync + Debug { /// Returns statistics for a specific partition, or aggregate statistics /// across all partitions if `partition` is `None`. - fn partition_statistics(&self, partition: Option) -> Result; - - /// Returns aggregate statistics across all partitions. - /// - /// # Deprecated - /// Use [`Self::partition_statistics`] instead, which provides more fine-grained - /// control over statistics retrieval (per-partition or aggregate). - #[deprecated(since = "51.0.0", note = "Use partition_statistics instead")] - fn statistics(&self) -> Result { - self.partition_statistics(None) - } + fn partition_statistics(&self, partition: Option) -> Result>; /// Return a copy of this DataSource with a new fetch limit fn with_fetch(&self, _limit: Option) -> Option>; @@ -174,9 +178,15 @@ pub trait DataSource: Send + Sync + Debug { } fn try_swapping_with_projection( &self, - _projection: &[ProjectionExpr], + _projection: &ProjectionExprs, ) -> Result>>; + /// Try to push down filters into this DataSource. + /// + /// These filters are in terms of the output schema of this DataSource (e.g. + /// [`Self::eq_properties`] and output of any projections pushed into the + /// source), not the original table schema. + /// /// See [`ExecutionPlan::handle_child_pushdown_result`] for more details. /// /// [`ExecutionPlan::handle_child_pushdown_result`]: datafusion_physical_plan::ExecutionPlan::handle_child_pushdown_result @@ -189,6 +199,105 @@ pub trait DataSource: Send + Sync + Debug { vec![PushedDown::No; filters.len()], )) } + + /// Try to create a new DataSource that produces data in the specified sort order. + /// + /// # Arguments + /// * `order` - The desired output ordering + /// + /// # Returns + /// * `Ok(SortOrderPushdownResult::Exact { .. })` - Created a source that guarantees exact ordering + /// * `Ok(SortOrderPushdownResult::Inexact { .. })` - Created a source optimized for the ordering + /// * `Ok(SortOrderPushdownResult::Unsupported)` - Cannot optimize for this ordering + /// * `Err(e)` - Error occurred + /// + /// Default implementation returns `Unsupported`. + fn try_pushdown_sort( + &self, + _order: &[PhysicalSortExpr], + ) -> Result>> { + Ok(SortOrderPushdownResult::Unsupported) + } + + /// Returns a variant of this `DataSource` that is aware of order-sensitivity. + fn with_preserve_order(&self, _preserve_order: bool) -> Option> { + None + } + + /// Injects arbitrary run-time state into this DataSource, returning a new instance + /// that incorporates that state *if* it is relevant to the concrete DataSource implementation. + /// + /// This is a generic entry point: the `state` can be any type wrapped in + /// `Arc`. A data source that cares about the state should + /// down-cast it to the concrete type it expects and, if successful, return a + /// modified copy of itself that captures the provided value. If the state is + /// not applicable, the default behaviour is to return `None` so that parent + /// nodes can continue propagating the attempt further down the plan tree. + fn with_new_state( + &self, + _state: Arc, + ) -> Option> { + None + } + + /// Create per execution state to share across sibling instances of this + /// data source during one execution. + /// + /// Returns `None` (the default) if this data source has + /// no sibling-shared execution state. + fn create_sibling_state(&self) -> Option> { + None + } + + /// Open a partition using optional sibling-shared execution state. + /// + /// The default implementation ignores the additional state and delegates to + /// [`Self::open`]. + fn open_with_args(&self, args: OpenArgs) -> Result { + self.open(args.partition, args.context) + } +} + +/// Arguments for [`DataSource::open_with_args`] +#[derive(Debug, Clone)] +pub struct OpenArgs { + /// Which partition to open + pub partition: usize, + /// The task context for execution + pub context: Arc, + /// Optional sibling-shared execution state, see + /// [`DataSource::create_sibling_state`] for details. + pub sibling_state: Option>, +} + +impl OpenArgs { + /// Create a new OpenArgs with required arguments + pub fn new(partition: usize, context: Arc) -> Self { + Self { + partition, + context, + sibling_state: None, + } + } + + /// Set sibling shared state + pub fn with_shared_state( + mut self, + sibling_state: Option>, + ) -> Self { + self.sibling_state = sibling_state; + self + } +} + +impl dyn DataSource { + pub fn is(&self) -> bool { + (self as &dyn Any).is::() + } + + pub fn downcast_ref(&self) -> Option<&T> { + (self as &dyn Any).downcast_ref() + } } /// [`ExecutionPlan`] that reads one or more files @@ -208,7 +317,13 @@ pub struct DataSourceExec { /// The source of the data -- for example, `FileScanConfig` or `MemorySourceConfig` data_source: Arc, /// Cached plan properties such as sort order - cache: PlanProperties, + cache: Arc, + /// Per execution state shared across partitions of this plan. + /// + /// Created by [`DataSource::create_sibling_state`] + /// and then passed to + /// [`DataSource::open_with_args`]. + execution_state: Arc>>>, } impl DisplayAs for DataSourceExec { @@ -228,11 +343,7 @@ impl ExecutionPlan for DataSourceExec { "DataSourceExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -262,17 +373,15 @@ impl ExecutionPlan for DataSourceExec { self.properties().eq_properties.output_ordering(), )?; - if let Some(source) = data_source { + Ok(data_source.map(|source| { let output_partitioning = source.output_partitioning(); let plan = self .clone() .with_data_source(source) // Changing source partitioning may invalidate output partitioning. Update it also .with_partitioning(output_partitioning); - Ok(Some(Arc::new(plan))) - } else { - Ok(Some(Arc::new(self.clone()))) - } + Arc::new(plan) as _ + })) } fn execute( @@ -280,8 +389,15 @@ impl ExecutionPlan for DataSourceExec { partition: usize, context: Arc, ) -> Result { - let stream = self.data_source.open(partition, Arc::clone(&context))?; + let shared_state = self + .execution_state + .get_or_init(|| self.data_source.create_sibling_state()) + .clone(); + let args = OpenArgs::new(partition, Arc::clone(&context)) + .with_shared_state(shared_state); + let stream = self.data_source.open_with_args(args)?; let batch_size = context.session_config().batch_size(); + log::debug!( "Batch splitting enabled for partition {partition}: batch_size={batch_size}" ); @@ -295,18 +411,35 @@ impl ExecutionPlan for DataSourceExec { } fn metrics(&self) -> Option { - Some(self.data_source.metrics().clone_inner()) + let mut metrics = self.data_source.metrics().clone_inner(); + + // Add `output_rows_skew` metric to the metrics set. + // Done here because it's a derived metric from output_rows metric. + if let Some(file_scan_config) = self.data_source.downcast_ref::() + && file_scan_config.file_source().file_type() == "parquet" + && let Some(output_rows_skew) = + BaselineMetrics::output_rows_skew_metric(&metrics) + { + metrics.push(output_rows_skew); + } + + Some(metrics) } - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { self.data_source.partition_statistics(partition) } fn with_fetch(&self, limit: Option) -> Option> { let data_source = self.data_source.with_fetch(limit)?; - let cache = self.cache.clone(); - - Some(Arc::new(Self { data_source, cache })) + let cache = Arc::clone(&self.cache); + let execution_state = Arc::new(OnceLock::new()); + + Some(Arc::new(Self { + data_source, + cache, + execution_state, + })) } fn fetch(&self) -> Option { @@ -319,7 +452,7 @@ impl ExecutionPlan for DataSourceExec { ) -> Result>> { match self .data_source - .try_swapping_with_projection(projection.expr())? + .try_swapping_with_projection(projection.projection_expr())? { Some(new_data_source) => { Ok(Some(Arc::new(DataSourceExec::new(new_data_source)))) @@ -342,13 +475,14 @@ impl ExecutionPlan for DataSourceExec { .collect_vec(); let res = self .data_source - .try_pushdown_filters(parent_filters.clone(), config)?; + .try_pushdown_filters(parent_filters, config)?; match res.updated_node { Some(data_source) => { let mut new_node = self.clone(); new_node.data_source = data_source; // Re-compute properties since we have new filters which will impact equivalence info - new_node.cache = Self::compute_properties(&new_node.data_source); + new_node.cache = + Arc::new(Self::compute_properties(&new_node.data_source)); Ok(FilterPushdownPropagation { filters: res.filters, @@ -361,6 +495,49 @@ impl ExecutionPlan for DataSourceExec { }), } } + + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + ) -> Result>> { + // Delegate to the data source and wrap result with DataSourceExec + self.data_source + .try_pushdown_sort(order)? + .try_map(|new_data_source| { + let new_exec = self.clone().with_data_source(new_data_source); + Ok(Arc::new(new_exec) as Arc) + }) + } + + fn with_preserve_order( + &self, + preserve_order: bool, + ) -> Option> { + self.data_source + .with_preserve_order(preserve_order) + .map(|new_data_source| { + Arc::new(self.clone().with_data_source(new_data_source)) + as Arc + }) + } + + fn with_new_state( + &self, + state: Arc, + ) -> Option> { + self.data_source + .with_new_state(state) + .map(|new_data_source| { + Arc::new(self.clone().with_data_source(new_data_source)) + as Arc + }) + } + + fn reset_state(self: Arc) -> Result> { + let mut new_exec = Arc::unwrap_or_clone(self); + new_exec.execution_state = Arc::new(OnceLock::new()); + Ok(Arc::new(new_exec)) + } } impl DataSourceExec { @@ -371,7 +548,11 @@ impl DataSourceExec { // Default constructor for `DataSourceExec`, setting the `cooperative` flag to `true`. pub fn new(data_source: Arc) -> Self { let cache = Self::compute_properties(&data_source); - Self { data_source, cache } + Self { + data_source, + cache: Arc::new(cache), + execution_state: Arc::new(OnceLock::new()), + } } /// Return the source object @@ -380,20 +561,21 @@ impl DataSourceExec { } pub fn with_data_source(mut self, data_source: Arc) -> Self { - self.cache = Self::compute_properties(&data_source); + self.cache = Arc::new(Self::compute_properties(&data_source)); self.data_source = data_source; + self.execution_state = Arc::new(OnceLock::new()); self } /// Assign constraints pub fn with_constraints(mut self, constraints: Constraints) -> Self { - self.cache = self.cache.with_constraints(constraints); + Arc::make_mut(&mut self.cache).set_constraints(constraints); self } /// Assign output partitioning pub fn with_partitioning(mut self, partitioning: Partitioning) -> Self { - self.cache = self.cache.with_partitioning(partitioning); + Arc::make_mut(&mut self.cache).partitioning = partitioning; self } @@ -412,14 +594,14 @@ impl DataSourceExec { /// Returns `None` if /// 1. the datasource is not scanning files (`FileScanConfig`) /// 2. The [`FileScanConfig::file_source`] is not of type `T` - pub fn downcast_to_file_source(&self) -> Option<(&FileScanConfig, &T)> { + pub fn downcast_to_file_source( + &self, + ) -> Option<(&FileScanConfig, &T)> { self.data_source() - .as_any() .downcast_ref::() .and_then(|file_scan_conf| { file_scan_conf .file_source() - .as_any() .downcast_ref::() .map(|source| (file_scan_conf, source)) }) diff --git a/datafusion/datasource/src/statistics.rs b/datafusion/datasource/src/statistics.rs index 980677e488b81..6abfafe9d39d4 100644 --- a/datafusion/datasource/src/statistics.rs +++ b/datafusion/datasource/src/statistics.rs @@ -22,16 +22,16 @@ use std::sync::Arc; -use crate::file_groups::FileGroup; use crate::PartitionedFile; +use crate::file_groups::FileGroup; use arrow::array::RecordBatch; use arrow::compute::SortColumn; use arrow::datatypes::SchemaRef; use arrow::row::{Row, Rows}; -use datafusion_common::stats::Precision; +use datafusion_common::stats::{NdvFallback, Precision}; use datafusion_common::{ - plan_datafusion_err, plan_err, DataFusionError, Result, ScalarValue, + DataFusionError, Result, ScalarValue, plan_datafusion_err, plan_err, }; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -50,13 +50,13 @@ pub(crate) struct MinMaxStatistics { impl MinMaxStatistics { /// Sort order used to sort the statistics - #[allow(unused)] + #[expect(unused)] pub fn sort_order(&self) -> &LexOrdering { &self.sort_order } /// Min value at index - #[allow(unused)] + #[expect(unused)] pub fn min(&'_ self, idx: usize) -> Row<'_> { self.min_by_sort_order.row(idx) } @@ -261,16 +261,17 @@ impl MinMaxStatistics { /// Return a sorted list of the min statistics together with the original indices pub fn min_values_sorted(&self) -> Vec<(usize, Row<'_>)> { let mut sort: Vec<_> = self.min_by_sort_order.iter().enumerate().collect(); - sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); + sort.sort_unstable_by_key(|(_, row)| *row); sort } /// Check if the min/max statistics are in order and non-overlapping + /// (or touching at boundaries) pub fn is_sorted(&self) -> bool { self.max_by_sort_order .iter() .zip(self.min_by_sort_order.iter().skip(1)) - .all(|(max, next_min)| max < next_min) + .all(|(max, next_min)| max <= next_min) } } @@ -279,10 +280,84 @@ fn sort_columns_from_physical_sort_exprs( ) -> Option> { sort_order .iter() - .map(|expr| expr.expr.as_any().downcast_ref::()) + .map(|expr| expr.expr.downcast_ref::()) .collect() } +fn seed_summary_statistics(summary_statistics: &mut Statistics, file_stats: &Statistics) { + summary_statistics.num_rows = file_stats.num_rows; + summary_statistics.total_byte_size = file_stats.total_byte_size; + + for (summary_col_stats, file_col_stats) in summary_statistics + .column_statistics + .iter_mut() + .zip(file_stats.column_statistics.iter()) + { + summary_col_stats.null_count = file_col_stats.null_count; + summary_col_stats.max_value = file_col_stats.max_value.clone(); + summary_col_stats.min_value = file_col_stats.min_value.clone(); + summary_col_stats.sum_value = file_col_stats.sum_value.cast_to_sum_type(); + summary_col_stats.byte_size = file_col_stats.byte_size; + } +} + +fn merge_summary_statistics( + summary_statistics: &mut Statistics, + file_stats: &Statistics, +) { + summary_statistics.num_rows = summary_statistics.num_rows.add(&file_stats.num_rows); + summary_statistics.total_byte_size = summary_statistics + .total_byte_size + .add(&file_stats.total_byte_size); + + for (summary_col_stats, file_col_stats) in summary_statistics + .column_statistics + .iter_mut() + .zip(file_stats.column_statistics.iter()) + { + let ColumnStatistics { + null_count: file_nc, + max_value: file_max, + min_value: file_min, + sum_value: file_sum, + distinct_count: _, + byte_size: file_sbs, + } = file_col_stats; + + summary_col_stats.null_count = summary_col_stats.null_count.add(file_nc); + summary_col_stats.max_value = summary_col_stats.max_value.max(file_max); + summary_col_stats.min_value = summary_col_stats.min_value.min(file_min); + summary_col_stats.sum_value = summary_col_stats.sum_value.add_for_sum(file_sum); + summary_col_stats.byte_size = summary_col_stats.byte_size.add(file_sbs); + } +} + +fn seed_first_file_statistics( + limit_num_rows: &mut Precision, + summary_statistics: &mut Statistics, + file_stats: &Statistics, + collect_stats: bool, +) { + *limit_num_rows = file_stats.num_rows; + + if collect_stats { + seed_summary_statistics(summary_statistics, file_stats); + } +} + +fn merge_file_statistics( + limit_num_rows: &mut Precision, + summary_statistics: &mut Statistics, + file_stats: &Statistics, + collect_stats: bool, +) { + *limit_num_rows = limit_num_rows.add(&file_stats.num_rows); + + if collect_stats { + merge_summary_statistics(summary_statistics, file_stats); + } +} + /// Get all files as well as the file level summary statistics (no statistic for partition columns). /// If the optional `limit` is provided, includes only sufficient files. Needed to read up to /// `limit` number of rows. `collect_stats` is passed down from the configuration parameter on @@ -292,7 +367,7 @@ fn sort_columns_from_physical_sort_exprs( since = "47.0.0", note = "Please use `get_files_with_limit` and `compute_all_files_statistics` instead" )] -#[allow(unused)] +#[cfg_attr(not(test), expect(unused))] pub async fn get_statistics_with_limit( all_files: impl Stream)>>, file_schema: SchemaRef, @@ -307,9 +382,14 @@ pub async fn get_statistics_with_limit( // - zero for summations, and // - neutral element for extreme points. let size = file_schema.fields().len(); - let mut col_stats_set = vec![ColumnStatistics::default(); size]; - let mut num_rows = Precision::::Absent; - let mut total_byte_size = Precision::::Absent; + let mut summary_statistics = Statistics { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics::default(); size], + }; + // Keep limit pruning separate from the returned summary so `collect_stats=false` + // can still stop early using known file row counts. + let mut limit_num_rows = Precision::::Absent; // Fusing the stream allows us to call next safely even once it is finished. let mut all_files = Box::pin(all_files.fuse()); @@ -319,23 +399,18 @@ pub async fn get_statistics_with_limit( file.statistics = Some(Arc::clone(&file_stats)); result_files.push(file); - // First file, we set them directly from the file statistics. - num_rows = file_stats.num_rows; - total_byte_size = file_stats.total_byte_size; - for (index, file_column) in - file_stats.column_statistics.clone().into_iter().enumerate() - { - col_stats_set[index].null_count = file_column.null_count; - col_stats_set[index].max_value = file_column.max_value; - col_stats_set[index].min_value = file_column.min_value; - col_stats_set[index].sum_value = file_column.sum_value; - } + seed_first_file_statistics( + &mut limit_num_rows, + &mut summary_statistics, + &file_stats, + collect_stats, + ); // If the number of rows exceeds the limit, we can stop processing // files. This only applies when we know the number of rows. It also // currently ignores tables that have no statistics regarding the // number of rows. - let conservative_num_rows = match num_rows { + let conservative_num_rows = match limit_num_rows { Precision::Exact(nr) => nr, _ => usize::MIN, }; @@ -344,42 +419,18 @@ pub async fn get_statistics_with_limit( let (mut file, file_stats) = current?; file.statistics = Some(Arc::clone(&file_stats)); result_files.push(file); - if !collect_stats { - continue; - } - - // We accumulate the number of rows, total byte size and null - // counts across all the files in question. If any file does not - // provide any information or provides an inexact value, we demote - // the statistic precision to inexact. - num_rows = num_rows.add(&file_stats.num_rows); - - total_byte_size = total_byte_size.add(&file_stats.total_byte_size); - - for (file_col_stats, col_stats) in file_stats - .column_statistics - .iter() - .zip(col_stats_set.iter_mut()) - { - let ColumnStatistics { - null_count: file_nc, - max_value: file_max, - min_value: file_min, - sum_value: file_sum, - distinct_count: _, - } = file_col_stats; - - col_stats.null_count = col_stats.null_count.add(file_nc); - col_stats.max_value = col_stats.max_value.max(file_max); - col_stats.min_value = col_stats.min_value.min(file_min); - col_stats.sum_value = col_stats.sum_value.add(file_sum); - } + merge_file_statistics( + &mut limit_num_rows, + &mut summary_statistics, + &file_stats, + collect_stats, + ); // If the number of rows exceeds the limit, we can stop processing // files. This only applies when we know the number of rows. It also // currently ignores tables that have no statistics regarding the // number of rows. - if num_rows.get_value().unwrap_or(&usize::MIN) + if limit_num_rows.get_value().unwrap_or(&usize::MIN) > &limit.unwrap_or(usize::MAX) { break; @@ -388,11 +439,7 @@ pub async fn get_statistics_with_limit( } }; - let mut statistics = Statistics { - num_rows, - total_byte_size, - column_statistics: col_stats_set, - }; + let mut statistics = summary_statistics; if all_files.next().await.is_some() { // If we still have files in the stream, it means that the limit kicked // in, and the statistic could have been different had we processed the @@ -432,7 +479,11 @@ pub fn compute_file_group_statistics( let stats = file.statistics.as_ref()?; Some(stats.as_ref()) }); - let statistics = Statistics::try_merge_iter(file_group_stats, &file_schema)?; + let statistics = Statistics::try_merge_iter_with_ndv_fallback( + file_group_stats, + &file_schema, + NdvFallback::Max, + )?; Ok(file_group.with_statistics(Arc::new(statistics))) } @@ -477,8 +528,11 @@ pub fn compute_all_files_statistics( .iter() .filter_map(|file_group| file_group.file_statistics(None)); - let mut statistics = - Statistics::try_merge_iter(file_groups_statistics, &table_schema)?; + let mut statistics = Statistics::try_merge_iter_with_ndv_fallback( + file_groups_statistics, + &table_schema, + NdvFallback::Max, + )?; if inexact_stats { statistics = statistics.to_inexact() @@ -494,3 +548,346 @@ pub fn add_row_stats( ) -> Precision { file_num_rows.add(&num_rows) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::PartitionedFile; + use crate::file_groups::FileGroup; + use arrow::datatypes::{DataType, Field, Schema}; + use futures::stream; + + fn file_stats(sum: u32) -> Statistics { + Statistics { + num_rows: Precision::Exact(1), + total_byte_size: Precision::Exact(4), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::UInt32(Some(sum))), + min_value: Precision::Exact(ScalarValue::UInt32(Some(sum))), + sum_value: Precision::Exact(ScalarValue::UInt32(Some(sum))), + distinct_count: Precision::Exact(1), + byte_size: Precision::Exact(4), + }], + } + } + + fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])) + } + + fn make_file_stats( + num_rows: usize, + total_byte_size: usize, + col_stats: ColumnStatistics, + ) -> Arc { + Arc::new(Statistics { + num_rows: Precision::Exact(num_rows), + total_byte_size: Precision::Exact(total_byte_size), + column_statistics: vec![col_stats], + }) + } + + fn rich_col_stats( + null_count: usize, + min: i64, + max: i64, + sum: i64, + byte_size: usize, + ) -> ColumnStatistics { + ColumnStatistics { + null_count: Precision::Exact(null_count), + max_value: Precision::Exact(ScalarValue::Int64(Some(max))), + min_value: Precision::Exact(ScalarValue::Int64(Some(min))), + distinct_count: Precision::Absent, + sum_value: Precision::Exact(ScalarValue::Int64(Some(sum))), + byte_size: Precision::Exact(byte_size), + } + } + + fn utf8_file_stats(ndv: usize, min: &str, max: &str) -> Statistics { + Statistics { + num_rows: Precision::Exact(1), + total_byte_size: Precision::Exact(16), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Utf8(Some(max.to_string()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some(min.to_string()))), + sum_value: Precision::Absent, + distinct_count: Precision::Exact(ndv), + byte_size: Precision::Exact(16), + }], + } + } + + fn file_with_stats(path: &str, stats: Statistics) -> PartitionedFile { + PartitionedFile::new(path, 1).with_statistics(Arc::new(stats)) + } + #[tokio::test] + #[expect(deprecated)] + async fn test_get_statistics_with_limit_casts_first_file_sum_to_sum_type() + -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::UInt32, true)])); + + let files = stream::iter(vec![Ok(( + PartitionedFile::new("f1.parquet", 1), + Arc::new(file_stats(100)), + ))]); + + let (_group, stats) = + get_statistics_with_limit(files, schema, None, true).await?; + + assert_eq!( + stats.column_statistics[0].sum_value, + Precision::Exact(ScalarValue::UInt64(Some(100))) + ); + + Ok(()) + } + + #[tokio::test] + #[expect(deprecated)] + async fn test_get_statistics_with_limit_merges_sum_with_unsigned_widening() + -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::UInt32, true)])); + + let files = stream::iter(vec![ + Ok(( + PartitionedFile::new("f1.parquet", 1), + Arc::new(file_stats(100)), + )), + Ok(( + PartitionedFile::new("f2.parquet", 1), + Arc::new(file_stats(200)), + )), + ]); + + let (_group, stats) = + get_statistics_with_limit(files, schema, None, true).await?; + + assert_eq!( + stats.column_statistics[0].sum_value, + Precision::Exact(ScalarValue::UInt64(Some(300))) + ); + + Ok(()) + } + + #[tokio::test] + #[expect(deprecated)] + async fn get_statistics_with_limit_collect_stats_false_returns_bare_statistics() { + let all_files = stream::iter(vec![ + Ok(( + PartitionedFile::new("first.parquet", 10), + make_file_stats(0, 0, rich_col_stats(1, 1, 9, 15, 64)), + )), + Ok(( + PartitionedFile::new("second.parquet", 20), + make_file_stats(10, 100, rich_col_stats(2, 10, 99, 300, 128)), + )), + ]); + + let (_files, statistics) = + get_statistics_with_limit(all_files, test_schema(), None, false) + .await + .unwrap(); + + assert_eq!(statistics.num_rows, Precision::Absent); + assert_eq!(statistics.total_byte_size, Precision::Absent); + assert_eq!(statistics.column_statistics.len(), 1); + assert_eq!( + statistics.column_statistics[0].null_count, + Precision::Absent + ); + assert_eq!(statistics.column_statistics[0].max_value, Precision::Absent); + assert_eq!(statistics.column_statistics[0].min_value, Precision::Absent); + assert_eq!(statistics.column_statistics[0].sum_value, Precision::Absent); + assert_eq!(statistics.column_statistics[0].byte_size, Precision::Absent); + } + + #[tokio::test] + #[expect(deprecated)] + async fn get_statistics_with_limit_collect_stats_false_uses_row_counts_for_limit() { + let all_files = stream::iter(vec![ + Ok(( + PartitionedFile::new("first.parquet", 10), + make_file_stats(3, 30, rich_col_stats(1, 1, 9, 15, 64)), + )), + Ok(( + PartitionedFile::new("second.parquet", 20), + make_file_stats(3, 30, rich_col_stats(2, 10, 99, 300, 128)), + )), + Ok(( + PartitionedFile::new("third.parquet", 30), + make_file_stats(3, 30, rich_col_stats(0, 100, 199, 450, 256)), + )), + ]); + + let (files, statistics) = + get_statistics_with_limit(all_files, test_schema(), Some(4), false) + .await + .unwrap(); + + assert_eq!(files.len(), 2); + assert_eq!(statistics.num_rows, Precision::Absent); + assert_eq!(statistics.total_byte_size, Precision::Absent); + } + + #[tokio::test] + #[expect(deprecated)] + async fn get_statistics_with_limit_collect_stats_true_aggregates_statistics() { + let all_files = stream::iter(vec![ + Ok(( + PartitionedFile::new("first.parquet", 10), + make_file_stats(5, 50, rich_col_stats(1, 1, 9, 15, 64)), + )), + Ok(( + PartitionedFile::new("second.parquet", 20), + make_file_stats(10, 100, rich_col_stats(2, 10, 99, 300, 128)), + )), + ]); + + let (_files, statistics) = + get_statistics_with_limit(all_files, test_schema(), None, true) + .await + .unwrap(); + + assert_eq!(statistics.num_rows, Precision::Exact(15)); + assert_eq!(statistics.total_byte_size, Precision::Exact(150)); + assert_eq!( + statistics.column_statistics[0].null_count, + Precision::Exact(3) + ); + assert_eq!( + statistics.column_statistics[0].min_value, + Precision::Exact(ScalarValue::Int64(Some(1))) + ); + assert_eq!( + statistics.column_statistics[0].max_value, + Precision::Exact(ScalarValue::Int64(Some(99))) + ); + assert_eq!( + statistics.column_statistics[0].sum_value, + Precision::Exact(ScalarValue::Int64(Some(315))) + ); + assert_eq!( + statistics.column_statistics[0].byte_size, + Precision::Exact(192) + ); + } + + #[tokio::test] + #[expect(deprecated)] + async fn get_statistics_with_limit_collect_stats_true_limit_marks_inexact() { + let all_files = stream::iter(vec![ + Ok(( + PartitionedFile::new("first.parquet", 10), + make_file_stats(5, 50, rich_col_stats(0, 1, 5, 15, 64)), + )), + Ok(( + PartitionedFile::new("second.parquet", 20), + make_file_stats(5, 50, rich_col_stats(1, 6, 10, 40, 64)), + )), + Ok(( + PartitionedFile::new("third.parquet", 20), + make_file_stats(5, 50, rich_col_stats(2, 11, 15, 65, 64)), + )), + ]); + + let (files, statistics) = + get_statistics_with_limit(all_files, test_schema(), Some(8), true) + .await + .unwrap(); + + assert_eq!(files.len(), 2); + assert_eq!(statistics.num_rows, Precision::Inexact(10)); + assert_eq!(statistics.total_byte_size, Precision::Inexact(100)); + assert_eq!( + statistics.column_statistics[0].min_value, + Precision::Inexact(ScalarValue::Int64(Some(1))) + ); + assert_eq!( + statistics.column_statistics[0].max_value, + Precision::Inexact(ScalarValue::Int64(Some(10))) + ); + assert_eq!( + statistics.column_statistics[0].sum_value, + Precision::Inexact(ScalarValue::Int64(Some(55))) + ); + assert_eq!( + statistics.column_statistics[0].byte_size, + Precision::Inexact(128) + ); + } + + #[test] + fn test_compute_file_group_statistics_uses_max_ndv_fallback() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, true)])); + let file_group = FileGroup::new(vec![ + file_with_stats("f1.parquet", utf8_file_stats(5, "a", "x")), + file_with_stats("f2.parquet", utf8_file_stats(8, "b", "z")), + ]); + + let file_group = + compute_file_group_statistics(file_group, Arc::clone(&schema), true)?; + let stats = file_group.file_statistics(None).unwrap(); + + assert_eq!( + stats.column_statistics[0].distinct_count, + Precision::Inexact(8) + ); + assert_eq!( + stats.column_statistics[0].min_value, + Precision::Exact(ScalarValue::Utf8(Some("a".to_string()))) + ); + assert_eq!( + stats.column_statistics[0].max_value, + Precision::Exact(ScalarValue::Utf8(Some("z".to_string()))) + ); + + Ok(()) + } + + #[test] + fn test_compute_all_files_statistics_uses_max_ndv_fallback() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, true)])); + let file_groups = vec![ + FileGroup::new(vec![ + file_with_stats("f1.parquet", utf8_file_stats(5, "a", "x")), + file_with_stats("f2.parquet", utf8_file_stats(8, "b", "z")), + ]), + FileGroup::new(vec![ + file_with_stats("f3.parquet", utf8_file_stats(3, "c", "w")), + file_with_stats("f4.parquet", utf8_file_stats(6, "d", "y")), + ]), + ]; + + let (file_groups, stats) = + compute_all_files_statistics(file_groups, schema, true, false)?; + + assert_eq!( + file_groups[0] + .file_statistics(None) + .unwrap() + .column_statistics[0] + .distinct_count, + Precision::Inexact(8) + ); + assert_eq!( + file_groups[1] + .file_statistics(None) + .unwrap() + .column_statistics[0] + .distinct_count, + Precision::Inexact(6) + ); + assert_eq!( + stats.column_statistics[0].distinct_count, + Precision::Inexact(8) + ); + + Ok(()) + } +} diff --git a/datafusion/datasource/src/table_schema.rs b/datafusion/datasource/src/table_schema.rs index ff0e788018875..f1cb86ed7413d 100644 --- a/datafusion/datasource/src/table_schema.rs +++ b/datafusion/datasource/src/table_schema.rs @@ -17,16 +17,23 @@ //! Helper struct to manage table schemas with partition columns -use arrow::datatypes::{FieldRef, SchemaBuilder, SchemaRef}; +use arrow::datatypes::{FieldRef, Fields, SchemaBuilder, SchemaRef}; use std::sync::Arc; -/// Helper to hold table schema information for partitioned data sources. +/// The overall schema for potentially partitioned data sources. /// -/// When reading partitioned data (such as Hive-style partitioning), a table's schema -/// consists of two parts: +/// When reading partitioned data (such as Hive-style partitioning), a [`TableSchema`] +/// consists of up to three parts: /// 1. **File schema**: The schema of the actual data files on disk -/// 2. **Partition columns**: Columns that are encoded in the directory structure, -/// not stored in the files themselves +/// 2. **Partition columns**: Columns whose values are encoded in the directory structure, +/// but not stored in the files themselves +/// 3. **Virtual columns**: Columns produced by the file reader (e.g. Parquet +/// `row_number`) that are not stored in the files +/// +/// The full table schema is composed in that order: file columns, then +/// partition columns, then virtual columns. Consumers that need a different +/// output ordering should use a projection on top of +/// [`TableSchema::table_schema`]. /// /// # Example: Partitioned Table /// @@ -70,30 +77,47 @@ pub struct TableSchema { /// /// These columns are NOT present in the data files but are appended to each /// row during query execution based on the file's location. - table_partition_cols: Vec, + /// + /// Stored as [`Fields`] (an immutable `Arc<[FieldRef]>`) so that cloning a + /// `TableSchema` is cheap and the partition columns can be shared zero-copy + /// with an existing schema. + table_partition_cols: Fields, + + /// Virtual columns that are generated by the reader rather than read from + /// the data files or the directory structure. + /// + /// For example, a Parquet reader may inject a `row_number` column whose + /// values are produced per file by the reader. Virtual column fields must + /// carry an arrow extension type (e.g. `RowNumber`, `RowGroupIndex`) so the + /// file reader can recognize them. + /// + /// Virtual columns are appended at the end of the table schema, after the + /// file columns and any partition columns (layout: `[file, partition, + /// virtual]`). + virtual_columns: Fields, - /// The complete table schema: file_schema columns followed by partition columns. + /// The complete table schema: file_schema columns, followed by partition + /// columns, followed by virtual columns. /// - /// This is pre-computed during construction by concatenating `file_schema` - /// and `table_partition_cols`, so it can be returned as a cheap reference. + /// This is pre-computed during construction by concatenating the three + /// parts, so it can be returned as a cheap reference. table_schema: SchemaRef, + + /// Schema of file + partition columns, excluding virtual columns. + /// + /// Pre-computed during construction so [`Self::schema_without_virtual_columns`] + /// can return a cheap reference. When there are no virtual columns this + /// shares the same `Arc` as `table_schema`. + schema_without_virtual_columns: SchemaRef, } impl TableSchema { - /// Create a new TableSchema from a file schema and partition columns. + /// Start building a [`TableSchema`] from its (required) file schema. /// - /// The table schema is automatically computed by appending the partition columns - /// to the file schema. - /// - /// You should prefer calling this method over - /// chaining [`TableSchema::from_file_schema`] and [`TableSchema::with_table_partition_cols`] - /// if you have both the file schema and partition columns available at construction time - /// since it avoids re-computing the table schema. - /// - /// # Arguments - /// - /// * `file_schema` - Schema of the data files (without partition columns) - /// * `table_partition_cols` - Partition columns to append to each row + /// Partition columns are optional and added with + /// [`TableSchemaBuilder::with_table_partition_cols`]; the full table schema + /// is computed once by [`TableSchemaBuilder::build`]. This is the preferred + /// way to construct a `TableSchema`. /// /// # Example /// @@ -106,45 +130,54 @@ impl TableSchema { /// Field::new("amount", DataType::Float64, false), /// ])); /// - /// let partition_cols = vec![ - /// Arc::new(Field::new("date", DataType::Utf8, false)), - /// Arc::new(Field::new("region", DataType::Utf8, false)), - /// ]; - /// - /// let table_schema = TableSchema::new(file_schema, partition_cols); + /// let table_schema = TableSchema::builder(file_schema) + /// .with_table_partition_cols(vec![ + /// Arc::new(Field::new("date", DataType::Utf8, false)), + /// Arc::new(Field::new("region", DataType::Utf8, false)), + /// ]) + /// .build(); /// /// // Table schema will have 4 columns: user_id, amount, date, region /// assert_eq!(table_schema.table_schema().fields().len(), 4); /// ``` + pub fn builder(file_schema: SchemaRef) -> TableSchemaBuilder { + TableSchemaBuilder::new(file_schema) + } + + /// Create a new TableSchema from a file schema and partition columns. + /// + /// This is a convenience for + /// `TableSchema::builder(file_schema).with_table_partition_cols(cols).build()`. + #[deprecated( + since = "55.0.0", + note = "use TableSchema::builder(file_schema).with_table_partition_cols(cols).build() (or TableSchema::from(file_schema) for no partition columns)" + )] pub fn new(file_schema: SchemaRef, table_partition_cols: Vec) -> Self { - let mut builder = SchemaBuilder::from(file_schema.as_ref()); - builder.extend(table_partition_cols.iter().cloned()); - Self { - file_schema, - table_partition_cols, - table_schema: Arc::new(builder.finish()), - } + TableSchemaBuilder::new(file_schema) + .with_table_partition_cols(table_partition_cols) + .build() } /// Create a new TableSchema with no partition columns. - /// - /// You should prefer calling [`TableSchema::new`] if you have partition columns at - /// construction time since it avoids re-computing the table schema. + #[deprecated( + since = "55.0.0", + note = "use TableSchema::from(file_schema) / file_schema.into()" + )] pub fn from_file_schema(file_schema: SchemaRef) -> Self { - Self::new(file_schema, vec![]) + TableSchemaBuilder::new(file_schema).build() } - /// Add partition columns to an existing TableSchema, returning a new instance. - /// - /// You should prefer calling [`TableSchema::new`] instead of chaining [`TableSchema::from_file_schema`] - /// into [`TableSchema::with_table_partition_cols`] if you have partition columns at construction time - /// since it avoids re-computing the table schema. - pub fn with_table_partition_cols(mut self, partition_cols: Vec) -> Self { - self.table_partition_cols = partition_cols; - let mut builder = SchemaBuilder::from(self.file_schema.as_ref()); - builder.extend(self.table_partition_cols.iter().cloned()); - self.table_schema = Arc::new(builder.finish()); - self + /// Return a new `TableSchema` with `partition_cols` as its partition columns, + /// replacing any existing ones. Existing virtual columns are preserved. + #[deprecated( + since = "55.0.0", + note = "use TableSchema::builder(file_schema).with_table_partition_cols(cols).build()" + )] + pub fn with_table_partition_cols(self, partition_cols: Vec) -> Self { + TableSchemaBuilder::new(self.file_schema) + .with_table_partition_cols(partition_cols) + .with_virtual_columns(self.virtual_columns) + .build() } /// Get the file schema (without partition columns). @@ -158,21 +191,420 @@ impl TableSchema { /// /// These are the columns derived from the directory structure that /// will be appended to each row during query execution. - pub fn table_partition_cols(&self) -> &Vec { + pub fn table_partition_cols(&self) -> &Fields { &self.table_partition_cols } - /// Get the full table schema (file schema + partition columns). + /// Get the virtual columns. /// - /// This is the complete schema that will be seen by queries, combining - /// both the columns from the files and the partition columns. + /// Virtual columns are produced by the file reader (e.g. Parquet + /// `row_number`) and are not stored in the data files or derived from + /// partition paths. + pub fn virtual_columns(&self) -> &Fields { + &self.virtual_columns + } + + /// Get the full table schema (file schema + partition columns + virtual columns). + /// + /// This is the complete schema that will be seen by queries. Fields appear + /// in the order: file columns, partition columns, virtual columns. pub fn table_schema(&self) -> &SchemaRef { &self.table_schema } + + /// Schema of columns that can be referenced by predicates pushed into the + /// file reader: file columns plus partition columns, excluding virtual + /// columns. + /// + /// Virtual columns are produced by the reader itself (e.g. Parquet + /// `row_number`) and cannot be referenced inside the reader's row filter, + /// so predicates that reference them must stay above the scan. Callers + /// deciding which filters to push down should check against this schema + /// rather than [`Self::table_schema`]. + /// + /// When there are no virtual columns this returns the same schema as + /// [`Self::table_schema`]. + pub fn schema_without_virtual_columns(&self) -> &SchemaRef { + &self.schema_without_virtual_columns + } } impl From for TableSchema { fn from(schema: SchemaRef) -> Self { - Self::from_file_schema(schema) + TableSchemaBuilder::new(schema).build() + } +} + +impl From<&SchemaRef> for TableSchema { + fn from(schema: &SchemaRef) -> Self { + TableSchemaBuilder::new(Arc::clone(schema)).build() + } +} + +/// Builder for [`TableSchema`]. +/// +/// The file schema is the only required input; partition columns and virtual +/// columns are optional. Unlike calling [`TableSchema`]'s setters repeatedly, +/// the builder computes the concatenated table schema exactly once, in +/// [`TableSchemaBuilder::build`]. +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow::datatypes::{Schema, Field, DataType}; +/// # use datafusion_datasource::TableSchemaBuilder; +/// # let file_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)])); +/// let table_schema = TableSchemaBuilder::new(file_schema) +/// .with_table_partition_cols(vec![Arc::new(Field::new("date", DataType::Utf8, false))]) +/// .build(); +/// assert_eq!(table_schema.table_partition_cols().len(), 1); +/// ``` +#[derive(Debug, Clone)] +pub struct TableSchemaBuilder { + file_schema: SchemaRef, + table_partition_cols: Fields, + virtual_columns: Fields, +} + +impl TableSchemaBuilder { + /// Create a builder for a `TableSchema` over the given file schema, with no + /// partition or virtual columns yet. + pub fn new(file_schema: SchemaRef) -> Self { + Self { + file_schema, + table_partition_cols: Fields::empty(), + virtual_columns: Fields::empty(), + } + } + + /// Set the partition columns, replacing any previously set. + /// + /// Accepts anything convertible into [`Fields`] (e.g. `Vec` or an + /// existing schema's `Fields`, which is shared zero-copy). + pub fn with_table_partition_cols( + mut self, + table_partition_cols: impl Into, + ) -> Self { + self.table_partition_cols = table_partition_cols.into(); + self + } + + /// Set the virtual columns, replacing any previously set. + /// + /// Virtual columns are produced by the file reader (e.g. Parquet + /// `row_number`) and appended at the end of the table schema. Each field + /// must carry an arrow virtual extension type so the reader can recognize + /// it. + /// + /// Accepts anything convertible into [`Fields`] (e.g. `Vec`). + pub fn with_virtual_columns(mut self, virtual_columns: impl Into) -> Self { + self.virtual_columns = virtual_columns.into(); + self + } + + /// Build the [`TableSchema`], computing the full + /// `file + partition + virtual` schema once. + pub fn build(self) -> TableSchema { + debug_assert!( + self.virtual_columns.iter().enumerate().all(|(i, v)| { + let name = v.name(); + !self.file_schema.fields().iter().any(|f| f.name() == name) + && !self.table_partition_cols.iter().any(|p| p.name() == name) + && !self.virtual_columns[..i].iter().any(|w| w.name() == name) + }), + "virtual column name collides with an existing file, partition, or virtual column" + ); + + let mut builder = SchemaBuilder::from(self.file_schema.as_ref()); + builder.extend(self.table_partition_cols.iter().cloned()); + let (table_schema, schema_without_virtual_columns) = + if self.virtual_columns.is_empty() { + let schema = Arc::new(builder.finish()); + (Arc::clone(&schema), schema) + } else { + let without_virtual = Arc::new(builder.finish()); + let mut builder = SchemaBuilder::from(without_virtual.as_ref()); + builder.extend(self.virtual_columns.iter().cloned()); + (Arc::new(builder.finish()), without_virtual) + }; + TableSchema { + file_schema: self.file_schema, + table_partition_cols: self.table_partition_cols, + virtual_columns: self.virtual_columns, + table_schema, + schema_without_virtual_columns, + } + } +} + +impl From for TableSchemaBuilder { + fn from(schema: SchemaRef) -> Self { + TableSchemaBuilder::new(schema) + } +} + +impl From<&SchemaRef> for TableSchemaBuilder { + fn from(schema: &SchemaRef) -> Self { + TableSchemaBuilder::new(Arc::clone(schema)) + } +} + +#[cfg(test)] +mod tests { + use super::{TableSchema, TableSchemaBuilder}; + use arrow::datatypes::{DataType, Field, Schema}; + use std::sync::Arc; + + #[test] + fn test_table_schema_creation() { + let file_schema = Arc::new(Schema::new(vec![ + Field::new("user_id", DataType::Int64, false), + Field::new("amount", DataType::Float64, false), + ])); + + let partition_cols = vec![ + Arc::new(Field::new("date", DataType::Utf8, false)), + Arc::new(Field::new("region", DataType::Utf8, false)), + ]; + + let table_schema = TableSchema::builder(file_schema.clone()) + .with_table_partition_cols(partition_cols.clone()) + .build(); + + // Verify file schema + assert_eq!(table_schema.file_schema().as_ref(), file_schema.as_ref()); + + // Verify partition columns + assert_eq!(table_schema.table_partition_cols().len(), 2); + assert_eq!(table_schema.table_partition_cols()[0], partition_cols[0]); + assert_eq!(table_schema.table_partition_cols()[1], partition_cols[1]); + + // Verify full table schema + let expected_fields = vec![ + Field::new("user_id", DataType::Int64, false), + Field::new("amount", DataType::Float64, false), + Field::new("date", DataType::Utf8, false), + Field::new("region", DataType::Utf8, false), + ]; + let expected_schema = Schema::new(expected_fields); + assert_eq!(table_schema.table_schema().as_ref(), &expected_schema); + } + + #[test] + fn test_builder_with_partition_cols() { + let file_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let table_schema = TableSchemaBuilder::new(Arc::clone(&file_schema)) + .with_table_partition_cols(vec![ + Arc::new(Field::new("country", DataType::Utf8, false)), + Arc::new(Field::new("year", DataType::Int32, false)), + ]) + .build(); + + // File schema is preserved and the partition columns are appended. + assert_eq!(table_schema.file_schema().as_ref(), file_schema.as_ref()); + assert_eq!(table_schema.table_partition_cols().len(), 2); + assert_eq!(table_schema.table_partition_cols()[0].name(), "country"); + assert_eq!(table_schema.table_partition_cols()[1].name(), "year"); + + let expected_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("country", DataType::Utf8, false), + Field::new("year", DataType::Int32, false), + ]); + assert_eq!(table_schema.table_schema().as_ref(), &expected_schema); + } + + #[test] + fn test_builder_with_table_partition_cols_replaces() { + // Calling the setter more than once replaces rather than appends. + let file_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + + let table_schema = TableSchemaBuilder::new(file_schema) + .with_table_partition_cols(vec![Arc::new(Field::new( + "country", + DataType::Utf8, + false, + ))]) + .with_table_partition_cols(vec![Arc::new(Field::new( + "city", + DataType::Utf8, + false, + ))]) + .build(); + + assert_eq!(table_schema.table_partition_cols().len(), 1); + assert_eq!(table_schema.table_partition_cols()[0].name(), "city"); + } + + #[test] + fn test_builder_accepts_fields_zero_copy() { + // `with_table_partition_cols` accepts an existing schema's `Fields` + // directly (shared via `Arc`, no `Vec` round-trip). + let file_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let partition_schema = + Schema::new(vec![Field::new("date", DataType::Utf8, false)]); + + let table_schema = TableSchemaBuilder::new(file_schema) + .with_table_partition_cols(partition_schema.fields().clone()) + .build(); + + assert_eq!(table_schema.table_partition_cols().len(), 1); + assert_eq!(table_schema.table_partition_cols()[0].name(), "date"); + } + + #[test] + #[expect(deprecated)] + fn test_deprecated_with_table_partition_cols_replaces() { + // The deprecated setter still works and replaces the partition columns. + // It is safe on a shared clone because partition columns are immutable. + let file_schema = + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let original = TableSchema::builder(file_schema) + .with_table_partition_cols(vec![Arc::new(Field::new( + "country", + DataType::Utf8, + false, + ))]) + .build(); + + let replaced = + original + .clone() + .with_table_partition_cols(vec![Arc::new(Field::new( + "city", + DataType::Utf8, + false, + ))]); + + assert_eq!(replaced.table_partition_cols().len(), 1); + assert_eq!(replaced.table_partition_cols()[0].name(), "city"); + + // The original is untouched. + assert_eq!(original.table_partition_cols().len(), 1); + assert_eq!(original.table_partition_cols()[0].name(), "country"); + } + + #[test] + fn test_builder_with_virtual_columns_layout() { + let file_schema = Arc::new(Schema::new(vec![ + Field::new("user_id", DataType::Int64, false), + Field::new("amount", DataType::Float64, false), + ])); + + let virtual_cols = + vec![Arc::new(Field::new("row_number", DataType::Int64, true))]; + + let partition_cols = vec![Arc::new(Field::new("date", DataType::Utf8, false))]; + + // Apply virtual columns and partition columns in either order on the + // builder; the resulting table schema should always be + // [file, partition, virtual]. + let built_virtual_first = TableSchemaBuilder::new(Arc::clone(&file_schema)) + .with_virtual_columns(virtual_cols.clone()) + .with_table_partition_cols(partition_cols.clone()) + .build(); + + let built_partition_first = TableSchemaBuilder::new(Arc::clone(&file_schema)) + .with_table_partition_cols(partition_cols.clone()) + .with_virtual_columns(virtual_cols.clone()) + .build(); + + let expected = Schema::new(vec![ + Field::new("user_id", DataType::Int64, false), + Field::new("amount", DataType::Float64, false), + Field::new("date", DataType::Utf8, false), + Field::new("row_number", DataType::Int64, true), + ]); + + for ts in [built_virtual_first, built_partition_first] { + assert_eq!(ts.table_schema().as_ref(), &expected); + assert_eq!(ts.virtual_columns().len(), 1); + assert_eq!(ts.virtual_columns()[0].name(), "row_number"); + assert_eq!(ts.table_partition_cols().len(), 1); + assert_eq!(ts.file_schema().fields().len(), 2); + } + } + + #[test] + #[should_panic(expected = "virtual column name collides")] + #[cfg(debug_assertions)] + fn test_virtual_column_collides_with_file_schema_panics_in_debug() { + let file_schema = Arc::new(Schema::new(vec![Field::new( + "row_number", + DataType::Int64, + false, + )])); + let _ = TableSchemaBuilder::new(file_schema) + .with_virtual_columns(vec![Arc::new(Field::new( + "row_number", + DataType::Int64, + true, + ))]) + .build(); + } + + #[test] + #[should_panic(expected = "virtual column name collides")] + #[cfg(debug_assertions)] + fn test_virtual_column_collides_with_partition_panics_in_debug() { + let file_schema = Arc::new(Schema::new(vec![Field::new( + "user_id", + DataType::Int64, + false, + )])); + let partition_cols = + vec![Arc::new(Field::new("row_number", DataType::Utf8, false))]; + let _ = TableSchemaBuilder::new(file_schema) + .with_table_partition_cols(partition_cols) + .with_virtual_columns(vec![Arc::new(Field::new( + "row_number", + DataType::Int64, + true, + ))]) + .build(); + } + + #[test] + #[should_panic(expected = "virtual column name collides")] + #[cfg(debug_assertions)] + fn test_duplicate_virtual_columns_panic_in_debug() { + let file_schema = Arc::new(Schema::new(vec![Field::new( + "user_id", + DataType::Int64, + false, + )])); + let _ = TableSchemaBuilder::new(file_schema) + .with_virtual_columns(vec![ + Arc::new(Field::new("vc", DataType::Int64, true)), + Arc::new(Field::new("vc", DataType::Int64, true)), + ]) + .build(); + } + + #[test] + #[should_panic(expected = "virtual column name collides")] + #[cfg(debug_assertions)] + fn test_partition_column_added_after_colliding_virtual_panics_in_debug() { + // Builder order is irrelevant: collision check runs in build(). + let file_schema = Arc::new(Schema::new(vec![Field::new( + "user_id", + DataType::Int64, + false, + )])); + let _ = TableSchemaBuilder::new(file_schema) + .with_virtual_columns(vec![Arc::new(Field::new( + "row_number", + DataType::Int64, + true, + ))]) + .with_table_partition_cols(vec![Arc::new(Field::new( + "row_number", + DataType::Utf8, + false, + ))]) + .build(); } } diff --git a/datafusion/datasource/src/test_util.rs b/datafusion/datasource/src/test_util.rs index 5d5b277dcf046..d35ed5feb51de 100644 --- a/datafusion/datasource/src/test_util.rs +++ b/datafusion/datasource/src/test_util.rs @@ -17,14 +17,13 @@ use crate::{ file::FileSource, file_scan_config::FileScanConfig, file_stream::FileOpener, - schema_adapter::SchemaAdapterFactory, }; use std::sync::Arc; use arrow::datatypes::Schema; use datafusion_common::Result; -use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; +use datafusion_physical_expr::{PhysicalExpr, expressions::Column}; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use object_store::ObjectStore; @@ -32,32 +31,35 @@ use object_store::ObjectStore; #[derive(Clone)] pub(crate) struct MockSource { metrics: ExecutionPlanMetricsSet, - schema_adapter_factory: Option>, filter: Option>, table_schema: crate::table_schema::TableSchema, + projection: crate::projection::SplitProjection, + file_opener: Option>, } impl Default for MockSource { fn default() -> Self { + let table_schema = + crate::table_schema::TableSchema::from(Arc::new(Schema::empty())); Self { metrics: ExecutionPlanMetricsSet::new(), - schema_adapter_factory: None, filter: None, - table_schema: crate::table_schema::TableSchema::new( - Arc::new(Schema::empty()), - vec![], - ), + projection: crate::projection::SplitProjection::unprojected(&table_schema), + table_schema, + file_opener: None, } } } impl MockSource { pub fn new(table_schema: impl Into) -> Self { + let table_schema = table_schema.into(); Self { metrics: ExecutionPlanMetricsSet::new(), - schema_adapter_factory: None, filter: None, - table_schema: table_schema.into(), + projection: crate::projection::SplitProjection::unprojected(&table_schema), + table_schema, + file_opener: None, } } @@ -65,6 +67,11 @@ impl MockSource { self.filter = Some(filter); self } + + pub fn with_file_opener(mut self, file_opener: Arc) -> Self { + self.file_opener = Some(file_opener); + self + } } impl FileSource for MockSource { @@ -73,12 +80,10 @@ impl FileSource for MockSource { _object_store: Arc, _base_config: &FileScanConfig, _partition: usize, - ) -> Arc { - unimplemented!() - } - - fn as_any(&self) -> &dyn std::any::Any { - self + ) -> Result> { + self.file_opener.clone().ok_or_else(|| { + datafusion_common::internal_datafusion_err!("MockSource missing FileOpener") + }) } fn filter(&self) -> Option> { @@ -89,10 +94,6 @@ impl FileSource for MockSource { Arc::new(Self { ..self.clone() }) } - fn with_projection(&self, _config: &FileScanConfig) -> Arc { - Arc::new(Self { ..self.clone() }) - } - fn metrics(&self) -> &ExecutionPlanMetricsSet { &self.metrics } @@ -101,22 +102,28 @@ impl FileSource for MockSource { "mock" } - fn with_schema_adapter_factory( - &self, - schema_adapter_factory: Arc, - ) -> Result> { - Ok(Arc::new(Self { - schema_adapter_factory: Some(schema_adapter_factory), - ..self.clone() - })) + fn table_schema(&self) -> &crate::table_schema::TableSchema { + &self.table_schema } - fn schema_adapter_factory(&self) -> Option> { - self.schema_adapter_factory.clone() + fn try_pushdown_projection( + &self, + projection: &datafusion_physical_plan::projection::ProjectionExprs, + ) -> Result>> { + let mut source = self.clone(); + let new_projection = self.projection.source.try_merge(projection)?; + let split_projection = crate::projection::SplitProjection::new( + self.table_schema.file_schema(), + &new_projection, + ); + source.projection = split_projection; + Ok(Some(Arc::new(source))) } - fn table_schema(&self) -> &crate::table_schema::TableSchema { - &self.table_schema + fn projection( + &self, + ) -> Option<&datafusion_physical_plan::projection::ProjectionExprs> { + Some(&self.projection.source) } } diff --git a/datafusion/datasource/src/url.rs b/datafusion/datasource/src/url.rs index 1307a4c8b1eb1..4bf99fc325e2c 100644 --- a/datafusion/datasource/src/url.rs +++ b/datafusion/datasource/src/url.rs @@ -17,7 +17,9 @@ use std::sync::Arc; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result, TableReference}; +use datafusion_execution::cache::TableScopedPath; +use datafusion_execution::cache::cache_manager::CachedFileList; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_session::Session; @@ -26,9 +28,9 @@ use futures::{StreamExt, TryStreamExt}; use glob::Pattern; use itertools::Itertools; use log::debug; -use object_store::path::Path; use object_store::path::DELIMITER; -use object_store::{ObjectMeta, ObjectStore}; +use object_store::path::Path; +use object_store::{ObjectMeta, ObjectStore, ObjectStoreExt}; use url::Url; /// A parsed URL identifying files for a listing table, see [`ListingTableUrl::parse`] @@ -41,6 +43,8 @@ pub struct ListingTableUrl { prefix: Path, /// An optional glob expression used to filter files glob: Option, + /// Optional table reference for the table this url belongs to + table_ref: Option, } impl ListingTableUrl { @@ -145,7 +149,12 @@ impl ListingTableUrl { /// to create a [`ListingTableUrl`]. pub fn try_new(url: Url, glob: Option) -> Result { let prefix = Path::from_url_path(url.path())?; - Ok(Self { url, prefix, glob }) + Ok(Self { + url, + prefix, + glob, + table_ref: None, + }) } /// Returns the URL scheme @@ -209,12 +218,12 @@ impl ListingTableUrl { /// assert_eq!(url.file_extension(), None); /// ``` pub fn file_extension(&self) -> Option<&str> { - if let Some(mut segments) = self.url.path_segments() { - if let Some(last_segment) = segments.next_back() { - if last_segment.contains(".") && !last_segment.ends_with(".") { - return last_segment.split('.').next_back(); - } - } + if let Some(mut segments) = self.url.path_segments() + && let Some(last_segment) = segments.next_back() + && last_segment.contains(".") + && !last_segment.ends_with(".") + { + return last_segment.split('.').next_back(); } None @@ -245,25 +254,40 @@ impl ListingTableUrl { let exec_options = &ctx.config_options().execution; let ignore_subdirectory = exec_options.listing_table_ignore_subdirectory; - let prefix = if let Some(prefix) = prefix { - let mut p = self.prefix.parts().collect::>(); - p.extend(prefix.parts()); - Path::from_iter(p.into_iter()) + // Build full_prefix for non-cached path and head() calls + let full_prefix = if let Some(ref p) = prefix { + let mut parts = self.prefix.parts().collect::>(); + parts.extend(p.parts()); + Path::from_iter(parts) } else { self.prefix.clone() }; let list: BoxStream<'a, Result> = if self.is_collection() { - list_with_cache(ctx, store, &prefix).await? + list_with_cache( + ctx, + store, + self.table_ref.as_ref(), + &self.prefix, + prefix.as_ref(), + ) + .await? } else { - match store.head(&prefix).await { + match store.head(&full_prefix).await { Ok(meta) => futures::stream::once(async { Ok(meta) }) .map_err(|e| DataFusionError::ObjectStore(Box::new(e))) .boxed(), // If the head command fails, it is likely that object doesn't exist. // Retry as though it were a prefix (aka a collection) Err(object_store::Error::NotFound { .. }) => { - list_with_cache(ctx, store, &prefix).await? + list_with_cache( + ctx, + store, + self.table_ref.as_ref(), + &self.prefix, + prefix.as_ref(), + ) + .await? } Err(e) => return Err(e.into()), } @@ -317,36 +341,93 @@ impl ListingTableUrl { } /// Returns a copy of current [`ListingTableUrl`] with a specified `glob` - pub fn with_glob(self, glob: &str) -> Result { - let glob = - Pattern::new(glob).map_err(|e| DataFusionError::External(Box::new(e)))?; - Self::try_new(self.url, Some(glob)) + pub fn with_glob(mut self, glob: &str) -> Result { + self.glob = + Some(Pattern::new(glob).map_err(|e| DataFusionError::External(Box::new(e)))?); + Ok(self) + } + + /// Set the table reference for this [`ListingTableUrl`] + pub fn with_table_ref(mut self, table_ref: TableReference) -> Self { + self.table_ref = Some(table_ref); + self + } + + /// Return the table reference for this [`ListingTableUrl`] + pub fn get_table_ref(&self) -> &Option { + &self.table_ref } } +/// Lists files with cache support, using prefix-aware lookups. +/// +/// # Arguments +/// * `ctx` - The session context +/// * `store` - The object store to list from +/// * `table_base_path` - The table's base path (the stable cache key) +/// * `prefix` - Optional prefix relative to table base for filtering results +/// +/// # Cache Behavior: +/// The cache key is always `table_base_path`. When a prefix-filtered listing +/// is requested via `prefix`, the cache: +/// - Looks up `table_base_path` in the cache +/// - Filters results to match `table_base_path/prefix` +/// - Returns filtered results without a storage call +/// +/// On cache miss, the full table is always listed and cached, ensuring +/// subsequent prefix queries can be served from cache. async fn list_with_cache<'b>( ctx: &'b dyn Session, store: &'b dyn ObjectStore, - prefix: &Path, + table_ref: Option<&TableReference>, + table_base_path: &Path, + prefix: Option<&Path>, ) -> Result>> { + // Build the full listing path (table_base + prefix) + let full_prefix = match prefix { + Some(p) => { + let mut parts: Vec<_> = table_base_path.parts().collect(); + parts.extend(p.parts()); + Path::from_iter(parts) + } + None => table_base_path.clone(), + }; + match ctx.runtime_env().cache_manager.get_list_files_cache() { None => Ok(store - .list(Some(prefix)) + .list(Some(&full_prefix)) .map(|res| res.map_err(|e| DataFusionError::ObjectStore(Box::new(e)))) .boxed()), Some(cache) => { - let vec = if let Some(res) = cache.get(prefix) { - debug!("Hit list all files cache"); - res.as_ref().clone() + // Build the filter prefix (only Some if prefix was requested) + let filter_prefix = prefix.is_some().then(|| full_prefix.clone()); + + let table_scoped_base_path = TableScopedPath { + table: table_ref.cloned(), + path: table_base_path.clone(), + }; + + // Try cache lookup - get returns CachedFileList + let vec = if let Some(cached) = cache.get(&table_scoped_base_path) { + debug!("Hit list files cache"); + cached.files_matching_prefix(&filter_prefix) } else { - let vec = store - .list(Some(prefix)) + // Cache miss - always list and cache the full table + // This ensures we have complete data for future prefix queries + let mut vec = store + .list(Some(table_base_path)) .try_collect::>() .await?; - cache.put(prefix, Arc::new(vec.clone())); - vec + vec.shrink_to_fit(); // Right-size before caching + let cached: CachedFileList = vec.into(); + let result = cached.files_matching_prefix(&filter_prefix); + cache.put(&table_scoped_base_path, cached); + result }; - Ok(futures::stream::iter(vec.into_iter().map(Ok)).boxed()) + Ok( + futures::stream::iter(Arc::unwrap_or_clone(vec).into_iter().map(Ok)) + .boxed(), + ) } } } @@ -430,18 +511,21 @@ mod tests { use super::*; use async_trait::async_trait; use bytes::Bytes; - use datafusion_common::config::TableOptions; use datafusion_common::DFSchema; + use datafusion_common::config::TableOptions; + use datafusion_execution::TaskContext; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; - use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; + use datafusion_expr::registry::ExtensionTypeRegistryRef; + use datafusion_expr::{ + AggregateUDF, Expr, HigherOrderUDF, LogicalPlan, ScalarUDF, WindowUDF, + }; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; use object_store::{ - GetOptions, GetResult, ListResult, MultipartUpload, PutMultipartOptions, - PutPayload, + CopyOptions, GetOptions, GetResult, ListResult, MultipartUpload, + PutMultipartOptions, PutPayload, }; use std::any::Any; use std::collections::HashMap; @@ -454,7 +538,7 @@ mod tests { let root = root.to_string_lossy(); let url = ListingTableUrl::parse(root).unwrap(); - let child = url.prefix.child("partition").child("file"); + let child = url.prefix.clone().join("partition").join("file"); let prefix: Vec<_> = url.strip_prefix(&child).unwrap().collect(); assert_eq!(prefix, vec!["partition", "file"]); @@ -754,6 +838,191 @@ mod tests { Ok(()) } + /// Tests that the cached code path produces identical results to the non-cached path. + /// + /// This is critical: the cache is a transparent optimization, so both paths + /// MUST return the same files. Note: order is not guaranteed by ObjectStore::list, + /// so we sort results before comparison. + #[tokio::test] + async fn test_cache_path_equivalence() -> Result<()> { + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + + let store = MockObjectStore { + in_mem: object_store::memory::InMemory::new(), + forbidden_paths: vec![], + }; + + // Create test files with partition-style paths + create_file(&store, "/table/year=2023/data1.parquet").await; + create_file(&store, "/table/year=2023/month=01/data2.parquet").await; + create_file(&store, "/table/year=2024/data3.parquet").await; + create_file(&store, "/table/year=2024/month=06/data4.parquet").await; + create_file(&store, "/table/year=2024/month=12/data5.parquet").await; + + // Session WITHOUT cache + let session_no_cache = MockSession::new(); + + // Session WITH cache - use RuntimeEnvBuilder with cache limit (no TTL needed for this test) + let runtime_with_cache = RuntimeEnvBuilder::new() + .with_object_list_cache_limit(1024 * 1024) // 1MB limit + .build_arc()?; + let session_with_cache = MockSession::with_runtime_env(runtime_with_cache); + + // Test cases: (url, prefix, description) + let test_cases = vec![ + ("/table/", None, "full table listing"), + ( + "/table/", + Some(Path::from("year=2023")), + "single partition filter", + ), + ( + "/table/", + Some(Path::from("year=2024")), + "different partition filter", + ), + ( + "/table/", + Some(Path::from("year=2024/month=06")), + "nested partition filter", + ), + ( + "/table/", + Some(Path::from("year=2025")), + "non-existent partition", + ), + ]; + + for (url_str, prefix, description) in test_cases { + let url = ListingTableUrl::parse(url_str)?; + + // Get results WITHOUT cache (sorted for comparison) + let mut results_no_cache: Vec = url + .list_prefixed_files(&session_no_cache, &store, prefix.clone(), "parquet") + .await? + .try_collect::>() + .await? + .into_iter() + .map(|m| m.location.to_string()) + .collect(); + results_no_cache.sort(); + + // Get results WITH cache (first call - cache miss, sorted for comparison) + let mut results_with_cache_miss: Vec = url + .list_prefixed_files( + &session_with_cache, + &store, + prefix.clone(), + "parquet", + ) + .await? + .try_collect::>() + .await? + .into_iter() + .map(|m| m.location.to_string()) + .collect(); + results_with_cache_miss.sort(); + + // Get results WITH cache (second call - cache hit, sorted for comparison) + let mut results_with_cache_hit: Vec = url + .list_prefixed_files(&session_with_cache, &store, prefix, "parquet") + .await? + .try_collect::>() + .await? + .into_iter() + .map(|m| m.location.to_string()) + .collect(); + results_with_cache_hit.sort(); + + // All three should contain the same files + assert_eq!( + results_no_cache, results_with_cache_miss, + "Cache miss path should match non-cached path for: {description}" + ); + assert_eq!( + results_no_cache, results_with_cache_hit, + "Cache hit path should match non-cached path for: {description}" + ); + } + + Ok(()) + } + + /// Tests that prefix queries can be served from a cached full-table listing + #[tokio::test] + async fn test_cache_serves_partition_from_full_listing() -> Result<()> { + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + + let store = MockObjectStore { + in_mem: object_store::memory::InMemory::new(), + forbidden_paths: vec![], + }; + + // Create test files + create_file(&store, "/sales/region=US/q1.parquet").await; + create_file(&store, "/sales/region=US/q2.parquet").await; + create_file(&store, "/sales/region=EU/q1.parquet").await; + + // Create session with cache (no TTL needed for this test) + let runtime = RuntimeEnvBuilder::new() + .with_object_list_cache_limit(1024 * 1024) // 1MB limit + .build_arc()?; + let session = MockSession::with_runtime_env(runtime); + + let url = ListingTableUrl::parse("/sales/")?; + + // First: query full table (populates cache) + let full_results: Vec = url + .list_prefixed_files(&session, &store, None, "parquet") + .await? + .try_collect::>() + .await? + .into_iter() + .map(|m| m.location.to_string()) + .collect(); + assert_eq!(full_results.len(), 3); + + // Second: query with prefix (should be served from cache) + let mut us_results: Vec = url + .list_prefixed_files( + &session, + &store, + Some(Path::from("region=US")), + "parquet", + ) + .await? + .try_collect::>() + .await? + .into_iter() + .map(|m| m.location.to_string()) + .collect(); + us_results.sort(); + + assert_eq!( + us_results, + vec!["sales/region=US/q1.parquet", "sales/region=US/q2.parquet"] + ); + + // Third: different prefix (also from cache) + let eu_results: Vec = url + .list_prefixed_files( + &session, + &store, + Some(Path::from("region=EU")), + "parquet", + ) + .await? + .try_collect::>() + .await? + .into_iter() + .map(|m| m.location.to_string()) + .collect(); + + assert_eq!(eu_results, vec!["sales/region=EU/q1.parquet"]); + + Ok(()) + } + /// Creates a file with "hello world" content at the specified path async fn create_file(object_store: &dyn ObjectStore, path: &str) { object_store @@ -841,7 +1110,14 @@ mod tests { location: &Path, options: GetOptions, ) -> object_store::Result { - self.in_mem.get_opts(location, options).await + if options.head && self.forbidden_paths.contains(location) { + Err(object_store::Error::PermissionDenied { + path: location.to_string(), + source: "forbidden".into(), + }) + } else { + self.in_mem.get_opts(location, options).await + } } async fn get_ranges( @@ -852,19 +1128,11 @@ mod tests { self.in_mem.get_ranges(location, ranges).await } - async fn head(&self, location: &Path) -> object_store::Result { - if self.forbidden_paths.contains(location) { - Err(object_store::Error::PermissionDenied { - path: location.to_string(), - source: "forbidden".into(), - }) - } else { - self.in_mem.head(location).await - } - } - - async fn delete(&self, location: &Path) -> object_store::Result<()> { - self.in_mem.delete(location).await + fn delete_stream( + &self, + locations: BoxStream<'static, object_store::Result>, + ) -> BoxStream<'static, object_store::Result> { + self.in_mem.delete_stream(locations) } fn list( @@ -881,16 +1149,13 @@ mod tests { self.in_mem.list_with_delimiter(prefix).await } - async fn copy(&self, from: &Path, to: &Path) -> object_store::Result<()> { - self.in_mem.copy(from, to).await - } - - async fn copy_if_not_exists( + async fn copy_opts( &self, from: &Path, to: &Path, + options: CopyOptions, ) -> object_store::Result<()> { - self.in_mem.copy_if_not_exists(from, to).await + self.in_mem.copy_opts(from, to, options).await } } @@ -906,6 +1171,14 @@ mod tests { runtime_env: Arc::new(RuntimeEnv::default()), } } + + /// Create a MockSession with a custom RuntimeEnv (for cache testing) + fn with_runtime_env(runtime_env: Arc) -> Self { + Self { + config: SessionConfig::new(), + runtime_env, + } + } } #[async_trait::async_trait] @@ -937,6 +1210,10 @@ mod tests { unimplemented!() } + fn higher_order_functions(&self) -> &HashMap> { + unimplemented!() + } + fn aggregate_functions(&self) -> &HashMap> { unimplemented!() } @@ -945,6 +1222,10 @@ mod tests { unimplemented!() } + fn extension_type_registry(&self) -> &ExtensionTypeRegistryRef { + unimplemented!() + } + fn runtime_env(&self) -> &Arc { &self.runtime_env } diff --git a/datafusion/datasource/src/write/demux.rs b/datafusion/datasource/src/write/demux.rs index 3fe6149b58b2b..acc6435acf371 100644 --- a/datafusion/datasource/src/write/demux.rs +++ b/datafusion/datasource/src/write/demux.rs @@ -28,15 +28,15 @@ use datafusion_common::error::Result; use datafusion_physical_plan::SendableRecordBatchStream; use arrow::array::{ - builder::UInt64Builder, cast::AsArray, downcast_dictionary_array, ArrayAccessor, - RecordBatch, StringArray, StructArray, + ArrayAccessor, RecordBatch, StringArray, StructArray, builder::UInt64Builder, + cast::AsArray, downcast_dictionary_array, }; use arrow::datatypes::{DataType, Schema}; use datafusion_common::cast::{ as_boolean_array, as_date32_array, as_date64_array, as_float16_array, - as_float32_array, as_float64_array, as_int16_array, as_int32_array, as_int64_array, - as_int8_array, as_string_array, as_string_view_array, as_uint16_array, - as_uint32_array, as_uint64_array, as_uint8_array, + as_float32_array, as_float64_array, as_int8_array, as_int16_array, as_int32_array, + as_int64_array, as_large_string_array, as_string_array, as_string_view_array, + as_uint8_array, as_uint16_array, as_uint32_array, as_uint64_array, }; use datafusion_common::{exec_datafusion_err, internal_datafusion_err, not_impl_err}; use datafusion_common_runtime::SpawnedTask; @@ -106,8 +106,9 @@ pub(crate) fn start_demuxer_task( let file_extension = config.file_extension.clone(); let base_output_path = config.table_paths[0].clone(); let task = if config.table_partition_cols.is_empty() { - let single_file_output = !base_output_path.is_collection() - && base_output_path.file_extension().is_some(); + let single_file_output = config + .file_output_mode + .single_file_output(&base_output_path); SpawnedTask::spawn(async move { row_count_demuxer( tx, @@ -191,7 +192,11 @@ async fn row_count_demuxer( part_idx += 1; } + let schema = input.schema(); + let mut is_batch_received = false; + while let Some(rb) = input.next().await.transpose()? { + is_batch_received = true; // ensure we have at least minimum_parallel_files open if open_file_streams.len() < minimum_parallel_files { open_file_streams.push(create_new_file_stream( @@ -228,6 +233,19 @@ async fn row_count_demuxer( next_send_steam = (next_send_steam + 1) % minimum_parallel_files; } + + // if there is no batch send but with a single file, send an empty batch + if single_file_output && !is_batch_received { + open_file_streams + .first_mut() + .ok_or_else(|| internal_datafusion_err!("Expected a single output file"))? + .send(RecordBatch::new_empty(schema)) + .await + .map_err(|_| { + exec_datafusion_err!("Error sending empty RecordBatch to file stream!") + })?; + } + Ok(()) } @@ -242,7 +260,8 @@ fn generate_file_path( if !single_file_output { base_output_path .prefix() - .child(format!("{write_id}_{part_idx}.{file_extension}")) + .clone() + .join(format!("{write_id}_{part_idx}.{file_extension}")) } else { base_output_path.prefix().to_owned() } @@ -380,6 +399,12 @@ fn compute_partition_keys_by_row<'a>( partition_values.push(Cow::from(array.value(i))); } } + DataType::LargeUtf8 => { + let array = as_large_string_array(col_array)?; + for i in 0..rb.num_rows() { + partition_values.push(Cow::from(array.value(i))); + } + } DataType::Utf8View => { let array = as_string_view_array(col_array)?; for i in 0..rb.num_rows() { @@ -502,9 +527,9 @@ fn compute_partition_keys_by_row<'a>( } _ => { return not_impl_err!( - "it is not yet supported to write to hive partitions with datatype {}", - dtype - ) + "it is not yet supported to write to hive partitions with datatype {}", + dtype + ); } } @@ -564,8 +589,8 @@ fn compute_hive_style_file_path( ) -> Path { let mut file_path = base_output_path.prefix().clone(); for j in 0..part_key.len() { - file_path = file_path.child(format!("{}={}", partition_by[j].0, part_key[j])); + file_path = file_path.join(format!("{}={}", partition_by[j].0, part_key[j])); } - file_path.child(format!("{write_id}.{file_extension}")) + file_path.join(format!("{write_id}.{file_extension}")) } diff --git a/datafusion/datasource/src/write/mod.rs b/datafusion/datasource/src/write/mod.rs index 85832f81bc185..e8d2d17da8ee8 100644 --- a/datafusion/datasource/src/write/mod.rs +++ b/datafusion/datasource/src/write/mod.rs @@ -28,9 +28,9 @@ use datafusion_common::error::Result; use arrow::array::RecordBatch; use arrow::datatypes::Schema; use bytes::Bytes; +use object_store::ObjectStore; use object_store::buffered::BufWriter; use object_store::path::Path; -use object_store::ObjectStore; use tokio::io::AsyncWrite; pub mod demux; @@ -131,6 +131,8 @@ pub struct ObjectWriterBuilder { object_store: Arc, /// The size of the buffer for the object writer. buffer_size: Option, + /// The compression level for the object writer. + compression_level: Option, } impl ObjectWriterBuilder { @@ -145,6 +147,7 @@ impl ObjectWriterBuilder { location: location.clone(), object_store, buffer_size: None, + compression_level: None, } } @@ -202,6 +205,22 @@ impl ObjectWriterBuilder { self.buffer_size } + /// Set compression level for object writer. + pub fn set_compression_level(&mut self, compression_level: Option) { + self.compression_level = compression_level; + } + + /// Set compression level for object writer, returning the builder. + pub fn with_compression_level(mut self, compression_level: Option) -> Self { + self.compression_level = compression_level; + self + } + + /// Currently specified compression level. + pub fn get_compression_level(&self) -> Option { + self.compression_level + } + /// Return a writer object that writes to the object store location. /// /// If a buffer size has not been set, the default buffer buffer size will @@ -215,6 +234,7 @@ impl ObjectWriterBuilder { location, object_store, buffer_size, + compression_level, } = self; let buf_writer = match buffer_size { @@ -222,6 +242,7 @@ impl ObjectWriterBuilder { None => BufWriter::new(object_store, location), }; - file_compression_type.convert_async_writer(buf_writer) + file_compression_type + .convert_async_writer_with_level(buf_writer, compression_level) } } diff --git a/datafusion/datasource/src/write/orchestration.rs b/datafusion/datasource/src/write/orchestration.rs index ab836b7b7f388..39c91a1c0d676 100644 --- a/datafusion/datasource/src/write/orchestration.rs +++ b/datafusion/datasource/src/write/orchestration.rs @@ -28,7 +28,7 @@ use datafusion_common::error::Result; use arrow::array::RecordBatch; use datafusion_common::{ - exec_datafusion_err, internal_datafusion_err, internal_err, DataFusionError, + DataFusionError, exec_datafusion_err, internal_datafusion_err, internal_err, }; use datafusion_common_runtime::{JoinSet, SpawnedTask}; use datafusion_execution::TaskContext; @@ -120,7 +120,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store( return SerializedRecordBatchResult::failure( None, exec_datafusion_err!("Error writing to object store: {e}"), - ) + ); } }; row_count += cnt; @@ -148,7 +148,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store( return SerializedRecordBatchResult::failure( Some(writer), internal_datafusion_err!("Unknown error writing to object store"), - ) + ); } } SerializedRecordBatchResult::success(writer, row_count) @@ -216,12 +216,20 @@ pub(crate) async fn stateless_serialize_and_write_files( } if any_errors { - match any_abort_errors{ - true => return internal_err!("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written."), + match any_abort_errors { + true => { + return internal_err!( + "Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written." + ); + } false => match triggering_error { Some(e) => return Err(e), - None => return internal_err!("Unknown Error encountered during writing to ObjectStore. All writers successfully aborted.") - } + None => { + return internal_err!( + "Unknown Error encountered during writing to ObjectStore. All writers successfully aborted." + ); + } + }, } } @@ -240,6 +248,7 @@ pub async fn spawn_writer_tasks_and_join( context: &Arc, serializer: Arc, compression: FileCompressionType, + compression_level: Option, object_store: Arc, demux_task: SpawnedTask>, mut file_stream_rx: DemuxedStreamReceiver, @@ -265,6 +274,7 @@ pub async fn spawn_writer_tasks_and_join( .execution .objectstore_writer_buffer_size, )) + .with_compression_level(compression_level) .build()?; if tx_file_bundle diff --git a/datafusion/expr-common/Cargo.toml b/datafusion/expr-common/Cargo.toml index 0c4fa2c211cf1..072c8f14da503 100644 --- a/datafusion/expr-common/Cargo.toml +++ b/datafusion/expr-common/Cargo.toml @@ -45,4 +45,6 @@ arrow = { workspace = true } datafusion-common = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } -paste = "^1.0" + +[dev-dependencies] +insta = { workspace = true } diff --git a/datafusion/expr-common/src/accumulator.rs b/datafusion/expr-common/src/accumulator.rs index 2829a9416f033..59fb6a595206a 100644 --- a/datafusion/expr-common/src/accumulator.rs +++ b/datafusion/expr-common/src/accumulator.rs @@ -18,7 +18,7 @@ //! Accumulator module contains the trait definition for aggregation function's accumulators. use arrow::array::ArrayRef; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, internal_err}; use std::fmt::Debug; /// Tracks an aggregate function's state. @@ -48,7 +48,7 @@ use std::fmt::Debug; /// [`evaluate`]: Self::evaluate /// [`merge_batch`]: Self::merge_batch /// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) -pub trait Accumulator: Send + Sync + Debug { +pub trait Accumulator: Send + Sync + Debug + std::any::Any { /// Updates the accumulator's state from its input. /// /// `values` contains the arguments to this aggregate function. @@ -58,17 +58,30 @@ pub trait Accumulator: Send + Sync + Debug { /// running sum. fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>; - /// Returns the final aggregate value, consuming the internal state. + /// Returns the final aggregate value. /// /// For example, the `SUM` accumulator maintains a running sum, /// and `evaluate` will produce that running sum as its output. /// - /// This function should not be called twice, otherwise it will - /// result in potentially non-deterministic behavior. - /// /// This function gets `&mut self` to allow for the accumulator to build /// arrow-compatible internal state that can be returned without copying - /// when possible (for example distinct strings) + /// when possible (for example distinct strings). + /// + /// ## Correctness + /// + /// This function must not consume the internal state, as it is also used in window + /// aggregate functions where it can be executed multiple times depending on the + /// current window frame. Consuming the internal state can cause the next invocation + /// to have incorrect results. + /// + /// - Even if this accumulator doesn't implement [`retract_batch`] it may still be used + /// in window aggregate functions where the window frame is + /// `ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW` + /// + /// It is fine to modify the state (e.g. re-order elements within internal state vec) so long + /// as this doesn't cause an incorrect computation on the next call of evaluate. + /// + /// [`retract_batch`]: Self::retract_batch fn evaluate(&mut self) -> Result; /// Returns the allocated size required for this accumulator, in diff --git a/datafusion/expr-common/src/casts.rs b/datafusion/expr-common/src/casts.rs index 8939ff1371bb9..d18c3d4f043eb 100644 --- a/datafusion/expr-common/src/casts.rs +++ b/datafusion/expr-common/src/casts.rs @@ -24,15 +24,38 @@ use std::cmp::Ordering; use arrow::datatypes::{ - DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION, - MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION, - MIN_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL32_FOR_EACH_PRECISION, - MIN_DECIMAL64_FOR_EACH_PRECISION, + DataType, MAX_DECIMAL32_FOR_EACH_PRECISION, MAX_DECIMAL64_FOR_EACH_PRECISION, + MAX_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL32_FOR_EACH_PRECISION, + MIN_DECIMAL64_FOR_EACH_PRECISION, MIN_DECIMAL128_FOR_EACH_PRECISION, TimeUnit, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; use datafusion_common::ScalarValue; -/// Convert a literal value from one data type to another +/// Convert a literal [`ScalarValue`] to `target_type`, preserving the exact value. +/// +/// Returns `None` if the value cannot be represented in `target_type` +/// *exactly*. +/// +/// This is a restricted, value-preserving cast used to rewrite comparison +/// predicates of the form `CAST(col AS target_type) literal` into +/// `col try_cast_literal_to_type(literal, col_type)`. That rewrite is +/// only valid when the cast cannot change the comparison result. +/// +/// # Supported Casts +/// * numeric → numeric, including integers, decimals, `Date32`/`Date64` and +/// `Timestamp`s, rejecting values outside the target's range or that would +/// lose decimal digits +/// * string → string between `Utf8`, `LargeUtf8` and `Utf8View` +/// * wrapping a value into, or unwrapping it out of, a `Dictionary` whose value +/// type matches the literal's type +/// * `Binary` → `FixedSizeBinary` of the matching length +/// * `Timestamp` → `Timestamp` cast between different time units is allowed even +/// though it can truncate (for example nanoseconds → seconds), and a unit +/// conversion that overflows yields a `NULL` literal rather than `None`. +/// +/// # See Also +/// - [`ScalarValue::cast_to`]: a general-purpose cast that can lose information +/// or change a value's meaning. pub fn try_cast_literal_to_type( lit_value: &ScalarValue, target_type: &DataType, @@ -59,7 +82,28 @@ pub fn is_supported_type(data_type: &DataType) -> bool { || is_supported_binary_type(data_type) } -/// Returns true if unwrap_cast_in_comparison support this numeric type +fn is_date_type(data_type: &DataType) -> bool { + matches!(data_type, DataType::Date32 | DataType::Date64) +} + +/// Returns true when unwrapping a date/timestamp cast could change comparison +/// semantics. +/// +/// A `Date` stores only a calendar day, while a `Timestamp` stores a specific +/// instant or wall-clock time. `Timestamp -> Date` is lossy because it drops the +/// time-of-day. `Date -> Timestamp` is also lossy in this optimizer context +/// because there is no unique inverse: converting a date to a timestamp has to +/// invent a time component such as midnight. +/// +/// For example, `CAST(ts AS DATE) = DATE '2024-01-01'` means "any timestamp +/// during that day", but unwrapping it to `ts = TIMESTAMP '2024-01-01 +/// 00:00:00'` matches only midnight. +fn is_lossy_temporal_cast(from_type: &DataType, to_type: &DataType) -> bool { + (is_date_type(from_type) && to_type.is_temporal()) + || (is_date_type(to_type) && from_type.is_temporal()) +} + +/// Returns true if unwrap_cast_in_comparison supports this numeric type fn is_supported_numeric_type(data_type: &DataType) -> bool { matches!( data_type, @@ -71,6 +115,8 @@ fn is_supported_numeric_type(data_type: &DataType) -> bool { | DataType::Int16 | DataType::Int32 | DataType::Int64 + | DataType::Date32 + | DataType::Date64 | DataType::Decimal32(_, _) | DataType::Decimal64(_, _) | DataType::Decimal128(_, _) @@ -108,6 +154,10 @@ fn try_cast_numeric_literal( return None; } + if is_lossy_temporal_cast(&lit_data_type, target_type) { + return None; + } + let mul = match target_type { DataType::UInt8 | DataType::UInt16 @@ -116,7 +166,9 @@ fn try_cast_numeric_literal( | DataType::Int8 | DataType::Int16 | DataType::Int32 - | DataType::Int64 => 1_i128, + | DataType::Int64 + | DataType::Date32 + | DataType::Date64 => 1_i128, DataType::Timestamp(_, _) => 1_i128, DataType::Decimal32(_, scale) => 10_i128.pow(*scale as u32), DataType::Decimal64(_, scale) => 10_i128.pow(*scale as u32), @@ -130,8 +182,8 @@ fn try_cast_numeric_literal( DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128), DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), - DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), - DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), + DataType::Int32 | DataType::Date32 => (i32::MIN as i128, i32::MAX as i128), + DataType::Int64 | DataType::Date64 => (i64::MIN as i128, i64::MAX as i128), DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128), DataType::Decimal32(precision, _) => ( // Different precision for decimal32 can store different range of value. @@ -165,6 +217,8 @@ fn try_cast_numeric_literal( ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul), ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul), ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Date32(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Date64(Some(v)) => (*v as i128).checked_mul(mul), ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul), ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul), ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul), @@ -242,6 +296,8 @@ fn try_cast_numeric_literal( DataType::Int16 => ScalarValue::Int16(Some(value as i16)), DataType::Int32 => ScalarValue::Int32(Some(value as i32)), DataType::Int64 => ScalarValue::Int64(Some(value as i64)), + DataType::Date32 => ScalarValue::Date32(Some(value as i32)), + DataType::Date64 => ScalarValue::Date64(Some(value as i64)), DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)), DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)), DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)), @@ -382,8 +438,8 @@ fn try_cast_binary( #[cfg(test)] mod tests { use super::*; - use arrow::compute::{cast_with_options, CastOptions}; - use arrow::datatypes::{Field, Fields, TimeUnit}; + use arrow::compute::{CastOptions, cast_with_options}; + use arrow::datatypes::{Field, Fields}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -701,6 +757,33 @@ mod tests { } } + #[test] + fn test_try_cast_to_type_date_timestamp_lossy_not_allowed() { + expect_cast( + ScalarValue::Date32(Some(1)), + DataType::Timestamp(TimeUnit::Second, None), + ExpectedCast::NoValue, + ); + + expect_cast( + ScalarValue::Date64(Some(86_400_000)), + DataType::Timestamp(TimeUnit::Millisecond, None), + ExpectedCast::NoValue, + ); + + expect_cast( + ScalarValue::TimestampSecond(Some(86_400), None), + DataType::Date32, + ExpectedCast::NoValue, + ); + + expect_cast( + ScalarValue::TimestampMillisecond(Some(86_400_000), None), + DataType::Date64, + ExpectedCast::NoValue, + ); + } + #[test] fn test_try_cast_to_type_unsupported() { // int64 to list diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index 585b47a9800d2..caeb3f10da752 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -18,17 +18,23 @@ //! [`ColumnarValue`] represents the result of evaluating an expression. use arrow::{ - array::{Array, ArrayRef, Date32Array, Date64Array, NullArray}, - compute::{kernels, max, min, CastOptions}, - datatypes::DataType, + array::{ + Array, ArrayRef, Date32Array, Date64Array, NullArray, TimestampMicrosecondArray, + TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + }, + compute::{CastOptions, kernels, max, min}, + datatypes::{DataType, TimeUnit}, util::pretty::pretty_format_columns, }; use datafusion_common::internal_datafusion_err; use datafusion_common::{ + Result, ScalarValue, format::DEFAULT_CAST_OPTIONS, internal_err, - scalar::{date_to_timestamp_multiplier, ensure_timestamp_in_bounds}, - Result, ScalarValue, + scalar::{ + date_to_timestamp_multiplier, ensure_timestamp_in_bounds, + timestamp_to_timestamp_multiplier, + }, }; use std::fmt; use std::sync::Arc; @@ -274,7 +280,17 @@ impl ColumnarValue { Ok(args) } - /// Cast's this [ColumnarValue] to the specified `DataType` + /// Cast this [ColumnarValue] to the specified `DataType` + /// + /// # Struct Casting Behavior + /// + /// When casting struct types, fields are matched **by name** rather than position: + /// - Source fields are matched to target fields using case-sensitive name comparison + /// - Fields are reordered to match the target schema + /// - Missing target fields are filled with null arrays + /// - Extra source fields are ignored + /// + /// For non-struct types, uses Arrow's standard positional casting. pub fn cast_to( &self, cast_type: &DataType, @@ -283,12 +299,8 @@ impl ColumnarValue { let cast_options = cast_options.cloned().unwrap_or(DEFAULT_CAST_OPTIONS); match self { ColumnarValue::Array(array) => { - ensure_date_array_timestamp_bounds(array, cast_type)?; - Ok(ColumnarValue::Array(kernels::cast::cast_with_options( - array, - cast_type, - &cast_options, - )?)) + let casted = cast_array_by_name(array, cast_type, &cast_options)?; + Ok(ColumnarValue::Array(casted)) } ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( scalar.cast_to_with_options(cast_type, &cast_options)?, @@ -297,12 +309,39 @@ impl ColumnarValue { } } -fn ensure_date_array_timestamp_bounds( +fn cast_array_by_name( + array: &ArrayRef, + cast_type: &DataType, + cast_options: &CastOptions<'static>, +) -> Result { + // If types are already equal, no cast needed + if array.data_type() == cast_type { + return Ok(Arc::clone(array)); + } + + if datafusion_common::nested_struct::requires_nested_struct_cast( + array.data_type(), + cast_type, + ) { + datafusion_common::nested_struct::cast_column(array, cast_type, cast_options) + } else { + ensure_temporal_array_timestamp_bounds(array, cast_type)?; + Ok(kernels::cast::cast_with_options( + array, + cast_type, + cast_options, + )?) + } +} + +fn ensure_temporal_array_timestamp_bounds( array: &ArrayRef, cast_type: &DataType, ) -> Result<()> { let source_type = array.data_type().clone(); - let Some(multiplier) = date_to_timestamp_multiplier(&source_type, cast_type) else { + let Some(multiplier) = date_to_timestamp_multiplier(&source_type, cast_type) + .or_else(|| timestamp_to_timestamp_multiplier(&source_type, cast_type)) + else { return Ok(()); }; @@ -336,7 +375,55 @@ fn ensure_date_array_timestamp_bounds( })?; (min(arr), max(arr)) } - _ => return Ok(()), // Not a date type, nothing to do + DataType::Timestamp(TimeUnit::Second, _) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!( + "Expected TimestampSecondArray but found {}", + array.data_type() + ) + })?; + (min(arr), max(arr)) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!( + "Expected TimestampMillisecondArray but found {}", + array.data_type() + ) + })?; + (min(arr), max(arr)) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!( + "Expected TimestampMicrosecondArray but found {}", + array.data_type() + ) + })?; + (min(arr), max(arr)) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!( + "Expected TimestampNanosecondArray but found {}", + array.data_type() + ) + })?; + (min(arr), max(arr)) + } + _ => return Ok(()), // Not a temporal type that needs checking. }; // Only validate the min and max values instead of all elements @@ -378,8 +465,8 @@ impl fmt::Display for ColumnarValue { mod tests { use super::*; use arrow::{ - array::{Date64Array, Int32Array}, - datatypes::TimeUnit, + array::{Date64Array, Int32Array, StructArray}, + datatypes::{Field, Fields, TimeUnit}, }; #[test] @@ -553,6 +640,102 @@ mod tests { ); } + #[test] + fn cast_struct_by_field_name() { + let source_fields = Fields::from(vec![ + Field::new("b", DataType::Int32, true), + Field::new("a", DataType::Int32, true), + ]); + + let target_fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + + let struct_array = StructArray::new( + source_fields, + vec![ + Arc::new(Int32Array::from(vec![Some(3)])), + Arc::new(Int32Array::from(vec![Some(4)])), + ], + None, + ); + + let value = ColumnarValue::Array(Arc::new(struct_array)); + let casted = value + .cast_to(&DataType::Struct(target_fields.clone()), None) + .expect("struct cast should succeed"); + + let ColumnarValue::Array(arr) = casted else { + panic!("expected array after cast"); + }; + + let struct_array = arr + .as_any() + .downcast_ref::() + .expect("expected StructArray"); + + let field_a = struct_array + .column_by_name("a") + .expect("expected field a in cast result"); + let field_b = struct_array + .column_by_name("b") + .expect("expected field b in cast result"); + + assert_eq!( + field_a + .as_any() + .downcast_ref::() + .expect("expected Int32 array") + .value(0), + 4 + ); + assert_eq!( + field_b + .as_any() + .downcast_ref::() + .expect("expected Int32 array") + .value(0), + 3 + ); + } + + #[test] + fn cast_struct_missing_field_inserts_nulls() { + let source_fields = Fields::from(vec![Field::new("a", DataType::Int32, true)]); + + let target_fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + + let struct_array = StructArray::new( + source_fields, + vec![Arc::new(Int32Array::from(vec![Some(5)]))], + None, + ); + + let value = ColumnarValue::Array(Arc::new(struct_array)); + let casted = value + .cast_to(&DataType::Struct(target_fields.clone()), None) + .expect("struct cast should succeed"); + + let ColumnarValue::Array(arr) = casted else { + panic!("expected array after cast"); + }; + + let struct_array = arr + .as_any() + .downcast_ref::() + .expect("expected StructArray"); + + let field_b = struct_array + .column_by_name("b") + .expect("expected missing field to be added"); + + assert!(field_b.is_null(0)); + } + #[test] fn cast_date64_array_to_timestamp_overflow() { let overflow_value = i64::MAX / 1_000_000 + 1; @@ -567,4 +750,20 @@ mod tests { "unexpected error: {err}" ); } + + #[test] + fn cast_timestamp_array_to_timestamp_overflow() { + let overflow_value = i64::MAX / 1_000_000_000 + 1; + let array: ArrayRef = + Arc::new(TimestampSecondArray::from(vec![Some(overflow_value)])); + let value = ColumnarValue::Array(array); + let result = + value.cast_to(&DataType::Timestamp(TimeUnit::Nanosecond, None), None); + let err = result.expect_err("expected overflow to be detected"); + assert!( + err.to_string() + .contains("converted value exceeds the representable i64 range"), + "unexpected error: {err}" + ); + } } diff --git a/datafusion/expr-common/src/dyn_eq.rs b/datafusion/expr-common/src/dyn_eq.rs index e0ebcae4879d6..75d9c06d67f56 100644 --- a/datafusion/expr-common/src/dyn_eq.rs +++ b/datafusion/expr-common/src/dyn_eq.rs @@ -28,7 +28,7 @@ use std::hash::{Hash, Hasher}; /// /// Note: This trait should not be implemented directly. Implement `Eq` and `Any` and use /// the blanket implementation. -#[allow(private_bounds)] +#[expect(private_bounds)] pub trait DynEq: private::EqSealed { fn dyn_eq(&self, other: &dyn Any) -> bool; } @@ -45,7 +45,7 @@ impl DynEq for T { /// /// Note: This trait should not be implemented directly. Implement `Hash` and `Any` and use /// the blanket implementation. -#[allow(private_bounds)] +#[expect(private_bounds)] pub trait DynHash: private::HashSealed { fn dyn_hash(&self, _state: &mut dyn Hasher); } diff --git a/datafusion/expr-common/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs index 9bcc1edff8824..da5da384c7b4e 100644 --- a/datafusion/expr-common/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -18,7 +18,7 @@ //! Vectorized [`GroupsAccumulator`] use arrow::array::{ArrayRef, BooleanArray}; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::{Result, not_impl_err, utils::split_vec_min_alloc}; /// Describes how many rows should be emitted during grouping. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -45,13 +45,7 @@ impl EmitTo { // Take the entire vector, leave new (empty) vector std::mem::take(v) } - Self::First(n) => { - // get end n+1,.. values into t - let mut t = v.split_off(*n); - // leave n+1,.. in v - std::mem::swap(v, &mut t); - t - } + Self::First(n) => split_vec_min_alloc(v, *n), } } } @@ -89,6 +83,9 @@ impl EmitTo { /// optional and is harder to implement than `Accumulator`, but can be much /// faster for queries with many group values. See the [Aggregating Millions of /// Groups Fast blog] for more background. +/// For more background, please also see the [Aggregating Millions of Groups Fast in Apache Arrow DataFusion 28.0.0 blog] +/// +/// [Aggregating Millions of Groups Fast in Apache Arrow DataFusion 28.0.0 blog]: https://datafusion.apache.org/blog/2023/08/05/datafusion_fast_grouping /// /// [`NullState`] can help keep the state for groups that have not seen any /// values and produce the correct output for those groups. @@ -105,7 +102,7 @@ impl EmitTo { /// /// [`Accumulator`]: crate::accumulator::Accumulator /// [Aggregating Millions of Groups Fast blog]: https://arrow.apache.org/blog/2023/08/05/datafusion_fast_grouping/ -pub trait GroupsAccumulator: Send { +pub trait GroupsAccumulator: Send + std::any::Any { /// Updates the accumulator's state from its arguments, encoded as /// a vector of [`ArrayRef`]s. /// @@ -251,3 +248,53 @@ pub trait GroupsAccumulator: Send { /// compute, not `O(num_groups)` fn size(&self) -> usize; } + +#[cfg(test)] +mod tests { + use super::EmitTo; + + /// When `n` is small relative to `len`, the old `split_off(n) + swap` pattern had + /// two allocation problems: + /// + /// 1. The returned Vec kept the original large backing allocation even though it + /// only contains `n` elements (wasted capacity on a short-lived value). + /// 2. `split_off` allocated a fresh Vec for the `len - n` remaining elements, + /// even though that side is much larger than `n` — the expensive side to + /// allocate. + /// + /// `split_vec_min_alloc` fixes both: when `n * 2 <= len` it uses + /// `drain(0..n).collect()`, allocating only `n` elements for the emitted prefix + /// and keeping the original large backing in the remaining accumulator. + #[test] + fn take_needed_first_small_n_allocates_minimally() { + let mut v: Vec = Vec::with_capacity(128); + v.extend(0..20i32); + let original_capacity = v.capacity(); // 128 + + // n=4, n*2=8 <= len=20 -> drain branch in split_vec_min_alloc + let emitted = EmitTo::First(4).take_needed(&mut v); + + assert_eq!(emitted, vec![0, 1, 2, 3]); + assert_eq!(v, (4..20i32).collect::>()); + + // The emitted prefix must NOT carry the original large allocation. + // Old split_off+swap returned a Vec with capacity=128 for only 4 elements. + assert!( + emitted.capacity() <= 4, + "emitted prefix capacity {} should be ~n=4, not the original {}", + emitted.capacity(), + original_capacity, + ); + + // The remaining accumulator must retain the original large allocation so + // that incoming groups don't immediately force a realloc. + // Old split_off+swap left the remaining vec with a small fresh allocation. + assert_eq!( + v.capacity(), + original_capacity, + "remaining vec capacity {} should equal original {}", + v.capacity(), + original_capacity, + ); + } +} diff --git a/datafusion/expr-common/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs index b9f8102f341ac..51858be538f5a 100644 --- a/datafusion/expr-common/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -22,22 +22,22 @@ use std::fmt::{self, Display, Formatter}; use std::ops::{AddAssign, SubAssign}; use crate::operator::Operator; -use crate::type_coercion::binary::{comparison_coercion_numeric, BinaryTypeCoercer}; +use crate::type_coercion::binary::{BinaryTypeCoercer, comparison_coercion}; -use arrow::compute::{cast_with_options, CastOptions}; +use arrow::compute::{CastOptions, cast_with_options}; use arrow::datatypes::{ - DataType, IntervalDayTime, IntervalMonthDayNano, IntervalUnit, TimeUnit, + DataType, IntervalDayTime, IntervalMonthDayNano, IntervalUnit, MAX_DECIMAL128_FOR_EACH_PRECISION, MAX_DECIMAL256_FOR_EACH_PRECISION, - MIN_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL256_FOR_EACH_PRECISION, + MIN_DECIMAL128_FOR_EACH_PRECISION, MIN_DECIMAL256_FOR_EACH_PRECISION, TimeUnit, }; use datafusion_common::rounding::{alter_fp_rounding_mode, next_down, next_up}; use datafusion_common::{ - assert_eq_or_internal_err, assert_or_internal_err, internal_err, DataFusionError, - Result, ScalarValue, + DataFusionError, Result, ScalarValue, assert_eq_or_internal_err, + assert_or_internal_err, internal_err, }; macro_rules! get_extreme_value { - ($extreme:ident, $value:expr) => { + ($extreme:ident, $DECIMAL128_ARRAY:ident, $DECIMAL256_ARRAY:ident, $value:expr) => { match $value { DataType::UInt8 => ScalarValue::UInt8(Some(u8::$extreme)), DataType::UInt16 => ScalarValue::UInt16(Some(u16::$extreme)), @@ -49,6 +49,8 @@ macro_rules! get_extreme_value { DataType::Int64 => ScalarValue::Int64(Some(i64::$extreme)), DataType::Float32 => ScalarValue::Float32(Some(f32::$extreme)), DataType::Float64 => ScalarValue::Float64(Some(f64::$extreme)), + DataType::Date32 => ScalarValue::Date32(Some(i32::$extreme)), + DataType::Date64 => ScalarValue::Date64(Some(i64::$extreme)), DataType::Duration(TimeUnit::Second) => { ScalarValue::DurationSecond(Some(i64::$extreme)) } @@ -83,18 +85,12 @@ macro_rules! get_extreme_value { ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano::$extreme)) } DataType::Decimal128(precision, scale) => ScalarValue::Decimal128( - Some( - paste::paste! {[<$extreme _DECIMAL128_FOR_EACH_PRECISION>]} - [*precision as usize], - ), + Some($DECIMAL128_ARRAY[*precision as usize]), *precision, *scale, ), DataType::Decimal256(precision, scale) => ScalarValue::Decimal256( - Some( - paste::paste! {[<$extreme _DECIMAL256_FOR_EACH_PRECISION>]} - [*precision as usize], - ), + Some($DECIMAL256_ARRAY[*precision as usize]), *precision, *scale, ), @@ -650,31 +646,25 @@ impl Interval { /// Compute the intersection of this interval with the given interval. /// If the intersection is empty, return `None`. /// - /// NOTE: This function only works with intervals of the same data type. - /// Attempting to compare intervals of different data types will lead - /// to an error. + /// If the two intervals have different data types, both are coerced to a + /// common comparison type via [`comparison_coercion`] before computing the + /// intersection. pub fn intersect>(&self, other: T) -> Result> { let rhs = other.borrow(); - let lhs_type = self.data_type(); - let rhs_type = rhs.data_type(); - assert_eq_or_internal_err!( - lhs_type, - rhs_type, - "Only intervals with the same data type are intersectable, lhs:{}, rhs:{}", - self.data_type(), - rhs.data_type() - ); + let (lhs_owned, rhs_owned) = coerce_for_comparison(self, rhs)?; + let lhs = lhs_owned.as_ref().unwrap_or(self); + let rhs = rhs_owned.as_ref().unwrap_or(rhs); // If it is evident that the result is an empty interval, short-circuit // and directly return `None`. - if (!(self.lower.is_null() || rhs.upper.is_null()) && self.lower > rhs.upper) - || (!(self.upper.is_null() || rhs.lower.is_null()) && self.upper < rhs.lower) + if (!(lhs.lower.is_null() || rhs.upper.is_null()) && lhs.lower > rhs.upper) + || (!(lhs.upper.is_null() || rhs.lower.is_null()) && lhs.upper < rhs.lower) { return Ok(None); } - let lower = max_of_bounds(&self.lower, &rhs.lower); - let upper = min_of_bounds(&self.upper, &rhs.upper); + let lower = max_of_bounds(&lhs.lower, &rhs.lower); + let upper = min_of_bounds(&lhs.upper, &rhs.upper); // New lower and upper bounds must always construct a valid interval. debug_assert!( @@ -687,35 +677,27 @@ impl Interval { /// Compute the union of this interval with the given interval. /// - /// NOTE: This function only works with intervals of the same data type. - /// Attempting to compare intervals of different data types will lead - /// to an error. + /// If the two intervals have different data types, both are coerced to a + /// common comparison type via [`comparison_coercion`] before computing the + /// union. pub fn union>(&self, other: T) -> Result { let rhs = other.borrow(); - let lhs_type = self.data_type(); - let rhs_type = rhs.data_type(); - assert_eq_or_internal_err!( - lhs_type, - rhs_type, - "Cannot calculate the union of intervals with different data types, lhs:{}, rhs:{}", - self.data_type(), - rhs.data_type() - ); + let (lhs_owned, rhs_owned) = coerce_for_comparison(self, rhs)?; + let lhs = lhs_owned.as_ref().unwrap_or(self); + let rhs = rhs_owned.as_ref().unwrap_or(rhs); - let lower = if self.lower.is_null() - || (!rhs.lower.is_null() && self.lower <= rhs.lower) - { - self.lower.clone() - } else { - rhs.lower.clone() - }; - let upper = if self.upper.is_null() - || (!rhs.upper.is_null() && self.upper >= rhs.upper) - { - self.upper.clone() - } else { - rhs.upper.clone() - }; + let lower = + if lhs.lower.is_null() || (!rhs.lower.is_null() && lhs.lower <= rhs.lower) { + lhs.lower.clone() + } else { + rhs.lower.clone() + }; + let upper = + if lhs.upper.is_null() || (!rhs.upper.is_null() && lhs.upper >= rhs.upper) { + lhs.upper.clone() + } else { + rhs.upper.clone() + }; // New lower and upper bounds must always construct a valid interval. debug_assert!( @@ -734,7 +716,7 @@ impl Interval { (self.lower.clone(), self.upper.clone(), rhs.clone()) } else { let maybe_common_type = - comparison_coercion_numeric(&self.data_type(), &rhs.data_type()); + comparison_coercion(&self.data_type(), &rhs.data_type()); assert_or_internal_err!( maybe_common_type.is_some(), "Data types must be compatible for containment checks, lhs:{}, rhs:{}", @@ -758,22 +740,16 @@ impl Interval { /// disjoint with `other` by returning `[true, true]`, `[false, true]` or /// `[false, false]` respectively. /// - /// NOTE: This function only works with intervals of the same data type. - /// Attempting to compare intervals of different data types will lead - /// to an error. + /// If the two intervals have different data types, both are coerced to a + /// common comparison type via [`comparison_coercion`] before checking + /// containment. pub fn contains>(&self, other: T) -> Result { let rhs = other.borrow(); - let lhs_type = self.data_type(); - let rhs_type = rhs.data_type(); - assert_eq_or_internal_err!( - lhs_type, - rhs_type, - "Interval data types must match for containment checks, lhs:{}, rhs:{}", - self.data_type(), - rhs.data_type() - ); + let (lhs_owned, rhs_owned) = coerce_for_comparison(self, rhs)?; + let lhs = lhs_owned.as_ref().unwrap_or(self); + let rhs = rhs_owned.as_ref().unwrap_or(rhs); - match self.intersect(rhs)? { + match lhs.intersect(rhs)? { Some(intersection) => { if &intersection == rhs { Ok(Self::TRUE) @@ -834,36 +810,29 @@ impl Interval { /// Note that this represents all possible values the product can take if /// one can choose single values arbitrarily from each of the operands. /// - /// NOTE: This function only works with intervals of the same data type. - /// Attempting to compare intervals of different data types will lead - /// to an error. + /// If the two intervals have different data types, both are coerced to a + /// common type via [`BinaryTypeCoercer`] before computing the product. pub fn mul>(&self, other: T) -> Result { let rhs = other.borrow(); - let dt = self.data_type(); - let rhs_type = rhs.data_type(); - assert_eq_or_internal_err!( - dt.clone(), - rhs_type.clone(), - "Intervals must have the same data type for multiplication, lhs:{}, rhs:{}", - dt.clone(), - rhs_type.clone() - ); + let (lhs_owned, rhs_owned, dt) = coerce_operands(self, rhs, &Operator::Multiply)?; + let lhs_ref = lhs_owned.as_ref().unwrap_or(self); + let rhs_ref = rhs_owned.as_ref().unwrap_or(rhs); let zero = ScalarValue::new_zero(&dt)?; let result = match ( - self.contains_value(&zero)?, - rhs.contains_value(&zero)?, + lhs_ref.contains_value(&zero)?, + rhs_ref.contains_value(&zero)?, dt.is_unsigned_integer(), ) { - (true, true, false) => mul_helper_multi_zero_inclusive(&dt, self, rhs), + (true, true, false) => mul_helper_multi_zero_inclusive(&dt, lhs_ref, rhs_ref), (true, false, false) => { - mul_helper_single_zero_inclusive(&dt, self, rhs, &zero) + mul_helper_single_zero_inclusive(&dt, lhs_ref, rhs_ref, &zero) } (false, true, false) => { - mul_helper_single_zero_inclusive(&dt, rhs, self, &zero) + mul_helper_single_zero_inclusive(&dt, rhs_ref, lhs_ref, &zero) } - _ => mul_helper_zero_exclusive(&dt, self, rhs, &zero), + _ => mul_helper_zero_exclusive(&dt, lhs_ref, rhs_ref, &zero), }; Ok(result) } @@ -874,23 +843,16 @@ impl Interval { /// all possible values the quotient can take if one can choose single values /// arbitrarily from each of the operands. /// - /// NOTE: This function only works with intervals of the same data type. - /// Attempting to compare intervals of different data types will lead - /// to an error. + /// If the two intervals have different data types, both are coerced to a + /// common type via [`BinaryTypeCoercer`] before computing the quotient. /// /// **TODO**: Once interval sets are supported, cases where the divisor contains /// zero should result in an interval set, not the universal set. pub fn div>(&self, other: T) -> Result { let rhs = other.borrow(); - let dt = self.data_type(); - let rhs_type = rhs.data_type(); - assert_eq_or_internal_err!( - dt.clone(), - rhs_type.clone(), - "Intervals must have the same data type for division, lhs:{}, rhs:{}", - dt.clone(), - rhs_type.clone() - ); + let (lhs_owned, rhs_owned, dt) = coerce_operands(self, rhs, &Operator::Divide)?; + let lhs_ref = lhs_owned.as_ref().unwrap_or(self); + let rhs_ref = rhs_owned.as_ref().unwrap_or(rhs); let zero = ScalarValue::new_zero(&dt)?; // We want 0 to be approachable from both negative and positive sides. @@ -901,15 +863,27 @@ impl Interval { // Exit early with an unbounded interval if zero is strictly inside the // right hand side: - if rhs.contains(&zero_point)? == Self::TRUE && !dt.is_unsigned_integer() { + if rhs_ref.contains(&zero_point)? == Self::TRUE && !dt.is_unsigned_integer() { Self::make_unbounded(&dt) } // At this point, we know that only one endpoint of the right hand side // can be zero. - else if self.contains(&zero_point)? == Self::TRUE && !dt.is_unsigned_integer() { - Ok(div_helper_lhs_zero_inclusive(&dt, self, rhs, &zero_point)) + else if lhs_ref.contains(&zero_point)? == Self::TRUE + && !dt.is_unsigned_integer() + { + Ok(div_helper_lhs_zero_inclusive( + &dt, + lhs_ref, + rhs_ref, + &zero_point, + )) } else { - Ok(div_helper_zero_exclusive(&dt, self, rhs, &zero_point)) + Ok(div_helper_zero_exclusive( + &dt, + lhs_ref, + rhs_ref, + &zero_point, + )) } } @@ -933,7 +907,12 @@ impl Interval { /// when the calculated cardinality does not fit in an `u64`. pub fn cardinality(&self) -> Option { let data_type = self.data_type(); - if data_type.is_integer() { + if data_type.is_integer() + || matches!( + data_type, + DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _) + ) + { self.upper.distance(&self.lower).map(|diff| diff as u64) } else if data_type.is_floating() { // Negative numbers are sorted in the reverse order. To @@ -965,7 +944,7 @@ impl Interval { // Cardinality calculations are not implemented for this data type yet: None } - .map(|result| result + 1) + .and_then(|result| result.checked_add(1)) } /// Reflects an [`Interval`] around the point zero. @@ -999,6 +978,70 @@ impl From<&ScalarValue> for Interval { } } +/// Coerces two intervals to a common comparison type so that lower/upper +/// bounds from each can be compared directly. +/// +/// Returns `(coerced_lhs, coerced_rhs)` where each is `Some(...)` if a cast +/// was required and `None` otherwise. Returns an internal error if the two +/// types cannot be unified for comparison. +fn coerce_for_comparison( + lhs: &Interval, + rhs: &Interval, +) -> Result<(Option, Option)> { + let lhs_type = lhs.data_type(); + let rhs_type = rhs.data_type(); + if lhs_type == rhs_type { + return Ok((None, None)); + } + let maybe_common = comparison_coercion(&lhs_type, &rhs_type); + assert_or_internal_err!( + maybe_common.is_some(), + "Data types must be compatible for interval comparison, lhs:{}, rhs:{}", + lhs_type, + rhs_type + ); + let common = maybe_common.expect("checked for Some"); + let cast_options = CastOptions::default(); + let new_lhs = (lhs_type != common) + .then(|| lhs.cast_to(&common, &cast_options)) + .transpose()?; + let new_rhs = (rhs_type != common) + .then(|| rhs.cast_to(&common, &cast_options)) + .transpose()?; + Ok((new_lhs, new_rhs)) +} + +/// Coerces two intervals to a common type for the given binary `op` so that +/// downstream interval helpers can operate on a single, consistent data type. +/// +/// Returns `(coerced_lhs, coerced_rhs, common_type)`. Each `coerced_*` is +/// `Some(...)` when a cast was required, and `None` when the original interval +/// already had the common type (the caller should use the original in that +/// case). The returned `common_type` is the type both (possibly cast) operands +/// share, taken from [`BinaryTypeCoercer::get_result_type`] — this mirrors +/// what arrow's numeric kernels would produce when computing the operation. +fn coerce_operands( + lhs: &Interval, + rhs: &Interval, + op: &Operator, +) -> Result<(Option, Option, DataType)> { + let lhs_type = lhs.data_type(); + let rhs_type = rhs.data_type(); + if lhs_type == rhs_type { + return Ok((None, None, lhs_type)); + } + let common_type = + BinaryTypeCoercer::new(&lhs_type, op, &rhs_type).get_result_type()?; + let cast_options = CastOptions::default(); + let new_lhs = (lhs_type != common_type) + .then(|| lhs.cast_to(&common_type, &cast_options)) + .transpose()?; + let new_rhs = (rhs_type != common_type) + .then(|| rhs.cast_to(&common_type, &cast_options)) + .transpose()?; + Ok((new_lhs, new_rhs, common_type)) +} + /// Applies the given binary operator the `lhs` and `rhs` arguments. pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> Result { match *op { @@ -1162,10 +1205,20 @@ fn handle_overflow( match (UPPER, positive_sign) { (true, true) | (false, false) => ScalarValue::try_from(dt).unwrap(), (true, false) => { - get_extreme_value!(MIN, dt) + get_extreme_value!( + MIN, + MIN_DECIMAL128_FOR_EACH_PRECISION, + MIN_DECIMAL256_FOR_EACH_PRECISION, + dt + ) } (false, true) => { - get_extreme_value!(MAX, dt) + get_extreme_value!( + MAX, + MAX_DECIMAL128_FOR_EACH_PRECISION, + MAX_DECIMAL256_FOR_EACH_PRECISION, + dt + ) } } } @@ -2202,7 +2255,7 @@ impl NullableInterval { mod tests { use crate::{ interval_arithmetic::{ - handle_overflow, next_value, prev_value, satisfy_greater, Interval, + Interval, handle_overflow, next_value, prev_value, satisfy_greater, }, operator::Operator, }; @@ -2248,10 +2301,12 @@ mod tests { ScalarValue::Float64(Some(1e-6)), ]; values.into_iter().zip(eps).for_each(|(value, eps)| { - assert!(next_value(value.clone()) - .sub(value.clone()) - .unwrap() - .lt(&eps)); + assert!( + next_value(value.clone()) + .sub(value.clone()) + .unwrap() + .lt(&eps) + ); assert!(value.sub(prev_value(value.clone())).unwrap().lt(&eps)); assert_ne!(next_value(value.clone()), value); assert_ne!(prev_value(value.clone()), value); @@ -2841,18 +2896,26 @@ mod tests { // not contain `null`. #[test] fn test_uncertain_boolean_interval() { - assert!(Interval::TRUE_OR_FALSE - .contains_value(ScalarValue::Boolean(Some(true))) - .unwrap()); - assert!(Interval::TRUE_OR_FALSE - .contains_value(ScalarValue::Boolean(Some(false))) - .unwrap()); - assert!(!Interval::TRUE_OR_FALSE - .contains_value(ScalarValue::Boolean(None)) - .unwrap()); - assert!(!Interval::TRUE_OR_FALSE - .contains_value(ScalarValue::Null) - .unwrap()); + assert!( + Interval::TRUE_OR_FALSE + .contains_value(ScalarValue::Boolean(Some(true))) + .unwrap() + ); + assert!( + Interval::TRUE_OR_FALSE + .contains_value(ScalarValue::Boolean(Some(false))) + .unwrap() + ); + assert!( + !Interval::TRUE_OR_FALSE + .contains_value(ScalarValue::Boolean(None)) + .unwrap() + ); + assert!( + !Interval::TRUE_OR_FALSE + .contains_value(ScalarValue::Null) + .unwrap() + ); } #[test] @@ -3773,6 +3836,134 @@ mod tests { Ok(()) } + #[test] + fn test_mul_div_mismatched_operand_types() -> Result<()> { + // Regression test: previously `Interval::div` and `Interval::mul` + // asserted that both operands had identical data types. That broke + // interval propagation for queries like `numeric / count(*)` where + // the operands end up as different `Decimal128` precisions/scales. + // Now both operations coerce to a common type via `BinaryTypeCoercer`. + + // `Decimal128(38, 10)` / `Decimal128(20, 0)` — the shape produced when + // dividing an unqualified `NUMERIC` by an `Int64` (e.g. `count(*)`). + let lhs = Interval::try_new( + ScalarValue::Decimal128(Some(0), 38, 10), + ScalarValue::Decimal128(Some(100_000_000_000), 38, 10), // 10.0 + )?; + let rhs = Interval::try_new( + ScalarValue::Decimal128(Some(1), 20, 0), + ScalarValue::Decimal128(Some(10), 20, 0), + )?; + let div_result = lhs.div(&rhs)?; + assert!(matches!(div_result.data_type(), DataType::Decimal128(_, _))); + let mul_result = lhs.mul(&rhs)?; + assert!(matches!(mul_result.data_type(), DataType::Decimal128(_, _))); + + // Cross-type Decimal128 / Int64 also goes through coercion. + let int_rhs = Interval::make(Some(1_i64), Some(10_i64))?; + let div_int = lhs.div(&int_rhs)?; + assert!(matches!(div_int.data_type(), DataType::Decimal128(_, _))); + let mul_int = lhs.mul(&int_rhs)?; + assert!(matches!(mul_int.data_type(), DataType::Decimal128(_, _))); + + Ok(()) + } + + #[test] + fn test_intersect_mismatched_decimal_types() -> Result<()> { + // Regression test: previously `Interval::intersect` asserted that both + // operands had identical data types. Now it coerces via + // `comparison_coercion`, which for `Decimal128(38, 10)` and + // `Decimal128(20, 0)` produces `Decimal128(38, 10)`. + + // Overlapping intervals: [0.0, 10.0] ∩ [5, 20] = [5.0, 10.0] + let lhs = Interval::try_new( + ScalarValue::Decimal128(Some(0), 38, 10), + ScalarValue::Decimal128(Some(100_000_000_000), 38, 10), // 10.0 + )?; + let rhs = Interval::try_new( + ScalarValue::Decimal128(Some(5), 20, 0), + ScalarValue::Decimal128(Some(20), 20, 0), + )?; + let intersected = lhs.intersect(&rhs)?.expect("intervals overlap"); + let expected = Interval::try_new( + ScalarValue::Decimal128(Some(50_000_000_000), 38, 10), // 5.0 + ScalarValue::Decimal128(Some(100_000_000_000), 38, 10), // 10.0 + )?; + assert_eq!(intersected, expected); + assert_eq!(intersected.data_type(), DataType::Decimal128(38, 10)); + + // Disjoint intervals across mismatched precisions: [0.0, 3.0] ∩ [5, 20] = ∅ + let lhs_disjoint = Interval::try_new( + ScalarValue::Decimal128(Some(0), 38, 10), + ScalarValue::Decimal128(Some(30_000_000_000), 38, 10), // 3.0 + )?; + assert_eq!(lhs_disjoint.intersect(&rhs)?, None); + + Ok(()) + } + + #[test] + fn test_union_mismatched_decimal_types() -> Result<()> { + // [0.0, 3.0] ∪ [5, 20] (mismatched precision/scale) = [0.0, 20.0] + let lhs = Interval::try_new( + ScalarValue::Decimal128(Some(0), 38, 10), + ScalarValue::Decimal128(Some(30_000_000_000), 38, 10), // 3.0 + )?; + let rhs = Interval::try_new( + ScalarValue::Decimal128(Some(5), 20, 0), + ScalarValue::Decimal128(Some(20), 20, 0), + )?; + let unioned = lhs.union(&rhs)?; + let expected = Interval::try_new( + ScalarValue::Decimal128(Some(0), 38, 10), + ScalarValue::Decimal128(Some(200_000_000_000), 38, 10), // 20.0 + )?; + assert_eq!(unioned, expected); + assert_eq!(unioned.data_type(), DataType::Decimal128(38, 10)); + + Ok(()) + } + + #[test] + fn test_contains_mismatched_decimal_types() -> Result<()> { + // `contains` should return TRUE when the lhs is a strict superset of + // rhs after coercion, TRUE_OR_FALSE when they merely overlap, and + // FALSE when they are disjoint — even with mismatched Decimal128 + // precision/scale. + let rhs = Interval::try_new( + ScalarValue::Decimal128(Some(5), 20, 0), + ScalarValue::Decimal128(Some(10), 20, 0), + )?; + + // Superset: [0.0, 20.0] ⊇ [5, 10] → TRUE + let superset = Interval::try_new( + ScalarValue::Decimal128(Some(0), 38, 10), + ScalarValue::Decimal128(Some(200_000_000_000), 38, 10), // 20.0 + )?; + assert_eq!(superset.contains(&rhs)?, Interval::TRUE); + + // Overlap (not superset): [0.0, 7.0] ∩ [5, 10] = [5.0, 7.0] → TRUE_OR_FALSE + let overlap = Interval::try_new( + ScalarValue::Decimal128(Some(0), 38, 10), + ScalarValue::Decimal128(Some(70_000_000_000), 38, 10), // 7.0 + )?; + assert_eq!(overlap.contains(&rhs)?, Interval::TRUE_OR_FALSE); + + // Disjoint: [0.0, 3.0] ∩ [5, 10] = ∅ → FALSE + let disjoint = Interval::try_new( + ScalarValue::Decimal128(Some(0), 38, 10), + ScalarValue::Decimal128(Some(30_000_000_000), 38, 10), // 3.0 + )?; + assert_eq!(disjoint.contains(&rhs)?, Interval::FALSE); + + // Cross-type with Int64: [0.0, 20.0] ⊇ [5, 10] → TRUE + let int_rhs = Interval::make(Some(5_i64), Some(10_i64))?; + assert_eq!(superset.contains(&int_rhs)?, Interval::TRUE); + + Ok(()) + } + #[test] fn test_overflow_handling() -> Result<()> { // Test integer overflow handling: @@ -3925,7 +4116,7 @@ mod tests { assert_eq!(interval.cardinality().unwrap(), 9178336040581070850); let interval = Interval::try_new( - ScalarValue::UInt64(Some(u64::MIN + 1)), + ScalarValue::UInt64(Some(1)), ScalarValue::UInt64(Some(u64::MAX)), )?; assert_eq!(interval.cardinality().unwrap(), u64::MAX); @@ -3942,6 +4133,46 @@ mod tests { )?; assert_eq!(interval.cardinality().unwrap(), 2); + // Temporal types + let interval = Interval::try_new( + ScalarValue::Date32(Some(0)), + ScalarValue::Date32(Some(10)), + )?; + assert_eq!(interval.cardinality().unwrap(), 11); + + let interval = Interval::try_new( + ScalarValue::Date64(Some(1000)), + ScalarValue::Date64(Some(5000)), + )?; + assert_eq!(interval.cardinality().unwrap(), 4001); + + let interval = Interval::try_new( + ScalarValue::TimestampSecond(Some(100), None), + ScalarValue::TimestampSecond(Some(200), None), + )?; + assert_eq!(interval.cardinality().unwrap(), 101); + + let interval = Interval::try_new( + ScalarValue::TimestampNanosecond(Some(1_000_000_000), None), + ScalarValue::TimestampNanosecond(Some(2_000_000_000), None), + )?; + assert_eq!(interval.cardinality().unwrap(), 1_000_000_001); + Ok(()) + } + + #[test] + fn test_cardinality_full_integer_range_does_not_overflow() -> Result<()> { + let interval = Interval::try_new( + ScalarValue::Int64(Some(i64::MIN)), + ScalarValue::Int64(Some(i64::MAX)), + )?; + assert_eq!(interval.cardinality(), None); + + let interval = Interval::try_new( + ScalarValue::UInt64(Some(0)), + ScalarValue::UInt64(Some(u64::MAX)), + )?; + assert_eq!(interval.cardinality(), None); Ok(()) } @@ -4208,12 +4439,8 @@ mod tests { } macro_rules! capture_mode_change { - ($TYPE:ty) => { - paste::item! { - capture_mode_change_helper!([], - [], - $TYPE); - } + ($TYPE:ty, $TEST_FN_NAME:ident, $CREATE_FN_NAME:ident) => { + capture_mode_change_helper!($TEST_FN_NAME, $CREATE_FN_NAME, $TYPE); }; } @@ -4241,8 +4468,8 @@ mod tests { }; } - capture_mode_change!(f32); - capture_mode_change!(f64); + capture_mode_change!(f32, capture_mode_change_f32, create_interval_f32); + capture_mode_change!(f64, capture_mode_change_f64, create_interval_f64); #[cfg(all( any(target_arch = "x86_64", target_arch = "aarch64"), diff --git a/datafusion/expr-common/src/lib.rs b/datafusion/expr-common/src/lib.rs index 5323c3cb18359..c9a95fd294503 100644 --- a/datafusion/expr-common/src/lib.rs +++ b/datafusion/expr-common/src/lib.rs @@ -31,8 +31,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] pub mod accumulator; @@ -42,7 +40,10 @@ pub mod dyn_eq; pub mod groups_accumulator; pub mod interval_arithmetic; pub mod operator; +pub mod placement; pub mod signature; pub mod sort_properties; pub mod statistics; pub mod type_coercion; + +pub use placement::ExpressionPlacement; diff --git a/datafusion/expr-common/src/operator.rs b/datafusion/expr-common/src/operator.rs index 33512b0c354d6..b15e770802799 100644 --- a/datafusion/expr-common/src/operator.rs +++ b/datafusion/expr-common/src/operator.rs @@ -36,15 +36,15 @@ pub enum Operator { Plus, /// Subtraction Minus, - /// Multiplication operator, like `*` + /// Multiplication Multiply, - /// Division operator, like `/` + /// Division Divide, - /// Remainder operator, like `%` + /// Remainder Modulo, - /// Logical AND, like `&&` + /// Logical AND And, - /// Logical OR, like `||` + /// Logical OR Or, /// `IS DISTINCT FROM` (see [`distinct`]) /// @@ -80,20 +80,20 @@ pub enum Operator { BitwiseShiftRight, /// Bitwise left, like `<<` BitwiseShiftLeft, - /// String concat + /// String concatenation, like `||` StringConcat, /// At arrow, like `@>`. /// /// Currently only supported to be used with lists: /// ```sql - /// select [1,3] <@ [1,2,3] + /// select [1,2,3] @> [1,3] /// ``` AtArrow, /// Arrow at, like `<@`. /// /// Currently only supported to be used with lists: /// ```sql - /// select [1,2,3] @> [1,3] + /// select [1,3] <@ [1,2,3] /// ``` ArrowAt, /// Arrow, like `->`. @@ -120,7 +120,7 @@ pub enum Operator { /// /// Not implemented in DataFusion yet. IntegerDivide, - /// Hash Minis, like `#-` + /// Hash Minus, like `#-` /// /// Not implemented in DataFusion yet. HashMinus, @@ -140,6 +140,10 @@ pub enum Operator { /// /// Not implemented in DataFusion yet. QuestionPipe, + /// Colon operator, like `:` + /// + /// Not implemented in DataFusion yet. + Colon, } impl Operator { @@ -159,6 +163,10 @@ impl Operator { Operator::ILikeMatch => Some(Operator::NotILikeMatch), Operator::NotLikeMatch => Some(Operator::LikeMatch), Operator::NotILikeMatch => Some(Operator::ILikeMatch), + Operator::RegexMatch => Some(Operator::RegexNotMatch), + Operator::RegexIMatch => Some(Operator::RegexNotIMatch), + Operator::RegexNotMatch => Some(Operator::RegexMatch), + Operator::RegexNotIMatch => Some(Operator::RegexIMatch), Operator::Plus | Operator::Minus | Operator::Multiply @@ -166,10 +174,6 @@ impl Operator { | Operator::Modulo | Operator::And | Operator::Or - | Operator::RegexMatch - | Operator::RegexIMatch - | Operator::RegexNotMatch - | Operator::RegexNotIMatch | Operator::BitwiseAnd | Operator::BitwiseOr | Operator::BitwiseXor @@ -188,7 +192,8 @@ impl Operator { | Operator::AtQuestion | Operator::Question | Operator::QuestionAnd - | Operator::QuestionPipe => None, + | Operator::QuestionPipe + | Operator::Colon => None, } } @@ -250,9 +255,9 @@ impl Operator { Operator::GtEq => Some(Operator::LtEq), Operator::AtArrow => Some(Operator::ArrowAt), Operator::ArrowAt => Some(Operator::AtArrow), - Operator::IsDistinctFrom - | Operator::IsNotDistinctFrom - | Operator::Plus + Operator::IsDistinctFrom => Some(Operator::IsDistinctFrom), + Operator::IsNotDistinctFrom => Some(Operator::IsNotDistinctFrom), + Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Divide @@ -283,7 +288,8 @@ impl Operator { | Operator::AtQuestion | Operator::Question | Operator::QuestionAnd - | Operator::QuestionPipe => None, + | Operator::QuestionPipe + | Operator::Colon => None, } } @@ -323,7 +329,8 @@ impl Operator { | Operator::AtQuestion | Operator::Question | Operator::QuestionAnd - | Operator::QuestionPipe => 30, + | Operator::QuestionPipe + | Operator::Colon => 30, Operator::Plus | Operator::Minus => 40, Operator::Multiply | Operator::Divide | Operator::Modulo => 45, } @@ -369,7 +376,9 @@ impl Operator { | Operator::AtQuestion | Operator::Question | Operator::QuestionAnd - | Operator::QuestionPipe => true, + | Operator::QuestionPipe + | Operator::Colon + | Operator::StringConcat => true, // E.g. `TRUE OR NULL` is `TRUE` Operator::Or @@ -377,11 +386,53 @@ impl Operator { | Operator::And // IS DISTINCT FROM and IS NOT DISTINCT FROM always return a TRUE/FALSE value, never NULL | Operator::IsDistinctFrom - | Operator::IsNotDistinctFrom - // DataFusion string concatenation operator treats NULL as an empty string - | Operator::StringConcat => false, + | Operator::IsNotDistinctFrom => false, } } + + /// Parse an `Operator` from the string name `datafusion-proto` uses on the + /// wire (the `Debug` name of the variant, e.g. `"Eq"`). + /// + /// Returns `None` for names with no binary-operator counterpart. This is + /// the canonical proto-string mapping, shared by `datafusion-proto` + /// (logical plans) and `PhysicalExpr` decoders such as `BinaryExpr`, so the + /// mapping is not duplicated across crates. + pub fn from_proto_name(name: &str) -> Option { + Some(match name { + "And" => Operator::And, + "Or" => Operator::Or, + "Eq" => Operator::Eq, + "NotEq" => Operator::NotEq, + "LtEq" => Operator::LtEq, + "Lt" => Operator::Lt, + "Gt" => Operator::Gt, + "GtEq" => Operator::GtEq, + "Plus" => Operator::Plus, + "Minus" => Operator::Minus, + "Multiply" => Operator::Multiply, + "Divide" => Operator::Divide, + "Modulo" => Operator::Modulo, + "IsDistinctFrom" => Operator::IsDistinctFrom, + "IsNotDistinctFrom" => Operator::IsNotDistinctFrom, + "BitwiseAnd" => Operator::BitwiseAnd, + "BitwiseOr" => Operator::BitwiseOr, + "BitwiseXor" => Operator::BitwiseXor, + "BitwiseShiftLeft" => Operator::BitwiseShiftLeft, + "BitwiseShiftRight" => Operator::BitwiseShiftRight, + "RegexIMatch" => Operator::RegexIMatch, + "RegexMatch" => Operator::RegexMatch, + "RegexNotIMatch" => Operator::RegexNotIMatch, + "RegexNotMatch" => Operator::RegexNotMatch, + "LikeMatch" => Operator::LikeMatch, + "ILikeMatch" => Operator::ILikeMatch, + "NotLikeMatch" => Operator::NotLikeMatch, + "NotILikeMatch" => Operator::NotILikeMatch, + "StringConcat" => Operator::StringConcat, + "AtArrow" => Operator::AtArrow, + "ArrowAt" => Operator::ArrowAt, + _ => return None, + }) + } } impl fmt::Display for Operator { @@ -429,6 +480,7 @@ impl fmt::Display for Operator { Operator::Question => "?", Operator::QuestionAnd => "?&", Operator::QuestionPipe => "?|", + Operator::Colon => ":", }; write!(f, "{display}") } diff --git a/datafusion/expr-common/src/placement.rs b/datafusion/expr-common/src/placement.rs new file mode 100644 index 0000000000000..8212ba618e322 --- /dev/null +++ b/datafusion/expr-common/src/placement.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression placement information for optimization decisions. + +/// Describes where an expression should be placed in the query plan for +/// optimal execution. This is used by optimizers to make decisions about +/// expression placement, such as whether to push expressions down through +/// projections. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ExpressionPlacement { + /// A constant literal value. + Literal, + /// A simple column reference. + Column, + /// A cheap expression that can be pushed to leaf nodes in the plan. + /// Examples include `get_field` for struct field access. + /// Pushing these expressions down in the plan can reduce data early + /// at low compute cost. + /// See [`ExpressionPlacement::should_push_to_leaves`] for details. + MoveTowardsLeafNodes, + /// An expensive expression that should stay where it is in the plan. + /// Examples include complex scalar functions or UDFs. + KeepInPlace, +} + +impl ExpressionPlacement { + /// Returns true if the expression can be pushed down to leaf nodes + /// in the query plan. + /// + /// This returns true for: + /// - [`ExpressionPlacement::Column`]: Simple column references can be pushed down. They do no compute and do not increase or + /// decrease the amount of data being processed. + /// A projection that reduces the number of columns can eliminate unnecessary data early, + /// but this method only considers one expression at a time, not a projection as a whole. + /// - [`ExpressionPlacement::MoveTowardsLeafNodes`]: Cheap expressions can be pushed down to leaves to take advantage of + /// early computation and potential optimizations at the data source level. + /// For example `struct_col['field']` is cheap to compute (just an Arc clone of the nested array for `'field'`) + /// and thus can reduce data early in the plan at very low compute cost. + /// It may even be possible to eliminate the expression entirely if the data source can project only the needed field + /// (as e.g. Parquet can). + pub fn should_push_to_leaves(&self) -> bool { + matches!( + self, + ExpressionPlacement::Column | ExpressionPlacement::MoveTowardsLeafNodes + ) + } +} diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 6ee1c4a2a40c6..3e941f00c2ee3 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -19,12 +19,15 @@ use std::fmt::Display; use std::hash::Hash; +use std::sync::Arc; -use crate::type_coercion::aggregates::NUMERICS; -use arrow::datatypes::{DataType, Decimal128Type, DecimalType, IntervalUnit, TimeUnit}; +use arrow::datatypes::{ + DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL128_MAX_PRECISION, DataType, + Decimal128Type, DecimalType, Field, IntervalUnit, TimeUnit, +}; use datafusion_common::types::{LogicalType, LogicalTypeRef, NativeType}; use datafusion_common::utils::ListCoercion; -use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_common::{Result, internal_err, plan_err}; use indexmap::IndexSet; use itertools::Itertools; @@ -154,7 +157,7 @@ pub enum Arity { pub enum TypeSignature { /// One or more arguments of a common type out of a list of valid types. /// - /// For functions that take no arguments (e.g. `random()` see [`TypeSignature::Nullary`]). + /// For functions that take no arguments (e.g. `random()`), see [`TypeSignature::Nullary`]. /// /// # Examples /// @@ -180,7 +183,7 @@ pub enum TypeSignature { Uniform(usize, Vec), /// One or more arguments with exactly the specified types in order. /// - /// For functions that take no arguments (e.g. `random()`) use [`TypeSignature::Nullary`]. + /// For functions that take no arguments (e.g. `random()`), use [`TypeSignature::Nullary`]. Exact(Vec), /// One or more arguments belonging to the [`TypeSignatureClass`], in order. /// @@ -188,12 +191,12 @@ pub enum TypeSignature { /// casts. For example, if you expect a function has string type, but you /// also allow it to be casted from binary type. /// - /// For functions that take no arguments (e.g. `random()`) see [`TypeSignature::Nullary`]. + /// For functions that take no arguments (e.g. `random()`), see [`TypeSignature::Nullary`]. Coercible(Vec), /// One or more arguments coercible to a single, comparable type. /// /// Each argument will be coerced to a single type using the - /// coercion rules described in [`comparison_coercion_numeric`]. + /// coercion rules described in [`comparison_coercion`]. /// /// # Examples /// @@ -201,17 +204,18 @@ pub enum TypeSignature { /// the types will both be coerced to `i64` before the function is invoked. /// /// If the `nullif('1', 2)` function is called with `Utf8` and `i64` arguments - /// the types will both be coerced to `Utf8` before the function is invoked. + /// the types will both be coerced to `Int64` before the function is invoked + /// (numeric is preferred over string). /// /// Note: - /// - For functions that take no arguments (e.g. `random()` see [`TypeSignature::Nullary`]). + /// - For functions that take no arguments (e.g. `random()`), see [`TypeSignature::Nullary`]. /// - If all arguments have type [`DataType::Null`], they are coerced to `Utf8` /// - /// [`comparison_coercion_numeric`]: crate::type_coercion::binary::comparison_coercion_numeric + /// [`comparison_coercion`]: crate::type_coercion::binary::comparison_coercion Comparable(usize), /// One or more arguments of arbitrary types. /// - /// For functions that take no arguments (e.g. `random()`) use [`TypeSignature::Nullary`]. + /// For functions that take no arguments (e.g. `random()`), use [`TypeSignature::Nullary`]. Any(usize), /// Matches exactly one of a list of [`TypeSignature`]s. /// @@ -229,7 +233,7 @@ pub enum TypeSignature { /// /// See [`NativeType::is_numeric`] to know which type is considered numeric /// - /// For functions that take no arguments (e.g. `random()`) use [`TypeSignature::Nullary`]. + /// For functions that take no arguments (e.g. `random()`), use [`TypeSignature::Nullary`]. /// /// [`NativeType::is_numeric`]: datafusion_common::types::NativeType::is_numeric Numeric(usize), @@ -242,7 +246,7 @@ pub enum TypeSignature { /// For example, if a function is called with (utf8, large_utf8), all /// arguments will be coerced to `LargeUtf8` /// - /// For functions that take no arguments (e.g. `random()` use [`TypeSignature::Nullary`]). + /// For functions that take no arguments (e.g. `random()`), use [`TypeSignature::Nullary`]. String(usize), /// No arguments Nullary, @@ -318,6 +322,43 @@ impl TypeSignature { } } +impl Display for TypeSignature { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TypeSignature::Variadic(types) => { + write!(f, "Variadic({})", types.iter().join(", ")) + } + TypeSignature::UserDefined => write!(f, "UserDefined"), + TypeSignature::VariadicAny => write!(f, "VariadicAny"), + TypeSignature::Uniform(count, types) => { + write!(f, "Uniform({count}, [{}])", types.iter().join(", ")) + } + TypeSignature::Exact(types) => { + write!(f, "Exact({})", types.iter().join(", ")) + } + TypeSignature::Coercible(coercions) => { + write!(f, "Coercible({})", coercions.iter().join(", ")) + } + TypeSignature::Comparable(count) => write!(f, "Comparable({count})"), + TypeSignature::Any(count) => write!(f, "Any({count})"), + TypeSignature::OneOf(sigs) => { + write!(f, "OneOf(")?; + for (i, sig) in sigs.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{sig}")?; + } + write!(f, ")") + } + TypeSignature::ArraySignature(sig) => write!(f, "ArraySignature({sig})"), + TypeSignature::Numeric(count) => write!(f, "Numeric({count})"), + TypeSignature::String(count) => write!(f, "String({count})"), + TypeSignature::Nullary => write!(f, "Nullary"), + } + } +} + /// Represents the class of types that can be used in a function signature. /// /// This is used to specify what types are valid for function arguments in a more flexible way than @@ -328,22 +369,45 @@ impl TypeSignature { /// arguments that can be coerced to a particular class of types. #[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Hash)] pub enum TypeSignatureClass { + /// Allows an arbitrary type argument without coercing the argument. + Any, + /// Timestamps, allowing arbitrary (or no) timezones Timestamp, + /// All time types Time, + /// All interval types Interval, + /// All duration types Duration, + /// A specific native type Native(LogicalTypeRef), + /// Signed and unsigned integers Integer, + /// All float types Float, + /// All decimal types, allowing arbitrary precision & scale Decimal, + /// Integers, floats and decimals Numeric, - /// Encompasses both the native Binary as well as arbitrarily sized FixedSizeBinary types + /// Encompasses both the native Binary/LargeBinary types as well as arbitrarily sized FixedSizeBinary types Binary, } impl Display for TypeSignatureClass { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "TypeSignatureClass::{self:?}") + match self { + Self::Any => write!(f, "Any"), + Self::Timestamp => write!(f, "Timestamp"), + Self::Time => write!(f, "Time"), + Self::Interval => write!(f, "Interval"), + Self::Duration => write!(f, "Duration"), + Self::Native(logical_type) => write!(f, "{logical_type}"), + Self::Integer => write!(f, "Integer"), + Self::Float => write!(f, "Float"), + Self::Decimal => write!(f, "Decimal"), + Self::Numeric => write!(f, "Numeric"), + Self::Binary => write!(f, "Binary"), + } } } @@ -354,6 +418,9 @@ impl TypeSignatureClass { /// documentation or error messages. fn get_example_types(&self) -> Vec { match self { + // TODO: might be too much info to return every single type here + // maybe https://github.com/apache/datafusion/issues/14761 will help here? + TypeSignatureClass::Any => vec![], TypeSignatureClass::Native(l) => get_data_types(l.native()), TypeSignatureClass::Timestamp => { vec![ @@ -396,6 +463,7 @@ impl TypeSignatureClass { } match self { + TypeSignatureClass::Any => true, TypeSignatureClass::Native(t) if t.native() == logical_type => true, TypeSignatureClass::Timestamp if logical_type.is_timestamp() => true, TypeSignatureClass::Time if logical_type.is_time() => true, @@ -417,6 +485,7 @@ impl TypeSignatureClass { origin_type: &DataType, ) -> Result { match self { + TypeSignatureClass::Any => Ok(origin_type.to_owned()), TypeSignatureClass::Native(logical_type) => { logical_type.native().default_cast_for(origin_type) } @@ -526,6 +595,20 @@ impl Display for ArrayFunctionArgument { } } +static NUMERICS: &[DataType] = &[ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float16, + DataType::Float32, + DataType::Float64, +]; + impl TypeSignature { pub fn to_string_repr(&self) -> Vec { match self { @@ -558,9 +641,11 @@ impl TypeSignature { vec![Self::join_types(types, ", ")] } TypeSignature::Any(arg_count) => { - vec![std::iter::repeat_n("Any", *arg_count) - .collect::>() - .join(", ")] + vec![ + std::iter::repeat_n("Any", *arg_count) + .collect::>() + .join(", "), + ] } TypeSignature::UserDefined => { vec!["UserDefined".to_string()] @@ -607,87 +692,103 @@ impl TypeSignature { match self { TypeSignature::Exact(types) => { if let Some(names) = parameter_names { - vec![names - .iter() - .zip(types.iter()) - .map(|(name, typ)| format!("{name}: {typ}")) - .collect::>() - .join(", ")] + vec![ + names + .iter() + .zip(types.iter()) + .map(|(name, typ)| format!("{name}: {typ}")) + .collect::>() + .join(", "), + ] } else { vec![Self::join_types(types, ", ")] } } TypeSignature::Any(count) => { if let Some(names) = parameter_names { - vec![names - .iter() - .take(*count) - .map(|name| format!("{name}: Any")) - .collect::>() - .join(", ")] + vec![ + names + .iter() + .take(*count) + .map(|name| format!("{name}: Any")) + .collect::>() + .join(", "), + ] } else { - vec![std::iter::repeat_n("Any", *count) - .collect::>() - .join(", ")] + vec![ + std::iter::repeat_n("Any", *count) + .collect::>() + .join(", "), + ] } } TypeSignature::Uniform(count, types) => { if let Some(names) = parameter_names { let type_str = Self::join_types(types, "/"); - vec![names - .iter() - .take(*count) - .map(|name| format!("{name}: {type_str}")) - .collect::>() - .join(", ")] + vec![ + names + .iter() + .take(*count) + .map(|name| format!("{name}: {type_str}")) + .collect::>() + .join(", "), + ] } else { self.to_string_repr() } } TypeSignature::Coercible(coercions) => { if let Some(names) = parameter_names { - vec![names - .iter() - .zip(coercions.iter()) - .map(|(name, coercion)| format!("{name}: {coercion}")) - .collect::>() - .join(", ")] + vec![ + names + .iter() + .zip(coercions.iter()) + .map(|(name, coercion)| format!("{name}: {coercion}")) + .collect::>() + .join(", "), + ] } else { vec![Self::join_types(coercions, ", ")] } } TypeSignature::Comparable(count) => { if let Some(names) = parameter_names { - vec![names - .iter() - .take(*count) - .map(|name| format!("{name}: Comparable")) - .collect::>() - .join(", ")] + vec![ + names + .iter() + .take(*count) + .map(|name| format!("{name}: Comparable")) + .collect::>() + .join(", "), + ] } else { self.to_string_repr() } } TypeSignature::Numeric(count) => { if let Some(names) = parameter_names { - vec![names - .iter() - .take(*count) - .map(|name| format!("{name}: Numeric")) - .collect::>() - .join(", ")] + vec![ + names + .iter() + .take(*count) + .map(|name| format!("{name}: Numeric")) + .collect::>() + .join(", "), + ] } else { self.to_string_repr() } } TypeSignature::String(count) => { if let Some(names) = parameter_names { - vec![names - .iter() - .take(*count) - .map(|name| format!("{name}: String")) - .collect::>() - .join(", ")] + vec![ + names + .iter() + .take(*count) + .map(|name| format!("{name}: String")) + .collect::>() + .join(", "), + ] } else { self.to_string_repr() } @@ -697,28 +798,34 @@ impl TypeSignature { if let Some(names) = parameter_names { match array_sig { ArrayFunctionSignature::Array { arguments, .. } => { - vec![names - .iter() - .zip(arguments.iter()) - .map(|(name, arg_type)| format!("{name}: {arg_type}")) - .collect::>() - .join(", ")] + vec![ + names + .iter() + .zip(arguments.iter()) + .map(|(name, arg_type)| format!("{name}: {arg_type}")) + .collect::>() + .join(", "), + ] } ArrayFunctionSignature::RecursiveArray => { - vec![names - .iter() - .take(1) - .map(|name| format!("{name}: recursive_array")) - .collect::>() - .join(", ")] + vec![ + names + .iter() + .take(1) + .map(|name| format!("{name}: recursive_array")) + .collect::>() + .join(", "), + ] } ArrayFunctionSignature::MapArray => { - vec![names - .iter() - .take(1) - .map(|name| format!("{name}: map_array")) - .collect::>() - .join(", ")] + vec![ + names + .iter() + .take(1) + .map(|name| format!("{name}: map_array")) + .collect::>() + .join(", "), + ] } } } else { @@ -864,8 +971,56 @@ fn get_data_types(native_type: &NativeType) -> Vec { NativeType::String => { vec![DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View] } - // TODO: support other native types - _ => vec![], + NativeType::Decimal(precision, scale) => { + // We assume incoming NativeType is valid already, in terms of precision & scale + let mut types = vec![DataType::Decimal256(*precision, *scale)]; + if *precision <= DECIMAL32_MAX_PRECISION { + types.push(DataType::Decimal32(*precision, *scale)); + } + if *precision <= DECIMAL64_MAX_PRECISION { + types.push(DataType::Decimal64(*precision, *scale)); + } + if *precision <= DECIMAL128_MAX_PRECISION { + types.push(DataType::Decimal128(*precision, *scale)); + } + types + } + NativeType::Timestamp(time_unit, timezone) => { + vec![DataType::Timestamp(*time_unit, timezone.to_owned())] + } + NativeType::Time(TimeUnit::Second) => vec![DataType::Time32(TimeUnit::Second)], + NativeType::Time(TimeUnit::Millisecond) => { + vec![DataType::Time32(TimeUnit::Millisecond)] + } + NativeType::Time(TimeUnit::Microsecond) => { + vec![DataType::Time64(TimeUnit::Microsecond)] + } + NativeType::Time(TimeUnit::Nanosecond) => { + vec![DataType::Time64(TimeUnit::Nanosecond)] + } + NativeType::Duration(time_unit) => vec![DataType::Duration(*time_unit)], + NativeType::Interval(interval_unit) => vec![DataType::Interval(*interval_unit)], + NativeType::FixedSizeBinary(size) => vec![DataType::FixedSizeBinary(*size)], + NativeType::FixedSizeList(logical_field, size) => { + get_data_types(logical_field.logical_type.native()) + .iter() + .map(|child_dt| { + let field = Field::new( + logical_field.name.clone(), + child_dt.clone(), + logical_field.nullable, + ); + DataType::FixedSizeList(Arc::new(field), *size) + }) + .collect() + } + // TODO: implement for nested types + NativeType::List(_) + | NativeType::Struct(_) + | NativeType::Union(_) + | NativeType::Map(_) => { + vec![] + } } } @@ -970,12 +1125,7 @@ impl Coercion { impl Display for Coercion { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Coercion({}", self.desired_type())?; - if let Some(implicit_coercion) = self.implicit_coercion() { - write!(f, ", implicit_coercion={implicit_coercion}",) - } else { - write!(f, ")") - } + write!(f, "{}", self.desired_type()) } } @@ -1027,11 +1177,14 @@ pub struct ImplicitCoercion { impl Display for ImplicitCoercion { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "ImplicitCoercion({:?}, default_type={:?})", - self.allowed_source_types, self.default_casted_type - ) + write!(f, "ImplicitCoercion(")?; + for (i, source_type) in self.allowed_source_types.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{source_type}")?; + } + write!(f, "; default={}", self.default_casted_type) } } @@ -1324,7 +1477,7 @@ impl Signature { Arity::Variable => { // For UserDefined signatures, allow parameter names // The function implementer is responsible for validating the names match the actual arguments - if !matches!(self.type_signature, TypeSignature::UserDefined) { + if self.type_signature != TypeSignature::UserDefined { return plan_err!( "Cannot specify parameter names for variable arity signature: {:?}", self.type_signature @@ -1346,7 +1499,9 @@ impl Signature { #[cfg(test)] mod tests { - use datafusion_common::types::{logical_int32, logical_int64, logical_string}; + use datafusion_common::types::{ + NativeType, logical_float64, logical_int32, logical_int64, logical_string, + }; use super::*; use crate::signature::{ @@ -1493,6 +1648,7 @@ mod tests { vec![DataType::UInt16, DataType::UInt16], vec![DataType::UInt32, DataType::UInt32], vec![DataType::UInt64, DataType::UInt64], + vec![DataType::Float16, DataType::Float16], vec![DataType::Float32, DataType::Float32], vec![DataType::Float64, DataType::Float64] ] @@ -1538,10 +1694,12 @@ mod tests { .with_parameter_names(vec!["count".to_string()]); // Only 1 name for 2 args assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("does not match signature arity")); + assert!( + result + .unwrap_err() + .to_string() + .contains("does not match signature arity") + ); } #[test] @@ -1553,10 +1711,12 @@ mod tests { .with_parameter_names(vec!["count".to_string(), "count".to_string()]); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Duplicate parameter name")); + assert!( + result + .unwrap_err() + .to_string() + .contains("Duplicate parameter name") + ); } #[test] @@ -1565,10 +1725,12 @@ mod tests { .with_parameter_names(vec!["arg".to_string()]); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("variable arity signature")); + assert!( + result + .unwrap_err() + .to_string() + .contains("variable arity signature") + ); } #[test] @@ -1935,4 +2097,124 @@ mod tests { let sig = TypeSignature::UserDefined; assert_eq!(sig.arity(), Arity::Variable); } + + #[test] + fn test_type_signature_display() { + use insta::assert_snapshot; + + assert_snapshot!(TypeSignature::Nullary, @"Nullary"); + assert_snapshot!(TypeSignature::Any(2), @"Any(2)"); + assert_snapshot!(TypeSignature::Numeric(3), @"Numeric(3)"); + assert_snapshot!(TypeSignature::String(1), @"String(1)"); + assert_snapshot!(TypeSignature::Comparable(2), @"Comparable(2)"); + assert_snapshot!(TypeSignature::VariadicAny, @"VariadicAny"); + assert_snapshot!(TypeSignature::UserDefined, @"UserDefined"); + + assert_snapshot!( + TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]), + @"Exact(Int32, Utf8)" + ); + assert_snapshot!( + TypeSignature::Variadic(vec![DataType::Utf8, DataType::LargeUtf8]), + @"Variadic(Utf8, LargeUtf8)" + ); + assert_snapshot!( + TypeSignature::Uniform(2, vec![DataType::Float32, DataType::Float64]), + @"Uniform(2, [Float32, Float64])" + ); + + assert_snapshot!( + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_float64())), + Coercion::new_exact(TypeSignatureClass::Native(logical_int32())), + ]), + @"Coercible(Float64, Int32)" + ); + + assert_snapshot!( + TypeSignature::OneOf(vec![ + TypeSignature::Nullary, + TypeSignature::VariadicAny, + ]), + @"OneOf(Nullary, VariadicAny)" + ); + } + + #[test] + fn test_type_signature_class_display() { + use insta::assert_snapshot; + + assert_snapshot!(TypeSignatureClass::Any, @"Any"); + assert_snapshot!(TypeSignatureClass::Numeric, @"Numeric"); + assert_snapshot!(TypeSignatureClass::Integer, @"Integer"); + assert_snapshot!(TypeSignatureClass::Float, @"Float"); + assert_snapshot!(TypeSignatureClass::Decimal, @"Decimal"); + assert_snapshot!(TypeSignatureClass::Timestamp, @"Timestamp"); + assert_snapshot!(TypeSignatureClass::Time, @"Time"); + assert_snapshot!(TypeSignatureClass::Interval, @"Interval"); + assert_snapshot!(TypeSignatureClass::Duration, @"Duration"); + assert_snapshot!(TypeSignatureClass::Binary, @"Binary"); + assert_snapshot!(TypeSignatureClass::Native(logical_int32()), @"Int32"); + assert_snapshot!(TypeSignatureClass::Native(logical_string()), @"String"); + } + + #[test] + fn test_coercion_display() { + use insta::assert_snapshot; + + let exact_int = Coercion::new_exact(TypeSignatureClass::Native(logical_int32())); + assert_snapshot!(exact_int, @"Int32"); + + let exact_numeric = Coercion::new_exact(TypeSignatureClass::Numeric); + assert_snapshot!(exact_numeric, @"Numeric"); + + let implicit = Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ); + assert_snapshot!(implicit, @"Float64"); + + let implicit_with_multiple_sources = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer, TypeSignatureClass::Numeric], + NativeType::Int64, + ); + assert_snapshot!(implicit_with_multiple_sources, @"Int64"); + } + + #[test] + fn test_to_string_repr_coercible() { + use insta::assert_snapshot; + + // Simulates a function like round(Float64, Int64) with coercion + let sig = TypeSignature::Coercible(vec![ + Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), + ]); + let repr = sig.to_string_repr(); + assert_eq!(repr.len(), 1); + assert_snapshot!(repr[0], @"Float64, Int64"); + } + + #[test] + fn test_to_string_repr_coercible_exact() { + use insta::assert_snapshot; + + let sig = TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_int64())), + ]); + let repr = sig.to_string_repr(); + assert_eq!(repr.len(), 1); + assert_snapshot!(repr[0], @"String, Int64"); + } } diff --git a/datafusion/expr-common/src/statistics.rs b/datafusion/expr-common/src/statistics.rs index 7284673d9a8f7..034358b043135 100644 --- a/datafusion/expr-common/src/statistics.rs +++ b/datafusion/expr-common/src/statistics.rs @@ -15,9 +15,19 @@ // specific language governing permissions and limitations // under the License. +//! Probabilistic distributions for expression-level statistics (unused). +//! +//! Note: All public items in this module are **deprecated** as of `54.0.0`. +//! +//! See for details. + +// The whole module is deprecated; suppress warnings from intra-module uses +// of the deprecated types so the module continues to compile. +#![allow(deprecated)] + use std::f64::consts::LN_2; -use crate::interval_arithmetic::{apply_operator, Interval}; +use crate::interval_arithmetic::{Interval, apply_operator}; use crate::operator::Operator; use crate::type_coercion::binary::binary_numeric_coercion; @@ -25,8 +35,8 @@ use arrow::array::ArrowNativeTypeOp; use arrow::datatypes::DataType; use datafusion_common::rounding::alter_fp_rounding_mode; use datafusion_common::{ - assert_eq_or_internal_err, assert_ne_or_internal_err, assert_or_internal_err, - internal_err, not_impl_err, Result, ScalarValue, + Result, ScalarValue, assert_eq_or_internal_err, assert_ne_or_internal_err, + assert_or_internal_err, internal_err, not_impl_err, }; /// This object defines probabilistic distributions that encode uncertain @@ -37,6 +47,10 @@ use datafusion_common::{ /// is the main unit of calculus when evaluating expressions in a statistical /// context. Notions like column and table statistics are built on top of this /// object and the operations it supports. +#[deprecated( + since = "54.0.0", + note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071" +)] #[derive(Clone, Debug, PartialEq)] pub enum Distribution { Uniform(UniformDistribution), @@ -214,6 +228,10 @@ impl Distribution { /// /// /// +#[deprecated( + since = "54.0.0", + note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071" +)] #[derive(Clone, Debug, PartialEq)] pub struct UniformDistribution { interval: Interval, @@ -236,6 +254,10 @@ pub struct UniformDistribution { /// For more information, see: /// /// +#[deprecated( + since = "54.0.0", + note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071" +)] #[derive(Clone, Debug, PartialEq)] pub struct ExponentialDistribution { rate: ScalarValue, @@ -249,6 +271,10 @@ pub struct ExponentialDistribution { /// For a more in-depth discussion, see: /// /// +#[deprecated( + since = "54.0.0", + note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071" +)] #[derive(Clone, Debug, PartialEq)] pub struct GaussianDistribution { mean: ScalarValue, @@ -259,6 +285,10 @@ pub struct GaussianDistribution { /// the success probability is unknown. For a more in-depth discussion, see: /// /// +#[deprecated( + since = "54.0.0", + note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071" +)] #[derive(Clone, Debug, PartialEq)] pub struct BernoulliDistribution { p: ScalarValue, @@ -268,6 +298,10 @@ pub struct BernoulliDistribution { /// approximated via some summary statistics. For a more in-depth discussion, see: /// /// +#[deprecated( + since = "54.0.0", + note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071" +)] #[derive(Clone, Debug, PartialEq)] pub struct GenericDistribution { mean: ScalarValue, @@ -594,6 +628,10 @@ impl GenericDistribution { /// This function takes a logical operator and two Bernoulli distributions, /// and it returns a new Bernoulli distribution that represents the result of /// the operation. Currently, only `AND` and `OR` operations are supported. +#[deprecated( + since = "54.0.0", + note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071" +)] pub fn combine_bernoullis( op: &Operator, left: &BernoulliDistribution, @@ -649,6 +687,10 @@ pub fn combine_bernoullis( /// see: /// /// +#[deprecated( + since = "54.0.0", + note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071" +)] pub fn combine_gaussians( op: &Operator, left: &GaussianDistribution, @@ -673,6 +715,10 @@ pub fn combine_gaussians( /// Expects `op` to be a comparison operator, with `left` and `right` having /// numeric distributions. The resulting distribution has the `Float64` data /// type. +#[deprecated( + since = "54.0.0", + note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071" +)] pub fn create_bernoulli_from_comparison( op: &Operator, left: &Distribution, @@ -751,6 +797,10 @@ pub fn create_bernoulli_from_comparison( /// given binary operation on two unknown quantities represented by their /// [`Distribution`] objects. The function computes the mean, median and /// variance if possible. +#[deprecated( + since = "54.0.0", + note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071" +)] pub fn new_generic_from_binary_op( op: &Operator, left: &Distribution, @@ -766,6 +816,10 @@ pub fn new_generic_from_binary_op( /// Computes the mean value for the result of the given binary operation on /// two unknown quantities represented by their [`Distribution`] objects. +#[deprecated( + since = "54.0.0", + note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071" +)] pub fn compute_mean( op: &Operator, left: &Distribution, @@ -798,6 +852,10 @@ pub fn compute_mean( /// the median is calculable only for addition and subtraction operations on: /// - [`Uniform`] and [`Uniform`] distributions, and /// - [`Gaussian`] and [`Gaussian`] distributions. +#[deprecated( + since = "54.0.0", + note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071" +)] pub fn compute_median( op: &Operator, left: &Distribution, @@ -835,6 +893,10 @@ pub fn compute_median( /// Computes the variance value for the result of the given binary operation on /// two unknown quantities represented by their [`Distribution`] objects. +#[deprecated( + since = "54.0.0", + note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071" +)] pub fn compute_variance( op: &Operator, left: &Distribution, @@ -878,11 +940,11 @@ pub fn compute_variance( #[cfg(test)] mod tests { use super::{ + BernoulliDistribution, Distribution, GaussianDistribution, UniformDistribution, combine_bernoullis, combine_gaussians, compute_mean, compute_median, compute_variance, create_bernoulli_from_comparison, new_generic_from_binary_op, - BernoulliDistribution, Distribution, GaussianDistribution, UniformDistribution, }; - use crate::interval_arithmetic::{apply_operator, Interval}; + use crate::interval_arithmetic::{Interval, apply_operator}; use crate::operator::Operator; use arrow::datatypes::DataType; @@ -1632,3 +1694,47 @@ mod tests { all_ops.into_iter().collect() } } + +use std::sync::Arc; + +use datafusion_common::Column; + +/// A statistic a caller would like a provider to supply, if it can do so +/// cheaply. +/// +/// A small, query-aware extension to the existing `Statistics` model: instead +/// of "give me everything you have for every column", a caller can ask for a +/// specific list of stats by name. `StatisticsRequest` is just that vocabulary +/// — DataFusion itself does not populate or consume it. It exists so a request +/// can be threaded from a `TableScan` (see `TableScan::statistics_requests`) +/// through `ScanArgs::statistics_requests` to a `TableProvider`, which is enough +/// for a query-aware statistics feature to be implemented outside of DataFusion. +/// +/// Each variant maps onto a field of [`datafusion_common::Statistics`] / +/// [`datafusion_common::ColumnStatistics`], so a provider that already +/// populates one can answer the request trivially. +/// +/// The per-column variants hold an `Arc` rather than an owned +/// [`Column`] (which carries owned strings) so cloning a request — and the +/// `BTreeSet` stored on `TableScan`, which is cloned with +/// the plan during optimization — stays cheap. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum StatisticsRequest { + /// Smallest non-null value of `column`. + Min(Arc), + /// Largest non-null value of `column`. + Max(Arc), + /// Number of NULLs in `column`. + NullCount(Arc), + /// Number of distinct values in `column` (exact or estimated). + DistinctCount(Arc), + /// Sum of values in `column` (numerics, widened per + /// `ColumnStatistics::sum_value`). + Sum(Arc), + /// Encoded/output byte size of `column`. + ByteSize(Arc), + /// Number of rows in the container (table / file). + RowCount, + /// Total byte size of the container's output. + TotalByteSize, +} diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs index 55a8843394b51..ada0bd26b8d06 100644 --- a/datafusion/expr-common/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -18,10 +18,9 @@ use crate::signature::TypeSignature; use arrow::datatypes::{DataType, FieldRef}; -use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_common::{Result, internal_err, plan_err}; -// TODO: remove usage of these (INTEGERS and NUMERICS) in favour of signatures -// see https://github.com/apache/datafusion/issues/18092 +#[deprecated(since = "54.0.0", note = "Use functions signatures")] pub static INTEGERS: &[DataType] = &[ DataType::Int8, DataType::Int16, @@ -33,6 +32,7 @@ pub static INTEGERS: &[DataType] = &[ DataType::UInt64, ]; +#[deprecated(since = "54.0.0", note = "Use functions signatures")] pub static NUMERICS: &[DataType] = &[ DataType::Int8, DataType::Int16, @@ -42,6 +42,7 @@ pub static NUMERICS: &[DataType] = &[ DataType::UInt16, DataType::UInt32, DataType::UInt64, + DataType::Float16, DataType::Float32, DataType::Float64, ]; @@ -60,8 +61,7 @@ pub fn check_arg_count( TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { if input_fields.len() != *agg_count { return plan_err!( - "The function {func_name} expects {:?} arguments, but {:?} were provided", - agg_count, + "The function {func_name} expects {agg_count} arguments, but {} were provided", input_fields.len() ); } @@ -69,7 +69,7 @@ pub fn check_arg_count( TypeSignature::Exact(types) => { if types.len() != input_fields.len() { return plan_err!( - "The function {func_name} expects {:?} arguments, but {:?} were provided", + "The function {func_name} expects {} arguments, but {} were provided", types.len(), input_fields.len() ); @@ -81,7 +81,7 @@ pub fn check_arg_count( .any(|v| check_arg_count(func_name, input_fields, v).is_ok()); if !ok { return plan_err!( - "The function {func_name} does not accept {:?} function arguments.", + "The function {func_name} does not accept {} function arguments.", input_fields.len() ); } @@ -100,9 +100,7 @@ pub fn check_arg_count( // Numeric and Coercible signature is validated in `get_valid_types` } _ => { - return internal_err!( - "Aggregate functions do not support this {signature:?}" - ); + return internal_err!("Aggregate functions do not support this {signature}"); } } Ok(()) diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 4aacc7533c64a..e700d4a04da3b 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -17,23 +17,25 @@ //! Coercion rules for matching argument types for binary operators +use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; use crate::operator::Operator; -use arrow::array::{new_empty_array, Array}; +use arrow::array::{Array, new_empty_array}; use arrow::compute::can_cast_types; +use arrow::datatypes::IntervalUnit::MonthDayNano; +use arrow::datatypes::TimeUnit::*; use arrow::datatypes::{ - DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION, - DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, - DECIMAL64_MAX_SCALE, + DECIMAL64_MAX_SCALE, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DataType, Field, FieldRef, Fields, + TimeUnit, }; -use datafusion_common::types::NativeType; use datafusion_common::{ - exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Diagnostic, - Result, Span, Spans, + Diagnostic, Result, Span, Spans, exec_err, internal_err, not_impl_err, + plan_datafusion_err, plan_err, }; use itertools::Itertools; @@ -184,8 +186,8 @@ impl<'a> BinaryTypeCoercer<'a> { } fn signature_inner(&'a self, lhs: &DataType, rhs: &DataType) -> Result { - use arrow::datatypes::DataType::*; use Operator::*; + use arrow::datatypes::DataType::*; let result = match self.op { Eq | NotEq | @@ -258,15 +260,36 @@ impl<'a> BinaryTypeCoercer<'a> { ) }) } + Minus if is_date_minus_date(lhs, rhs) => { + return Ok(Signature { + lhs: lhs.clone(), + rhs: rhs.clone(), + ret: Int64, + }); + } Plus | Minus | Multiply | Divide | Modulo => { if let Ok(ret) = self.get_result(lhs, rhs) { + // Temporal arithmetic, e.g. Date32 + Interval Ok(Signature{ lhs: lhs.clone(), rhs: rhs.clone(), ret, }) + } else if let Some((lhs, rhs)) = temporal_math_coercion(lhs, rhs) { + // Temporal arithmetic, e.g. Date32 + int64, Timestamp + duration, etc + let ret = self.get_result(&lhs, &rhs).map_err(|e| { + plan_datafusion_err!( + "Cannot get result type for temporal operation {} {} {}: {e}", self.lhs, self.op, self.rhs + ) + })?; + Ok(Signature { + lhs, + rhs, + ret, + }) } else if let Some(coerced) = temporal_coercion_strict_timezone(lhs, rhs) { + // Temporal arithmetic by first coercing to a common time representation // e.g. Date32 - Timestamp let ret = self.get_result(&coerced, &coerced).map_err(|e| { @@ -300,6 +323,9 @@ impl<'a> BinaryTypeCoercer<'a> { ) } }, + Colon => { + Ok(Signature { lhs: lhs.clone(), rhs: rhs.clone(), ret: lhs.clone() }) + }, IntegerDivide | Arrow | LongArrow | HashArrow | HashLongArrow | HashMinus | AtQuestion | Question | QuestionAnd | QuestionPipe => { not_impl_err!("Operator {} is not yet supported", self.op) @@ -327,13 +353,12 @@ impl<'a> BinaryTypeCoercer<'a> { // TODO Move the rest inside of BinaryTypeCoercer -fn is_decimal(data_type: &DataType) -> bool { +/// Returns true if both operands are Date types (Date32 or Date64) +/// Used to detect Date - Date operations which should return Int64 (days difference) +fn is_date_minus_date(lhs: &DataType, rhs: &DataType) -> bool { matches!( - data_type, - DataType::Decimal32(..) - | DataType::Decimal64(..) - | DataType::Decimal128(..) - | DataType::Decimal256(..) + (lhs, rhs), + (DataType::Date32, DataType::Date32) | (DataType::Date64, DataType::Date64) ) } @@ -353,6 +378,16 @@ fn math_decimal_coercion( let (lhs_type, value_type) = math_decimal_coercion(lhs_type, value_type)?; Some((lhs_type, value_type)) } + (RunEndEncoded(_, field), _) => { + let (value_type, rhs_type) = + math_decimal_coercion(field.data_type(), rhs_type)?; + Some((value_type, rhs_type)) + } + (_, RunEndEncoded(_, field)) => { + let (lhs_type, value_type) = + math_decimal_coercion(lhs_type, field.data_type())?; + Some((lhs_type, value_type)) + } ( Null, Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), @@ -369,8 +404,8 @@ fn math_decimal_coercion( } // Cross-variant decimal coercion - choose larger variant with appropriate precision/scale (lhs, rhs) - if is_decimal(lhs) - && is_decimal(rhs) + if lhs.is_decimal() + && rhs.is_decimal() && std::mem::discriminant(lhs) != std::mem::discriminant(rhs) => { let coerced_type = get_wider_decimal_type_cross_variant(lhs_type, rhs_type)?; @@ -447,7 +482,9 @@ fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option for TypeCategory { fn from(data_type: &DataType) -> Self { match data_type { - // Dict is a special type in arrow, we check the value type + // Dict and REE are special types in arrow, we check the value type. DataType::Dictionary(_, v) => { let v = v.as_ref(); TypeCategory::from(v) } + DataType::RunEndEncoded(_, v) => TypeCategory::from(v.data_type()), _ => { if data_type.is_numeric() { return TypeCategory::Numeric; } - if matches!(data_type, DataType::Boolean) { + if *data_type == DataType::Boolean { return TypeCategory::Boolean; } @@ -552,8 +590,8 @@ impl From<&DataType> for TypeCategory { } /// Coerce dissimilar data types to a single data type. -/// UNION, INTERSECT, EXCEPT, CASE, ARRAY, VALUES, and the GREATEST and LEAST functions are -/// examples that has the similar resolution rules. +/// ARRAY literals, VALUES, COALESCE, and array concatenation are examples +/// of contexts that use this function. /// See for more information. /// The rules in the document provide a clue, but adhering strictly to them doesn't precisely /// align with the behavior of Postgres. Therefore, we've made slight adjustments to the rules @@ -672,6 +710,27 @@ fn type_union_resolution_coercion( None => None, } } + ( + DataType::RunEndEncoded(lhs_run, lhs_val), + DataType::RunEndEncoded(rhs_run, rhs_val), + ) => { + let new_run = + type_union_resolution_coercion(lhs_run.data_type(), rhs_run.data_type())?; + let new_val = + type_union_resolution_coercion(lhs_val.data_type(), rhs_val.data_type())?; + Some(DataType::RunEndEncoded( + Arc::new(lhs_run.as_ref().clone().with_data_type(new_run)), + Arc::new(lhs_val.as_ref().clone().with_data_type(new_val)), + )) + } + (DataType::RunEndEncoded(run, val), other) + | (other, DataType::RunEndEncoded(run, val)) => { + let new_val = type_union_resolution_coercion(val.data_type(), other)?; + Some(DataType::RunEndEncoded( + Arc::clone(run), + Arc::new(val.as_ref().clone().with_data_type(new_val)), + )) + } (DataType::Struct(lhs), DataType::Struct(rhs)) => { if lhs.len() != rhs.len() { return None; @@ -713,30 +772,27 @@ fn type_union_resolution_coercion( .collect(); Some(DataType::Struct(fields.into())) } - _ => { - // Numeric coercion is the same as comparison coercion, both find the narrowest type - // that can accommodate both types - binary_numeric_coercion(lhs_type, rhs_type) - .or_else(|| list_coercion(lhs_type, rhs_type)) - .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) - .or_else(|| string_coercion(lhs_type, rhs_type)) - .or_else(|| numeric_string_coercion(lhs_type, rhs_type)) - .or_else(|| binary_coercion(lhs_type, rhs_type)) - } + _ => binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| list_coercion(lhs_type, rhs_type, type_union_resolution_coercion)) + .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) + .or_else(|| string_coercion(lhs_type, rhs_type)) + .or_else(|| null_coercion(lhs_type, rhs_type)) + .or_else(|| string_numeric_coercion(lhs_type, rhs_type)) + .or_else(|| binary_coercion(lhs_type, rhs_type)), } } /// Handle type union resolution including struct type and others. pub fn try_type_union_resolution(data_types: &[DataType]) -> Result> { - let err = match try_type_union_resolution_with_struct(data_types) { + let struct_err = match try_type_union_resolution_with_struct(data_types) { Ok(struct_types) => return Ok(struct_types), - Err(e) => Some(e), + Err(e) => e, }; if let Some(new_type) = type_union_resolution(data_types) { Ok(vec![new_type; data_types.len()]) } else { - exec_err!("Fail to find the coerced type, errors: {:?}", err) + exec_err!("Fail to find the coerced type, errors: {struct_err}") } } @@ -751,7 +807,11 @@ pub fn try_type_union_resolution_with_struct( let keys = fields.iter().map(|f| f.name().to_owned()).join(","); if let Some(ref k) = keys_string { if *k != keys { - return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys); + return exec_err!( + "Expect same keys for struct type but got mismatched pair {} and {}", + *k, + keys + ); } } else { keys_string = Some(keys); @@ -765,7 +825,9 @@ pub fn try_type_union_resolution_with_struct( { fields.iter().map(|f| f.data_type().to_owned()).collect() } else { - return internal_err!("Struct type is checked is the previous function, so this should be unreachable"); + return internal_err!( + "Struct type is checked is the previous function, so this should be unreachable" + ); }; for data_type in data_types.iter().skip(1) { @@ -809,102 +871,105 @@ pub fn try_type_union_resolution_with_struct( Ok(final_struct_types) } -/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a -/// comparison operation -/// -/// Example comparison operations are `lhs = rhs` and `lhs > rhs` +/// Coerce `lhs_type` and `rhs_type` to a common type for type unification +/// contexts — where two values must be brought to a common type but are not +/// being compared. Examples: UNION, CASE THEN/ELSE branches, NVL2. For other +/// contexts, [`comparison_coercion`] should typically be used instead. /// -/// Binary comparison kernels require the two arguments to be the (exact) same -/// data type. However, users can write queries where the two arguments are -/// different data types. In such cases, the data types are automatically cast -/// (coerced) to a single data type to pass to the kernels. -/// -/// # Numeric comparisons -/// -/// When comparing numeric values, the lower precision type is coerced to the -/// higher precision type to avoid losing data. For example when comparing -/// `Int32` to `Int64` the coerced type is `Int64` so the `Int32` argument will -/// be cast. -/// -/// # Numeric / String comparisons -/// -/// When comparing numeric values and strings, both values will be coerced to -/// strings. For example when comparing `'2' > 1`, the arguments will be -/// coerced to `Utf8` for comparison -pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { +/// The intuition is that we try to find the "widest" type that can represent +/// all values from both sides. When one side is a string and the other is +/// numeric, this prefers strings because every number has a textual +/// representation but not every string can be parsed as a number (e.g., `SELECT +/// 1 UNION SELECT 'a'` coerces both sides to a string). This is in contrast to +/// [`comparison_coercion`], which prefers numeric types so that ordering and +/// equality follow numeric semantics. +pub fn type_union_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { if lhs_type.equals_datatype(rhs_type) { - // same type => equality is possible return Some(lhs_type.clone()); } binary_numeric_coercion(lhs_type, rhs_type) - .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, true)) - .or_else(|| ree_comparison_coercion(lhs_type, rhs_type, true)) + .or_else(|| dictionary_coercion(lhs_type, rhs_type, true, type_union_coercion)) + .or_else(|| ree_coercion(lhs_type, rhs_type, true, type_union_coercion)) .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) .or_else(|| string_coercion(lhs_type, rhs_type)) - .or_else(|| list_coercion(lhs_type, rhs_type)) + .or_else(|| list_coercion(lhs_type, rhs_type, type_union_coercion)) .or_else(|| null_coercion(lhs_type, rhs_type)) - .or_else(|| string_numeric_coercion(lhs_type, rhs_type)) + .or_else(|| string_numeric_union_coercion(lhs_type, rhs_type)) .or_else(|| string_temporal_coercion(lhs_type, rhs_type)) .or_else(|| binary_coercion(lhs_type, rhs_type)) - .or_else(|| struct_coercion(lhs_type, rhs_type)) - .or_else(|| map_coercion(lhs_type, rhs_type)) + .or_else(|| struct_coercion(lhs_type, rhs_type, type_union_coercion)) + .or_else(|| map_coercion(lhs_type, rhs_type, type_union_coercion)) } -/// Similar to [`comparison_coercion`] but prefers numeric if compares with -/// numeric and string +/// Coerce `lhs_type` and `rhs_type` to a common type for comparison +/// contexts — any context where two values are compared rather than +/// unified. This includes binary comparison operators, IN lists, +/// CASE/WHEN conditions, and BETWEEN. +/// +/// When the two types differ, this function determines the common type +/// to cast to. /// /// # Numeric comparisons /// -/// When comparing numeric values and strings, the values will be coerced to the -/// numeric type. For example, `'2' > 1` if `1` is an `Int32`, the arguments -/// will be coerced to `Int32`. -pub fn comparison_coercion_numeric( - lhs_type: &DataType, - rhs_type: &DataType, -) -> Option { - if lhs_type == rhs_type { +/// The lower precision type is widened to the higher precision type +/// (e.g., `Int32` vs `Int64` → `Int64`). +/// +/// # Numeric / String comparisons +/// +/// Prefers the numeric type (e.g., `'2' > 1` where `1` is `Int32` coerces +/// `'2'` to `Int32`). +/// +/// For type unification contexts (UNION, CASE THEN/ELSE), use +/// [`type_union_coercion`] instead. +pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + if lhs_type.equals_datatype(rhs_type) { // same type => equality is possible return Some(lhs_type.clone()); } binary_numeric_coercion(lhs_type, rhs_type) - .or_else(|| dictionary_comparison_coercion_numeric(lhs_type, rhs_type, true)) - .or_else(|| ree_comparison_coercion_numeric(lhs_type, rhs_type, true)) + .or_else(|| dictionary_coercion(lhs_type, rhs_type, true, comparison_coercion)) + .or_else(|| ree_coercion(lhs_type, rhs_type, true, comparison_coercion)) + .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) .or_else(|| string_coercion(lhs_type, rhs_type)) + .or_else(|| list_coercion(lhs_type, rhs_type, comparison_coercion)) .or_else(|| null_coercion(lhs_type, rhs_type)) - .or_else(|| string_numeric_coercion_as_numeric(lhs_type, rhs_type)) + .or_else(|| string_numeric_coercion(lhs_type, rhs_type)) + .or_else(|| string_temporal_coercion(lhs_type, rhs_type)) + .or_else(|| binary_coercion(lhs_type, rhs_type)) + .or_else(|| struct_coercion(lhs_type, rhs_type, comparison_coercion)) + .or_else(|| map_coercion(lhs_type, rhs_type, comparison_coercion)) + .or_else(|| union_coercion(lhs_type, rhs_type)) } -/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation -/// where one is numeric and one is `Utf8`/`LargeUtf8`. +/// Coerce a numeric/string pair to the numeric type. +/// +/// Used by [`comparison_coercion`]. fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (Utf8, _) if rhs_type.is_numeric() => Some(Utf8), - (LargeUtf8, _) if rhs_type.is_numeric() => Some(LargeUtf8), - (Utf8View, _) if rhs_type.is_numeric() => Some(Utf8View), - (_, Utf8) if lhs_type.is_numeric() => Some(Utf8), - (_, LargeUtf8) if lhs_type.is_numeric() => Some(LargeUtf8), - (_, Utf8View) if lhs_type.is_numeric() => Some(Utf8View), + (lhs, Utf8 | LargeUtf8 | Utf8View) if lhs.is_numeric() => Some(lhs.clone()), + (Utf8 | LargeUtf8 | Utf8View, rhs) if rhs.is_numeric() => Some(rhs.clone()), _ => None, } } -/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation -/// where one is numeric and one is `Utf8`/`LargeUtf8`. -fn string_numeric_coercion_as_numeric( +/// Coerce a numeric/string pair to the string type. +/// +/// Used by [`type_union_coercion`]. +fn string_numeric_union_coercion( lhs_type: &DataType, rhs_type: &DataType, ) -> Option { - let lhs_logical_type = NativeType::from(lhs_type); - let rhs_logical_type = NativeType::from(rhs_type); - if lhs_logical_type.is_numeric() && rhs_logical_type == NativeType::String { - return Some(lhs_type.to_owned()); - } - if rhs_logical_type.is_numeric() && lhs_logical_type == NativeType::String { - return Some(rhs_type.to_owned()); + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (lhs @ (Utf8 | LargeUtf8 | Utf8View), _) if rhs_type.is_numeric() => { + Some(lhs.clone()) + } + (_, rhs @ (Utf8 | LargeUtf8 | Utf8View)) if lhs_type.is_numeric() => { + Some(rhs.clone()) + } + _ => None, } - - None } /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation @@ -919,7 +984,7 @@ fn string_numeric_coercion_as_numeric( /// ``` /// /// In the absence of a full type inference system, we can't determine the correct type -/// to parse the string argument +/// to parse the string argument. fn string_temporal_coercion( lhs_type: &DataType, rhs_type: &DataType, @@ -939,7 +1004,7 @@ fn string_temporal_coercion( None } } - Timestamp(_, tz) => Some(Timestamp(TimeUnit::Nanosecond, tz.clone())), + Timestamp(_, tz) => Some(Timestamp(Nanosecond, tz.clone())), _ => None, } } @@ -979,8 +1044,8 @@ pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { @@ -988,8 +1053,8 @@ pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { @@ -1196,32 +1261,137 @@ fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option Option { +/// Coerce two struct types by recursively coercing their fields using +/// `coerce_fn` (either [`comparison_coercion`] or [`type_union_coercion`]). +fn struct_coercion( + lhs_type: &DataType, + rhs_type: &DataType, + coerce_fn: fn(&DataType, &DataType) -> Option, +) -> Option { use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { (Struct(lhs_fields), Struct(rhs_fields)) => { + // Field count must match for coercion if lhs_fields.len() != rhs_fields.len() { return None; } - let coerced_types = std::iter::zip(lhs_fields.iter(), rhs_fields.iter()) - .map(|(lhs, rhs)| comparison_coercion(lhs.data_type(), rhs.data_type())) - .collect::>>()?; - - // preserve the field name and nullability - let orig_fields = std::iter::zip(lhs_fields.iter(), rhs_fields.iter()); + // If the two structs have exactly the same set of field names (possibly in + // different order), prefer name-based coercion. Otherwise fall back to + // positional coercion which preserves backward compatibility. + // + // Name-based coercion is used in: + // 1. Array construction: [s1, s2] where s1 and s2 have reordered fields + // 2. UNION operations: different field orders unified by name + // 3. VALUES clauses: heterogeneous struct rows unified by field name + // 4. JOIN conditions: structs with matching field names + // 5. Window functions: partitions/orders by struct fields + // 6. Aggregate functions: collecting structs with reordered fields + // + // See docs/source/user-guide/sql/struct_coercion.md for detailed examples. + if fields_have_same_names(lhs_fields, rhs_fields) { + return coerce_struct_by_name(lhs_fields, rhs_fields, coerce_fn); + } - let fields: Vec = coerced_types - .into_iter() - .zip(orig_fields) - .map(|(datatype, (lhs, rhs))| coerce_fields(datatype, lhs, rhs)) - .collect(); - Some(Struct(fields.into())) + coerce_struct_by_position(lhs_fields, rhs_fields, coerce_fn) } _ => None, } } +/// Return true if every left-field name exists in the right fields (and lengths are equal). +/// +/// # Assumptions +/// **This function assumes field names within each struct are unique.** This assumption is safe +/// because field name uniqueness is enforced at multiple levels: +/// - **Arrow level:** `StructType` construction enforces unique field names at the schema level +/// - **DataFusion level:** SQL parser rejects duplicate field names in `CREATE TABLE` and struct type definitions +/// - **Runtime level:** `StructArray::try_new()` validates field uniqueness +/// +/// Therefore, we don't need to handle degenerate cases like: +/// - `struct -> struct` (target has duplicate field names) +/// - `struct -> struct` (source has duplicate field names) +fn fields_have_same_names(lhs_fields: &Fields, rhs_fields: &Fields) -> bool { + // Debug assertions: field names should be unique within each struct + #[cfg(debug_assertions)] + { + let lhs_names: HashSet<_> = lhs_fields.iter().map(|f| f.name()).collect(); + assert_eq!( + lhs_names.len(), + lhs_fields.len(), + "Struct has duplicate field names (should be caught by Arrow schema validation)" + ); + + let rhs_names_check: HashSet<_> = rhs_fields.iter().map(|f| f.name()).collect(); + assert_eq!( + rhs_names_check.len(), + rhs_fields.len(), + "Struct has duplicate field names (should be caught by Arrow schema validation)" + ); + } + + let rhs_names: HashSet<&str> = rhs_fields.iter().map(|f| f.name().as_str()).collect(); + lhs_fields + .iter() + .all(|lf| rhs_names.contains(lf.name().as_str())) +} + +/// Coerce two structs by matching fields by name using `coerce_fn`. +/// Assumes the name-sets match. +fn coerce_struct_by_name( + lhs_fields: &Fields, + rhs_fields: &Fields, + coerce_fn: fn(&DataType, &DataType) -> Option, +) -> Option { + use arrow::datatypes::DataType::*; + + let rhs_by_name: HashMap<&str, &FieldRef> = + rhs_fields.iter().map(|f| (f.name().as_str(), f)).collect(); + + let mut coerced: Vec = Vec::with_capacity(lhs_fields.len()); + + for lhs in lhs_fields.iter() { + let rhs = rhs_by_name.get(lhs.name().as_str()).unwrap(); // safe: caller ensured names match + let coerced_type = coerce_fn(lhs.data_type(), rhs.data_type())?; + let is_nullable = lhs.is_nullable() || rhs.is_nullable(); + coerced.push(Arc::new(Field::new( + lhs.name().clone(), + coerced_type, + is_nullable, + ))); + } + + Some(Struct(coerced.into())) +} + +/// Coerce two structs positionally (left-to-right) using `coerce_fn`. +/// Preserves field names from the left struct and uses combined nullability. +fn coerce_struct_by_position( + lhs_fields: &Fields, + rhs_fields: &Fields, + coerce_fn: fn(&DataType, &DataType) -> Option, +) -> Option { + use arrow::datatypes::DataType::*; + + // First coerce individual types; fail early if any pair cannot be coerced. + let coerced_types: Vec = lhs_fields + .iter() + .zip(rhs_fields.iter()) + .map(|(l, r)| coerce_fn(l.data_type(), r.data_type())) + .collect::>>()?; + + // Build final fields preserving left-side names and combined nullability. + let orig_pairs = lhs_fields.iter().zip(rhs_fields.iter()); + let fields: Vec = coerced_types + .into_iter() + .zip(orig_pairs) + .map(|(datatype, (lhs, rhs))| coerce_fields(datatype, lhs, rhs)) + .collect(); + + Some(Struct(fields.into())) +} + /// returns the result of coercing two fields to a common type fn coerce_fields(common_type: DataType, lhs: &FieldRef, rhs: &FieldRef) -> FieldRef { let is_nullable = lhs.is_nullable() || rhs.is_nullable(); @@ -1229,13 +1399,17 @@ fn coerce_fields(common_type: DataType, lhs: &FieldRef, rhs: &FieldRef) -> Field Arc::new(Field::new(name, common_type, is_nullable)) } -/// coerce two types if they are Maps by coercing their inner 'entries' fields' types -/// using struct coercion -fn map_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { +/// Coerce two Map types by coercing their inner entry fields using +/// `coerce_fn` (either [`comparison_coercion`] or [`type_union_coercion`]). +fn map_coercion( + lhs_type: &DataType, + rhs_type: &DataType, + coerce_fn: fn(&DataType, &DataType) -> Option, +) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { (Map(lhs_field, lhs_ordered), Map(rhs_field, rhs_ordered)) => { - struct_coercion(lhs_field.data_type(), rhs_field.data_type()).map( + struct_coercion(lhs_field.data_type(), rhs_field.data_type(), coerce_fn).map( |key_value_type| { Map( Arc::new((**lhs_field).clone().with_data_type(key_value_type)), @@ -1248,6 +1422,28 @@ fn map_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { } } +/// Coerce a Union and an "opaque" (non-Union) type for comparison. +/// +/// the resulting type is the opaque scalar type whenever any union variant +/// can be cast to it. at execution time, arrow's `cast(Union -> T)` extracts +/// values from the matching variant; rows whose active variant cannot be +/// cast to `T` become NULL. +/// +/// Identical union types are already handled by the `equals_datatype` fast path +/// in [`comparison_coercion`]; coercing between two different union types is not +/// supported. +fn union_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + + match (lhs_type, rhs_type) { + (Union(fields, _), opaque) | (opaque, Union(fields, _)) => fields + .iter() + .any(|(_, f)| can_cast_types(f.data_type(), opaque)) + .then(|| opaque.clone()), + _ => None, + } +} + /// Returns the output type of applying mathematics operations such as /// `+` to arguments of `lhs_type` and `rhs_type`. fn mathematics_numerical_coercion( @@ -1273,6 +1469,15 @@ fn mathematics_numerical_coercion( (_, Dictionary(_, value_type)) => { mathematics_numerical_coercion(lhs_type, value_type) } + (RunEndEncoded(_, lhs_field), RunEndEncoded(_, rhs_field)) => { + mathematics_numerical_coercion(lhs_field.data_type(), rhs_field.data_type()) + } + (RunEndEncoded(_, field), _) => { + mathematics_numerical_coercion(field.data_type(), rhs_type) + } + (_, RunEndEncoded(_, field)) => { + mathematics_numerical_coercion(lhs_type, field.data_type()) + } _ => numerical_coercion(lhs_type, rhs_type), } } @@ -1352,19 +1557,25 @@ fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> (_, Dictionary(_, value_type)) => { lhs_type.is_numeric() && value_type.is_numeric() } + (RunEndEncoded(_, lhs_field), RunEndEncoded(_, rhs_field)) => { + lhs_field.data_type().is_numeric() && rhs_field.data_type().is_numeric() + } + (RunEndEncoded(_, field), _) => { + field.data_type().is_numeric() && rhs_type.is_numeric() + } + (_, RunEndEncoded(_, field)) => { + lhs_type.is_numeric() && field.data_type().is_numeric() + } _ => lhs_type.is_numeric() && rhs_type.is_numeric(), } } -/// Generic coercion rules for Dictionaries: the type that both lhs and rhs -/// can be casted to for the purpose of a computation. -/// -/// Not all operators support dictionaries, if `preserve_dictionaries` is true -/// dictionaries will be preserved if possible. +/// Coerce two Dictionary types by coercing their value types using +/// `coerce_fn` (either [`comparison_coercion`] or [`type_union_coercion`]). /// -/// The `coerce_fn` parameter determines which comparison coercion function to use -/// for comparing the dictionary value types. -fn dictionary_comparison_coercion_generic( +/// If `preserve_dictionaries` is true, dictionaries will be preserved +/// when possible. +fn dictionary_coercion( lhs_type: &DataType, rhs_type: &DataType, preserve_dictionaries: bool, @@ -1388,52 +1599,11 @@ fn dictionary_comparison_coercion_generic( } } -/// Coercion rules for Dictionaries: the type that both lhs and rhs -/// can be casted to for the purpose of a computation. -/// -/// Not all operators support dictionaries, if `preserve_dictionaries` is true -/// dictionaries will be preserved if possible -fn dictionary_comparison_coercion( - lhs_type: &DataType, - rhs_type: &DataType, - preserve_dictionaries: bool, -) -> Option { - dictionary_comparison_coercion_generic( - lhs_type, - rhs_type, - preserve_dictionaries, - comparison_coercion, - ) -} - -/// Coercion rules for Dictionaries with numeric preference: similar to -/// [`dictionary_comparison_coercion`] but uses [`comparison_coercion_numeric`] -/// which prefers numeric types over strings when both are present. +/// Coerce two RunEndEncoded types using `coerce_fn` +/// (either [`comparison_coercion`] or [`type_union_coercion`]). /// -/// This is used by [`comparison_coercion_numeric`] to maintain consistent -/// numeric-preferring semantics when dealing with dictionary types. -fn dictionary_comparison_coercion_numeric( - lhs_type: &DataType, - rhs_type: &DataType, - preserve_dictionaries: bool, -) -> Option { - dictionary_comparison_coercion_generic( - lhs_type, - rhs_type, - preserve_dictionaries, - comparison_coercion_numeric, - ) -} - -/// Coercion rules for RunEndEncoded: the type that both lhs and rhs -/// can be casted to for the purpose of a computation. -/// -/// Not all operators support REE, if `preserve_ree` is true -/// REE will be preserved if possible -/// -/// The `coerce_fn` parameter determines which comparison coercion function to use -/// for comparing the REE value types. -fn ree_comparison_coercion_generic( +/// If `preserve_ree` is true, REE will be preserved when possible. +fn ree_coercion( lhs_type: &DataType, rhs_type: &DataType, preserve_ree: bool, @@ -1460,45 +1630,41 @@ fn ree_comparison_coercion_generic( } } -/// Coercion rules for RunEndEncoded: the type that both lhs and rhs -/// can be casted to for the purpose of a computation. -/// -/// Not all operators support REE, if `preserve_ree` is true -/// REE will be preserved if possible -fn ree_comparison_coercion( - lhs_type: &DataType, - rhs_type: &DataType, - preserve_ree: bool, -) -> Option { - ree_comparison_coercion_generic(lhs_type, rhs_type, preserve_ree, comparison_coercion) -} - -/// Coercion rules for RunEndEncoded with numeric preference: similar to -/// [`ree_comparison_coercion`] but uses [`comparison_coercion_numeric`] -/// which prefers numeric types over strings when both are present. -/// -/// This is used by [`comparison_coercion_numeric`] to maintain consistent -/// numeric-preferring semantics when dealing with REE types. -fn ree_comparison_coercion_numeric( - lhs_type: &DataType, - rhs_type: &DataType, - preserve_ree: bool, -) -> Option { - ree_comparison_coercion_generic( - lhs_type, - rhs_type, - preserve_ree, - comparison_coercion_numeric, - ) -} - /// Coercion rules for string concat. /// This is a union of string coercion rules and specified rules: /// 1. At least one side of lhs and rhs should be string type (Utf8 / LargeUtf8) /// 2. Data type of the other side should be able to cast to string type +/// 3. Binary and string types cannot be mixed fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; string_coercion(lhs_type, rhs_type).or_else(|| match (lhs_type, rhs_type) { + // Allow pure binary + binary + ( + Binary | LargeBinary | BinaryView | FixedSizeBinary(_), + Binary | LargeBinary | BinaryView | FixedSizeBinary(_), + ) => { + // Coerce fixed-sized binary to variable-sized `Binary` to make uniform signature + // with the `Binary` result + let lhs_type = match lhs_type { + FixedSizeBinary(_) => &Binary, + val => val, + }; + let rhs_type = match rhs_type { + FixedSizeBinary(_) => &Binary, + val => val, + }; + binary_coercion(lhs_type, rhs_type) + } + // Deny other mixed binary + string combinations + ( + Binary | LargeBinary | BinaryView | FixedSizeBinary(_), + Utf8 | LargeUtf8 | Utf8View, + ) => None, + ( + Utf8 | LargeUtf8 | Utf8View, + Binary | LargeBinary | BinaryView | FixedSizeBinary(_), + ) => None, + // Predicate-based coercion rules are following (Utf8View, from_type) | (from_type, Utf8View) => { string_concat_internal_coercion(from_type, &Utf8View) } @@ -1556,32 +1722,28 @@ pub fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { - use arrow::datatypes::DataType::*; - match (lhs_type, rhs_type) { - (Utf8 | LargeUtf8 | Utf8View, other_type) - | (other_type, Utf8 | LargeUtf8 | Utf8View) - if other_type.is_numeric() => - { - Some(other_type.clone()) - } - _ => None, - } -} - -/// Coerces two fields together, ensuring the field data (name and nullability) is correctly set. -fn coerce_list_children(lhs_field: &FieldRef, rhs_field: &FieldRef) -> Option { - let data_types = vec![lhs_field.data_type().clone(), rhs_field.data_type().clone()]; +/// Coerce two list element fields to a common type using the provided +/// coercion function for element types. +fn coerce_list_children( + lhs_field: &FieldRef, + rhs_field: &FieldRef, + coerce_fn: fn(&DataType, &DataType) -> Option, +) -> Option { Some(Arc::new( (**lhs_field) .clone() - .with_data_type(type_union_resolution(&data_types)?) + .with_data_type(coerce_fn(lhs_field.data_type(), rhs_field.data_type())?) .with_nullable(lhs_field.is_nullable() || rhs_field.is_nullable()), )) } -/// Coercion rules for list types. -fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { +/// Coerce two list types by coercing their element types via `coerce_fn` +/// (either [`comparison_coercion`] or [`type_union_coercion`]). +fn list_coercion( + lhs_type: &DataType, + rhs_type: &DataType, + coerce_fn: fn(&DataType, &DataType) -> Option, +) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { // Coerce to the left side FixedSizeList type if the list lengths are the same, @@ -1589,11 +1751,11 @@ fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { (FixedSizeList(lhs_field, ls), FixedSizeList(rhs_field, rs)) => { if ls == rs { Some(FixedSizeList( - coerce_list_children(lhs_field, rhs_field)?, + coerce_list_children(lhs_field, rhs_field, coerce_fn)?, *rs, )) } else { - Some(List(coerce_list_children(lhs_field, rhs_field)?)) + Some(List(coerce_list_children(lhs_field, rhs_field, coerce_fn)?)) } } // LargeList on any side @@ -1601,13 +1763,13 @@ fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { LargeList(lhs_field), List(rhs_field) | LargeList(rhs_field) | FixedSizeList(rhs_field, _), ) - | (List(lhs_field) | FixedSizeList(lhs_field, _), LargeList(rhs_field)) => { - Some(LargeList(coerce_list_children(lhs_field, rhs_field)?)) - } + | (List(lhs_field) | FixedSizeList(lhs_field, _), LargeList(rhs_field)) => Some( + LargeList(coerce_list_children(lhs_field, rhs_field, coerce_fn)?), + ), // Lists on both sides (List(lhs_field), List(rhs_field) | FixedSizeList(rhs_field, _)) | (FixedSizeList(lhs_field, _), List(rhs_field)) => { - Some(List(coerce_list_children(lhs_field, rhs_field)?)) + Some(List(coerce_list_children(lhs_field, rhs_field, coerce_fn)?)) } _ => None, } @@ -1644,13 +1806,17 @@ pub fn binary_to_string_coercion( fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { + // Prefer symmetric coercion (in case the function is called directly) + (Binary, Binary) => Some(Binary), + (LargeBinary, LargeBinary) => Some(LargeBinary), + (BinaryView, BinaryView) => Some(BinaryView), // If BinaryView is in any side, we coerce to BinaryView. - (BinaryView, BinaryView | Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View) + (BinaryView, Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View) | (LargeBinary | Binary | Utf8 | LargeUtf8 | Utf8View, BinaryView) => { Some(BinaryView) } // Prefer LargeBinary over Binary - (LargeBinary | Binary | Utf8 | LargeUtf8 | Utf8View, LargeBinary) + (Binary | Utf8 | LargeUtf8 | Utf8View, LargeBinary) | (LargeBinary, Binary | Utf8 | LargeUtf8 | Utf8View) => Some(LargeBinary), // If Utf8View/LargeUtf8 presents need to be large Binary @@ -1671,12 +1837,13 @@ fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option /// Coercion rules for like operations. /// This is a union of string coercion rules, dictionary coercion rules, and REE coercion rules +/// Note: list_coercion is intentionally NOT included here because LIKE is a string pattern +/// matching operation and is not supported for nested types (List, Struct, etc.) pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_coercion(lhs_type, rhs_type) - .or_else(|| list_coercion(lhs_type, rhs_type)) .or_else(|| binary_to_string_coercion(lhs_type, rhs_type)) - .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false)) - .or_else(|| ree_comparison_coercion(lhs_type, rhs_type, false)) + .or_else(|| dictionary_coercion(lhs_type, rhs_type, false, like_coercion)) + .or_else(|| ree_coercion(lhs_type, rhs_type, false, like_coercion)) .or_else(|| regex_null_coercion(lhs_type, rhs_type)) .or_else(|| null_coercion(lhs_type, rhs_type)) } @@ -1693,10 +1860,11 @@ fn regex_null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { string_coercion(lhs_type, rhs_type) - .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false)) + .or_else(|| dictionary_coercion(lhs_type, rhs_type, false, regex_coercion)) + .or_else(|| ree_coercion(lhs_type, rhs_type, false, regex_coercion)) .or_else(|| regex_null_coercion(lhs_type, rhs_type)) } @@ -1706,10 +1874,10 @@ pub fn regex_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option bool { matches!( datatype, - &DataType::Time32(TimeUnit::Second) - | &DataType::Time32(TimeUnit::Millisecond) - | &DataType::Time64(TimeUnit::Microsecond) - | &DataType::Time64(TimeUnit::Nanosecond) + &DataType::Time32(Second) + | &DataType::Time32(Millisecond) + | &DataType::Time64(Microsecond) + | &DataType::Time64(Nanosecond) ) } @@ -1795,6 +1963,73 @@ fn temporal_coercion_strict_timezone( } } +fn temporal_math_coercion( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option<(DataType, DataType)> { + use DataType::*; + + match (lhs_type, rhs_type) { + // Coerce Date + int -> Date + Interval + (Date32, Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64) => { + Some((Date32, Interval(MonthDayNano))) + } + (Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, Date32) => { + Some((Interval(MonthDayNano), Date32)) + } + (Date64, Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64) => { + Some((Date64, Interval(MonthDayNano))) + } + (Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, Date64) => { + Some((Interval(MonthDayNano), Date64)) + } + // Coerce Date + time -> timestamp + Duration + (Date32, Time32(_)) => Some((Timestamp(Nanosecond, None), Duration(Nanosecond))), + (Time32(_), Date32) => Some((Duration(Nanosecond), Timestamp(Nanosecond, None))), + + (Date32, Time64(_)) => Some((Timestamp(Nanosecond, None), Duration(Nanosecond))), + (Time64(_), Date32) => Some((Duration(Nanosecond), Timestamp(Nanosecond, None))), + + (Date64, Time32(_)) => Some((Timestamp(Nanosecond, None), Duration(Nanosecond))), + (Time32(_), Date64) => Some((Duration(Nanosecond), Timestamp(Nanosecond, None))), + + (Date64, Time64(_)) => Some((Timestamp(Nanosecond, None), Duration(Nanosecond))), + (Time64(_), Date64) => Some((Duration(Nanosecond), Timestamp(Nanosecond, None))), + + // Coerce Duration to match Timestamp's unit, + // e.g. Timestamp(ms) + Duration(s) → Timestamp(ms) + Duration(ms) + (Timestamp(ts_unit, tz), Duration(_)) => { + Some((Timestamp(*ts_unit, tz.clone()), Duration(*ts_unit))) + } + (Duration(_), Timestamp(ts_unit, tz)) => { + Some((Duration(*ts_unit), Timestamp(*ts_unit, tz.clone()))) + } + // time - time -> Interval + (Time32(_) | Time64(_), Time32(_) | Time64(_)) => { + Some((Interval(MonthDayNano), Interval(MonthDayNano))) + } + // time + interval -> Interval + (Time32(_) | Time64(_), Interval(_)) => { + Some((Interval(MonthDayNano), Interval(MonthDayNano))) + } + (Interval(_), Time32(_) | Time64(_)) => { + Some((Interval(MonthDayNano), Interval(MonthDayNano))) + } + // Interval * number => Interval + ( + Interval(_), + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 | Float16 + | Float32 | Float64, + ) => Some((Interval(MonthDayNano), Interval(MonthDayNano))), + ( + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 | Float16 + | Float32 | Float64, + Interval(_), + ) => Some((Interval(MonthDayNano), Interval(MonthDayNano))), + _ => None, + } +} + fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; use arrow::datatypes::IntervalUnit::*; @@ -1804,7 +2039,19 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { Some(Interval(MonthDayNano)) } + (Date32, Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64) + | (Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, Date32) => { + Some(Date32) + } + (Date64, Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64) + | (Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, Date64) => { + Some(Date64) + } (Date64, Date32) | (Date32, Date64) => Some(Date64), + (Date32, Time32(_)) | (Time32(_), Date32) => Some(Timestamp(Nanosecond, None)), + (Date32, Time64(_)) | (Time64(_), Date32) => Some(Timestamp(Nanosecond, None)), + (Date64, Time32(_)) | (Time32(_), Date64) => Some(Timestamp(Nanosecond, None)), + (Date64, Time64(_)) | (Time64(_), Date64) => Some(Timestamp(Nanosecond, None)), (Timestamp(_, None), Date64) | (Date64, Timestamp(_, None)) => { Some(Timestamp(Nanosecond, None)) } @@ -1824,22 +2071,10 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option TimeUnit { use arrow::datatypes::TimeUnit::*; match (lhs_unit, rhs_unit) { - (Second, Millisecond) => Second, - (Second, Microsecond) => Second, - (Second, Nanosecond) => Second, - (Millisecond, Second) => Second, - (Millisecond, Microsecond) => Millisecond, - (Millisecond, Nanosecond) => Millisecond, - (Microsecond, Second) => Second, - (Microsecond, Millisecond) => Millisecond, - (Microsecond, Nanosecond) => Microsecond, - (Nanosecond, Second) => Second, - (Nanosecond, Millisecond) => Millisecond, - (Nanosecond, Microsecond) => Microsecond, - (l, r) => { - assert_eq!(l, r); - *l - } + (Second, Second) => Second, + (Nanosecond, _) | (_, Nanosecond) => Nanosecond, + (Microsecond, _) | (_, Microsecond) => Microsecond, + (Millisecond, _) | (_, Millisecond) => Millisecond, } } diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs index 63945a4dabd0c..70a8fc0e35a15 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs @@ -25,20 +25,23 @@ fn test_coercion_error() -> Result<()> { let result_type = coercer.get_input_types(); let e = result_type.unwrap_err(); - assert_eq!(e.strip_backtrace(), "Error during planning: Cannot coerce arithmetic expression Float32 + Utf8 to valid types"); + assert_eq!( + e.strip_backtrace(), + "Error during planning: Cannot coerce arithmetic expression Float32 + Utf8 to valid types" + ); Ok(()) } #[test] fn test_date_timestamp_arithmetic_error() -> Result<()> { let (lhs, rhs) = BinaryTypeCoercer::new( - &DataType::Timestamp(TimeUnit::Nanosecond, None), + &DataType::Timestamp(Nanosecond, None), &Operator::Minus, - &DataType::Timestamp(TimeUnit::Millisecond, None), + &DataType::Timestamp(Millisecond, None), ) .get_input_types()?; - assert_eq!(lhs, DataType::Timestamp(TimeUnit::Millisecond, None)); - assert_eq!(rhs, DataType::Timestamp(TimeUnit::Millisecond, None)); + assert_eq!(lhs, DataType::Timestamp(Nanosecond, None)); + assert_eq!(rhs, DataType::Timestamp(Nanosecond, None)); let err = BinaryTypeCoercer::new(&DataType::Date32, &Operator::Plus, &DataType::Date64) @@ -146,14 +149,18 @@ fn test_type_coercion_arithmetic() -> Result<()> { // (_, Float32) | (Float32, _) => Some(Float32) test_coercion_binary_rule_multiple!( Float32, - [Float32, Float16, Int64, UInt64, Int32, UInt32, Int16, UInt16, Int8, UInt8], + [ + Float32, Float16, Int64, UInt64, Int32, UInt32, Int16, UInt16, Int8, UInt8 + ], Operator::Plus, Float32 ); // (_, Float16) | (Float16, _) => Some(Float16) test_coercion_binary_rule_multiple!( Float16, - [Float16, Int64, UInt64, Int32, UInt32, Int16, UInt16, Int8, UInt8], + [ + Float16, Int64, UInt64, Int32, UInt32, Int16, UInt16, Int8, UInt8 + ], Operator::Plus, Float16 ); @@ -221,6 +228,53 @@ fn test_type_coercion_arithmetic() -> Result<()> { Ok(()) } +#[test] +fn test_bitwise_coercion_non_integer_types() -> Result<()> { + let err = BinaryTypeCoercer::new( + &DataType::Float32, + &Operator::BitwiseAnd, + &DataType::Float32, + ) + .get_input_types() + .unwrap_err() + .to_string(); + assert_contains!( + &err, + "Cannot infer common type for bitwise operation Float32 & Float32" + ); + + let err = BinaryTypeCoercer::new( + &DataType::Float32, + &Operator::BitwiseAnd, + &DataType::Float64, + ) + .get_input_types() + .unwrap_err() + .to_string(); + assert_contains!( + &err, + "Cannot infer common type for bitwise operation Float32 & Float64" + ); + + let err = BinaryTypeCoercer::new( + &DataType::Decimal128(10, 2), + &Operator::BitwiseAnd, + &DataType::Decimal128(10, 2), + ) + .get_input_types() + .unwrap_err() + .to_string(); + assert_contains!( + &err, + "Cannot infer common type for bitwise operation Decimal128(10, 2) & Decimal128(10, 2)" + ); + + let dict_int8 = DataType::Dictionary(DataType::Int8.into(), DataType::Int8.into()); + test_coercion_binary_rule!(dict_int8, dict_int8, Operator::BitwiseAnd, dict_int8); + + Ok(()) +} + fn test_math_decimal_coercion_rule( lhs_type: DataType, rhs_type: DataType, diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs b/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs index 5401264e43e39..5f6b7dfcc1d4f 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs @@ -122,51 +122,51 @@ fn test_type_coercion() -> Result<()> { ); test_coercion_binary_rule!( DataType::Utf8, - DataType::Time32(TimeUnit::Second), + DataType::Time32(Second), Operator::Eq, - DataType::Time32(TimeUnit::Second) + DataType::Time32(Second) ); test_coercion_binary_rule!( DataType::Utf8, - DataType::Time32(TimeUnit::Millisecond), + DataType::Time32(Millisecond), Operator::Eq, - DataType::Time32(TimeUnit::Millisecond) + DataType::Time32(Millisecond) ); test_coercion_binary_rule!( DataType::Utf8, - DataType::Time64(TimeUnit::Microsecond), + DataType::Time64(Microsecond), Operator::Eq, - DataType::Time64(TimeUnit::Microsecond) + DataType::Time64(Microsecond) ); test_coercion_binary_rule!( DataType::Utf8, - DataType::Time64(TimeUnit::Nanosecond), + DataType::Time64(Nanosecond), Operator::Eq, - DataType::Time64(TimeUnit::Nanosecond) + DataType::Time64(Nanosecond) ); test_coercion_binary_rule!( DataType::Utf8, - DataType::Timestamp(TimeUnit::Second, None), + DataType::Timestamp(Second, None), Operator::Lt, - DataType::Timestamp(TimeUnit::Nanosecond, None) + DataType::Timestamp(Nanosecond, None) ); test_coercion_binary_rule!( DataType::Utf8, - DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Timestamp(Millisecond, None), Operator::Lt, - DataType::Timestamp(TimeUnit::Nanosecond, None) + DataType::Timestamp(Nanosecond, None) ); test_coercion_binary_rule!( DataType::Utf8, - DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(Microsecond, None), Operator::Lt, - DataType::Timestamp(TimeUnit::Nanosecond, None) + DataType::Timestamp(Nanosecond, None) ); test_coercion_binary_rule!( DataType::Utf8, - DataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::Timestamp(Nanosecond, None), Operator::Lt, - DataType::Timestamp(TimeUnit::Nanosecond, None) + DataType::Timestamp(Nanosecond, None) ); test_coercion_binary_rule!( DataType::Utf8, @@ -552,28 +552,46 @@ fn test_type_coercion_compare() -> Result<()> { // Timestamps let utc: Option> = Some("UTC".into()); test_coercion_binary_rule!( - DataType::Timestamp(TimeUnit::Second, utc.clone()), - DataType::Timestamp(TimeUnit::Second, utc.clone()), + DataType::Timestamp(Second, utc.clone()), + DataType::Timestamp(Second, utc.clone()), Operator::Eq, - DataType::Timestamp(TimeUnit::Second, utc.clone()) + DataType::Timestamp(Second, utc.clone()) ); test_coercion_binary_rule!( - DataType::Timestamp(TimeUnit::Second, utc.clone()), - DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())), + DataType::Timestamp(Second, utc.clone()), + DataType::Timestamp(Second, Some("Europe/Brussels".into())), Operator::Eq, - DataType::Timestamp(TimeUnit::Second, utc.clone()) + DataType::Timestamp(Second, utc.clone()) ); test_coercion_binary_rule!( - DataType::Timestamp(TimeUnit::Second, Some("America/New_York".into())), - DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())), + DataType::Timestamp(Second, Some("America/New_York".into())), + DataType::Timestamp(Second, Some("Europe/Brussels".into())), Operator::Eq, - DataType::Timestamp(TimeUnit::Second, Some("America/New_York".into())) + DataType::Timestamp(Second, Some("America/New_York".into())) ); test_coercion_binary_rule!( - DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())), - DataType::Timestamp(TimeUnit::Second, utc), + DataType::Timestamp(Second, Some("Europe/Brussels".into())), + DataType::Timestamp(Second, utc), Operator::Eq, - DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())) + DataType::Timestamp(Second, Some("Europe/Brussels".into())) + ); + test_coercion_binary_rule!( + DataType::Timestamp(Second, None), + DataType::Timestamp(Millisecond, None), + Operator::Eq, + DataType::Timestamp(Millisecond, None) + ); + test_coercion_binary_rule!( + DataType::Timestamp(Second, Some("America/New_York".into())), + DataType::Timestamp(Nanosecond, Some("Europe/Brussels".into())), + Operator::Lt, + DataType::Timestamp(Nanosecond, Some("America/New_York".into())) + ); + test_coercion_binary_rule!( + DataType::Timestamp(Microsecond, None), + DataType::Timestamp(Nanosecond, None), + Operator::GtEq, + DataType::Timestamp(Nanosecond, None) ); // list @@ -634,7 +652,7 @@ fn test_type_coercion_compare() -> Result<()> { ); let inner_timestamp_field = Arc::new(Field::new_list_field( - DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(Microsecond, None), true, )); let result_type = BinaryTypeCoercer::new( @@ -654,7 +672,7 @@ fn test_list_coercion() { let rhs_type = DataType::List(Arc::new(Field::new("rhs", DataType::Int64, true))); - let coerced_type = list_coercion(&lhs_type, &rhs_type).unwrap(); + let coerced_type = list_coercion(&lhs_type, &rhs_type, comparison_coercion).unwrap(); assert_eq!( coerced_type, DataType::List(Arc::new(Field::new("lhs", DataType::Int64, true))) @@ -778,10 +796,246 @@ fn test_decimal_cross_variant_comparison_coercion() -> Result<()> { for op in comparison_op_types { let (lhs, rhs) = BinaryTypeCoercer::new(&lhs_type, &op, &rhs_type).get_input_types()?; - assert_eq!(expected_type, lhs, "Coercion of type {lhs_type:?} with {rhs_type:?} resulted in unexpected type: {lhs:?}"); - assert_eq!(expected_type, rhs, "Coercion of type {rhs_type:?} with {lhs_type:?} resulted in unexpected type: {rhs:?}"); + assert_eq!( + expected_type, lhs, + "Coercion of type {lhs_type:?} with {rhs_type:?} resulted in unexpected type: {lhs:?}" + ); + assert_eq!( + expected_type, rhs, + "Coercion of type {rhs_type:?} with {lhs_type:?} resulted in unexpected type: {rhs:?}" + ); } } Ok(()) } + +/// Tests that `comparison_coercion` prefers the numeric type when one side is +/// numeric and the other is a string (e.g., `numeric_col < '123'`). +#[test] +fn test_comparison_coercion_prefers_numeric() { + assert_eq!( + comparison_coercion(&DataType::Int32, &DataType::Utf8), + Some(DataType::Int32) + ); + assert_eq!( + comparison_coercion(&DataType::Utf8, &DataType::Int32), + Some(DataType::Int32) + ); + assert_eq!( + comparison_coercion(&DataType::Utf8, &DataType::Float64), + Some(DataType::Float64) + ); + assert_eq!( + comparison_coercion(&DataType::Float64, &DataType::Utf8), + Some(DataType::Float64) + ); + assert_eq!( + comparison_coercion(&DataType::Int64, &DataType::LargeUtf8), + Some(DataType::Int64) + ); + assert_eq!( + comparison_coercion(&DataType::Utf8View, &DataType::Int16), + Some(DataType::Int16) + ); + // String-string stays string + assert_eq!( + comparison_coercion(&DataType::Utf8, &DataType::Utf8), + Some(DataType::Utf8) + ); + // Numeric-numeric stays numeric + assert_eq!( + comparison_coercion(&DataType::Int32, &DataType::Int64), + Some(DataType::Int64) + ); +} + +/// Tests that `type_union_coercion` prefers the string type when unifying +/// numeric and string types (for UNION, CASE THEN/ELSE, etc.). +#[test] +fn test_type_union_coercion_prefers_string() { + assert_eq!( + type_union_coercion(&DataType::Int32, &DataType::Utf8), + Some(DataType::Utf8) + ); + assert_eq!( + type_union_coercion(&DataType::Utf8, &DataType::Int32), + Some(DataType::Utf8) + ); + assert_eq!( + type_union_coercion(&DataType::Float64, &DataType::Utf8), + Some(DataType::Utf8) + ); + assert_eq!( + type_union_coercion(&DataType::Utf8, &DataType::Float64), + Some(DataType::Utf8) + ); + assert_eq!( + type_union_coercion(&DataType::Int64, &DataType::LargeUtf8), + Some(DataType::LargeUtf8) + ); + assert_eq!( + type_union_coercion(&DataType::Utf8View, &DataType::Int16), + Some(DataType::Utf8View) + ); + // String-string stays string + assert_eq!( + type_union_coercion(&DataType::Utf8, &DataType::Utf8), + Some(DataType::Utf8) + ); + // Numeric-numeric stays numeric + assert_eq!( + type_union_coercion(&DataType::Int32, &DataType::Int64), + Some(DataType::Int64) + ); +} + +#[test] +fn test_type_union_coercion_prefers_finer_timestamp_unit() { + assert_eq!( + type_union_coercion( + &DataType::Timestamp(Second, None), + &DataType::Timestamp(Millisecond, None), + ), + Some(DataType::Timestamp(Millisecond, None)) + ); + assert_eq!( + type_union_resolution(&[ + DataType::Timestamp(Second, None), + DataType::Timestamp(Nanosecond, None), + ]), + Some(DataType::Timestamp(Nanosecond, None)) + ); +} + +/// Tests that comparison operators coerce to numeric when comparing +/// numeric and string types. +#[test] +fn test_binary_comparison_string_numeric_coercion() -> Result<()> { + let comparison_ops = [ + Operator::Eq, + Operator::NotEq, + Operator::Lt, + Operator::LtEq, + Operator::Gt, + Operator::GtEq, + ]; + for op in &comparison_ops { + let (lhs, rhs) = BinaryTypeCoercer::new(&DataType::Int64, op, &DataType::Utf8) + .get_input_types()?; + assert_eq!(lhs, DataType::Int64, "Op {op}: Int64 vs Utf8 -> lhs"); + assert_eq!(rhs, DataType::Int64, "Op {op}: Int64 vs Utf8 -> rhs"); + + let (lhs, rhs) = BinaryTypeCoercer::new(&DataType::Utf8, op, &DataType::Float64) + .get_input_types()?; + assert_eq!(lhs, DataType::Float64, "Op {op}: Utf8 vs Float64 -> lhs"); + assert_eq!(rhs, DataType::Float64, "Op {op}: Utf8 vs Float64 -> rhs"); + } + Ok(()) +} + +#[test] +fn test_string_concat_coercion() -> Result<()> { + // Binary + test_coercion_binary_rule!( + DataType::Binary, + DataType::Binary, + Operator::StringConcat, + DataType::Binary + ); + test_coercion_binary_rule!( + DataType::LargeBinary, + DataType::LargeBinary, + Operator::StringConcat, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::BinaryView, + DataType::BinaryView, + Operator::StringConcat, + DataType::BinaryView + ); + test_coercion_binary_rule!( + DataType::Binary, + DataType::LargeBinary, + Operator::StringConcat, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::BinaryView, + DataType::Binary, + Operator::StringConcat, + DataType::BinaryView + ); + test_coercion_binary_rule!( + DataType::FixedSizeBinary(4), + DataType::FixedSizeBinary(16), + Operator::StringConcat, + DataType::Binary + ); + test_coercion_binary_rule!( + DataType::FixedSizeBinary(4), + DataType::LargeBinary, + Operator::StringConcat, + DataType::LargeBinary + ); + test_coercion_binary_rule!( + DataType::FixedSizeBinary(4), + DataType::BinaryView, + Operator::StringConcat, + DataType::BinaryView + ); + + // String + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8, + Operator::StringConcat, + DataType::Utf8 + ); + test_coercion_binary_rule!( + DataType::LargeUtf8, + DataType::LargeUtf8, + Operator::StringConcat, + DataType::LargeUtf8 + ); + test_coercion_binary_rule!( + DataType::Utf8View, + DataType::Utf8View, + Operator::StringConcat, + DataType::Utf8View + ); + + // Mixed string-binary + for string_dt in [DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View] { + for binary_dt in [ + DataType::Binary, + DataType::LargeBinary, + DataType::BinaryView, + DataType::FixedSizeBinary(8), + ] { + assert!( + BinaryTypeCoercer::new(&binary_dt, &Operator::StringConcat, &string_dt,) + .get_input_types() + .is_err(), + "{binary_dt} || {string_dt}" + ); + assert!( + BinaryTypeCoercer::new(&string_dt, &Operator::StringConcat, &binary_dt,) + .get_input_types() + .is_err(), + "{string_dt} || {binary_dt}" + ); + } + } + + // Mixed string-other + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Timestamp(Second, None), + Operator::StringConcat, + DataType::Utf8 + ); + + Ok(()) +} diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/dictionary.rs b/datafusion/expr-common/src/type_coercion/binary/tests/dictionary.rs index 0fb56a4a2c536..f0aadfd3ce3a5 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/dictionary.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/dictionary.rs @@ -24,49 +24,49 @@ fn test_dictionary_type_coercion() { let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32)); let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + dictionary_coercion(&lhs_type, &rhs_type, true, comparison_coercion), Some(Int32) ); assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, false), + dictionary_coercion(&lhs_type, &rhs_type, false, comparison_coercion), Some(Int32) ); - // Since we can coerce values of Int16 to Utf8 can support this + // In comparison context, numeric is preferred over string let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, true), - Some(Utf8) + dictionary_coercion(&lhs_type, &rhs_type, true, comparison_coercion), + Some(Int16) ); // Since we can coerce values of Utf8 to Binary can support this let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Dictionary(Box::new(Int8), Box::new(Binary)); assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + dictionary_coercion(&lhs_type, &rhs_type, true, comparison_coercion), Some(Binary) ); let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Utf8; assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, false), + dictionary_coercion(&lhs_type, &rhs_type, false, comparison_coercion), Some(Utf8) ); assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + dictionary_coercion(&lhs_type, &rhs_type, true, comparison_coercion), Some(lhs_type.clone()) ); let lhs_type = Utf8; let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, false), + dictionary_coercion(&lhs_type, &rhs_type, false, comparison_coercion), Some(Utf8) ); assert_eq!( - dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + dictionary_coercion(&lhs_type, &rhs_type, true, comparison_coercion), Some(rhs_type.clone()) ); } diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/run_end_encoded.rs b/datafusion/expr-common/src/type_coercion/binary/tests/run_end_encoded.rs index 9997db7a82688..38e9fb3908d5b 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/run_end_encoded.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/run_end_encoded.rs @@ -16,11 +16,17 @@ // under the License. use super::*; +use DataType::*; + +fn ree(value_type: DataType) -> DataType { + RunEndEncoded( + Arc::new(Field::new("run_ends", Int32, false)), + Arc::new(Field::new("values", value_type, false)), + ) +} #[test] fn test_ree_type_coercion() { - use DataType::*; - let lhs_type = RunEndEncoded( Arc::new(Field::new("run_ends", Int8, false)), Arc::new(Field::new("values", Int32, false)), @@ -30,15 +36,15 @@ fn test_ree_type_coercion() { Arc::new(Field::new("values", Int16, false)), ); assert_eq!( - ree_comparison_coercion(&lhs_type, &rhs_type, true), + ree_coercion(&lhs_type, &rhs_type, true, comparison_coercion), Some(Int32) ); assert_eq!( - ree_comparison_coercion(&lhs_type, &rhs_type, false), + ree_coercion(&lhs_type, &rhs_type, false, comparison_coercion), Some(Int32) ); - // Since we can coerce values of Int16 to Utf8 can support this: Coercion of Int16 to Utf8 + // In comparison context, numeric is preferred over string let lhs_type = RunEndEncoded( Arc::new(Field::new("run_ends", Int8, false)), Arc::new(Field::new("values", Utf8, false)), @@ -48,8 +54,8 @@ fn test_ree_type_coercion() { Arc::new(Field::new("values", Int16, false)), ); assert_eq!( - ree_comparison_coercion(&lhs_type, &rhs_type, true), - Some(Utf8) + ree_coercion(&lhs_type, &rhs_type, true, comparison_coercion), + Some(Int16) ); // Since we can coerce values of Utf8 to Binary can support this @@ -62,7 +68,7 @@ fn test_ree_type_coercion() { Arc::new(Field::new("values", Binary, false)), ); assert_eq!( - ree_comparison_coercion(&lhs_type, &rhs_type, true), + ree_coercion(&lhs_type, &rhs_type, true, comparison_coercion), Some(Binary) ); let lhs_type = RunEndEncoded( @@ -72,12 +78,12 @@ fn test_ree_type_coercion() { let rhs_type = Utf8; // Don't preserve REE assert_eq!( - ree_comparison_coercion(&lhs_type, &rhs_type, false), + ree_coercion(&lhs_type, &rhs_type, false, comparison_coercion), Some(Utf8) ); // Preserve REE assert_eq!( - ree_comparison_coercion(&lhs_type, &rhs_type, true), + ree_coercion(&lhs_type, &rhs_type, true, comparison_coercion), Some(lhs_type.clone()) ); @@ -88,12 +94,38 @@ fn test_ree_type_coercion() { ); // Don't preserve REE assert_eq!( - ree_comparison_coercion(&lhs_type, &rhs_type, false), + ree_coercion(&lhs_type, &rhs_type, false, comparison_coercion), Some(Utf8) ); // Preserve REE assert_eq!( - ree_comparison_coercion(&lhs_type, &rhs_type, true), + ree_coercion(&lhs_type, &rhs_type, true, comparison_coercion), Some(rhs_type.clone()) ); } + +#[test] +fn test_ree_arithmetic_coercion() -> Result<()> { + test_coercion_binary_rule!(ree(Int64), Int64, Operator::Plus, Int64); + test_coercion_binary_rule!(Int64, ree(Int64), Operator::Multiply, Int64); + test_coercion_binary_rule!(ree(Int32), ree(Int64), Operator::Plus, Int64); + + // Decimal unwrapping through math_decimal_coercion + let (lhs, rhs) = + BinaryTypeCoercer::new(&ree(Decimal128(10, 2)), &Operator::Plus, &Int32) + .get_input_types()?; + assert_eq!(lhs, Decimal128(10, 2)); + assert_eq!(rhs, Decimal128(10, 0)); + + let (lhs, rhs) = + BinaryTypeCoercer::new(&Int32, &Operator::Plus, &ree(Decimal128(10, 2))) + .get_input_types()?; + assert_eq!(lhs, Decimal128(10, 0)); + assert_eq!(rhs, Decimal128(10, 2)); + + let result = + BinaryTypeCoercer::new(&ree(Utf8), &Operator::Plus, &Int32).get_input_types(); + assert!(result.is_err()); + + Ok(()) +} diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 11d6ca1533db3..8cec01feb30b5 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -46,7 +46,8 @@ recursive_protection = ["dep:recursive"] sql = ["sqlparser"] [dependencies] -arrow = { workspace = true } +arrow = { workspace = true, features = ["canonical_extension_types"] } +arrow-schema = { workspace = true, features = ["canonical_extension_types"] } async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = false } @@ -57,7 +58,6 @@ datafusion-functions-window-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } -paste = "^1.0" recursive = { workspace = true, optional = true } serde_json = { workspace = true } sqlparser = { workspace = true, optional = true } @@ -66,3 +66,6 @@ sqlparser = { workspace = true, optional = true } ctor = { workspace = true } env_logger = { workspace = true } insta = { workspace = true } +# Makes sure `test_display_pg_json` behaves in a consistent way regardless of +# feature unification with dependencies +serde_json = { workspace = true, features = ["preserve_order"] } diff --git a/datafusion/expr/src/arguments.rs b/datafusion/expr/src/arguments.rs index 5653993db98fe..f10cf50f60b24 100644 --- a/datafusion/expr/src/arguments.rs +++ b/datafusion/expr/src/arguments.rs @@ -18,8 +18,21 @@ //! Argument resolution logic for named function parameters use crate::Expr; -use datafusion_common::{plan_err, Result}; -use std::collections::HashMap; +use datafusion_common::{Result, plan_err}; + +/// Represents a named function argument with its original case and quote information. +/// +/// This struct preserves whether an identifier was quoted in the SQL, which determines +/// whether case-sensitive or case-insensitive matching should be used per SQL standards. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ArgumentName { + /// The argument name in its original case as it appeared in the SQL + pub value: String, + /// Whether the identifier was quoted (e.g., "STR" vs STR) + /// - true: quoted identifier, requires case-sensitive matching + /// - false: unquoted identifier, uses case-insensitive matching + pub is_quoted: bool, +} /// Resolves function arguments, handling named and positional notation. /// @@ -50,7 +63,7 @@ use std::collections::HashMap; pub fn resolve_function_arguments( param_names: &[String], args: Vec, - arg_names: Vec>, + arg_names: Vec>, ) -> Result> { if args.len() != arg_names.len() { return plan_err!( @@ -71,7 +84,7 @@ pub fn resolve_function_arguments( } /// Validates that positional arguments come before named arguments -fn validate_argument_order(arg_names: &[Option]) -> Result<()> { +fn validate_argument_order(arg_names: &[Option]) -> Result<()> { let mut seen_named = false; for (i, arg_name) in arg_names.iter().enumerate() { match arg_name { @@ -93,15 +106,8 @@ fn validate_argument_order(arg_names: &[Option]) -> Result<()> { fn reorder_named_arguments( param_names: &[String], args: Vec, - arg_names: Vec>, + arg_names: Vec>, ) -> Result> { - // Build HashMap for O(1) parameter name lookups - let param_index_map: HashMap<&str, usize> = param_names - .iter() - .enumerate() - .map(|(idx, name)| (name.as_str(), idx)) - .collect(); - let positional_count = arg_names.iter().filter(|n| n.is_none()).count(); // Capture args length before consuming the vector @@ -120,19 +126,35 @@ fn reorder_named_arguments( let mut result: Vec> = vec![None; expected_arg_count]; for (i, (arg, arg_name)) in args.into_iter().zip(arg_names).enumerate() { - if let Some(name) = arg_name { - // Named argument - O(1) lookup in HashMap - let param_index = - param_index_map.get(name.as_str()).copied().ok_or_else(|| { + if let Some(arg_name) = arg_name { + // Named argument - find parameter index using linear search + // Match based on SQL identifier rules: + // - Quoted identifiers: case-sensitive (exact match) + // - Unquoted identifiers: case-insensitive match + let param_index = param_names + .iter() + .position(|p| { + if arg_name.is_quoted { + // Quoted: exact case match + p == &arg_name.value + } else { + // Unquoted: case-insensitive match + p.eq_ignore_ascii_case(&arg_name.value) + } + }) + .ok_or_else(|| { datafusion_common::plan_datafusion_err!( "Unknown parameter name '{}'. Valid parameters are: [{}]", - name, + arg_name.value, param_names.join(", ") ) })?; if result[param_index].is_some() { - return plan_err!("Parameter '{}' specified multiple times", name); + return plan_err!( + "Parameter '{}' specified multiple times", + arg_name.value + ); } result[param_index] = Some(arg); @@ -175,12 +197,111 @@ mod tests { let param_names = vec!["a".to_string(), "b".to_string()]; let args = vec![lit(1), lit("hello")]; - let arg_names = vec![Some("a".to_string()), Some("b".to_string())]; + let arg_names = vec![ + Some(ArgumentName { + value: "a".to_string(), + is_quoted: false, + }), + Some(ArgumentName { + value: "b".to_string(), + is_quoted: false, + }), + ]; let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap(); assert_eq!(result.len(), 2); } + #[test] + fn test_case_insensitive_parameter_matching() { + // Parameter names in function signature (lowercase) + let param_names = vec!["startpos".to_string(), "length".to_string()]; + + // Unquoted arguments with different casing should match case-insensitively + let args = vec![lit(1), lit(10)]; + let arg_names = vec![ + Some(ArgumentName { + value: "STARTPOS".to_string(), + is_quoted: false, + }), + Some(ArgumentName { + value: "LENGTH".to_string(), + is_quoted: false, + }), + ]; + + let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap(); + assert_eq!(result.len(), 2); + assert_eq!(result[0], lit(1)); + assert_eq!(result[1], lit(10)); + + // Test with reordering and different cases + let args2 = vec![lit(20), lit(5)]; + let arg_names2 = vec![ + Some(ArgumentName { + value: "Length".to_string(), + is_quoted: false, + }), + Some(ArgumentName { + value: "StartPos".to_string(), + is_quoted: false, + }), + ]; + + let result2 = + resolve_function_arguments(¶m_names, args2, arg_names2).unwrap(); + assert_eq!(result2.len(), 2); + assert_eq!(result2[0], lit(5)); // startpos + assert_eq!(result2[1], lit(20)); // length + } + + #[test] + fn test_quoted_parameter_case_sensitive() { + // Parameter names in function signature (lowercase) + let param_names = vec!["str".to_string(), "start_pos".to_string()]; + + // Quoted identifiers with wrong case should fail + let args = vec![lit("hello"), lit(1)]; + let arg_names = vec![ + Some(ArgumentName { + value: "STR".to_string(), + is_quoted: true, + }), + Some(ArgumentName { + value: "start_pos".to_string(), + is_quoted: true, + }), + ]; + + let result = resolve_function_arguments(¶m_names, args, arg_names); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Unknown parameter") + ); + + // Quoted identifiers with correct case should succeed + let args2 = vec![lit("hello"), lit(1)]; + let arg_names2 = vec![ + Some(ArgumentName { + value: "str".to_string(), + is_quoted: true, + }), + Some(ArgumentName { + value: "start_pos".to_string(), + is_quoted: true, + }), + ]; + + let result2 = + resolve_function_arguments(¶m_names, args2, arg_names2).unwrap(); + assert_eq!(result2.len(), 2); + assert_eq!(result2[0], lit("hello")); + assert_eq!(result2[1], lit(1)); + } + #[test] fn test_named_reordering() { let param_names = vec!["a".to_string(), "b".to_string(), "c".to_string()]; @@ -188,9 +309,18 @@ mod tests { // Call with: func(c => 3.0, a => 1, b => "hello") let args = vec![lit(3.0), lit(1), lit("hello")]; let arg_names = vec![ - Some("c".to_string()), - Some("a".to_string()), - Some("b".to_string()), + Some(ArgumentName { + value: "c".to_string(), + is_quoted: false, + }), + Some(ArgumentName { + value: "a".to_string(), + is_quoted: false, + }), + Some(ArgumentName { + value: "b".to_string(), + is_quoted: false, + }), ]; let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap(); @@ -208,7 +338,17 @@ mod tests { // Call with: func(1, c => 3.0, b => "hello") let args = vec![lit(1), lit(3.0), lit("hello")]; - let arg_names = vec![None, Some("c".to_string()), Some("b".to_string())]; + let arg_names = vec![ + None, + Some(ArgumentName { + value: "c".to_string(), + is_quoted: false, + }), + Some(ArgumentName { + value: "b".to_string(), + is_quoted: false, + }), + ]; let result = resolve_function_arguments(¶m_names, args, arg_names).unwrap(); @@ -225,14 +365,22 @@ mod tests { // Call with: func(a => 1, "hello") - ERROR let args = vec![lit(1), lit("hello")]; - let arg_names = vec![Some("a".to_string()), None]; + let arg_names = vec![ + Some(ArgumentName { + value: "a".to_string(), + is_quoted: false, + }), + None, + ]; let result = resolve_function_arguments(¶m_names, args, arg_names); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Positional argument")); + assert!( + result + .unwrap_err() + .to_string() + .contains("Positional argument") + ); } #[test] @@ -241,14 +389,25 @@ mod tests { // Call with: func(x => 1, b => "hello") - ERROR let args = vec![lit(1), lit("hello")]; - let arg_names = vec![Some("x".to_string()), Some("b".to_string())]; + let arg_names = vec![ + Some(ArgumentName { + value: "x".to_string(), + is_quoted: false, + }), + Some(ArgumentName { + value: "b".to_string(), + is_quoted: false, + }), + ]; let result = resolve_function_arguments(¶m_names, args, arg_names); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Unknown parameter")); + assert!( + result + .unwrap_err() + .to_string() + .contains("Unknown parameter") + ); } #[test] @@ -257,14 +416,25 @@ mod tests { // Call with: func(a => 1, a => 2) - ERROR let args = vec![lit(1), lit(2)]; - let arg_names = vec![Some("a".to_string()), Some("a".to_string())]; + let arg_names = vec![ + Some(ArgumentName { + value: "a".to_string(), + is_quoted: false, + }), + Some(ArgumentName { + value: "a".to_string(), + is_quoted: false, + }), + ]; let result = resolve_function_arguments(¶m_names, args, arg_names); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("specified multiple times")); + assert!( + result + .unwrap_err() + .to_string() + .contains("specified multiple times") + ); } #[test] @@ -273,13 +443,232 @@ mod tests { // Call with: func(a => 1, c => 3.0) - missing 'b' let args = vec![lit(1), lit(3.0)]; - let arg_names = vec![Some("a".to_string()), Some("c".to_string())]; + let arg_names = vec![ + Some(ArgumentName { + value: "a".to_string(), + is_quoted: false, + }), + Some(ArgumentName { + value: "c".to_string(), + is_quoted: false, + }), + ]; let result = resolve_function_arguments(¶m_names, args, arg_names); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Missing required parameter")); + assert!( + result + .unwrap_err() + .to_string() + .contains("Missing required parameter") + ); + } + + #[test] + fn test_mixed_case_signature_unquoted_matching() { + // Test with mixed-case signature parameters (lowercase, camelCase, UPPERCASE) + // This proves case-insensitive matching works for unquoted identifiers + let param_names = vec![ + "prefix".to_string(), // lowercase + "startPos".to_string(), // camelCase + "LENGTH".to_string(), // UPPERCASE + ]; + + // Test 1: All lowercase unquoted arguments should match + let args1 = vec![lit("a"), lit(1), lit(5)]; + let arg_names1 = vec![ + Some(ArgumentName { + value: "prefix".to_string(), + is_quoted: false, + }), + Some(ArgumentName { + value: "startpos".to_string(), // lowercase version of startPos + is_quoted: false, + }), + Some(ArgumentName { + value: "length".to_string(), // lowercase version of LENGTH + is_quoted: false, + }), + ]; + + let result1 = + resolve_function_arguments(¶m_names, args1, arg_names1).unwrap(); + assert_eq!(result1.len(), 3); + assert_eq!(result1[0], lit("a")); + assert_eq!(result1[1], lit(1)); + assert_eq!(result1[2], lit(5)); + + // Test 2: All uppercase unquoted arguments should match + let args2 = vec![lit("b"), lit(2), lit(10)]; + let arg_names2 = vec![ + Some(ArgumentName { + value: "PREFIX".to_string(), // uppercase version of prefix + is_quoted: false, + }), + Some(ArgumentName { + value: "STARTPOS".to_string(), // uppercase version of startPos + is_quoted: false, + }), + Some(ArgumentName { + value: "LENGTH".to_string(), // matches UPPERCASE + is_quoted: false, + }), + ]; + + let result2 = + resolve_function_arguments(¶m_names, args2, arg_names2).unwrap(); + assert_eq!(result2.len(), 3); + assert_eq!(result2[0], lit("b")); + assert_eq!(result2[1], lit(2)); + assert_eq!(result2[2], lit(10)); + + // Test 3: Mixed case unquoted arguments should match + let args3 = vec![lit("c"), lit(3), lit(15)]; + let arg_names3 = vec![ + Some(ArgumentName { + value: "Prefix".to_string(), // Title case + is_quoted: false, + }), + Some(ArgumentName { + value: "StartPos".to_string(), // matches camelCase + is_quoted: false, + }), + Some(ArgumentName { + value: "Length".to_string(), // Title case + is_quoted: false, + }), + ]; + + let result3 = + resolve_function_arguments(¶m_names, args3, arg_names3).unwrap(); + assert_eq!(result3.len(), 3); + assert_eq!(result3[0], lit("c")); + assert_eq!(result3[1], lit(3)); + assert_eq!(result3[2], lit(15)); + } + + #[test] + fn test_mixed_case_signature_quoted_matching() { + // Test that quoted identifiers require exact case match with signature + let param_names = vec![ + "prefix".to_string(), // lowercase + "startPos".to_string(), // camelCase + "LENGTH".to_string(), // UPPERCASE + ]; + + // Test 1: Quoted with wrong case should fail for "prefix" + let args_wrong_prefix = vec![lit("a"), lit(1), lit(5)]; + let arg_names_wrong_prefix = vec![ + Some(ArgumentName { + value: "PREFIX".to_string(), // Wrong case + is_quoted: true, + }), + Some(ArgumentName { + value: "startPos".to_string(), + is_quoted: true, + }), + Some(ArgumentName { + value: "LENGTH".to_string(), + is_quoted: true, + }), + ]; + + let result = resolve_function_arguments( + ¶m_names, + args_wrong_prefix, + arg_names_wrong_prefix, + ); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Unknown parameter") + ); + + // Test 2: Quoted with wrong case should fail for "startPos" + let args_wrong_startpos = vec![lit("a"), lit(1), lit(5)]; + let arg_names_wrong_startpos = vec![ + Some(ArgumentName { + value: "prefix".to_string(), + is_quoted: true, + }), + Some(ArgumentName { + value: "STARTPOS".to_string(), // Wrong case + is_quoted: true, + }), + Some(ArgumentName { + value: "LENGTH".to_string(), + is_quoted: true, + }), + ]; + + let result2 = resolve_function_arguments( + ¶m_names, + args_wrong_startpos, + arg_names_wrong_startpos, + ); + assert!(result2.is_err()); + assert!( + result2 + .unwrap_err() + .to_string() + .contains("Unknown parameter") + ); + + // Test 3: Quoted with wrong case should fail for "LENGTH" + let args_wrong_length = vec![lit("a"), lit(1), lit(5)]; + let arg_names_wrong_length = vec![ + Some(ArgumentName { + value: "prefix".to_string(), + is_quoted: true, + }), + Some(ArgumentName { + value: "startPos".to_string(), + is_quoted: true, + }), + Some(ArgumentName { + value: "length".to_string(), // Wrong case + is_quoted: true, + }), + ]; + + let result3 = resolve_function_arguments( + ¶m_names, + args_wrong_length, + arg_names_wrong_length, + ); + assert!(result3.is_err()); + assert!( + result3 + .unwrap_err() + .to_string() + .contains("Unknown parameter") + ); + + // Test 4: Quoted with exact case should succeed + let args_correct = vec![lit("a"), lit(1), lit(5)]; + let arg_names_correct = vec![ + Some(ArgumentName { + value: "prefix".to_string(), // Exact match + is_quoted: true, + }), + Some(ArgumentName { + value: "startPos".to_string(), // Exact match + is_quoted: true, + }), + Some(ArgumentName { + value: "LENGTH".to_string(), // Exact match + is_quoted: true, + }), + ]; + + let result4 = + resolve_function_arguments(¶m_names, args_correct, arg_names_correct) + .unwrap(); + assert_eq!(result4.len(), 3); + assert_eq!(result4[0], lit("a")); + assert_eq!(result4[1], lit(1)); + assert_eq!(result4[2], lit(5)); } } diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs index 561ef1dc15e7d..02a6d2ece8cdb 100644 --- a/datafusion/expr/src/async_udf.rs +++ b/datafusion/expr/src/async_udf.rs @@ -63,7 +63,7 @@ impl PartialEq for AsyncScalarUDF { fn eq(&self, other: &Self) -> bool { // Deconstruct to catch any new fields added in future let Self { inner } = self; - inner.dyn_eq(other.inner.as_any()) + inner.as_ref().dyn_eq(other.inner.as_ref() as &dyn Any) } } impl Eq for AsyncScalarUDF {} @@ -102,10 +102,6 @@ impl AsyncScalarUDF { } impl ScalarUDFImpl for AsyncScalarUDF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { self.inner.name() } @@ -146,8 +142,8 @@ mod tests { use datafusion_expr_common::{columnar_value::ColumnarValue, signature::Signature}; use crate::{ - async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}, ScalarFunctionArgs, ScalarUDFImpl, + async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}, }; #[derive(Debug, PartialEq, Eq, Hash, Clone)] @@ -156,10 +152,6 @@ mod tests { } impl ScalarUDFImpl for TestAsyncUDFImpl1 { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { todo!() } @@ -193,10 +185,6 @@ mod tests { } impl ScalarUDFImpl for TestAsyncUDFImpl2 { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { todo!() } diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index d02f522910c19..10a9fd6948e4f 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -17,9 +17,9 @@ //! Conditional expressions use crate::expr::Case; -use crate::{expr_schema::ExprSchemable, Expr}; +use crate::{Expr, expr_schema::ExprSchemable}; use arrow::datatypes::DataType; -use datafusion_common::{plan_err, DFSchema, HashSet, Result}; +use datafusion_common::{DFSchema, HashSet, Result, plan_err}; use itertools::Itertools as _; /// Helper struct for building [Expr::Case] diff --git a/datafusion/expr/src/execution_props.rs b/datafusion/expr/src/execution_props.rs index acfcc61b7eced..649f74ed3997c 100644 --- a/datafusion/expr/src/execution_props.rs +++ b/datafusion/expr/src/execution_props.rs @@ -16,14 +16,20 @@ // under the License. use crate::var_provider::{VarProvider, VarType}; -use chrono::{DateTime, TimeZone, Utc}; +use chrono::{DateTime, Utc}; +use datafusion_common::HashMap; +use datafusion_common::ScalarValue; +use datafusion_common::TableReference; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; -use datafusion_common::HashMap; -use std::sync::Arc; +use datafusion_common::{Result, internal_err}; +use std::fmt; +use std::hash::{Hash, Hasher}; +use std::sync::{Arc, Mutex}; -/// Holds per-query execution properties and data (such as statement -/// starting timestamps). +/// Holds properties and scratch state used while optimizing a [`LogicalPlan`] +/// and translating it into an executable physical plan, such as the statement +/// start time used during simplification. /// /// An [`ExecutionProps`] is created each time a `LogicalPlan` is /// prepared for execution (optimized). If the same plan is optimized @@ -31,15 +37,43 @@ use std::sync::Arc; /// /// It is important that this structure be cheap to create as it is /// done so during predicate pruning and expression simplification +/// +/// # Relationship with [`TaskContext`] +/// +/// [`ExecutionProps`] is intentionally distinct from [`TaskContext`]. +/// It is used while optimizing a logical plan and constructing physical +/// expressions and physical plans, before physical operators are run. +/// +/// [`TaskContext`] is the runtime context passed to physical operators during +/// physical-plan execution. +/// +/// Keeping these structures separate avoids threading execution/runtime state +/// through planning APIs, and avoids making execution depend on planner-only +/// scratch state. +/// +/// [`TaskContext`]: https://docs.rs/datafusion/latest/datafusion/execution/struct.TaskContext.html +/// [`LogicalPlan`]: crate::LogicalPlan #[derive(Clone, Debug)] pub struct ExecutionProps { - pub query_execution_start_time: DateTime, + /// The time at which the query execution started. If `None`, + /// functions like `now()` will not be simplified during optimization. + pub query_execution_start_time: Option>, /// Alias generator used by subquery optimizer rules pub alias_generator: Arc, /// Snapshot of config options when the query started pub config_options: Option>, /// Providers for scalar variables pub var_providers: Option>>, + /// Maps each logical `Subquery` to its index in `subquery_results`. + /// Populated by the physical planner before calling `create_physical_expr`. + pub subquery_indexes: HashMap, + /// Shared results container for uncorrelated scalar subquery values. + /// Populated at execution time by `ScalarSubqueryExec`. + pub subquery_results: ScalarSubqueryResults, + /// Maps each lambda variable name to its lambda qualifier generated + /// during physical planning. Populated by the physical planner for + /// each lambda before calling `create_physical_expr`. + pub lambda_variable_qualifier: HashMap, } impl Default for ExecutionProps { @@ -52,12 +86,13 @@ impl ExecutionProps { /// Creates a new execution props pub fn new() -> Self { ExecutionProps { - // Set this to a fixed sentinel to make it obvious if this is - // not being updated / propagated correctly - query_execution_start_time: Utc.timestamp_nanos(0), + query_execution_start_time: None, alias_generator: Arc::new(AliasGenerator::new()), config_options: None, var_providers: None, + subquery_indexes: HashMap::new(), + subquery_results: ScalarSubqueryResults::default(), + lambda_variable_qualifier: HashMap::new(), } } @@ -66,7 +101,7 @@ impl ExecutionProps { mut self, query_execution_start_time: DateTime, ) -> Self { - self.query_execution_start_time = query_execution_start_time; + self.query_execution_start_time = Some(query_execution_start_time); self } @@ -79,14 +114,13 @@ impl ExecutionProps { /// Marks the execution of query started timestamp. /// This also instantiates a new alias generator. pub fn mark_start_execution(&mut self, config_options: Arc) -> &Self { - self.query_execution_start_time = Utc::now(); + self.query_execution_start_time = Some(Utc::now()); self.alias_generator = Arc::new(AliasGenerator::new()); self.config_options = Some(config_options); &*self } - /// Registers a variable provider, returning the existing - /// provider, if any + /// Registers a variable provider, returning the existing provider, if any pub fn add_var_provider( &mut self, var_type: VarType, @@ -117,14 +151,167 @@ impl ExecutionProps { pub fn config_options(&self) -> Option<&Arc> { self.config_options.as_ref() } + + /// Adds a mapping for each variable to the given qualifier. Existing + /// variables with conflicting names get's shadowed + pub fn with_qualified_lambda_variables( + mut self, + qualifier: &TableReference, + variables: &[String], + ) -> Self { + for var in variables { + self.lambda_variable_qualifier + .entry_ref(var) + .insert(qualifier.clone()); + } + + self + } +} + +/// Index of a scalar subquery within a [`ScalarSubqueryResults`] container. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct SubqueryIndex(usize); + +impl SubqueryIndex { + /// Creates a new subquery index. + pub const fn new(index: usize) -> Self { + Self(index) + } + + /// Returns the underlying slot index. + pub const fn as_usize(self) -> usize { + self.0 + } +} + +/// Shared results container for uncorrelated scalar subqueries. +/// +/// Each entry corresponds to one scalar subquery, identified by its index. +/// Each slot is populated at execution time by `ScalarSubqueryExec`, read by +/// `ScalarSubqueryExpr` instances that share this container, and cleared when +/// the plan is reset for re-execution. +#[derive(Clone, Default)] +pub struct ScalarSubqueryResults { + slots: Arc>>>, +} + +impl ScalarSubqueryResults { + /// Creates a new shared results container with `n` empty slots. + pub fn new(n: usize) -> Self { + Self { + slots: Arc::new((0..n).map(|_| Mutex::new(None)).collect()), + } + } + + /// Returns the scalar value stored at `index`, if it has been populated. + pub fn get(&self, index: SubqueryIndex) -> Option { + let slot = self.slots.get(index.as_usize())?; + slot.lock().unwrap().clone() + } + + /// Stores `value` in the slot at `index`. + pub fn set(&self, index: SubqueryIndex, value: ScalarValue) -> Result<()> { + let Some(slot) = self.slots.get(index.as_usize()) else { + return internal_err!( + "ScalarSubqueryResults: result index {} is out of bounds", + index.as_usize() + ); + }; + + let mut slot = slot.lock().unwrap(); + if slot.is_some() { + return internal_err!( + "ScalarSubqueryResults: result for index {} was already populated", + index.as_usize() + ); + } + *slot = Some(value); + + Ok(()) + } + + /// Clears all populated results so the container can be reused. + pub fn clear(&self) { + for slot in self.slots.iter() { + *slot.lock().unwrap() = None; + } + } + + /// Returns true if `this` and `other` point to the same shared container. + pub fn ptr_eq(this: &Self, other: &Self) -> bool { + Arc::ptr_eq(&this.slots, &other.slots) + } +} + +impl fmt::Debug for ScalarSubqueryResults { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list() + .entries(self.slots.iter().map(|slot| slot.lock().unwrap().clone())) + .finish() + } +} + +impl PartialEq for ScalarSubqueryResults { + fn eq(&self, other: &Self) -> bool { + Self::ptr_eq(self, other) + } +} + +impl Eq for ScalarSubqueryResults {} + +impl Hash for ScalarSubqueryResults { + fn hash(&self, state: &mut H) { + Arc::as_ptr(&self.slots).hash(state); + } } #[cfg(test)] mod test { use super::*; + #[test] fn debug() { let props = ExecutionProps::new(); - assert_eq!("ExecutionProps { query_execution_start_time: 1970-01-01T00:00:00Z, alias_generator: AliasGenerator { next_id: 1 }, config_options: None, var_providers: None }", format!("{props:?}")); + assert_eq!( + "ExecutionProps { query_execution_start_time: None, alias_generator: AliasGenerator { next_id: 1 }, config_options: None, var_providers: None, subquery_indexes: {}, subquery_results: [], lambda_variable_qualifier: {} }", + format!("{props:?}") + ); + } + + #[test] + fn scalar_subquery_results_set_and_get() -> Result<()> { + let results = ScalarSubqueryResults::new(1); + assert_eq!(results.get(SubqueryIndex::new(0)), None); + + results.set(SubqueryIndex::new(0), ScalarValue::Int32(Some(42)))?; + assert_eq!( + results.get(SubqueryIndex::new(0)), + Some(ScalarValue::Int32(Some(42))) + ); + assert!( + results + .set(SubqueryIndex::new(0), ScalarValue::Int32(Some(7))) + .is_err() + ); + + Ok(()) + } + + #[test] + fn scalar_subquery_results_clear() -> Result<()> { + let results = ScalarSubqueryResults::new(1); + results.set(SubqueryIndex::new(0), ScalarValue::Int32(Some(42)))?; + + results.clear(); + + assert_eq!(results.get(SubqueryIndex::new(0)), None); + results.set(SubqueryIndex::new(0), ScalarValue::Int32(Some(7)))?; + assert_eq!( + results.get(SubqueryIndex::new(0)), + Some(ScalarValue::Int32(Some(7))) + ); + + Ok(()) } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 13160d573ab4d..98d355fad800e 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -26,23 +26,35 @@ use std::sync::Arc; use crate::expr_fn::binary_expr; use crate::function::WindowFunctionSimplification; +use crate::higher_order_function::{HigherOrderUDF, resolve_lambda_variables}; use crate::logical_plan::Subquery; -use crate::{AggregateUDF, Volatility}; +use crate::type_coercion::functions::value_fields_with_higher_order_udf; +use crate::{AggregateUDF, LambdaParametersProgress, ValueOrLambda, Volatility}; use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; +use datafusion_common::datatype::DataTypeExt; +use datafusion_common::metadata::format_type_and_metadata; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; use datafusion_common::{ - Column, DFSchema, HashMap, Result, ScalarValue, Spans, TableReference, + Column, DFSchema, ExprSchema, HashMap, Result, ScalarValue, Spans, TableReference, + plan_err, }; +use datafusion_expr_common::placement::ExpressionPlacement; use datafusion_functions_window_common::field::WindowUDFFieldArgs; #[cfg(feature = "sql")] -use sqlparser::ast::{ - display_comma_separated, ExceptSelectItem, ExcludeSelectItem, IlikeSelectItem, - RenameSelectItem, ReplaceSelectElement, +pub use sqlparser::ast::{ + ExceptSelectItem, ExcludeSelectItem, IlikeSelectItem, RenameSelectItem, + ReplaceSelectElement, +}; +// Use shims for sqlparser types when the sql feature is disabled. +#[cfg(not(feature = "sql"))] +pub use crate::sql::{ + ExceptSelectItem, ExcludeSelectItem, IlikeSelectItem, RenameSelectItem, + ReplaceSelectElement, }; // Moved in 51.0.0 to datafusion_common @@ -80,7 +92,7 @@ impl From for NullTreatment { /// /// For example the expression `A + 1` will be represented as /// -///```text +/// ```text /// BinaryExpr { /// left: Expr::Column("A"), /// op: Operator::Plus, @@ -253,7 +265,7 @@ impl From for NullTreatment { /// /// [`ExplainFormat::Tree`]: crate::logical_plan::ExplainFormat::Tree /// -///``` +/// ``` /// # use datafusion_expr::{lit, col}; /// let expr = col("c1") + lit(42); /// assert_eq!(format!("{}", expr.human_display()), "c1 + 42"); @@ -289,7 +301,7 @@ impl From for NullTreatment { /// Rewrite an expression, replacing references to column "a" in an /// to the literal `42`: /// -/// ``` +/// ``` /// # use datafusion_common::tree_node::{Transformed, TreeNode}; /// # use datafusion_expr::{col, Expr, lit}; /// // expression a = 5 AND b = 6 @@ -309,6 +321,7 @@ impl From for NullTreatment { /// assert!(rewritten.transformed); /// // to 42 = 5 AND b = 6 /// assert_eq!(rewritten.data, lit(42).eq(lit(5)).and(col("b").eq(lit(6)))); +/// ``` #[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] pub enum Expr { /// An expression with a specific name. @@ -316,7 +329,7 @@ pub enum Expr { /// A named reference to a qualified field in a schema. Column(Column), /// A named reference to a variable in a registry. - ScalarVariable(DataType, Vec), + ScalarVariable(FieldRef, Vec), /// A constant value along with associated [`FieldMetadata`]. Literal(ScalarValue, Option), /// A binary expression such as "age > 21" @@ -372,6 +385,8 @@ pub enum Expr { Exists(Exists), /// IN subquery InSubquery(InSubquery), + /// Set comparison subquery (e.g. `= ANY`, `> ALL`) + SetComparison(SetComparison), /// Scalar subquery ScalarSubquery(Subquery), /// Represents a reference to all available fields in a specific schema, @@ -398,6 +413,133 @@ pub enum Expr { OuterReferenceColumn(FieldRef, Column), /// Unnest expression Unnest(Unnest), + /// Call a higher order function with a set of arguments. + /// + /// For example, `array_transform([1,2,3], v -> v+1)` would be equivalent to: + /// + /// ```text + /// HigherOrderFunction(array_transform) + /// ├── args[0]: Literal([1,2,3]) + /// └── args[1]: Lambda + /// ├── params: ["v"] + /// └── body: BinaryExpr(+) + /// ├── LambdaVariable("v") + /// └── Literal(1) + /// ``` + HigherOrderFunction(HigherOrderFunction), + /// A Lambda expression with a set of parameters names and a body + Lambda(Lambda), + /// A named reference to a lambda parameter + LambdaVariable(LambdaVariable), +} + +/// Invoke a [`HigherOrderUDF`] with a set of arguments +#[derive(Clone, Eq, PartialOrd, Debug)] +pub struct HigherOrderFunction { + /// The function + pub func: Arc, + /// List of expressions to feed to the functions as arguments + pub args: Vec, +} + +impl HigherOrderFunction { + /// Create a new `HigherOrderFunction` from a [`HigherOrderUDF`] + pub fn new(func: Arc, args: Vec) -> Self { + Self { func, args } + } + + pub fn name(&self) -> &str { + self.func.name() + } + + /// Invokes the inner function [`crate::HigherOrderUDFImpl::lambda_parameters`] + /// using the arguments of this invocation. This expression lambda + /// variables must be already resolved either by coming from the + /// default sql planner or by calling [Expr::resolve_lambda_variables] + /// or [LogicalPlan::resolve_lambda_variables] + /// + /// [LogicalPlan::resolve_lambda_variables]: crate::LogicalPlan::resolve_lambda_variables + pub fn lambda_parameters( + &self, + schema: &dyn ExprSchema, + ) -> Result>> { + let args = self + .args + .iter() + .map(|e| match e { + Expr::Lambda(lambda) => { + Ok(ValueOrLambda::Lambda(Some(lambda.body.to_field(schema)?.1))) + } + _ => Ok(ValueOrLambda::Value(e.to_field(schema)?.1)), + }) + .collect::>>()?; + + let coerced_fields = + value_fields_with_higher_order_udf(&args, self.func.as_ref())?; + + match self.func.lambda_parameters(0, &coerced_fields)? { + LambdaParametersProgress::Partial(_) => plan_err!( + "{} lambda_parameters returned a partial result when the return type of all it's lambdas were provided", + self.name() + ), + LambdaParametersProgress::Complete(items) => Ok(items), + } + } +} + +impl Hash for HigherOrderFunction { + fn hash(&self, state: &mut H) { + self.func.hash(state); + self.args.hash(state); + } +} + +impl PartialEq for HigherOrderFunction { + fn eq(&self, other: &Self) -> bool { + self.func.as_ref() == other.func.as_ref() && self.args == other.args + } +} + +/// A named reference to a lambda parameter which includes it's own [`FieldRef`], +/// which is used to implement [`ExprSchemable`], for example. It is an option only to make +/// easier for `expr_api` users to construct lambda variables, but any expression +/// tree or [`LogicalPlan`] containing unresolved variables must be resolved before +/// usage with either [`Expr::resolve_lambda_variables`] or +/// [`LogicalPlan::resolve_lambda_variables`]. The default SQL planner produces +/// already resolved variables and no further resolving is required. +/// +/// After resolving, if any argument from the lambda function which this +/// variables originates from have it's field changed (type, nullability, +/// metadata, etc), the resolved variable may became outdated and must be +/// resolved again. +/// +/// [`LogicalPlan`]: crate::LogicalPlan +/// [`LogicalPlan::resolve_lambda_variables`]: crate::LogicalPlan::resolve_lambda_variables +#[derive(Clone, PartialEq, PartialOrd, Eq, Debug, Hash)] +pub struct LambdaVariable { + pub name: String, + pub field: Option, + pub spans: Spans, +} + +impl LambdaVariable { + /// Create a lambda variable from a name and an optional field. + /// If the field is none, the expression tree or LogicalPlan which + /// owns this variable must be resolved before usage with either + /// [`Expr::resolve_lambda_variables`] or [`LogicalPlan::resolve_lambda_variables`]. + /// + /// [`LogicalPlan::resolve_lambda_variables`]: crate::LogicalPlan::resolve_lambda_variables + pub fn new(name: String, field: Option) -> Self { + Self { + name, + field, + spans: Spans::new(), + } + } + + pub fn spans_mut(&mut self) -> &mut Spans { + &mut self.spans + } } impl Default for Expr { @@ -479,7 +621,7 @@ impl<'a> TreeNodeContainer<'a, Self> for Expr { /// that may be missing in the physical data but present in the logical schema. /// See the [default_column_values.rs] example implementation. /// -/// [default_column_values.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/default_column_values.rs +/// [default_column_values.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/custom_data_source/default_column_values.rs pub type SchemaFieldMetadata = std::collections::HashMap; /// Intersects multiple metadata instances for UNION operations. @@ -506,17 +648,26 @@ pub type SchemaFieldMetadata = std::collections::HashMap; pub fn intersect_metadata_for_union<'a>( metadatas: impl IntoIterator, ) -> SchemaFieldMetadata { - let mut metadatas = metadatas.into_iter(); - let Some(mut intersected) = metadatas.next().cloned() else { - return Default::default(); - }; + let mut intersected: Option = None; for metadata in metadatas { - // Only keep keys that exist in both with the same value - intersected.retain(|k, v| metadata.get(k) == Some(v)); + // Skip empty metadata (e.g. from NULL literals or computed expressions) + // to avoid dropping metadata from branches that have it. + if metadata.is_empty() { + continue; + } + match &mut intersected { + None => { + intersected = Some(metadata.clone()); + } + Some(current) => { + // Only keep keys that exist in both with the same value + current.retain(|k, v| metadata.get(k) == Some(&*v)); + } + } } - intersected + intersected.unwrap_or_default() } /// UNNEST expression. @@ -592,9 +743,31 @@ impl Alias { self.metadata = metadata; self } + + #[doc(hidden)] + pub fn with_expr(mut self, expr: Expr) -> Self { + self.expr = Box::new(expr); + self + } + + #[doc(hidden)] + pub fn try_map_expr(self, f: impl FnOnce(Expr) -> Result) -> Result { + let Alias { + expr, + relation, + name, + metadata, + } = self; + Ok(Expr::Alias(Alias { + expr: Box::new(f(*expr)?), + relation, + name, + metadata, + })) + } } -/// Binary expression +/// Binary expression for [`Expr::BinaryExpr`] #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct BinaryExpr { /// Left-hand side of the expression @@ -796,13 +969,20 @@ pub struct Cast { /// The expression being cast pub expr: Box, /// The `DataType` the expression will yield - pub data_type: DataType, + pub field: FieldRef, } impl Cast { /// Create a new Cast expression pub fn new(expr: Box, data_type: DataType) -> Self { - Self { expr, data_type } + Self { + expr, + field: data_type.into_nullable_field_ref(), + } + } + + pub fn new_from_field(expr: Box, field: FieldRef) -> Self { + Self { expr, field } } } @@ -812,13 +992,20 @@ pub struct TryCast { /// The expression being cast pub expr: Box, /// The `DataType` the expression will yield - pub data_type: DataType, + pub field: FieldRef, } impl TryCast { /// Create a new TryCast expression pub fn new(expr: Box, data_type: DataType) -> Self { - Self { expr, data_type } + Self { + expr, + field: data_type.into_nullable_field_ref(), + } + } + + pub fn new_from_field(expr: Box, field: FieldRef) -> Self { + Self { expr, field } } } @@ -953,7 +1140,7 @@ impl AggregateFunction { pub enum WindowFunctionDefinition { /// A user defined aggregate function AggregateUDF(Arc), - /// A user defined aggregate function + /// A user defined window function WindowUDF(Arc), } @@ -990,7 +1177,7 @@ impl WindowFunctionDefinition { } } - /// Return the inner window simplification function, if any + /// Returns this window function's simplification hook, if any. /// /// See [`WindowFunctionSimplification`] for more information pub fn simplify(&self) -> Option { @@ -1077,7 +1264,7 @@ impl WindowFunction { } } - /// Return the inner window simplification function, if any + /// Returns this window function's simplification hook, if any. /// /// See [`WindowFunctionSimplification`] for more information pub fn simplify(&self) -> Option { @@ -1101,6 +1288,54 @@ impl Exists { } } +/// Whether the set comparison uses `ANY`/`SOME` or `ALL` +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub enum SetQuantifier { + /// `ANY` (or `SOME`) + Any, + /// `ALL` + All, +} + +impl Display for SetQuantifier { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + SetQuantifier::Any => write!(f, "ANY"), + SetQuantifier::All => write!(f, "ALL"), + } + } +} + +/// Set comparison subquery (e.g. `= ANY`, `> ALL`) +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct SetComparison { + /// The expression to compare + pub expr: Box, + /// Subquery that will produce a single column of data to compare against + pub subquery: Subquery, + /// Comparison operator (e.g. `=`, `>`, `<`) + pub op: Operator, + /// Quantifier (`ANY`/`ALL`) + pub quantifier: SetQuantifier, +} + +impl SetComparison { + /// Create a new set comparison expression + pub fn new( + expr: Box, + subquery: Subquery, + op: Operator, + quantifier: SetQuantifier, + ) -> Self { + Self { + expr, + subquery, + op, + quantifier, + } + } +} + /// InList expression #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct InList { @@ -1211,64 +1446,25 @@ impl GroupingSet { } } +/// A Lambda expression with a set of parameters names and a body #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] -#[cfg(not(feature = "sql"))] -pub struct IlikeSelectItem { - pub pattern: String, -} -#[cfg(not(feature = "sql"))] -impl Display for IlikeSelectItem { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "ILIKE '{}'", &self.pattern)?; - Ok(()) - } -} -#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] -#[cfg(not(feature = "sql"))] -pub enum ExcludeSelectItem { - Single(Ident), - Multiple(Vec), -} -#[cfg(not(feature = "sql"))] -impl Display for ExcludeSelectItem { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "EXCLUDE")?; - match self { - Self::Single(column) => { - write!(f, " {column}")?; - } - Self::Multiple(columns) => { - write!(f, " ({})", display_comma_separated(columns))?; - } - } - Ok(()) - } -} -#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] -#[cfg(not(feature = "sql"))] -pub struct ExceptSelectItem { - pub first_element: Ident, - pub additional_elements: Vec, +pub struct Lambda { + /// The parameters names + pub params: Vec, + /// The body expression + pub body: Box, } -#[cfg(not(feature = "sql"))] -impl Display for ExceptSelectItem { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "EXCEPT ")?; - if self.additional_elements.is_empty() { - write!(f, "({})", self.first_element)?; - } else { - write!( - f, - "({}, {})", - self.first_element, - display_comma_separated(&self.additional_elements) - )?; + +impl Lambda { + /// Create a new lambda expression + pub fn new(params: Vec, body: Expr) -> Self { + Self { + params, + body: Box::new(body), } - Ok(()) } } -#[cfg(not(feature = "sql"))] pub fn display_comma_separated(slice: &[T]) -> String where T: Display, @@ -1277,64 +1473,6 @@ where slice.iter().map(|v| format!("{v}")).join(", ") } -#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] -#[cfg(not(feature = "sql"))] -pub enum RenameSelectItem { - Single(String), - Multiple(Vec), -} -#[cfg(not(feature = "sql"))] -impl Display for RenameSelectItem { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "RENAME")?; - match self { - Self::Single(column) => { - write!(f, " {column}")?; - } - Self::Multiple(columns) => { - write!(f, " ({})", display_comma_separated(columns))?; - } - } - Ok(()) - } -} - -#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] -#[cfg(not(feature = "sql"))] -pub struct Ident { - /// The value of the identifier without quotes. - pub value: String, - /// The starting quote if any. Valid quote characters are the single quote, - /// double quote, backtick, and opening square bracket. - pub quote_style: Option, - /// The span of the identifier in the original SQL string. - pub span: String, -} -#[cfg(not(feature = "sql"))] -impl Display for Ident { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "[{}]", self.value) - } -} - -#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] -#[cfg(not(feature = "sql"))] -pub struct ReplaceSelectElement { - pub expr: String, - pub column_name: Ident, - pub as_keyword: bool, -} -#[cfg(not(feature = "sql"))] -impl Display for ReplaceSelectElement { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - if self.as_keyword { - write!(f, "{} AS {}", self.expr, self.column_name) - } else { - write!(f, "{} {}", self.expr, self.column_name) - } - } -} - /// Additional options for wildcards, e.g. Snowflake `EXCLUDE`/`RENAME` and Bigquery `EXCEPT`. #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug, Default)] pub struct WildcardOptions { @@ -1487,6 +1625,24 @@ impl Expr { } } + /// Returns placement information for this expression. + /// + /// This is used by optimizers to make decisions about expression placement, + /// such as whether to push expressions down through projections. + pub fn placement(&self) -> ExpressionPlacement { + match self { + Expr::Column(_) => ExpressionPlacement::Column, + Expr::Literal(_, _) => ExpressionPlacement::Literal, + Expr::Alias(inner) => inner.expr.placement(), + Expr::ScalarFunction(func) => { + let arg_placements: Vec<_> = + func.args.iter().map(|arg| arg.placement()).collect(); + func.func.placement(&arg_placements) + } + _ => ExpressionPlacement::KeepInPlace, + } + } + /// Return String representation of the variant represented by `self` /// Useful for non-rust based bindings pub fn variant_name(&self) -> &str { @@ -1503,6 +1659,7 @@ impl Expr { Expr::GroupingSet(..) => "GroupingSet", Expr::InList { .. } => "InList", Expr::InSubquery(..) => "InSubquery", + Expr::SetComparison(..) => "SetComparison", Expr::IsNotNull(..) => "IsNotNull", Expr::IsNull(..) => "IsNull", Expr::Like { .. } => "Like", @@ -1525,6 +1682,9 @@ impl Expr { #[expect(deprecated)] Expr::Wildcard { .. } => "Wildcard", Expr::Unnest { .. } => "Unnest", + Expr::HigherOrderFunction { .. } => "HigherOrderFunction", + Expr::Lambda { .. } => "Lambda", + Expr::LambdaVariable { .. } => "LambdaVariable", } } @@ -1964,6 +2124,12 @@ impl Expr { .expect("exists closure is infallible") } + /// Returns true if the expression contains a scalar subquery. + pub fn contains_scalar_subquery(&self) -> bool { + self.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_)))) + .expect("exists closure is infallible") + } + /// Returns true if the expression node is volatile, i.e. whether it can return /// different results when evaluated multiple times with the same input. /// Note: unlike [`Self::is_volatile`], this function does not consider inputs: @@ -2020,6 +2186,32 @@ impl Expr { rewrite_placeholder(item, expr.as_ref(), schema)?; } } + Expr::InSubquery(InSubquery { + expr, + subquery, + negated: _, + }) => { + let subquery_schema = subquery.subquery.schema(); + match &subquery_schema.fields()[..] { + [subquery_field] => { + let column = Expr::Column(Column::new_unqualified( + subquery_field.name().clone(), + )); + rewrite_placeholder( + expr.as_mut(), + &column, + subquery_schema, + )?; + } + _ => { + return plan_err!( + "InSubquery should only return one column, but found {}: {}", + subquery_schema.fields().len(), + subquery_schema.field_names().join(", ") + ); + } + } + } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { rewrite_placeholder(pattern.as_mut(), expr.as_ref(), schema)?; @@ -2040,6 +2232,9 @@ impl Expr { pub fn short_circuits(&self) -> bool { match self { Expr::ScalarFunction(ScalarFunction { func, .. }) => func.short_circuits(), + Expr::HigherOrderFunction(HigherOrderFunction { func, .. }) => { + func.short_circuits() + } Expr::BinaryExpr(BinaryExpr { op, .. }) => { matches!(op, Operator::And | Operator::Or) } @@ -2058,6 +2253,7 @@ impl Expr { | Expr::GroupingSet(..) | Expr::InList(..) | Expr::InSubquery(..) + | Expr::SetComparison(..) | Expr::IsFalse(..) | Expr::IsNotFalse(..) | Expr::IsNotNull(..) @@ -2078,7 +2274,9 @@ impl Expr { | Expr::Wildcard { .. } | Expr::WindowFunction(..) | Expr::Literal(..) - | Expr::Placeholder(..) => false, + | Expr::Placeholder(..) + | Expr::Lambda(..) + | Expr::LambdaVariable(..) => false, } } @@ -2100,11 +2298,21 @@ impl Expr { None } } + + /// Return a `Expr` with all [`LambdaVariable`] resolved only if all of them + /// are contained in the subtree of the [`HigherOrderFunction`] it originates from, + /// otherwise returns an error + pub fn resolve_lambda_variables( + self, + schema: &DFSchema, + ) -> Result> { + resolve_lambda_variables(self, schema, &mut HashMap::new()) + } } impl Normalizeable for Expr { fn can_normalize(&self) -> bool { - #[allow(clippy::match_like_matches_macro)] + #[expect(clippy::match_like_matches_macro)] match self { Expr::BinaryExpr(BinaryExpr { op: @@ -2252,23 +2460,23 @@ impl NormalizeEq for Expr { ( Expr::Cast(Cast { expr: self_expr, - data_type: self_data_type, + field: self_field, }), Expr::Cast(Cast { expr: other_expr, - data_type: other_data_type, + field: other_field, }), ) | ( Expr::TryCast(TryCast { expr: self_expr, - data_type: self_data_type, + field: self_field, }), Expr::TryCast(TryCast { expr: other_expr, - data_type: other_data_type, + field: other_field, }), - ) => self_data_type == other_data_type && self_expr.normalize_eq(other_expr), + ) => self_field == other_field && self_expr.normalize_eq(other_expr), ( Expr::ScalarFunction(ScalarFunction { func: self_func, @@ -2529,8 +2737,8 @@ impl HashNode for Expr { Expr::Column(column) => { column.hash(state); } - Expr::ScalarVariable(data_type, name) => { - data_type.hash(state); + Expr::ScalarVariable(field, name) => { + field.hash(state); name.hash(state); } Expr::Literal(scalar_value, _) => { @@ -2584,15 +2792,9 @@ impl HashNode for Expr { when_then_expr: _when_then_expr, else_expr: _else_expr, }) => {} - Expr::Cast(Cast { - expr: _expr, - data_type, - }) - | Expr::TryCast(TryCast { - expr: _expr, - data_type, - }) => { - data_type.hash(state); + Expr::Cast(Cast { expr: _expr, field }) + | Expr::TryCast(TryCast { expr: _expr, field }) => { + field.hash(state); } Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { func.hash(state); @@ -2651,6 +2853,16 @@ impl HashNode for Expr { subquery.hash(state); negated.hash(state); } + Expr::SetComparison(SetComparison { + expr: _, + subquery, + op, + quantifier, + }) => { + subquery.hash(state); + op.hash(state); + quantifier.hash(state); + } Expr::ScalarSubquery(subquery) => { subquery.hash(state); } @@ -2674,6 +2886,20 @@ impl HashNode for Expr { column.hash(state); } Expr::Unnest(Unnest { expr: _expr }) => {} + Expr::HigherOrderFunction(HigherOrderFunction { func, args: _args }) => { + func.hash(state); + } + Expr::Lambda(Lambda { params, body: _ }) => { + params.hash(state); + } + Expr::LambdaVariable(LambdaVariable { + name, + field, + spans: _, + }) => { + name.hash(state); + field.hash(state); + } }; } } @@ -2681,24 +2907,23 @@ impl HashNode for Expr { // Modifies expr to match the DataType, metadata, and nullability of other if it is // a placeholder with previously unspecified type information (i.e., most placeholders) fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> { - if let Expr::Placeholder(Placeholder { id: _, field }) = expr { - if field.is_none() { - let other_field = other.to_field(schema); - match other_field { - Err(e) => { - Err(e.context(format!( - "Can not find type of {other} needed to infer type of {expr}" - )))?; - } - Ok((_, other_field)) => { - // We can't infer the nullability of the future parameter that might - // be bound, so ensure this is set to true - *field = - Some(other_field.as_ref().clone().with_nullable(true).into()); - } + if let Expr::Placeholder(Placeholder { id: _, field }) = expr + && field.is_none() + { + let other_field = other.to_field(schema); + match other_field { + Err(e) => { + Err(e.context(format!( + "Can not find type of {other} needed to infer type of {expr}" + )))?; } - }; - } + Ok((_, other_field)) => { + // We can't infer the nullability of the future parameter that might + // be bound, so ensure this is set to true + *field = Some(other_field.as_ref().clone().with_nullable(true).into()); + } + } + }; Ok(()) } @@ -2842,6 +3067,12 @@ impl Display for SchemaDisplay<'_> { write!(f, "NOT IN") } Expr::InSubquery(InSubquery { negated: false, .. }) => write!(f, "IN"), + Expr::SetComparison(SetComparison { + expr, + op, + quantifier, + .. + }) => write!(f, "{} {op} {quantifier}", SchemaDisplay(expr.as_ref())), Expr::IsTrue(expr) => write!(f, "{} IS TRUE", SchemaDisplay(expr)), Expr::IsFalse(expr) => write!(f, "{} IS FALSE", SchemaDisplay(expr)), Expr::IsNotTrue(expr) => { @@ -2987,6 +3218,25 @@ impl Display for SchemaDisplay<'_> { } } } + Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => { + match func.schema_name(args) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!(f, "got error from schema_name {e}") + } + } + } + Expr::Lambda(Lambda { params, body }) => { + write!( + f, + "({}) -> {}", + display_comma_separated(params), + SchemaDisplay(body) + ) + } + Expr::LambdaVariable(c) => f.write_str(&c.name), } } } @@ -3167,6 +3417,9 @@ impl Display for SqlDisplay<'_> { } } } + Expr::Lambda(Lambda { params, body }) => { + write!(f, "({}) -> {}", params.join(", "), SchemaDisplay(body)) + } _ => write!(f, "{}", self.0), } } @@ -3283,11 +3536,15 @@ impl Display for Expr { } write!(f, "END") } - Expr::Cast(Cast { expr, data_type }) => { - write!(f, "CAST({expr} AS {data_type})") + Expr::Cast(Cast { expr, field }) => { + let formatted = + format_type_and_metadata(field.data_type(), Some(field.metadata())); + write!(f, "CAST({expr} AS {formatted})") } - Expr::TryCast(TryCast { expr, data_type }) => { - write!(f, "TRY_CAST({expr} AS {data_type})") + Expr::TryCast(TryCast { expr, field }) => { + let formatted = + format_type_and_metadata(field.data_type(), Some(field.metadata())); + write!(f, "TRY_CAST({expr} AS {formatted})") } Expr::Not(expr) => write!(f, "NOT {expr}"), Expr::Negative(expr) => write!(f, "(- {expr})"), @@ -3317,6 +3574,12 @@ impl Display for Expr { subquery, negated: false, }) => write!(f, "{expr} IN ({subquery:?})"), + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => write!(f, "{expr} {op} {quantifier} ({subquery:?})"), Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"), Expr::BinaryExpr(expr) => write!(f, "{expr}"), Expr::ScalarFunction(fun) => { @@ -3474,6 +3737,13 @@ impl Display for Expr { Expr::Unnest(Unnest { expr }) => { write!(f, "{UNNEST_COLUMN_PREFIX}({expr})") } + Expr::HigherOrderFunction(fun) => { + fmt_function(f, fun.name(), false, &fun.args, true) + } + Expr::Lambda(Lambda { params, body }) => { + write!(f, "({}) -> {body}", params.join(", ")) + } + Expr::LambdaVariable(c) => f.write_str(&c.name), } } } @@ -3511,13 +3781,12 @@ pub fn physical_name(expr: &Expr) -> Result { mod test { use crate::expr_fn::col; use crate::{ - case, lit, placeholder, qualified_wildcard, wildcard, wildcard_with_options, - ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility, + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Volatility, case, + lit, placeholder, qualified_wildcard, wildcard, wildcard_with_options, }; use arrow::datatypes::{Field, Schema}; use sqlparser::ast; use sqlparser::ast::{Ident, IdentWithAlias}; - use std::any::Any; #[test] fn infer_placeholder_in_clause() { @@ -3574,6 +3843,108 @@ mod test { } } + #[test] + fn infer_placeholder_in_subquery() { + // WHERE $1 IN (SELECT a FROM t) + let subquery_field = Field::new("a", DataType::Int32, false); + let subquery_schema = Arc::new( + DFSchema::from_unqualified_fields( + vec![subquery_field].into(), + Default::default(), + ) + .unwrap(), + ); + let subquery = Subquery { + subquery: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: subquery_schema, + })), + outer_ref_columns: vec![], + spans: Spans::new(), + }; + + let in_subquery = Expr::InSubquery(InSubquery { + expr: Box::new(Expr::Placeholder(Placeholder { + id: "$1".to_string(), + field: None, + })), + subquery, + negated: false, + }); + + let outer_schema = DFSchema::empty(); + let (inferred_expr, contains_placeholder) = + in_subquery.infer_placeholder_types(&outer_schema).unwrap(); + + assert!(contains_placeholder); + + match inferred_expr { + Expr::InSubquery(in_subquery) => match *in_subquery.expr { + Expr::Placeholder(placeholder) => { + let inferred = placeholder.field.expect("placeholder field"); + assert_eq!(inferred.data_type(), &DataType::Int32); + assert!(inferred.is_nullable()); + } + _ => panic!("Expected Placeholder expression in InSubquery"), + }, + _ => panic!("Expected InSubquery expression"), + } + } + + #[test] + fn infer_placeholder_not_in_subquery() { + // WHERE $1 NOT IN (SELECT a FROM t) + let subquery_field = Field::new("a", DataType::Int32, false); + let subquery_schema = Arc::new( + DFSchema::from_unqualified_fields( + vec![subquery_field].into(), + Default::default(), + ) + .unwrap(), + ); + let subquery = Subquery { + subquery: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: subquery_schema, + })), + outer_ref_columns: vec![], + spans: Spans::new(), + }; + + let not_in_subquery = Expr::InSubquery(InSubquery { + expr: Box::new(Expr::Placeholder(Placeholder { + id: "$1".to_string(), + field: None, + })), + subquery, + negated: true, + }); + + let outer_schema = DFSchema::empty(); + let (inferred_expr, contains_placeholder) = not_in_subquery + .infer_placeholder_types(&outer_schema) + .unwrap(); + + assert!(contains_placeholder); + + match inferred_expr { + Expr::InSubquery(in_subquery) => { + assert!(in_subquery.negated, "negated flag must be preserved"); + match *in_subquery.expr { + Expr::Placeholder(placeholder) => { + let inferred = placeholder.field.expect("placeholder field"); + assert_eq!(inferred.data_type(), &DataType::Int32); + assert!(inferred.is_nullable()); + } + _ => { + panic!("Expected Placeholder expression in InSubquery") + } + } + } + _ => panic!("Expected InSubquery expression"), + } + } + #[test] fn infer_placeholder_like_and_similar_to() { // name LIKE $1 @@ -3628,11 +3999,11 @@ mod test { #[test] fn infer_placeholder_with_metadata() { // name == $1, where name is a non-nullable string - let schema = - Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, false) - .with_metadata( - [("some_key".to_string(), "some_value".to_string())].into(), - )])); + let schema = Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, false).with_metadata( + [("some_key".to_string(), "some_value".to_string())].into(), + ), + ])); let df_schema = DFSchema::try_from(schema).unwrap(); let expr = binary_expr(col("name"), Operator::Eq, placeholder("$1")); @@ -3673,7 +4044,7 @@ mod test { fn format_cast() -> Result<()> { let expr = Expr::Cast(Cast { expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)), None)), - data_type: DataType::Utf8, + field: DataType::Utf8.into_nullable_field_ref(), }); let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; assert_eq!(expected_canonical, format!("{expr}")); @@ -3683,6 +4054,24 @@ mod test { Ok(()) } + #[test] + fn format_decimal_literal() { + let expr = lit(ScalarValue::Decimal128(Some(1), 1, 1)); + assert_eq!("Decimal128(0.1,1,1)", format!("{expr}")); + assert_eq!("Decimal128(0.1,1,1)", expr.schema_name().to_string()); + assert_eq!("0.1", expr.human_display().to_string()); + + let expr = lit(ScalarValue::Decimal128(Some(120), 3, 2)); + assert_eq!("Decimal128(1.20,3,2)", format!("{expr}")); + assert_eq!("Decimal128(1.20,3,2)", expr.schema_name().to_string()); + assert_eq!("1.20", expr.human_display().to_string()); + + let null_expr = lit(ScalarValue::Decimal128(None, 10, 2)); + assert_eq!("Decimal128(NULL,10,2)", format!("{null_expr}")); + assert_eq!("Decimal128(NULL,10,2)", null_expr.schema_name().to_string()); + assert_eq!("NULL", null_expr.human_display().to_string()); + } + #[test] fn test_partial_ord() { // Test validates that partial ord is defined for Expr, not @@ -3762,9 +4151,6 @@ mod test { signature: Signature, } impl ScalarUDFImpl for TestScalarUDF { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "TestScalarUDF" } @@ -3800,6 +4186,7 @@ mod test { } use super::*; + use crate::logical_plan::{EmptyRelation, LogicalPlan}; #[test] fn test_display_wildcard() { @@ -3826,8 +4213,8 @@ mod test { wildcard_with_options(wildcard_options( None, Some(ExcludeSelectItem::Multiple(vec![ - Ident::from("c1"), - Ident::from("c2") + Ident::from("c1").into(), + Ident::from("c2").into() ])), None, None, @@ -3890,6 +4277,28 @@ mod test { ) } + #[test] + fn test_display_set_comparison() { + let subquery = Subquery { + subquery: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + })), + outer_ref_columns: vec![], + spans: Spans::new(), + }; + + let expr = Expr::SetComparison(SetComparison::new( + Box::new(Expr::Column(Column::from_name("a"))), + subquery, + Operator::Gt, + SetQuantifier::Any, + )); + + assert_eq!(format!("{expr}"), "a > ANY ()"); + assert_eq!(format!("{}", expr.human_display()), "a > ANY ()"); + } + #[test] fn test_schema_display_alias_with_relation() { assert_eq!( @@ -3914,6 +4323,36 @@ mod test { ); } + #[test] + fn test_unalias_nested_respects_user_metadata() { + use std::collections::HashMap; + + let base_expr = col("id"); + + let no_metadata = base_expr.clone().alias("alias"); + assert_eq!(no_metadata.unalias_nested().data, base_expr); + + let Expr::Alias(empty_metadata_alias) = base_expr.clone().alias("alias") else { + unreachable!(); + }; + let empty_metadata_alias = Expr::Alias( + empty_metadata_alias.with_metadata(Some(FieldMetadata::default())), + ); + assert_eq!(empty_metadata_alias.unalias_nested().data, base_expr); + + let user_metadata = FieldMetadata::from(HashMap::from([( + "some_key".to_string(), + "some_value".to_string(), + )])); + + let Expr::Alias(user_alias) = base_expr.clone().alias("alias") else { + unreachable!(); + }; + let user_alias = + Expr::Alias(user_alias.with_metadata(Some(user_metadata.clone()))); + assert_eq!(user_alias.clone().unalias_nested().data, user_alias); + } + fn wildcard_options( opt_ilike: Option, opt_exclude: Option, @@ -3974,10 +4413,6 @@ mod test { #[derive(Debug, PartialEq, Eq, Hash)] struct TestUDF {} impl ScalarUDFImpl for TestUDF { - fn as_any(&self) -> &dyn Any { - unimplemented!() - } - fn name(&self) -> &str { unimplemented!() } @@ -3998,4 +4433,67 @@ mod test { } } } + + mod intersect_metadata_tests { + use super::super::intersect_metadata_for_union; + use std::collections::HashMap; + + #[test] + fn all_branches_same_metadata() { + let m1 = HashMap::from([("key".into(), "val".into())]); + let m2 = HashMap::from([("key".into(), "val".into())]); + let result = intersect_metadata_for_union([&m1, &m2]); + assert_eq!(result, HashMap::from([("key".into(), "val".into())])); + } + + #[test] + fn conflicting_metadata_dropped() { + let m1 = HashMap::from([("key".into(), "a".into())]); + let m2 = HashMap::from([("key".into(), "b".into())]); + let result = intersect_metadata_for_union([&m1, &m2]); + assert!(result.is_empty()); + } + + #[test] + fn empty_metadata_branch_skipped() { + let m1 = HashMap::from([("key".into(), "val".into())]); + let m2 = HashMap::new(); // e.g. NULL literal + let result = intersect_metadata_for_union([&m1, &m2]); + assert_eq!(result, HashMap::from([("key".into(), "val".into())])); + } + + #[test] + fn empty_metadata_first_branch_skipped() { + let m1 = HashMap::new(); + let m2 = HashMap::from([("key".into(), "val".into())]); + let result = intersect_metadata_for_union([&m1, &m2]); + assert_eq!(result, HashMap::from([("key".into(), "val".into())])); + } + + #[test] + fn all_branches_empty_metadata() { + let m1: HashMap = HashMap::new(); + let m2: HashMap = HashMap::new(); + let result = intersect_metadata_for_union([&m1, &m2]); + assert!(result.is_empty()); + } + + #[test] + fn mixed_empty_and_conflicting() { + let m1 = HashMap::from([("key".into(), "a".into())]); + let m2 = HashMap::new(); + let m3 = HashMap::from([("key".into(), "b".into())]); + let result = intersect_metadata_for_union([&m1, &m2, &m3]); + // m2 is skipped; m1 and m3 conflict → dropped + assert!(result.is_empty()); + } + + #[test] + fn no_inputs() { + let result = intersect_metadata_for_union(std::iter::empty::< + &HashMap, + >()); + assert!(result.is_empty()); + } + } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 94d8009ce814e..9d711113e4f74 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -18,8 +18,9 @@ //! Functions for creating logical expressions use crate::expr::{ - AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, - NullTreatment, Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction, + AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, Lambda, + LambdaVariable, NullTreatment, Placeholder, TryCast, Unnest, WildcardOptions, + WindowFunction, }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, @@ -28,9 +29,9 @@ use crate::function::{ use crate::ptr_eq::PtrEq; use crate::select_expr::SelectExpr; use crate::{ - conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery, AggregateUDF, Expr, LimitEffect, LogicalPlan, Operator, PartitionEvaluator, ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, + conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery, }; use crate::{ AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, @@ -39,11 +40,10 @@ use arrow::compute::kernels::cast_utils::{ parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, }; use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion_common::{plan_err, Column, Result, ScalarValue, Spans, TableReference}; +use datafusion_common::{Column, Result, ScalarValue, Spans, TableReference, plan_err}; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use std::any::Any; use std::collections::HashMap; use std::fmt::Debug; use std::hash::Hash; @@ -478,10 +478,6 @@ impl SimpleScalarUDF { } impl ScalarUDFImpl for SimpleScalarUDF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { &self.name } @@ -592,10 +588,6 @@ impl SimpleAggregateUDF { } impl AggregateUDFImpl for SimpleAggregateUDF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { &self.name } @@ -685,10 +677,6 @@ impl SimpleWindowUDF { } impl WindowUDFImpl for SimpleWindowUDF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { &self.name } @@ -732,6 +720,25 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { Expr::Literal(ScalarValue::IntervalMonthDayNano(interval), None) } +/// Create a lambda expression +pub fn lambda(params: impl IntoIterator>, body: Expr) -> Expr { + Expr::Lambda(Lambda::new( + params.into_iter().map(Into::into).collect(), + body, + )) +} + +/// Create an unresolved lambda variable expression +/// +/// The expression tree or [`LogicalPlan`] which +/// owns this variable must be resolved before usage with either +/// [`Expr::resolve_lambda_variables`] or [`LogicalPlan::resolve_lambda_variables`]. +/// +/// [LogicalPlan::resolve_lambda_variables]: crate::LogicalPlan::resolve_lambda_variables +pub fn lambda_var(name: impl Into) -> Expr { + Expr::LambdaVariable(LambdaVariable::new(name.into(), None)) +} + /// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] /// /// Adds methods to [`Expr`] that make it easy to set optional options diff --git a/datafusion/expr/src/expr_rewriter/guarantees.rs b/datafusion/expr/src/expr_rewriter/guarantees.rs index b8589a17df3e4..61fbbdba43aa9 100644 --- a/datafusion/expr/src/expr_rewriter/guarantees.rs +++ b/datafusion/expr/src/expr_rewriter/guarantees.rs @@ -17,7 +17,7 @@ //! Rewrite expressions based on external expression value range guarantees. -use crate::{expr::InList, lit, Between, BinaryExpr, Expr}; +use crate::{Between, BinaryExpr, Expr, expr::InList, lit}; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{DataFusionError, HashMap, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::{Interval, NullableInterval}; @@ -102,10 +102,10 @@ fn rewrite_expr( guarantees: &HashMap<&Expr, &NullableInterval>, ) -> Result> { // If an expression collapses to a single value, replace it with a literal - if let Some(interval) = guarantees.get(&expr) { - if let Some(value) = interval.single_value() { - return Ok(Transformed::yes(lit(value))); - } + if let Some(interval) = guarantees.get(&expr) + && let Some(value) = interval.single_value() + { + return Ok(Transformed::yes(lit(value))); } let result = match expr { @@ -302,9 +302,8 @@ fn rewrite_inlist( mod tests { use super::*; - use crate::{col, Operator}; + use crate::{Operator, col}; use datafusion_common::tree_node::TransformedResult; - use datafusion_common::ScalarValue; #[test] fn test_not_null_guarantee() { diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 31759f1cc9cfe..a9a0c156538f9 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -26,15 +26,15 @@ use crate::expr::{Alias, Sort, Unnest}; use crate::logical_plan::Projection; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; +use datafusion_common::TableReference; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::TableReference; use datafusion_common::{Column, DFSchema, Result}; mod guarantees; +pub use guarantees::GuaranteeRewriter; pub use guarantees::rewrite_with_guarantees; pub use guarantees::rewrite_with_guarantees_map; -pub use guarantees::GuaranteeRewriter; mod order_by; pub use order_by::rewrite_sort_cols_by_aggs; @@ -260,7 +260,18 @@ fn coerce_exprs_for_schema( } #[expect(deprecated)] Expr::Wildcard { .. } => Ok(expr), - _ => expr.cast_to(new_type, src_schema), + _ => { + match expr { + // maintain the original name when casting a column, to avoid the + // tablename being added to it when not explicitly set by the query + // (see: https://github.com/apache/datafusion/issues/18818) + Expr::Column(ref column) => { + let name = column.name().to_owned(); + Ok(expr.cast_to(new_type, src_schema)?.alias(name)) + } + _ => Ok(expr.cast_to(new_type, src_schema)?), + } + } } } else { Ok(expr) @@ -329,8 +340,16 @@ impl NamePreserver { pub fn save(&self, expr: &Expr) -> SavedName { if self.use_alias { - let (relation, name) = expr.qualified_name(); - SavedName::Saved { relation, name } + match expr { + Expr::Alias(alias) => SavedName::Saved { + relation: alias.relation.clone(), + name: alias.name.clone(), + }, + _ => { + let (relation, name) = expr.qualified_name(); + SavedName::Saved { relation, name } + } + } } else { SavedName::None } @@ -360,10 +379,10 @@ mod test { use super::*; use crate::literal::lit_with_metadata; - use crate::{col, lit, Cast}; + use crate::{Cast, col, lit}; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::tree_node::TreeNodeRewriter; use datafusion_common::ScalarValue; + use datafusion_common::tree_node::TreeNodeRewriter; #[derive(Default)] struct RecordingRewriter { @@ -464,7 +483,7 @@ mod test { normalize_col_with_schemas_and_ambiguity_check(expr, &[&schemas], &[]) .unwrap_err() .strip_backtrace(); - let expected = "Schema error: No field named b. \ + let expected = "Schema error: No field named b.\n\ Valid fields are \"tableA\".a."; assert_eq!(error, expected); } diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index c21c6e6222a05..720788113c6cb 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -19,11 +19,9 @@ use crate::expr::Alias; use crate::expr_rewriter::normalize_col; -use crate::{expr::Sort, Cast, Expr, LogicalPlan, TryCast}; +use crate::{Cast, Expr, LogicalPlan, TryCast, expr::Sort}; -use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, -}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Column, Result}; /// Rewrite sort on aggregate expressions to sort on the column of aggregate output @@ -77,8 +75,10 @@ fn rewrite_in_terms_of_projection( // assumption is that each item in exprs, such as "b + c" is // available as an output column named "b + c" expr.transform(|expr| { - // search for unnormalized names first such as "c1" (such as aliases) - if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { + // search for unnormalized names first such as "c1" (such as aliases). + // Also look inside aliases so e.g. `count(Int64(1))` matches + // `count(Int64(1)) AS count(*)`. + if let Some(found) = proj_exprs.iter().find(|a| expr_match(&expr, a)) { let (qualifier, field_name) = found.qualified_name(); let col = Expr::Column(Column::new(qualifier, field_name)); return Ok(Transformed::yes(col)); @@ -102,29 +102,27 @@ fn rewrite_in_terms_of_projection( let search_col = Expr::Column(Column::new_unqualified(name)); - // look for the column named the same as this expr - let mut found = None; - for proj_expr in proj_exprs { - proj_expr.apply(|e| { - if expr_match(&search_col, e) { - found = Some(e.clone()); - return Ok(TreeNodeRecursion::Stop); - } - Ok(TreeNodeRecursion::Continue) - })?; - } + // Search only top-level projection expressions for a match. + // We intentionally avoid a recursive search (e.g. `apply`) to + // prevent matching sub-expressions of composites like + // `min(c2) + max(c3)` when the ORDER BY is just `min(c2)`. + let found = proj_exprs + .iter() + .find(|proj_expr| expr_match(&search_col, proj_expr)); if let Some(found) = found { + let (qualifier, field_name) = found.qualified_name(); + let col = Expr::Column(Column::new(qualifier, field_name)); return Ok(Transformed::yes(match normalized_expr { - Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast { - expr: Box::new(found), - data_type, + Expr::Cast(Cast { expr: _, field }) => Expr::Cast(Cast { + expr: Box::new(col), + field, }), - Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast { - expr: Box::new(found), - data_type, + Expr::TryCast(TryCast { expr: _, field }) => Expr::TryCast(TryCast { + expr: Box::new(col), + field, }), - _ => found, + _ => col, })); } @@ -152,13 +150,16 @@ mod test { use arrow::datatypes::{DataType, Field, Schema}; use crate::{ - cast, col, lit, logical_plan::builder::LogicalTableSource, try_cast, - LogicalPlanBuilder, + LogicalPlanBuilder, cast, col, lit, logical_plan::builder::LogicalTableSource, + try_cast, }; use super::*; use crate::test::function_stub::avg; + use crate::test::function_stub::count; + use crate::test::function_stub::max; use crate::test::function_stub::min; + use crate::test::function_stub::sum; #[test] fn rewrite_sort_cols_by_agg() { @@ -235,18 +236,19 @@ mod test { TestCase { desc: r#"min(c2) --> "min(c2)" -- (column *named* "min(t.c2)"!)"#, input: sort(min(col("c2"))), - expected: sort(col("min(t.c2)")), + expected: sort(Expr::Column(Column::new_unqualified("min(t.c2)"))), }, TestCase { desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* "min(t.c2)"!)"#, input: sort(col("c1") + min(col("c2"))), - // should be "c1" not t.c1 - expected: sort(col("c1") + col("min(t.c2)")), + expected: sort( + col("c1") + Expr::Column(Column::new_unqualified("min(t.c2)")), + ), }, TestCase { - desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#, + desc: r#"avg(c3) --> "average" (column *named* "average", from alias)"#, input: sort(avg(col("c3"))), - expected: sort(col("avg(t.c3)").alias("average")), + expected: sort(col("average")), }, ]; @@ -255,6 +257,202 @@ mod test { } } + /// When an aggregate is aliased in the projection, + /// ORDER BY on the original aggregate expression should resolve to + /// a Column reference using the alias name — not leak the inner + /// Alias expression node or resolve to a descendant subtree. + #[test] + fn rewrite_sort_resolves_alias_to_column_ref() { + let plan = make_input() + .aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c3"))]) + .unwrap() + .project(vec![ + col("c1"), + min(col("c2")).alias("min_val"), + max(col("c3")).alias("max_val"), + ]) + .unwrap() + .build() + .unwrap(); + + let cases = vec![ + TestCase { + desc: "min(c2) with alias 'min_val' should resolve to col(min_val)", + input: sort(min(col("c2"))), + expected: sort(col("min_val")), + }, + TestCase { + desc: "max(c3) with alias 'max_val' should resolve to col(max_val)", + input: sort(max(col("c3"))), + expected: sort(col("max_val")), + }, + ]; + + for case in cases { + case.run(&plan) + } + } + + #[test] + fn composite_proj_expr_containing_sort_col_as_subexpr() { + let plan = make_input() + .aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c3"))]) + .unwrap() + .project(vec![ + col("c1"), + (min(col("c2")) + max(col("c3"))).alias("range"), + min(col("c2")).alias("min_val"), + max(col("c3")).alias("max_val"), + ]) + .unwrap() + .build() + .unwrap(); + + let cases = vec![ + TestCase { + desc: "sort by min(c2) should resolve to col(min_val), not col(range)", + input: sort(min(col("c2"))), + expected: sort(col("min_val")), + }, + TestCase { + desc: "sort by max(c3) should resolve to col(max_val), not col(range)", + input: sort(max(col("c3"))), + expected: sort(col("max_val")), + }, + ]; + + for case in cases { + case.run(&plan) + } + } + + #[test] + fn composite_before_standalone_should_not_shadow() { + let plan = make_input() + .aggregate(vec![col("c1")], vec![min(col("c2")), max(col("c2"))]) + .unwrap() + .project(vec![ + col("c1"), + (min(col("c2")) + max(col("c2"))).alias("combined"), + min(col("c2")), + ]) + .unwrap() + .build() + .unwrap(); + + let cases = vec![TestCase { + desc: "sort by min(c2) should resolve to col(min(t.c2)), not col(combined)", + input: sort(min(col("c2"))), + expected: sort(Expr::Column(Column::new_unqualified("min(t.c2)"))), + }]; + + for case in cases { + case.run(&plan) + } + } + + #[test] + fn duplicate_aggregate_in_multiple_proj_exprs() { + let plan = make_input() + .aggregate(vec![col("c1")], vec![min(col("c2"))]) + .unwrap() + .project(vec![ + col("c1"), + min(col("c2")).alias("first_alias"), + min(col("c2")).alias("second_alias"), + ]) + .unwrap() + .build() + .unwrap(); + + let cases = vec![TestCase { + desc: "sort by min(c2) with two aliases picks first_alias", + input: sort(min(col("c2"))), + expected: sort(col("first_alias")), + }]; + + for case in cases { + case.run(&plan) + } + } + + #[test] + fn sort_agg_not_in_select_with_aliased_aggs() { + let plan = make_input() + .aggregate( + vec![col("c1")], + vec![min(col("c2")), max(col("c3")), sum(col("c3"))], + ) + .unwrap() + .project(vec![ + col("c1"), + min(col("c2")).alias("min_val"), + max(col("c3")).alias("max_val"), + ]) + .unwrap() + .build() + .unwrap(); + + let cases = vec![TestCase { + desc: "sort by sum(c3) not in projection should not be rewritten", + input: sort(sum(col("c3"))), + expected: sort(sum(col("c3"))), + }]; + + for case in cases { + case.run(&plan) + } + } + + #[test] + fn cast_on_aliased_aggregate() { + let plan = make_input() + .aggregate(vec![col("c1")], vec![min(col("c2"))]) + .unwrap() + .project(vec![col("c1"), min(col("c2")).alias("min_val")]) + .unwrap() + .build() + .unwrap(); + + let cases = vec![ + TestCase { + desc: "CAST on aliased aggregate should preserve cast and resolve alias", + input: sort(cast(min(col("c2")), DataType::Int64)), + expected: sort(cast(col("min_val"), DataType::Int64)), + }, + TestCase { + desc: "TryCast on aliased aggregate should preserve try_cast and resolve alias", + input: sort(try_cast(min(col("c2")), DataType::Int64)), + expected: sort(try_cast(col("min_val"), DataType::Int64)), + }, + ]; + + for case in cases { + case.run(&plan) + } + } + + #[test] + fn count_star_with_alias() { + let plan = make_input() + .aggregate(vec![col("c1")], vec![count(lit(1))]) + .unwrap() + .project(vec![col("c1"), count(lit(1)).alias("cnt")]) + .unwrap() + .build() + .unwrap(); + + let cases = vec![TestCase { + desc: "sort by count(1) should resolve to cnt alias", + input: sort(count(lit(1))), + expected: sort(col("cnt")), + }]; + + for case in cases { + case.run(&plan) + } + } + #[test] fn preserve_cast() { let plan = make_input() @@ -269,12 +467,12 @@ mod test { TestCase { desc: "Cast is preserved by rewrite_sort_cols_by_aggs", input: sort(cast(col("c2"), DataType::Int64)), - expected: sort(cast(col("c2").alias("c2"), DataType::Int64)), + expected: sort(cast(col("c2"), DataType::Int64)), }, TestCase { desc: "TryCast is preserved by rewrite_sort_cols_by_aggs", input: sort(try_cast(col("c2"), DataType::Int64)), - expected: sort(try_cast(col("c2").alias("c2"), DataType::Int64)), + expected: sort(try_cast(col("c2"), DataType::Int64)), }, ]; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 3ef61da91bd82..039bbad65a660 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -15,23 +15,26 @@ // specific language governing permissions and limitations // under the License. -use super::{predicate_bounds, Between, Expr, Like}; +use super::{Between, Expr, Like, predicate_bounds}; +use crate::ValueOrLambda; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, - InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, + InSubquery, Lambda, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; -use crate::type_coercion::functions::{ - data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf, -}; +use crate::expr::{FieldMetadata, LambdaVariable}; +use crate::higher_order_function::HigherOrderReturnFieldArgs; +use crate::type_coercion::functions::value_fields_with_higher_order_udf_and_lambdas; +use crate::type_coercion::functions::{UDFCoercionExt, fields_with_udf}; use crate::udf::ReturnFieldArgs; -use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; +use crate::{LogicalPlan, Projection, Subquery, WindowFunctionDefinition, utils}; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion_common::metadata::FieldMetadata; +use arrow::datatypes::FieldRef; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::datatype::FieldExt; use datafusion_common::{ - not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, - Result, ScalarValue, Spans, TableReference, + Column, DataFusionError, ExprSchema, Result, ScalarValue, Spans, TableReference, + not_impl_err, plan_datafusion_err, plan_err, }; use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer; use datafusion_functions_window_common::field::WindowUDFFieldArgs; @@ -58,8 +61,30 @@ pub trait ExprSchemable { fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result; /// Given a schema, return the type and nullability of the expr + #[deprecated( + since = "51.0.0", + note = "Use `to_field().1.is_nullable` and `to_field().1.data_type()` directly instead" + )] fn data_type_and_nullable(&self, schema: &dyn ExprSchema) - -> Result<(DataType, bool)>; + -> Result<(DataType, bool)>; +} + +/// Derives the output field for a cast expression from the source field. +/// For `TryCast`, `force_nullable` is `true` since a failed cast returns NULL. +fn cast_output_field( + source_field: &FieldRef, + target_type: &DataType, + force_nullable: bool, +) -> Arc { + let mut f = source_field + .as_ref() + .clone() + .with_data_type(target_type.clone()) + .with_metadata(source_field.metadata().clone()); + if force_nullable { + f = f.with_nullable(true); + } + Arc::new(f) } impl ExprSchemable for Expr { @@ -116,7 +141,7 @@ impl ExprSchemable for Expr { Expr::Negative(expr) => expr.get_type(schema), Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(field, _) => Ok(field.data_type().clone()), - Expr::ScalarVariable(ty, _) => Ok(ty.clone()), + Expr::ScalarVariable(field, _) => Ok(field.data_type().clone()), Expr::Literal(l, _) => Ok(l.data_type()), Expr::Case(case) => { for (_, then_expr) in &case.when_then_expr { @@ -129,15 +154,18 @@ impl ExprSchemable for Expr { .as_ref() .map_or(Ok(DataType::Null), |e| e.get_type(schema)) } - Expr::Cast(Cast { data_type, .. }) - | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), + Expr::Cast(Cast { field, .. }) | Expr::TryCast(TryCast { field, .. }) => { + Ok(field.data_type().clone()) + } Expr::Unnest(Unnest { expr }) => { let arg_data_type = expr.get_type(schema)?; // Unnest's output type is the inner type of the list match arg_data_type { DataType::List(field) | DataType::LargeList(field) - | DataType::FixedSizeList(field, _) => Ok(field.data_type().clone()), + | DataType::FixedSizeList(field, _) + | DataType::ListView(field) + | DataType::LargeListView(field) => Ok(field.data_type().clone()), DataType::Struct(_) => Ok(arg_data_type), DataType::Null => { not_impl_err!("unnest() does not support null yet") @@ -149,48 +177,16 @@ impl ExprSchemable for Expr { } } } - Expr::ScalarFunction(_func) => { - let (return_type, _) = self.data_type_and_nullable(schema)?; - Ok(return_type) - } - Expr::WindowFunction(window_function) => self - .data_type_and_nullable_with_window_function(schema, window_function) - .map(|(return_type, _)| return_type), - Expr::AggregateFunction(AggregateFunction { - func, - params: AggregateFunctionParams { args, .. }, - }) => { - let fields = args - .iter() - .map(|e| e.to_field(schema).map(|(_, f)| f)) - .collect::>>()?; - let new_fields = fields_with_aggregate_udf(&fields, func) - .map_err(|err| { - let data_types = fields - .iter() - .map(|f| f.data_type().clone()) - .collect::>(); - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - func.name(), - func.signature().clone(), - &data_types - ) - ) - })? - .into_iter() - .collect::>(); - Ok(func.return_field(&new_fields)?.data_type().clone()) + Expr::ScalarFunction(_) + | Expr::WindowFunction(_) + | Expr::AggregateFunction(_) => { + Ok(self.to_field(schema)?.1.data_type().clone()) } Expr::Not(_) | Expr::IsNull(_) | Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::Between { .. } | Expr::InList { .. } | Expr::IsNotNull(_) @@ -203,11 +199,7 @@ impl ExprSchemable for Expr { Expr::ScalarSubquery(subquery) => { Ok(subquery.subquery.schema().field(0).data_type().clone()) } - Expr::BinaryExpr(BinaryExpr { - ref left, - ref right, - ref op, - }) => BinaryTypeCoercer::new( + Expr::BinaryExpr(BinaryExpr { left, right, op }) => BinaryTypeCoercer::new( &left.get_type(schema)?, op, &right.get_type(schema)?, @@ -229,6 +221,16 @@ impl ExprSchemable for Expr { // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } + Expr::HigherOrderFunction(_func) => { + Ok(self.to_field(schema)?.1.data_type().clone()) + } + Expr::Lambda(_lambda) => Ok(DataType::Null), + Expr::LambdaVariable(LambdaVariable { field, .. }) => match field { + Some(f) => Ok(f.data_type().clone()), + // If the lambda variable's field hasn't been specified, treat it as + // null (unspecified lambda variables generate an error during planning) + None => Ok(DataType::Null), + }, } } @@ -349,25 +351,11 @@ impl ExprSchemable for Expr { } } Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), - Expr::ScalarFunction(_func) => { - let (_, nullable) = self.data_type_and_nullable(input_schema)?; - Ok(nullable) - } - Expr::AggregateFunction(AggregateFunction { func, .. }) => { - Ok(func.is_nullable()) - } - Expr::WindowFunction(window_function) => self - .data_type_and_nullable_with_window_function( - input_schema, - window_function, - ) - .map(|(_, nullable)| nullable), - Expr::Placeholder(Placeholder { id: _, field }) => { - Ok(field.as_ref().map(|f| f.is_nullable()).unwrap_or(true)) - } - Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::Unnest(_) => { - Ok(true) - } + Expr::ScalarFunction(_) + | Expr::AggregateFunction(_) + | Expr::WindowFunction(_) => Ok(self.to_field(input_schema)?.1.is_nullable()), + Expr::ScalarVariable(field, _) => Ok(field.is_nullable()), + Expr::TryCast { .. } | Expr::Unnest(_) | Expr::Placeholder(_) => Ok(true), Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) @@ -377,15 +365,14 @@ impl ExprSchemable for Expr { | Expr::IsNotFalse(_) | Expr::IsNotUnknown(_) | Expr::Exists { .. } => Ok(false), + Expr::SetComparison(_) => Ok(true), Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema), Expr::ScalarSubquery(subquery) => { Ok(subquery.subquery.schema().field(0).is_nullable()) } - Expr::BinaryExpr(BinaryExpr { - ref left, - ref right, - .. - }) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?), + Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { + Ok(left.nullable(input_schema)? || right.nullable(input_schema)?) + } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?) @@ -397,6 +384,16 @@ impl ExprSchemable for Expr { // in projections Ok(true) } + Expr::HigherOrderFunction(_func) => { + Ok(self.to_field(input_schema)?.1.is_nullable()) + } + Expr::Lambda(_lambda) => Ok(true), + Expr::LambdaVariable(LambdaVariable { field, .. }) => match field { + Some(f) => Ok(f.is_nullable()), + // If the lambda variable's field hasn't been specified, treat it as + // null (unspecified lambda variables generate an error during planning) + None => Ok(true), + }, } } @@ -463,7 +460,7 @@ impl ExprSchemable for Expr { /// with the default implementation returning empty field metadata /// - **Aggregate functions**: Generate metadata via function's [`return_field`] method, /// with the default implementation returning empty field metadata - /// - **Window functions**: field metadata is empty + /// - **Window functions**: field metadata follows the function's return field /// /// ## Table Reference Scoping /// - Establishes proper qualified field references when columns belong to specific tables @@ -479,7 +476,7 @@ impl ExprSchemable for Expr { schema: &dyn ExprSchema, ) -> Result<(Option, Arc)> { let (relation, schema_name) = self.qualified_name(); - #[allow(deprecated)] + #[expect(deprecated)] let field = match self { Expr::Alias(Alias { expr, @@ -487,30 +484,26 @@ impl ExprSchemable for Expr { metadata, .. }) => { - let field = expr.to_field(schema).map(|(_, f)| f.as_ref().clone())?; - let mut combined_metadata = expr.metadata(schema)?; if let Some(metadata) = metadata { combined_metadata.extend(metadata.clone()); } - Ok(Arc::new(combined_metadata.add_to_field(field))) + Ok(expr + .to_field(schema) + .map(|(_, f)| f)? + .with_field_metadata(&combined_metadata)) } Expr::Negative(expr) => expr.to_field(schema).map(|(_, f)| f), - Expr::Column(c) => schema.field_from_column(c).map(|f| Arc::new(f.clone())), + Expr::Column(c) => schema.field_from_column(c).map(Arc::clone), Expr::OuterReferenceColumn(field, _) => { - Ok(Arc::new(field.as_ref().clone().with_name(&schema_name))) - } - Expr::ScalarVariable(ty, _) => { - Ok(Arc::new(Field::new(&schema_name, ty.clone(), true))) - } - Expr::Literal(l, metadata) => { - let mut field = Field::new(&schema_name, l.data_type(), l.is_null()); - if let Some(metadata) = metadata { - field = metadata.add_to_field(field); - } - Ok(Arc::new(field)) + Ok(Arc::clone(field).renamed(&schema_name)) } + Expr::ScalarVariable(field, _) => Ok(Arc::clone(field).renamed(&schema_name)), + Expr::Literal(l, metadata) => Ok(Arc::new( + Field::new(&schema_name, l.data_type(), l.is_null()) + .with_field_metadata_opt(metadata.as_ref()), + )), Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) @@ -525,14 +518,15 @@ impl ExprSchemable for Expr { Expr::ScalarSubquery(subquery) => { Ok(Arc::clone(&subquery.subquery.schema().fields()[0])) } - Expr::BinaryExpr(BinaryExpr { - ref left, - ref right, - ref op, - }) => { - let (lhs_type, lhs_nullable) = left.data_type_and_nullable(schema)?; - let (rhs_type, rhs_nullable) = right.data_type_and_nullable(schema)?; - let mut coercer = BinaryTypeCoercer::new(&lhs_type, op, &rhs_type); + Expr::BinaryExpr(BinaryExpr { left, right, op }) => { + let (left_field, right_field) = + (left.to_field(schema)?.1, right.to_field(schema)?.1); + + let (lhs_type, lhs_nullable) = + (left_field.data_type(), left_field.is_nullable()); + let (rhs_type, rhs_nullable) = + (right_field.data_type(), right_field.is_nullable()); + let mut coercer = BinaryTypeCoercer::new(lhs_type, op, rhs_type); coercer.set_lhs_spans(left.spans().cloned().unwrap_or_default()); coercer.set_rhs_spans(right.spans().cloned().unwrap_or_default()); Ok(Arc::new(Field::new( @@ -542,79 +536,49 @@ impl ExprSchemable for Expr { ))) } Expr::WindowFunction(window_function) => { - let (dt, nullable) = self.data_type_and_nullable_with_window_function( - schema, - window_function, - )?; - Ok(Arc::new(Field::new(&schema_name, dt, nullable))) - } - Expr::AggregateFunction(aggregate_function) => { - let AggregateFunction { - func, - params: AggregateFunctionParams { args, .. }, + let WindowFunction { + fun, + params: WindowFunctionParams { args, .. }, .. - } = aggregate_function; + } = window_function.as_ref(); let fields = args .iter() .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; - // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` - let new_fields = fields_with_aggregate_udf(&fields, func) - .map_err(|err| { - let arg_types = fields - .iter() - .map(|f| f.data_type()) - .cloned() - .collect::>(); - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - func.name(), - func.signature().clone(), - &arg_types, - ) - ) - })? - .into_iter() - .collect::>(); - + match fun { + WindowFunctionDefinition::AggregateUDF(udaf) => { + let new_fields = + verify_function_arguments(udaf.as_ref(), &fields)?; + let return_field = udaf.return_field(&new_fields)?; + Ok(return_field) + } + WindowFunctionDefinition::WindowUDF(udwf) => { + let new_fields = + verify_function_arguments(udwf.as_ref(), &fields)?; + let return_field = udwf + .field(WindowUDFFieldArgs::new(&new_fields, &schema_name))?; + Ok(return_field) + } + } + } + Expr::AggregateFunction(AggregateFunction { + func, + params: AggregateFunctionParams { args, .. }, + }) => { + let fields = args + .iter() + .map(|e| e.to_field(schema).map(|(_, f)| f)) + .collect::>>()?; + let new_fields = verify_function_arguments(func.as_ref(), &fields)?; func.return_field(&new_fields) } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let (arg_types, fields): (Vec, Vec>) = args + let fields = args .iter() .map(|e| e.to_field(schema).map(|(_, f)| f)) - .collect::>>()? - .into_iter() - .map(|f| (f.data_type().clone(), f)) - .unzip(); - // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` - let new_data_types = data_types_with_scalar_udf(&arg_types, func) - .map_err(|err| { - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - func.name(), - func.signature().clone(), - &arg_types, - ) - ) - })?; - let new_fields = fields - .into_iter() - .zip(new_data_types) - .map(|(f, d)| f.as_ref().clone().with_data_type(d)) - .map(Arc::new) - .collect::>(); + .collect::>>()?; + let new_fields = verify_function_arguments(func.as_ref(), &fields)?; let arguments = args .iter() @@ -631,35 +595,85 @@ impl ExprSchemable for Expr { func.return_field_from_args(args) } // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), - Expr::Cast(Cast { expr, data_type }) => expr - .to_field(schema) - .map(|(_, f)| f.as_ref().clone().with_data_type(data_type.clone())) - .map(Arc::new), + Expr::Cast(Cast { expr, field }) => { + expr.to_field(schema).map(|(_table_ref, src)| { + cast_output_field(&src, field.data_type(), false) + }) + } Expr::Placeholder(Placeholder { id: _, field: Some(field), - }) => Ok(field.as_ref().clone().with_name(&schema_name).into()), + }) => Ok(Arc::clone(field).renamed(&schema_name)), + Expr::TryCast(TryCast { expr, field }) => { + expr.to_field(schema).map(|(_table_ref, src)| { + cast_output_field(&src, field.data_type(), true) + }) + } + Expr::LambdaVariable(LambdaVariable { + field: Some(field), .. + }) => Ok(Arc::clone(field).renamed(&schema_name)), Expr::Like(_) | Expr::SimilarTo(_) | Expr::Not(_) | Expr::Between(_) | Expr::Case(_) - | Expr::TryCast(_) | Expr::InList(_) | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::Unnest(_) => Ok(Arc::new(Field::new( + | Expr::Unnest(_) + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => Ok(Arc::new(Field::new( &schema_name, self.get_type(schema)?, self.nullable(schema)?, ))), + Expr::HigherOrderFunction(func) => { + let arg_fields = func + .args + .iter() + .map(|arg| match arg { + Expr::Lambda(Lambda { params: _, body }) => { + // use the name of the lambda instead of just the body to help with debugging + Ok(ValueOrLambda::Lambda(Arc::new(Field::new( + arg.qualified_name().1, + body.get_type(schema)?, + body.nullable(schema)?, + )))) + } + _ => Ok(ValueOrLambda::Value(arg.to_field(schema)?.1)), + }) + .collect::>>()?; + + let new_fields = value_fields_with_higher_order_udf_and_lambdas( + &arg_fields, + func.func.as_ref(), + )?; + + let arguments = func + .args + .iter() + .map(|e| match e { + Expr::Literal(sv, _) => Some(sv), + _ => None, + }) + .collect::>(); + + let args = HigherOrderReturnFieldArgs { + arg_fields: &new_fields, + scalar_arguments: &arguments, + }; + + func.func.return_field_from_args(args) + } }?; Ok(( relation, - Arc::new(field.as_ref().clone().with_name(schema_name)), + // todo avoid this rename / use the name above + field.renamed(&schema_name), )) } @@ -679,7 +693,16 @@ impl ExprSchemable for Expr { // like all of the binary expressions below. Perhaps Expr should track the // type of the expression? - if can_cast_types(&this_type, cast_to_type) { + // Special handling for struct-to-struct casts with name-based field matching + let can_cast = match (&this_type, cast_to_type) { + (DataType::Struct(_), DataType::Struct(_)) => { + // Always allow struct-to-struct casts; field matching happens at runtime + true + } + _ => can_cast_types(&this_type, cast_to_type), + }; + + if can_cast { match self { Expr::ScalarSubquery(subquery) => { Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?)) @@ -692,6 +715,33 @@ impl ExprSchemable for Expr { } } +/// Verify that function is invoked with correct number and type of arguments as +/// defined in `TypeSignature`. +fn verify_function_arguments( + function: &F, + input_fields: &[FieldRef], +) -> Result> { + fields_with_udf(input_fields, function).map_err(|err| { + let data_types = input_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + plan_datafusion_err!( + "{}. {}", + match err { + DataFusionError::Plan(msg) => msg, + err => err.to_string(), + }, + utils::generate_signature_error_message( + function.name(), + function.signature(), + &data_types + ) + ) + }) +} + /// Returns the innermost [Expr] that is provably null if `expr` is null. fn unwrap_certainly_null_expr(expr: &Expr) -> &Expr { match expr { @@ -702,93 +752,6 @@ fn unwrap_certainly_null_expr(expr: &Expr) -> &Expr { } } -impl Expr { - /// Common method for window functions that applies type coercion - /// to all arguments of the window function to check if it matches - /// its signature. - /// - /// If successful, this method returns the data type and - /// nullability of the window function's result. - /// - /// Otherwise, returns an error if there's a type mismatch between - /// the window function's signature and the provided arguments. - fn data_type_and_nullable_with_window_function( - &self, - schema: &dyn ExprSchema, - window_function: &WindowFunction, - ) -> Result<(DataType, bool)> { - let WindowFunction { - fun, - params: WindowFunctionParams { args, .. }, - .. - } = window_function; - - let fields = args - .iter() - .map(|e| e.to_field(schema).map(|(_, f)| f)) - .collect::>>()?; - match fun { - WindowFunctionDefinition::AggregateUDF(udaf) => { - let data_types = fields - .iter() - .map(|f| f.data_type()) - .cloned() - .collect::>(); - let new_fields = fields_with_aggregate_udf(&fields, udaf) - .map_err(|err| { - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - fun.name(), - fun.signature(), - &data_types - ) - ) - })? - .into_iter() - .collect::>(); - - let return_field = udaf.return_field(&new_fields)?; - - Ok((return_field.data_type().clone(), return_field.is_nullable())) - } - WindowFunctionDefinition::WindowUDF(udwf) => { - let data_types = fields - .iter() - .map(|f| f.data_type()) - .cloned() - .collect::>(); - let new_fields = fields_with_window_udf(&fields, udwf) - .map_err(|err| { - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - fun.name(), - fun.signature(), - &data_types - ) - ) - })? - .into_iter() - .collect::>(); - let (_, function_name) = self.qualified_name(); - let field_args = WindowUDFFieldArgs::new(&new_fields, &function_name); - - udwf.field(field_args) - .map(|field| (field.data_type().clone(), field.is_nullable())) - } - } - } -} - /// Cast subquery in InSubquery/ScalarSubquery to a given type. /// /// 1. **Projection plan**: If the subquery is a projection (i.e. a SELECT statement with specific @@ -835,7 +798,7 @@ mod tests { use super::*; use crate::{and, col, lit, not, or, out_ref_col_with_metadata, when}; - use datafusion_common::{assert_or_internal_err, DFSchema, ScalarValue}; + use datafusion_common::{DFSchema, assert_or_internal_err}; macro_rules! test_is_expr_nullable { ($EXPR_TYPE:ident) => {{ @@ -848,9 +811,10 @@ mod tests { fn expr_schema_nullability() { let expr = col("foo").eq(lit(1)); assert!(!expr.nullable(&MockExprSchema::new()).unwrap()); - assert!(expr - .nullable(&MockExprSchema::new().with_nullable(true)) - .unwrap()); + assert!( + expr.nullable(&MockExprSchema::new().with_nullable(true)) + .unwrap() + ); test_is_expr_nullable!(is_null); test_is_expr_nullable!(is_not_null); @@ -1029,9 +993,10 @@ mod tests { assert!(!expr.nullable(&get_schema(false)).unwrap()); assert!(expr.nullable(&get_schema(true)).unwrap()); // Testing nullable() returns an error. - assert!(expr - .nullable(&get_schema(false).with_error_on_nullable(true)) - .is_err()); + assert!( + expr.nullable(&get_schema(false).with_error_on_nullable(true)) + .is_err() + ); let null = lit(ScalarValue::Int32(None)); let expr = col("foo").in_list(vec![null, lit(1)], false); @@ -1108,6 +1073,27 @@ mod tests { assert_eq!(meta, outer_ref.metadata(&schema).unwrap()); } + #[test] + fn test_alias_metadata_is_preserved_in_field_metadata() { + let schema = MockExprSchema::new().with_data_type(DataType::Int32); + let alias_metadata = FieldMetadata::from(HashMap::from([( + "some_key".to_string(), + "some_value".to_string(), + )])); + + let Expr::Alias(alias) = col("foo").alias("alias") else { + unreachable!(); + }; + let expr = Expr::Alias(alias.with_metadata(Some(alias_metadata.clone()))); + + let field = expr.to_field(&schema).unwrap().1; + assert_eq!( + field.metadata().get("some_key"), + Some(&"some_value".to_string()) + ); + assert_eq!(expr.metadata(&schema).unwrap(), alias_metadata); + } + #[test] fn test_expr_placeholder() { let schema = MockExprSchema::new(); @@ -1125,16 +1111,18 @@ mod tests { ), )); + let field = expr.to_field(&schema).unwrap().1; assert_eq!( - expr.data_type_and_nullable(&schema).unwrap(), - (DataType::Utf8, true) + (field.data_type(), field.is_nullable()), + (&DataType::Utf8, true) ); assert_eq!(placeholder_meta, expr.metadata(&schema).unwrap()); let expr_alias = expr.alias("a placeholder by any other name"); + let expr_alias_field = expr_alias.to_field(&schema).unwrap().1; assert_eq!( - expr_alias.data_type_and_nullable(&schema).unwrap(), - (DataType::Utf8, true) + (expr_alias_field.data_type(), expr_alias_field.is_nullable()), + (&DataType::Utf8, true) ); assert_eq!(placeholder_meta, expr_alias.metadata(&schema).unwrap()); @@ -1143,38 +1131,41 @@ mod tests { "".to_string(), Some(Field::new("", DataType::Utf8, false).into()), )); + let expr_field = expr.to_field(&schema).unwrap().1; assert_eq!( - expr.data_type_and_nullable(&schema).unwrap(), - (DataType::Utf8, false) + (expr_field.data_type(), expr_field.is_nullable()), + (&DataType::Utf8, false) ); + let expr_alias = expr.alias("a placeholder by any other name"); + let expr_alias_field = expr_alias.to_field(&schema).unwrap().1; assert_eq!( - expr_alias.data_type_and_nullable(&schema).unwrap(), - (DataType::Utf8, false) + (expr_alias_field.data_type(), expr_alias_field.is_nullable()), + (&DataType::Utf8, false) ); } #[derive(Debug)] struct MockExprSchema { - field: Field, + field: FieldRef, error_on_nullable: bool, } impl MockExprSchema { fn new() -> Self { Self { - field: Field::new("mock_field", DataType::Null, false), + field: Arc::new(Field::new("mock_field", DataType::Null, false)), error_on_nullable: false, } } fn with_nullable(mut self, nullable: bool) -> Self { - self.field = self.field.with_nullable(nullable); + Arc::make_mut(&mut self.field).set_nullable(nullable); self } fn with_data_type(mut self, data_type: DataType) -> Self { - self.field = self.field.with_data_type(data_type); + Arc::make_mut(&mut self.field).set_data_type(data_type); self } @@ -1184,7 +1175,8 @@ mod tests { } fn with_metadata(mut self, metadata: FieldMetadata) -> Self { - self.field = metadata.add_to_field(self.field); + self.field = + Arc::new(metadata.add_to_field(Arc::unwrap_or_clone(self.field))); self } } @@ -1195,8 +1187,25 @@ mod tests { Ok(self.field.is_nullable()) } - fn field_from_column(&self, _col: &Column) -> Result<&Field> { + fn field_from_column(&self, _col: &Column) -> Result<&FieldRef> { Ok(&self.field) } } + + #[test] + fn test_scalar_variable() { + let mut meta = HashMap::new(); + meta.insert("bar".to_string(), "buzz".to_string()); + let meta = FieldMetadata::from(meta); + + let field = Field::new("foo", DataType::Int32, true); + let field = meta.add_to_field(field); + let field = Arc::new(field); + + let expr = Expr::ScalarVariable(field, vec!["foo".to_string()]); + + let schema = MockExprSchema::new(); + + assert_eq!(meta, expr.metadata(&schema).unwrap()); + } } diff --git a/datafusion/expr/src/extension_types/array_formatter_factory.rs b/datafusion/expr/src/extension_types/array_formatter_factory.rs new file mode 100644 index 0000000000000..f0239d3978801 --- /dev/null +++ b/datafusion/expr/src/extension_types/array_formatter_factory.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::registry::ExtensionTypeRegistryRef; +use arrow::array::Array; +use arrow::util::display::{ArrayFormatter, ArrayFormatterFactory, FormatOptions}; +use arrow_schema::{ArrowError, Field}; + +/// A factory for creating [`ArrayFormatter`]s that checks whether a registered extension type can +/// format a given array based on its metadata. +#[derive(Debug)] +pub struct DFArrayFormatterFactory { + /// The extension type registry + registry: ExtensionTypeRegistryRef, +} + +impl DFArrayFormatterFactory { + /// Creates a new [`DFArrayFormatterFactory`]. + pub fn new(registry: ExtensionTypeRegistryRef) -> Self { + Self { registry } + } +} + +impl ArrayFormatterFactory for DFArrayFormatterFactory { + fn create_array_formatter<'formatter>( + &self, + array: &'formatter dyn Array, + options: &FormatOptions<'formatter>, + field: Option<&'formatter Field>, + ) -> Result>, ArrowError> { + let Some(field) = field else { + return Ok(None); + }; + + let Some(extension_type_name) = field.extension_type_name() else { + return Ok(None); + }; + + let Some(registration) = self + .registry + .extension_type_registration(extension_type_name) + .ok() + else { + // If the extension type is not registered, we fall back to the default formatter + return Ok(None); + }; + + registration + .create_df_extension_type(field.data_type(), field.extension_type_metadata())? + .create_array_formatter(array, options) + .map_err(ArrowError::from) + } +} diff --git a/datafusion/expr/src/extension_types/mod.rs b/datafusion/expr/src/extension_types/mod.rs new file mode 100644 index 0000000000000..55ec1ad95b5a1 --- /dev/null +++ b/datafusion/expr/src/extension_types/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains code that enables DataFusion's extension type capabilities. + +mod array_formatter_factory; + +pub use array_formatter_factory::*; diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index e0235d32292fa..68865cbe1ca54 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -27,6 +27,8 @@ pub use datafusion_functions_aggregate_common::accumulator::{ AccumulatorArgs, AccumulatorFactoryFunction, StateFieldsArgs, }; +use crate::expr::{AggregateFunction, WindowFunction}; +use crate::simplify::SimplifyContext; pub use datafusion_functions_window_common::expr::ExpressionArgs; pub use datafusion_functions_window_common::field::WindowUDFFieldArgs; pub use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; @@ -64,28 +66,22 @@ pub type PartitionEvaluatorFactory = pub type StateTypeFunction = Arc Result>> + Send + Sync>; -/// [crate::udaf::AggregateUDFImpl::simplify] simplifier closure -/// A closure with two arguments: -/// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked -/// * 'info': [crate::simplify::SimplifyInfo] +/// Type alias for [crate::udaf::AggregateUDFImpl::simplify]. /// -/// Closure returns simplified [Expr] or an error. -pub type AggregateFunctionSimplification = Box< - dyn Fn( - crate::expr::AggregateFunction, - &dyn crate::simplify::SimplifyInfo, - ) -> Result, ->; +/// This closure is invoked with: +/// * `aggregate_function`: [AggregateFunction] with already simplified arguments +/// * `info`: [SimplifyContext] +/// +/// It returns a simplified [Expr] or an error. +pub type AggregateFunctionSimplification = + Box Result>; -/// [crate::udwf::WindowUDFImpl::simplify] simplifier closure -/// A closure with two arguments: -/// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked -/// * 'info': [crate::simplify::SimplifyInfo] +/// Type alias for [crate::udwf::WindowUDFImpl::simplify]. +/// +/// This closure is invoked with: +/// * `window_function`: [WindowFunction] with already simplified arguments +/// * `info`: [SimplifyContext] /// -/// Closure returns simplified [Expr] or an error. -pub type WindowFunctionSimplification = Box< - dyn Fn( - crate::expr::WindowFunction, - &dyn crate::simplify::SimplifyInfo, - ) -> Result, ->; +/// It returns a simplified [Expr] or an error. +pub type WindowFunctionSimplification = + Box Result>; diff --git a/datafusion/expr/src/higher_order_function.rs b/datafusion/expr/src/higher_order_function.rs new file mode 100644 index 0000000000000..413714f498164 --- /dev/null +++ b/datafusion/expr/src/higher_order_function.rs @@ -0,0 +1,1684 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`HigherOrderUDF`]: User Defined Higher Order Functions + +use crate::expr::{ + HigherOrderFunction, display_comma_separated, + schema_name_from_exprs_comma_separated_without_space, +}; +use crate::type_coercion::functions::value_fields_with_higher_order_udf; +use crate::udf_eq::UdfEq; +use crate::{ColumnarValue, Documentation, Expr, ExprSchemable}; +use arrow::array::{ArrayRef, RecordBatch}; +use arrow::datatypes::{DataType, FieldRef, Schema}; +use arrow_schema::SchemaRef; +use datafusion_common::config::ConfigOptions; +use datafusion_common::datatype::FieldExt; +use datafusion_common::hash_map::EntryRef; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, +}; +use datafusion_common::{ + DFSchema, HashMap, HashSet, Result, ScalarValue, exec_err, internal_datafusion_err, + internal_err, not_impl_err, plan_datafusion_err, plan_err, +}; +use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; +use datafusion_expr_common::signature::Volatility; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::any::Any; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::hash::{Hash, Hasher}; +use std::mem; +use std::sync::Arc; + +/// The types of arguments for which a function has implementations. +/// +/// [`HigherOrderTypeSignature`] **DOES NOT** define the types that a user query could call the +/// function with. DataFusion will automatically coerce (cast) argument types to +/// one of the supported function signatures, if possible. +/// +/// # Overview +/// Functions typically provide implementations for a small number of different +/// argument [`DataType`]s, rather than all possible combinations. If a user +/// calls a function with arguments that do not match any of the declared types, +/// DataFusion will attempt to automatically coerce (add casts to) function +/// arguments so they match the [`HigherOrderTypeSignature`]. See the [`type_coercion`] module +/// for more details +/// +/// [`type_coercion`]: crate::type_coercion +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub enum HigherOrderTypeSignature { + /// The acceptable signature and coercions rules are special for this + /// function. + /// + /// If this signature is specified, + /// DataFusion will call [`HigherOrderUDFImpl::coerce_value_types`] to prepare argument types. + UserDefined, + /// One or more lambdas or arguments with arbitrary types + VariadicAny, + /// The specified number of lambdas or arguments with arbitrary types. + Any(usize), + /// Exactly the specified arguments in the given order, with arbitrary types. + /// DataFusion will call [`HigherOrderUDFImpl::coerce_value_types`] to prepare the value + /// argument types. + Exact(Vec>), +} + +/// Provides information necessary for calling a higher order function. +/// +/// - [`HigherOrderTypeSignature`] defines the argument types that a function has implementations +/// for. +/// +/// - [`Volatility`] defines how the output of the function changes with the input. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub struct HigherOrderSignature { + /// The data types that the function accepts. See [HigherOrderTypeSignature] for more information. + pub type_signature: HigherOrderTypeSignature, + /// The volatility of the function. See [Volatility] for more information. + pub volatility: Volatility, + /// The max number of times to call [HigherOrderUDFImpl::lambda_parameters] before raising an error. + /// Used to guard against implementations that causes an infinite loop by endlessly returning + /// [LambdaParametersProgress::Partial]. Defaults to 256 + pub lambda_parameters_max_iterations: usize, +} + +const LAMBDA_PARAMETERS_MAX_ITERATIONS: usize = 256; + +impl HigherOrderSignature { + /// Creates a new `HigherOrderSignature` from a given type signature and volatility. + pub fn new(type_signature: HigherOrderTypeSignature, volatility: Volatility) -> Self { + HigherOrderSignature { + type_signature, + volatility, + lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS, + } + } + + /// User-defined coercion rules for the function. + pub fn user_defined(volatility: Volatility) -> Self { + Self { + type_signature: HigherOrderTypeSignature::UserDefined, + volatility, + lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS, + } + } + + /// An arbitrary number of lambdas or arguments of any type. + pub fn variadic_any(volatility: Volatility) -> Self { + Self { + type_signature: HigherOrderTypeSignature::VariadicAny, + volatility, + lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS, + } + } + + /// A specified number of arguments of any type + pub fn any(arg_count: usize, volatility: Volatility) -> Self { + Self { + type_signature: HigherOrderTypeSignature::Any(arg_count), + volatility, + lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS, + } + } + + /// Exactly the specified arguments in the given order, with arbitrary types. + /// DataFusion will call [`HigherOrderUDFImpl::coerce_value_types`] to prepare the value + /// argument types. + /// + /// # Example + /// A function that takes one value argument followed by one lambda: + /// ``` + /// # use datafusion_expr::{HigherOrderSignature, ValueOrLambda, Volatility}; + /// let sig = HigherOrderSignature::exact( + /// vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())], + /// Volatility::Immutable, + /// ); + /// ``` + pub fn exact(args: Vec>, volatility: Volatility) -> Self { + Self { + type_signature: HigherOrderTypeSignature::Exact(args), + volatility, + lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS, + } + } +} + +impl PartialEq for dyn HigherOrderUDFImpl { + fn eq(&self, other: &Self) -> bool { + self.dyn_eq(other as _) + } +} + +impl PartialOrd for dyn HigherOrderUDFImpl { + fn partial_cmp(&self, other: &Self) -> Option { + let mut cmp = self.name().cmp(other.name()); + if cmp == Ordering::Equal { + cmp = self.signature().partial_cmp(other.signature())?; + } + if cmp == Ordering::Equal { + cmp = self.aliases().partial_cmp(other.aliases())?; + } + // Contract for PartialOrd and PartialEq consistency requires that + // a == b if and only if partial_cmp(a, b) == Some(Equal). + if cmp == Ordering::Equal && self != other { + // Functions may have other properties besides name and signature + // that differentiate two instances (e.g. type, or arbitrary parameters). + // We cannot return Some(Equal) in such case. + return None; + } + debug_assert!( + cmp == Ordering::Equal || self != other, + "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \ + The functions compare as equal, but they are not equal based on general properties that \ + the PartialOrd implementation observes,", + self.name(), + other.name() + ); + Some(cmp) + } +} + +impl Eq for dyn HigherOrderUDFImpl {} + +impl Hash for dyn HigherOrderUDFImpl { + fn hash(&self, state: &mut H) { + self.dyn_hash(state) + } +} + +/// Arguments passed to [`HigherOrderUDFImpl::invoke_with_args`] when invoking a +/// higher order function. +#[derive(Debug, Clone)] +pub struct HigherOrderFunctionArgs { + /// The evaluated arguments and lambdas to the function + pub args: Vec>, + /// Field associated with each arg, if it exists + /// For lambdas, it will be the field of the result of + /// the lambda if evaluated with the parameters + /// returned from [`HigherOrderUDFImpl::lambda_parameters`] + pub arg_fields: Vec>, + /// The number of rows in record batch being evaluated + pub number_rows: usize, + /// The return field of the higher order function returned + /// (from `return_field_from_args`) when creating the + /// physical expression from the logical expression + pub return_field: FieldRef, + /// The config options at execution time + pub config_options: Arc, +} + +impl HigherOrderFunctionArgs { + /// The return type of the function. See [`Self::return_field`] for more + /// details. + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } +} + +/// A lambda argument to a HigherOrderFunction +#[derive(Clone, Debug)] +pub struct LambdaArgument { + /// The parameters defined in this lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be `vec![Field::new("v", DataType::Int32, true)]` + params: Vec, + /// The body of the lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be the physical expression of `-v` + body: Arc, + /// Cached schema built from `params`. Reused across every `evaluate` call + /// (and across every nested-list iteration when the lambda is called once + /// per outer sublist), avoiding the per-call `Schema::new` build that + /// includes constructing the internal name -> index map. + schema: SchemaRef, + /// A RecordBatch containing the captured columns inside this lambda body, if any + /// + /// For example, for `array_transform([2], v -> v + a + b)`, + /// this will be a `RecordBatch` with two columns, `a` and `b` + captures: Option, +} + +impl LambdaArgument { + pub fn new( + params: Vec, + body: Arc, + captures: Option, + ) -> Self { + let fields = match &captures { + Some(batch) => batch + .schema_ref() + .fields() + .iter() + .cloned() + .chain(params.clone()) + .collect(), + None => params.clone(), + }; + + let schema = Arc::new(Schema::new(fields)); + + Self { + params, + body, + schema, + captures, + } + } + + /// Evaluate this lambda + /// `args` should evaluate to the value of each parameter + /// of the correspondent lambda returned in [HigherOrderUDFImpl::lambda_parameters]. + /// + /// `spread_captures` is responsible for transforming the captured column arrays + /// so they align with the evaluation batch. Captures are snapshotted from the + /// outer batch at construction time, giving one value per outer row, but the + /// function may evaluate the lambda body over a batch with a different number + /// of rows. It is the function's responsibility to provide the appropriate + /// `spread_captures` closure to expand (or otherwise reshape) the captures + /// to match. + /// + /// Taking as an example the following table: + /// + /// ```sql + /// CREATE TABLE t (arr INT[], a INT) AS VALUES + /// ([1, 2, 3], 10), + /// ([], 20), + /// ([4], 30); + /// ``` + /// + /// `SELECT array_transform(arr, v -> v + a) from t` would execute over three outer rows: + /// + /// ```text + /// arr (ListArray): [[1, 2, 3], [], [4]] -- 3 outer rows, 4 total elements + /// a (captured): [10, 20, 30] -- one value per outer row + /// ``` + /// + /// `array_transform` flattens the list elements into a single batch of 4 rows, + /// so `spread_captures` must repeat/drop captured values to match: + /// + /// ```text + /// v (flattened args): [1, 2, 3, 4] + /// a (spread): [10, 10, 10, 30] -- 10 repeated for 3 elements in row 0, + /// -- 20 dropped for the empty sublist in row 1, + /// -- 30 once for the single element in row 2 + /// ``` + /// + /// The lambda body `v + a` then evaluates element-wise over these 4-row arrays, + /// producing `[11, 12, 13, 34]`, which `array_transform` reassembles into `[[11, 12, 13], [], [34]]`. + /// + /// If the lambda has no captures, `spread_captures` is never called. + pub fn evaluate( + &self, + args: &[&dyn Fn() -> Result], + spread_captures: impl FnOnce(&[ArrayRef]) -> Result>, + ) -> Result { + let spread_captures = self + .captures + .as_ref() + .map(|captures| { + let spread_columns = spread_captures(captures.columns())?; + + RecordBatch::try_new(captures.schema(), spread_columns) + }) + .transpose()?; + + let merged = merge_captures_with_variables( + spread_captures.as_ref(), + Arc::clone(&self.schema), + &self.params, + args, + )?; + + self.body.evaluate(&merged) + } +} + +fn merge_captures_with_variables( + captures: Option<&RecordBatch>, + schema: SchemaRef, + params: &[FieldRef], + variables: &[&dyn Fn() -> Result], +) -> Result { + if variables.len() < params.len() { + return exec_err!( + "expected at least {} lambda arguments to merge with captures, got {}", + params.len(), + variables.len() + ); + } + + let columns = match captures { + Some(captures) => { + let mut columns = captures.columns().to_vec(); + + for arg in &variables[..params.len()] { + columns.push(arg()?); + } + + columns + } + None => variables + .iter() + .take(params.len()) + .map(|arg| arg()) + .collect::>()?, + }; + + Ok(RecordBatch::try_new(schema, columns)?) +} + +/// Information about arguments passed to the function +/// +/// This structure contains metadata about how the function was called +/// such as the type of the arguments, any scalar arguments and if the +/// arguments can (ever) be null +/// +/// See [`HigherOrderUDFImpl::return_field_from_args`] for more information +#[derive(Clone, Debug)] +pub struct HigherOrderReturnFieldArgs<'a> { + /// The data types of the arguments to the function + /// + /// If argument `i` to the function is a lambda, it will be the field of the result of the + /// lambda if evaluated with the parameters returned from [`HigherOrderUDFImpl::lambda_parameters`] + /// + /// For example, with `array_transform([1], v -> v == 5)` + /// this field will be + /// ```ignore + /// [ + /// ValueOrLambda::Value(Field::new("", DataType::new_list(DataType::Int32, true), true)), + /// ValueOrLambda::Lambda(Field::new("", DataType::Boolean, true)) + /// ] + /// ``` + pub arg_fields: &'a [ValueOrLambda], + /// Is argument `i` to the function a scalar (constant)? + /// + /// If the argument `i` is not a scalar, it will be None + /// + /// For example, if a function is called like `array_transform([1], v -> v == 5)` + /// this field will be `[Some(ScalarValue::List(...), None]` + pub scalar_arguments: &'a [Option<&'a ScalarValue>], +} + +/// An argument to a higher order function +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Hash)] +pub enum ValueOrLambda { + /// A value with associated data + Value(V), + /// A lambda with associated data + Lambda(L), +} + +/// Represents a step during the resolution of the parameters of all lambdas of a given +/// higher-order function via [HigherOrderUDFImpl::lambda_parameters]. It's valid that the +/// fields of a given lambda changes between steps, and is up to the implementation to +/// provide during the function evaluation the parameters that matches the fields returned +/// at the [LambdaParametersProgress::Complete] step. See [HigherOrderUDFImpl::lambda_parameters] +/// docs for more details +pub enum LambdaParametersProgress { + /// The parameters of some lambdas are unknown due to a dependency on another lambda output field + /// or are placeholders due to a dependency on it's own output field. It's perfectly valid to + /// contain only `Some`'s and not a single `None`, representing lambdas that depends only on itself + /// and not on others. [HigherOrderUDFImpl::lambda_parameters] will be called again with the output + /// field of all lambdas with known parameters. + Partial(Vec>>), + /// There are no unmet dependencies and all parameters are known, [HigherOrderUDFImpl::lambda_parameters] + /// will not be called again + Complete(Vec>), +} + +/// Trait for implementing user defined higher order functions. +/// +/// This trait exposes the full API for implementing user defined functions and +/// can be used to implement any function. +/// +/// New higher order functions typically implement this trait and are then +/// wrapped in a [`HigherOrderUDF`] for registration with DataFusion. +/// +/// See [`array_transform.rs`] for a commented complete implementation +/// +/// [`array_transform.rs`]: https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/array_transform.rs +pub trait HigherOrderUDFImpl: Debug + DynEq + DynHash + Send + Sync + Any { + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns any aliases (alternate names) for this function. + /// + /// Aliases can be used to invoke the same function using different names. + /// For example in some databases `now()` and `current_timestamp()` are + /// aliases for the same function. This behavior can be obtained by + /// returning `current_timestamp` as an alias for the `now` function. + /// + /// Note: `aliases` should only include names other than [`Self::name`]. + /// Defaults to `[]` (no aliases) + fn aliases(&self) -> &[String] { + &[] + } + + /// Returns the name of the column this expression would create + /// + /// See [`Expr::schema_name`] for details + fn schema_name(&self, args: &[Expr]) -> Result { + Ok(format!( + "{}({})", + self.name(), + schema_name_from_exprs_comma_separated_without_space(args)? + )) + } + + /// Returns a [`HigherOrderSignature`] describing the argument types for which this + /// function has an implementation, and the function's [`Volatility`]. + /// + /// See [`HigherOrderSignature`] for more details on argument type handling + /// and [`Self::return_field_from_args`] for computing the return type. + /// + /// [`Volatility`]: datafusion_expr_common::signature::Volatility + fn signature(&self) -> &HigherOrderSignature; + + /// Return the field of all the parameters supported by the lambdas in `fields`. + /// If a lambda support multiple parameters, all should be returned, regardless of + /// whether they are used or not on a particular invocation + /// + /// Tip: If you have a [`HigherOrderFunction`] invocation, you can call the helper + /// [`HigherOrderFunction::lambda_parameters`] instead of this method directly + /// + /// The name of the returned fields are ignored. + /// + /// This function is repeatedelly called until [LambdaParametersProgress::Complete] is returned, with + /// `step` increased by one at each invocation, starting at 0. + /// + /// For functions which all lambda parameters depend only on the field of it's value arguments, + /// this can return [LambdaParametersProgress::Complete] at step 0. Taking as an example a strict + /// array_reduce with the signature `(arr: [V], initial_value: I, (I, V) -> I, (I) -> O) -> O`, which + /// requires it's initial value to be the exact same type of it's merge output, which is also the + /// parameter of it's finish lambda, the expression + /// + /// `array_reduce([1.2, 2.1], 0.0, (acc, v) -> acc + v + 1.5, v -> v > 5.1)` + /// + /// would result in this function being called as the following: + /// + /// ```ignore + /// let lambda_parameters = array_reduce.lambda_parameters( + /// 0, + /// &[ + /// // the Field of the literal `[1.2, 2.1]`, the array being reduced + /// ValueOrLambda::Value(Arc::new(Field::new("", DataType::new_list(DataType::Float32, true), true))), + /// // the Field of the literal `0.0`, the initial value + /// ValueOrLambda::Value(Arc::new(Field::new("", DataType::Float32, true))), + /// // the Field of the output of the merge lambda, which is unknown at this point because it depends + /// // on the return of this call + /// ValueOrLambda::Lambda(None), + /// // the Field of the output of the finish lambda, unknown for the same reason as above + /// ValueOrLambda::Lambda(None), + /// ])?; + /// + /// assert_eq!( + /// lambda_parameters, + /// LambdaParametersProgress::Complete(vec![ + /// // the finish lambda supported parameters, regardless of how many are actually used + /// vec![ + /// // the accumulator which is the field of the initial value + /// Arc::new(Field::new("ignored_name", DataType::Float32, true)), + /// // the array values being reduced + /// Arc::new(Field::new("", DataType::Float32, true)), + /// ], + /// // the merge lambda supported parameters + /// vec![ + /// // the reduced value which is the field of the initial value + /// Arc::new(Field::new("ignored_name", DataType::Float32, true)), + /// ], + /// ]) + /// ); + /// ``` + /// + /// For functions which lambda parameters depends on the output of other lambdas, or on their own lambda, + /// this can return [LambdaParametersProgress::Partial] until all dependencies are met. Note that for + /// lambda with cyclic dependencies, you likely want to use [HigherOrderUDFImpl::coerce_values_for_lambdas] too. + /// Take as an example a flexible array_reduce with the signature `(arr: [V], initial_value: I, (ACC, V) -> ACC, (ACC) -> O) -> O`. + /// It has a cyclic dependency in the merge lambda, and a dependency of the finish lambda in the merge lambda, + /// and only requires the initial value to be *coercible* to the output of the merge lambda, which is defined by + /// it's [HigherOrderUDFImpl::coerce_values_for_lambdas] implementation. The expression + /// + /// `array_reduce([1.2, 2.1], 0, (acc, v) -> acc + v + 1.5, v -> v > 5.1)` + /// + /// would result in this function being called as the following: + /// + /// ```ignore + /// let lambda_parameters = array_reduce.lambda_parameters( + /// 0, + /// &[ + /// // the Field of the literal `[1.2, 2.1]`, the array being reduced + /// ValueOrLambda::Value(Arc::new(Field::new("", DataType::new_list(DataType::Float32, true), true))), + /// // the Field of the literal `0`, the initial value + /// ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, true))), + /// // the Field of the output of the merge lambda, which is unknown at this point because it depends on + /// // the return this call + /// ValueOrLambda::Lambda(None), + /// // the Field of the output of the finish lambda, unknown for the same reason as above + /// ValueOrLambda::Lambda(None), + /// ])?; + /// + /// assert_eq!( + /// lambda_parameters, + /// LambdaParametersProgress::Partial(vec![ + /// // the finish lambda supported parameters, regardless of how many are actually used + /// Some(vec![ + /// // at step 0, use the field of the initial value + /// Arc::new(Field::new("ignored_name", DataType::Int32, true)), + /// // the array values being reduced + /// Arc::new(Field::new("", DataType::Float32, true)), + /// ]), + /// // the merge lambda supported parameters, unknown at this point due to dependency on the merge output + /// None, + /// ]) + /// ); + /// + /// let lambda_parameters = array_reduce.lambda_parameters( + /// 1, + /// &[ + /// // the Field of the literal `[1.2, 2.1]`, the array being reduced + /// ValueOrLambda::Value(Arc::new(Field::new("", DataType::new_list(DataType::Float32, true), true))), + /// // the Field of the literal `0`, the initial value + /// ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, true))), + /// // the Field of the output of the merge lambda, which could be inferred to be a Float32 based on the + /// // returned values of the previous step + /// ValueOrLambda::Value(Arc::new(Field::new("", DataType::Float32, true))), + /// // the Field of the output of the finish lambda, which is unknown at this point because it depends + /// // on the return of this call + /// ValueOrLambda::Lambda(None), + /// ])?; + /// + /// assert_eq!( + /// lambda_parameters, + /// LambdaParametersProgress::Complete(vec![ + /// // the finish lambda supported parameters, regardless of how many are actually used + /// vec![ + /// // the finish lambda own output now used as it's accumulator + /// Arc::new(Field::new("ignored_name", DataType::Float32, true)), + /// // the array values being reduced + /// Arc::new(Field::new("", DataType::Float32, true)), + /// ], + /// // the merge lambda supported parameters, which is the output of the merge lambda, + /// vec![ + /// // the output of the merge lambda + /// Arc::new(Field::new("", DataType::Float32, true)), + /// ], + /// ]) + /// ); + /// + /// let coerce_to = array_reduce.coerce_values_for_lambdas(&[ + /// // the literal `[1.2, 2.1]` data type, the array being reduced + /// ValueOrLambda::Value(DataType::new_list(DataType::Float32, true)), + /// // the literal `0` data type, the initial value + /// ValueOrLambda::Value(DataType::Int32), + /// // the output data type of the merge lambda + /// ValueOrLambda::Lambda(DataType::Float32), + /// // the output data type of the finish lambda + /// ValueOrLambda::Lambda(DataType::Boolean), + /// ])?; + /// + /// assert_eq!( + /// coerce_to, + /// Some(vec![ + /// // return the same type for the array being reduced + /// DataType::new_list(DataType::Float32, true), + /// // coerce the initial value to the output of the merge lambda + /// DataType::Float32, + /// ]) + /// ); + /// + /// ``` + /// + /// Note this may also be called at step 0 with all lambda outputs already set, and in that case, + /// [LambdaParametersProgress::Complete] must be returned + /// + /// The implementation can assume that some other part of the code has coerced + /// the actual argument types to match [`Self::signature`], except the coercion defined by + /// [Self::coerce_values_for_lambdas]. + /// + /// [`HigherOrderFunction`]: crate::expr::HigherOrderFunction + /// [`HigherOrderFunction::lambda_parameters`]: crate::expr::HigherOrderFunction::lambda_parameters + fn lambda_parameters( + &self, + step: usize, + fields: &[ValueOrLambda>], + ) -> Result; + + /// Coerce value arguments of a function call to types that the function can evaluate also taking into + /// account the *output type of it's lambdas*. This differs from [HigherOrderUDFImpl::coerce_value_types] + /// that only has access to the type of it's value arguments because it's called before the output type + /// of lambdas are known. + /// + /// See the [type coercion module](crate::type_coercion) + /// documentation for more details on type coercion + /// + /// # Parameters + /// * `fields`: The argument types of the value arguments of this function, or the output type of lambdas + /// + /// # Return value + /// If `Some`, contains a Vec with the same number of [ValueOrLambda::Value] in `fields`. + /// DataFusion will `CAST` the function call arguments to these specific types. If `None`, no + /// coercion will be applied beyond the one defined by the function signature. + /// + /// For example, a flexible array_reduce implementation (see [Self::lambda_parameters] docs), when working + /// with the expression below, may want to coerce it's initial value argument, the *integer* `0`, + /// to match the output of it's merge function, which is a *float*: + /// + /// `array_reduce([1.2, 2.1], 0, (acc, v) -> acc + v + 1.5, v -> v > 2.0)` + fn coerce_values_for_lambdas( + &self, + _fields: &[ValueOrLambda], + ) -> Result>> { + Ok(None) + } + + /// What type will be returned by this function, given the arguments? + /// + /// The implementation can assume that some other part of the code has coerced + /// the actual argument types to match [`Self::signature`], including the coercion + /// defined by [Self::coerce_values_for_lambdas]. + /// + /// # Example creating `Field` + /// + /// Note the name of the `Field` is ignored, except for structured types such as + /// `DataType::Struct`. + /// + /// ```rust + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field, FieldRef}; + /// # use datafusion_common::Result; + /// # use datafusion_expr::HigherOrderReturnFieldArgs; + /// # struct Example{} + /// # impl Example { + /// fn return_field_from_args(&self, args: HigherOrderReturnFieldArgs) -> Result { + /// let field = Arc::new(Field::new("ignored_name", DataType::Int32, true)); + /// Ok(field) + /// } + /// # } + /// ``` + fn return_field_from_args( + &self, + args: HigherOrderReturnFieldArgs, + ) -> Result; + + /// Whether List or LargeList arguments should have it's non-empty null + /// sublists cleaned with [remove_list_null_values] before invoking this function + /// + /// The default implementation always returns true and should only be implemented + /// if you want to handle non-empty null sublists yourself + /// + /// [remove_list_null_values]: datafusion_common::utils::remove_list_null_values + // todo: extend this to listview and maps when remove_list_null_values supports it + fn clear_null_values(&self) -> bool { + true + } + + /// Invoke the function returning the appropriate result. + /// + /// # Performance + /// + /// For the best performance, the implementations should handle the common case + /// when one or more of their arguments are constant values (aka + /// [`ColumnarValue::Scalar`]). + /// + /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments + /// to arrays, which will likely be simpler code, but be slower. + fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result; + + /// Returns true if some of this `exprs` subexpressions may not be evaluated + /// and thus any side effects (like divide by zero) may not be encountered. + /// + /// Setting this to true prevents certain optimizations such as common + /// subexpression elimination + /// + /// When overriding this function to return `true`, [HigherOrderUDFImpl::conditional_arguments] can also be + /// overridden to report more accurately which arguments are eagerly evaluated and which ones + /// lazily. + fn short_circuits(&self) -> bool { + false + } + + /// Determines which of the arguments passed to *this higher-order function* + /// are evaluated eagerly and which may be evaluated lazily. Note that this + /// does *not* applies to the arguments that *lambda functions* pass to it's + /// body expression + /// + /// If this function returns `None`, all arguments are eagerly evaluated. + /// Returning `None` is a micro optimization that saves a needless `Vec` + /// allocation. + /// + /// If the function returns `Some`, returns (`eager`, `lazy`) where `eager` + /// are the arguments that are always evaluated, and `lazy` are the + /// arguments that may be evaluated lazily (i.e. may not be evaluated at all + /// in some cases). + /// + /// Implementations must ensure that the two returned `Vec`s are disjunct, + /// and that each argument from `args` is present in one the two `Vec`s. + /// + /// When overriding this function, [HigherOrderUDFImpl::short_circuits] must + /// be overridden to return `true`. + fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + if self.short_circuits() { + Some((vec![], args.iter().collect())) + } else { + None + } + } + + /// Coerce value arguments of a function call to types that the function can evaluate. + /// Note that if you need to coerce values based on the output type of lambdas, you + /// must use [HigherOrderUDFImpl::coerce_values_for_lambdas], as this function is used before + /// the output type of lambdas are known + /// + /// See the [type coercion module](crate::type_coercion) + /// documentation for more details on type coercion + /// + /// For example, if your function requires a contiguous list argument, but the user calls + /// it like `my_func(c, v -> v+2)` (i.e. with `c` as a ListView), coerce_types can return `[DataType::List(..)]` + /// to ensure the argument is converted to a List + /// + /// # Parameters + /// * `arg_types`: The argument types of the value arguments of this function, excluding lambdas + /// + /// # Return value + /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call + /// arguments to these specific types. + fn coerce_value_types(&self, _arg_types: &[DataType]) -> Result> { + not_impl_err!( + "Function {} does not implement coerce_value_types", + self.name() + ) + } + + /// Returns the documentation for this function. + /// + /// Documentation can be accessed programmatically as well as generating + /// publicly facing documentation. + fn documentation(&self) -> Option<&Documentation> { + None + } +} + +/// Logical representation of a Higher Order User Defined Function. +/// +/// A higher order function takes one or more lambda arguments in addition to +/// regular value arguments. This struct contains the information DataFusion +/// needs to plan and invoke functions you supply such as name, type signature, +/// return type, and actual implementation. +#[derive(Debug, Clone)] +pub struct HigherOrderUDF { + inner: Arc, +} + +impl PartialEq for HigherOrderUDF { + fn eq(&self, other: &Self) -> bool { + self.inner.as_ref().dyn_eq(other.inner.as_ref()) + } +} + +impl PartialOrd for HigherOrderUDF { + fn partial_cmp(&self, other: &Self) -> Option { + let mut cmp = self.name().cmp(other.name()); + if cmp == Ordering::Equal { + cmp = self.signature().partial_cmp(other.signature())?; + } + if cmp == Ordering::Equal { + cmp = self.aliases().partial_cmp(other.aliases())?; + } + // Contract for PartialOrd and PartialEq consistency requires that + // a == b if and only if partial_cmp(a, b) == Some(Equal). + if cmp == Ordering::Equal && self != other { + // Functions may have other properties besides name and signature + // that differentiate two instances (e.g. type, or arbitrary parameters). + // We cannot return Some(Equal) in such case. + return None; + } + debug_assert!( + cmp == Ordering::Equal || self != other, + "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \ + The functions compare as equal, but they are not equal based on general properties that \ + the PartialOrd implementation observes,", + self.name(), + other.name() + ); + Some(cmp) + } +} + +impl Eq for HigherOrderUDF {} + +impl Hash for HigherOrderUDF { + fn hash(&self, state: &mut H) { + self.inner.dyn_hash(state) + } +} + +impl HigherOrderUDF { + /// Create a new `HigherOrderUDF` from a [`HigherOrderUDFImpl`] trait object. + /// + /// Note this is the same as using the `From` impl (`HigherOrderUDF::from`). + pub fn new_from_impl(fun: F) -> HigherOrderUDF + where + F: HigherOrderUDFImpl + 'static, + { + Self::new_from_shared_impl(Arc::new(fun)) + } + + /// Create a new `HigherOrderUDF` from a shared [`HigherOrderUDFImpl`] trait object. + pub fn new_from_shared_impl(fun: Arc) -> HigherOrderUDF { + Self { inner: fun } + } + + /// Return the underlying [`HigherOrderUDFImpl`] trait object for this function. + pub fn inner(&self) -> &Arc { + &self.inner + } + + /// Adds additional names that can be used to invoke this function, in + /// addition to `name`. + /// + /// If you implement [`HigherOrderUDFImpl`] directly you should return aliases + /// directly. + pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { + Self::new_from_impl(AliasedHigherOrderUDFImpl::new( + Arc::clone(&self.inner), + aliases, + )) + } + + /// Returns this function's name. + /// + /// See [`HigherOrderUDFImpl::name`] for more details. + pub fn name(&self) -> &str { + self.inner.name() + } + + /// Returns the aliases for this function. + /// + /// See [`HigherOrderUDF::with_aliases`] for more details. + pub fn aliases(&self) -> &[String] { + self.inner.aliases() + } + + /// Returns this function's schema_name. + /// + /// See [`HigherOrderUDFImpl::schema_name`] for more details. + pub fn schema_name(&self, args: &[Expr]) -> Result { + self.inner.schema_name(args) + } + + /// Returns this function's [`HigherOrderSignature`]. + pub fn signature(&self) -> &HigherOrderSignature { + self.inner.signature() + } + + /// Returns the parameters of all lambdas of this function for the current step. + /// + /// See [`HigherOrderUDFImpl::lambda_parameters`] for more details. + pub fn lambda_parameters( + &self, + step: usize, + fields: &[ValueOrLambda>], + ) -> Result { + self.inner.lambda_parameters(step, fields) + } + + /// Coerce value arguments based on lambda output types. + /// + /// See [`HigherOrderUDFImpl::coerce_values_for_lambdas`] for more details. + pub fn coerce_values_for_lambdas( + &self, + fields: &[ValueOrLambda], + ) -> Result>> { + self.inner.coerce_values_for_lambdas(fields) + } + + /// Returns the return field of the function given its arguments. + /// + /// See [`HigherOrderUDFImpl::return_field_from_args`] for more details. + pub fn return_field_from_args( + &self, + args: HigherOrderReturnFieldArgs, + ) -> Result { + self.inner.return_field_from_args(args) + } + + /// Whether List or LargeList arguments should have non-empty null sublists + /// cleaned before invoking this function. + pub fn clear_null_values(&self) -> bool { + self.inner.clear_null_values() + } + + /// Invoke the function returning the appropriate result. + /// + /// See [`HigherOrderUDFImpl::invoke_with_args`] for more details. + pub fn invoke_with_args( + &self, + args: HigherOrderFunctionArgs, + ) -> Result { + self.inner.invoke_with_args(args) + } + + /// Returns true if some of this function's subexpressions may not be evaluated. + /// + /// See [`HigherOrderUDFImpl::short_circuits`] for more details. + pub fn short_circuits(&self) -> bool { + self.inner.short_circuits() + } + + /// Returns which arguments are evaluated eagerly vs lazily. + /// + /// See [`HigherOrderUDFImpl::conditional_arguments`] for more details. + pub fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + self.inner.conditional_arguments(args) + } + + /// Coerce value arguments of a function call to types that the function can evaluate. + /// + /// See [`HigherOrderUDFImpl::coerce_value_types`] for more details. + pub fn coerce_value_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_value_types(arg_types) + } + + /// Returns the documentation for this function, if any. + pub fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() + } +} + +impl From for HigherOrderUDF +where + F: HigherOrderUDFImpl + 'static, +{ + fn from(fun: F) -> Self { + Self::new_from_impl(fun) + } +} + +/// `HigherOrderUDFImpl` that adds aliases to the underlying function. It is +/// better to implement [`HigherOrderUDFImpl`], which supports aliases, directly +/// if possible. +#[derive(Debug, PartialEq, Eq, Hash)] +struct AliasedHigherOrderUDFImpl { + inner: UdfEq>, + aliases: Vec, +} + +impl AliasedHigherOrderUDFImpl { + fn new( + inner: Arc, + new_aliases: impl IntoIterator, + ) -> Self { + let mut aliases = inner.aliases().to_vec(); + aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); + Self { + inner: inner.into(), + aliases, + } + } +} + +#[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method +impl HigherOrderUDFImpl for AliasedHigherOrderUDFImpl { + fn name(&self) -> &str { + self.inner.name() + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn schema_name(&self, args: &[Expr]) -> Result { + self.inner.schema_name(args) + } + + fn signature(&self) -> &HigherOrderSignature { + self.inner.signature() + } + + fn lambda_parameters( + &self, + step: usize, + fields: &[ValueOrLambda>], + ) -> Result { + self.inner.lambda_parameters(step, fields) + } + + fn coerce_values_for_lambdas( + &self, + fields: &[ValueOrLambda], + ) -> Result>> { + self.inner.coerce_values_for_lambdas(fields) + } + + fn return_field_from_args( + &self, + args: HigherOrderReturnFieldArgs, + ) -> Result { + self.inner.return_field_from_args(args) + } + + fn clear_null_values(&self) -> bool { + self.inner.clear_null_values() + } + + fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result { + self.inner.invoke_with_args(args) + } + + fn short_circuits(&self) -> bool { + self.inner.short_circuits() + } + + fn conditional_arguments<'a>( + &self, + args: &'a [Expr], + ) -> Option<(Vec<&'a Expr>, Vec<&'a Expr>)> { + self.inner.conditional_arguments(args) + } + + fn coerce_value_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_value_types(arg_types) + } + + fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() + } +} + +pub(crate) fn resolve_lambda_variables( + expr: Expr, + schema: &DFSchema, + // a map of lambda variable name => a never empty stack of fields [ [..shadowed], in_scope ] + vars: &mut HashMap>, +) -> Result> { + expr.transform_down(|expr| match expr { + Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => { + // not inlined to reduce nesting + resolve_higher_order_function(func, args, schema, vars) + } + Expr::LambdaVariable(mut var) => { + let field_stack = vars.get(&var.name).ok_or_else(|| { + plan_datafusion_err!( + "missing field of lambda variable {} while resolving", + var.name + ) + })?; + + let field = field_stack.last().ok_or_else(|| { + internal_datafusion_err!("every entry should have at least one field") + })?; + + let field = Arc::clone(field).renamed(&var.name); + + let transformed = var.field.as_ref().is_none_or(|old| old != &field); + + var.field = Some(field); + + Ok(Transformed::new_transformed( + Expr::LambdaVariable(var), + transformed, + )) + } + _ => Ok(Transformed::no(expr)), + }) +} + +fn resolve_higher_order_function( + func: Arc, + args: Vec, + schema: &DFSchema, + // a map of lambda variable name => a never empty stack of fields [ [..shadowed], in_scope ] + vars: &mut HashMap>, +) -> Result> { + let args = if !vars.is_empty() { + /* if this is a nested lambda, we must resolve non-lambda args before invoking + lambda_parameters because it will invoke ExprSchemable::to_field for every + non-lambda parameter, and if one them contains a lambda variable, it will fail + due to it being unresolved. Example query: + + array_transform([[1, 2]], a -> array_transform(a, b -> b+1)) + + the nested array_transform's lambda_parameters will call Lambdavariable::to_field + on it's first argument, the variable `a`, which must be resolved + */ + args.map_elements(|arg| match arg { + Expr::Lambda(_) => Ok(Transformed::no(arg)), + _ => resolve_lambda_variables(arg, schema, vars), + })? + } else { + Transformed::no(args) + }; + + let transformed = args.transformed; + let mut args = args.data; + + let current_fields = args + .iter() + .map(|e| match e { + Expr::Lambda(_lambda_function) => Ok(ValueOrLambda::Lambda(None)), + _ => Ok(ValueOrLambda::Value(e.to_field(schema)?.1)), + }) + .collect::>>()?; + + // coerce fields because coercion may alter the lambda parameters + let mut fields = value_fields_with_higher_order_udf(¤t_fields, func.as_ref())?; + + let num_lambdas = args.iter().filter(|a| matches!(a, Expr::Lambda(_))).count(); + + let mut step = 0; + + let lambda_params = loop { + match func.lambda_parameters(step, &fields)? { + LambdaParametersProgress::Partial(params) => { + let mut params = params.into_iter(); + + if params.len() != num_lambdas { + return plan_err!( + "{} lambda_parameters returned {} lambdas but {num_lambdas} expected", + func.name(), + params.len() + ); + } + + for (arg, field) in std::iter::zip(&mut args, &mut fields) { + match (arg, field) { + (Expr::Lambda(lambda), ValueOrLambda::Lambda(field)) => { + let params = params.next().ok_or_else(|| { + internal_datafusion_err!( + "params len should have been checked above" + ) + })?; + + if let Some(params) = params { + for (name, field) in + std::iter::zip(&lambda.params, params) + { + vars.entry_ref(name) + .or_default() + .push(field.renamed(name.as_str())); + } + + let body_with_vars = resolve_lambda_variables( + mem::take(lambda.body.as_mut()), + schema, + vars, + )?; + + remove_scope(vars, &lambda.params)?; + + *field = Some(body_with_vars.data.to_field(schema)?.1); + *lambda.body = body_with_vars.data; + } + } + (_, ValueOrLambda::Lambda(_)) => { + return internal_err!( + "value_fields_with_higher_order_udf returned a value for a lambda argument" + ); + } + (Expr::Lambda(_), ValueOrLambda::Value(_)) => { + return internal_err!( + "value_fields_with_higher_order_udf returned a lambda for a value argument" + ); + } + (_, ValueOrLambda::Value(_)) => {} // nothing to do + } + } + } + LambdaParametersProgress::Complete(params) => break params, + } + + let limit = func.signature().lambda_parameters_max_iterations; + + step += 1; + + if step > limit { + return plan_err!( + "{} lambda_parameters called {limit} times without completion", + func.name() + ); + } + }; + + let mut lambda_params = lambda_params.into_iter(); + + if num_lambdas != lambda_params.len() { + return plan_err!( + "{} lambda_parameters returned {} values for {num_lambdas} lambdas", + func.name(), + lambda_params.len() + ); + } + + let args = args.map_elements(|arg| match arg { + Expr::Lambda(mut lambda) => { + let lambda_params = lambda_params.next().ok_or_else(|| { + internal_datafusion_err!( + "lambda_params len should have been checked above" + ) + })?; + + if lambda.params.len() > lambda_params.len() { + return plan_err!( + "{} lambda defined {} params ({}), but only {} supported", + func.name(), + lambda.params.len(), + display_comma_separated(&lambda.params), + lambda_params.len() + ); + } + + if !all_unique(&lambda.params) { + return plan_err!( + "lambda params must be unique, got ({})", + lambda.params.join(", ") + ); + } + + for (param, field) in std::iter::zip(&lambda.params, lambda_params) { + vars.entry_ref(param) + .or_default() + .push(field.renamed(param.as_str())); + } + + let transformed = + resolve_lambda_variables(mem::take(lambda.body.as_mut()), schema, vars)?; + + *lambda.body = transformed.data; + + remove_scope(vars, &lambda.params)?; + + Ok(Transformed::new( + Expr::Lambda(lambda), + transformed.transformed, + TreeNodeRecursion::Jump, + )) + } + arg => Ok(Transformed::no(arg)), // resolved at the start of the function + })?; + + Ok(Transformed::new( + Expr::HigherOrderFunction(HigherOrderFunction::new(func, args.data)), + transformed || args.transformed, + TreeNodeRecursion::Jump, + )) +} + +fn remove_scope( + vars: &mut HashMap>, + scope: &[String], +) -> Result<()> { + for param in scope { + match vars.entry_ref(param) { + EntryRef::Occupied(mut v) => { + if v.get().len() == 1 { + v.remove(); + } else { + v.get_mut().pop().ok_or_else(|| { + internal_datafusion_err!( + "every entry should have at least one field" + ) + })?; + } + } + EntryRef::Vacant(_v) => { + return internal_err!("no empty value should be in the map"); + } + } + } + + Ok(()) +} + +fn all_unique(params: &[String]) -> bool { + match params.len() { + 0 | 1 => true, + 2 => params[0] != params[1], + _ => { + let mut set = HashSet::with_capacity(params.len()); + + params.iter().all(|p| set.insert(p.as_str())) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::hash::DefaultHasher; + use std::sync::Arc; + + use arrow_schema::{DataType, Field, FieldRef, Schema}; + use datafusion_common::{DFSchema, Result}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_expr_common::signature::Volatility; + + use crate::{ + Expr, HigherOrderSignature, HigherOrderUDF, HigherOrderUDFImpl, + LambdaParametersProgress, ValueOrLambda, col, + expr::{HigherOrderFunction, LambdaVariable}, + lambda, lambda_var, lit, + }; + + #[derive(Debug, PartialEq, Eq, Hash)] + struct TestHigherOrderUDF { + name: &'static str, + field: &'static str, + signature: HigherOrderSignature, + } + impl HigherOrderUDFImpl for TestHigherOrderUDF { + fn name(&self) -> &str { + self.name + } + + fn signature(&self) -> &HigherOrderSignature { + &self.signature + } + + fn lambda_parameters( + &self, + _step: usize, + _fields: &[ValueOrLambda>], + ) -> Result { + unimplemented!() + } + + fn return_field_from_args( + &self, + _args: HigherOrderReturnFieldArgs, + ) -> Result { + unimplemented!() + } + + fn invoke_with_args( + &self, + _args: HigherOrderFunctionArgs, + ) -> Result { + unimplemented!() + } + } + + // PartialEq and Hash must be consistent, and also PartialEq and PartialOrd + // must be consistent, so they are tested together. + #[test] + fn test_partial_eq_hash_and_partial_ord() { + // A parameterized function + let f = test_func("foo", "a"); + + // Same like `f`, different instance + let f2 = test_func("foo", "a"); + assert_eq!(&f, &f2); + assert_eq!(hash(&f), hash(&f2)); + assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal)); + + // Different parameter + let b = test_func("foo", "b"); + assert_ne!(&f, &b); + assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test + assert_eq!(f.partial_cmp(&b), None); + + // Different name + let o = test_func("other", "a"); + assert_ne!(&f, &o); + assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test + assert_eq!(f.partial_cmp(&o), Some(Ordering::Less)); + + // Different name and parameter + assert_ne!(&b, &o); + assert_ne!(hash(&b), hash(&o)); // hash can collide for different values but does not collide in this test + assert_eq!(b.partial_cmp(&o), Some(Ordering::Less)); + } + + fn test_func(name: &'static str, parameter: &'static str) -> Arc { + Arc::new(HigherOrderUDF::new_from_impl(TestHigherOrderUDF { + name, + field: parameter, + signature: HigherOrderSignature::variadic_any(Volatility::Immutable), + })) + } + + fn hash(value: &T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } + + #[derive(Debug, PartialEq, Eq, Hash)] + struct MockArrayReduce { + signature: HigherOrderSignature, + } + + impl HigherOrderUDFImpl for MockArrayReduce { + fn name(&self) -> &str { + "array_reduce" + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn signature(&self) -> &HigherOrderSignature { + &self.signature + } + + fn lambda_parameters( + &self, + step: usize, + fields: &[ValueOrLambda>], + ) -> Result { + // optional finish not supported for simplicity + let [ + ValueOrLambda::Value(list), + ValueOrLambda::Value(initial_value), + ValueOrLambda::Lambda(merge), + ValueOrLambda::Lambda(_finish), + ] = fields + else { + unreachable!() + }; + + let list_field = match list.data_type() { + DataType::List(field) => field, + _ => unreachable!(), + }; + + Ok(match (step, merge) { + (0, None) => { + // at the first step, we use the initial_value as merge accumulator, + // and return None for finish since we don't know the output of merge + LambdaParametersProgress::Partial(vec![ + // merge + Some(vec![Arc::clone(initial_value), Arc::clone(list_field)]), + // finish + None, + ]) + } + (1, Some(accumulator)) | (0, Some(accumulator)) => { + // now we can use the merge output as it's accumulator and + // as the finish parameter + LambdaParametersProgress::Complete(vec![ + // merge + vec![Arc::clone(accumulator), Arc::clone(list_field)], + // finish + vec![Arc::clone(accumulator)], + ]) + } + (1, None) => { + unreachable!() + } + _ => unreachable!(), + }) + } + + fn return_field_from_args( + &self, + args: HigherOrderReturnFieldArgs, + ) -> Result { + // optional finish not supported for simplicity + let [ + ValueOrLambda::Value(_list), + ValueOrLambda::Value(_initial_value), + ValueOrLambda::Lambda(_merge), + ValueOrLambda::Lambda(finish), + ] = args.arg_fields + else { + unreachable!() + }; + + Ok(Arc::clone(finish)) + } + + fn invoke_with_args( + &self, + _args: HigherOrderFunctionArgs, + ) -> Result { + unreachable!() + } + } + + #[test] + fn test_resolve_lambda_variables() { + let schema = DFSchema::try_from(Schema::new(vec![Field::new( + "c", + DataType::new_list(DataType::new_list(DataType::Int32, true), true), + true, + )])) + .unwrap(); + + let func = Arc::new(HigherOrderUDF::new_from_impl(MockArrayReduce { + signature: HigherOrderSignature::variadic_any(Volatility::Immutable), + })); + + /* + array_reduce( + c, + 0, + (acc1, v) -> acc + array_reduce( + v, + 0, + (acc2, v) -> acc2 + acc1 + v, + reduced -> reduced * 2.0 + ), + reduced -> reduced * 2 + ) + */ + let expr = Expr::HigherOrderFunction(HigherOrderFunction::new( + Arc::clone(&func), + vec![ + col("c"), + lit(0), + lambda( + ["acc1", "v"], + lambda_var("acc1") + + Expr::HigherOrderFunction(HigherOrderFunction::new( + Arc::clone(&func), + vec![ + lambda_var("v"), + lit(0), + lambda( + ["acc2", "v"], + lambda_var("acc2") + + lambda_var("acc1") + + lambda_var("v"), + ), + lambda(["reduced"], lambda_var("reduced") * lit(2.0)), + ], + )), + ), + lambda(["reduced"], lambda_var("reduced") * lit(2)), + ], + )); + + let resolved_expr = expr.resolve_lambda_variables(&schema).unwrap().data; + + /* + array_reduce( + c@[[Int32]], + 0@Int64, + (acc1@Float64, v@[Int32]) -> acc@Float64 + array_reduce( + v@[Int32], + 0@Int64, + (acc2@Float64, v@Int32) -> acc2@Float64 + acc1@Float64 + v@Int32, + reducedFloat64 -> reduced@Float64 * 2.0@Float64 + ), + reduced@Float64 -> reduced@Float64 * 2@Int64 + ) + */ + let expected = Expr::HigherOrderFunction(HigherOrderFunction::new( + Arc::clone(&func), + vec![ + col("c"), + lit(0), + lambda( + ["acc1", "v"], + resolved_lambda_var("acc1", DataType::Float64, true) + + Expr::HigherOrderFunction(HigherOrderFunction::new( + Arc::clone(&func), + vec![ + resolved_lambda_var( + "v", + DataType::new_list(DataType::Int32, true), + true, + ), + lit(0), + lambda( + ["acc2", "v"], + resolved_lambda_var("acc2", DataType::Float64, true) + + resolved_lambda_var( + "acc1", + DataType::Float64, + true, + ) + + resolved_lambda_var("v", DataType::Int32, true), + ), + lambda( + ["reduced"], + resolved_lambda_var( + "reduced", + DataType::Float64, + true, + ) * lit(2.0), + ), + ], + )), + ), + lambda( + ["reduced"], + resolved_lambda_var("reduced", DataType::Float64, true) * lit(2), + ), + ], + )); + + assert_eq!(resolved_expr, expected); + } + + fn resolved_lambda_var(name: &str, dt: DataType, nullable: bool) -> Expr { + Expr::LambdaVariable(LambdaVariable::new( + name.into(), + Some(Arc::new(Field::new(name, dt, nullable))), + )) + } +} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index c82b56aa58a3c..b52a784df931a 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -23,8 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! [DataFusion](https://github.com/apache/datafusion) @@ -39,6 +37,7 @@ extern crate core; +mod higher_order_function; mod literal; mod operation; mod partition_evaluator; @@ -54,6 +53,7 @@ pub mod expr; pub mod expr_fn; pub mod expr_rewriter; pub mod expr_schema; +pub mod extension_types; pub mod function; pub mod select_expr; pub mod groups_accumulator { @@ -63,6 +63,10 @@ pub mod interval_arithmetic { pub use datafusion_expr_common::interval_arithmetic::*; } pub mod logical_plan; +pub mod dml { + //! DML (Data Manipulation Language) types for DELETE, UPDATE operations. + pub use crate::logical_plan::dml::*; +} pub mod planner; pub mod registry; pub mod simplify; @@ -74,7 +78,10 @@ pub mod statistics { pub use datafusion_expr_common::statistics::*; } mod predicate_bounds; +pub mod preimage; pub mod ptr_eq; +#[cfg(not(feature = "sql"))] +pub mod sql; pub mod test; pub mod tree_node; pub mod type_coercion; @@ -85,16 +92,17 @@ pub mod window_frame; pub mod window_state; pub use datafusion_doc::{ - aggregate_doc_sections, scalar_doc_sections, window_doc_sections, DocSection, - Documentation, DocumentationBuilder, + DocSection, Documentation, DocumentationBuilder, aggregate_doc_sections, + scalar_doc_sections, window_doc_sections, }; pub use datafusion_expr_common::accumulator::Accumulator; pub use datafusion_expr_common::columnar_value::ColumnarValue; pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; pub use datafusion_expr_common::operator::Operator; +pub use datafusion_expr_common::placement::ExpressionPlacement; pub use datafusion_expr_common::signature::{ - ArrayFunctionArgument, ArrayFunctionSignature, Coercion, Signature, TypeSignature, - TypeSignatureClass, Volatility, TIMEZONE_WILDCARD, + ArrayFunctionArgument, ArrayFunctionSignature, Coercion, Signature, + TIMEZONE_WILDCARD, TypeSignature, TypeSignatureClass, Volatility, }; pub use datafusion_expr_common::type_coercion::binary; pub use expr::{ @@ -107,8 +115,13 @@ pub use function::{ AccumulatorFactoryFunction, PartitionEvaluatorFactory, ReturnTypeFunction, ScalarFunctionImplementation, StateTypeFunction, }; +pub use higher_order_function::{ + HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, HigherOrderSignature, + HigherOrderTypeSignature, HigherOrderUDF, HigherOrderUDFImpl, LambdaArgument, + LambdaParametersProgress, ValueOrLambda, +}; pub use literal::{ - lit, lit_timestamp_nano, lit_with_metadata, Literal, TimestampLiteral, + Literal, TimestampLiteral, lit, lit_timestamp_nano, lit_with_metadata, }; pub use logical_plan::*; pub use partition_evaluator::PartitionEvaluator; @@ -116,17 +129,19 @@ pub use partition_evaluator::PartitionEvaluator; pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{ + AggregateUDF, AggregateUDFImpl, ReversedUDAF, SetMonotonicity, StatisticsArgs, udaf_default_display_name, udaf_default_human_display, udaf_default_return_field, udaf_default_schema_name, udaf_default_window_function_display_name, - udaf_default_window_function_schema_name, AggregateUDF, AggregateUDFImpl, - ReversedUDAF, SetMonotonicity, StatisticsArgs, + udaf_default_window_function_schema_name, +}; +pub use udf::{ + ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, StructFieldMapping, }; -pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; #[cfg(test)] -#[ctor::ctor] +#[ctor::ctor(unsafe)] fn init() { // Enable RUST_LOG logging configuration for test let _ = env_logger::try_init(); diff --git a/datafusion/expr/src/literal.rs b/datafusion/expr/src/literal.rs index 09d8e9bb58b23..2e2980d607648 100644 --- a/datafusion/expr/src/literal.rs +++ b/datafusion/expr/src/literal.rs @@ -18,7 +18,7 @@ //! Literal module contains foundational types that are used to represent literals in DataFusion. use crate::Expr; -use datafusion_common::{metadata::FieldMetadata, ScalarValue}; +use datafusion_common::{ScalarValue, metadata::FieldMetadata}; /// Create a literal expression #[expect(clippy::needless_pass_by_value)] diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index b9afd894d77d3..2ecb12c30afad 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -17,7 +17,6 @@ //! This module provides a builder for creating LogicalPlans -use std::any::Any; use std::borrow::Cow; use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; @@ -34,8 +33,8 @@ use crate::expr_rewriter::{ use crate::logical_plan::{ Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, - Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, - Window, + Projection, Repartition, Sort, SubqueryAlias, TableScanBuilder, Union, Unnest, + Values, Window, }; use crate::select_expr::SelectExpr; use crate::utils::{ @@ -44,8 +43,8 @@ use crate::utils::{ group_window_expr_by_sort_keys, }; use crate::{ - and, binary_expr, lit, DmlStatement, ExplainOption, Expr, ExprSchemable, Operator, - RecursiveQuery, Statement, TableProviderFilterPushDown, TableSource, WriteOp, + DmlStatement, ExplainOption, Expr, ExprSchemable, Operator, RecursiveQuery, + Statement, TableProviderFilterPushDown, TableSource, WriteOp, and, binary_expr, lit, }; use super::dml::InsertOp; @@ -55,9 +54,10 @@ use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ - exec_err, get_target_functional_dependencies, internal_datafusion_err, not_impl_err, - plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, - NullEquality, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, + Column, Constraints, DFSchema, DFSchemaRef, NullEquality, Result, ScalarValue, + TableReference, ToDFSchema, UnnestOptions, exec_err, + get_target_functional_dependencies, internal_datafusion_err, plan_datafusion_err, + plan_err, }; use datafusion_expr_common::type_coercion::binary::type_union_resolution; @@ -179,30 +179,26 @@ impl LogicalPlanBuilder { recursive_term: LogicalPlan, is_distinct: bool, ) -> Result { - // TODO: we need to do a bunch of validation here. Maybe more. - if is_distinct { - return not_impl_err!( - "Recursive queries with a distinct 'UNION' (in which the previous iteration's results will be de-duplicated) is not supported" - ); - } // Ensure that the static term and the recursive term have the same number of fields let static_fields_len = self.plan.schema().fields().len(); let recursive_fields_len = recursive_term.schema().fields().len(); if static_fields_len != recursive_fields_len { return plan_err!( "Non-recursive term and recursive term must have the same number of columns ({} != {})", - static_fields_len, recursive_fields_len + static_fields_len, + recursive_fields_len ); } // Ensure that the recursive term has the same field types as the static term let coerced_recursive_term = coerce_plan_expr_for_schema(recursive_term, self.plan.schema())?; - Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { + let recursive_query = RecursiveQuery::try_new( name, - static_term: self.plan, - recursive_term: Arc::new(coerced_recursive_term), + self.plan, + Arc::new(coerced_recursive_term), is_distinct, - }))) + )?; + Ok(Self::from(LogicalPlan::RecursiveQuery(recursive_query))) } /// Create a values list based relation, and the schema is inferred from data, consuming @@ -307,39 +303,50 @@ impl LogicalPlanBuilder { for j in 0..n_cols { let mut common_type: Option = None; let mut common_metadata: Option = None; + let mut nullable = false; for (i, row) in values.iter().enumerate() { let value = &row[j]; let metadata = value.metadata(&schema)?; if let Some(ref cm) = common_metadata { if &metadata != cm { - return plan_err!("Inconsistent metadata across values list at row {i} column {j}. Was {:?} but found {:?}", cm, metadata); + return plan_err!( + "Inconsistent metadata across values list at row {i} column {j}. Was {:?} but found {:?}", + cm, + metadata + ); } } else { common_metadata = Some(metadata.clone()); } + if !nullable && value.nullable(&schema)? { + nullable = true; + } let data_type = value.get_type(&schema)?; if data_type == DataType::Null { continue; } if let Some(prev_type) = common_type { - // get common type of each column values. + // Widen the running type so that it can hold both the + // previously seen rows and this row's value. let data_types = vec![prev_type.clone(), data_type.clone()]; let Some(new_type) = type_union_resolution(&data_types) else { - return plan_err!("Inconsistent data type across values list at row {i} column {j}. Was {prev_type} but found {data_type}"); + return plan_err!( + "Inconsistent data type across values list at row {i} column {j}. Was {prev_type} but found {data_type}" + ); }; common_type = Some(new_type); } else { common_type = Some(data_type); } } - // assuming common_type was not set, and no error, therefore the type should be NULL - // since the code loop skips NULL - fields.push_with_metadata( - common_type.unwrap_or(DataType::Null), - true, - common_metadata, - ); + // If common_type is not set, every value in this column had type + // NULL. A DataType::Null field is always nullable. + let (data_type, nullable) = match common_type { + Some(t) => (t, nullable), + None => (DataType::Null, true), + }; + fields.push_with_metadata(data_type, nullable, common_metadata); } Self::infer_inner(values, fields, &schema) @@ -509,33 +516,34 @@ impl LogicalPlanBuilder { filters: Vec, fetch: Option, ) -> Result { - let table_scan = - TableScan::try_new(table_name, table_source, projection, filters, fetch)?; + let table_scan = TableScanBuilder::new(table_name, table_source) + .with_projection(projection) + .with_filters(filters) + .with_fetch(fetch) + .build()?; // Inline TableScan - if table_scan.filters.is_empty() { - if let Some(p) = table_scan.source.get_logical_plan() { - let sub_plan = p.into_owned(); - - if let Some(proj) = table_scan.projection { - let projection_exprs = proj - .into_iter() - .map(|i| { - Expr::Column(Column::from( - sub_plan.schema().qualified_field(i), - )) - }) - .collect::>(); - return Self::new(sub_plan) - .project(projection_exprs)? - .alias(table_scan.table_name); - } + if table_scan.filters.is_empty() + && let Some(p) = table_scan.source.get_logical_plan() + { + let sub_plan = p.into_owned(); - // Ensures that the reference to the inlined table remains the - // same, meaning we don't have to change any of the parent nodes - // that reference this table. - return Self::new(sub_plan).alias(table_scan.table_name); + if let Some(proj) = table_scan.projection { + let projection_exprs = proj + .into_iter() + .map(|i| { + Expr::Column(Column::from(sub_plan.schema().qualified_field(i))) + }) + .collect::>(); + return Self::new(sub_plan) + .project(projection_exprs)? + .alias(table_scan.table_name); } + + // Ensures that the reference to the inlined table remains the + // same, meaning we don't have to change any of the parent nodes + // that reference this table. + return Self::new(sub_plan).alias(table_scan.table_name); } Ok(Self::new(LogicalPlan::TableScan(table_scan))) @@ -593,7 +601,23 @@ impl LogicalPlanBuilder { self, expr: Vec<(impl Into, bool)>, ) -> Result { - project_with_validation(Arc::unwrap_or_clone(self.plan), expr).map(Self::new) + project_with_validation(Arc::unwrap_or_clone(self.plan), expr, None) + .map(Self::new) + } + + /// Apply a projection, aliasing non-Column/non-Alias expressions to + /// match the field names from the provided schema. + pub fn project_with_validation_and_schema( + self, + expr: impl IntoIterator>, + schema: &DFSchemaRef, + ) -> Result { + project_with_validation( + Arc::unwrap_or_clone(self.plan), + expr.into_iter().map(|e| (e, true)), + Some(schema), + ) + .map(Self::new) } /// Select the given column indices @@ -771,7 +795,9 @@ impl LogicalPlanBuilder { .map(|col| col.flat_name()) .collect::(); - plan_err!("For SELECT DISTINCT, ORDER BY expressions {missing_col_names} must appear in select list") + plan_err!( + "For SELECT DISTINCT, ORDER BY expressions {missing_col_names} must appear in select list" + ) } /// Apply a sort by provided expressions with default direction @@ -1009,6 +1035,25 @@ impl LogicalPlanBuilder { join_keys: (Vec>, Vec>), filter: Option, null_equality: NullEquality, + ) -> Result { + self.join_detailed_with_options( + right, + join_type, + join_keys, + filter, + null_equality, + false, + ) + } + + pub fn join_detailed_with_options( + self, + right: LogicalPlan, + join_type: JoinType, + join_keys: (Vec>, Vec>), + filter: Option, + null_equality: NullEquality, + null_aware: bool, ) -> Result { if join_keys.0.len() != join_keys.1.len() { return plan_err!("left_keys and right_keys were not the same length"); @@ -1126,6 +1171,7 @@ impl LogicalPlanBuilder { join_constraint: JoinConstraint::On, schema: DFSchemaRef::new(join_schema), null_equality, + null_aware, }))) } @@ -1199,6 +1245,7 @@ impl LogicalPlanBuilder { join_type, JoinConstraint::Using, NullEquality::NullEqualsNothing, + false, // null_aware )?; Ok(Self::new(LogicalPlan::Join(join))) @@ -1215,6 +1262,7 @@ impl LogicalPlanBuilder { JoinType::Inner, JoinConstraint::On, NullEquality::NullEqualsNothing, + false, // null_aware )?; Ok(Self::new(LogicalPlan::Join(join))) @@ -1288,8 +1336,11 @@ impl LogicalPlanBuilder { if explain_option.analyze { Ok(Self::new(LogicalPlan::Analyze(Analyze { verbose: explain_option.verbose, + format: explain_option.format, input: self.plan, schema, + analyze_level: explain_option.analyze_level, + analyze_categories: explain_option.analyze_categories, }))) } else { let stringified_plans = @@ -1302,6 +1353,7 @@ impl LogicalPlanBuilder { stringified_plans, schema, logical_optimization_succeeded: false, + show_statistics: explain_option.show_statistics, }))) } } @@ -1350,6 +1402,15 @@ impl LogicalPlanBuilder { ); } + // Requalify sides if needed to avoid duplicate qualified field names + // (e.g., when both sides reference the same table) + let left_builder = LogicalPlanBuilder::from(left_plan); + let right_builder = LogicalPlanBuilder::from(right_plan); + let (left_builder, right_builder, _requalified) = + requalify_sides_if_needed(left_builder, right_builder)?; + let left_plan = left_builder.build()?; + let right_plan = right_builder.build()?; + let join_keys = left_plan .schema() .fields() @@ -1460,6 +1521,7 @@ impl LogicalPlanBuilder { join_type, JoinConstraint::On, NullEquality::NullEqualsNothing, + false, // null_aware )?; Ok(Self::new(LogicalPlan::Join(join))) @@ -1729,23 +1791,61 @@ pub fn requalify_sides_if_needed( ) -> Result<(LogicalPlanBuilder, LogicalPlanBuilder, bool)> { let left_cols = left.schema().columns(); let right_cols = right.schema().columns(); - if left_cols.iter().any(|l| { - right_cols.iter().any(|r| { - l == r || (l.name == r.name && (l.relation.is_none() || r.relation.is_none())) - }) - }) { - // These names have no connection to the original plan, but they'll make the columns - // (mostly) unique. - Ok(( - left.alias(TableReference::bare("left"))?, - right.alias(TableReference::bare("right"))?, - true, - )) - } else { - Ok((left, right, false)) + + // Requalify if merging the schemas would cause an error during join. + // This can happen in several cases: + // 1. Duplicate qualified fields: both sides have same relation.name + // 2. Duplicate unqualified fields: both sides have same unqualified name + // 3. Ambiguous reference: one side qualified, other unqualified, same name + // + // Implementation note: This uses a simple O(n*m) nested loop rather than + // a HashMap-based O(n+m) approach. The nested loop is preferred because: + // - Schemas are typically small (in TPCH benchmark, max is 16 columns), + // so n*m is negligible + // - Early return on first conflict makes common case very fast + // - Code is simpler and easier to reason about + // - Called only during plan construction, not in execution hot path + for l in &left_cols { + for r in &right_cols { + if l.name != r.name { + continue; + } + + // Same name - check if this would cause a conflict + match (&l.relation, &r.relation) { + // Both qualified with same relation - duplicate qualified field + (Some(l_rel), Some(r_rel)) if l_rel == r_rel => { + return Ok(( + left.alias(TableReference::bare("left"))?, + right.alias(TableReference::bare("right"))?, + true, + )); + } + // Both unqualified - duplicate unqualified field + (None, None) => { + return Ok(( + left.alias(TableReference::bare("left"))?, + right.alias(TableReference::bare("right"))?, + true, + )); + } + // One qualified, one not - ambiguous reference + (Some(_), None) | (None, Some(_)) => { + return Ok(( + left.alias(TableReference::bare("left"))?, + right.alias(TableReference::bare("right"))?, + true, + )); + } + // Different qualifiers - OK, no conflict + _ => {} + } + } } -} + // No conflicts found + Ok((left, right, false)) +} /// Add additional "synthetic" group by expressions based on functional /// dependencies. /// @@ -1844,7 +1944,7 @@ pub fn project( plan: LogicalPlan, expr: impl IntoIterator>, ) -> Result { - project_with_validation(plan, expr.into_iter().map(|e| (e, true))) + project_with_validation(plan, expr.into_iter().map(|e| (e, true)), None) } /// Create Projection. Similar to project except that the expressions @@ -1857,12 +1957,15 @@ pub fn project( fn project_with_validation( plan: LogicalPlan, expr: impl IntoIterator, bool)>, + schema: Option<&DFSchemaRef>, ) -> Result { let mut projected_expr = vec![]; + let mut has_wildcard = false; for (e, validate) in expr { let e = e.into(); match e { SelectExpr::Wildcard(opt) => { + has_wildcard = true; let expanded = expand_wildcard(plan.schema(), &plan, Some(&opt))?; // If there is a REPLACE statement, replace that column with the given @@ -1883,6 +1986,7 @@ fn project_with_validation( } } SelectExpr::QualifiedWildcard(table_ref, opt) => { + has_wildcard = true; let expanded = expand_qualified_wildcard(&table_ref, plan.schema(), Some(&opt))?; @@ -1912,6 +2016,24 @@ fn project_with_validation( } } } + + if has_wildcard && projected_expr.is_empty() && !plan.schema().fields().is_empty() { + return plan_err!( + "SELECT list is empty after resolving * expressions, \ + the wildcard expanded to zero columns" + ); + } + + // When inside a set expression, alias non-Column/non-Alias expressions + // to match the left side's field names, avoiding duplicate name errors. + if let Some(schema) = &schema { + for (expr, field) in projected_expr.iter_mut().zip(schema.fields()) { + if !matches!(expr, Expr::Column(_) | Expr::Alias(_)) { + *expr = std::mem::take(expr).alias(field.name()); + } + } + } + validate_unique_names("Projections", projected_expr.iter())?; Projection::try_new(projected_expr, Arc::new(plan)).map(LogicalPlan::Projection) @@ -1926,15 +2048,14 @@ fn replace_columns( replace: &PlannedReplaceSelectItem, ) -> Result> { for expr in exprs.iter_mut() { - if let Expr::Column(Column { name, .. }) = expr { - if let Some((_, new_expr)) = replace + if let Expr::Column(Column { name, .. }) = expr + && let Some((_, new_expr)) = replace .items() .iter() .zip(replace.expressions().iter()) .find(|(item, _)| item.column_name.value == *name) - { - *expr = new_expr.clone().alias(name.clone()) - } + { + *expr = new_expr.clone().alias(name.clone()) } } Ok(exprs) @@ -2052,6 +2173,8 @@ pub fn wrap_projection_for_join_if_necessary( .into_iter() .map(Expr::Column) .collect::>(); + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] + // Expr contains Arc with interior mutability but is intentionally used as hash key let join_key_items = alias_join_keys .iter() .flat_map(|expr| expr.try_as_col().is_none().then_some(expr)) @@ -2105,10 +2228,6 @@ impl LogicalTableSource { } impl TableSource for LogicalTableSource { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { Arc::clone(&self.table_schema) } @@ -2188,7 +2307,7 @@ mod tests { use super::*; use crate::lit_with_metadata; use crate::logical_plan::StringifiedPlan; - use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; + use crate::{col, expr, expr_fn::exists, in_subquery, scalar_subquery}; use crate::test::function_stub::sum; use datafusion_common::{ @@ -2708,12 +2827,12 @@ mod tests { assert_snapshot!(plan, @r" Union - Cross Join: + Cross Join: SubqueryAlias: left Values: (Int32(1)) SubqueryAlias: right Values: (Int32(1)) - Cross Join: + Cross Join: SubqueryAlias: left Values: (Int32(1)) SubqueryAlias: right @@ -2828,11 +2947,13 @@ mod tests { .into_iter() .collect(); let metadata2 = FieldMetadata::from(metadata2); - assert!(LogicalPlanBuilder::values(vec![ - vec![lit_with_metadata(1, Some(metadata.clone()))], - vec![lit_with_metadata(2, Some(metadata2.clone()))], - ]) - .is_err()); + assert!( + LogicalPlanBuilder::values(vec![ + vec![lit_with_metadata(1, Some(metadata.clone()))], + vec![lit_with_metadata(2, Some(metadata2.clone()))], + ]) + .is_err() + ); Ok(()) } diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 74fe7a2d009d0..1990a31edb95f 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -24,9 +24,9 @@ use std::{ hash::{Hash, Hasher}, }; -#[cfg(not(feature = "sql"))] -use crate::expr::Ident; use crate::expr::Sort; +#[cfg(not(feature = "sql"))] +use crate::sql::Ident; use arrow::datatypes::DataType; use datafusion_common::tree_node::{Transformed, TreeNodeContainer, TreeNodeRecursion}; use datafusion_common::{ @@ -38,8 +38,10 @@ use sqlparser::ast::Ident; /// Various types of DDL (CREATE / DROP) catalog manipulation #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum DdlStatement { - /// Creates an external table. - CreateExternalTable(CreateExternalTable), + /// Creates an external table. Boxed to keep `LogicalPlan` enum size down + /// — `CreateExternalTable` is ~312 bytes, dwarfing every other variant + /// in the plan tree and forcing the whole enum to that width. + CreateExternalTable(Box), /// Creates an in memory table. CreateMemoryTable(CreateMemoryTable), /// Creates a new view. @@ -56,8 +58,9 @@ pub enum DdlStatement { DropView(DropView), /// Drops a catalog schema DropCatalogSchema(DropCatalogSchema), - /// Create function statement - CreateFunction(CreateFunction), + /// Create function statement. Boxed for the same reason as + /// [`Self::CreateExternalTable`] (~288 bytes). + CreateFunction(Box), /// Drop function statement DropFunction(DropFunction), } @@ -66,9 +69,7 @@ impl DdlStatement { /// Get a reference to the logical plan's schema pub fn schema(&self) -> &DFSchemaRef { match self { - DdlStatement::CreateExternalTable(CreateExternalTable { schema, .. }) => { - schema - } + DdlStatement::CreateExternalTable(ce) => &ce.schema, DdlStatement::CreateMemoryTable(CreateMemoryTable { input, .. }) | DdlStatement::CreateView(CreateView { input, .. }) => input.schema(), DdlStatement::CreateCatalogSchema(CreateCatalogSchema { schema, .. }) => { @@ -79,7 +80,7 @@ impl DdlStatement { DdlStatement::DropTable(DropTable { schema, .. }) => schema, DdlStatement::DropView(DropView { schema, .. }) => schema, DdlStatement::DropCatalogSchema(DropCatalogSchema { schema, .. }) => schema, - DdlStatement::CreateFunction(CreateFunction { schema, .. }) => schema, + DdlStatement::CreateFunction(cf) => &cf.schema, DdlStatement::DropFunction(DropFunction { schema, .. }) => schema, } } @@ -131,11 +132,9 @@ impl DdlStatement { impl Display for Wrapper<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.0 { - DdlStatement::CreateExternalTable(CreateExternalTable { - ref name, - constraints, - .. - }) => { + DdlStatement::CreateExternalTable(ce) => { + let name = &ce.name; + let constraints = &ce.constraints; if constraints.is_empty() { write!(f, "CreateExternalTable: {name:?}") } else { @@ -186,9 +185,13 @@ impl DdlStatement { cascade, .. }) => { - write!(f, "DropCatalogSchema: {name:?} if not exist:={if_exists} cascade:={cascade}") + write!( + f, + "DropCatalogSchema: {name:?} if not exist:={if_exists} cascade:={cascade}" + ) } - DdlStatement::CreateFunction(CreateFunction { name, .. }) => { + DdlStatement::CreateFunction(cf) => { + let name = &cf.name; write!(f, "CreateFunction: name {name:?}") } DdlStatement::DropFunction(DropFunction { name, .. }) => { @@ -234,6 +237,158 @@ pub struct CreateExternalTable { pub column_defaults: HashMap, } +impl CreateExternalTable { + /// Creates a builder for [`CreateExternalTable`] with required fields. + /// + /// # Arguments + /// * `name` - The table name + /// * `location` - The physical location of the table files + /// * `file_type` - The file type (e.g., "parquet", "csv", "json") + /// * `schema` - The table schema + /// + /// # Example + /// ``` + /// # use datafusion_expr::CreateExternalTable; + /// # use datafusion_common::{DFSchema, TableReference}; + /// # use std::sync::Arc; + /// let table = CreateExternalTable::builder( + /// TableReference::bare("my_table"), + /// "/path/to/data", + /// "parquet", + /// Arc::new(DFSchema::empty()) + /// ).build(); + /// ``` + pub fn builder( + name: impl Into, + location: impl Into, + file_type: impl Into, + schema: DFSchemaRef, + ) -> CreateExternalTableBuilder { + CreateExternalTableBuilder { + name: name.into(), + location: location.into(), + file_type: file_type.into(), + schema, + table_partition_cols: vec![], + if_not_exists: false, + or_replace: false, + temporary: false, + definition: None, + order_exprs: vec![], + unbounded: false, + options: HashMap::new(), + constraints: Default::default(), + column_defaults: HashMap::new(), + } + } +} + +/// Builder for [`CreateExternalTable`] that provides a fluent API for construction. +/// +/// Created via [`CreateExternalTable::builder`]. +#[derive(Debug, Clone)] +pub struct CreateExternalTableBuilder { + name: TableReference, + location: String, + file_type: String, + schema: DFSchemaRef, + table_partition_cols: Vec, + if_not_exists: bool, + or_replace: bool, + temporary: bool, + definition: Option, + order_exprs: Vec>, + unbounded: bool, + options: HashMap, + constraints: Constraints, + column_defaults: HashMap, +} + +impl CreateExternalTableBuilder { + /// Set the partition columns + pub fn with_partition_cols(mut self, cols: Vec) -> Self { + self.table_partition_cols = cols; + self + } + + /// Set the if_not_exists flag + pub fn with_if_not_exists(mut self, if_not_exists: bool) -> Self { + self.if_not_exists = if_not_exists; + self + } + + /// Set the or_replace flag + pub fn with_or_replace(mut self, or_replace: bool) -> Self { + self.or_replace = or_replace; + self + } + + /// Set the temporary flag + pub fn with_temporary(mut self, temporary: bool) -> Self { + self.temporary = temporary; + self + } + + /// Set the SQL definition + pub fn with_definition(mut self, definition: Option) -> Self { + self.definition = definition; + self + } + + /// Set the order expressions + pub fn with_order_exprs(mut self, order_exprs: Vec>) -> Self { + self.order_exprs = order_exprs; + self + } + + /// Set the unbounded flag + pub fn with_unbounded(mut self, unbounded: bool) -> Self { + self.unbounded = unbounded; + self + } + + /// Set the table options + pub fn with_options(mut self, options: HashMap) -> Self { + self.options = options; + self + } + + /// Set the table constraints + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = constraints; + self + } + + /// Set the column defaults + pub fn with_column_defaults( + mut self, + column_defaults: HashMap, + ) -> Self { + self.column_defaults = column_defaults; + self + } + + /// Build the [`CreateExternalTable`] + pub fn build(self) -> CreateExternalTable { + CreateExternalTable { + schema: self.schema, + name: self.name, + location: self.location, + file_type: self.file_type, + table_partition_cols: self.table_partition_cols, + if_not_exists: self.if_not_exists, + or_replace: self.or_replace, + temporary: self.temporary, + definition: self.definition, + order_exprs: self.order_exprs, + unbounded: self.unbounded, + options: self.options, + constraints: self.constraints, + column_defaults: self.column_defaults, + } + } +} + // Hashing refers to a subset of fields considered in PartialEq. impl Hash for CreateExternalTable { fn hash(&self, state: &mut H) { diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index b60126335598f..27b86a6d8cdd5 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -21,17 +21,17 @@ use std::collections::HashMap; use std::fmt; use crate::{ - expr_vec_fmt, Aggregate, DescribeTable, Distinct, DistinctOn, DmlStatement, Expr, - Filter, Join, Limit, LogicalPlan, Partitioning, Projection, RecursiveQuery, - Repartition, Sort, Subquery, SubqueryAlias, TableProviderFilterPushDown, TableScan, - Unnest, Values, Window, + Aggregate, DescribeTable, Distinct, DistinctOn, DmlStatement, Expr, Filter, Join, + Limit, LogicalPlan, Partitioning, Projection, RecursiveQuery, Repartition, Sort, + Subquery, SubqueryAlias, TableProviderFilterPushDown, TableScan, Unnest, Values, + Window, expr_vec_fmt, }; use crate::dml::CopyTo; use arrow::datatypes::Schema; use datafusion_common::display::GraphvizBuilder; use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; -use datafusion_common::{internal_datafusion_err, Column, DataFusionError}; +use datafusion_common::{Column, DataFusionError, internal_datafusion_err}; use serde_json::json; /// Formats plans with a single line per node. For example: @@ -117,13 +117,7 @@ pub fn display_schema(schema: &Schema) -> impl fmt::Display + '_ { write!(f, ", ")?; } let nullable_str = if field.is_nullable() { ";N" } else { "" }; - write!( - f, - "{}:{:?}{}", - field.name(), - field.data_type(), - nullable_str - )?; + write!(f, "{}:{}{}", field.name(), field.data_type(), nullable_str)?; } write!(f, "]") } @@ -319,7 +313,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Is Distinct": is_distinct, }) } - LogicalPlan::Values(Values { ref values, .. }) => { + LogicalPlan::Values(Values { values, .. }) => { let str_values = values .iter() // limit to only 5 values to avoid horrible display @@ -344,10 +338,10 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { }) } LogicalPlan::TableScan(TableScan { - ref source, - ref table_name, - ref filters, - ref fetch, + source, + table_name, + filters, + fetch, .. }) => { let mut object = json!({ @@ -403,7 +397,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { object } - LogicalPlan::Projection(Projection { ref expr, .. }) => { + LogicalPlan::Projection(Projection { expr, .. }) => { json!({ "Node Type": "Projection", "Expressions": expr.iter().map(|e| e.to_string()).collect::>() @@ -443,25 +437,22 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { }) } LogicalPlan::Filter(Filter { - predicate: ref expr, - .. + predicate: expr, .. }) => { json!({ "Node Type": "Filter", "Condition": format!("{}", expr) }) } - LogicalPlan::Window(Window { - ref window_expr, .. - }) => { + LogicalPlan::Window(Window { window_expr, .. }) => { json!({ "Node Type": "WindowAggr", "Expressions": expr_vec_fmt!(window_expr) }) } LogicalPlan::Aggregate(Aggregate { - ref group_expr, - ref aggr_expr, + group_expr, + aggr_expr, .. }) => { json!({ @@ -483,7 +474,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { object } LogicalPlan::Join(Join { - on: ref keys, + on: keys, filter, join_constraint, join_type, @@ -524,6 +515,23 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Partitioning Key": hash_expr }) } + Partitioning::Range(range) => { + let range_expr: Vec = + range.ordering().iter().map(|e| format!("{e}")).collect(); + let split_points: Vec = range + .split_points() + .iter() + .map(|e| format!("{e}")) + .collect(); + + json!({ + "Node Type": "Repartition", + "Partitioning Scheme": "Range", + "Partition Count": range.partition_count(), + "Partitioning Key": range_expr, + "Split Points": split_points + }) + } Partitioning::DistributeBy(expr) => { let dist_by_expr: Vec = expr.iter().map(|e| format!("{e}")).collect(); @@ -534,11 +542,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { }) } }, - LogicalPlan::Limit(Limit { - ref skip, - ref fetch, - .. - }) => { + LogicalPlan::Limit(Limit { skip, fetch, .. }) => { let mut object = serde_json::json!( { "Node Type": "Limit", @@ -557,7 +561,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Node Type": "Subquery" }) } - LogicalPlan::SubqueryAlias(SubqueryAlias { ref alias, .. }) => { + LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => { json!({ "Node Type": "Subquery", "Alias": alias.table(), diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index b8448a5da6c42..b668cbfe2cc35 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -122,11 +122,9 @@ impl CopyTo { /// * `INSERT` - Appends new rows to the existing table. Calls /// [`TableProvider::insert_into`] /// -/// * `DELETE` - Removes rows from the table. Currently NOT supported by the -/// [`TableProvider`] trait or builtin sources. +/// * `DELETE` - Removes rows from the table. Calls [`TableProvider::delete_from`] /// -/// * `UPDATE` - Modifies existing rows in the table. Currently NOT supported by -/// the [`TableProvider`] trait or builtin sources. +/// * `UPDATE` - Modifies existing rows in the table. Calls [`TableProvider::update`] /// /// * `CREATE TABLE AS SELECT` - Creates a new table and populates it with data /// from a query. This is similar to the `INSERT` operation, but it creates a new @@ -136,6 +134,8 @@ impl CopyTo { /// /// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html /// [`TableProvider::insert_into`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html#method.insert_into +/// [`TableProvider::delete_from`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html#method.delete_from +/// [`TableProvider::update`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html#method.update #[derive(Clone)] pub struct DmlStatement { /// The table name @@ -237,6 +237,8 @@ pub enum WriteOp { Update, /// `CREATE TABLE AS SELECT` operation Ctas, + /// `TRUNCATE` operation + Truncate, } impl WriteOp { @@ -247,6 +249,7 @@ impl WriteOp { WriteOp::Delete => "Delete", WriteOp::Update => "Update", WriteOp::Ctas => "Ctas", + WriteOp::Truncate => "Truncate", } } } diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index 1c2c8a2a936f5..0889afd08fee4 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -16,16 +16,15 @@ // under the License. use datafusion_common::{ - assert_or_internal_err, plan_err, + DFSchemaRef, Result, assert_or_internal_err, plan_err, tree_node::{TreeNode, TreeNodeRecursion}, - DFSchemaRef, Result, }; use crate::{ - expr::{Exists, InSubquery}, + Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window, + expr::{Exists, InSubquery, SetComparison}, expr_rewriter::strip_outer_reference, utils::{collect_subquery_cols, split_conjunction}, - Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window, }; use super::Extension; @@ -82,6 +81,7 @@ fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Re match expr { Expr::Exists(Exists { subquery, .. }) | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { assert_valid_extension_nodes(&subquery.subquery, check)?; } @@ -134,6 +134,7 @@ fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> { match expr { Expr::Exists(Exists { subquery, .. }) | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { check_subquery_expr(plan, &subquery.subquery, expr)?; } @@ -198,21 +199,26 @@ pub fn check_subquery_expr( } }?; match outer_plan { - LogicalPlan::Projection(_) - | LogicalPlan::Filter(_) => Ok(()), - LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. }) => { + LogicalPlan::Projection(_) | LogicalPlan::Filter(_) => Ok(()), + LogicalPlan::Aggregate(Aggregate { + group_expr, + aggr_expr, + .. + }) => { if group_expr.contains(expr) && !aggr_expr.contains(expr) { // TODO revisit this validation logic plan_err!( - "Correlated scalar subquery in the GROUP BY clause must also be in the aggregate expressions" + "Correlated scalar subquery in the GROUP BY clause must \ + also be in the aggregate expressions" ) } else { Ok(()) } } _ => plan_err!( - "Correlated scalar subquery can only be used in Projection, Filter, Aggregate plan nodes" - ) + "Correlated scalar subquery can only be used in Projection, \ + Filter, Aggregate plan nodes" + ), }?; } check_correlations_in_subquery(inner_plan) @@ -227,6 +233,20 @@ pub fn check_subquery_expr( ); } } + if let Expr::SetComparison(set_comparison) = expr + && set_comparison.subquery.subquery.schema().fields().len() > 1 + { + return plan_err!( + "Set comparison subquery should only return one column, but found {}: {}", + set_comparison.subquery.subquery.schema().fields().len(), + set_comparison + .subquery + .subquery + .schema() + .field_names() + .join(", ") + ); + } match outer_plan { LogicalPlan::Projection(_) | LogicalPlan::Filter(_) @@ -235,7 +255,7 @@ pub fn check_subquery_expr( | LogicalPlan::Aggregate(_) | LogicalPlan::Join(_) => Ok(()), _ => plan_err!( - "In/Exist subquery can only be used in \ + "In/Exist/SetComparison subquery can only be used in \ Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, \ but was used in [{}]", outer_plan.display() diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index f0212be294a96..e0e51d7e470c3 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -21,15 +21,15 @@ pub mod display; pub mod dml; mod extension; pub(crate) mod invariants; -pub use invariants::{assert_expected_schema, check_subquery_expr, InvariantLevel}; +pub use invariants::{InvariantLevel, assert_expected_schema, check_subquery_expr}; mod plan; mod statement; pub mod tree_node; pub use builder::{ + LogicalPlanBuilder, LogicalPlanBuilderOptions, LogicalTableSource, UNNAMED_TABLE, build_join_schema, requalify_sides_if_needed, table_scan, union, - wrap_projection_for_join_if_necessary, LogicalPlanBuilder, LogicalPlanBuilderOptions, - LogicalTableSource, UNNAMED_TABLE, + wrap_projection_for_join_if_necessary, }; pub use ddl::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction, @@ -38,11 +38,12 @@ pub use ddl::{ }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ - projection_schema, Aggregate, Analyze, ColumnUnnestList, DescribeTable, Distinct, - DistinctOn, EmptyRelation, Explain, ExplainOption, Extension, FetchType, Filter, - Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, - Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, - SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, + Aggregate, Analyze, ColumnUnnestList, DescribeTable, Distinct, DistinctOn, + EmptyRelation, Explain, ExplainOption, Extension, FetchType, Filter, Join, + JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Projection, + RangePartitioning, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, + Subquery, SubqueryAlias, TableScan, TableScanBuilder, ToStringifiedPlan, Union, + Unnest, Values, Window, projection_schema, }; pub use statement::{ Deallocate, Execute, Prepare, ResetVariable, SetVariable, Statement, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index ad9a46b004fca..3608c81878d17 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -18,53 +18,56 @@ //! Logical plan types use std::cmp::Ordering; -use std::collections::{HashMap, HashSet}; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::{Arc, LazyLock}; +use super::DdlStatement; use super::dml::CopyTo; use super::invariants::{ - assert_always_invariants_at_current_node, assert_executable_invariants, - InvariantLevel, + InvariantLevel, assert_always_invariants_at_current_node, + assert_executable_invariants, }; -use super::DdlStatement; use crate::builder::{unique_field_aliases, unnest_with_options}; use crate::expr::{ - intersect_metadata_for_union, Alias, Placeholder, Sort as SortExpr, WindowFunction, - WindowFunctionParams, + Alias, Placeholder, Sort as SortExpr, WindowFunction, WindowFunctionParams, + intersect_metadata_for_union, }; use crate::expr_rewriter::{ - create_col_from_scalar_expr, normalize_cols, normalize_sorts, NamePreserver, + NamePreserver, create_col_from_scalar_expr, normalize_cols, normalize_sorts, }; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, - grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction, + grouping_set_expr_count, grouping_set_to_exprlist, merge_schema, split_conjunction, }; use crate::{ - build_join_schema, expr_vec_fmt, requalify_sides_if_needed, BinaryExpr, - CreateMemoryTable, CreateView, Execute, Expr, ExprSchemable, LogicalPlanBuilder, - Operator, Prepare, TableProviderFilterPushDown, TableSource, - WindowFunctionDefinition, + BinaryExpr, CreateMemoryTable, CreateView, Execute, Expr, ExprSchemable, GroupingSet, + LogicalPlanBuilder, Operator, Prepare, TableProviderFilterPushDown, TableSource, + WindowFunctionDefinition, build_join_schema, expr_vec_fmt, requalify_sides_if_needed, }; +use crate::statistics::StatisticsRequest; +use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; use datafusion_common::cse::{NormalizeEq, Normalizeable}; -use datafusion_common::format::ExplainFormat; +use datafusion_common::format::{ExplainAnalyzeCategories, ExplainFormat, MetricType}; use datafusion_common::metadata::check_metadata_with_storage_equal; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; use datafusion_common::{ + Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, + FunctionalDependence, FunctionalDependencies, NullEquality, ParamValues, Result, + ScalarValue, Spans, SplitPoint, TableReference, UnnestOptions, aggregate_functional_dependencies, assert_eq_or_internal_err, assert_or_internal_err, - internal_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, - Dependency, FunctionalDependence, FunctionalDependencies, NullEquality, ParamValues, - Result, ScalarValue, Spans, TableReference, UnnestOptions, + internal_err, plan_err, validate_range_split_points, }; use indexmap::IndexSet; +use itertools::Itertools as _; // backwards compatibility use crate::display::PgJsonVisitor; @@ -295,9 +298,12 @@ pub enum LogicalPlan { impl Default for LogicalPlan { fn default() -> Self { + // `Default` is used as a transient placeholder on hot paths (e.g. + // `Box`/`Arc` `map_elements`), so use a shared empty schema to avoid + // allocating. LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: Arc::new(DFSchema::empty()), + schema: Arc::clone(DFSchema::empty_ref()), }) } } @@ -351,10 +357,7 @@ impl LogicalPlan { LogicalPlan::Copy(CopyTo { output_schema, .. }) => output_schema, LogicalPlan::Ddl(ddl) => ddl.schema(), LogicalPlan::Unnest(Unnest { schema, .. }) => schema, - LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { - // we take the schema of the static term as the schema of the entire recursive query - static_term.schema() - } + LogicalPlan::RecursiveQuery(RecursiveQuery { schema, .. }) => schema, } } @@ -662,6 +665,7 @@ impl LogicalPlan { on, schema: _, null_equality, + null_aware, }) => { let schema = build_join_schema(left.schema(), right.schema(), &join_type)?; @@ -683,6 +687,7 @@ impl LogicalPlan { filter, schema: DFSchemaRef::new(schema), null_equality, + null_aware, })) } LogicalPlan::Subquery(_) => Ok(self), @@ -736,7 +741,14 @@ impl LogicalPlan { }; Ok(LogicalPlan::Distinct(distinct)) } - LogicalPlan::RecursiveQuery(_) => Ok(self), + LogicalPlan::RecursiveQuery(RecursiveQuery { + name, + static_term, + recursive_term, + is_distinct, + schema: _, + }) => RecursiveQuery::try_new(name, static_term, recursive_term, is_distinct) + .map(LogicalPlan::RecursiveQuery), LogicalPlan::Analyze(_) => Ok(self), LogicalPlan::Explain(_) => Ok(self), LogicalPlan::TableScan(_) => Ok(self), @@ -860,6 +872,32 @@ impl LogicalPlan { input: Arc::new(input), })) } + Partitioning::Range(range) => { + if expr.len() != range.ordering().len() { + return internal_err!( + "Incorrect number of expressions for Range partitioning" + ); + } + let input = self.only_input(inputs)?; + let ordering = range + .ordering() + .iter() + .zip(expr) + .map(|(sort_expr, expr)| SortExpr { + expr, + asc: sort_expr.asc, + nulls_first: sort_expr.nulls_first, + }) + .collect(); + let range = RangePartitioning::try_new( + ordering, + range.split_points().to_vec(), + )?; + Ok(LogicalPlan::Repartition(Repartition { + partitioning_scheme: Partitioning::Range(range), + input: Arc::new(input), + })) + } Partitioning::DistributeBy(_) => { let input = self.only_input(inputs)?; Ok(LogicalPlan::Repartition(Repartition { @@ -902,6 +940,7 @@ impl LogicalPlan { join_constraint, on, null_equality, + null_aware, .. }) => { let (left, right) = self.only_two_inputs(inputs)?; @@ -925,7 +964,9 @@ impl LogicalPlan { let mut iter = expr.into_iter(); while let Some(left) = iter.next() { let Some(right) = iter.next() else { - internal_err!("Expected a pair of expressions to construct the join on expression")? + internal_err!( + "Expected a pair of expressions to construct the join on expression" + )? }; // SimplifyExpression rule may add alias to the equi_expr. @@ -941,6 +982,7 @@ impl LogicalPlan { filter: filter_expr, schema: DFSchemaRef::new(schema), null_equality: *null_equality, + null_aware: *null_aware, })) } LogicalPlan::Subquery(Subquery { @@ -1053,7 +1095,10 @@ impl LogicalPlan { let input = self.only_input(inputs)?; let sort_expr = expr.split_off(on_expr.len() + select_expr.len()); let select_expr = expr.split_off(on_expr.len()); - assert!(sort_expr.is_empty(), "with_new_exprs for Distinct does not support sort expressions"); + assert!( + sort_expr.is_empty(), + "with_new_exprs for Distinct does not support sort expressions" + ); Distinct::On(DistinctOn::try_new( expr, select_expr, @@ -1069,20 +1114,24 @@ impl LogicalPlan { }) => { self.assert_no_expressions(expr)?; let (static_term, recursive_term) = self.only_two_inputs(inputs)?; - Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { - name: name.clone(), - static_term: Arc::new(static_term), - recursive_term: Arc::new(recursive_term), - is_distinct: *is_distinct, - })) + RecursiveQuery::try_new( + name.clone(), + Arc::new(static_term), + Arc::new(recursive_term), + *is_distinct, + ) + .map(LogicalPlan::RecursiveQuery) } LogicalPlan::Analyze(a) => { self.assert_no_expressions(expr)?; let input = self.only_input(inputs)?; Ok(LogicalPlan::Analyze(Analyze { verbose: a.verbose, + format: a.format.clone(), schema: Arc::clone(&a.schema), input: Arc::new(input), + analyze_level: a.analyze_level, + analyze_categories: a.analyze_categories.clone(), })) } LogicalPlan::Explain(e) => { @@ -1095,6 +1144,7 @@ impl LogicalPlan { stringified_plans: e.stringified_plans.clone(), schema: Arc::clone(&e.schema), logical_optimization_succeeded: e.logical_optimization_succeeded, + show_statistics: e.show_statistics, })) } LogicalPlan::Statement(Statement::Prepare(Prepare { @@ -1384,6 +1434,82 @@ impl LogicalPlan { } } + /// Returns the skip (offset) of this plan node, if it has one. + /// + /// Only [`LogicalPlan::Limit`] carries a skip value; all other variants + /// return `Ok(None)`. Returns `Ok(None)` for a zero skip. + pub fn skip(&self) -> Result> { + match self { + LogicalPlan::Limit(limit) => match limit.get_skip_type()? { + SkipType::Literal(0) => Ok(None), + SkipType::Literal(n) => Ok(Some(n)), + SkipType::UnsupportedExpr => Ok(None), + }, + LogicalPlan::Sort(_) => Ok(None), + LogicalPlan::TableScan(_) => Ok(None), + LogicalPlan::Projection(_) => Ok(None), + LogicalPlan::Filter(_) => Ok(None), + LogicalPlan::Window(_) => Ok(None), + LogicalPlan::Aggregate(_) => Ok(None), + LogicalPlan::Join(_) => Ok(None), + LogicalPlan::Repartition(_) => Ok(None), + LogicalPlan::Union(_) => Ok(None), + LogicalPlan::EmptyRelation(_) => Ok(None), + LogicalPlan::Subquery(_) => Ok(None), + LogicalPlan::SubqueryAlias(_) => Ok(None), + LogicalPlan::Statement(_) => Ok(None), + LogicalPlan::Values(_) => Ok(None), + LogicalPlan::Explain(_) => Ok(None), + LogicalPlan::Analyze(_) => Ok(None), + LogicalPlan::Extension(_) => Ok(None), + LogicalPlan::Distinct(_) => Ok(None), + LogicalPlan::Dml(_) => Ok(None), + LogicalPlan::Ddl(_) => Ok(None), + LogicalPlan::Copy(_) => Ok(None), + LogicalPlan::DescribeTable(_) => Ok(None), + LogicalPlan::Unnest(_) => Ok(None), + LogicalPlan::RecursiveQuery(_) => Ok(None), + } + } + + /// Returns the fetch (limit) of this plan node, if it has one. + /// + /// [`LogicalPlan::Sort`], [`LogicalPlan::TableScan`], and + /// [`LogicalPlan::Limit`] may carry a fetch value; all other variants + /// return `Ok(None)`. + pub fn fetch(&self) -> Result> { + match self { + LogicalPlan::Sort(Sort { fetch, .. }) => Ok(*fetch), + LogicalPlan::TableScan(TableScan { fetch, .. }) => Ok(*fetch), + LogicalPlan::Limit(limit) => match limit.get_fetch_type()? { + FetchType::Literal(s) => Ok(s), + FetchType::UnsupportedExpr => Ok(None), + }, + LogicalPlan::Projection(_) => Ok(None), + LogicalPlan::Filter(_) => Ok(None), + LogicalPlan::Window(_) => Ok(None), + LogicalPlan::Aggregate(_) => Ok(None), + LogicalPlan::Join(_) => Ok(None), + LogicalPlan::Repartition(_) => Ok(None), + LogicalPlan::Union(_) => Ok(None), + LogicalPlan::EmptyRelation(_) => Ok(None), + LogicalPlan::Subquery(_) => Ok(None), + LogicalPlan::SubqueryAlias(_) => Ok(None), + LogicalPlan::Statement(_) => Ok(None), + LogicalPlan::Values(_) => Ok(None), + LogicalPlan::Explain(_) => Ok(None), + LogicalPlan::Analyze(_) => Ok(None), + LogicalPlan::Extension(_) => Ok(None), + LogicalPlan::Distinct(_) => Ok(None), + LogicalPlan::Dml(_) => Ok(None), + LogicalPlan::Ddl(_) => Ok(None), + LogicalPlan::Copy(_) => Ok(None), + LogicalPlan::DescribeTable(_) => Ok(None), + LogicalPlan::Unnest(_) => Ok(None), + LogicalPlan::RecursiveQuery(_) => Ok(None), + } + } + /// If this node's expressions contains any references to an outer subquery pub fn contains_outer_reference(&self) -> bool { let mut contains = false; @@ -1767,16 +1893,19 @@ impl LogicalPlan { impl Display for Wrapper<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self.0 { - LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row, schema: _ }) => { + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row, + schema: _, + }) => { let rows = if *produce_one_row { 1 } else { 0 }; write!(f, "EmptyRelation: rows={rows}") - }, + } LogicalPlan::RecursiveQuery(RecursiveQuery { is_distinct, .. }) => { write!(f, "RecursiveQuery: is_distinct={is_distinct}") } - LogicalPlan::Values(Values { ref values, .. }) => { + LogicalPlan::Values(Values { values, .. }) => { let str_values: Vec<_> = values .iter() // limit to only 5 values to avoid horrible display @@ -1796,11 +1925,11 @@ impl LogicalPlan { } LogicalPlan::TableScan(TableScan { - ref source, - ref table_name, - ref projection, - ref filters, - ref fetch, + source, + table_name, + projection, + filters, + fetch, .. }) => { let projected_fields = match projection { @@ -1870,7 +1999,7 @@ impl LogicalPlan { Ok(()) } - LogicalPlan::Projection(Projection { ref expr, .. }) => { + LogicalPlan::Projection(Projection { expr, .. }) => { write!(f, "Projection:")?; for (i, expr_item) in expr.iter().enumerate() { if i > 0 { @@ -1896,18 +2025,19 @@ impl LogicalPlan { .collect::>() .join(", "); - write!(f, "CopyTo: format={} output_url={output_url} options: ({op_str})", file_type.get_ext()) + write!( + f, + "CopyTo: format={} output_url={output_url} options: ({op_str})", + file_type.get_ext() + ) } LogicalPlan::Ddl(ddl) => { write!(f, "{}", ddl.display()) } LogicalPlan::Filter(Filter { - predicate: ref expr, - .. + predicate: expr, .. }) => write!(f, "Filter: {expr}"), - LogicalPlan::Window(Window { - ref window_expr, .. - }) => { + LogicalPlan::Window(Window { window_expr, .. }) => { write!( f, "WindowAggr: windowExpr=[[{}]]", @@ -1915,8 +2045,8 @@ impl LogicalPlan { ) } LogicalPlan::Aggregate(Aggregate { - ref group_expr, - ref aggr_expr, + group_expr, + aggr_expr, .. }) => write!( f, @@ -1939,7 +2069,7 @@ impl LogicalPlan { Ok(()) } LogicalPlan::Join(Join { - on: ref keys, + on: keys, filter, join_constraint, join_type, @@ -1951,20 +2081,26 @@ impl LogicalPlan { .as_ref() .map(|expr| format!(" Filter: {expr}")) .unwrap_or_else(|| "".to_string()); - let join_type = if filter.is_none() && keys.is_empty() && matches!(join_type, JoinType::Inner) { + let join_type = if filter.is_none() + && keys.is_empty() + && *join_type == JoinType::Inner + { "Cross".to_string() } else { join_type.to_string() }; match join_constraint { JoinConstraint::On => { - write!( - f, - "{} Join: {}{}", - join_type, - join_expr.join(", "), - filter_expr - ) + write!(f, "{join_type} Join:",)?; + if !join_expr.is_empty() || !filter_expr.is_empty() { + write!( + f, + " {}{}", + join_expr.join(", "), + filter_expr + )?; + } + Ok(()) } JoinConstraint::Using => { write!( @@ -1994,6 +2130,9 @@ impl LogicalPlan { n ) } + Partitioning::Range(range) => { + write!(f, "Repartition: {range}") + } Partitioning::DistributeBy(expr) => { let dist_by_expr: Vec = expr.iter().map(|e| format!("{e}")).collect(); @@ -2008,22 +2147,25 @@ impl LogicalPlan { // Attempt to display `skip` and `fetch` as literals if possible, otherwise as expressions. let skip_str = match limit.get_skip_type() { Ok(SkipType::Literal(n)) => n.to_string(), - _ => limit.skip.as_ref().map_or_else(|| "None".to_string(), |x| x.to_string()), + _ => limit + .skip + .as_ref() + .map_or_else(|| "None".to_string(), |x| x.to_string()), }; let fetch_str = match limit.get_fetch_type() { Ok(FetchType::Literal(Some(n))) => n.to_string(), Ok(FetchType::Literal(None)) => "None".to_string(), - _ => limit.fetch.as_ref().map_or_else(|| "None".to_string(), |x| x.to_string()) + _ => limit + .fetch + .as_ref() + .map_or_else(|| "None".to_string(), |x| x.to_string()), }; - write!( - f, - "Limit: skip={skip_str}, fetch={fetch_str}", - ) + write!(f, "Limit: skip={skip_str}, fetch={fetch_str}",) } LogicalPlan::Subquery(Subquery { .. }) => { write!(f, "Subquery:") } - LogicalPlan::SubqueryAlias(SubqueryAlias { ref alias, .. }) => { + LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => { write!(f, "SubqueryAlias: {alias}") } LogicalPlan::Statement(statement) => { @@ -2041,7 +2183,11 @@ impl LogicalPlan { "DistinctOn: on_expr=[[{}]], select_expr=[[{}]], sort_expr=[[{}]]", expr_vec_fmt!(on_expr), expr_vec_fmt!(select_expr), - if let Some(sort_expr) = sort_expr { expr_vec_fmt!(sort_expr) } else { "".to_string() }, + if let Some(sort_expr) = sort_expr { + expr_vec_fmt!(sort_expr) + } else { + "".to_string() + }, ), }, LogicalPlan::Explain { .. } => write!(f, "Explain"), @@ -2054,28 +2200,48 @@ impl LogicalPlan { LogicalPlan::Unnest(Unnest { input: plan, list_type_columns: list_col_indices, - struct_type_columns: struct_col_indices, .. }) => { + struct_type_columns: struct_col_indices, + .. + }) => { let input_columns = plan.schema().columns(); let list_type_columns = list_col_indices .iter() - .map(|(i,unnest_info)| - format!("{}|depth={}", &input_columns[*i].to_string(), - unnest_info.depth)) + .map(|(i, unnest_info)| { + format!( + "{}|depth={}", + &input_columns[*i].to_string(), + unnest_info.depth + ) + }) .collect::>(); let struct_type_columns = struct_col_indices .iter() .map(|i| &input_columns[*i]) .collect::>(); // get items from input_columns indexed by list_col_indices - write!(f, "Unnest: lists[{}] structs[{}]", - expr_vec_fmt!(list_type_columns), - expr_vec_fmt!(struct_type_columns)) + write!( + f, + "Unnest: lists[{}] structs[{}]", + expr_vec_fmt!(list_type_columns), + expr_vec_fmt!(struct_type_columns) + ) } } } } Wrapper(self) } + + /// Return a `LogicalPLan` with all [`LambdaVariable`]'s resolved + /// + /// [`LambdaVariable`]: crate::expr::LambdaVariable + pub fn resolve_lambda_variables(self) -> Result> { + self.transform_with_subqueries(|plan| { + let schema = merge_schema(&plan.inputs()); + + plan.map_expressions(|expr| expr.resolve_lambda_variables(&schema)) + }) + } } impl Display for LogicalPlan { @@ -2133,7 +2299,7 @@ impl PartialOrd for EmptyRelation { /// intermediate table, then empty the intermediate table. /// /// [Postgres Docs]: https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-RECURSIVE -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct RecursiveQuery { /// Name of the query pub name: String, @@ -2145,6 +2311,90 @@ pub struct RecursiveQuery { /// Should the output of the recursive term be deduplicated (`UNION`) or /// not (`UNION ALL`). pub is_distinct: bool, + /// Schema exposed to parent plans after reconciling the static and recursive terms. + pub schema: DFSchemaRef, +} + +impl PartialOrd for RecursiveQuery { + fn partial_cmp(&self, other: &Self) -> Option { + match self.name.partial_cmp(&other.name) { + Some(Ordering::Equal) => { + match self.static_term.partial_cmp(&other.static_term) { + Some(Ordering::Equal) => { + match self.recursive_term.partial_cmp(&other.recursive_term) { + Some(Ordering::Equal) => { + self.is_distinct.partial_cmp(&other.is_distinct) + } + cmp => cmp, + } + } + cmp => cmp, + } + } + cmp => cmp, + } + // If the query definition compares equal but the derived schema differs, + // return `None` instead of contradicting `PartialEq` with `Some(Equal)`. + // TODO (https://github.com/apache/datafusion/issues/17477) avoid recomparing all fields + .filter(|cmp| *cmp != Ordering::Equal || self == other) + } +} + +impl RecursiveQuery { + pub fn try_new( + name: String, + static_term: Arc, + recursive_term: Arc, + is_distinct: bool, + ) -> Result { + let schema = + recursive_query_output_schema(static_term.schema(), recursive_term.schema())?; + Ok(Self { + name, + static_term, + recursive_term, + is_distinct, + schema, + }) + } +} + +/// Compute a recursive query's output schema by considering both its static and +/// recursive terms. +/// +/// Field names, types, and metadata come from the static term. A field is +/// nullable if either the static or the recursive term produces a nullable +/// value in that position, matching how `UNION` reconciles branch nullability. +/// +/// Functional dependencies are intentionally dropped: the recursive term +/// appends rows that can duplicate values the static term guarantees unique, so +/// any FDs carried by the static term may not hold over the combined output. +fn recursive_query_output_schema( + static_schema: &DFSchemaRef, + recursive_schema: &DFSchemaRef, +) -> Result { + if static_schema.fields().len() != recursive_schema.fields().len() { + return Err(DataFusionError::Plan(format!( + "Non-recursive term and recursive term must have the same number of columns ({} != {})", + static_schema.fields().len(), + recursive_schema.fields().len() + ))); + } + + let fields = static_schema + .iter() + .zip(recursive_schema.fields()) + .map(|((qualifier, static_field), recursive_field)| { + let nullable = static_field.is_nullable() || recursive_field.is_nullable(); + ( + qualifier.cloned(), + static_field.as_ref().clone().with_nullable(nullable).into(), + ) + }) + .collect::>(); + + DFSchema::new_with_metadata(fields, static_schema.metadata().clone()) + .map(DFSchemaRef::new) } /// Values expression. See @@ -2211,7 +2461,11 @@ impl Projection { if !expr.iter().any(|e| matches!(e, Expr::Wildcard { .. })) && expr.len() != schema.fields().len() { - return plan_err!("Projection has mismatch between number of expressions ({}) and number of fields in schema ({})", expr.len(), schema.fields().len()); + return plan_err!( + "Projection has mismatch between number of expressions ({}) and number of fields in schema ({})", + expr.len(), + schema.fields().len() + ); } Ok(Self { expr, @@ -2367,6 +2621,19 @@ pub struct Filter { } impl Filter { + /// Create a new filter operator. + /// + /// Skips the type-checking and dealiasing done in [Self::try_new]. + /// For internal use in DataFusion only. + /// + /// **Preconditions:** + /// - the `predicate` expression returns a boolean value + /// - the `predicate` expression is not aliased + #[doc(hidden)] + pub fn new(predicate: Expr, input: Arc) -> Self { + Self { predicate, input } + } + /// Create a new filter operator. /// /// Notes: as Aliases have no effect on the output of a filter operator, @@ -2398,12 +2665,12 @@ impl Filter { // Note that it is not always possible to resolve the predicate expression during plan // construction (such as with correlated subqueries) so we make a best effort here and // ignore errors resolving the expression against the schema. - if let Ok(predicate_type) = predicate.get_type(input.schema()) { - if !Filter::is_allowed_filter_type(&predicate_type) { - return plan_err!( - "Cannot create filter with non-boolean predicate '{predicate}' returning {predicate_type}" - ); - } + if let Ok(predicate_type) = predicate.get_type(input.schema()) + && !Filter::is_allowed_filter_type(&predicate_type) + { + return plan_err!( + "Cannot create filter with non-boolean predicate '{predicate}' returning {predicate_type}" + ); } Ok(Self { @@ -2652,6 +2919,12 @@ pub struct TableScan { pub filters: Vec, /// Optional number of rows to read pub fetch: Option, + /// Statistics the planner would like the provider to answer for this + /// scan, typically attached by a custom optimizer rule from the + /// surrounding plan (e.g. Min/Max for sort keys). + /// + /// A [`BTreeSet`], not a `Vec` to keep the resulting plan deterministic. + pub statistics_requests: BTreeSet, } impl Debug for TableScan { @@ -2726,6 +2999,7 @@ impl Hash for TableScan { impl TableScan { /// Initialize TableScan with appropriate schema from the given /// arguments. + #[deprecated(since = "54.0.0", note = "use `TableScanBuilder` instead")] pub fn try_new( table_name: impl Into, table_source: Arc, @@ -2733,14 +3007,92 @@ impl TableScan { filters: Vec, fetch: Option, ) -> Result { - let table_name = table_name.into(); + TableScanBuilder::new(table_name, table_source) + .with_projection(projection) + .with_filters(filters) + .with_fetch(fetch) + .build() + } +} + +/// Builder for [`TableScan`]. +/// +/// Prefer this over constructing a [`TableScan`] directly: it derives the +/// `projected_schema` from the source schema and projection, and is resilient +/// to new fields being added to [`TableScan`]. An existing scan can be turned +/// back into a builder with `TableScanBuilder::from(scan)`, tweaked, and +/// rebuilt with [`TableScanBuilder::build`]. +pub struct TableScanBuilder { + table_name: TableReference, + source: Arc, + projection: Option>, + filters: Vec, + fetch: Option, + statistics_requests: BTreeSet, +} + +impl TableScanBuilder { + /// Create a new builder for a scan of `source` named `table_name`. + pub fn new( + table_name: impl Into, + source: Arc, + ) -> Self { + Self { + table_name: table_name.into(), + source, + projection: None, + filters: vec![], + fetch: None, + statistics_requests: BTreeSet::new(), + } + } + + /// Set the column projection (indices into the source schema). + pub fn with_projection(mut self, projection: Option>) -> Self { + self.projection = projection; + self + } + + /// Set the filter expressions offered to the table provider. + pub fn with_filters(mut self, filters: Vec) -> Self { + self.filters = filters; + self + } + + /// Set the maximum number of rows to read. + pub fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } + + /// Set the statistics requests for the scan. See + /// [`TableScan::statistics_requests`]. + pub fn with_statistics_requests( + mut self, + statistics_requests: BTreeSet, + ) -> Self { + self.statistics_requests = statistics_requests; + self + } + + /// Build the [`TableScan`], deriving its `projected_schema` from the + /// source schema and projection. + pub fn build(self) -> Result { + let TableScanBuilder { + table_name, + source, + projection, + filters, + fetch, + statistics_requests, + } = self; if table_name.table().is_empty() { return plan_err!("table_name cannot be empty"); } - let schema = table_source.schema(); + let schema = source.schema(); let func_dependencies = FunctionalDependencies::new_from_constraints( - table_source.constraints(), + source.constraints(), schema.fields.len(), ); let projected_schema = projection @@ -2766,17 +3118,31 @@ impl TableScan { })?; let projected_schema = Arc::new(projected_schema); - Ok(Self { + Ok(TableScan { table_name, - source: table_source, + source, projection, projected_schema, filters, fetch, + statistics_requests, }) } } +impl From for TableScanBuilder { + fn from(scan: TableScan) -> Self { + Self { + table_name: scan.table_name, + source: scan.source, + projection: scan.projection, + filters: scan.filters, + fetch: scan.fetch, + statistics_requests: scan.statistics_requests, + } + } +} + // Repartition the plan based on a partitioning scheme. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Repartition { @@ -3088,6 +3454,15 @@ pub struct ExplainOption { pub analyze: bool, /// Output syntax/format pub format: ExplainFormat, + /// Statement-level override for `datafusion.explain.show_statistics`. + /// `None` means "fall back to session config". + pub show_statistics: Option, + /// Statement-level override for `datafusion.explain.analyze_level`. + /// `None` means "fall back to session config". + pub analyze_level: Option, + /// Statement-level override for `datafusion.explain.analyze_categories`. + /// `None` means "fall back to session config". + pub analyze_categories: Option, } impl Default for ExplainOption { @@ -3096,6 +3471,9 @@ impl Default for ExplainOption { verbose: false, analyze: false, format: ExplainFormat::Indent, + show_statistics: None, + analyze_level: None, + analyze_categories: None, } } } @@ -3118,6 +3496,30 @@ impl ExplainOption { self.format = format; self } + + /// Builder-style setter for a statement-level override of + /// `datafusion.explain.show_statistics`. + pub fn with_show_statistics(mut self, show_statistics: Option) -> Self { + self.show_statistics = show_statistics; + self + } + + /// Builder-style setter for a statement-level override of + /// `datafusion.explain.analyze_level`. + pub fn with_analyze_level(mut self, analyze_level: Option) -> Self { + self.analyze_level = analyze_level; + self + } + + /// Builder-style setter for a statement-level override of + /// `datafusion.explain.analyze_categories`. + pub fn with_analyze_categories( + mut self, + analyze_categories: Option, + ) -> Self { + self.analyze_categories = analyze_categories; + self + } } /// Produces a relation with string representations of @@ -3141,6 +3543,9 @@ pub struct Explain { pub schema: DFSchemaRef, /// Used by physical planner to check if should proceed with planning pub logical_optimization_succeeded: bool, + /// Statement-level override for `datafusion.explain.show_statistics`. + /// When `None`, the session-config value is used. + pub show_statistics: Option, } // Manual implementation needed because of `schema` field. Comparison excludes this field. @@ -3156,18 +3561,22 @@ impl PartialOrd for Explain { pub stringified_plans: &'a Vec, /// Used by physical planner to check if should proceed with planning pub logical_optimization_succeeded: &'a bool, + /// Statement-level override for show_statistics + pub show_statistics: &'a Option, } let comparable_self = ComparableExplain { verbose: &self.verbose, plan: &self.plan, stringified_plans: &self.stringified_plans, logical_optimization_succeeded: &self.logical_optimization_succeeded, + show_statistics: &self.show_statistics, }; let comparable_other = ComparableExplain { verbose: &other.verbose, plan: &other.plan, stringified_plans: &other.stringified_plans, logical_optimization_succeeded: &other.logical_optimization_succeeded, + show_statistics: &other.show_statistics, }; comparable_self .partial_cmp(&comparable_other) @@ -3182,13 +3591,24 @@ impl PartialOrd for Explain { pub struct Analyze { /// Should extra detail be included? pub verbose: bool, + /// Output syntax/format for the rendered physical plan + metrics. + pub format: ExplainFormat, /// The logical plan that is being EXPLAIN ANALYZE'd pub input: Arc, /// The output schema of the explain (2 columns of text) pub schema: DFSchemaRef, + /// Statement-level override for `datafusion.explain.analyze_level`. + /// When `None`, the session-config value is used. + pub analyze_level: Option, + /// Statement-level override for `datafusion.explain.analyze_categories`. + /// When `None`, the session-config value is used. + pub analyze_categories: Option, } -// Manual implementation needed because of `schema` field. Comparison excludes this field. +// Manual implementation needed because of `schema` field and the lack of +// `PartialOrd` on `MetricType` / `ExplainAnalyzeCategories`. Ordering is +// defined over `(verbose, input)` and then falls back to `==` for the +// remaining statement-level override fields. impl PartialOrd for Analyze { fn partial_cmp(&self, other: &Self) -> Option { match self.verbose.partial_cmp(&other.verbose) { @@ -3204,6 +3624,7 @@ impl PartialOrd for Analyze { // TODO(clippy): This clippy `allow` should be removed if // the manual `PartialEq` is removed in favor of a derive. // (see `PartialEq` the impl for details.) +#[allow(clippy::allow_attributes)] #[allow(clippy::derived_hash_with_manual_eq)] #[derive(Debug, Clone, Eq, Hash)] pub struct Extension { @@ -3451,7 +3872,9 @@ pub struct Aggregate { pub input: Arc, /// Grouping expressions pub group_expr: Vec, - /// Aggregate expressions + /// Aggregate expressions. + /// + /// Note these *must* be either [`Expr::AggregateFunction`] or [`Expr::Alias`] pub aggr_expr: Vec, /// The schema description of the aggregate output pub schema: DFSchemaRef, @@ -3478,11 +3901,12 @@ impl Aggregate { .into_iter() .map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into())) .collect::>(); + let max_ordinal = max_grouping_set_duplicate_ordinal(&group_expr); qualified_fields.push(( None, Field::new( Self::INTERNAL_GROUPING_ID, - Self::grouping_id_type(qualified_fields.len()), + Self::grouping_id_type(qualified_fields.len(), max_ordinal), false, ) .into(), @@ -3504,7 +3928,6 @@ impl Aggregate { /// /// This method should only be called when you are absolutely sure that the schema being /// provided is correct for the aggregate. If in doubt, call [try_new](Self::try_new) instead. - #[expect(clippy::needless_pass_by_value)] pub fn try_new_with_schema( input: Arc, group_expr: Vec, @@ -3530,7 +3953,7 @@ impl Aggregate { let aggregate_func_dependencies = calc_func_dependencies_for_aggregate(&group_expr, &input, &schema)?; - let new_schema = schema.as_ref().clone(); + let new_schema = Arc::unwrap_or_clone(schema); let schema = Arc::new( new_schema.with_functional_dependencies(aggregate_func_dependencies)?, ); @@ -3568,15 +3991,24 @@ impl Aggregate { } /// Returns the data type of the grouping id. - /// The grouping ID value is a bitmask where each set bit - /// indicates that the corresponding grouping expression is - /// null - pub fn grouping_id_type(group_exprs: usize) -> DataType { - if group_exprs <= 8 { + /// + /// The grouping ID packs two pieces of information into a single integer: + /// - The low `group_exprs` bits are the semantic bitmask (a set bit means the + /// corresponding grouping expression is NULL for this grouping set). + /// - The bits above position `group_exprs` encode a duplicate ordinal that + /// distinguishes multiple occurrences of the same grouping set pattern. + /// + /// `max_ordinal` is the highest ordinal value that will appear (0 when there + /// are no duplicate grouping sets). The type is chosen to be the smallest + /// unsigned integer that can represent both parts. + pub fn grouping_id_type(group_exprs: usize, max_ordinal: usize) -> DataType { + let ordinal_bits = usize::BITS as usize - max_ordinal.leading_zeros() as usize; + let total_bits = group_exprs + ordinal_bits; + if total_bits <= 8 { DataType::UInt8 - } else if group_exprs <= 16 { + } else if total_bits <= 16 { DataType::UInt16 - } else if group_exprs <= 32 { + } else if total_bits <= 32 { DataType::UInt32 } else { DataType::UInt64 @@ -3585,21 +4017,36 @@ impl Aggregate { /// Internal column used when the aggregation is a grouping set. /// - /// This column contains a bitmask where each bit represents a grouping - /// expression. The least significant bit corresponds to the rightmost - /// grouping expression. A bit value of 0 indicates that the corresponding - /// column is included in the grouping set, while a value of 1 means it is excluded. + /// This column packs two values into a single unsigned integer: + /// + /// - **Low bits (positions 0 .. n-1)**: a semantic bitmask where each bit + /// represents one of the `n` grouping expressions. The least significant + /// bit corresponds to the rightmost grouping expression. A `1` bit means + /// the corresponding column is replaced with `NULL` for this grouping set; + /// a `0` bit means it is included. + /// - **High bits (positions n and above)**: a *duplicate ordinal* that + /// distinguishes multiple occurrences of the same semantic grouping set + /// pattern within a single query. The ordinal is `0` for the first + /// occurrence, `1` for the second, and so on. + /// + /// The integer type is chosen by [`Self::grouping_id_type`] to be the + /// smallest `UInt8 / UInt16 / UInt32 / UInt64` that can represent both + /// parts. /// - /// For example, for the grouping expressions CUBE(a, b), the grouping ID - /// column will have the following values: + /// For example, for the grouping expressions CUBE(a, b) (no duplicates), + /// the grouping ID column will have the following values: /// 0b00: Both `a` and `b` are included /// 0b01: `b` is excluded /// 0b10: `a` is excluded /// 0b11: Both `a` and `b` are excluded /// - /// This internal column is necessary because excluded columns are replaced - /// with `NULL` values. To handle these cases correctly, we must distinguish - /// between an actual `NULL` value in a column and a column being excluded from the set. + /// When the same set appears twice and `n = 2`, the duplicate ordinal is + /// packed into bit 2: + /// first occurrence: `0b0_01` (ordinal = 0, mask = 0b01) + /// second occurrence: `0b1_01` (ordinal = 1, mask = 0b01) + /// + /// The GROUPING function always masks the value with `(1 << n) - 1` before + /// interpreting it so the ordinal bits are invisible to user-facing SQL. pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id"; } @@ -3620,6 +4067,25 @@ impl PartialOrd for Aggregate { } } +/// Returns the highest duplicate ordinal across all grouping sets in `group_expr`. +/// +/// The ordinal for each occurrence of a grouping set pattern is its 0-based +/// index among identical entries. For example, if the same set appears three +/// times, the ordinals are 0, 1, 2 and this function returns 2. +/// Returns 0 when no grouping set is duplicated. +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key +fn max_grouping_set_duplicate_ordinal(group_expr: &[Expr]) -> usize { + if let Some(Expr::GroupingSet(GroupingSet::GroupingSets(sets))) = group_expr.first() { + let mut counts: HashMap<&[Expr], usize> = HashMap::new(); + for set in sets { + *counts.entry(set).or_insert(0) += 1; + } + counts.into_values().max().unwrap_or(0).saturating_sub(1) + } else { + 0 + } +} + /// Checks whether any expression in `group_expr` contains `Expr::GroupingSet`. fn contains_grouping_set(group_expr: &[Expr]) -> bool { group_expr @@ -3749,6 +4215,14 @@ pub struct Join { pub schema: DFSchemaRef, /// Defines the null equality for the join. pub null_equality: NullEquality, + /// Whether this is a null-aware anti join (for NOT IN semantics). + /// + /// Only applies to LeftAnti joins. When true, implements SQL NOT IN semantics where: + /// - If the right side (subquery) contains any NULL in join keys, no rows are output + /// - Left side rows with NULL in join keys are not output + /// + /// This is required for correct NOT IN subquery behavior with three-valued logic. + pub null_aware: bool, } impl Join { @@ -3766,10 +4240,12 @@ impl Join { /// * `join_type` - Type of join (Inner, Left, Right, etc.) /// * `join_constraint` - Join constraint (On, Using) /// * `null_equality` - How to handle nulls in join comparisons + /// * `null_aware` - Whether this is a null-aware anti join (for NOT IN semantics) /// /// # Returns /// /// A new Join operator with the computed schema + #[expect(clippy::too_many_arguments)] pub fn try_new( left: Arc, right: Arc, @@ -3778,6 +4254,7 @@ impl Join { join_type: JoinType, join_constraint: JoinConstraint, null_equality: NullEquality, + null_aware: bool, ) -> Result { let join_schema = build_join_schema(left.schema(), right.schema(), &join_type)?; @@ -3790,6 +4267,7 @@ impl Join { join_constraint, schema: Arc::new(join_schema), null_equality, + null_aware, }) } @@ -3845,6 +4323,7 @@ impl Join { join_constraint: original_join.join_constraint, schema: Arc::new(join_schema), null_equality: original_join.null_equality, + null_aware: original_join.null_aware, }, requalified, )) @@ -3950,11 +4429,16 @@ impl Debug for Subquery { } } -/// Logical partitioning schemes supported by [`LogicalPlan::Repartition`] +/// Logical partitioning schemes. +/// +/// A scheme can describe either requested repartitioning in +/// [`LogicalPlan::Repartition`] or a partitioning property declared by a source. +/// Some schemes are only valid as metadata until planner support is added. /// -/// See [`Partitioning`] for more details on partitioning +/// For physical execution partitioning, see +/// [`datafusion_physical_expr::Partitioning`]. /// -/// [`Partitioning`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/enum.Partitioning.html# +/// [`datafusion_physical_expr::Partitioning`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/enum.Partitioning.html# #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum Partitioning { /// Allocate batches using a round-robin algorithm and the specified number of partitions @@ -3962,10 +4446,118 @@ pub enum Partitioning { /// Allocate rows based on a hash of one of more expressions and the specified number /// of partitions. Hash(Vec, usize), + /// Partition rows by ranges. + /// See [`RangePartitioning`] for the logical contract. + Range(RangePartitioning), /// The DISTRIBUTE BY clause is used to repartition the data based on the input expressions DistributeBy(Vec), } +impl Partitioning { + /// Return the number of partitions, if known. + pub fn partition_count(&self) -> Option { + match self { + Self::RoundRobinBatch(partition_count) | Self::Hash(_, partition_count) => { + Some(*partition_count) + } + Self::Range(range) => Some(range.partition_count()), + Self::DistributeBy(_) => None, + } + } +} + +/// Logical range partitioning. +/// +/// [`RangePartitioning`] describes an ordered logical key space with split points. +/// +/// - `ordering` defines the partitioning key and ordering using logical +/// [`SortExpr`]s. +/// - `split_points` define the boundaries between adjacent partitions. +/// +/// Comparisons use the lexicographic order defined by `ordering`, +/// including `ASC`/`DESC` and null ordering. Split points must be ordered +/// according to that ordering, and each split point must have one value per +/// ordering expression. See [`SplitPoint`] for the shared boundary contract. +/// +/// The expressions are resolved against the declaring plan's schema. This +/// constructor does not validate split point value types against the resolved +/// expression types. Like other user-specified data properties such as +/// sortedness, if a source declares range partitioning, it is responsible for +/// placing each row in the partition described by the split points. DataFusion +/// will not validate this is upheld. +/// +/// NOTE: Range-aware optimizer and execution behavior will be introduced +/// incrementally. See +/// . +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +pub struct RangePartitioning { + /// Ordered logical partitioning key. + ordering: Vec, + /// Boundaries between adjacent partitions. + split_points: Vec, +} + +impl RangePartitioning { + /// Creates logical range partitioning metadata and validates split point + /// shape and ordering. + pub fn try_new( + ordering: Vec, + split_points: Vec, + ) -> Result { + if ordering.is_empty() { + return plan_err!("Range partitioning requires non-empty ordering"); + } + + validate_range_split_points(&split_points, &logical_sort_options(&ordering))?; + + Ok(Self { + ordering, + split_points, + }) + } + + /// Return the number of partitions. + pub fn partition_count(&self) -> usize { + self.split_points.len() + 1 + } + + /// Returns the ordering that defines the range key. + pub fn ordering(&self) -> &[SortExpr] { + &self.ordering + } + + /// Returns the ordered split points between partitions. + pub fn split_points(&self) -> &[SplitPoint] { + &self.split_points + } +} + +fn logical_sort_options(ordering: &[SortExpr]) -> Vec { + ordering + .iter() + .map(|sort_expr| SortOptions { + descending: !sort_expr.asc, + nulls_first: sort_expr.nulls_first, + }) + .collect() +} + +impl Display for RangePartitioning { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let ordering = self.ordering().iter().map(ToString::to_string).join(", "); + let split_points = self + .split_points() + .iter() + .map(ToString::to_string) + .join(", "); + write!( + f, + "Range([{ordering}], [{split_points}], {})", + self.partition_count() + ) + } +} + /// Represent the unnesting operation on a list column, such as the recursion depth and /// the output column name after unnesting /// @@ -4142,7 +4734,9 @@ impl Unnest { } DataType::List(_) | DataType::FixedSizeList(_, _) - | DataType::LargeList(_) => { + | DataType::LargeList(_) + | DataType::ListView(_) + | DataType::LargeListView(_) => { list_columns.push(( index, ColumnUnnestList { @@ -4217,7 +4811,11 @@ fn get_unnested_columns( let mut qualified_columns = Vec::with_capacity(1); match data_type { - DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => { + DataType::List(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) + | DataType::ListView(_) + | DataType::LargeListView(_) => { let data_type = get_unnested_list_datatype_recursive(data_type, depth)?; let new_field = Arc::new(Field::new( col_name, data_type, @@ -4254,7 +4852,9 @@ fn get_unnested_list_datatype_recursive( match data_type { DataType::List(field) | DataType::FixedSizeList(field, _) - | DataType::LargeList(field) => { + | DataType::LargeList(field) + | DataType::ListView(field) + | DataType::LargeListView(field) => { if depth == 1 { return Ok(field.data_type().clone()); } @@ -4274,17 +4874,56 @@ mod tests { use crate::select_expr::SelectExpr; use crate::test::function_stub::{count, count_udaf}; use crate::{ - binary_expr, col, exists, in_subquery, lit, placeholder, scalar_subquery, - GroupingSet, + GroupingSet, binary_expr, col, exists, in_subquery, lit, placeholder, + scalar_subquery, }; use datafusion_common::metadata::ScalarAndMetadata; use datafusion_common::tree_node::{ TransformedResult, TreeNodeRewriter, TreeNodeVisitor, }; - use datafusion_common::{not_impl_err, Constraint, ScalarValue}; + use datafusion_common::{Constraint, not_impl_err}; use insta::{assert_debug_snapshot, assert_snapshot}; use std::hash::DefaultHasher; + /// `LogicalPlan` is moved/swapped on every step of the planning hot path + /// (every `mem::take` in an in-place rewriter, every `Arc` + /// write, every owned `map_*` traversal). Its size is set by the largest + /// variant, so an oversized variant balloons cost for every other variant. + /// + /// Today the size-setter should be `Join` (~176 bytes); `DdlStatement` is + /// boxed precisely so it does not dominate. If you grow a variant, please + /// box the new large fields rather than letting this number creep up — + /// see the analogous `test_size_of_expr` in `expr.rs`. + #[test] + fn test_size_of_logical_plan() { + // `LogicalPlan` enum on aarch64 / x86_64. Today this matches + // `Join`'s 176 bytes (the enum discriminant fits in `Join`'s + // alignment padding); if `Join` grows or another variant overtakes + // it, this number will move with the new size-setter. + assert_eq!(size_of::(), 176); + // `DdlStatement` is `Ddl(DdlStatement)`'s payload; keep it below the + // `Join` ceiling so it never re-becomes the size-setter. + assert!( + size_of::() < size_of::(), + "DdlStatement ({} bytes) should stay smaller than Join ({} bytes); \ + box the new large variant rather than letting it dominate `LogicalPlan`.", + size_of::(), + size_of::(), + ); + // Sanity check the two boxed variants stay boxed (so the payload + // sits on the heap, not in the enum). + assert_eq!( + size_of::>(), + 8, + "CreateExternalTable should be Box'd inside DdlStatement" + ); + assert_eq!( + size_of::>(), + 8, + "CreateFunction should be Box'd inside DdlStatement" + ); + } + fn employee_schema() -> Schema { Schema::new(vec![ Field::new("id", DataType::Int32, false), @@ -4295,6 +4934,134 @@ mod tests { ]) } + fn i32_split_point(value: i32) -> SplitPoint { + SplitPoint::new(vec![ScalarValue::Int32(Some(value))]) + } + + fn null_i32_split_point() -> SplitPoint { + SplitPoint::new(vec![ScalarValue::Int32(None)]) + } + + #[test] + fn logical_range_partitioning_validates_shape() { + let range = RangePartitioning::try_new( + vec![col("id").sort(true, true)], + vec![i32_split_point(10), i32_split_point(20)], + ) + .unwrap(); + assert_eq!(range.partition_count(), 3); + + let range = RangePartitioning::try_new( + vec![col("id").sort(false, true)], + vec![i32_split_point(20), i32_split_point(10)], + ) + .unwrap(); + assert_eq!(range.partition_count(), 3); + + let err = RangePartitioning::try_new(vec![], vec![]).unwrap_err(); + assert!(err.to_string().contains("non-empty ordering")); + + let err = RangePartitioning::try_new( + vec![col("id").sort(true, true), col("salary").sort(true, true)], + vec![i32_split_point(10)], + ) + .unwrap_err(); + assert!( + err.to_string() + .contains("split point 0 has width 1, but ordering has width 2") + ); + + let err = RangePartitioning::try_new( + vec![col("id").sort(true, true)], + vec![i32_split_point(20), i32_split_point(10)], + ) + .unwrap_err(); + assert!( + err.to_string() + .contains("split points must be strictly ordered") + ); + + let err = RangePartitioning::try_new( + vec![col("id").sort(true, true)], + vec![i32_split_point(10), i32_split_point(10)], + ) + .unwrap_err(); + assert!( + err.to_string() + .contains("split points must be strictly ordered") + ); + + let range = RangePartitioning::try_new( + vec![col("id").sort(true, true)], + vec![null_i32_split_point(), i32_split_point(10)], + ) + .unwrap(); + assert_eq!(range.partition_count(), 3); + } + + #[test] + fn logical_partitioning_reports_known_partition_count() -> Result<()> { + let range = RangePartitioning::try_new( + vec![col("id").sort(true, true)], + vec![i32_split_point(10)], + )?; + + assert_eq!(Partitioning::RoundRobinBatch(4).partition_count(), Some(4)); + assert_eq!( + Partitioning::Hash(vec![col("id")], 8).partition_count(), + Some(8) + ); + assert_eq!(Partitioning::Range(range).partition_count(), Some(2)); + assert_eq!( + Partitioning::DistributeBy(vec![col("id")]).partition_count(), + None + ); + + Ok(()) + } + + #[test] + fn logical_range_partitioning_participates_in_expression_rewrite() -> Result<()> { + let input = + table_scan(Some("employee_csv"), &employee_schema(), None)?.build()?; + let plan = LogicalPlan::Repartition(Repartition { + input: Arc::new(input), + partitioning_scheme: Partitioning::Range(RangePartitioning::try_new( + vec![col("id").sort(true, true)], + vec![i32_split_point(10)], + )?), + }); + + let mut visited_exprs = vec![]; + plan.apply_expressions(|expr| { + visited_exprs.push(expr.to_string()); + Ok(TreeNodeRecursion::Continue) + })?; + assert_eq!(visited_exprs, vec!["id"]); + + let plan = plan + .map_expressions(|expr| { + if expr == col("id") { + Ok(Transformed::yes(col("salary"))) + } else { + Ok(Transformed::no(expr)) + } + })? + .data; + + let LogicalPlan::Repartition(Repartition { + partitioning_scheme: Partitioning::Range(range), + .. + }) = plan + else { + unreachable!("expected range repartition"); + }; + assert_eq!(range.ordering()[0].expr, col("salary")); + assert_eq!(range.partition_count(), 2); + + Ok(()) + } + fn display_plan() -> Result { let plan1 = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3]))? .build()?; @@ -4305,6 +5072,74 @@ mod tests { .build() } + fn recursive_term_scan(name: &str, fields: Vec) -> Result> { + Ok(Arc::new( + table_scan(Some(name), &Schema::new(fields), None)?.build()?, + )) + } + + #[test] + fn recursive_query_widens_nullability_per_column() -> Result<()> { + // Column `a` is non-nullable in both terms and must stay non-nullable; + // column `b` is non-nullable in the static term but nullable in the + // recursive term, so the output must widen it to nullable. + let static_term = recursive_term_scan( + "static", + vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ], + )?; + let recursive_term = recursive_term_scan( + "rec", + vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, true), + ], + )?; + + let query = + RecursiveQuery::try_new("t".to_string(), static_term, recursive_term, false)?; + + // Names and types are taken from the static term. + assert_eq!(query.schema.field(0).name(), "a"); + assert_eq!(query.schema.field(1).name(), "b"); + assert_eq!(query.schema.field(0).data_type(), &DataType::Int32); + assert_eq!(query.schema.field(1).data_type(), &DataType::Int32); + // Nullability is widened independently per column. + assert!(!query.schema.field(0).is_nullable()); + assert!(query.schema.field(1).is_nullable()); + // `schema()` returns the widened recursive-query schema. + assert_eq!( + LogicalPlan::RecursiveQuery(query.clone()).schema(), + &query.schema + ); + Ok(()) + } + + #[test] + fn recursive_query_rejects_column_count_mismatch() -> Result<()> { + let static_term = + recursive_term_scan("static", vec![Field::new("a", DataType::Int32, false)])?; + let recursive_term = recursive_term_scan( + "rec", + vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ], + )?; + + let err = + RecursiveQuery::try_new("t".to_string(), static_term, recursive_term, false) + .unwrap_err(); + assert!( + err.strip_backtrace() + .contains("must have the same number of columns"), + "unexpected error: {err}" + ); + Ok(()) + } + #[test] fn test_display_indent() -> Result<()> { let plan = display_plan()?; @@ -4404,49 +5239,49 @@ mod tests { [ { "Plan": { + "Node Type": "Projection", "Expressions": [ "employee_csv.id" ], - "Node Type": "Projection", - "Output": [ - "id" - ], "Plans": [ { - "Condition": "employee_csv.state IN ()", "Node Type": "Filter", - "Output": [ - "id", - "state" - ], + "Condition": "employee_csv.state IN ()", "Plans": [ { "Node Type": "Subquery", - "Output": [ - "state" - ], "Plans": [ { "Node Type": "TableScan", + "Relation Name": "employee_csv", + "Plans": [], "Output": [ "state" - ], - "Plans": [], - "Relation Name": "employee_csv" + ] } + ], + "Output": [ + "state" ] }, { "Node Type": "TableScan", + "Relation Name": "employee_csv", + "Plans": [], "Output": [ "id", "state" - ], - "Plans": [], - "Relation Name": "employee_csv" + ] } + ], + "Output": [ + "id", + "state" ] } + ], + "Output": [ + "id" ] } } @@ -4901,14 +5736,26 @@ mod tests { let output_schema = plan.schema(); - assert!(output_schema - .field_with_name(None, "foo") - .unwrap() - .is_nullable(),); - assert!(output_schema - .field_with_name(None, "bar") - .unwrap() - .is_nullable()); + assert!( + output_schema + .field_with_name(None, "foo") + .unwrap() + .is_nullable(), + ); + assert!( + output_schema + .field_with_name(None, "bar") + .unwrap() + .is_nullable() + ); + } + + #[test] + fn grouping_id_type_accounts_for_duplicate_ordinal_bits() { + // 8 grouping columns fit in UInt8 when there are no duplicate ordinals, + // but adding one duplicate ordinal bit widens the type to UInt16. + assert_eq!(Aggregate::grouping_id_type(8, 0), DataType::UInt8); + assert_eq!(Aggregate::grouping_id_type(8, 1), DataType::UInt16); } #[test] @@ -4932,6 +5779,7 @@ mod tests { projected_schema: Arc::clone(&schema), filters: vec![], fetch: None, + statistics_requests: BTreeSet::new(), })); let col = schema.field_names()[0].clone(); @@ -4962,6 +5810,7 @@ mod tests { projected_schema: Arc::clone(&unique_schema), filters: vec![], fetch: None, + statistics_requests: BTreeSet::new(), })); let col = schema.field_names()[0].clone(); @@ -5173,7 +6022,11 @@ mod tests { .transform_down_with_subqueries(|plan| { match plan { LogicalPlan::Projection(..) => { - return Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) + return Ok(Transformed::new( + plan, + false, + TreeNodeRecursion::Jump, + )); } LogicalPlan::Filter(..) => filter_found = true, _ => {} @@ -5193,7 +6046,7 @@ mod tests { plan, false, TreeNodeRecursion::Jump, - )) + )); } LogicalPlan::Filter(..) => filter_found = true, _ => {} @@ -5223,7 +6076,11 @@ mod tests { fn f_down(&mut self, node: Self::Node) -> Result> { match node { LogicalPlan::Projection(..) => { - return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)) + return Ok(Transformed::new( + node, + false, + TreeNodeRecursion::Jump, + )); } LogicalPlan::Filter(..) => self.filter_found = true, _ => {} @@ -5285,6 +6142,7 @@ mod tests { join_constraint: JoinConstraint::On, schema: Arc::new(left_schema.join(&right_schema)?), null_equality: NullEquality::NullEqualsNothing, + null_aware: false, })) } @@ -5396,6 +6254,7 @@ mod tests { join_type, JoinConstraint::On, NullEquality::NullEqualsNothing, + false, )?; match join_type { @@ -5541,6 +6400,7 @@ mod tests { JoinType::Inner, JoinConstraint::Using, NullEquality::NullEqualsNothing, + false, )?; let fields = join.schema.fields(); @@ -5592,6 +6452,7 @@ mod tests { JoinType::Inner, JoinConstraint::On, NullEquality::NullEqualsNothing, + false, )?; let fields = join.schema.fields(); @@ -5641,6 +6502,7 @@ mod tests { JoinType::Inner, JoinConstraint::On, NullEquality::NullEqualsNull, + false, )?; assert_eq!(join.null_equality, NullEquality::NullEqualsNull); @@ -5683,6 +6545,7 @@ mod tests { join_type, JoinConstraint::On, NullEquality::NullEqualsNothing, + false, )?; let fields = join.schema.fields(); @@ -5722,6 +6585,7 @@ mod tests { JoinType::Inner, JoinConstraint::Using, NullEquality::NullEqualsNothing, + false, )?; assert_eq!( diff --git a/datafusion/expr/src/logical_plan/statement.rs b/datafusion/expr/src/logical_plan/statement.rs index 49a938c9f6eb9..daf29d7c81d3f 100644 --- a/datafusion/expr/src/logical_plan/statement.rs +++ b/datafusion/expr/src/logical_plan/statement.rs @@ -20,9 +20,9 @@ use datafusion_common::metadata::format_type_and_metadata; use datafusion_common::{DFSchema, DFSchemaRef}; use itertools::Itertools as _; use std::fmt::{self, Display}; -use std::sync::{Arc, LazyLock}; +use std::sync::Arc; -use crate::{expr_vec_fmt, Expr, LogicalPlan}; +use crate::{Expr, LogicalPlan, expr_vec_fmt}; /// Various types of Statements. /// @@ -55,10 +55,7 @@ impl Statement { /// Get a reference to the logical plan's schema pub fn schema(&self) -> &DFSchemaRef { // Statements have an unchanging empty schema. - static STATEMENT_EMPTY_SCHEMA: LazyLock = - LazyLock::new(|| Arc::new(DFSchema::empty())); - - &STATEMENT_EMPTY_SCHEMA + DFSchema::empty_ref() } /// Return a descriptive string describing the type of this diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 47088370a1d93..cba2dac24b610 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -37,21 +37,22 @@ //! * [`LogicalPlan::with_new_exprs`]: Create a new plan with different expressions //! * [`LogicalPlan::expressions`]: Return a copy of the plan's expressions +use crate::logical_plan::plan::RangePartitioning; use crate::{ - dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, - Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, - Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, - Sort, Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, - UserDefinedLogicalNode, Values, Window, + Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, Distinct, + DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, Limit, + LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, Sort, + Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, + Values, Window, dml::CopyTo, }; use datafusion_common::tree_node::TreeNodeRefContainer; -use crate::expr::{Exists, InSubquery}; +use crate::expr::{Exists, InSubquery, SetComparison}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{Result, internal_err}; impl TreeNode for LogicalPlan { fn apply_children<'n, F: FnMut(&'n Self) -> Result>( @@ -133,6 +134,7 @@ impl TreeNode for LogicalPlan { join_constraint, schema, null_equality, + null_aware, }) => (left, right).map_elements(f)?.update_data(|(left, right)| { LogicalPlan::Join(Join { left, @@ -143,6 +145,7 @@ impl TreeNode for LogicalPlan { join_constraint, schema, null_equality, + null_aware, }) }), LogicalPlan::Limit(Limit { skip, fetch, input }) => input @@ -201,6 +204,7 @@ impl TreeNode for LogicalPlan { stringified_plans, schema, logical_optimization_succeeded, + show_statistics, }) => plan.map_elements(f)?.update_data(|plan| { LogicalPlan::Explain(Explain { verbose, @@ -209,17 +213,24 @@ impl TreeNode for LogicalPlan { stringified_plans, schema, logical_optimization_succeeded, + show_statistics, }) }), LogicalPlan::Analyze(Analyze { verbose, + format, input, schema, + analyze_level, + analyze_categories, }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Analyze(Analyze { verbose, + format, input, schema, + analyze_level, + analyze_categories, }) }), LogicalPlan::Dml(DmlStatement { @@ -327,13 +338,18 @@ impl TreeNode for LogicalPlan { static_term, recursive_term, is_distinct, + schema, }) => (static_term, recursive_term).map_elements(f)?.update_data( |(static_term, recursive_term)| { + // Ordinary child rewrites preserve derived schemas. Call + // `LogicalPlan::recompute_schema` when child schemas should + // be reconciled again. LogicalPlan::RecursiveQuery(RecursiveQuery { name, static_term, recursive_term, is_distinct, + schema, }) }, ), @@ -412,6 +428,7 @@ impl LogicalPlan { Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) => { expr.apply_elements(f) } + Partitioning::Range(range) => range.ordering().to_vec().apply_elements(f), Partitioning::RoundRobinBatch(_) => Ok(TreeNodeRecursion::Continue), }, LogicalPlan::Window(Window { window_expr, .. }) => { @@ -517,6 +534,19 @@ impl LogicalPlan { Partitioning::DistributeBy(expr) => expr .map_elements(f)? .update_data(Partitioning::DistributeBy), + Partitioning::Range(range) => { + let split_points = range.split_points().to_vec(); + range + .ordering() + .to_vec() + .map_elements(f)? + .map_data(|ordering| { + Ok(Partitioning::Range(RangePartitioning::try_new( + ordering, + split_points, + )?)) + })? + } Partitioning::RoundRobinBatch(_) => Transformed::no(partitioning_scheme), } .update_data(|partitioning_scheme| { @@ -564,6 +594,7 @@ impl LogicalPlan { join_constraint, schema, null_equality, + null_aware, }) => (on, filter).map_elements(f)?.update_data(|(on, filter)| { LogicalPlan::Join(Join { left, @@ -574,23 +605,36 @@ impl LogicalPlan { join_constraint, schema, null_equality, + null_aware, }) }), LogicalPlan::Sort(Sort { expr, input, fetch }) => expr .map_elements(f)? .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), LogicalPlan::Extension(Extension { node }) => { - // would be nice to avoid this copy -- maybe can - // update extension to just observer Exprs - let exprs = node.expressions().map_elements(f)?; - let plan = LogicalPlan::Extension(Extension { - node: UserDefinedLogicalNode::with_exprs_and_inputs( - node.as_ref(), - exprs.data, - node.inputs().into_iter().cloned().collect::>(), - )?, - }); - Transformed::new(plan, exprs.transformed, exprs.tnr) + let raw_exprs = node.expressions(); + if raw_exprs.is_empty() { + // No expressions to transform — skip expensive clone of + // all inputs and reconstruction via with_exprs_and_inputs. + Transformed::no(LogicalPlan::Extension(Extension { node })) + } else { + // TODO: a more general optimization would be to change + // `UserDefinedLogicalNode::expressions()` to return + // references (`&[Expr]`) instead of cloned `Vec`, + // and only clone + rebuild when the transform actually + // modifies an expression. This would avoid the clone + + // `with_exprs_and_inputs` rebuild even for non-empty + // expression lists when the transform is a no-op. + let exprs = raw_exprs.map_elements(f)?; + let plan = LogicalPlan::Extension(Extension { + node: UserDefinedLogicalNode::with_exprs_and_inputs( + node.as_ref(), + exprs.data, + node.inputs().into_iter().cloned().collect::>(), + )?, + }); + Transformed::new(plan, exprs.transformed, exprs.tnr) + } } LogicalPlan::TableScan(TableScan { table_name, @@ -599,6 +643,7 @@ impl LogicalPlan { projected_schema, filters, fetch, + statistics_requests, }) => filters.map_elements(f)?.update_data(|filters| { LogicalPlan::TableScan(TableScan { table_name, @@ -607,6 +652,7 @@ impl LogicalPlan { projected_schema, filters, fetch, + statistics_requests, }) }), LogicalPlan::Distinct(Distinct::On(DistinctOn { @@ -804,7 +850,7 @@ impl LogicalPlan { transform_down_up_with_subqueries_impl(self, &mut f_down, &mut f_up) } - /// Similarly to [`Self::apply`], calls `f` on this node and its inputs + /// Similarly to [`Self::apply`], calls `f` on this node and its inputs, /// including subqueries that may appear in expressions such as `IN (SELECT /// ...)`. pub fn apply_subqueries Result>( @@ -815,10 +861,9 @@ impl LogicalPlan { expr.apply(|expr| match expr { Expr::Exists(Exists { subquery, .. }) | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { - // use a synthetic plan so the collector sees a - // LogicalPlan::Subquery (even though it is - // actually a Subquery alias) + // Wrap in LogicalPlan::Subquery to match f's signature f(&LogicalPlan::Subquery(subquery.clone())) } _ => Ok(TreeNodeRecursion::Continue), @@ -826,6 +871,32 @@ impl LogicalPlan { }) } + /// Returns true if any expression in this node contains a subquery + /// (Exists, InSubquery, SetComparison, or ScalarSubquery). + fn has_subquery_expressions(&self) -> bool { + let mut found = false; + let _ = self.apply_expressions(|expr| { + if found { + return Ok(TreeNodeRecursion::Stop); + } + expr.apply(|e| { + if matches!( + e, + Expr::Exists(_) + | Expr::InSubquery(_) + | Expr::SetComparison(_) + | Expr::ScalarSubquery(_) + ) { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + }); + found + } + /// Similarly to [`Self::map_children`], rewrites all subqueries that may /// appear in expressions such as `IN (SELECT ...)` using `f`. /// @@ -834,6 +905,14 @@ impl LogicalPlan { self, mut f: F, ) -> Result> { + // Fast path: skip the expensive ownership-based expression traversal + // when this node has no subquery expressions. This avoids + // map_expressions → transform_down walking every expression node + // via consume+recreate just to find no subqueries. + if !self.has_subquery_expressions() { + return Ok(Transformed::no(self)); + } + self.map_expressions(|expr| { expr.transform_down(|expr| match expr { Expr::Exists(Exists { subquery, negated }) => { @@ -856,6 +935,22 @@ impl LogicalPlan { })), _ => internal_err!("Transformation should return Subquery"), }), + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery) => { + Ok(Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + })) + } + _ => internal_err!("Transformation should return Subquery"), + }), Expr::ScalarSubquery(subquery) => f(LogicalPlan::Subquery(subquery))? .map_data(|s| match s { LogicalPlan::Subquery(subquery) => { @@ -867,4 +962,18 @@ impl LogicalPlan { }) }) } + + /// Similar to [`Self::map_subqueries`], but only applies `f` to + /// uncorrelated subqueries (those with no outer column references). + pub fn map_uncorrelated_subqueries Result>>( + self, + mut f: F, + ) -> Result> { + self.map_subqueries(|subquery_plan| match &subquery_plan { + LogicalPlan::Subquery(sq) if sq.outer_ref_columns.is_empty() => { + f(subquery_plan) + } + _ => Ok(Transformed::no(subquery_plan)), + }) + } } diff --git a/datafusion/expr/src/partition_evaluator.rs b/datafusion/expr/src/partition_evaluator.rs index a0f0988b4f4e5..5a4e20e5ac9ac 100644 --- a/datafusion/expr/src/partition_evaluator.rs +++ b/datafusion/expr/src/partition_evaluator.rs @@ -18,7 +18,7 @@ //! Partition evaluation module use arrow::array::ArrayRef; -use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err}; use std::fmt::Debug; use std::ops::Range; @@ -86,7 +86,11 @@ use crate::window_state::WindowAggState; /// [`uses_window_frame`]: Self::uses_window_frame /// [`include_rank`]: Self::include_rank /// [`supports_bounded_execution`]: Self::supports_bounded_execution -pub trait PartitionEvaluator: Debug + Send { +/// +/// For more background, please also see the [User defined Window Functions in DataFusion blog] +/// +/// [User defined Window Functions in DataFusion blog]: https://datafusion.apache.org/blog/2025/04/19/user-defined-window-functions +pub trait PartitionEvaluator: Debug + Send + std::any::Any { /// When the window frame has a fixed beginning (e.g UNBOUNDED /// PRECEDING), some functions such as FIRST_VALUE, LAST_VALUE and /// NTH_VALUE do not need the (unbounded) input once they have @@ -175,7 +179,7 @@ pub trait PartitionEvaluator: Debug + Send { } /// Evaluate window function on a range of rows in an input - /// partition.x + /// partition. /// /// This is the simplest and most general function to implement /// but also the least performant as it creates output one row at @@ -210,7 +214,7 @@ pub trait PartitionEvaluator: Debug + Send { /// A | 1 /// C | 3 /// D | 4 - /// D | 5 + /// D | 4 /// ``` /// /// For this case, `num_rows` would be `5` and the diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 25a0f83947eee..7aaf3a98cbe5d 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -21,15 +21,20 @@ use std::fmt::Debug; use std::sync::Arc; use crate::expr::NullTreatment; +#[cfg(feature = "sql")] +use crate::logical_plan::LogicalPlan; use crate::{ - AggregateUDF, Expr, GetFieldAccess, ScalarUDF, SortExpr, TableSource, WindowFrame, - WindowFunctionDefinition, WindowUDF, + AggregateUDF, Expr, GetFieldAccess, HigherOrderUDF, ScalarUDF, SortExpr, TableSource, + WindowFrame, WindowFunctionDefinition, WindowUDF, }; -use arrow::datatypes::{DataType, Field, SchemaRef}; +use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef}; +use datafusion_common::datatype::DataTypeExt; use datafusion_common::{ - config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema, - Result, TableReference, + DFSchema, Result, TableReference, config::ConfigOptions, + file_options::file_type::FileType, not_impl_err, }; +#[cfg(feature = "sql")] +use sqlparser::ast::{Expr as SQLExpr, Ident, ObjectName, TableAlias, TableFactor}; /// Provides the `SQL` query planner meta-data about tables and /// functions referenced in SQL statements, without a direct dependency on the @@ -56,7 +61,8 @@ pub trait ContextProvider { not_impl_err!("Table Functions are not supported") } - /// Provides an intermediate table that is used to store the results of a CTE during execution + /// Provides an intermediate table that is used to expose a recursive CTE + /// self-reference during planning and execution. /// /// CTE stands for "Common Table Expression" /// @@ -67,6 +73,9 @@ pub trait ContextProvider { /// of the sql crate (for example [`CteWorkTable`]). /// /// The [`ContextProvider`] provides a way to "hide" this dependency. + /// The schema argument is the schema to expose for scans of the recursive + /// self-reference, which may be more conservative than the final recursive + /// query output schema. /// /// [`SqlToRel`]: https://docs.rs/datafusion/latest/datafusion/sql/planner/struct.SqlToRel.html /// [`CteWorkTable`]: https://docs.rs/datafusion/latest/datafusion/datasource/cte_worktable/struct.CteWorkTable.html @@ -83,6 +92,12 @@ pub trait ContextProvider { &[] } + /// Return [`RelationPlanner`] extensions for planning table factors + #[cfg(feature = "sql")] + fn get_relation_planners(&self) -> &[Arc] { + &[] + } + /// Return [`TypePlanner`] extensions for planning data types #[cfg(feature = "sql")] fn get_type_planner(&self) -> Option> { @@ -92,6 +107,9 @@ pub trait ContextProvider { /// Return the scalar function with a given name, if any fn get_function_meta(&self, name: &str) -> Option>; + /// Return the higher order function with a given name, if any + fn get_higher_order_meta(&self, name: &str) -> Option>; + /// Return the aggregate function with a given name, if any fn get_aggregate_meta(&self, name: &str) -> Option>; @@ -103,12 +121,26 @@ pub trait ContextProvider { /// A user defined variable is typically accessed via `@var_name` fn get_variable_type(&self, variable_names: &[String]) -> Option; + /// Return metadata about a system/user-defined variable, if any. + /// + /// By default, this wraps [`Self::get_variable_type`] in an Arrow [`Field`] + /// with nullable set to `true` and no metadata. Implementations that can + /// provide richer information (such as nullability or extension metadata) + /// should override this method. + fn get_variable_field(&self, variable_names: &[String]) -> Option { + self.get_variable_type(variable_names) + .map(|data_type| data_type.into_nullable_field_ref()) + } + /// Return overall configuration options fn options(&self) -> &ConfigOptions; /// Return all scalar function names fn udf_names(&self) -> Vec; + /// Return all higher order function names + fn higher_order_function_names(&self) -> Vec; + /// Return all aggregate function names fn udaf_names(&self) -> Vec; @@ -117,6 +149,10 @@ pub trait ContextProvider { } /// Customize planning of SQL AST expressions to [`Expr`]s +/// +/// For more background, please also see the [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog] +/// +/// [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog]: https://datafusion.apache.org/blog/2026/01/12/extending-sql pub trait ExprPlanner: Debug + Send + Sync { /// Plan the binary operation between two expressions, returns original /// BinaryExpr if not possible @@ -227,13 +263,6 @@ pub trait ExprPlanner: Debug + Send + Sync { ) } - /// Plans `ANY` expression, such as `expr = ANY(array_expr)` - /// - /// Returns origin binary expression if not possible - fn plan_any(&self, expr: RawBinaryExpr) -> Result> { - Ok(PlannerResult::Original(expr)) - } - /// Plans aggregate functions, such as `COUNT()` /// /// Returns original expression arguments if not possible @@ -324,16 +353,119 @@ pub enum PlannerResult { Original(T), } +/// Result of planning a relation with [`RelationPlanner`] +#[cfg(feature = "sql")] +#[derive(Debug, Clone)] +pub struct PlannedRelation { + /// The logical plan for the relation + pub plan: LogicalPlan, + /// Optional table alias for the relation + pub alias: Option, +} + +#[cfg(feature = "sql")] +impl PlannedRelation { + /// Create a new `PlannedRelation` with the given plan and alias + pub fn new(plan: LogicalPlan, alias: Option) -> Self { + Self { plan, alias } + } +} + +/// Result of attempting to plan a relation with extension planners +#[cfg(feature = "sql")] +#[derive(Debug)] +pub enum RelationPlanning { + /// The relation was successfully planned by an extension planner + Planned(Box), + /// No extension planner handled the relation, return it for default processing + Original(Box), +} + +/// Customize planning SQL table factors to [`LogicalPlan`]s. +#[cfg(feature = "sql")] +/// For more background, please also see the [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog] +/// +/// [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog]: https://datafusion.apache.org/blog/2026/01/12/extending-sql +pub trait RelationPlanner: Debug + Send + Sync { + /// Plan a table factor into a [`LogicalPlan`]. + /// + /// Returning [`RelationPlanning::Planned`] short-circuits further planning and uses the + /// provided plan. Returning [`RelationPlanning::Original`] allows the next registered planner, + /// or DataFusion's default logic, to handle the relation. + fn plan_relation( + &self, + relation: TableFactor, + context: &mut dyn RelationPlannerContext, + ) -> Result; +} + +/// Provides utilities for relation planners to interact with DataFusion's SQL +/// planner. +/// +/// This trait provides SQL planning utilities specific to relation planning, +/// such as converting SQL expressions to logical expressions and normalizing +/// identifiers. It uses composition to provide access to session context via +/// [`ContextProvider`]. +#[cfg(feature = "sql")] +pub trait RelationPlannerContext { + /// Provides access to the underlying context provider for reading session + /// configuration, accessing tables, functions, and other metadata. + fn context_provider(&self) -> &dyn ContextProvider; + + /// Plans the specified relation through the full planner pipeline, starting + /// from the first registered relation planner. + fn plan(&mut self, relation: TableFactor) -> Result; + + /// Converts a SQL expression into a logical expression using the current + /// planner context. + fn sql_to_expr(&mut self, expr: SQLExpr, schema: &DFSchema) -> Result; + + /// Converts a SQL expression into a logical expression without DataFusion + /// rewrites. + fn sql_expr_to_logical_expr( + &mut self, + expr: SQLExpr, + schema: &DFSchema, + ) -> Result; + + /// Normalizes an identifier according to session settings. + fn normalize_ident(&self, ident: Ident) -> String; + + /// Normalizes a SQL object name into a [`TableReference`]. + fn object_name_to_table_reference(&self, name: ObjectName) -> Result; +} + /// Customize planning SQL types to DataFusion (Arrow) types. #[cfg(feature = "sql")] +/// For more background, please also see the [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog] +/// +/// [Extending SQL in DataFusion: from ->> to TABLESAMPLE blog]: https://datafusion.apache.org/blog/2026/01/12/extending-sql pub trait TypePlanner: Debug + Send + Sync { /// Plan SQL [`sqlparser::ast::DataType`] to DataFusion [`DataType`] /// /// Returns None if not possible + #[deprecated(since = "53.0.0", note = "Use plan_type_field()")] fn plan_type( &self, _sql_type: &sqlparser::ast::DataType, ) -> Result> { Ok(None) } + + /// Plan SQL [`sqlparser::ast::DataType`] to DataFusion [`FieldRef`] + /// + /// Returns None if not possible. Unlike [`Self::plan_type`], `plan_type_field()` + /// makes it possible to express extension types (e.g., `arrow.uuid`) or otherwise + /// insert metadata into the DataFusion type representation. The default implementation + /// falls back on [`Self::plan_type`] for backward compatibility and wraps the result + /// in a nullable field reference. + fn plan_type_field( + &self, + sql_type: &sqlparser::ast::DataType, + ) -> Result> { + #[expect(deprecated)] + Ok(self + .plan_type(sql_type)? + .map(|data_type| data_type.into_nullable_field_ref())) + } } diff --git a/datafusion/expr/src/predicate_bounds.rs b/datafusion/expr/src/predicate_bounds.rs index 192b2929fdeb8..992d9f88bb14a 100644 --- a/datafusion/expr/src/predicate_bounds.rs +++ b/datafusion/expr/src/predicate_bounds.rs @@ -171,10 +171,10 @@ impl PredicateBoundsEvaluator<'_> { } // Check if the expression is the `certainly_null_expr` that was passed in. - if let Some(certainly_null_expr) = &self.certainly_null_expr { - if expr.eq(certainly_null_expr) { - return NullableInterval::TRUE; - } + if let Some(certainly_null_expr) = &self.certainly_null_expr + && expr.eq(certainly_null_expr) + { + return NullableInterval::TRUE; } // `expr` is nullable, so our default answer for `is null` is going to be `{ TRUE, FALSE }`. @@ -235,8 +235,8 @@ mod tests { use crate::expr::ScalarFunction; use crate::predicate_bounds::evaluate_bounds; use crate::{ - binary_expr, col, create_udf, is_false, is_not_false, is_not_null, is_not_true, - is_not_unknown, is_null, is_true, is_unknown, lit, not, Expr, + Expr, binary_expr, col, create_udf, is_false, is_not_false, is_not_null, + is_not_true, is_not_unknown, is_null, is_true, is_unknown, lit, not, }; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{DFSchema, Result, ScalarValue}; diff --git a/datafusion/expr/src/preimage.rs b/datafusion/expr/src/preimage.rs new file mode 100644 index 0000000000000..67ca7a91bbf38 --- /dev/null +++ b/datafusion/expr/src/preimage.rs @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr_common::interval_arithmetic::Interval; + +use crate::Expr; + +/// Return from [`crate::ScalarUDFImpl::preimage`] +pub enum PreimageResult { + /// No preimage exists for the specified value + None, + /// The expression always evaluates to the specified constant + /// given that `expr` is within the interval + Range { expr: Expr, interval: Box }, +} diff --git a/datafusion/expr/src/ptr_eq.rs b/datafusion/expr/src/ptr_eq.rs index 0bbfba5e8d063..79ea3d7219143 100644 --- a/datafusion/expr/src/ptr_eq.rs +++ b/datafusion/expr/src/ptr_eq.rs @@ -39,7 +39,7 @@ pub fn arc_ptr_hash(a: &Arc, hasher: &mut impl Hasher) { /// /// If you have pointers to a `dyn UDF impl` consider using [`super::udf_eq::UdfEq`]. #[derive(Clone)] -#[allow(private_bounds)] // This is so that PtrEq can only be used with allowed pointer types (e.g. Arc), without allowing misuse. +#[expect(private_bounds)] // This is so that PtrEq can only be used with allowed pointer types (e.g. Arc), without allowing misuse. pub struct PtrEq(Ptr); impl PartialEq for PtrEq> diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 9554dd68e1758..4b9744d9573b6 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -18,18 +18,32 @@ //! FunctionRegistry trait use crate::expr_rewriter::FunctionRewrite; +use crate::higher_order_function::HigherOrderUDF; use crate::planner::ExprPlanner; use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; -use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result}; +use arrow::datatypes::Field; +use arrow_schema::DataType; +use arrow_schema::extension::{ + Bool8, ExtensionType, FixedShapeTensor, Json, Opaque, TimestampWithOffset, Uuid, + VariableShapeTensor, +}; +use datafusion_common::types::{ + DFBool8, DFExtensionTypeRef, DFFixedShapeTensor, DFJson, DFOpaque, + DFTimestampWithOffset, DFUuid, DFVariableShapeTensor, +}; +use datafusion_common::{HashMap, Result, not_impl_err, plan_datafusion_err}; use std::collections::HashSet; -use std::fmt::Debug; -use std::sync::Arc; +use std::fmt::{Debug, Formatter}; +use std::sync::{Arc, RwLock}; /// A registry knows how to build logical expressions out of user-defined function' names pub trait FunctionRegistry { /// Returns names of all available scalar user defined functions. fn udfs(&self) -> HashSet; + /// Returns names of all available higher order user defined functions. + fn higher_order_function_names(&self) -> HashSet; + /// Returns names of all available aggregate user defined functions. fn udafs(&self) -> HashSet; @@ -40,6 +54,10 @@ pub trait FunctionRegistry { /// `name`. fn udf(&self, name: &str) -> Result>; + /// Returns a reference to the user defined higher order function named + /// `name`. + fn higher_order_function(&self, name: &str) -> Result>; + /// Returns a reference to the user defined aggregate function (udaf) named /// `name`. fn udaf(&self, name: &str) -> Result>; @@ -56,6 +74,17 @@ pub trait FunctionRegistry { fn register_udf(&mut self, _udf: Arc) -> Result>> { not_impl_err!("Registering ScalarUDF") } + /// Registers a new [`HigherOrderUDF`], returning any previously registered + /// implementation. + /// + /// Returns an error (the default) if the function can not be registered, + /// for example if the registry is read only. + fn register_higher_order_function( + &mut self, + _function: Arc, + ) -> Result>> { + not_impl_err!("Registering HigherOrderUDF") + } /// Registers a new [`AggregateUDF`], returning any previously registered /// implementation. /// @@ -85,6 +114,18 @@ pub trait FunctionRegistry { not_impl_err!("Deregistering ScalarUDF") } + /// Deregisters a [`HigherOrderUDF`], returning the implementation that was + /// deregistered. + /// + /// Returns an error (the default) if the function can not be deregistered, + /// for example if the registry is read only. + fn deregister_higher_order_function( + &mut self, + _name: &str, + ) -> Result>> { + not_impl_err!("Deregistering HigherOrderUDF") + } + /// Deregisters a [`AggregateUDF`], returning the implementation that was /// deregistered. /// @@ -156,6 +197,8 @@ pub struct MemoryFunctionRegistry { udafs: HashMap>, /// Window Functions udwfs: HashMap>, + /// Higher Order Functions + higher_order_functions: HashMap>, } impl MemoryFunctionRegistry { @@ -176,6 +219,13 @@ impl FunctionRegistry for MemoryFunctionRegistry { .ok_or_else(|| plan_datafusion_err!("Function {name} not found")) } + fn higher_order_function(&self, name: &str) -> Result> { + self.higher_order_functions + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Higher Order Function {name} not found")) + } + fn udaf(&self, name: &str) -> Result> { self.udafs .get(name) @@ -193,6 +243,14 @@ impl FunctionRegistry for MemoryFunctionRegistry { fn register_udf(&mut self, udf: Arc) -> Result>> { Ok(self.udfs.insert(udf.name().to_string(), udf)) } + fn register_higher_order_function( + &mut self, + function: Arc, + ) -> Result>> { + Ok(self + .higher_order_functions + .insert(function.name().into(), function)) + } fn register_udaf( &mut self, udaf: Arc, @@ -207,6 +265,10 @@ impl FunctionRegistry for MemoryFunctionRegistry { vec![] } + fn higher_order_function_names(&self) -> HashSet { + self.higher_order_functions.keys().cloned().collect() + } + fn udafs(&self) -> HashSet { self.udafs.keys().cloned().collect() } @@ -215,3 +277,320 @@ impl FunctionRegistry for MemoryFunctionRegistry { self.udwfs.keys().cloned().collect() } } + +/// A cheaply cloneable pointer to an [ExtensionTypeRegistry]. +pub type ExtensionTypeRegistryRef = Arc; + +/// Manages [`ExtensionTypeRegistration`]s, which allow users to register custom behavior for +/// extension types. +/// +/// Each registration is connected to the extension type name, which can also be looked up to get +/// the registration. +pub trait ExtensionTypeRegistry: Debug + Send + Sync { + /// Returns a reference to registration of an extension type named `name`. + /// + /// Returns an error if there is no extension type with that name. + fn extension_type_registration( + &self, + name: &str, + ) -> Result; + + /// Creates a [`DFExtensionTypeRef`] from the type information in the `field`. + /// + /// The result `Ok(None)` indicates that there is no extension type metadata. Returns an error + /// if the extension type in the metadata is not found. + fn create_extension_type_for_field( + &self, + field: &Field, + ) -> Result> { + let Some(extension_type_name) = field.extension_type_name() else { + return Ok(None); + }; + + let registration = self.extension_type_registration(extension_type_name)?; + registration + .create_df_extension_type(field.data_type(), field.extension_type_metadata()) + .map(Some) + } + + /// Returns all registered [ExtensionTypeRegistration]. + fn extension_type_registrations(&self) -> Vec; + + /// Registers a new [ExtensionTypeRegistrationRef], returning any previously registered + /// implementation. + /// + /// Returns an error if the type cannot be registered, for example, if the registry is + /// read-only. + fn add_extension_type_registration( + &self, + extension_type: ExtensionTypeRegistrationRef, + ) -> Result>; + + /// Extends the registry with the provided extension types. + /// + /// Returns an error if the type cannot be registered, for example, if the registry is + /// read-only. + fn extend(&self, extension_types: &[ExtensionTypeRegistrationRef]) -> Result<()> { + for extension_type in extension_types.iter().cloned() { + self.add_extension_type_registration(extension_type)?; + } + Ok(()) + } + + /// Deregisters an extension type registration with the name `name`, returning the + /// implementation that was deregistered. + /// + /// Returns an error if the type cannot be deregistered, for example, if the registry is + /// read-only. + fn remove_extension_type_registration( + &self, + name: &str, + ) -> Result>; +} + +/// A factory that creates instances of extension types from a storage [`DataType`] and the +/// metadata. +pub type ExtensionTypeFactory = + dyn Fn(&DataType, Option<&str>) -> Result + Send + Sync; + +/// A cheaply cloneable pointer to an [ExtensionTypeRegistration]. +pub type ExtensionTypeRegistrationRef = Arc; + +/// The registration of an extension type. Implementations of this trait are responsible for +/// *creating* instances of [`DFExtensionType`] that represent the entire semantics of an extension +/// type. +/// +/// # Why do we need a Registration? +/// +/// A good question is why this trait is even necessary. Why not directly register the +/// [`DFExtensionType`] in a registry? +/// +/// While this works for extension types requiring no additional metadata (e.g., `arrow.uuid`), it +/// does not work for more complex extension types with metadata. For example, consider an extension +/// type `custom.shortened(n)` that aims to short the pretty-printing string to `n` characters. +/// Here, `n` is a parameter of the extension type and should be a field in the struct that +/// implements the [`DFExtensionType`]. The job of the registration is to read the metadata from the +/// field and create the corresponding [`DFExtensionType`] instance with the correct `n` set. +/// +/// [`DFExtensionType`]: datafusion_common::types::DFExtensionType +pub struct ExtensionTypeRegistration { + /// The name of the extension type. + name: String, + /// A function that creates an instance of [`DFExtensionTypeRef`] from the storage type and the + /// metadata. + factory: Box, +} + +impl ExtensionTypeRegistration { + /// Creates a new registration for an extension type. The factory is required to validate that + /// the storage [`DataType`] is compatible with the extension type. + pub fn new_arc( + name: impl Into, + factory: impl Fn(&DataType, Option<&str>) -> Result + + Send + + Sync + + 'static, + ) -> ExtensionTypeRegistrationRef { + Arc::new(Self { + name: name.into(), + factory: Box::new(factory), + }) + } +} + +impl ExtensionTypeRegistration { + /// The name of the extension type. + /// + /// This name will be used to find the correct [ExtensionTypeRegistration] when an extension + /// type is encountered. + pub fn type_name(&self) -> &str { + &self.name + } + + /// Creates an extension type instance from the optional metadata. The name of the extension + /// type is not a parameter as it's already defined by the registration itself. + pub fn create_df_extension_type( + &self, + storage_type: &DataType, + metadata: Option<&str>, + ) -> Result { + self.factory.as_ref()(storage_type, metadata) + } +} + +impl Debug for ExtensionTypeRegistration { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("DefaultExtensionTypeRegistration") + .field("type_name", &self.name) + .finish() + } +} + +/// An [`ExtensionTypeRegistry`] that uses in memory [`HashMap`]s. +#[derive(Clone, Debug)] +pub struct MemoryExtensionTypeRegistry { + /// Holds a mapping between the name of an extension type and its logical type. + extension_types: Arc>>, +} + +impl Default for MemoryExtensionTypeRegistry { + fn default() -> Self { + Self::new_empty() + } +} + +impl MemoryExtensionTypeRegistry { + /// Creates an empty [MemoryExtensionTypeRegistry]. + pub fn new_empty() -> Self { + Self { + extension_types: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Pre-registers the [canonical extension types](https://arrow.apache.org/docs/format/CanonicalExtensions.html) + /// in the extension type registry. + pub fn new_with_canonical_extension_types() -> Self { + let mapping = [ + ExtensionTypeRegistration::new_arc( + FixedShapeTensor::NAME, + |storage_type, metadata| { + Ok(Arc::new(DFFixedShapeTensor::try_new( + storage_type, + FixedShapeTensor::deserialize_metadata(metadata)?, + )?)) + }, + ), + ExtensionTypeRegistration::new_arc( + VariableShapeTensor::NAME, + |storage_type, metadata| { + Ok(Arc::new(DFVariableShapeTensor::try_new( + storage_type, + VariableShapeTensor::deserialize_metadata(metadata)?, + )?)) + }, + ), + ExtensionTypeRegistration::new_arc(Json::NAME, |storage_type, metadata| { + Ok(Arc::new(DFJson::try_new( + storage_type, + Json::deserialize_metadata(metadata)?, + )?)) + }), + ExtensionTypeRegistration::new_arc(Uuid::NAME, |storage_type, metadata| { + Ok(Arc::new(DFUuid::try_new( + storage_type, + Uuid::deserialize_metadata(metadata)?, + )?)) + }), + ExtensionTypeRegistration::new_arc(Opaque::NAME, |storage_type, metadata| { + Ok(Arc::new(DFOpaque::try_new( + storage_type, + Opaque::deserialize_metadata(metadata)?, + )?)) + }), + ExtensionTypeRegistration::new_arc(Bool8::NAME, |storage_type, metadata| { + Ok(Arc::new(DFBool8::try_new( + storage_type, + Bool8::deserialize_metadata(metadata)?, + )?)) + }), + ExtensionTypeRegistration::new_arc( + TimestampWithOffset::NAME, + |storage_type, metadata| { + Ok(Arc::new(DFTimestampWithOffset::try_new( + storage_type, + TimestampWithOffset::deserialize_metadata(metadata)?, + )?)) + }, + ), + ]; + + let mut extension_types = HashMap::new(); + for registration in mapping.into_iter() { + extension_types.insert(registration.type_name().to_owned(), registration); + } + + Self { + extension_types: Arc::new(RwLock::new(HashMap::from(extension_types))), + } + } + + /// Creates a new [MemoryExtensionTypeRegistry] with the provided `types`. + /// + /// # Errors + /// + /// Returns an error if one of the `types` is a native type. + pub fn new_with_types( + types: impl IntoIterator, + ) -> Result { + let extension_types = types + .into_iter() + .map(|t| (t.type_name().to_owned(), t)) + .collect::>(); + Ok(Self { + extension_types: Arc::new(RwLock::new(extension_types)), + }) + } + + /// Returns a list of all registered types. + pub fn all_extension_types(&self) -> Vec { + self.extension_types + .read() + .expect("Extension type registry lock poisoned") + .values() + .cloned() + .collect() + } +} + +impl ExtensionTypeRegistry for MemoryExtensionTypeRegistry { + fn extension_type_registration( + &self, + name: &str, + ) -> Result { + self.extension_types + .write() + .expect("Extension type registry lock poisoned") + .get(name) + .ok_or_else(|| plan_datafusion_err!("Logical type not found.")) + .cloned() + } + + fn extension_type_registrations(&self) -> Vec { + self.extension_types + .read() + .expect("Extension type registry lock poisoned") + .values() + .cloned() + .collect() + } + + fn add_extension_type_registration( + &self, + extension_type: ExtensionTypeRegistrationRef, + ) -> Result> { + Ok(self + .extension_types + .write() + .expect("Extension type registry lock poisoned") + .insert(extension_type.type_name().to_owned(), extension_type)) + } + + fn remove_extension_type_registration( + &self, + name: &str, + ) -> Result> { + Ok(self + .extension_types + .write() + .expect("Extension type registry lock poisoned") + .remove(name)) + } +} + +impl From> for MemoryExtensionTypeRegistry { + fn from(value: HashMap) -> Self { + Self { + extension_types: Arc::new(RwLock::new(value)), + } + } +} diff --git a/datafusion/expr/src/select_expr.rs b/datafusion/expr/src/select_expr.rs index bfec4c5844d08..22b9660572a66 100644 --- a/datafusion/expr/src/select_expr.rs +++ b/datafusion/expr/src/select_expr.rs @@ -20,7 +20,7 @@ use std::fmt; use arrow::datatypes::FieldRef; use datafusion_common::{Column, TableReference}; -use crate::{expr::WildcardOptions, Expr}; +use crate::{Expr, expr::WildcardOptions}; /// Represents a SELECT expression in a SQL query. /// diff --git a/datafusion/expr/src/simplify.rs b/datafusion/expr/src/simplify.rs index 02794271a9ee1..522cf122a273c 100644 --- a/datafusion/expr/src/simplify.rs +++ b/datafusion/expr/src/simplify.rs @@ -15,92 +15,167 @@ // specific language governing permissions and limitations // under the License. -//! Structs and traits to provide the information needed for expression simplification. +//! Structs to provide the information needed for expression simplification. -use arrow::datatypes::DataType; -use datafusion_common::{internal_datafusion_err, DFSchemaRef, Result}; - -use crate::{execution_props::ExecutionProps, Expr, ExprSchemable}; +use std::sync::Arc; -/// Provides the information necessary to apply algebraic simplification to an -/// [Expr]. See [SimplifyContext] for one concrete implementation. -/// -/// This trait exists so that other systems can plug schema -/// information in without having to create `DFSchema` objects. If you -/// have a [`DFSchemaRef`] you can use [`SimplifyContext`] -pub trait SimplifyInfo { - /// Returns true if this Expr has boolean type - fn is_boolean_type(&self, expr: &Expr) -> Result; - - /// Returns true of this expr is nullable (could possibly be NULL) - fn nullable(&self, expr: &Expr) -> Result; - - /// Returns details needed for partial expression evaluation - fn execution_props(&self) -> &ExecutionProps; +use arrow::datatypes::DataType; +use chrono::{DateTime, Utc}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{DFSchema, DFSchemaRef, Result}; - /// Returns data type of this expr needed for determining optimized int type of a value - fn get_data_type(&self, expr: &Expr) -> Result; -} +use crate::{Expr, ExprSchemable}; -/// Provides simplification information based on DFSchema and -/// [`ExecutionProps`]. This is the default implementation used by DataFusion +/// Provides simplification information based on schema, query execution time, +/// and configuration options. /// /// # Example /// See the `simplify_demo` in the [`expr_api` example] /// -/// [`expr_api` example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs +/// [`expr_api` example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/query_planning/expr_api.rs #[derive(Debug, Clone)] -pub struct SimplifyContext<'a> { +pub struct SimplifyContext { + schema: DFSchemaRef, + query_execution_start_time: Option>, + config_options: Arc, +} + +/// Builder for [`SimplifyContext`]. +#[derive(Debug, Default)] +pub struct SimplifyContextBuilder { schema: Option, - props: &'a ExecutionProps, + query_execution_start_time: Option>, + config_options: Option>, } -impl<'a> SimplifyContext<'a> { - /// Create a new SimplifyContext - pub fn new(props: &'a ExecutionProps) -> Self { +impl Default for SimplifyContext { + fn default() -> Self { Self { - schema: None, - props, + schema: Arc::new(DFSchema::empty()), + query_execution_start_time: None, + config_options: Arc::new(ConfigOptions::default()), } } +} + +impl SimplifyContext { + /// Returns a builder for [`SimplifyContext`]. + pub fn builder() -> SimplifyContextBuilder { + SimplifyContextBuilder::default() + } - /// Register a [`DFSchemaRef`] with this context + #[deprecated( + since = "54.0.0", + note = "Use SimplifyContextBuilder if you intend to use non-default values." + )] + /// Set the [`ConfigOptions`] for this context + pub fn with_config_options(mut self, config_options: Arc) -> Self { + self.config_options = config_options; + self + } + + #[deprecated( + since = "54.0.0", + note = "Use SimplifyContextBuilder if you intend to use non-default values." + )] + /// Set the schema for this context pub fn with_schema(mut self, schema: DFSchemaRef) -> Self { - self.schema = Some(schema); + self.schema = schema; self } -} -impl SimplifyInfo for SimplifyContext<'_> { - /// Returns true if this Expr has boolean type - fn is_boolean_type(&self, expr: &Expr) -> Result { - if let Some(schema) = &self.schema { - if let Ok(DataType::Boolean) = expr.get_type(schema) { - return Ok(true); - } - } + #[deprecated( + since = "54.0.0", + note = "Use SimplifyContextBuilder if you intend to use non-default values." + )] + /// Set the query execution start time + pub fn with_query_execution_start_time( + mut self, + query_execution_start_time: Option>, + ) -> Self { + self.query_execution_start_time = query_execution_start_time; + self + } + + #[deprecated( + since = "54.0.0", + note = "Use SimplifyContextBuilder if you intend to use non-default values." + )] + /// Set the query execution start to the current time + pub fn with_current_time(mut self) -> Self { + self.query_execution_start_time = Some(Utc::now()); + self + } + + /// Returns the schema + pub fn schema(&self) -> &DFSchemaRef { + &self.schema + } - Ok(false) + /// Returns true if this Expr has boolean type + pub fn is_boolean_type(&self, expr: &Expr) -> Result { + Ok(expr.get_type(&self.schema)? == DataType::Boolean) } /// Returns true if expr is nullable - fn nullable(&self, expr: &Expr) -> Result { - let schema = self.schema.as_ref().ok_or_else(|| { - internal_datafusion_err!("attempt to get nullability without schema") - })?; - expr.nullable(schema.as_ref()) + pub fn nullable(&self, expr: &Expr) -> Result { + expr.nullable(self.schema.as_ref()) } /// Returns data type of this expr needed for determining optimized int type of a value - fn get_data_type(&self, expr: &Expr) -> Result { - let schema = self.schema.as_ref().ok_or_else(|| { - internal_datafusion_err!("attempt to get data type without schema") - })?; - expr.get_type(schema) + pub fn get_data_type(&self, expr: &Expr) -> Result { + expr.get_type(&self.schema) + } + + /// Returns the time at which the query execution started. + /// If `None`, time-dependent functions like `now()` will not be simplified. + pub fn query_execution_start_time(&self) -> Option> { + self.query_execution_start_time } - fn execution_props(&self) -> &ExecutionProps { - self.props + /// Returns the configuration options for the session. + pub fn config_options(&self) -> &Arc { + &self.config_options + } +} + +impl SimplifyContextBuilder { + /// Set the [`ConfigOptions`] for this context. + pub fn with_config_options(mut self, config_options: Arc) -> Self { + self.config_options = Some(config_options); + self + } + + /// Set the schema for this context. + pub fn with_schema(mut self, schema: DFSchemaRef) -> Self { + self.schema = Some(schema); + self + } + + /// Set the query execution start time. + pub fn with_query_execution_start_time( + mut self, + query_execution_start_time: Option>, + ) -> Self { + self.query_execution_start_time = query_execution_start_time; + self + } + + /// Set the query execution start to the current time. + pub fn with_current_time(mut self) -> Self { + self.query_execution_start_time = Some(Utc::now()); + self + } + + /// Build a [`SimplifyContext`], filling in any unspecified fields with defaults. + pub fn build(self) -> SimplifyContext { + SimplifyContext { + schema: self.schema.unwrap_or_else(|| Arc::new(DFSchema::empty())), + query_execution_start_time: self.query_execution_start_time, + config_options: self + .config_options + .unwrap_or_else(|| Arc::new(ConfigOptions::default())), + } } } @@ -113,3 +188,38 @@ pub enum ExprSimplifyResult { /// are return unmodified. Original(Vec), } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn simplify_context_builder_builds_default_context() { + let context = SimplifyContext::builder().build(); + let default_options = ConfigOptions::default(); + + assert_eq!(context.schema().as_ref(), &DFSchema::empty()); + assert_eq!(context.query_execution_start_time(), None); + assert_eq!( + context.config_options().optimizer.max_passes, + default_options.optimizer.max_passes + ); + } + + #[test] + fn simplify_context_builder_uses_overrides() { + let schema = Arc::new(DFSchema::empty()); + let config_options = Arc::new(ConfigOptions::default()); + let current_time = Utc::now(); + + let context = SimplifyContext::builder() + .with_schema(Arc::clone(&schema)) + .with_config_options(Arc::clone(&config_options)) + .with_query_execution_start_time(Some(current_time)) + .build(); + + assert_eq!(context.schema().as_ref(), schema.as_ref()); + assert_eq!(context.query_execution_start_time(), Some(current_time)); + assert!(Arc::ptr_eq(context.config_options(), &config_options)); + } +} diff --git a/datafusion/expr/src/sql.rs b/datafusion/expr/src/sql.rs new file mode 100644 index 0000000000000..d582a0f6b95d1 --- /dev/null +++ b/datafusion/expr/src/sql.rs @@ -0,0 +1,174 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Local copies of [`sqlparser::ast`] structures +//! +//! These types are used when the `sql` feature is disabled. When `sql` is +//! enabled, the upstream types from [`sqlparser`] are used instead. +//! +//! These definitions should be structurally compatible with the upstream +//! `sqlparser` types, so that code which switches between them via `cfg` keeps +//! compiling. +//! +//! See [#17332](https://github.com/apache/datafusion/pull/17332) for +//! more detail. + +use crate::expr::display_comma_separated; +use std::fmt; +use std::fmt::{Display, Formatter}; + +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct IlikeSelectItem { + pub pattern: String, +} + +impl Display for IlikeSelectItem { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "ILIKE '{}'", &self.pattern)?; + Ok(()) + } +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub enum ExcludeSelectItem { + Single(ObjectName), + Multiple(Vec), +} + +impl Display for ExcludeSelectItem { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "EXCLUDE")?; + match self { + Self::Single(column) => { + write!(f, " {column}")?; + } + Self::Multiple(columns) => { + write!(f, " ({})", display_comma_separated(columns))?; + } + } + Ok(()) + } +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct ObjectName(pub Vec); + +impl Display for ObjectName { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let parts: Vec = self.0.iter().map(|p| format!("{p}")).collect(); + write!(f, "{}", parts.join(".")) + } +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub enum ObjectNamePart { + Identifier(Ident), +} + +impl ObjectNamePart { + pub fn as_ident(&self) -> Option<&Ident> { + match self { + ObjectNamePart::Identifier(ident) => Some(ident), + } + } +} + +impl Display for ObjectNamePart { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self { + ObjectNamePart::Identifier(ident) => write!(f, "{ident}"), + } + } +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct ExceptSelectItem { + pub first_element: Ident, + pub additional_elements: Vec, +} + +impl Display for ExceptSelectItem { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "EXCEPT ")?; + if self.additional_elements.is_empty() { + write!(f, "({})", self.first_element)?; + } else { + write!( + f, + "({}, {})", + self.first_element, + display_comma_separated(&self.additional_elements) + )?; + } + Ok(()) + } +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub enum RenameSelectItem { + Single(String), + Multiple(Vec), +} + +impl Display for RenameSelectItem { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "RENAME")?; + match self { + Self::Single(column) => { + write!(f, " {column}")?; + } + Self::Multiple(columns) => { + write!(f, " ({})", display_comma_separated(columns))?; + } + } + Ok(()) + } +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct Ident { + /// The value of the identifier without quotes. + pub value: String, + /// The starting quote if any. Valid quote characters are the single quote, + /// double quote, backtick, and opening square bracket. + pub quote_style: Option, + /// The span of the identifier in the original SQL string. + pub span: String, +} + +impl Display for Ident { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "[{}]", self.value) + } +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct ReplaceSelectElement { + pub expr: String, + pub column_name: Ident, + pub as_keyword: bool, +} + +impl Display for ReplaceSelectElement { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + if self.as_keyword { + write!(f, "{} AS {}", self.expr, self.column_name) + } else { + write!(f, "{} {}", self.expr, self.column_name) + } + } +} diff --git a/datafusion/expr/src/table_source.rs b/datafusion/expr/src/table_source.rs index d3b253c0e102c..65dce8f3c8b0b 100644 --- a/datafusion/expr/src/table_source.rs +++ b/datafusion/expr/src/table_source.rs @@ -91,9 +91,7 @@ impl std::fmt::Display for TableType { /// /// [`TableProvider`]: https://docs.rs/datafusion/latest/datafusion/datasource/trait.TableProvider.html /// [`DefaultTableSource`]: https://docs.rs/datafusion/latest/datafusion/datasource/default_table_source/struct.DefaultTableSource.html -pub trait TableSource: Sync + Send { - fn as_any(&self) -> &dyn Any; - +pub trait TableSource: Any + Sync + Send { /// Get a reference to the schema for this table fn schema(&self) -> SchemaRef; @@ -130,3 +128,13 @@ pub trait TableSource: Sync + Send { None } } + +impl dyn TableSource { + pub fn is(&self) -> bool { + (self as &dyn Any).is::() + } + + pub fn downcast_ref(&self) -> Option<&T> { + (self as &dyn Any).downcast_ref() + } +} diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index 8609afeae6018..a1f29b649b2f8 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -19,29 +19,27 @@ //! //! These are used to avoid a dependence on `datafusion-functions-aggregate` which live in a different crate -use std::any::Any; - use arrow::datatypes::{ - DataType, FieldRef, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, - DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, - DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, + DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, + DECIMAL64_MAX_SCALE, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DataType, FieldRef, }; use datafusion_common::plan_err; -use datafusion_common::{exec_err, not_impl_err, utils::take_function_args, Result}; +use datafusion_common::{Result, exec_err, not_impl_err, utils::take_function_args}; -use crate::type_coercion::aggregates::NUMERICS; use crate::Volatility::Immutable; use crate::{ + Accumulator, AggregateUDFImpl, Coercion, Expr, GroupsAccumulator, ReversedUDAF, + Signature, TypeSignature, TypeSignatureClass, expr::AggregateFunction, function::{AccumulatorArgs, StateFieldsArgs}, utils::AggregateOrderSensitivity, - Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature, }; +use datafusion_common::types::{NativeType, logical_float64}; macro_rules! create_func { ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { - paste::paste! { #[doc = concat!("AggregateFunction that returns a [AggregateUDF](crate::AggregateUDF) for [`", stringify!($UDAF), "`]")] pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc { // Singleton instance of [$UDAF], ensures the UDAF is only created once @@ -51,7 +49,6 @@ macro_rules! create_func { }); std::sync::Arc::clone(&INSTANCE) } - } } } @@ -115,10 +112,6 @@ impl Default for Sum { } impl AggregateUDFImpl for Sum { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "sum" } @@ -247,10 +240,6 @@ impl Count { } impl AggregateUDFImpl for Count { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "COUNT" } @@ -334,10 +323,6 @@ impl Min { } impl AggregateUDFImpl for Min { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "min" } @@ -416,10 +401,6 @@ impl Max { } impl AggregateUDFImpl for Max { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "max" } @@ -464,9 +445,22 @@ pub struct Avg { impl Avg { pub fn new() -> Self { + let signature = Signature::one_of( + vec![ + TypeSignature::Coercible(vec![Coercion::new_exact( + TypeSignatureClass::Decimal, + )]), + TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Integer, TypeSignatureClass::Float], + NativeType::Float64, + )]), + ], + Immutable, + ); Self { aliases: vec![String::from("mean")], - signature: Signature::uniform(1, NUMERICS.to_vec(), Immutable), + signature, } } } @@ -478,10 +472,6 @@ impl Default for Avg { } impl AggregateUDFImpl for Avg { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "avg" } diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 81846b4f80608..010441b5a25d1 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -17,17 +17,21 @@ //! Tree node implementation for Logical Expressions -use crate::expr::{ - AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast, - GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, - WindowFunction, WindowFunctionParams, +use crate::{ + Expr, + expr::{ + AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, + Cast, GroupingSet, HigherOrderFunction, InList, InSubquery, Lambda, Like, + Placeholder, ScalarFunction, SetComparison, TryCast, Unnest, WindowFunction, + WindowFunctionParams, + }, }; -use crate::Expr; - -use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, +use datafusion_common::{ + Result, + tree_node::{ + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, + }, }; -use datafusion_common::Result; /// Implementation of the [`TreeNode`] trait /// @@ -58,7 +62,8 @@ impl TreeNode for Expr { | Expr::Negative(expr) | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) - | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f), + | Expr::InSubquery(InSubquery { expr, .. }) + | Expr::SetComparison(SetComparison { expr, .. }) => expr.apply_elements(f), Expr::GroupingSet(GroupingSet::Rollup(exprs)) | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f), Expr::ScalarFunction(ScalarFunction { args, .. }) => { @@ -77,7 +82,8 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } - | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue), + | Expr::Placeholder(_) + | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { (left, right).apply_ref_elements(f) } @@ -106,6 +112,8 @@ impl TreeNode for Expr { Expr::InList(InList { expr, list, .. }) => { (expr, list).apply_ref_elements(f) } + Expr::HigherOrderFunction(HigherOrderFunction { func: _, args}) => args.apply_elements(f), + Expr::Lambda (Lambda{ params: _, body}) => body.apply_elements(f) } } @@ -115,7 +123,7 @@ impl TreeNode for Expr { /// indicating whether the expression was transformed or left unchanged. fn map_children Result>>( self, - mut f: F, + f: F, ) -> Result> { Ok(match self { // TODO: remove the next line after `Expr::Wildcard` is removed @@ -127,7 +135,21 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) - | Expr::Literal(_, _) => Transformed::no(self), + | Expr::Literal(_, _) + | Expr::LambdaVariable(_) => Transformed::no(self), + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => expr.map_elements(f)?.update_data(|expr| { + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) + }), Expr::Unnest(Unnest { expr, .. }) => expr .map_elements(f)? .update_data(|expr| Expr::Unnest(Unnest { expr })), @@ -136,8 +158,13 @@ impl TreeNode for Expr { relation, name, metadata, - }) => f(*expr)?.update_data(|e| { - e.alias_qualified_with_metadata(relation, name, metadata) + }) => expr.map_elements(f)?.update_data(|expr| { + Expr::Alias(Alias { + expr, + relation, + name, + metadata, + }) }), Expr::InSubquery(InSubquery { expr, @@ -220,12 +247,12 @@ impl TreeNode for Expr { .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) }), - Expr::Cast(Cast { expr, data_type }) => expr + Expr::Cast(Cast { expr, field }) => expr .map_elements(f)? - .update_data(|be| Expr::Cast(Cast::new(be, data_type))), - Expr::TryCast(TryCast { expr, data_type }) => expr + .update_data(|be| Expr::Cast(Cast::new_from_field(be, field))), + Expr::TryCast(TryCast { expr, field }) => expr .map_elements(f)? - .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), + .update_data(|be| Expr::TryCast(TryCast::new_from_field(be, field))), Expr::ScalarFunction(ScalarFunction { func, args }) => { args.map_elements(f)?.map_data(|new_args| { Ok(Expr::ScalarFunction(ScalarFunction::new_udf( @@ -311,6 +338,14 @@ impl TreeNode for Expr { .update_data(|(new_expr, new_list)| { Expr::InList(InList::new(new_expr, new_list, negated)) }), + Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => { + args.map_elements(f)?.update_data(|args| { + Expr::HigherOrderFunction(HigherOrderFunction { func, args }) + }) + } + Expr::Lambda(Lambda { params, body }) => body + .map_elements(f)? + .update_data(|body| Expr::Lambda(Lambda { params, body })), }) } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index bcaff11bcdb49..33746a2c46b30 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -16,77 +16,98 @@ // under the License. use super::binary::binary_numeric_coercion; -use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; -use arrow::datatypes::FieldRef; +use crate::{ + AggregateUDF, HigherOrderTypeSignature, HigherOrderUDF, ScalarUDF, Signature, + TypeSignature, ValueOrLambda, WindowUDF, +}; +use arrow::datatypes::{Field, FieldRef}; use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; +use datafusion_common::internal_datafusion_err; use datafusion_common::types::LogicalType; use datafusion_common::utils::{ - base_type, coerced_fixed_size_list_to_list, ListCoercion, + ListCoercion, base_type, coerced_fixed_size_list_to_list, }; use datafusion_common::{ - exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims, Result, + Result, exec_err, internal_err, plan_err, types::NativeType, utils::list_ndims, }; use datafusion_expr_common::signature::ArrayFunctionArgument; use datafusion_expr_common::type_coercion::binary::type_union_resolution; use datafusion_expr_common::{ signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD}, - type_coercion::binary::comparison_coercion_numeric, + type_coercion::binary::comparison_coercion, type_coercion::binary::string_coercion, }; use itertools::Itertools as _; use std::sync::Arc; -/// Performs type coercion for scalar function arguments. -/// -/// Returns the data types to which each argument must be coerced to -/// match `signature`. -/// -/// For more details on coercion in general, please see the -/// [`type_coercion`](crate::type_coercion) module. -pub fn data_types_with_scalar_udf( - current_types: &[DataType], - func: &ScalarUDF, -) -> Result> { - let signature = func.signature(); - let type_signature = &signature.type_signature; +/// Extension trait to unify common functionality between [`ScalarUDF`], [`AggregateUDF`] +/// and [`WindowUDF`] for use by signature coercion functions. +pub trait UDFCoercionExt { + /// Should delegate to [`ScalarUDF::name`], [`AggregateUDF::name`] or [`WindowUDF::name`]. + fn name(&self) -> &str; + /// Should delegate to [`ScalarUDF::signature`], [`AggregateUDF::signature`] + /// or [`WindowUDF::signature`]. + fn signature(&self) -> &Signature; + /// Should delegate to [`ScalarUDF::coerce_types`], [`AggregateUDF::coerce_types`] + /// or [`WindowUDF::coerce_types`]. + fn coerce_types(&self, arg_types: &[DataType]) -> Result>; +} - if current_types.is_empty() && type_signature != &TypeSignature::UserDefined { - if type_signature.supports_zero_argument() { - return Ok(vec![]); - } else if type_signature.used_to_support_zero_arguments() { - // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763 - return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name()); - } else { - return plan_err!("'{}' does not support zero arguments", func.name()); - } +impl UDFCoercionExt for ScalarUDF { + fn name(&self) -> &str { + self.name() } - let valid_types = - get_valid_types_with_scalar_udf(type_signature, current_types, func)?; + fn signature(&self) -> &Signature { + self.signature() + } - if valid_types - .iter() - .any(|data_type| data_type == current_types) - { - return Ok(current_types.to_vec()); + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.coerce_types(arg_types) + } +} + +impl UDFCoercionExt for AggregateUDF { + fn name(&self) -> &str { + self.name() + } + + fn signature(&self) -> &Signature { + self.signature() } - try_coerce_types(func.name(), valid_types, current_types, type_signature) + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.coerce_types(arg_types) + } } -/// Performs type coercion for aggregate function arguments. +impl UDFCoercionExt for WindowUDF { + fn name(&self) -> &str { + self.name() + } + + fn signature(&self) -> &Signature { + self.signature() + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.coerce_types(arg_types) + } +} + +/// Performs type coercion for UDF arguments. /// -/// Returns the fields to which each argument must be coerced to +/// Returns the data types to which each argument must be coerced to /// match `signature`. /// /// For more details on coercion in general, please see the /// [`type_coercion`](crate::type_coercion) module. -pub fn fields_with_aggregate_udf( +pub fn fields_with_udf( current_fields: &[FieldRef], - func: &AggregateUDF, + func: &F, ) -> Result> { let signature = func.signature(); let type_signature = &signature.type_signature; @@ -96,7 +117,10 @@ pub fn fields_with_aggregate_udf( return Ok(vec![]); } else if type_signature.used_to_support_zero_arguments() { // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763 - return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name()); + return plan_err!( + "'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", + func.name() + ); } else { return plan_err!("'{}' does not support zero arguments", func.name()); } @@ -107,8 +131,7 @@ pub fn fields_with_aggregate_udf( .cloned() .collect::>(); - let valid_types = - get_valid_types_with_aggregate_udf(type_signature, ¤t_types, func)?; + let valid_types = get_valid_types_with_udf(type_signature, ¤t_types, func)?; if valid_types .iter() .any(|data_type| data_type == ¤t_types) @@ -129,58 +152,274 @@ pub fn fields_with_aggregate_udf( .collect()) } -/// Performs type coercion for window function arguments. +/// Performs type coercion for higher order function arguments. /// -/// Returns the data types to which each argument must be coerced to -/// match `signature`. +/// For value arguments, returns the field to which each +/// argument must be coerced to match `signature`. +/// For lambda arguments, returns a clone of the associated data +/// +/// Note this does not invokes [crate::HigherOrderUDFImpl::coerce_values_for_lambdas]. +/// If that's required, use [value_fields_with_higher_order_udf_and_lambdas] +/// instead /// /// For more details on coercion in general, please see the /// [`type_coercion`](crate::type_coercion) module. -pub fn fields_with_window_udf( - current_fields: &[FieldRef], - func: &WindowUDF, -) -> Result> { - let signature = func.signature(); - let type_signature = &signature.type_signature; +pub fn value_fields_with_higher_order_udf( + current_fields: &[ValueOrLambda], + func: &HigherOrderUDF, +) -> Result>> { + match func.signature().type_signature { + HigherOrderTypeSignature::UserDefined => { + let arg_types = current_fields + .iter() + .filter_map(|p| match p { + ValueOrLambda::Value(field) => Some(field.data_type().clone()), + ValueOrLambda::Lambda(_) => None, + }) + .collect::>(); - if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined { - if type_signature.supports_zero_argument() { - return Ok(vec![]); - } else if type_signature.used_to_support_zero_arguments() { - // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763 - return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name()); - } else { - return plan_err!("'{}' does not support zero arguments", func.name()); + let coerced_types = func.coerce_value_types(&arg_types)?; + + if coerced_types.len() != arg_types.len() { + return plan_err!( + "{} coerce_value_types should have returned {} items but returned {}", + func.name(), + arg_types.len(), + coerced_types.len() + ); + } + + // coerced_types has been partitioned from current_fields + // and refers only to values and not to lambdas, so instead + // of zipping them, we iterate over current_fields and only + // consume from coerced_types when a given argument is a value + // to reconstruct the arguments list with the correct order + // this supports any value and lambda positioning including + // multiple lambdas interleaved with values + let mut coerced_types = coerced_types.into_iter(); + + current_fields + .iter() + .map(|current_field| match current_field { + ValueOrLambda::Value(field) => { + let data_type = coerced_types.next().ok_or_else(|| { + internal_datafusion_err!( + "coerced_types len should have been checked above" + ) + })?; + + Ok(ValueOrLambda::Value(Arc::new( + field.as_ref().clone().with_data_type(data_type), + ))) + } + ValueOrLambda::Lambda(lambda) => { + Ok(ValueOrLambda::Lambda(lambda.clone())) + } + }) + .collect() + } + HigherOrderTypeSignature::VariadicAny => Ok(current_fields.to_vec()), + HigherOrderTypeSignature::Any(number) => { + if current_fields.len() != number { + return plan_err!( + "The function '{}' expected {number} arguments but received {}", + func.name(), + current_fields.len() + ); + } + + Ok(current_fields.to_vec()) + } + HigherOrderTypeSignature::Exact(ref expected) => { + if current_fields.len() != expected.len() { + let name = func.name(); + let expected_len = expected.len(); + let actual_len = current_fields.len(); + return plan_err!( + "The function '{name}' expected {expected_len} argument(s) but received {actual_len}" + ); + } + + for (i, (actual, expected)) in + current_fields.iter().zip(expected.iter()).enumerate() + { + match (actual, expected) { + (ValueOrLambda::Value(_), ValueOrLambda::Value(_)) => {} + (ValueOrLambda::Lambda(_), ValueOrLambda::Lambda(_)) => {} + (ValueOrLambda::Value(_), ValueOrLambda::Lambda(_)) => { + let name = func.name(); + return plan_err!( + "The function '{name}' expected a lambda at position {i} but received a value" + ); + } + (ValueOrLambda::Lambda(_), ValueOrLambda::Value(_)) => { + let name = func.name(); + return plan_err!( + "The function '{name}' expected a value at position {i} but received a lambda" + ); + } + } + } + + let arg_types = current_fields + .iter() + .filter_map(|p| match p { + ValueOrLambda::Value(field) => Some(field.data_type().clone()), + ValueOrLambda::Lambda(_) => None, + }) + .collect::>(); + + let coerced_types = func.coerce_value_types(&arg_types)?; + + if coerced_types.len() != arg_types.len() { + return plan_err!( + "{} coerce_value_types should have returned {} items but returned {}", + func.name(), + arg_types.len(), + coerced_types.len() + ); + } + + let mut coerced_types = coerced_types.into_iter(); + + current_fields + .iter() + .map(|current_field| match current_field { + ValueOrLambda::Value(field) => { + let data_type = coerced_types.next().ok_or_else(|| { + internal_datafusion_err!( + "coerced_types len should have been checked above" + ) + })?; + + Ok(ValueOrLambda::Value(Arc::new( + field.as_ref().clone().with_data_type(data_type), + ))) + } + ValueOrLambda::Lambda(lambda) => { + Ok(ValueOrLambda::Lambda(lambda.clone())) + } + }) + .collect() } } +} - let current_types = current_fields +/// Performs type coercion for higher order function arguments, +/// including those defined by [crate::HigherOrderUDFImpl::coerce_values_for_lambdas], +/// if it returns `Some(...)` instead of the default `None`. Note that +/// compared to [value_fields_with_higher_order_udf], this function requires +/// the [ValueOrLambda::Lambda] variant to contain the output field of the lambda. +/// +/// For value arguments, returns the field to which each +/// argument must be coerced to match `signature`. +/// For lambda arguments, returns a clone of the output field +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +pub fn value_fields_with_higher_order_udf_and_lambdas( + current_fields: &[ValueOrLambda], + func: &HigherOrderUDF, +) -> Result>> { + let mut new_fields = value_fields_with_higher_order_udf(current_fields, func)?; + + let new_types = new_fields .iter() - .map(|f| f.data_type()) - .cloned() + .map(|f| match f { + ValueOrLambda::Value(f) => ValueOrLambda::Value(f.data_type().clone()), + ValueOrLambda::Lambda(f) => ValueOrLambda::Lambda(f.data_type().clone()), + }) .collect::>(); - let valid_types = - get_valid_types_with_window_udf(type_signature, ¤t_types, func)?; - if valid_types - .iter() - .any(|data_type| data_type == ¤t_types) - { - return Ok(current_fields.to_vec()); - } - let updated_types = - try_coerce_types(func.name(), valid_types, ¤t_types, type_signature)?; + if let Some(new_value_types) = func.coerce_values_for_lambdas(&new_types)? { + let mut new_value_types = new_value_types.into_iter(); - Ok(current_fields + let value_types_count = new_types + .iter() + .filter(|e| matches!(e, ValueOrLambda::Value(_))) + .count(); + + if new_value_types.len() != value_types_count { + return plan_err!( + "{} coerce_values_for_lambdas returned {} values but {value_types_count} expected", + func.name(), + new_value_types.len() + ); + } + + for new_field in &mut new_fields { + match new_field { + ValueOrLambda::Value(value) => { + let coerce_to = new_value_types.next().ok_or_else(|| { + internal_datafusion_err!( + "new_value_types len should have been checked above" + ) + })?; + + if value.data_type() != &coerce_to { + Arc::make_mut(value).set_data_type(coerce_to); + } + } + ValueOrLambda::Lambda(_) => {} + } + } + }; + + Ok(new_fields) +} + +/// Performs type coercion for scalar function arguments. +/// +/// Returns the data types to which each argument must be coerced to +/// match `signature`. +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +#[deprecated(since = "52.0.0", note = "use fields_with_udf")] +pub fn data_types_with_scalar_udf( + current_types: &[DataType], + func: &ScalarUDF, +) -> Result> { + let current_fields = current_types .iter() - .zip(updated_types) - .map(|(current_field, new_type)| { - current_field.as_ref().clone().with_data_type(new_type) - }) - .map(Arc::new) + .map(|dt| Arc::new(Field::new("f", dt.clone(), true))) + .collect::>(); + Ok(fields_with_udf(¤t_fields, func)? + .iter() + .map(|f| f.data_type().clone()) .collect()) } +/// Performs type coercion for aggregate function arguments. +/// +/// Returns the fields to which each argument must be coerced to +/// match `signature`. +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +#[deprecated(since = "52.0.0", note = "use fields_with_udf")] +pub fn fields_with_aggregate_udf( + current_fields: &[FieldRef], + func: &AggregateUDF, +) -> Result> { + fields_with_udf(current_fields, func) +} + +/// Performs type coercion for window function arguments. +/// +/// Returns the data types to which each argument must be coerced to +/// match `signature`. +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +#[deprecated(since = "52.0.0", note = "use fields_with_udf")] +pub fn fields_with_window_udf( + current_fields: &[FieldRef], + func: &WindowUDF, +) -> Result> { + fields_with_udf(current_fields, func) +} + /// Performs type coercion for function arguments. /// /// Returns the data types to which each argument must be coerced to @@ -188,6 +427,7 @@ pub fn fields_with_window_udf( /// /// For more details on coercion in general, please see the /// [`type_coercion`](crate::type_coercion) module. +#[deprecated(since = "52.0.0", note = "use fields_with_udf")] pub fn data_types( function_name: impl AsRef, current_types: &[DataType], @@ -201,12 +441,12 @@ pub fn data_types( } else if type_signature.used_to_support_zero_arguments() { // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763 return plan_err!( - "function '{}' has signature {type_signature:?} which does not support zero arguments. Use TypeSignature::Nullary for zero arguments", + "function '{}' has signature {type_signature} which does not support zero arguments. Use TypeSignature::Nullary for zero arguments", function_name.as_ref() ); } else { return plan_err!( - "Function '{}' has signature {type_signature:?} which does not support zero arguments", + "Function '{}' has signature {type_signature} which does not support zero arguments", function_name.as_ref() ); } @@ -230,20 +470,23 @@ pub fn data_types( } fn is_well_supported_signature(type_signature: &TypeSignature) -> bool { - if let TypeSignature::OneOf(signatures) = type_signature { - return signatures.iter().all(is_well_supported_signature); - } - - matches!( - type_signature, + match type_signature { + TypeSignature::OneOf(type_signatures) => { + type_signatures.iter().all(is_well_supported_signature) + } TypeSignature::UserDefined - | TypeSignature::Numeric(_) - | TypeSignature::String(_) - | TypeSignature::Coercible(_) - | TypeSignature::Any(_) - | TypeSignature::Nullary - | TypeSignature::Comparable(_) - ) + | TypeSignature::Numeric(_) + | TypeSignature::String(_) + | TypeSignature::Coercible(_) + | TypeSignature::Any(_) + | TypeSignature::Nullary + | TypeSignature::Comparable(_) => true, + TypeSignature::Variadic(_) + | TypeSignature::VariadicAny + | TypeSignature::Uniform(_, _) + | TypeSignature::Exact(_) + | TypeSignature::ArraySignature(_) => false, + } } fn try_coerce_types( @@ -279,30 +522,32 @@ fn try_coerce_types( // none possible -> Error plan_err!( - "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {} to the signature {type_signature:?} failed", + "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {} to the signature {type_signature} failed", current_types.iter().join(", ") ) } -fn get_valid_types_with_scalar_udf( +fn get_valid_types_with_udf( signature: &TypeSignature, current_types: &[DataType], - func: &ScalarUDF, + func: &F, ) -> Result>> { - match signature { + let valid_types = match signature { TypeSignature::UserDefined => match func.coerce_types(current_types) { - Ok(coerced_types) => Ok(vec![coerced_types]), - Err(e) => exec_err!( - "Function '{}' user-defined coercion failed with {:?}", - func.name(), - e.strip_backtrace() - ), + Ok(coerced_types) => vec![coerced_types], + Err(e) => { + return exec_err!( + "Function '{}' user-defined coercion failed with: {}", + func.name(), + e.strip_backtrace() + ); + } }, TypeSignature::OneOf(signatures) => { let mut res = vec![]; let mut errors = vec![]; for sig in signatures { - match get_valid_types_with_scalar_udf(sig, current_types, func) { + match get_valid_types_with_udf(sig, current_types, func) { Ok(valid_types) => { res.extend(valid_types); } @@ -314,69 +559,15 @@ fn get_valid_types_with_scalar_udf( // Every signature failed, return the joined error if res.is_empty() { - internal_err!( + return internal_err!( "Function '{}' failed to match any signature, errors: {}", func.name(), errors.join(",") - ) + ); } else { - Ok(res) + res } } - _ => get_valid_types(func.name(), signature, current_types), - } -} - -fn get_valid_types_with_aggregate_udf( - signature: &TypeSignature, - current_types: &[DataType], - func: &AggregateUDF, -) -> Result>> { - let valid_types = match signature { - TypeSignature::UserDefined => match func.coerce_types(current_types) { - Ok(coerced_types) => vec![coerced_types], - Err(e) => { - return exec_err!( - "Function '{}' user-defined coercion failed with {:?}", - func.name(), - e.strip_backtrace() - ) - } - }, - TypeSignature::OneOf(signatures) => signatures - .iter() - .filter_map(|t| { - get_valid_types_with_aggregate_udf(t, current_types, func).ok() - }) - .flatten() - .collect::>(), - _ => get_valid_types(func.name(), signature, current_types)?, - }; - - Ok(valid_types) -} - -fn get_valid_types_with_window_udf( - signature: &TypeSignature, - current_types: &[DataType], - func: &WindowUDF, -) -> Result>> { - let valid_types = match signature { - TypeSignature::UserDefined => match func.coerce_types(current_types) { - Ok(coerced_types) => vec![coerced_types], - Err(e) => { - return exec_err!( - "Function '{}' user-defined coercion failed with {:?}", - func.name(), - e.strip_backtrace() - ) - } - }, - TypeSignature::OneOf(signatures) => signatures - .iter() - .filter_map(|t| get_valid_types_with_window_udf(t, current_types, func).ok()) - .flatten() - .collect::>(), _ => get_valid_types(func.name(), signature, current_types)?, }; @@ -395,6 +586,52 @@ fn get_valid_types( arguments: &[ArrayFunctionArgument], array_coercion: Option<&ListCoercion>, ) -> Result>> { + fn rebuild_array_type( + current_type: &DataType, + element_type: &DataType, + nullable: bool, + large_list: bool, + fixed_size: Option, + ) -> DataType { + // Preserve the original list field when possible so field name or + // metadata differences do not introduce otherwise unnecessary casts. + let field = match current_type { + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) => Some(Arc::new( + field + .as_ref() + .clone() + .with_data_type(element_type.clone()) + .with_nullable(nullable), + )), + _ => None, + }; + + if large_list { + field.map_or_else( + || DataType::new_large_list(element_type.clone(), nullable), + DataType::LargeList, + ) + } else if let Some(size) = fixed_size { + field.map_or_else( + || { + DataType::new_fixed_size_list( + element_type.clone(), + size, + nullable, + ) + }, + |field| DataType::FixedSizeList(field, size), + ) + } else { + field.map_or_else( + || DataType::new_list(element_type.clone(), nullable), + DataType::List, + ) + } + } + if current_types.len() != arguments.len() { return Ok(vec![vec![]]); } @@ -418,12 +655,12 @@ fn get_valid_types( element_types.push(DataType::Null); nested_item_nullability.push(None); } - DataType::List(field) => { + DataType::List(field) | DataType::ListView(field) => { element_types.push(field.data_type().clone()); nested_item_nullability.push(Some(field.is_nullable())); fixed_size = false; } - DataType::LargeList(field) => { + DataType::LargeList(field) | DataType::LargeListView(field) => { element_types.push(field.data_type().clone()); nested_item_nullability.push(Some(field.is_nullable())); large_list = true; @@ -461,24 +698,18 @@ fn get_valid_types( ArrayFunctionArgument::Index => DataType::Int64, ArrayFunctionArgument::String => DataType::Utf8, ArrayFunctionArgument::Element => element_type.clone(), + // TODO: support maintaining ListView types here + // https://github.com/apache/datafusion/issues/21777 ArrayFunctionArgument::Array => { if current_type.is_null() { DataType::Null - } else if large_list { - DataType::new_large_list( - element_type.clone(), - is_nested_item_nullable.unwrap_or(true), - ) - } else if let Some(size) = list_sizes.next() { - DataType::new_fixed_size_list( - element_type.clone(), - size, - is_nested_item_nullable.unwrap_or(true), - ) } else { - DataType::new_list( - element_type.clone(), + rebuild_array_type( + current_type, + &element_type, is_nested_item_nullable.unwrap_or(true), + large_list, + list_sizes.next(), ) } } @@ -492,6 +723,8 @@ fn get_valid_types( match array_type { DataType::List(_) | DataType::LargeList(_) + | DataType::ListView(_) + | DataType::LargeListView(_) | DataType::FixedSizeList(_, _) => { let array_type = coerced_fixed_size_list_to_list(array_type); Some(array_type) @@ -516,7 +749,7 @@ fn get_valid_types( let valid_types = match signature { TypeSignature::Variadic(valid_types) => valid_types .iter() - .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) + .map(|valid_type| vec![valid_type.clone(); current_types.len()]) .collect(), TypeSignature::String(number) => { function_length_check(function_name, current_types.len(), *number)?; @@ -531,7 +764,7 @@ fn get_valid_types( new_types.push(DataType::Utf8); } else { return plan_err!( - "Function '{function_name}' expects NativeType::String but NativeType::received NativeType::{logical_data_type}" + "Function '{function_name}' expects String but received {logical_data_type}" ); } } @@ -591,7 +824,7 @@ fn get_valid_types( if !logical_data_type.is_numeric() { return plan_err!( - "Function '{function_name}' expects NativeType::Numeric but received NativeType::{logical_data_type}" + "Function '{function_name}' expects Numeric but received {logical_data_type}" ); } @@ -612,7 +845,7 @@ fn get_valid_types( valid_type = DataType::Float64; } else if !logical_data_type.is_numeric() { return plan_err!( - "Function '{function_name}' expects NativeType::Numeric but received NativeType::{logical_data_type}" + "Function '{function_name}' expects Numeric but received {logical_data_type}" ); } @@ -622,10 +855,12 @@ fn get_valid_types( function_length_check(function_name, current_types.len(), *num)?; let mut target_type = current_types[0].to_owned(); for data_type in current_types.iter().skip(1) { - if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) { + if let Some(dt) = comparison_coercion(&target_type, data_type) { target_type = dt; } else { - return plan_err!("For function '{function_name}' {target_type} and {data_type} is not comparable"); + return plan_err!( + "For function '{function_name}' {target_type} and {data_type} is not comparable" + ); } } // Convert null to String type. @@ -642,24 +877,33 @@ fn get_valid_types( for (current_type, param) in current_types.iter().zip(param_types.iter()) { let current_native_type: NativeType = current_type.into(); - if param.desired_type().matches_native_type(¤t_native_type) { - let casted_type = param.desired_type().default_casted_type( - ¤t_native_type, - current_type, - )?; + if param + .desired_type() + .matches_native_type(¤t_native_type) + { + let casted_type = param + .desired_type() + .default_casted_type(¤t_native_type, current_type)?; new_types.push(casted_type); } else if param - .allowed_source_types() - .iter() - .any(|t| t.matches_native_type(¤t_native_type)) { + .allowed_source_types() + .iter() + .any(|t| t.matches_native_type(¤t_native_type)) + { // If the condition is met which means `implicit coercion`` is provided so we can safely unwrap let default_casted_type = param.default_casted_type().unwrap(); - let casted_type = default_casted_type.default_cast_for(current_type)?; + let casted_type = + default_casted_type.default_cast_for(current_type)?; new_types.push(casted_type); } else { - return internal_err!( - "Expect {} but received NativeType::{}, DataType: {}", + let hint = if matches!(current_native_type, NativeType::Binary) { + "\n\nHint: Binary types are not automatically coerced to String. Use CAST(column AS VARCHAR) to convert Binary data to String." + } else { + "" + }; + return plan_err!( + "Function '{function_name}' requires {}, but received {} (DataType: {}).{hint}", param.desired_type(), current_native_type, current_type @@ -671,18 +915,20 @@ fn get_valid_types( } TypeSignature::Uniform(number, valid_types) => { if *number == 0 { - return plan_err!("The function '{function_name}' expected at least one argument"); + return plan_err!( + "The function '{function_name}' expected at least one argument" + ); } valid_types .iter() - .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) + .map(|valid_type| vec![valid_type.clone(); *number]) .collect() } TypeSignature::UserDefined => { return internal_err!( "Function '{function_name}' user-defined signature should be handled by function-specific coerce_types" - ) + ); } TypeSignature::VariadicAny => { if current_types.is_empty() { @@ -693,10 +939,16 @@ fn get_valid_types( vec![current_types.to_vec()] } TypeSignature::Exact(valid_types) => vec![valid_types.clone()], - TypeSignature::ArraySignature(ref function_signature) => match function_signature { - ArrayFunctionSignature::Array { arguments, array_coercion, } => { - array_valid_types(function_name, current_types, arguments, array_coercion.as_ref())? - } + TypeSignature::ArraySignature(function_signature) => match function_signature { + ArrayFunctionSignature::Array { + arguments, + array_coercion, + } => array_valid_types( + function_name, + current_types, + arguments, + array_coercion.as_ref(), + )?, ArrayFunctionSignature::RecursiveArray => { if current_types.len() != 1 { return Ok(vec![vec![]]); @@ -737,7 +989,7 @@ fn get_valid_types( current_types.len() ); } - vec![(0..*number).map(|i| current_types[i].clone()).collect()] + vec![current_types.to_vec()] } TypeSignature::OneOf(types) => types .iter() @@ -815,6 +1067,7 @@ fn maybe_data_types_without_coercion( /// (losslessly converted) into a value of `type_to` /// /// See the module level documentation for more detail on coercion. +#[deprecated(since = "53.0.0", note = "Unused internal function")] pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { if type_into == type_from { return true; @@ -861,10 +1114,13 @@ fn coerced_from<'a>( (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()), (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()), (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()), + (Float16, Null | Int8 | Int16 | UInt8 | UInt16 | Float16) => { + Some(type_into.clone()) + } ( Float32, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 - | Float32, + | Float16 | Float32, ) => Some(type_into.clone()), ( Float64, @@ -877,6 +1133,7 @@ fn coerced_from<'a>( | UInt16 | UInt32 | UInt64 + | Float16 | Float32 | Float64 | Decimal32(_, _) @@ -888,18 +1145,20 @@ fn coerced_from<'a>( Timestamp(TimeUnit::Nanosecond, None), Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8, ) => Some(type_into.clone()), - (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()), + (Interval(_), Null | Utf8 | LargeUtf8) => Some(type_into.clone()), // We can go into a Utf8View from a Utf8 or LargeUtf8 (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()), // Any type can be coerced into strings (Utf8 | LargeUtf8, _) => Some(type_into.clone()), + // We can go into a BinaryView from a Binary or LargeBinary + (BinaryView, Binary | LargeBinary | Null) => Some(type_into.clone()), (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()), (List(_), FixedSizeList(_, _)) => Some(type_into.clone()), // Only accept list and largelist with the same number of dimensions unless the type is Null. // List or LargeList with different dimensions should be handled in TypeSignature or other places before this - (List(_) | LargeList(_), _) + (List(_) | LargeList(_) | ListView(_) | LargeListView(_), _) if base_type(type_from).is_null() || list_ndims(type_from) == list_ndims(type_into) => { @@ -933,30 +1192,91 @@ fn coerced_from<'a>( (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => { Some(type_into.clone()) } + // Null can be coerced to any target type, provided the cast is valid. + // This mirrors null_coercion() in binary comparison coercion + // (expr-common/src/type_coercion/binary.rs) and is the symmetric + // counterpart of the (Null, _) arm above. Without this, untyped + // placeholders ($1, $foo) inside function calls fail signature matching + // because their Null type doesn't match any Exact(...) variant. + (_, Null) if can_cast_types(type_from, type_into) => Some(type_into.clone()), _ => None, } } #[cfg(test)] mod tests { - use crate::Volatility; + use crate::{ + HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, HigherOrderSignature, + HigherOrderUDFImpl, Volatility, + }; use super::*; - use arrow::datatypes::Field; - use datafusion_common::assert_contains; + use arrow::datatypes::IntervalUnit; + use datafusion_common::{ + assert_contains, + types::{logical_binary, logical_int64}, + }; + use datafusion_expr_common::{ + columnar_value::ColumnarValue, + signature::{Coercion, TypeSignatureClass}, + }; #[test] fn test_string_conversion() { let cases = vec![ - (DataType::Utf8View, DataType::Utf8, true), - (DataType::Utf8View, DataType::LargeUtf8, true), + (DataType::Utf8View, DataType::Utf8), + (DataType::Utf8View, DataType::LargeUtf8), + (DataType::Utf8View, DataType::Null), + ]; + + for case in cases { + assert_eq!(coerced_from(&case.0, &case.1), Some(case.0)); + } + } + + #[test] + fn test_binary_conversion() { + let cases = vec![ + (DataType::BinaryView, DataType::Binary), + (DataType::BinaryView, DataType::LargeBinary), + (DataType::BinaryView, DataType::Null), ]; for case in cases { - assert_eq!(can_coerce_from(&case.0, &case.1), case.2); + assert_eq!(coerced_from(&case.0, &case.1), Some(case.0)); } } + #[test] + fn test_coerced_from_null() { + // Null should coerce to Interval (the motivating case) + assert_eq!( + coerced_from( + &DataType::Interval(IntervalUnit::MonthDayNano), + &DataType::Null + ), + Some(DataType::Interval(IntervalUnit::MonthDayNano)) + ); + + // Null should coerce to Date32 + assert_eq!( + coerced_from(&DataType::Date32, &DataType::Null), + Some(DataType::Date32) + ); + + // Null should coerce to Timestamp with timezone + assert_eq!( + coerced_from( + &DataType::Timestamp(TimeUnit::Microsecond, Some("+00".into())), + &DataType::Null + ), + Some(DataType::Timestamp( + TimeUnit::Microsecond, + Some("+00".into()) + )) + ); + } + #[test] fn test_maybe_data_types() { // this vec contains: arg1, arg2, expected result @@ -1057,7 +1377,7 @@ mod tests { .unwrap_err(); assert_contains!( got.to_string(), - "Function 'test' expects NativeType::Numeric but received NativeType::String" + "Function 'test' expects Numeric but received String" ); // Fallbacks to float64 if the arg is of type null. @@ -1077,7 +1397,7 @@ mod tests { .unwrap_err(); assert_contains!( got.to_string(), - "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(Second, None)" + "Function 'test' expects Numeric but received Timestamp(s)" ); Ok(()) @@ -1132,12 +1452,29 @@ mod tests { Ok(()) } + struct MockUdf(Signature); + + impl UDFCoercionExt for MockUdf { + fn name(&self) -> &str { + "test" + } + fn signature(&self) -> &Signature { + &self.0 + } + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + unimplemented!() + } + } + #[test] fn test_fixed_list_wildcard_coerce() -> Result<()> { let inner = Arc::new(Field::new_list_field(DataType::Int32, false)); - let current_types = vec![ - DataType::FixedSizeList(Arc::clone(&inner), 2), // able to coerce for any size - ]; + // able to coerce for any size + let current_fields = vec![Arc::new(Field::new( + "t", + DataType::FixedSizeList(Arc::clone(&inner), 2), + true, + ))]; let signature = Signature::exact( vec![DataType::FixedSizeList( @@ -1147,24 +1484,25 @@ mod tests { Volatility::Stable, ); - let coerced_data_types = data_types("test", ¤t_types, &signature)?; - assert_eq!(coerced_data_types, current_types); + let coerced_fields = fields_with_udf(¤t_fields, &MockUdf(signature))?; + assert_eq!(coerced_fields, current_fields); // make sure it can't coerce to a different size let signature = Signature::exact( vec![DataType::FixedSizeList(Arc::clone(&inner), 3)], Volatility::Stable, ); - let coerced_data_types = data_types("test", ¤t_types, &signature); - assert!(coerced_data_types.is_err()); + let coerced_fields = fields_with_udf(¤t_fields, &MockUdf(signature)); + assert!(coerced_fields.is_err()); // make sure it works with the same type. let signature = Signature::exact( vec![DataType::FixedSizeList(Arc::clone(&inner), 2)], Volatility::Stable, ); - let coerced_data_types = data_types("test", ¤t_types, &signature).unwrap(); - assert_eq!(coerced_data_types, current_types); + let coerced_fields = + fields_with_udf(¤t_fields, &MockUdf(signature)).unwrap(); + assert_eq!(coerced_fields, current_fields); Ok(()) } @@ -1271,6 +1609,54 @@ mod tests { ]] ); + let data_types = vec![ + DataType::ListView(Field::new_list_field(DataType::Int32, true).into()), + DataType::new_list(DataType::Int32, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_list(DataType::Int32, true), + DataType::new_list(DataType::Int32, true), + ]] + ); + + let data_types = vec![ + DataType::LargeListView(Field::new_list_field(DataType::Int32, true).into()), + DataType::new_list(DataType::Int32, true), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_large_list(DataType::Int32, true), + DataType::new_large_list(DataType::Int32, true), + ]] + ); + + let data_types = vec![ + DataType::ListView(Field::new_list_field(DataType::Int32, true).into()), + DataType::ListView(Field::new_list_field(DataType::Int32, true).into()), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_list(DataType::Int32, true), + DataType::new_list(DataType::Int32, true), + ]] + ); + + let data_types = vec![ + DataType::LargeListView(Field::new_list_field(DataType::Int32, true).into()), + DataType::LargeListView(Field::new_list_field(DataType::Int32, true).into()), + ]; + assert_eq!( + get_valid_types(function, &signature.type_signature, &data_types)?, + vec![vec![ + DataType::new_large_list(DataType::Int32, true), + DataType::new_large_list(DataType::Int32, true), + ]] + ); + Ok(()) } @@ -1316,6 +1702,31 @@ mod tests { Ok(()) } + #[test] + fn test_get_valid_types_array_and_index_preserves_list_field_name() -> Result<()> { + let struct_fields = vec![ + Field::new("id", DataType::Utf8, true), + Field::new("prim", DataType::Boolean, true), + ]; + let current_type = DataType::List(Arc::new(Field::new( + "element", + DataType::Struct(struct_fields.into()), + true, + ))); + let signature = Signature::array_and_index(Volatility::Immutable); + + assert_eq!( + get_valid_types( + "array_element", + &signature.type_signature, + &[current_type.clone(), DataType::Int64], + )?, + vec![vec![current_type, DataType::Int64]] + ); + + Ok(()) + } + #[test] fn test_get_valid_types_element_and_array() -> Result<()> { let function = "element_and_array"; @@ -1336,6 +1747,164 @@ mod tests { Ok(()) } + #[test] + fn test_coercible_nulls() -> Result<()> { + fn null_input(coercion: Coercion) -> Result> { + fields_with_udf( + &[Field::new("field", DataType::Null, true).into()], + &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)), + ) + .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect()) + } + + // Casts Null to Int64 if we use TypeSignatureClass::Native + let output = null_input(Coercion::new_exact(TypeSignatureClass::Native( + logical_int64(), + )))?; + assert_eq!(vec![DataType::Int64], output); + + let output = null_input(Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![DataType::Int64], output); + + // Null gets passed through if we use TypeSignatureClass apart from Native + let output = null_input(Coercion::new_exact(TypeSignatureClass::Integer))?; + assert_eq!(vec![DataType::Null], output); + + let output = null_input(Coercion::new_implicit( + TypeSignatureClass::Integer, + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![DataType::Null], output); + + Ok(()) + } + + #[test] + fn test_coercible_dictionary() -> Result<()> { + let dictionary = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int64)); + fn dictionary_input(coercion: Coercion) -> Result> { + fields_with_udf( + &[Field::new( + "field", + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Int64), + ), + true, + ) + .into()], + &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)), + ) + .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect()) + } + + // Casts Dictionary to Int64 if we use TypeSignatureClass::Native + let output = dictionary_input(Coercion::new_exact(TypeSignatureClass::Native( + logical_int64(), + )))?; + assert_eq!(vec![DataType::Int64], output); + + let output = dictionary_input(Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![DataType::Int64], output); + + // Dictionary gets passed through if we use TypeSignatureClass apart from Native + let output = dictionary_input(Coercion::new_exact(TypeSignatureClass::Integer))?; + assert_eq!(vec![dictionary.clone()], output); + + let output = dictionary_input(Coercion::new_implicit( + TypeSignatureClass::Integer, + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![dictionary.clone()], output); + + Ok(()) + } + + #[test] + fn test_coercible_run_end_encoded() -> Result<()> { + let run_end_encoded = DataType::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Int64, true).into(), + ); + fn run_end_encoded_input(coercion: Coercion) -> Result> { + fields_with_udf( + &[Field::new( + "field", + DataType::RunEndEncoded( + Field::new("run_ends", DataType::Int16, false).into(), + Field::new("values", DataType::Int64, true).into(), + ), + true, + ) + .into()], + &MockUdf(Signature::coercible(vec![coercion], Volatility::Immutable)), + ) + .map(|v| v.into_iter().map(|f| f.data_type().clone()).collect()) + } + + // Casts REE to Int64 if we use TypeSignatureClass::Native + let output = run_end_encoded_input(Coercion::new_exact( + TypeSignatureClass::Native(logical_int64()), + ))?; + assert_eq!(vec![DataType::Int64], output); + + let output = run_end_encoded_input(Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![DataType::Int64], output); + + // REE gets passed through if we use TypeSignatureClass apart from Native + let output = + run_end_encoded_input(Coercion::new_exact(TypeSignatureClass::Integer))?; + assert_eq!(vec![run_end_encoded.clone()], output); + + let output = run_end_encoded_input(Coercion::new_implicit( + TypeSignatureClass::Integer, + vec![], + NativeType::Int64, + ))?; + assert_eq!(vec![run_end_encoded.clone()], output); + + Ok(()) + } + + #[test] + fn test_get_valid_types_coercible_binary() -> Result<()> { + let signature = Signature::coercible( + vec![Coercion::new_exact(TypeSignatureClass::Native( + logical_binary(), + ))], + Volatility::Immutable, + ); + + // Binary types should stay their original selves + for t in [ + DataType::Binary, + DataType::BinaryView, + DataType::LargeBinary, + ] { + assert_eq!( + get_valid_types("", &signature.type_signature, std::slice::from_ref(&t))?, + vec![vec![t]] + ); + } + + Ok(()) + } + #[test] fn test_get_valid_types_fixed_size_arrays() -> Result<()> { let function = "fixed_size_arrays"; @@ -1388,4 +1957,291 @@ mod tests { Ok(()) } + + #[derive(Debug, PartialEq, Eq, Hash)] + struct MockHigherOrderUDF { + signature: HigherOrderSignature, + coerced_value_types: Vec, + } + + impl HigherOrderUDFImpl for MockHigherOrderUDF { + fn name(&self) -> &str { + "mock_higher_order_function" + } + + fn signature(&self) -> &HigherOrderSignature { + &self.signature + } + + fn coerce_value_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return plan_err!( + "mock_higher_order_function expects 1 value arguments, got {}", + arg_types.len() + ); + } + Ok(self.coerced_value_types.clone()) + } + + fn coerce_values_for_lambdas( + &self, + fields: &[ValueOrLambda], + ) -> Result>> { + // thoerical impl of array_reduce without finish + let [ + ValueOrLambda::Value(list), + ValueOrLambda::Value(_initial), + ValueOrLambda::Lambda(merge), + ] = fields + else { + unreachable!() + }; + + Ok(Some(vec![list.clone(), merge.clone()])) + } + + fn lambda_parameters( + &self, + _step: usize, + _fields: &[ValueOrLambda>], + ) -> Result { + unimplemented!("mock_higher_order_function") + } + + fn return_field_from_args( + &self, + _args: HigherOrderReturnFieldArgs, + ) -> Result { + unimplemented!("mock_higher_order_function") + } + + fn invoke_with_args( + &self, + _args: HigherOrderFunctionArgs, + ) -> Result { + unimplemented!("mock_higher_order_function") + } + } + + #[test] + fn test_higher_order_function_user_defined_type_coercion() { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { + signature: HigherOrderSignature::user_defined(Volatility::Immutable), + coerced_value_types: vec![DataType::new_large_list(DataType::Int32, false)], + }); + + let new_fields = value_fields_with_higher_order_udf( + &[ + ValueOrLambda::Value(Arc::new(Field::new_list( + "", + Field::new_list_field(DataType::Int32, false), + false, + ))), + ValueOrLambda::Lambda(()), + ], + &fun, + ) + .unwrap(); + + // from List(Int32) to LargeList(Int32) + assert_eq!( + new_fields, + vec![ + ValueOrLambda::Value(Arc::new(Field::new_large_list( + "", + Field::new_list_field(DataType::Int32, false), + false + ))), + ValueOrLambda::Lambda(()), + ] + ) + } + + #[test] + fn test_higher_order_function_coerce_values_for_lambdas() { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { + signature: HigherOrderSignature::variadic_any(Volatility::Immutable), + coerced_value_types: vec![], + }); + + let new_fields = value_fields_with_higher_order_udf_and_lambdas( + &[ + ValueOrLambda::Value(Arc::new(Field::new_list( + "", + Field::new_list_field(DataType::Float32, true), + true, + ))), + ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, true))), + ValueOrLambda::Lambda(Arc::new(Field::new("", DataType::Float32, true))), + ], + &fun, + ) + .unwrap(); + + // second parameter from Int32 to Float32 + assert_eq!( + new_fields, + vec![ + ValueOrLambda::Value(Arc::new(Field::new_list( + "", + Field::new_list_field(DataType::Float32, true), + true, + ))), + ValueOrLambda::Value(Arc::new(Field::new("", DataType::Float32, true))), + ValueOrLambda::Lambda(Arc::new(Field::new("", DataType::Float32, true))), + ] + ) + } + + #[test] + fn test_higher_order_function_user_defined_type_coercion_bad_args() { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { + signature: HigherOrderSignature::user_defined(Volatility::Immutable), + coerced_value_types: vec![DataType::Int32], + }); + + let err = value_fields_with_higher_order_udf::<()>(&[], &fun).unwrap_err(); + + assert_contains!( + err.to_string(), + "mock_higher_order_function expects 1 value arguments, got 0" + ); + } + + #[test] + fn test_higher_order_function_faulty_user_defined_type_coercion() { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { + signature: HigherOrderSignature::user_defined(Volatility::Immutable), + coerced_value_types: vec![DataType::Int32, DataType::Int32], + }); + + let err = value_fields_with_higher_order_udf::<()>( + &[ValueOrLambda::Value(Arc::new(Field::new( + "", + DataType::Int32, + false, + )))], + &fun, + ) + .unwrap_err(); + + assert_contains!( + err.to_string(), + "mock_higher_order_function coerce_value_types should have returned 1 items but returned 2" + ); + } + + #[test] + fn test_higher_order_function_any_signature() { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { + signature: HigherOrderSignature::any(1, Volatility::Immutable), + coerced_value_types: vec![], + }); + + let new_fields = + value_fields_with_higher_order_udf(&[ValueOrLambda::Lambda(())], &fun) + .unwrap(); + + // no coercion, just number of args checked + assert_eq!(new_fields, vec![ValueOrLambda::Lambda(())]) + } + + #[test] + fn test_higher_order_function_any_signature_bad_args() { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { + signature: HigherOrderSignature::any(1, Volatility::Immutable), + coerced_value_types: vec![], + }); + + let err = value_fields_with_higher_order_udf::<()>(&[], &fun).unwrap_err(); + + assert_contains!( + err.to_string(), + "The function 'mock_higher_order_function' expected 1 arguments but received 0" + ); + } + + #[test] + fn test_higher_order_function_exact_signature() { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { + signature: HigherOrderSignature::exact( + vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())], + Volatility::Immutable, + ), + coerced_value_types: vec![DataType::new_large_list(DataType::Int32, false)], + }); + + let new_fields = value_fields_with_higher_order_udf( + &[ + ValueOrLambda::Value(Arc::new(Field::new_list( + "", + Field::new_list_field(DataType::Int32, false), + false, + ))), + ValueOrLambda::Lambda(()), + ], + &fun, + ) + .unwrap(); + + // type coercion applied: List(Int32) -> LargeList(Int32) + assert_eq!( + new_fields, + vec![ + ValueOrLambda::Value(Arc::new(Field::new_large_list( + "", + Field::new_list_field(DataType::Int32, false), + false + ))), + ValueOrLambda::Lambda(()), + ] + ) + } + + #[test] + fn test_higher_order_function_exact_signature_wrong_value_count() { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { + signature: HigherOrderSignature::exact( + vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())], + Volatility::Immutable, + ), + coerced_value_types: vec![], + }); + + let err = value_fields_with_higher_order_udf::<()>( + &[ValueOrLambda::Lambda(()), ValueOrLambda::Lambda(())], + &fun, + ) + .unwrap_err(); + + assert_contains!( + err.to_string(), + "expected a value at position 0 but received a lambda" + ); + } + + #[test] + fn test_higher_order_function_exact_signature_wrong_lambda_count() { + let fun = HigherOrderUDF::new_from_impl(MockHigherOrderUDF { + signature: HigherOrderSignature::exact( + vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())], + Volatility::Immutable, + ), + coerced_value_types: vec![], + }); + + let err = value_fields_with_higher_order_udf::<()>( + &[ + ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, false))), + ValueOrLambda::Value(Arc::new(Field::new("", DataType::Int32, false))), + ], + &fun, + ) + .unwrap_err(); + + assert_contains!( + err.to_string(), + "expected a lambda at position 1 but received a value" + ); + } } diff --git a/datafusion/expr/src/type_coercion/mod.rs b/datafusion/expr/src/type_coercion/mod.rs index bd1acd3f3a2e2..c92d434e34abe 100644 --- a/datafusion/expr/src/type_coercion/mod.rs +++ b/datafusion/expr/src/type_coercion/mod.rs @@ -58,11 +58,6 @@ pub fn is_signed_numeric(dt: &DataType) -> bool { ) } -/// Determine whether the given data type `dt` is `Null`. -pub fn is_null(dt: &DataType) -> bool { - *dt == DataType::Null -} - /// Determine whether the given data type `dt` is a `Timestamp`. pub fn is_timestamp(dt: &DataType) -> bool { matches!(dt, DataType::Timestamp(_, _)) @@ -80,22 +75,3 @@ pub fn is_datetime(dt: &DataType) -> bool { DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _) ) } - -/// Determine whether the given data type `dt` is a `Utf8` or `Utf8View` or `LargeUtf8`. -pub fn is_utf8_or_utf8view_or_large_utf8(dt: &DataType) -> bool { - matches!( - dt, - DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 - ) -} - -/// Determine whether the given data type `dt` is a `Decimal`. -pub fn is_decimal(dt: &DataType) -> bool { - matches!( - dt, - DataType::Decimal32(_, _) - | DataType::Decimal64(_, _) - | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) - ) -} diff --git a/datafusion/expr/src/type_coercion/other.rs b/datafusion/expr/src/type_coercion/other.rs index 634558094ae79..48125b661e2ca 100644 --- a/datafusion/expr/src/type_coercion/other.rs +++ b/datafusion/expr/src/type_coercion/other.rs @@ -17,38 +17,58 @@ use arrow::datatypes::DataType; -use super::binary::comparison_coercion; +use super::binary::{comparison_coercion, type_union_coercion}; + +/// Fold `coerce_fn` over `types`, starting from `initial_type`. +fn fold_coerce( + initial_type: &DataType, + types: &[DataType], + coerce_fn: fn(&DataType, &DataType) -> Option, +) -> Option { + types + .iter() + .try_fold(initial_type.clone(), |left_type, right_type| { + coerce_fn(&left_type, right_type) + }) +} /// Attempts to coerce the types of `list_types` to be comparable with the -/// `expr_type`. -/// Returns the common data type for `expr_type` and `list_types` +/// `expr_type` for IN list predicates. +/// Returns the common data type for `expr_type` and `list_types`. +/// +/// Uses comparison coercion because `x IN (a, b)` is semantically equivalent +/// to `x = a OR x = b`. pub fn get_coerce_type_for_list( expr_type: &DataType, list_types: &[DataType], ) -> Option { - list_types - .iter() - .try_fold(expr_type.clone(), |left_type, right_type| { - comparison_coercion(&left_type, right_type) - }) + fold_coerce(expr_type, list_types, comparison_coercion) +} + +/// Find a common coerceable type for `CASE expr WHEN val1 WHEN val2 ...` +/// conditions. Returns the common type for `case_type` and all `when_types`. +/// +/// Uses comparison coercion because `CASE expr WHEN val` is semantically +/// equivalent to `expr = val`. +pub fn get_coerce_type_for_case_when( + when_types: &[DataType], + case_type: &DataType, +) -> Option { + fold_coerce(case_type, when_types, comparison_coercion) } -/// Find a common coerceable type for all `when_or_then_types` as well -/// and the `case_or_else_type`, if specified. -/// Returns the common data type for `when_or_then_types` and `case_or_else_type` +/// Find a common coerceable type for CASE THEN/ELSE result expressions. +/// Returns the common data type for `then_types` and `else_type`. +/// +/// Uses type union coercion because the result branches must be brought to a +/// common type (like UNION), not compared. pub fn get_coerce_type_for_case_expression( - when_or_then_types: &[DataType], - case_or_else_type: Option<&DataType>, + then_types: &[DataType], + else_type: Option<&DataType>, ) -> Option { - let case_or_else_type = match case_or_else_type { - None => when_or_then_types[0].clone(), - Some(data_type) => data_type.clone(), + let (initial_type, remaining) = match else_type { + None => then_types.split_first()?, + Some(data_type) => (data_type, then_types), }; - when_or_then_types - .iter() - .try_fold(case_or_else_type, |left_type, right_type| { - // TODO: now just use the `equal` coercion rule for case when. If find the issue, and - // refactor again. - comparison_coercion(&left_type, right_type) - }) + fold_coerce(initial_type, remaining, type_union_coercion) } diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index e2ae697deedfb..54957c273abcc 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -26,23 +26,24 @@ use std::vec; use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics}; +use datafusion_common::{Result, ScalarValue, Statistics, exec_err, not_impl_err}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; +use datafusion_expr_common::operator::Operator; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use crate::expr::{ + AggregateFunction, AggregateFunctionParams, ExprListDisplay, WindowFunctionParams, schema_name_from_exprs, schema_name_from_exprs_comma_separated_without_space, - schema_name_from_sorts, AggregateFunction, AggregateFunctionParams, ExprListDisplay, - WindowFunctionParams, + schema_name_from_sorts, }; use crate::function::{ AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, }; use crate::groups_accumulator::GroupsAccumulator; use crate::udf_eq::UdfEq; -use crate::utils::format_state_name; use crate::utils::AggregateOrderSensitivity; -use crate::{expr_vec_fmt, Accumulator, Expr}; +use crate::utils::format_state_name; +use crate::{Accumulator, Expr, expr_vec_fmt}; use crate::{Documentation, Signature}; /// Logical representation of a user-defined [aggregate function] (UDAF). @@ -74,8 +75,8 @@ use crate::{Documentation, Signature}; /// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function /// [`Accumulator`]: Accumulator /// [`create_udaf`]: crate::expr_fn::create_udaf -/// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs -/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs +/// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/simple_udaf.rs +/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udaf.rs #[derive(Debug, Clone, PartialOrd)] pub struct AggregateUDF { inner: Arc, @@ -83,7 +84,7 @@ pub struct AggregateUDF { impl PartialEq for AggregateUDF { fn eq(&self, other: &Self) -> bool { - self.inner.dyn_eq(other.inner.as_any()) + self.inner.dyn_eq(other.inner.as_ref() as &dyn Any) } } @@ -294,13 +295,28 @@ impl AggregateUDF { self.inner.reverse_expr() } - /// Do the function rewrite + /// Returns this aggregate function's simplification hook, if any. /// /// See [`AggregateUDFImpl::simplify`] for more details. pub fn simplify(&self) -> Option { self.inner.simplify() } + /// Rewrite aggregate to have simpler arguments + /// + /// See [`AggregateUDFImpl::simplify_expr_op_literal`] for more details + pub fn simplify_expr_op_literal( + &self, + agg_function: &AggregateFunction, + arg: &Expr, + op: Operator, + lit: &Expr, + arg_is_left: bool, + ) -> Result> { + self.inner + .simplify_expr_op_literal(agg_function, arg, op, lit, arg_is_left) + } + /// Returns true if the function is max, false if the function is min /// None in all other cases, used in certain optimizations for /// or aggregate @@ -360,7 +376,7 @@ where /// See [`advanced_udaf.rs`] for a full example with complete implementation and /// [`AggregateUDF`] for other available options. /// -/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs +/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udaf.rs /// /// # Basic Example /// ``` @@ -399,7 +415,6 @@ where /// /// /// Implement the AggregateUDFImpl trait for GeoMeanUdf /// impl AggregateUDFImpl for GeoMeanUdf { -/// fn as_any(&self) -> &dyn Any { self } /// fn name(&self) -> &str { "geo_mean" } /// fn signature(&self) -> &Signature { &self.signature } /// fn return_type(&self, args: &[DataType]) -> Result { @@ -427,10 +442,7 @@ where /// // Call the function `geo_mean(col)` /// let expr = geometric_mean.call(vec![col("a")]); /// ``` -pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync { - /// Returns this object as an [`Any`] trait object - fn as_any(&self) -> &dyn Any; - +pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync + Any { /// Returns this function's name fn name(&self) -> &str; @@ -565,11 +577,12 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync { /// be derived from `name`. See [`format_state_name`] for a utility function /// to generate a unique name. fn state_fields(&self, args: StateFieldsArgs) -> Result> { - let fields = vec![args - .return_field - .as_ref() - .clone() - .with_name(format_state_name(args.name, "value"))]; + let fields = vec![ + args.return_field + .as_ref() + .clone() + .with_name(format_state_name(args.name, "value")), + ]; Ok(fields .into_iter() @@ -650,26 +663,34 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync { AggregateOrderSensitivity::HardRequirement } - /// Optionally apply per-UDaF simplification / rewrite rules. + /// Returns an optional hook for simplifying this user-defined aggregate. + /// + /// Use this hook to apply function-specific rewrites during optimization. + /// The default implementation returns `None`. /// - /// This can be used to apply function specific simplification rules during - /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default - /// implementation does nothing. + /// For example, `percentile_cont(x, 0.0)` and `percentile_cont(x, 1.0)` can + /// be rewritten to `MIN(x)` or `MAX(x)` depending on the `ORDER BY` + /// direction. /// - /// Note that DataFusion handles simplifying arguments and "constant - /// folding" (replacing a function call with constant arguments such as - /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such - /// optimizations manually for specific UDFs. + /// DataFusion already simplifies arguments and performs constant folding + /// (for example, `my_add(1, 2) -> 3`). For nested expressions, the optimizer + /// runs simplification in multiple passes, so arguments are typically + /// simplified before this hook is invoked. As a result, UDF implementations + /// usually do not need to handle argument simplification themselves. + /// + /// See configuration `datafusion.optimizer.max_passes` for details on how many + /// optimization passes may be applied. /// /// # Returns /// - /// [None] if simplify is not defined or, + /// `None` if simplify is not defined. /// - /// Or, a closure with two arguments: - /// * 'aggregate_function': [AggregateFunction] for which simplified has been invoked - /// * 'info': [crate::simplify::SimplifyInfo] + /// Or, a closure ([`AggregateFunctionSimplification`]) invoked with: + /// * `aggregate_function`: [AggregateFunction] with already simplified + /// arguments + /// * `info`: [crate::simplify::SimplifyContext] /// - /// closure returns simplified [Expr] or an error. + /// The closure returns a simplified [Expr] or an error. /// /// # Notes /// @@ -682,6 +703,74 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync { None } + /// Rewrite the aggregate to have simpler arguments + /// + /// This query pattern is not common in most real workloads, and most + /// aggregate implementations can safely ignore it. This API is included in + /// DataFusion because it is important for ClickBench Q29. See backstory + /// on + /// + /// # Rewrite Overview + /// + /// The idea is to rewrite multiple aggregates with "complex arguments" into + /// ones with simpler arguments that can be optimized by common subexpression + /// elimination (CSE). At a high level the rewrite looks like + /// + /// * `Aggregate(SUM(x + 1), SUM(x + 2), ...)` + /// + /// Into + /// + /// * `Aggregate(SUM(x) + 1 * COUNT(x), SUM(x) + 2 * COUNT(x), ...)` + /// + /// While this rewrite may seem worse (slower) than the original as it + /// computes *more* aggregate expressions, the common subexpression + /// elimination (CSE) can then reduce the number of distinct aggregates the + /// query actually needs to compute with a rewrite like + /// + /// * `Projection(_A + 1*_B, _A + 2*_B)` + /// * ` Aggregate(_A = SUM(x), _B = COUNT(x))` + /// + /// This optimization is extremely important for ClickBench Q29, which has 90 + /// such expressions for some reason, and so this optimization results in + /// only two aggregates being needed. The DataFusion optimizer will invoke + /// this method when it detects multiple aggregates in a query that share + /// arguments of the form ` `. + /// + /// # API + /// + /// If `agg_function` supports the rewrite, it should return a semantically + /// equivalent expression (likely with more aggregate expressions, but + /// simpler arguments) + /// + /// This is only called when: + /// 1. There are no "special" aggregate params (filters, null handling, etc) + /// 2. Aggregate functions with exactly one [`Expr`] argument + /// 3. There are no volatile expressions + /// + /// Arguments + /// * `agg_function`: the original aggregate function detected with complex + /// arguments. + /// * `arg`: The common argument shared across multiple aggregates (e.g. `x` + /// in the example above) + /// * `op`: the operator between the common argument and the literal (e.g. + /// `+` in `x + 1` or `1 + x`) + /// * `lit`: the literal argument (e.g. `1` or `2` in the example above) + /// * `arg_is_left`: whether the common argument is on the left or right of + /// the operator (e.g. `true` for `x + 1` and false for `1 + x`) + /// + /// The default implementation returns `None`, which is what most aggregates + /// should do. + fn simplify_expr_op_literal( + &self, + _agg_function: &AggregateFunction, + _arg: &Expr, + _op: Operator, + _lit: &Expr, + _arg_is_left: bool, + ) -> Result> { + Ok(None) + } + /// Returns the reverse expression of the aggregate function. fn reverse_expr(&self) -> ReversedUDAF { ReversedUDAF::NotSupported @@ -818,9 +907,28 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync { } } +impl dyn AggregateUDFImpl { + /// Returns `true` if the implementation is of type `T`. + /// + /// Works correctly when called on `Arc` via auto-deref. + pub fn is(&self) -> bool { + (self as &dyn Any).is::() + } + + /// Attempts to downcast to a concrete type `T`, returning `None` if the + /// implementation is not of that type. + /// + /// Works correctly when called on `Arc` via auto-deref, + /// unlike `(&arc as &dyn Any).downcast_ref::()` which would attempt to + /// downcast the `Arc` itself. + pub fn downcast_ref(&self) -> Option<&T> { + (self as &dyn Any).downcast_ref() + } +} + impl PartialEq for dyn AggregateUDFImpl { fn eq(&self, other: &Self) -> bool { - self.dyn_eq(other.as_any()) + self.dyn_eq(other) } } @@ -1138,10 +1246,6 @@ impl AliasedAggregateUDFImpl { #[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method impl AggregateUDFImpl for AliasedAggregateUDFImpl { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { self.inner.name() } @@ -1234,6 +1338,18 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { self.inner.simplify() } + fn simplify_expr_op_literal( + &self, + agg_function: &AggregateFunction, + arg: &Expr, + op: Operator, + lit: &Expr, + arg_is_left: bool, + ) -> Result> { + self.inner + .simplify_expr_op_literal(agg_function, arg, op, lit, arg_is_left) + } + fn reverse_expr(&self) -> ReversedUDAF { self.inner.reverse_expr() } @@ -1310,7 +1426,6 @@ mod test { use datafusion_functions_aggregate_common::accumulator::{ AccumulatorArgs, StateFieldsArgs, }; - use std::any::Any; use std::cmp::Ordering; use std::hash::{DefaultHasher, Hash, Hasher}; @@ -1332,9 +1447,6 @@ mod test { } impl AggregateUDFImpl for AMeanUdf { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "a" } @@ -1372,9 +1484,6 @@ mod test { } impl AggregateUDFImpl for BMeanUdf { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "b" } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 92caf5427d637..6a3aa31a8609f 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -19,23 +19,44 @@ use crate::async_udf::AsyncScalarUDF; use crate::expr::schema_name_from_exprs_comma_separated_without_space; -use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; +use crate::preimage::PreimageResult; +use crate::simplify::{ExprSimplifyResult, SimplifyContext}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::udf_eq::UdfEq; use crate::{ColumnarValue, Documentation, Expr, Signature}; use arrow::datatypes::{DataType, Field, FieldRef}; +#[cfg(debug_assertions)] +use datafusion_common::assert_or_internal_err; use datafusion_common::config::ConfigOptions; -use datafusion_common::{ - assert_or_internal_err, not_impl_err, ExprSchema, Result, ScalarValue, -}; +use datafusion_common::{ExprSchema, Result, ScalarValue, not_impl_err}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::placement::ExpressionPlacement; use std::any::Any; use std::cmp::Ordering; use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::sync::Arc; +/// Describes how a struct-producing UDF's output fields correspond to its +/// input arguments. This enables the optimizer to propagate orderings +/// through struct projections (e.g., so that sorting by a struct field +/// can be recognized as equivalent to sorting by the source column). +/// +/// See [`ScalarUDFImpl::struct_field_mapping`] for details. +pub struct StructFieldMapping { + /// The UDF used to construct field access expressions on the output. + /// For example, the `get_field` UDF for accessing struct fields. + pub field_accessor: Arc, + /// For each output field: the literal arguments to pass to the + /// `field_accessor` UDF (after the base expression), and the index + /// of the corresponding input argument that produces the field's value. + /// + /// For `named_struct('a', col1, 'b', col2)`, this would be: + /// `[(["a"], 1), (["b"], 3)]` — field `"a"` comes from arg index 1. + pub fields: Vec<(Vec, usize)>, +} + /// Logical representation of a Scalar User Defined Function. /// /// A scalar function produces a single row output for each row of input. This @@ -56,8 +77,8 @@ use std::sync::Arc; /// compatibility with the older API. /// /// [`create_udf`]: crate::expr_fn::create_udf -/// [`simple_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udf.rs -/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +/// [`simple_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/simple_udf.rs +/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udf.rs #[derive(Debug, Clone)] pub struct ScalarUDF { inner: Arc, @@ -65,7 +86,7 @@ pub struct ScalarUDF { impl PartialEq for ScalarUDF { fn eq(&self, other: &Self) -> bool { - self.inner.dyn_eq(other.inner.as_any()) + self.inner.as_ref().dyn_eq(other.inner.as_ref() as &dyn Any) } } @@ -91,7 +112,8 @@ impl PartialOrd for ScalarUDF { "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \ The functions compare as equal, but they are not equal based on general properties that \ the PartialOrd implementation observes,", - self.name(), other.name() + self.name(), + other.name() ); Some(cmp) } @@ -214,23 +236,35 @@ impl ScalarUDF { self.inner.return_field_from_args(args) } - /// Do the function rewrite + /// Returns this scalar function's simplification result. /// /// See [`ScalarUDFImpl::simplify`] for more details. pub fn simplify( &self, args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { self.inner.simplify(args, info) } #[deprecated(since = "50.0.0", note = "Use `return_field_from_args` instead.")] pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { - #[allow(deprecated)] + #[expect(deprecated)] self.inner.is_nullable(args, schema) } + /// Return a preimage + /// + /// See [`ScalarUDFImpl::preimage`] for more details. + pub fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + info: &SimplifyContext, + ) -> Result { + self.inner.preimage(args, lit_expr, info) + } + /// Invoke the function on `args`, returning the appropriate result. /// /// See [`ScalarUDFImpl::invoke_with_args`] for details. @@ -246,7 +280,7 @@ impl ScalarUDF { let expected_type = return_field.data_type(); assert_or_internal_err!( result_data_type == *expected_type, - "Function '{}' returned value of type '{:?}' while the following type was promised at planning time and expected: '{:?}'", + "Function '{}' returned value of type '{}' while the following type was promised at planning time and expected: '{}'", self.name(), result_data_type, expected_type @@ -290,6 +324,14 @@ impl ScalarUDF { self.inner.evaluate_bounds(inputs) } + /// See [`ScalarUDFImpl::struct_field_mapping`] for more details. + pub fn struct_field_mapping( + &self, + literal_args: &[Option], + ) -> Option { + self.inner.struct_field_mapping(literal_args) + } + /// Updates bounds for child expressions, given a known interval for this /// function. This is used to propagate constraints down through an expression /// tree. @@ -345,7 +387,14 @@ impl ScalarUDF { /// Return true if this function is an async function pub fn as_async(&self) -> Option<&AsyncScalarUDF> { - self.inner().as_any().downcast_ref::() + self.inner().downcast_ref::() + } + + /// Returns placement information for this function. + /// + /// See [`ScalarUDFImpl::placement`] for more details. + pub fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement { + self.inner.placement(args) } } @@ -412,7 +461,7 @@ pub struct ReturnFieldArgs<'a> { /// See [`advanced_udf.rs`] for a full example with complete implementation and /// [`ScalarUDF`] for other available options. /// -/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs +/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udf.rs /// /// # Basic Example /// ``` @@ -449,7 +498,6 @@ pub struct ReturnFieldArgs<'a> { /// /// /// Implement the ScalarUDFImpl trait for AddOne /// impl ScalarUDFImpl for AddOne { -/// fn as_any(&self) -> &dyn Any { self } /// fn name(&self) -> &str { "add_one" } /// fn signature(&self) -> &Signature { &self.signature } /// fn return_type(&self, args: &[DataType]) -> Result { @@ -473,10 +521,7 @@ pub struct ReturnFieldArgs<'a> { /// // Call the function `add_one(col)` /// let expr = add_one.call(vec![col("a")]); /// ``` -pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { - /// Returns this object as an [`Any`] trait object - fn as_any(&self) -> &dyn Any; - +pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync + Any { /// Returns this function's name fn name(&self) -> &str; @@ -542,7 +587,7 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { /// /// If you provide an implementation for [`Self::return_field_from_args`], /// DataFusion will not call `return_type` (this function). While it is - /// valid to to put [`unimplemented!()`] or [`unreachable!()`], it is + /// valid to put [`unimplemented!()`] or [`unreachable!()`], it is /// recommended to return [`DataFusionError::Internal`] instead, which /// reduces the severity of symptoms if bugs occur (an error rather than a /// panic). @@ -609,7 +654,7 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { /// fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { /// // report output is only nullable if any one of the arguments are nullable /// let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); - /// let field = Arc::new(Field::new("ignored_name", DataType::Int32, true)); + /// let field = Arc::new(Field::new("ignored_name", DataType::Int32, nullable)); /// Ok(field) /// } /// # } @@ -690,11 +735,116 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + _info: &SimplifyContext, ) -> Result { Ok(ExprSimplifyResult::Original(args)) } + /// Returns a single contiguous preimage for this function and the specified + /// scalar expression, if any. + /// + /// Currently only applies to `=, !=, >, >=, <, <=, is distinct from, is not distinct from` predicates + /// # Return Value + /// + /// Implementations should return a half-open interval: inclusive lower + /// bound and exclusive upper bound. This is slightly different from normal + /// [`Interval`] semantics where the upper bound is closed (inclusive). + /// Typically this means the upper endpoint must be adjusted to the next + /// value not included in the preimage. See the Half-Open Intervals section + /// below for more details. + /// + /// # Background + /// + /// Inspired by the [ClickHouse Paper], a "preimage rewrite" transforms a + /// predicate containing a function call into a predicate containing an + /// equivalent set of input literal (constant) values. The resulting + /// predicate can often be further optimized by other rewrites (see + /// Examples). + /// + /// From the paper: + /// + /// > some functions can compute the preimage of a given function result. + /// > This is used to replace comparisons of constants with function calls + /// > on the key columns by comparing the key column value with the preimage. + /// > For example, `toYear(k) = 2024` can be replaced by + /// > `k >= 2024-01-01 && k < 2025-01-01` + /// + /// For example, given an expression like + /// ```sql + /// date_part('YEAR', k) = 2024 + /// ``` + /// + /// The interval `[2024-01-01, 2025-12-31`]` contains all possible input + /// values (preimage values) for which the function `date_part(YEAR, k)` + /// produces the output value `2024` (image value). Returning the interval + /// (note upper bound adjusted up) `[2024-01-01, 2025-01-01]` the expression + /// can be rewritten to + /// + /// ```sql + /// k >= '2024-01-01' AND k < '2025-01-01' + /// ``` + /// + /// which is a simpler and a more canonical form, making it easier for other + /// optimizer passes to recognize and apply further transformations. + /// + /// # Examples + /// + /// Case 1: + /// + /// Original: + /// ```sql + /// date_part('YEAR', k) = 2024 AND k >= '2024-06-01' + /// ``` + /// + /// After preimage rewrite: + /// ```sql + /// k >= '2024-01-01' AND k < '2025-01-01' AND k >= '2024-06-01' + /// ``` + /// + /// Since this form is much simpler, the optimizer can combine and simplify + /// sub-expressions further into: + /// ```sql + /// k >= '2024-06-01' AND k < '2025-01-01' + /// ``` + /// + /// Case 2: + /// + /// For min/max pruning, simpler predicates such as: + /// ```sql + /// k >= '2024-01-01' AND k < '2025-01-01' + /// ``` + /// are much easier for the pruner to reason about. See [PruningPredicate] + /// for the backgrounds of predicate pruning. + /// + /// The trade-off with the preimage rewrite is that evaluating the rewritten + /// form might be slightly more expensive than evaluating the original + /// expression. In practice, this cost is usually outweighed by the more + /// aggressive optimization opportunities it enables. + /// + /// # Half-Open Intervals + /// + /// The preimage API uses half-open intervals, which makes the rewrite + /// easier to implement by avoiding calculations to adjust the upper bound. + /// For example, if a function returns its input unchanged and the desired + /// output is the single value `5`, a closed interval could be represented + /// as `[5, 5]`, but then the rewrite would require adjusting the upper + /// bound to `6` to create a proper range predicate. With a half-open + /// interval, the same range is represented as `[5, 6)`, which already + /// forms a valid predicate. + /// + /// [PruningPredicate]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/pruning/struct.PruningPredicate.html + /// [ClickHouse Paper]: https://www.vldb.org/pvldb/vol17/p3731-schulze.pdf + /// [image]: https://en.wikipedia.org/wiki/Image_(mathematics)#Image_of_an_element + /// [preimage]: https://en.wikipedia.org/wiki/Image_(mathematics)#Inverse_image + fn preimage( + &self, + _args: &[Expr], + _lit_expr: &Expr, + _info: &SimplifyContext, + ) -> Result { + Ok(PreimageResult::None) + } + /// Returns true if some of this `exprs` subexpressions may not be evaluated /// and thus any side effects (like divide by zero) may not be encountered. /// @@ -838,6 +988,25 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { not_impl_err!("Function {} does not implement coerce_types", self.name()) } + /// For struct-producing functions, return how output fields map to input + /// arguments. This enables the optimizer to propagate orderings through + /// struct projections. + /// + /// `literal_args[i]` is `Some(value)` if argument `i` is a known literal, + /// allowing extraction of field names from arguments like + /// `named_struct('field_name', value, ...)`. + /// + /// For example, `named_struct('a', col1, 'b', col2)` would return a + /// mapping indicating that output field `'a'` (accessed via + /// `get_field(output, 'a')`) corresponds to input argument `col1` at + /// index 1, and field `'b'` corresponds to `col2` at index 3. + fn struct_field_mapping( + &self, + _literal_args: &[Option], + ) -> Option { + None + } + /// Returns the documentation for this Scalar UDF. /// /// Documentation can be accessed programmatically as well as generating @@ -845,6 +1014,39 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { fn documentation(&self) -> Option<&Documentation> { None } + + /// Returns placement information for this function. + /// + /// This is used by optimizers to make decisions about expression placement, + /// such as whether to push expressions down through projections. + /// + /// The default implementation returns [`ExpressionPlacement::KeepInPlace`], + /// meaning the expression should be kept where it is in the plan. + /// + /// Override this method to indicate that the function can be pushed down + /// closer to the data source. + fn placement(&self, _args: &[ExpressionPlacement]) -> ExpressionPlacement { + ExpressionPlacement::KeepInPlace + } +} + +impl dyn ScalarUDFImpl { + /// Returns `true` if the implementation is of type `T`. + /// + /// Works correctly when called on `Arc` via auto-deref. + pub fn is(&self) -> bool { + (self as &dyn Any).is::() + } + + /// Attempts to downcast to a concrete type `T`, returning `None` if the + /// implementation is not of that type. + /// + /// Works correctly when called on `Arc` via auto-deref, + /// unlike `(&arc as &dyn Any).downcast_ref::()` which would attempt to + /// downcast the `Arc` itself. + pub fn downcast_ref(&self) -> Option<&T> { + (self as &dyn Any).downcast_ref() + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -871,10 +1073,6 @@ impl AliasedScalarUDFImpl { #[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method impl ScalarUDFImpl for AliasedScalarUDFImpl { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { self.inner.name() } @@ -901,7 +1099,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { } fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { - #[allow(deprecated)] + #[expect(deprecated)] self.inner.is_nullable(args, schema) } @@ -920,11 +1118,20 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn simplify( &self, args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { self.inner.simplify(args, info) } + fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + info: &SimplifyContext, + ) -> Result { + self.inner.preimage(args, lit_expr, info) + } + fn conditional_arguments<'a>( &self, args: &'a [Expr], @@ -948,6 +1155,13 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.propagate_constraints(interval, inputs) } + fn struct_field_mapping( + &self, + literal_args: &[Option], + ) -> Option { + self.inner.struct_field_mapping(literal_args) + } + fn output_ordering(&self, inputs: &[ExprProperties]) -> Result { self.inner.output_ordering(inputs) } @@ -963,6 +1177,10 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() } + + fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement { + self.inner.placement(args) + } } #[cfg(test)] @@ -978,10 +1196,6 @@ mod tests { signature: Signature, } impl ScalarUDFImpl for TestScalarUDFImpl { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { self.name } diff --git a/datafusion/expr/src/udf_eq.rs b/datafusion/expr/src/udf_eq.rs index 6664495267129..8766b483137f4 100644 --- a/datafusion/expr/src/udf_eq.rs +++ b/datafusion/expr/src/udf_eq.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::{AggregateUDFImpl, ScalarUDFImpl, WindowUDFImpl}; +use crate::{AggregateUDFImpl, HigherOrderUDFImpl, ScalarUDFImpl, WindowUDFImpl}; +use std::any::Any; use std::fmt::Debug; use std::hash::{DefaultHasher, Hash, Hasher}; use std::ops::Deref; @@ -26,7 +27,7 @@ use std::sync::Arc; /// /// If you want to just compare pointers for equality, use [`super::ptr_eq::PtrEq`]. #[derive(Clone)] -#[allow(private_bounds)] // This is so that UdfEq can only be used with allowed pointer types (e.g. Arc), without allowing misuse. +#[expect(private_bounds)] // This is so that UdfEq can only be used with allowed pointer types (e.g. Arc), without allowing misuse. pub struct UdfEq(Ptr); impl PartialEq for UdfEq @@ -83,7 +84,19 @@ trait UdfPointer: Deref { impl UdfPointer for Arc { fn equals(&self, other: &(dyn ScalarUDFImpl + '_)) -> bool { - self.as_ref().dyn_eq(other.as_any()) + self.as_ref().dyn_eq(other as &dyn Any) + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.as_ref().dyn_hash(hasher); + hasher.finish() + } +} + +impl UdfPointer for Arc { + fn equals(&self, other: &Self::Target) -> bool { + self.as_ref().dyn_eq(other) } fn hash_value(&self) -> u64 { @@ -95,7 +108,7 @@ impl UdfPointer for Arc { impl UdfPointer for Arc { fn equals(&self, other: &(dyn AggregateUDFImpl + '_)) -> bool { - self.as_ref().dyn_eq(other.as_any()) + self.as_ref().dyn_eq(other) } fn hash_value(&self) -> u64 { @@ -107,7 +120,7 @@ impl UdfPointer for Arc { impl UdfPointer for Arc { fn equals(&self, other: &(dyn WindowUDFImpl + '_)) -> bool { - self.as_ref().dyn_eq(other.as_any()) + self.as_ref().dyn_eq(other as &dyn Any) } fn hash_value(&self) -> u64 { @@ -124,7 +137,6 @@ mod tests { use arrow::datatypes::DataType; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::signature::{Signature, Volatility}; - use std::any::Any; use std::hash::DefaultHasher; #[derive(Debug, PartialEq, Eq, Hash)] @@ -133,10 +145,6 @@ mod tests { name: &'static str, } impl ScalarUDFImpl for TestScalarUDF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { self.name } diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 3220fdcbcad70..5a5daca28a918 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -31,9 +31,9 @@ use arrow::datatypes::{DataType, FieldRef}; use crate::expr::WindowFunction; use crate::udf_eq::UdfEq; use crate::{ - function::WindowFunctionSimplification, Expr, PartitionEvaluator, Signature, + Expr, PartitionEvaluator, Signature, function::WindowFunctionSimplification, }; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::{Result, not_impl_err}; use datafusion_doc::Documentation; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_functions_window_common::expr::ExpressionArgs; @@ -66,8 +66,8 @@ use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// /// [`PartitionEvaluator`]: crate::PartitionEvaluator /// [`create_udwf`]: crate::expr_fn::create_udwf -/// [`simple_udwf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs -/// [`advanced_udwf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs +/// [`simple_udwf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/simple_udwf.rs +/// [`advanced_udwf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udwf.rs #[derive(Debug, Clone, PartialOrd)] pub struct WindowUDF { inner: Arc, @@ -82,7 +82,7 @@ impl Display for WindowUDF { impl PartialEq for WindowUDF { fn eq(&self, other: &Self) -> bool { - self.inner.dyn_eq(other.inner.as_any()) + self.inner.dyn_eq(other.inner.as_ref() as &dyn Any) } } @@ -157,7 +157,7 @@ impl WindowUDF { self.inner.signature() } - /// Do the function rewrite + /// Returns this window function's simplification hook, if any. /// /// See [`WindowUDFImpl::simplify`] for more details. pub fn simplify(&self) -> Option { @@ -237,10 +237,9 @@ where /// [`WindowUDF`] for other available options. /// /// -/// [`advanced_udwf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs +/// [`advanced_udwf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/udf/advanced_udwf.rs /// # Basic Example /// ``` -/// # use std::any::Any; /// # use std::sync::LazyLock; /// # use arrow::datatypes::{DataType, Field, FieldRef}; /// # use datafusion_common::{DataFusionError, plan_err, Result}; @@ -277,7 +276,6 @@ where /// /// /// Implement the WindowUDFImpl trait for SmoothIt /// impl WindowUDFImpl for SmoothIt { -/// fn as_any(&self) -> &dyn Any { self } /// fn name(&self) -> &str { "smooth_it" } /// fn signature(&self) -> &Signature { &self.signature } /// // The actual implementation would smooth the window @@ -314,10 +312,7 @@ where /// .build() /// .unwrap(); /// ``` -pub trait WindowUDFImpl: Debug + DynEq + DynHash + Send + Sync { - /// Returns this object as an [`Any`] trait object - fn as_any(&self) -> &dyn Any; - +pub trait WindowUDFImpl: Debug + DynEq + DynHash + Send + Sync + Any { /// Returns this function's name fn name(&self) -> &str; @@ -344,25 +339,28 @@ pub trait WindowUDFImpl: Debug + DynEq + DynHash + Send + Sync { partition_evaluator_args: PartitionEvaluatorArgs, ) -> Result>; - /// Optionally apply per-UDWF simplification / rewrite rules. + /// Returns an optional hook for simplifying this user-defined window + /// function. /// - /// This can be used to apply function specific simplification rules during - /// optimization. The default implementation does nothing. + /// Use this hook to apply function-specific rewrites during optimization. + /// The default implementation returns `None`. /// - /// Note that DataFusion handles simplifying arguments and "constant - /// folding" (replacing a function call with constant arguments such as - /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such - /// optimizations manually for specific UDFs. + /// DataFusion already simplifies arguments and performs constant folding + /// (for example, `my_add(1, 2) -> 3`), so there is usually no need to + /// implement those optimizations manually for specific UDFs. /// /// Example: - /// `advanced_udwf.rs`: + /// `advanced_udwf.rs`: /// /// # Returns - /// [None] if simplify is not defined or, + /// `None` if simplify is not defined. + /// + /// Or, a closure ([`WindowFunctionSimplification`]) invoked with: + /// * `window_function`: [WindowFunction] with already simplified + /// arguments + /// * `info`: [crate::simplify::SimplifyContext] /// - /// Or, a closure with two arguments: - /// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked - /// * 'info': [crate::simplify::SimplifyInfo] + /// The closure returns a simplified [Expr] or an error. /// /// # Notes /// The returned expression must have the same schema as the original @@ -433,6 +431,25 @@ pub trait WindowUDFImpl: Debug + DynEq + DynHash + Send + Sync { } } +impl dyn WindowUDFImpl { + /// Returns `true` if the implementation is of type `T`. + /// + /// Works correctly when called on `Arc` via auto-deref. + pub fn is(&self) -> bool { + (self as &dyn Any).is::() + } + + /// Attempts to downcast to a concrete type `T`, returning `None` if the + /// implementation is not of that type. + /// + /// Works correctly when called on `Arc` via auto-deref, + /// unlike `(&arc as &dyn Any).downcast_ref::()` which would attempt to + /// downcast the `Arc` itself. + pub fn downcast_ref(&self) -> Option<&T> { + (self as &dyn Any).downcast_ref() + } +} + /// the effect this function will have on the limit pushdown pub enum LimitEffect { /// Does not affect the limit (i.e. this is causal) @@ -459,7 +476,7 @@ pub enum ReversedUDWF { impl PartialEq for dyn WindowUDFImpl { fn eq(&self, other: &Self) -> bool { - self.dyn_eq(other.as_any()) + self.dyn_eq(other as &dyn Any) } } @@ -499,10 +516,6 @@ impl AliasedWindowUDFImpl { #[warn(clippy::missing_trait_methods)] // Delegates, so it should implement every single trait method impl WindowUDFImpl for AliasedWindowUDFImpl { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { self.inner.name() } @@ -567,7 +580,6 @@ mod test { use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; - use std::any::Any; use std::cmp::Ordering; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; @@ -591,9 +603,6 @@ mod test { /// Implement the WindowUDFImpl trait for AddOne impl WindowUDFImpl for AWindowUDF { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "a" } @@ -634,9 +643,6 @@ mod test { /// Implement the WindowUDFImpl trait for AddOne impl WindowUDFImpl for BWindowUDF { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "b" } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 8e8483bc2a35f..22abb454d4e6b 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams}; use crate::expr_rewriter::strip_outer_reference; use crate::{ - and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, + BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, and, }; use datafusion_expr_common::signature::{Signature, TypeSignature}; @@ -34,15 +34,15 @@ use datafusion_common::tree_node::{ }; use datafusion_common::utils::get_at_indices; use datafusion_common::{ - internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, HashMap, - Result, TableReference, + Column, DFSchema, DFSchemaRef, HashMap, Result, TableReference, internal_err, + plan_err, }; #[cfg(not(feature = "sql"))] -use crate::expr::{ExceptSelectItem, ExcludeSelectItem}; +use crate::sql::{ExceptSelectItem, ExcludeSelectItem, Ident, ObjectName}; use indexmap::IndexSet; #[cfg(feature = "sql")] -use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem}; +use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, Ident, ObjectName}; pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; @@ -66,6 +66,23 @@ pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { } } +/// Internal helper that generates indices for powerset subsets using bitset iteration. +/// Returns an iterator of index vectors, where each vector contains the indices +/// of elements to include in that subset. +fn powerset_indices(len: usize) -> impl Iterator> { + (0..(1 << len)).map(move |mask| { + let mut indices = vec![]; + let mut bitset = mask; + while bitset > 0 { + let rightmost: u64 = bitset & !(bitset - 1); + let idx = rightmost.trailing_zeros() as usize; + indices.push(idx); + bitset &= bitset - 1; + } + indices + }) +} + /// The [power set] (or powerset) of a set S is the set of all subsets of S, \ /// including the empty set and S itself. /// @@ -83,33 +100,23 @@ pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { /// and hence the power set of S is {{}, {x}, {y}, {z}, {x, y}, {x, z}, {y, z}, {x, y, z}}. /// /// [power set]: https://en.wikipedia.org/wiki/Power_set -fn powerset(slice: &[T]) -> Result>, String> { +pub fn powerset(slice: &[T]) -> Result>> { if slice.len() >= 64 { - return Err("The size of the set must be less than 64.".into()); + return plan_err!("The size of the set must be less than 64"); } - let mut v = Vec::new(); - for mask in 0..(1 << slice.len()) { - let mut ss = vec![]; - let mut bitset = mask; - while bitset > 0 { - let rightmost: u64 = bitset & !(bitset - 1); - let idx = rightmost.trailing_zeros(); - let item = slice.get(idx as usize).unwrap(); - ss.push(item); - // zero the trailing bit - bitset &= bitset - 1; - } - v.push(ss); - } - Ok(v) + Ok(powerset_indices(slice.len()) + .map(|indices| indices.iter().map(|&idx| &slice[idx]).collect()) + .collect()) } /// check the number of expressions contained in the grouping_set fn check_grouping_set_size_limit(size: usize) -> Result<()> { let max_grouping_set_size = 65535; if size > max_grouping_set_size { - return plan_err!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}"); + return plan_err!( + "The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}" + ); } Ok(()) @@ -119,7 +126,9 @@ fn check_grouping_set_size_limit(size: usize) -> Result<()> { fn check_grouping_sets_size_limit(size: usize) -> Result<()> { let max_grouping_sets_size = 4096; if size > max_grouping_sets_size { - return plan_err!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}"); + return plan_err!( + "The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}" + ); } Ok(()) @@ -207,8 +216,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { grouping_sets.iter().map(|e| e.iter().collect()).collect() } Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => { - let grouping_sets = powerset(group_exprs) - .map_err(|e| plan_datafusion_err!("{}", e))?; + let grouping_sets = powerset(group_exprs)?; check_grouping_sets_size_limit(grouping_sets.len())?; grouping_sets } @@ -304,10 +312,14 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::InList { .. } | Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) - | Expr::OuterReferenceColumn { .. } => {} + | Expr::OuterReferenceColumn { .. } + | Expr::HigherOrderFunction(_) + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => {} } Ok(TreeNodeRecursion::Continue) }) @@ -327,11 +339,32 @@ fn get_excluded_columns( idents.push(&excepts.first_element); idents.extend(&excepts.additional_elements); } + // Declared outside the `if let` so `idents.extend(exclude_owned.iter())` + // below can borrow references that outlive the inner scope. + let exclude_owned: Vec; if let Some(exclude) = opt_exclude { - match exclude { - ExcludeSelectItem::Single(ident) => idents.push(ident), - ExcludeSelectItem::Multiple(idents_inner) => idents.extend(idents_inner), - } + let object_name_to_ident = |name: &ObjectName| -> Result { + if name.0.len() != 1 { + return plan_err!( + "EXCLUDE with multi-part identifiers is not supported: {name}" + ); + } + let part = &name.0[0]; + let Some(ident) = part.as_ident() else { + return plan_err!( + "EXCLUDE with non-identifier name part is not supported: {part}" + ); + }; + Ok(ident.clone()) + }; + exclude_owned = match exclude { + ExcludeSelectItem::Single(name) => vec![object_name_to_ident(name)?], + ExcludeSelectItem::Multiple(names) => names + .iter() + .map(object_name_to_ident) + .collect::>>()?, + }; + idents.extend(exclude_owned.iter()); } // Excluded columns should be unique let n_elem = idents.len(); @@ -373,30 +406,39 @@ fn get_exprs_except_skipped( } } -/// For each column specified in the USING JOIN condition, the JOIN plan outputs it twice -/// (once for each join side), but an unqualified wildcard should include it only once. -/// This function returns the columns that should be excluded. +/// When a JOIN has a USING clause, the join columns appear in the output +/// schema once per side (for inner/outer joins) or once total (for semi/anti +/// joins). An unqualified wildcard should include each USING column only once. +/// This function returns the duplicate columns that should be excluded. fn exclude_using_columns(plan: &LogicalPlan) -> Result> { - let using_columns = plan.using_columns()?; - let excluded = using_columns - .into_iter() - // For each USING JOIN condition, only expand to one of each join column in projection - .flat_map(|cols| { - let mut cols = cols.into_iter().collect::>(); - // sort join columns to make sure we consistently keep the same - // qualified column - cols.sort(); - let mut out_column_names: HashSet = HashSet::new(); - cols.into_iter().filter_map(move |c| { - if out_column_names.contains(&c.name) { - Some(c) - } else { - out_column_names.insert(c.name); - None - } - }) - }) - .collect::>(); + let output_columns: HashSet<_> = plan.schema().columns().iter().cloned().collect(); + let mut excluded = HashSet::new(); + for cols in plan.using_columns()? { + // `using_columns()` returns join columns from both sides regardless of + // the join type. For semi/anti joins, only one side's columns appear in + // the output schema. Filter to output columns so that columns from the + // non-output side don't participate in the deduplication process below + // and displace real output columns. + let mut cols: Vec<_> = cols + .into_iter() + .filter(|c| output_columns.contains(c)) + .collect(); + + // Sort so we keep the same qualified column, regardless of HashSet + // iteration order. + cols.sort(); + + // Keep only one column per name from the columns set, adding any + // duplicates to the excluded set. + let mut seen_names = HashSet::new(); + for col in cols { + if seen_names.contains(col.name.as_str()) { + excluded.insert(col); // exclude columns with already seen name + } else { + seen_names.insert(col.name.clone()); // mark column name as seen + } + } + } Ok(excluded) } @@ -929,6 +971,7 @@ pub fn find_valid_equijoin_key_pair( /// round(Float32) /// ``` #[expect(clippy::needless_pass_by_value)] +#[deprecated(since = "53.0.0", note = "Internal function")] pub fn generate_signature_error_msg( func_name: &str, func_signature: Signature, @@ -943,9 +986,31 @@ pub fn generate_signature_error_msg( .join("\n"); format!( - "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}", - func_name, TypeSignature::join_types(input_expr_types, ", "), candidate_signatures - ) + "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}", + func_name, + TypeSignature::join_types(input_expr_types, ", "), + candidate_signatures + ) +} + +/// Creates a detailed error message for a function with wrong signature. +/// +/// For example, a query like `select round(3.14, 1.1);` would yield: +/// ```text +/// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts. +/// Candidate functions: +/// round(Float64, Int64) +/// round(Float32, Int64) +/// round(Float64) +/// round(Float32) +/// ``` +pub(crate) fn generate_signature_error_message( + func_name: &str, + func_signature: &Signature, + input_expr_types: &[DataType], +) -> String { + #[expect(deprecated)] + generate_signature_error_msg(func_name, func_signature.clone(), input_expr_types) } /// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` @@ -1276,14 +1341,13 @@ pub fn collect_subquery_cols( mod tests { use super::*; use crate::{ - col, cube, + Cast, ExprFunctionExt, WindowFunctionDefinition, col, cube, expr::WindowFunction, expr_vec_fmt, grouping_set, lit, rollup, test::function_stub::{max_udaf, min_udaf, sum_udaf}, - Cast, ExprFunctionExt, WindowFunctionDefinition, }; use arrow::datatypes::{UnionFields, UnionMode}; - use datafusion_expr_common::signature::{TypeSignature, Volatility}; + use datafusion_expr_common::signature::Volatility; #[test] fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> { @@ -1725,7 +1789,8 @@ mod tests { .expect("valid parameter names"); // Generate error message with only 1 argument provided - let error_msg = generate_signature_error_msg("substr", sig, &[DataType::Utf8]); + let error_msg = + generate_signature_error_message("substr", &sig, &[DataType::Utf8]); assert!( error_msg.contains("str: Utf8, start_pos: Int64"), @@ -1744,11 +1809,112 @@ mod tests { Volatility::Immutable, ); - let error_msg = generate_signature_error_msg("my_func", sig, &[DataType::Int32]); + let error_msg = + generate_signature_error_message("my_func", &sig, &[DataType::Int32]); assert!( error_msg.contains("Any, Any"), "Expected 'Any, Any' without parameter names, got: {error_msg}" ); } + + #[test] + fn test_signature_error_msg_exact() { + use insta::assert_snapshot; + + let sig = Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Float64, DataType::Int64]), + TypeSignature::Exact(vec![DataType::Float32, DataType::Int64]), + TypeSignature::Exact(vec![DataType::Float64]), + TypeSignature::Exact(vec![DataType::Float32]), + ], + Volatility::Immutable, + ); + let msg = generate_signature_error_message( + "round", + &sig, + &[DataType::Float64, DataType::Float64], + ); + assert_snapshot!(msg, @r" + No function matches the given name and argument types 'round(Float64, Float64)'. You might need to add explicit type casts. + Candidate functions: + round(Float64, Int64) + round(Float32, Int64) + round(Float64) + round(Float32) + "); + } + + #[test] + fn test_signature_error_msg_coercible() { + use datafusion_common::types::NativeType; + use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; + use insta::assert_snapshot; + + let sig = Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Native( + datafusion_common::types::logical_float64(), + ), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(datafusion_common::types::logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), + ], + Volatility::Immutable, + ); + let msg = generate_signature_error_message( + "round", + &sig, + &[DataType::Utf8, DataType::Utf8], + ); + assert_snapshot!(msg, @r" + No function matches the given name and argument types 'round(Utf8, Utf8)'. You might need to add explicit type casts. + Candidate functions: + round(Float64, Int64) + "); + } + + #[test] + fn test_signature_error_msg_with_names_coercible() { + use datafusion_common::types::NativeType; + use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; + use insta::assert_snapshot; + + let sig = Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Native( + datafusion_common::types::logical_string(), + )), + Coercion::new_exact(TypeSignatureClass::Native( + datafusion_common::types::logical_int64(), + )), + Coercion::new_implicit( + TypeSignatureClass::Native(datafusion_common::types::logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), + ], + Volatility::Immutable, + ) + .with_parameter_names(vec![ + "string".to_string(), + "start_pos".to_string(), + "length".to_string(), + ]) + .expect("valid parameter names"); + + let msg = generate_signature_error_message("substr", &sig, &[DataType::Int32]); + assert_snapshot!(msg, @r" + No function matches the given name and argument types 'substr(Int32)'. You might need to add explicit type casts. + Candidate functions: + substr(string: String, start_pos: Int64, length: Int64) + "); + } } diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 5fb2916c34e95..a61d9d689ae7a 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -27,7 +27,7 @@ use crate::{expr::Sort, lit}; use std::fmt::{self, Formatter}; use std::hash::Hash; -use datafusion_common::{plan_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, plan_err}; #[cfg(feature = "sql")] use sqlparser::ast::{self, ValueWithSpan}; @@ -131,12 +131,10 @@ impl TryFrom for WindowFrame { "Invalid window frame: start bound cannot be UNBOUNDED FOLLOWING" )? } - } else if let WindowFrameBound::Preceding(val) = &end_bound { - if val.is_null() { - plan_err!( - "Invalid window frame: end bound cannot be UNBOUNDED PRECEDING" - )? - } + } else if let WindowFrameBound::Preceding(val) = &end_bound + && val.is_null() + { + plan_err!("Invalid window frame: end bound cannot be UNBOUNDED PRECEDING")? }; let units = value.units.into(); @@ -254,15 +252,14 @@ impl WindowFrame { // one column. However, an ORDER BY clause may be absent or have // more than one column when the start/end bounds are UNBOUNDED or // CURRENT ROW. - WindowFrameUnits::Range if self.free_range() => { + WindowFrameUnits::Range if self.free_range() && order_by.is_empty() => { // If an ORDER BY clause is absent, it is equivalent to an // ORDER BY clause with constant value as sort key. If an // ORDER BY clause is present but has more than one column, // it is unchanged. Note that this follows PostgreSQL behavior. - if order_by.is_empty() { - order_by.push(lit(1u64).sort(true, false)); - } + order_by.push(lit(1u64).sort(true, false)); } + WindowFrameUnits::Range if self.free_range() => {} WindowFrameUnits::Range if order_by.len() != 1 => { return plan_err!("RANGE requires exactly one ORDER BY column"); } @@ -375,9 +372,10 @@ fn convert_frame_bound_to_scalar_value( match units { // For ROWS and GROUPS we are sure that the ScalarValue must be a non-negative integer ... ast::WindowFrameUnits::Rows | ast::WindowFrameUnits::Groups => match v { - ast::Expr::Value(ValueWithSpan{value: ast::Value::Number(value, false), span: _}) => { - Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?) - }, + ast::Expr::Value(ValueWithSpan { + value: ast::Value::Number(value, false), + span: _, + }) => Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?), ast::Expr::Interval(ast::Interval { value, leading_field: None, @@ -386,11 +384,12 @@ fn convert_frame_bound_to_scalar_value( fractional_seconds_precision: None, }) => { let value = match *value { - ast::Expr::Value(ValueWithSpan{value: ast::Value::SingleQuotedString(item), span: _}) => item, + ast::Expr::Value(ValueWithSpan { + value: ast::Value::SingleQuotedString(item), + span: _, + }) => item, e => { - return exec_err!( - "INTERVAL expression cannot be {e:?}" - ); + return exec_err!("INTERVAL expression cannot be {e:?}"); } }; Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?) @@ -402,18 +401,22 @@ fn convert_frame_bound_to_scalar_value( // ... instead for RANGE it could be anything depending on the type of the ORDER BY clause, // so we use a ScalarValue::Utf8. ast::WindowFrameUnits::Range => Ok(ScalarValue::Utf8(Some(match v { - ast::Expr::Value(ValueWithSpan{value: ast::Value::Number(value, false), span: _}) => value, + ast::Expr::Value(ValueWithSpan { + value: ast::Value::Number(value, false), + span: _, + }) => value, ast::Expr::Interval(ast::Interval { value, leading_field, .. }) => { let result = match *value { - ast::Expr::Value(ValueWithSpan{value: ast::Value::SingleQuotedString(item), span: _}) => item, + ast::Expr::Value(ValueWithSpan { + value: ast::Value::SingleQuotedString(item), + span: _, + }) => item, e => { - return exec_err!( - "INTERVAL expression cannot be {e:?}" - ); + return exec_err!("INTERVAL expression cannot be {e:?}"); } }; if let Some(leading_field) = leading_field { @@ -604,8 +607,16 @@ mod tests { last_field: None, leading_precision: None, }))); - test_bound_err!(Rows, number.clone(), "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers"); - test_bound_err!(Groups, number.clone(), "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers"); + test_bound_err!( + Rows, + number.clone(), + "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers" + ); + test_bound_err!( + Groups, + number.clone(), + "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers" + ); test_bound!( Range, number.clone(), diff --git a/datafusion/expr/src/window_state.rs b/datafusion/expr/src/window_state.rs index cdfb18ee1ddd7..f8d4609d3690c 100644 --- a/datafusion/expr/src/window_state.rs +++ b/datafusion/expr/src/window_state.rs @@ -23,14 +23,13 @@ use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use arrow::{ array::ArrayRef, - compute::{concat, concat_batches, SortOptions}, + compute::{SortOptions, concat, concat_batches}, datatypes::{DataType, SchemaRef}, record_batch::RecordBatch, }; use datafusion_common::{ - internal_datafusion_err, internal_err, + Result, ScalarValue, internal_datafusion_err, internal_err, utils::{compare_rows, get_row_at_idx, search_in_slice}, - Result, ScalarValue, }; /// Holds the state of evaluating a window function @@ -170,7 +169,7 @@ impl WindowFrameContext { // comparison of rows. WindowFrameContext::Range { window_frame, - ref mut state, + state, } => state.calculate_range( window_frame, last_range, @@ -183,7 +182,7 @@ impl WindowFrameContext { // or position of NULLs do not impact inequality. WindowFrameContext::Groups { window_frame, - ref mut state, + state, } => state.calculate_range(window_frame, range_columns, length, idx), } } @@ -205,14 +204,14 @@ impl WindowFrameContext { WindowFrameBound::Following(ScalarValue::UInt64(None)) => { return internal_err!( "Frame start cannot be UNBOUNDED FOLLOWING '{window_frame:?}'" - ) + ); } WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => { std::cmp::min(idx + n as usize, length) } // ERRONEOUS FRAMES WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => { - return internal_err!("Rows should be UInt64") + return internal_err!("Rows should be UInt64"); } }; let end = match window_frame.end_bound { @@ -220,7 +219,7 @@ impl WindowFrameContext { WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => { return internal_err!( "Frame end cannot be UNBOUNDED PRECEDING '{window_frame:?}'" - ) + ); } WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => { if idx >= n as usize { @@ -237,7 +236,7 @@ impl WindowFrameContext { } // ERRONEOUS FRAMES WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => { - return internal_err!("Rows should be UInt64") + return internal_err!("Rows should be UInt64"); } }; Ok(Range { start, end }) @@ -397,6 +396,11 @@ impl WindowFrameStateRange { length: usize, ) -> Result { let current_row_values = get_row_at_idx(range_columns, idx)?; + let search_start = if SIDE { + last_range.start + } else { + last_range.end + }; let end_range = if let Some(delta) = delta { let is_descending: bool = self .sort_options @@ -408,34 +412,40 @@ impl WindowFrameStateRange { })? .descending; - current_row_values - .iter() - .map(|value| { - if value.is_null() { - return Ok(value.clone()); + // On overflow the boundary exceeds the type's range and is + // effectively unbounded within the partition. Collapse to the + // partition edge rather than feeding `search_in_slice` a + // wrapped-around target: PRECEDING searches reach `search_start`, + // FOLLOWING searches reach `length`. + let unbounded_edge = if SEARCH_SIDE { search_start } else { length }; + let mut targets = Vec::with_capacity(current_row_values.len()); + for value in ¤t_row_values { + if value.is_null() { + targets.push(value.clone()); + continue; + } + let target = if SEARCH_SIDE == is_descending { + match value.add_checked(delta) { + Ok(v) => v, + Err(_) => return Ok(unbounded_edge), } - if SEARCH_SIDE == is_descending { - // TODO: Handle positive overflows. - value.add(delta) - } else if value.is_unsigned() && value < delta { - // NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue. - // If we decide to implement a "default" construction mechanism for ScalarValue, - // change the following statement to use that. - value.sub(value) - } else { - // TODO: Handle negative overflows. - value.sub(delta) + } else if value.is_unsigned() && value < delta { + // NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue. + // If we decide to implement a "default" construction mechanism for ScalarValue, + // change the following statement to use that. + value.sub(value)? + } else { + match value.sub_checked(delta) { + Ok(v) => v, + Err(_) => return Ok(unbounded_edge), } - }) - .collect::>>()? + }; + targets.push(target); + } + targets } else { current_row_values }; - let search_start = if SIDE { - last_range.start - } else { - last_range.end - }; let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| { let cmp = compare_rows(current, target, &self.sort_options)?; Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() }) diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 428855a61698c..778e6a24bf00e 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -41,7 +41,6 @@ workspace = true name = "datafusion_functions_aggregate" [dependencies] -ahash = { workspace = true } arrow = { workspace = true } datafusion-common = { workspace = true } datafusion-doc = { workspace = true } @@ -51,9 +50,10 @@ datafusion-functions-aggregate-common = { workspace = true } datafusion-macros = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } +foldhash = "0.2" half = { workspace = true } log = { workspace = true } -paste = "1.0.14" +num-traits = { workspace = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } @@ -75,3 +75,23 @@ harness = false [[bench]] harness = false name = "min_max_bytes" + +[[bench]] +name = "approx_distinct" +harness = false + +[[bench]] +name = "first_last" +harness = false + +[[bench]] +name = "count_distinct" +harness = false + +[[bench]] +name = "median" +harness = false + +[[bench]] +name = "percentile_cont" +harness = false diff --git a/datafusion/functions-aggregate/benches/approx_distinct.rs b/datafusion/functions-aggregate/benches/approx_distinct.rs new file mode 100644 index 0000000000000..44b45431e3eb1 --- /dev/null +++ b/datafusion/functions-aggregate/benches/approx_distinct.rs @@ -0,0 +1,331 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::hint::black_box; +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, Int8Array, Int16Array, Int64Array, StringArray, StringViewArray, + UInt8Array, UInt16Array, +}; +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::{ + Accumulator, AggregateUDF, AggregateUDFImpl, EmitTo, GroupsAccumulator, +}; +use datafusion_functions_aggregate::approx_distinct::ApproxDistinct; +use datafusion_physical_expr::GroupsAccumulatorAdapter; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_expr::expressions::col; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; + +const BATCH_SIZE: usize = 8192; +const SHORT_STRING_LENGTH: usize = 8; +const LONG_STRING_LENGTH: usize = 20; + +// Grouped (high-cardinality `GROUP BY`) benchmark parameters. +const N_GROUPS: usize = 50_000; +const AVG_ROWS_PER_GROUP: usize = 8; +const STRING_POOL_SIZE: usize = 100_000; + +fn prepare_accumulator(data_type: DataType) -> Box { + let schema = Arc::new(Schema::new(vec![Field::new("f", data_type, true)])); + let expr = col("f", &schema).unwrap(); + let accumulator_args = AccumulatorArgs { + return_field: Field::new("f", DataType::UInt64, true).into(), + schema: &schema, + expr_fields: &[expr.return_field(&schema).unwrap()], + ignore_nulls: false, + order_bys: &[], + is_reversed: false, + name: "approx_distinct(f)", + is_distinct: false, + exprs: &[expr], + }; + ApproxDistinct::new().accumulator(accumulator_args).unwrap() +} + +/// Creates an Int64Array where values are drawn from `0..n_distinct`. +fn create_i64_array(n_distinct: usize) -> Int64Array { + let mut rng = StdRng::seed_from_u64(42); + (0..BATCH_SIZE) + .map(|_| Some(rng.random_range(0..n_distinct as i64))) + .collect() +} + +fn create_u8_array(n_distinct: usize) -> UInt8Array { + let mut rng = StdRng::seed_from_u64(42); + let max_val = n_distinct.min(256) as u8; + (0..BATCH_SIZE) + .map(|_| Some(rng.random_range(0..max_val))) + .collect() +} + +fn create_i8_array(n_distinct: usize) -> Int8Array { + let mut rng = StdRng::seed_from_u64(42); + let max_val = (n_distinct.min(256) / 2) as i8; + (0..BATCH_SIZE) + .map(|_| Some(rng.random_range(-max_val..max_val))) + .collect() +} + +fn create_u16_array(n_distinct: usize) -> UInt16Array { + let mut rng = StdRng::seed_from_u64(42); + let max_val = n_distinct.min(65536) as u16; + (0..BATCH_SIZE) + .map(|_| Some(rng.random_range(0..max_val))) + .collect() +} + +fn create_i16_array(n_distinct: usize) -> Int16Array { + let mut rng = StdRng::seed_from_u64(42); + let max_val = (n_distinct.min(65536) / 2) as i16; + (0..BATCH_SIZE) + .map(|_| Some(rng.random_range(-max_val..max_val))) + .collect() +} + +/// Creates a pool of `n_distinct` random strings of the given length. +fn create_string_pool(n_distinct: usize, string_length: usize) -> Vec { + let mut rng = StdRng::seed_from_u64(42); + (0..n_distinct) + .map(|_| { + (0..string_length) + .map(|_| rng.random_range(b'a'..=b'z') as char) + .collect() + }) + .collect() +} + +/// Creates a StringArray where values are drawn from the given pool. +fn create_string_array(pool: &[String]) -> StringArray { + let mut rng = StdRng::seed_from_u64(99); + (0..BATCH_SIZE) + .map(|_| Some(pool[rng.random_range(0..pool.len())].as_str())) + .collect() +} + +/// Creates a StringViewArray where values are drawn from the given pool. +fn create_string_view_array(pool: &[String]) -> StringViewArray { + let mut rng = StdRng::seed_from_u64(99); + (0..BATCH_SIZE) + .map(|_| Some(pool[rng.random_range(0..pool.len())].as_str())) + .collect() +} + +fn approx_distinct_benchmark(c: &mut Criterion) { + for pct in [80, 99] { + let n_distinct = BATCH_SIZE * pct / 100; + + // --- Int64 benchmarks --- + let values = Arc::new(create_i64_array(n_distinct)) as ArrayRef; + c.bench_function(&format!("approx_distinct i64 {pct}% distinct"), |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::Int64); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }); + + for (label, str_len) in + [("short", SHORT_STRING_LENGTH), ("long", LONG_STRING_LENGTH)] + { + let string_pool = create_string_pool(n_distinct, str_len); + + // --- Utf8 benchmarks --- + let values = Arc::new(create_string_array(&string_pool)) as ArrayRef; + c.bench_function( + &format!("approx_distinct utf8 {label} {pct}% distinct"), + |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::Utf8); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }, + ); + + // --- Utf8View benchmarks --- + let values = Arc::new(create_string_view_array(&string_pool)) as ArrayRef; + c.bench_function( + &format!("approx_distinct utf8view {label} {pct}% distinct"), + |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::Utf8View); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }, + ); + } + } + + // Small integer types + + // UInt8 + let values = Arc::new(create_u8_array(200)) as ArrayRef; + c.bench_function("approx_distinct u8 bitmap", |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::UInt8); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }); + + // Int8 + let values = Arc::new(create_i8_array(200)) as ArrayRef; + c.bench_function("approx_distinct i8 bitmap", |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::Int8); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }); + + // UInt16 + let values = Arc::new(create_u16_array(50000)) as ArrayRef; + c.bench_function("approx_distinct u16 bitmap", |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::UInt16); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }); + + // Int16 + let values = Arc::new(create_i16_array(50000)) as ArrayRef; + c.bench_function("approx_distinct i16 bitmap", |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::Int16); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }); +} + +/// Build a `GroupsAccumulator` the same way the aggregate operator does: use the +/// specialized one if the function supports it, otherwise fall back to wrapping +/// the per-group `Accumulator` in a `GroupsAccumulatorAdapter`. +fn prepare_groups_accumulator(data_type: DataType) -> Box { + let schema = Arc::new(Schema::new(vec![Field::new("f", data_type, true)])); + let expr = col("f", &schema).unwrap(); + let udf = Arc::new(AggregateUDF::from(ApproxDistinct::new())); + let agg = Arc::new( + AggregateExprBuilder::new(udf, vec![expr]) + .schema(schema) + .alias("approx_distinct(f)") + .build() + .unwrap(), + ); + + if agg.groups_accumulator_supported() { + agg.create_groups_accumulator().unwrap() + } else { + let agg = Arc::clone(&agg); + let factory = move || agg.create_accumulator(); + Box::new(GroupsAccumulatorAdapter::new(factory)) + } +} + +fn grouped_total_rows() -> usize { + N_GROUPS * AVG_ROWS_PER_GROUP +} + +/// A random group index in `0..N_GROUPS` for each row of a batch. +fn make_group_indices(rng: &mut StdRng) -> Vec { + (0..BATCH_SIZE) + .map(|_| rng.random_range(0..N_GROUPS)) + .collect() +} + +/// Pre-build all input batches `(values, group_indices)` for the grouped run, so +/// the measured loop only times the accumulator, not data generation. +fn build_grouped_batches(data_type: &DataType) -> Vec<(ArrayRef, Vec)> { + let n_batches = grouped_total_rows().div_ceil(BATCH_SIZE); + let mut rng = StdRng::seed_from_u64(7); + let pool = create_string_pool(STRING_POOL_SIZE, SHORT_STRING_LENGTH); + + (0..n_batches) + .map(|_| { + let group_indices = make_group_indices(&mut rng); + let values: ArrayRef = match data_type { + DataType::Int64 => Arc::new( + (0..BATCH_SIZE) + .map(|_| Some(rng.random::())) + .collect::(), + ), + DataType::Utf8 => Arc::new( + (0..BATCH_SIZE) + .map(|_| Some(pool[rng.random_range(0..pool.len())].as_str())) + .collect::(), + ), + DataType::Utf8View => Arc::new( + (0..BATCH_SIZE) + .map(|_| Some(pool[rng.random_range(0..pool.len())].as_str())) + .collect::(), + ), + other => panic!("unsupported grouped bench type: {other}"), + }; + (values, group_indices) + }) + .collect() +} + +/// Benchmark grouped `approx_distinct` over many groups. Each iteration feeds all batches into a +/// fresh accumulator and emits the result for every group. +fn approx_distinct_grouped_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("approx_distinct_grouped"); + group.sample_size(10); + + for data_type in [DataType::Int64, DataType::Utf8, DataType::Utf8View] { + let batches = build_grouped_batches(&data_type); + let label = format!("{data_type:?} {N_GROUPS} groups"); + group.bench_function(&label, |b| { + b.iter(|| { + let mut acc = prepare_groups_accumulator(data_type.clone()); + for (values, group_indices) in &batches { + acc.update_batch( + std::slice::from_ref(values), + group_indices, + None, + N_GROUPS, + ) + .unwrap(); + } + black_box(acc.evaluate(EmitTo::All).unwrap()); + }) + }); + } + + group.finish(); +} + +criterion_group!( + benches, + approx_distinct_benchmark, + approx_distinct_grouped_benchmark +); +criterion_main!(benches); diff --git a/datafusion/functions-aggregate/benches/array_agg.rs b/datafusion/functions-aggregate/benches/array_agg.rs index 83b0c4a4c659c..b0d8148c3ea65 100644 --- a/datafusion/functions-aggregate/benches/array_agg.rs +++ b/datafusion/functions-aggregate/benches/array_agg.rs @@ -20,29 +20,30 @@ use std::sync::Arc; use arrow::array::{ Array, ArrayRef, ArrowPrimitiveType, AsArray, ListArray, NullBufferBuilder, - PrimitiveArray, }; use arrow::datatypes::{Field, Int64Type}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_expr::Accumulator; use datafusion_functions_aggregate::array_agg::ArrayAggAccumulator; use arrow::buffer::OffsetBuffer; -use rand::distr::{Distribution, StandardUniform}; -use rand::prelude::StdRng; +use arrow::util::bench_util::create_primitive_array; use rand::Rng; use rand::SeedableRng; +use rand::distr::{Distribution, StandardUniform}; +use rand::prelude::StdRng; /// Returns fixed seedable RNG pub fn seedable_rng() -> StdRng { StdRng::seed_from_u64(42) } +#[expect(clippy::needless_pass_by_value)] fn merge_batch_bench(c: &mut Criterion, name: &str, values: ArrayRef) { let list_item_data_type = values.as_list::().values().data_type().clone(); c.bench_function(name, |b| { b.iter(|| { - #[allow(clippy::unit_arg)] + #[expect(clippy::unit_arg)] black_box( ArrayAggAccumulator::try_new(&list_item_data_type, false) .unwrap() @@ -53,24 +54,6 @@ fn merge_batch_bench(c: &mut Criterion, name: &str, values: ArrayRef) { }); } -pub fn create_primitive_array(size: usize, null_density: f32) -> PrimitiveArray -where - T: ArrowPrimitiveType, - StandardUniform: Distribution, -{ - let mut rng = seedable_rng(); - - (0..size) - .map(|_| { - if rng.random::() < null_density { - None - } else { - Some(rng.random()) - } - }) - .collect() -} - /// Create List array with the given item data type, null density, null locations and zero length lists density /// Creates a random (but fixed-seeded) array of a given size and null density pub fn create_list_array( diff --git a/datafusion/functions-aggregate/benches/count.rs b/datafusion/functions-aggregate/benches/count.rs index 53484652fd251..48f71858c1204 100644 --- a/datafusion/functions-aggregate/benches/count.rs +++ b/datafusion/functions-aggregate/benches/count.rs @@ -30,7 +30,7 @@ use datafusion_expr::{Accumulator, AggregateUDFImpl, GroupsAccumulator}; use datafusion_functions_aggregate::count::Count; use datafusion_physical_expr::expressions::col; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; fn prepare_group_accumulator() -> Box { let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Int32, true)])); @@ -76,6 +76,7 @@ fn prepare_accumulator() -> Box { count_fn.accumulator(accumulator_args).unwrap() } +#[expect(clippy::needless_pass_by_value)] fn convert_to_state_bench( c: &mut Criterion, name: &str, @@ -129,7 +130,7 @@ fn count_benchmark(c: &mut Criterion) { let mut accumulator = prepare_accumulator(); c.bench_function("count low cardinality dict 20% nulls, no filter", |b| { b.iter(|| { - #[allow(clippy::unit_arg)] + #[expect(clippy::unit_arg)] black_box( accumulator .update_batch(std::slice::from_ref(&values)) diff --git a/datafusion/functions-aggregate/benches/count_distinct.rs b/datafusion/functions-aggregate/benches/count_distinct.rs new file mode 100644 index 0000000000000..4d9e8c5b67b31 --- /dev/null +++ b/datafusion/functions-aggregate/benches/count_distinct.rs @@ -0,0 +1,459 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, Int8Array, Int16Array, Int32Array, Int64Array, UInt8Array, + UInt16Array, UInt32Array, +}; +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::{Accumulator, AggregateUDFImpl, EmitTo}; +use datafusion_functions_aggregate::count::Count; +use datafusion_physical_expr::expressions::col; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; + +const BATCH_SIZE: usize = 8192; + +fn prepare_accumulator(data_type: DataType) -> Box { + let schema = Arc::new(Schema::new(vec![Field::new("f", data_type, true)])); + let expr = col("f", &schema).unwrap(); + let accumulator_args = AccumulatorArgs { + return_field: Field::new("f", DataType::Int64, true).into(), + schema: &schema, + expr_fields: &[expr.return_field(&schema).unwrap()], + ignore_nulls: false, + order_bys: &[], + is_reversed: false, + name: "count(distinct f)", + is_distinct: true, + exprs: &[expr], + }; + Count::new().accumulator(accumulator_args).unwrap() +} + +fn create_i64_array(n_distinct: usize) -> Int64Array { + let mut rng = StdRng::seed_from_u64(42); + (0..BATCH_SIZE) + .map(|_| Some(rng.random_range(0..n_distinct as i64))) + .collect() +} + +fn create_u8_array(n_distinct: usize) -> UInt8Array { + let mut rng = StdRng::seed_from_u64(42); + let max_val = n_distinct.min(256) as u8; + (0..BATCH_SIZE) + .map(|_| Some(rng.random_range(0..max_val))) + .collect() +} + +fn create_i8_array(n_distinct: usize) -> Int8Array { + let mut rng = StdRng::seed_from_u64(42); + let max_val = (n_distinct.min(256) / 2) as i8; + (0..BATCH_SIZE) + .map(|_| Some(rng.random_range(-max_val..max_val))) + .collect() +} + +fn create_u16_array(n_distinct: usize) -> UInt16Array { + let mut rng = StdRng::seed_from_u64(42); + let max_val = n_distinct.min(65536) as u16; + (0..BATCH_SIZE) + .map(|_| Some(rng.random_range(0..max_val))) + .collect() +} + +fn create_i16_array(n_distinct: usize) -> Int16Array { + let mut rng = StdRng::seed_from_u64(42); + let max_val = (n_distinct.min(65536) / 2) as i16; + (0..BATCH_SIZE) + .map(|_| Some(rng.random_range(-max_val..max_val))) + .collect() +} + +fn create_u32_array(n_distinct: usize) -> UInt32Array { + let mut rng = StdRng::seed_from_u64(42); + (0..BATCH_SIZE) + .map(|_| Some(rng.random_range(0..n_distinct as u32))) + .collect() +} + +fn create_i32_array(n_distinct: usize) -> Int32Array { + let mut rng = StdRng::seed_from_u64(42); + (0..BATCH_SIZE) + .map(|_| Some(rng.random_range(0..n_distinct as i32))) + .collect() +} + +fn prepare_args(data_type: DataType) -> (Arc, AccumulatorArgs<'static>) { + let schema = Arc::new(Schema::new(vec![Field::new("f", data_type, true)])); + let schema_leaked: &'static Schema = Box::leak(Box::new((*schema).clone())); + let expr = col("f", schema_leaked).unwrap(); + let expr_leaked: &'static _ = Box::leak(Box::new(expr)); + let return_field: Arc = Field::new("f", DataType::Int64, true).into(); + let return_field_leaked: &'static _ = Box::leak(Box::new(return_field.clone())); + let expr_field = expr_leaked.return_field(schema_leaked).unwrap(); + let expr_field_leaked: &'static _ = Box::leak(Box::new(expr_field)); + + let accumulator_args = AccumulatorArgs { + return_field: return_field_leaked.clone(), + schema: schema_leaked, + expr_fields: std::slice::from_ref(expr_field_leaked), + ignore_nulls: false, + order_bys: &[], + is_reversed: false, + name: "count(distinct f)", + is_distinct: true, + exprs: std::slice::from_ref(expr_leaked), + }; + (schema, accumulator_args) +} + +fn count_distinct_benchmark(c: &mut Criterion) { + for pct in [80, 99] { + let n_distinct = BATCH_SIZE * pct / 100; + + // Int64 + let values = Arc::new(create_i64_array(n_distinct)) as ArrayRef; + c.bench_function(&format!("count_distinct i64 {pct}% distinct"), |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::Int64); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }); + } + + // Small integer types + + // UInt8 + let values = Arc::new(create_u8_array(200)) as ArrayRef; + c.bench_function("count_distinct u8 bitmap", |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::UInt8); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }); + + // Int8 + let values = Arc::new(create_i8_array(200)) as ArrayRef; + c.bench_function("count_distinct i8 bitmap", |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::Int8); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }); + + // UInt16 + let values = Arc::new(create_u16_array(50000)) as ArrayRef; + c.bench_function("count_distinct u16 bitmap", |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::UInt16); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }); + + // Int16 + let values = Arc::new(create_i16_array(50000)) as ArrayRef; + c.bench_function("count_distinct i16 bitmap", |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::Int16); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }); + + // 32-bit integer types + for pct in [80, 99] { + let n_distinct = BATCH_SIZE * pct / 100; + + // UInt32 + let values = Arc::new(create_u32_array(n_distinct)) as ArrayRef; + c.bench_function(&format!("count_distinct u32 {pct}% distinct"), |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::UInt32); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }); + + // Int32 + let values = Arc::new(create_i32_array(n_distinct)) as ArrayRef; + c.bench_function(&format!("count_distinct i32 {pct}% distinct"), |b| { + b.iter(|| { + let mut accumulator = prepare_accumulator(DataType::Int32); + accumulator + .update_batch(std::slice::from_ref(&values)) + .unwrap() + }) + }); + } +} + +/// Create group indices with uniform distribution +fn create_uniform_groups(num_groups: usize) -> Vec { + let mut rng = StdRng::seed_from_u64(42); + (0..BATCH_SIZE) + .map(|_| rng.random_range(0..num_groups)) + .collect() +} + +/// Create group indices with skewed distribution (80% in 20% of groups) +fn create_skewed_groups(num_groups: usize) -> Vec { + let mut rng = StdRng::seed_from_u64(42); + let hot_groups = (num_groups / 5).max(1); + (0..BATCH_SIZE) + .map(|_| { + if rng.random_range(0..100) < 80 { + rng.random_range(0..hot_groups) + } else { + rng.random_range(0..num_groups) + } + }) + .collect() +} + +fn count_distinct_groups_benchmark(c: &mut Criterion) { + let count_fn = Count::new(); + + let group_counts = [100, 1000, 10000]; + let cardinalities = [("low", 20), ("mid", 80), ("high", 99)]; + let distributions = ["uniform", "skewed"]; + + // i64 benchmarks + for num_groups in group_counts { + for (card_name, distinct_pct) in cardinalities { + for dist in distributions { + let name = format!("i64_g{num_groups}_{card_name}_{dist}"); + let n_distinct = BATCH_SIZE * distinct_pct / 100; + let values = Arc::new(create_i64_array(n_distinct)) as ArrayRef; + let group_indices = if dist == "uniform" { + create_uniform_groups(num_groups) + } else { + create_skewed_groups(num_groups) + }; + + let (_schema, args) = prepare_args(DataType::Int64); + + if count_fn.groups_accumulator_supported(args.clone()) { + c.bench_function(&format!("count_distinct_groups {name}"), |b| { + b.iter(|| { + let mut acc = + count_fn.create_groups_accumulator(args.clone()).unwrap(); + acc.update_batch( + std::slice::from_ref(&values), + &group_indices, + None, + num_groups, + ) + .unwrap(); + acc.evaluate(EmitTo::All).unwrap() + }) + }); + } else { + let arr = values.as_any().downcast_ref::().unwrap(); + let mut group_rows: Vec> = vec![Vec::new(); num_groups]; + for (idx, &group_idx) in group_indices.iter().enumerate() { + if arr.is_valid(idx) { + group_rows[group_idx].push(arr.value(idx)); + } + } + let group_arrays: Vec = group_rows + .iter() + .map(|rows| Arc::new(Int64Array::from(rows.clone())) as ArrayRef) + .collect(); + + c.bench_function(&format!("count_distinct_groups {name}"), |b| { + b.iter(|| { + let mut accumulators: Vec<_> = (0..num_groups) + .map(|_| prepare_accumulator(DataType::Int64)) + .collect(); + + for (group_idx, batch) in group_arrays.iter().enumerate() { + if !batch.is_empty() { + accumulators[group_idx] + .update_batch(std::slice::from_ref(batch)) + .unwrap(); + } + } + + let _results: Vec<_> = accumulators + .iter_mut() + .map(|acc| acc.evaluate().unwrap()) + .collect(); + }) + }); + } + } + } + } + + // i32 benchmarks + for num_groups in group_counts { + for (card_name, distinct_pct) in cardinalities { + for dist in distributions { + let name = format!("i32_g{num_groups}_{card_name}_{dist}"); + let n_distinct = BATCH_SIZE * distinct_pct / 100; + let values = Arc::new(create_i32_array(n_distinct)) as ArrayRef; + let group_indices = if dist == "uniform" { + create_uniform_groups(num_groups) + } else { + create_skewed_groups(num_groups) + }; + + let (_schema, args) = prepare_args(DataType::Int32); + + if count_fn.groups_accumulator_supported(args.clone()) { + c.bench_function(&format!("count_distinct_groups {name}"), |b| { + b.iter(|| { + let mut acc = + count_fn.create_groups_accumulator(args.clone()).unwrap(); + acc.update_batch( + std::slice::from_ref(&values), + &group_indices, + None, + num_groups, + ) + .unwrap(); + acc.evaluate(EmitTo::All).unwrap() + }) + }); + } else { + let arr = values.as_any().downcast_ref::().unwrap(); + let mut group_rows: Vec> = vec![Vec::new(); num_groups]; + for (idx, &group_idx) in group_indices.iter().enumerate() { + if arr.is_valid(idx) { + group_rows[group_idx].push(arr.value(idx)); + } + } + let group_arrays: Vec = group_rows + .iter() + .map(|rows| Arc::new(Int32Array::from(rows.clone())) as ArrayRef) + .collect(); + + c.bench_function(&format!("count_distinct_groups {name}"), |b| { + b.iter(|| { + let mut accumulators: Vec<_> = (0..num_groups) + .map(|_| prepare_accumulator(DataType::Int32)) + .collect(); + + for (group_idx, batch) in group_arrays.iter().enumerate() { + if !batch.is_empty() { + accumulators[group_idx] + .update_batch(std::slice::from_ref(batch)) + .unwrap(); + } + } + + let _results: Vec<_> = accumulators + .iter_mut() + .map(|acc| acc.evaluate().unwrap()) + .collect(); + }) + }); + } + } + } + } + + // u32 benchmarks + for num_groups in group_counts { + for (card_name, distinct_pct) in cardinalities { + for dist in distributions { + let name = format!("u32_g{num_groups}_{card_name}_{dist}"); + let n_distinct = BATCH_SIZE * distinct_pct / 100; + let values = Arc::new(create_u32_array(n_distinct)) as ArrayRef; + let group_indices = if dist == "uniform" { + create_uniform_groups(num_groups) + } else { + create_skewed_groups(num_groups) + }; + + let (_schema, args) = prepare_args(DataType::UInt32); + + if count_fn.groups_accumulator_supported(args.clone()) { + c.bench_function(&format!("count_distinct_groups {name}"), |b| { + b.iter(|| { + let mut acc = + count_fn.create_groups_accumulator(args.clone()).unwrap(); + acc.update_batch( + std::slice::from_ref(&values), + &group_indices, + None, + num_groups, + ) + .unwrap(); + acc.evaluate(EmitTo::All).unwrap() + }) + }); + } else { + let arr = values.as_any().downcast_ref::().unwrap(); + let mut group_rows: Vec> = vec![Vec::new(); num_groups]; + for (idx, &group_idx) in group_indices.iter().enumerate() { + if arr.is_valid(idx) { + group_rows[group_idx].push(arr.value(idx)); + } + } + let group_arrays: Vec = group_rows + .iter() + .map(|rows| Arc::new(UInt32Array::from(rows.clone())) as ArrayRef) + .collect(); + + c.bench_function(&format!("count_distinct_groups {name}"), |b| { + b.iter(|| { + let mut accumulators: Vec<_> = (0..num_groups) + .map(|_| prepare_accumulator(DataType::UInt32)) + .collect(); + + for (group_idx, batch) in group_arrays.iter().enumerate() { + if !batch.is_empty() { + accumulators[group_idx] + .update_batch(std::slice::from_ref(batch)) + .unwrap(); + } + } + + let _results: Vec<_> = accumulators + .iter_mut() + .map(|acc| acc.evaluate().unwrap()) + .collect(); + }) + }); + } + } + } + } +} + +criterion_group!( + benches, + count_distinct_benchmark, + count_distinct_groups_benchmark +); +criterion_main!(benches); diff --git a/datafusion/functions-aggregate/benches/first_last.rs b/datafusion/functions-aggregate/benches/first_last.rs new file mode 100644 index 0000000000000..1d18e1c7dcd44 --- /dev/null +++ b/datafusion/functions-aggregate/benches/first_last.rs @@ -0,0 +1,359 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, BooleanArray, Int64Array}; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, Int64Type, Schema}; +use arrow::util::bench_util::{create_boolean_array, create_primitive_array}; +use datafusion_common::instant::Instant; +use std::hint::black_box; +use std::sync::Arc; + +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, function::AccumulatorArgs, +}; +use datafusion_functions_aggregate::first_last::{ + FirstValue, LastValue, TrivialFirstValueAccumulator, TrivialLastValueAccumulator, +}; +use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_expr::expressions::col; + +use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; + +fn prepare_groups_accumulator(is_first: bool) -> Box { + let schema = Arc::new(Schema::new(vec![ + Field::new("value", DataType::Int64, true), + Field::new("ord", DataType::Int64, true), + ])); + + let order_expr = col("ord", &schema).unwrap(); + let sort_expr = PhysicalSortExpr { + expr: order_expr, + options: SortOptions::default(), + }; + + let value_field: Arc = Field::new("value", DataType::Int64, true).into(); + let accumulator_args = AccumulatorArgs { + return_field: Arc::clone(&value_field), + schema: &schema, + expr_fields: &[value_field], + ignore_nulls: false, + order_bys: std::slice::from_ref(&sort_expr), + is_reversed: false, + name: if is_first { + "FIRST_VALUE(value ORDER BY ord)" + } else { + "LAST_VALUE(value ORDER BY ord)" + }, + is_distinct: false, + exprs: &[col("value", &schema).unwrap()], + }; + + if is_first { + FirstValue::new() + .create_groups_accumulator(accumulator_args) + .unwrap() + } else { + LastValue::new() + .create_groups_accumulator(accumulator_args) + .unwrap() + } +} + +fn create_trivial_accumulator( + is_first: bool, + ignore_nulls: bool, +) -> Box { + if is_first { + Box::new( + TrivialFirstValueAccumulator::try_new(&DataType::Int64, ignore_nulls) + .unwrap(), + ) + } else { + Box::new( + TrivialLastValueAccumulator::try_new(&DataType::Int64, ignore_nulls).unwrap(), + ) + } +} + +#[expect(clippy::needless_pass_by_value)] +#[expect(clippy::too_many_arguments)] +fn evaluate_bench( + c: &mut Criterion, + is_first: bool, + emit_to: EmitTo, + name: &str, + values: ArrayRef, + ord: ArrayRef, + opt_filter: Option<&BooleanArray>, + num_groups: usize, +) { + let n = values.len(); + let group_indices: Vec = (0..n).map(|i| i % num_groups).collect(); + + c.bench_function(name, |b| { + b.iter_batched( + || { + let mut accumulator = prepare_groups_accumulator(is_first); + accumulator + .update_batch( + &[Arc::clone(&values), Arc::clone(&ord)], + &group_indices, + opt_filter, + num_groups, + ) + .unwrap(); + accumulator + }, + |mut accumulator| { + black_box(accumulator.evaluate(emit_to).unwrap()); + }, + BatchSize::SmallInput, + ) + }); +} + +#[expect(clippy::needless_pass_by_value)] +fn update_bench( + c: &mut Criterion, + is_first: bool, + name: &str, + values: ArrayRef, + ord: ArrayRef, + opt_filter: Option<&BooleanArray>, + num_groups: usize, +) { + let n = values.len(); + let group_indices: Vec = (0..n).map(|i| i % num_groups).collect(); + + // Initialize with worst-case ordering so update_batch forces rows comparison for all groups. + let worst_ord: ArrayRef = Arc::new(Int64Array::from(vec![ + if is_first { + i64::MAX + } else { + i64::MIN + }; + n + ])); + + c.bench_function(name, |b| { + b.iter_batched( + || { + let mut accumulator = prepare_groups_accumulator(is_first); + accumulator + .update_batch( + &[Arc::clone(&values), Arc::clone(&worst_ord)], + &group_indices, + None, // no filter: ensure all groups are initialised + num_groups, + ) + .unwrap(); + accumulator + }, + |mut accumulator| { + for _ in 0..100 { + #[expect(clippy::unit_arg)] + black_box( + accumulator + .update_batch( + &[Arc::clone(&values), Arc::clone(&ord)], + &group_indices, + opt_filter, + num_groups, + ) + .unwrap(), + ); + } + }, + BatchSize::SmallInput, + ) + }); +} + +#[expect(clippy::needless_pass_by_value)] +fn merge_bench( + c: &mut Criterion, + is_first: bool, + name: &str, + values: ArrayRef, + ord: ArrayRef, + opt_filter: Option<&BooleanArray>, + num_groups: usize, +) { + let n = values.len(); + let group_indices: Vec = (0..n).map(|i| i % num_groups).collect(); + let is_set: ArrayRef = Arc::new(BooleanArray::from(vec![true; n])); + + // Initialize with worst-case ordering so update_batch forces rows comparison for all groups. + let worst_ord: ArrayRef = Arc::new(Int64Array::from(vec![ + if is_first { + i64::MAX + } else { + i64::MIN + }; + n + ])); + + c.bench_function(name, |b| { + b.iter_batched( + || { + // Prebuild accumulator + let mut accumulator = prepare_groups_accumulator(is_first); + accumulator + .update_batch( + &[Arc::clone(&values), Arc::clone(&worst_ord)], + &group_indices, + opt_filter, + num_groups, + ) + .unwrap(); + accumulator + }, + |mut accumulator| { + for _ in 0..100 { + #[expect(clippy::unit_arg)] + black_box( + accumulator + .merge_batch( + &[ + Arc::clone(&values), + Arc::clone(&ord), + Arc::clone(&is_set), + ], + &group_indices, + opt_filter, + num_groups, + ) + .unwrap(), + ); + } + }, + BatchSize::SmallInput, + ) + }); +} + +#[expect(clippy::needless_pass_by_value)] +fn trivial_update_bench( + c: &mut Criterion, + is_first: bool, + ignore_nulls: bool, + name: &str, + values: ArrayRef, +) { + c.bench_function(name, |b| { + b.iter_custom(|iters| { + // The bench is way too fast, so apply scaling factor + let mut accumulators: Vec> = (0..iters * 100) + .map(|_| create_trivial_accumulator(is_first, ignore_nulls)) + .collect(); + let start = Instant::now(); + for acc in &mut accumulators { + #[expect(clippy::unit_arg)] + black_box(acc.update_batch(&[Arc::clone(&values)]).unwrap()); + } + start.elapsed() + }) + }); +} + +fn first_last_benchmark(c: &mut Criterion) { + const N: usize = 65536; + const NUM_GROUPS: usize = 1024; + + assert_eq!(N % NUM_GROUPS, 0); + + for is_first in [true, false] { + for pct in [0, 90] { + let fn_name = if is_first { + "first_value" + } else { + "last_value" + }; + + let null_density = (pct as f32) / 100.0; + let values = Arc::new(create_primitive_array::(N, null_density)) + as ArrayRef; + let ord = Arc::new(create_primitive_array::(N, null_density)) + as ArrayRef; + + for with_filter in [false, true] { + let filter = create_boolean_array(N, 0.0, 0.5); + let opt_filter = if with_filter { Some(&filter) } else { None }; + + evaluate_bench( + c, + is_first, + EmitTo::First(2), + &format!( + "{fn_name} evaluate_bench nulls={pct}%, filter={with_filter}, first(2)" + ), + values.clone(), + ord.clone(), + opt_filter, + NUM_GROUPS, + ); + evaluate_bench( + c, + is_first, + EmitTo::All, + &format!( + "{fn_name} evaluate_bench nulls={pct}%, filter={with_filter}, all" + ), + values.clone(), + ord.clone(), + opt_filter, + NUM_GROUPS, + ); + + update_bench( + c, + is_first, + &format!("{fn_name} update_bench nulls={pct}%, filter={with_filter}"), + values.clone(), + ord.clone(), + opt_filter, + NUM_GROUPS, + ); + merge_bench( + c, + is_first, + &format!("{fn_name} merge_bench nulls={pct}%, filter={with_filter}"), + values.clone(), + ord.clone(), + opt_filter, + NUM_GROUPS, + ); + } + + for ignore_nulls in [false, true] { + trivial_update_bench( + c, + is_first, + ignore_nulls, + &format!( + "{fn_name} trivial_update_bench nulls={pct}%, ignore_nulls={ignore_nulls}" + ), + values.clone(), + ); + } + } + } +} + +criterion_group!(benches, first_last_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-aggregate/benches/median.rs b/datafusion/functions-aggregate/benches/median.rs new file mode 100644 index 0000000000000..0f5f70c7b47f4 --- /dev/null +++ b/datafusion/functions-aggregate/benches/median.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::hint::black_box; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float64Array}; +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::{Accumulator, AggregateUDFImpl}; +use datafusion_functions_aggregate::median::Median; +use datafusion_physical_expr::expressions::col; + +const STEP_SIZE: usize = 128; +const SLIDES_PER_ITER: usize = 32; +const WINDOW_SIZES: [usize; 3] = [256, 4096, 16384]; + +fn prepare_accumulator() -> Box { + let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Float64, true)])); + let expr = col("f", &schema).unwrap(); + let accumulator_args = AccumulatorArgs { + return_field: Field::new("f", DataType::Float64, true).into(), + schema: &schema, + expr_fields: &[expr.return_field(&schema).unwrap()], + ignore_nulls: false, + order_bys: &[], + is_reversed: false, + name: "median(f)", + is_distinct: false, + exprs: &[expr], + }; + Median::new().accumulator(accumulator_args).unwrap() +} + +fn stream_array(len: usize, null_stride: Option) -> ArrayRef { + let values = (0..len) + .map(|idx| { + if null_stride.is_some_and(|stride| idx % stride == 0) { + None + } else { + Some(idx as f64) + } + }) + .collect::>(); + Arc::new(Float64Array::from(values)) as ArrayRef +} + +/// Benchmark the sliding window cycle: retract + update + evaluate +fn sliding_window_bench( + c: &mut Criterion, + name: &str, + window_size: usize, + stream: &ArrayRef, +) { + c.bench_function(name, |b| { + b.iter_batched( + || { + let mut accumulator = prepare_accumulator(); + let initial = stream.slice(0, window_size); + accumulator + .update_batch(std::slice::from_ref(&initial)) + .unwrap(); + accumulator + }, + |mut accumulator| { + for slide in 0..SLIDES_PER_ITER { + let offset = slide * STEP_SIZE; + let retract = stream.slice(offset, STEP_SIZE); + let update = stream.slice(offset + window_size, STEP_SIZE); + accumulator + .retract_batch(std::slice::from_ref(&retract)) + .unwrap(); + accumulator + .update_batch(std::slice::from_ref(&update)) + .unwrap(); + black_box(accumulator.evaluate().unwrap()); + } + }, + BatchSize::SmallInput, + ) + }); +} + +fn median_benchmark(c: &mut Criterion) { + for window_size in WINDOW_SIZES { + let stream_len = window_size + STEP_SIZE * SLIDES_PER_ITER; + let stream_no_nulls = stream_array(stream_len, None); + let stream_with_nulls = stream_array(stream_len, Some(10)); + + sliding_window_bench( + c, + &format!("median sliding_window f64 no_nulls window_size={window_size}"), + window_size, + &stream_no_nulls, + ); + + sliding_window_bench( + c, + &format!("median sliding_window f64 with_nulls window_size={window_size}"), + window_size, + &stream_with_nulls, + ); + } +} + +criterion_group!(benches, median_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-aggregate/benches/min_max_bytes.rs b/datafusion/functions-aggregate/benches/min_max_bytes.rs index 6d76ff2d0366d..9f4eb0f0c6246 100644 --- a/datafusion/functions-aggregate/benches/min_max_bytes.rs +++ b/datafusion/functions-aggregate/benches/min_max_bytes.rs @@ -29,8 +29,8 @@ use arrow::{ array::{ArrayRef, StringArray}, datatypes::{DataType, Field, Schema}, }; -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; -use datafusion_expr::{function::AccumulatorArgs, GroupsAccumulator}; +use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; +use datafusion_expr::{GroupsAccumulator, function::AccumulatorArgs}; use datafusion_functions_aggregate::min_max; use datafusion_physical_expr::expressions::col; diff --git a/datafusion/functions-aggregate/benches/percentile_cont.rs b/datafusion/functions-aggregate/benches/percentile_cont.rs new file mode 100644 index 0000000000000..05119441e1b10 --- /dev/null +++ b/datafusion/functions-aggregate/benches/percentile_cont.rs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::hint::black_box; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Float64Array}; +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::{Accumulator, AggregateUDFImpl}; +use datafusion_functions_aggregate::percentile_cont::PercentileCont; +use datafusion_physical_expr::expressions::{col, lit}; + +const STEP_SIZE: usize = 128; +const SLIDES_PER_ITER: usize = 32; +const WINDOW_SIZES: [usize; 3] = [256, 4096, 16384]; + +fn prepare_accumulator() -> Box { + let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Float64, true)])); + let value_expr = col("f", &schema).unwrap(); + let percentile_expr = lit(0.5_f64); + let value_field = value_expr.return_field(&schema).unwrap(); + let percentile_field = percentile_expr.return_field(&schema).unwrap(); + let accumulator_args = AccumulatorArgs { + return_field: Field::new("f", DataType::Float64, true).into(), + schema: &schema, + expr_fields: &[value_field, percentile_field], + ignore_nulls: false, + order_bys: &[], + is_reversed: false, + name: "percentile_cont(f, 0.5)", + is_distinct: false, + exprs: &[value_expr, percentile_expr], + }; + PercentileCont::new().accumulator(accumulator_args).unwrap() +} + +fn stream_array(len: usize, null_stride: Option) -> ArrayRef { + let values = (0..len) + .map(|idx| { + if null_stride.is_some_and(|stride| idx % stride == 0) { + None + } else { + Some(idx as f64) + } + }) + .collect::>(); + Arc::new(Float64Array::from(values)) as ArrayRef +} + +/// Benchmark the sliding window cycle: retract + update + evaluate +fn sliding_window_bench( + c: &mut Criterion, + name: &str, + window_size: usize, + stream: &ArrayRef, +) { + c.bench_function(name, |b| { + b.iter_batched( + || { + let mut accumulator = prepare_accumulator(); + let initial = stream.slice(0, window_size); + accumulator + .update_batch(std::slice::from_ref(&initial)) + .unwrap(); + accumulator + }, + |mut accumulator| { + for slide in 0..SLIDES_PER_ITER { + let offset = slide * STEP_SIZE; + let retract = stream.slice(offset, STEP_SIZE); + let update = stream.slice(offset + window_size, STEP_SIZE); + accumulator + .retract_batch(std::slice::from_ref(&retract)) + .unwrap(); + accumulator + .update_batch(std::slice::from_ref(&update)) + .unwrap(); + black_box(accumulator.evaluate().unwrap()); + } + }, + BatchSize::SmallInput, + ) + }); +} + +fn percentile_cont_benchmark(c: &mut Criterion) { + for window_size in WINDOW_SIZES { + let stream_len = window_size + STEP_SIZE * SLIDES_PER_ITER; + let stream_no_nulls = stream_array(stream_len, None); + let stream_with_nulls = stream_array(stream_len, Some(10)); + + sliding_window_bench( + c, + &format!( + "percentile_cont sliding_window f64 no_nulls window_size={window_size}" + ), + window_size, + &stream_no_nulls, + ); + + sliding_window_bench( + c, + &format!( + "percentile_cont sliding_window f64 with_nulls window_size={window_size}" + ), + window_size, + &stream_with_nulls, + ); + } +} + +criterion_group!(benches, percentile_cont_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-aggregate/benches/sum.rs b/datafusion/functions-aggregate/benches/sum.rs index d85f0686224b3..52998179024c1 100644 --- a/datafusion/functions-aggregate/benches/sum.rs +++ b/datafusion/functions-aggregate/benches/sum.rs @@ -22,11 +22,11 @@ use arrow::array::{ArrayRef, BooleanArray}; use arrow::datatypes::{DataType, Field, Int64Type, Schema}; use arrow::util::bench_util::{create_boolean_array, create_primitive_array}; -use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator}; +use datafusion_expr::{AggregateUDFImpl, GroupsAccumulator, function::AccumulatorArgs}; use datafusion_functions_aggregate::sum::Sum; use datafusion_physical_expr::expressions::col; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; fn prepare_accumulator(data_type: &DataType) -> Box { let field = Field::new("f", data_type.clone(), true).into(); @@ -47,6 +47,7 @@ fn prepare_accumulator(data_type: &DataType) -> Box { sum_fn.create_groups_accumulator(accumulator_args).unwrap() } +#[expect(clippy::needless_pass_by_value)] fn convert_to_state_bench( c: &mut Criterion, name: &str, diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index 58c2a5489d6a2..38b902964f546 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -17,35 +17,43 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use crate::hyperloglog::HyperLogLog; -use arrow::array::{BinaryArray, StringViewArray}; +use crate::hyperloglog::{HLL_HASH_STATE, HyperLogLog, NUM_REGISTERS, count_from_hashes}; +use arrow::array::{Array, BinaryArray, StringViewArray}; use arrow::array::{ - GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, + AsArray, BinaryBuilder, BooleanArray, GenericBinaryArray, GenericStringArray, + OffsetSizeTrait, PrimitiveArray, UInt64Array, }; +use arrow::buffer::NullBuffer; use arrow::datatypes::{ - ArrowPrimitiveType, Date32Type, Date64Type, FieldRef, Int16Type, Int32Type, - Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, - Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, + ArrowPrimitiveType, Date32Type, Date64Type, FieldRef, Int32Type, Int64Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, UInt32Type, UInt64Type, }; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; use datafusion_common::ScalarValue; use datafusion_common::{ - downcast_value, internal_datafusion_err, internal_err, not_impl_err, DataFusionError, - Result, + DataFusionError, Result, downcast_value, internal_datafusion_err, internal_err, + not_impl_err, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, Signature, + Volatility, }; +use datafusion_functions_aggregate_common::aggregate::count_distinct::{ + Bitmap65536DistinctCountAccumulator, Bitmap65536DistinctCountAccumulatorI16, + BoolArray256DistinctCountAccumulator, BoolArray256DistinctCountAccumulatorI8, + BooleanDistinctCountAccumulator, +}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filter_to_nulls; use datafusion_functions_aggregate_common::noop_accumulator::NoopAccumulator; use datafusion_macros::user_doc; -use std::any::Any; use std::fmt::{Debug, Formatter}; -use std::hash::Hash; +use std::hash::{BuildHasher, Hash}; use std::marker::PhantomData; +use std::sync::Arc; make_udaf_expr_and_func!( ApproxDistinct, @@ -55,14 +63,14 @@ make_udaf_expr_and_func!( approx_distinct_udaf ); -impl From<&HyperLogLog> for ScalarValue { +impl From<&HyperLogLog> for ScalarValue { fn from(v: &HyperLogLog) -> ScalarValue { let values = v.as_ref().to_vec(); ScalarValue::Binary(Some(values)) } } -impl TryFrom<&[u8]> for HyperLogLog { +impl TryFrom<&[u8]> for HyperLogLog { type Error = DataFusionError; fn try_from(v: &[u8]) -> Result> { let arr: [u8; 16384] = v.try_into().map_err(|_| { @@ -72,7 +80,7 @@ impl TryFrom<&[u8]> for HyperLogLog { } } -impl TryFrom<&ScalarValue> for HyperLogLog { +impl TryFrom<&ScalarValue> for HyperLogLog { type Error = DataFusionError; fn try_from(v: &ScalarValue) -> Result> { if let ScalarValue::Binary(Some(slice)) = v { @@ -85,6 +93,36 @@ impl TryFrom<&ScalarValue> for HyperLogLog { } } +#[derive(Debug)] +struct ApproxDistinctBitmapWrapper { + inner: A, +} + +impl Accumulator for ApproxDistinctBitmapWrapper { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.inner.update_batch(values) + } + + fn evaluate(&mut self) -> Result { + match self.inner.evaluate()? { + ScalarValue::Int64(Some(v)) => Ok(ScalarValue::UInt64(Some(v as u64))), + other => internal_err!("unexpected: {other}"), + } + } + + fn size(&self) -> usize { + self.inner.size() + } + + fn state(&mut self) -> Result> { + self.inner.state() + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.inner.merge_batch(states) + } +} + #[derive(Debug)] struct NumericHLLAccumulator where @@ -99,7 +137,6 @@ where T: ArrowPrimitiveType, T::Native: Hash, { - /// new approx_distinct accumulator pub fn new() -> Self { Self { hll: HyperLogLog::new(), @@ -112,7 +149,7 @@ struct StringHLLAccumulator where T: OffsetSizeTrait, { - hll: HyperLogLog, + hll: HyperLogLog, phantom_data: PhantomData, } @@ -120,7 +157,6 @@ impl StringHLLAccumulator where T: OffsetSizeTrait, { - /// new approx_distinct accumulator pub fn new() -> Self { Self { hll: HyperLogLog::new(), @@ -130,22 +166,14 @@ where } #[derive(Debug)] -struct StringViewHLLAccumulator -where - T: OffsetSizeTrait, -{ - hll: HyperLogLog, - phantom_data: PhantomData, +struct StringViewHLLAccumulator { + hll: HyperLogLog, } -impl StringViewHLLAccumulator -where - T: OffsetSizeTrait, -{ +impl StringViewHLLAccumulator { pub fn new() -> Self { Self { hll: HyperLogLog::new(), - phantom_data: PhantomData, } } } @@ -155,7 +183,7 @@ struct BinaryHLLAccumulator where T: OffsetSizeTrait, { - hll: HyperLogLog>, + hll: HyperLogLog<[u8]>, phantom_data: PhantomData, } @@ -163,7 +191,6 @@ impl BinaryHLLAccumulator where T: OffsetSizeTrait, { - /// new approx_distinct accumulator pub fn new() -> Self { Self { hll: HyperLogLog::new(), @@ -213,23 +240,42 @@ where let array: &GenericBinaryArray = downcast_value!(values[0], GenericBinaryArray, T); // flatten because we would skip nulls - self.hll - .extend(array.into_iter().flatten().map(|v| v.to_vec())); + self.hll.extend(array.into_iter().flatten()); Ok(()) } default_accumulator_impl!(); } -impl Accumulator for StringViewHLLAccumulator -where - T: OffsetSizeTrait, -{ +impl Accumulator for StringViewHLLAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let array: &StringViewArray = downcast_value!(values[0], StringViewArray); - // flatten because we would skip nulls - self.hll - .extend(array.iter().flatten().map(|s| s.to_string())); + + if array.data_buffers().is_empty() { + // Fast path: with no data buffers every value is inline, so they all + // take the u128 path — no need to check the length per row. + for (i, &view) in array.views().iter().enumerate() { + if !array.is_null(i) { + self.hll.add_hashed(HLL_HASH_STATE.hash_one(view)); + } + } + } else { + // Mixed batch: decide per row by length. Short strings still use the + // u128 path so they match how they'd be hashed in an all-inline + // batch; only the genuinely out-of-line strings materialize a &str. + for (i, &view) in array.views().iter().enumerate() { + if array.is_null(i) { + continue; + } + // The low 32 bits of the u128 view encode the string length. + if (view as u32) <= 12 { + self.hll.add_hashed(HLL_HASH_STATE.hash_one(view)); + } else { + self.hll.add(array.value(i)); + } + } + } + Ok(()) } @@ -244,8 +290,7 @@ where let array: &GenericStringArray = downcast_value!(values[0], GenericStringArray, T); // flatten because we would skip nulls - self.hll - .extend(array.into_iter().flatten().map(|i| i.to_string())); + self.hll.extend(array.into_iter().flatten()); Ok(()) } @@ -267,6 +312,455 @@ where default_accumulator_impl!(); } +/// Maximum number of distinct hashes kept in the sparse representation of a +/// per-group sketch before it is promoted to a dense [`HyperLogLog`]. +/// +/// A dense sketch always occupies [`NUM_REGISTERS`] (16 KiB) regardless of how +/// many values it has seen. The vast majority of groups in a high-cardinality +/// `GROUP BY` only observe a handful of distinct values, so keeping their state +/// as a small list of hashes saves a huge amount of memory (both while +/// aggregating and when serializing the partial state for the final phase). +const SPARSE_LIMIT: usize = 256; + +/// Per-group HyperLogLog state used by [`HllGroupsAccumulator`]. +/// +/// Starts out as a compact list of the (deduplicated) hashes observed for the +/// group and only switches to a full dense [`HyperLogLog`] once it has seen more +/// than [`SPARSE_LIMIT`] distinct values. Folding the stored hashes into a dense +/// sketch produces exactly the same registers as adding the original values one +/// by one, so the cardinality estimate is identical to the per-group +/// [`Accumulator`] path. +#[derive(Clone, Debug)] +enum GroupHll { + /// Distinct hashes seen so far. May contain duplicates between compactions. + Sparse(Vec), + Dense(Box>), +} + +impl Default for GroupHll { + fn default() -> Self { + GroupHll::Sparse(Vec::new()) + } +} + +/// Fold a slice of pre-computed hashes into a fresh [`HyperLogLog`] sketch. +fn fold_sparse_to_hll(hashes: &[u64]) -> HyperLogLog { + let mut hll = HyperLogLog::::new(); + for &h in hashes { + hll.add_hashed(h); + } + hll +} + +impl GroupHll { + /// Add a pre-computed hash, returning the change in heap-allocated bytes so + /// the accumulator can track its memory usage incrementally. + #[inline] + fn add_hash(&mut self, hash: u64) -> isize { + match self { + GroupHll::Dense(hll) => { + hll.add_hashed(hash); + 0 + } + GroupHll::Sparse(v) => { + let cap_before = v.capacity(); + v.push(hash); + if v.len() >= 2 * SPARSE_LIMIT { + return self.compact_or_promote(cap_before); + } + ((v.capacity() - cap_before) * size_of::()) as isize + } + } + } + + /// Deduplicate the sparse hash list and, if it still exceeds + /// [`SPARSE_LIMIT`] distinct values, promote it to a dense sketch. + #[cold] + fn compact_or_promote(&mut self, cap_before: usize) -> isize { + let GroupHll::Sparse(v) = self else { + return 0; + }; + v.sort_unstable(); + v.dedup(); + if v.len() > SPARSE_LIMIT { + // cap_before is the capacity already reflected in allocated_bytes. + // Any reallocation caused by the triggering push was never counted and + // is also freed here, so the two cancel out. + *self = GroupHll::Dense(Box::new(fold_sparse_to_hll(v))); + (NUM_REGISTERS as isize) - ((cap_before * size_of::()) as isize) + } else { + // Account for any Vec growth caused by the triggering push. + // sort/dedup do not reallocate, so v.capacity() is the post-push capacity. + ((v.capacity() - cap_before) * size_of::()) as isize + } + } + + /// Merge a serialized state (produced by [`Self::serialize`] or by the + /// per-group [`Accumulator`]) into this sketch. + fn merge_serialized(&mut self, bytes: &[u8]) -> Result { + if bytes.is_empty() { + return Ok(0); + } + if bytes.len() == NUM_REGISTERS { + let other: HyperLogLog = bytes.try_into()?; + Ok(self.merge_dense(&other)) + } else { + if !bytes.len().is_multiple_of(size_of::()) { + return internal_err!( + "approx_distinct: malformed sparse state: length {} is not a multiple of {}", + bytes.len(), + size_of::() + ); + } + if bytes.len() > SPARSE_LIMIT * size_of::() { + return internal_err!( + "approx_distinct: malformed sparse state: length {} exceeds sparse limit {}", + bytes.len(), + SPARSE_LIMIT * size_of::() + ); + } + let mut delta = 0; + for chunk in bytes.chunks_exact(size_of::()) { + let h = u64::from_le_bytes(chunk.try_into().unwrap()); + delta += self.add_hash(h); + } + Ok(delta) + } + } + + /// Merge a dense sketch into this one, promoting to dense if necessary. + fn merge_dense(&mut self, other: &HyperLogLog) -> isize { + match self { + GroupHll::Dense(hll) => { + hll.merge(other); + 0 + } + GroupHll::Sparse(v) => { + let cap_before = v.capacity(); + let mut hll = other.clone(); + for &h in v.iter() { + hll.add_hashed(h); + } + *self = GroupHll::Dense(Box::new(hll)); + (NUM_REGISTERS as isize) - ((cap_before * size_of::()) as isize) + } + } + } + + /// The approximate number of distinct values seen by this group. + fn count(&self) -> u64 { + match self { + GroupHll::Dense(hll) => hll.count() as u64, + // Estimate directly from the stored hashes; this produces exactly the + // same value as folding them into a dense sketch but avoids + // allocating and scanning a 16 KiB register array for every group. + GroupHll::Sparse(v) => count_from_hashes(v) as u64, + } + } + + /// Heap bytes held by this sketch. Mirrors the deltas accrued in + /// [`Self::add_hash`] / [`Self::merge_dense`] so emitting a group can + /// precisely reverse them. + fn heap_bytes(&self) -> usize { + match self { + GroupHll::Sparse(v) => v.capacity() * size_of::(), + GroupHll::Dense(_) => NUM_REGISTERS, + } + } + + /// Serialize the sketch into `scratch` (which is cleared first). A dense + /// sketch is written as its raw [`NUM_REGISTERS`] registers (wire-compatible + /// with the per-group [`Accumulator`]); a sparse sketch is written as its + /// distinct hashes in little-endian order unless it has crossed + /// [`SPARSE_LIMIT`], in which case it is emitted as dense state so the final + /// merge path accepts it. + fn serialize(&mut self, scratch: &mut Vec) { + scratch.clear(); + match self { + GroupHll::Dense(hll) => { + let registers: &[u8] = (**hll).as_ref(); + scratch.extend_from_slice(registers); + } + GroupHll::Sparse(v) => { + v.sort_unstable(); + v.dedup(); + if v.len() > SPARSE_LIMIT { + scratch.extend_from_slice(fold_sparse_to_hll(v).as_ref()); + } else { + for &h in v.iter() { + scratch.extend_from_slice(&h.to_le_bytes()); + } + } + } + } + } +} + +/// Computes HyperLogLog hashes for the rows of an input array, type by type. +/// +/// The hashing matches the per-group [`Accumulator`] implementations exactly so +/// that the grouped and ungrouped paths produce identical estimates. +trait HllValueHasher: Send + Sync + 'static { + /// Invoke `f(row_index, hash)` for every row that is valid according to + /// `nulls`. `nulls = None` means every row is valid (caller has + /// pre-combined value-nulls and filter into a single buffer). + fn for_each_hash( + array: &dyn Array, + nulls: Option<&NullBuffer>, + f: impl FnMut(usize, u64), + ); +} + +struct NumericHasher(PhantomData); + +impl HllValueHasher for NumericHasher +where + T: ArrowPrimitiveType + Send + Sync + 'static, + T::Native: Hash, +{ + #[inline] + fn for_each_hash( + array: &dyn Array, + nulls: Option<&NullBuffer>, + mut f: impl FnMut(usize, u64), + ) { + let array: &PrimitiveArray = array.as_primitive::(); + match nulls { + None => { + for (i, v) in array.values().iter().enumerate() { + f(i, HLL_HASH_STATE.hash_one(v)); + } + } + Some(nulls) => { + for i in 0..array.len() { + if nulls.is_valid(i) { + f(i, HLL_HASH_STATE.hash_one(array.value(i))); + } + } + } + } + } +} + +struct Utf8Hasher(PhantomData); + +impl HllValueHasher for Utf8Hasher { + #[inline] + fn for_each_hash( + array: &dyn Array, + nulls: Option<&NullBuffer>, + mut f: impl FnMut(usize, u64), + ) { + let array: &GenericStringArray = array.as_string::(); + for i in 0..array.len() { + if nulls.is_none_or(|n| n.is_valid(i)) { + f(i, HLL_HASH_STATE.hash_one(array.value(i))); + } + } + } +} + +struct Utf8ViewHasher; + +impl HllValueHasher for Utf8ViewHasher { + #[inline] + fn for_each_hash( + array: &dyn Array, + nulls: Option<&NullBuffer>, + mut f: impl FnMut(usize, u64), + ) { + let array: &StringViewArray = array.as_string_view(); + // Mirror `StringViewHLLAccumulator`: hash the raw inline view when all + // strings are stored inline (≤ 12 bytes), avoiding `&str` materialization. + if array.data_buffers().is_empty() { + let views = array.views(); + for i in 0..array.len() { + if nulls.is_none_or(|n| n.is_valid(i)) { + f(i, HLL_HASH_STATE.hash_one(views[i])); + } + } + } else { + // Mixed batch: short strings (≤ 12 bytes) are still inline and must + // be hashed as the raw u128 view to match the all-inline fast path. + let views = array.views(); + for i in 0..array.len() { + if nulls.is_none_or(|n| n.is_valid(i)) { + let view = views[i]; + if (view as u32) <= 12 { + f(i, HLL_HASH_STATE.hash_one(view)); + } else { + f(i, HLL_HASH_STATE.hash_one(array.value(i))); + } + } + } + } + } +} + +struct BinaryHasher(PhantomData); + +impl HllValueHasher for BinaryHasher { + #[inline] + fn for_each_hash( + array: &dyn Array, + nulls: Option<&NullBuffer>, + mut f: impl FnMut(usize, u64), + ) { + let array: &GenericBinaryArray = array.as_binary::(); + for i in 0..array.len() { + if nulls.is_none_or(|n| n.is_valid(i)) { + f(i, HLL_HASH_STATE.hash_one(array.value(i))); + } + } + } +} + +/// A [`GroupsAccumulator`] for `approx_distinct` that keeps one adaptive +/// (sparse → dense) HyperLogLog sketch per group. +/// +/// This is dramatically faster than the generic `GroupsAccumulatorAdapter` +/// fallback for high-cardinality `GROUP BY`s: it processes the whole input in a +/// single vectorized pass (no per-group `take`/slice and no dynamic dispatch), +/// and the sparse representation avoids allocating a 16 KiB sketch for every +/// group when most groups only see a few distinct values. +/// +/// +/// # Example +/// +/// For `SELECT k, approx_distinct(v) FROM t GROUP BY k`, each group owns one +/// independent sketch: +/// +/// ```text +/// group state +/// a Sparse([h1, h2, h3, h2]) +/// b Dense(HLL registers) +/// ... +/// ``` +/// +/// Group `a` has fewer than [`SPARSE_LIMIT`] distinct hashes, so it stays in +/// the sparse representation. Before emitting state or estimating the count, the +/// hash list is sorted and deduplicated to `[h1, h2, h3]`, then those hashes are +/// interpreted exactly as if they had been added to a dense [`HyperLogLog`]. +/// +/// Group `b` has crossed the sparse limit, so its hashes have already been +/// replayed into a dense sketch. New values for `b` update the dense registers +/// directly, and serialized state is the raw [`NUM_REGISTERS`]-byte register +/// array. +struct HllGroupsAccumulator { + /// Per-group sketches, indexed by `group_index`. + groups: Vec, + /// Incrementally maintained estimate of heap bytes used by `groups`. + allocated_bytes: usize, + phantom: PhantomData, +} + +impl HllGroupsAccumulator { + fn new() -> Self { + Self { + groups: Vec::new(), + allocated_bytes: 0, + phantom: PhantomData, + } + } + + #[inline] + fn ensure_groups(&mut self, total_num_groups: usize) { + if total_num_groups > self.groups.len() { + self.groups.resize_with(total_num_groups, GroupHll::default); + } + } + + #[inline] + fn apply_delta(&mut self, delta: isize) { + self.allocated_bytes = + (self.allocated_bytes as isize).saturating_add(delta).max(0) as usize; + } +} + +impl GroupsAccumulator for HllGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.ensure_groups(total_num_groups); + let groups = &mut self.groups; + let mut delta: isize = 0; + // Pre-combine value-nulls and filter into one mask so the callback + // needs no per-row branching. + let filter_nulls = opt_filter.map(filter_to_nulls); + let value_nulls = values[0].logical_nulls(); + let combined_nulls = + NullBuffer::union(filter_nulls.as_ref(), value_nulls.as_ref()); + H::for_each_hash(values[0].as_ref(), combined_nulls.as_ref(), |row, hash| { + delta += groups[group_indices[row]].add_hash(hash); + }); + self.apply_delta(delta); + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + // Since aggregate filter should be applied in partial stage, in final stage there should be no filter + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert!( + opt_filter.is_none(), + "aggregate filter should be applied in partial stage, there should be no filter in final stage" + ); + + self.ensure_groups(total_num_groups); + let states = downcast_value!(values[0], BinaryArray); + let mut delta: isize = 0; + for (row, &group_index) in group_indices.iter().enumerate() { + if states.is_valid(row) { + delta += self.groups[group_index].merge_serialized(states.value(row))?; + } + } + self.apply_delta(delta); + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let groups = emit_to.take_needed(&mut self.groups); + let mut freed = 0; + let counts: UInt64Array = groups + .iter() + .map(|g| { + freed += g.heap_bytes(); + Some(g.count()) + }) + .collect(); + // The emitted groups have been removed; reclaim their tracked bytes. + self.allocated_bytes = self.allocated_bytes.saturating_sub(freed); + Ok(Arc::new(counts)) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + let mut groups = emit_to.take_needed(&mut self.groups); + let mut builder = BinaryBuilder::new(); + let mut scratch: Vec = Vec::new(); + let mut freed = 0; + for g in groups.iter_mut() { + freed += g.heap_bytes(); + g.serialize(&mut scratch); + builder.append_value(&scratch); + } + // The emitted groups have been removed; reclaim their tracked bytes. + self.allocated_bytes = self.allocated_bytes.saturating_sub(freed); + Ok(vec![Arc::new(builder.finish())]) + } + + fn size(&self) -> usize { + self.groups.capacity() * size_of::() + self.allocated_bytes + } +} + impl Debug for ApproxDistinct { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("ApproxDistinct") @@ -309,11 +803,46 @@ impl ApproxDistinct { } } -impl AggregateUDFImpl for ApproxDistinct { - fn as_any(&self) -> &dyn Any { - self +#[cold] +fn get_fixed_domain_approx_accumulator( + data_type: &DataType, +) -> Result> { + match data_type { + DataType::Boolean => Ok(Box::new(ApproxDistinctBitmapWrapper { + inner: BooleanDistinctCountAccumulator::new(), + })), + DataType::UInt8 => Ok(Box::new(ApproxDistinctBitmapWrapper { + inner: BoolArray256DistinctCountAccumulator::new(), + })), + DataType::Int8 => Ok(Box::new(ApproxDistinctBitmapWrapper { + inner: BoolArray256DistinctCountAccumulatorI8::new(), + })), + DataType::UInt16 => Ok(Box::new(ApproxDistinctBitmapWrapper { + inner: Bitmap65536DistinctCountAccumulator::new(), + })), + DataType::Int16 => Ok(Box::new(ApproxDistinctBitmapWrapper { + inner: Bitmap65536DistinctCountAccumulatorI16::new(), + })), + _ => internal_err!("unsupported small int type: {}", data_type), } +} +#[cold] +fn get_fixed_domain_state_field( + name: &str, + data_type: &DataType, +) -> Result> { + Ok(vec![ + Field::new_list( + format_state_name(name, "approx_distinct"), + Field::new_list_field(data_type.clone(), true), + false, + ) + .into(), + ]) +} + +impl AggregateUDFImpl for ApproxDistinct { fn name(&self) -> &str { "approx_distinct" } @@ -326,21 +855,38 @@ impl AggregateUDFImpl for ApproxDistinct { Ok(DataType::UInt64) } + fn default_value(&self, _data_type: &DataType) -> Result { + Ok(ScalarValue::UInt64(Some(0))) + } + + fn is_nullable(&self) -> bool { + false + } + fn state_fields(&self, args: StateFieldsArgs) -> Result> { - if args.input_fields[0].data_type().is_null() { - Ok(vec![Field::new( - format_state_name(args.name, self.name()), - DataType::Null, - true, - ) - .into()]) - } else { - Ok(vec![Field::new( - format_state_name(args.name, "hll_registers"), - DataType::Binary, - false, - ) - .into()]) + let data_type = args.input_fields[0].data_type(); + match data_type { + DataType::Null => Ok(vec![ + Field::new( + format_state_name(args.name, self.name()), + DataType::Null, + true, + ) + .into(), + ]), + DataType::Boolean + | DataType::UInt8 + | DataType::Int8 + | DataType::UInt16 + | DataType::Int16 => get_fixed_domain_state_field(args.name, data_type), + _ => Ok(vec![ + Field::new( + format_state_name(args.name, "hll_registers"), + DataType::Binary, + false, + ) + .into(), + ]), } } @@ -348,15 +894,15 @@ impl AggregateUDFImpl for ApproxDistinct { let data_type = acc_args.expr_fields[0].data_type(); let accumulator: Box = match data_type { - // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL - // TODO support for boolean (trivial case) - // https://github.com/apache/datafusion/issues/1109 - DataType::UInt8 => Box::new(NumericHLLAccumulator::::new()), - DataType::UInt16 => Box::new(NumericHLLAccumulator::::new()), + DataType::Boolean + | DataType::UInt8 + | DataType::Int8 + | DataType::UInt16 + | DataType::Int16 => { + return get_fixed_domain_approx_accumulator(data_type); + } DataType::UInt32 => Box::new(NumericHLLAccumulator::::new()), DataType::UInt64 => Box::new(NumericHLLAccumulator::::new()), - DataType::Int8 => Box::new(NumericHLLAccumulator::::new()), - DataType::Int16 => Box::new(NumericHLLAccumulator::::new()), DataType::Int32 => Box::new(NumericHLLAccumulator::::new()), DataType::Int64 => Box::new(NumericHLLAccumulator::::new()), DataType::Date32 => Box::new(NumericHLLAccumulator::::new()), @@ -387,7 +933,7 @@ impl AggregateUDFImpl for ApproxDistinct { } DataType::Utf8 => Box::new(StringHLLAccumulator::::new()), DataType::LargeUtf8 => Box::new(StringHLLAccumulator::::new()), - DataType::Utf8View => Box::new(StringViewHLLAccumulator::::new()), + DataType::Utf8View => Box::new(StringViewHLLAccumulator::new()), DataType::Binary => Box::new(BinaryHLLAccumulator::::new()), DataType::LargeBinary => Box::new(BinaryHLLAccumulator::::new()), DataType::Null => { @@ -395,8 +941,84 @@ impl AggregateUDFImpl for ApproxDistinct { } other => { return not_impl_err!( - "Support for 'approx_distinct' for data type {other} is not implemented" - ) + "Support for 'approx_distinct' for data type {other} is not implemented" + ); + } + }; + Ok(accumulator) + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + is_hll_groups_type(args.expr_fields[0].data_type()) + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let data_type = args.expr_fields[0].data_type(); + let accumulator: Box = match data_type { + DataType::UInt32 => { + Box::new(HllGroupsAccumulator::>::new()) + } + DataType::UInt64 => { + Box::new(HllGroupsAccumulator::>::new()) + } + DataType::Int32 => { + Box::new(HllGroupsAccumulator::>::new()) + } + DataType::Int64 => { + Box::new(HllGroupsAccumulator::>::new()) + } + DataType::Date32 => { + Box::new(HllGroupsAccumulator::>::new()) + } + DataType::Date64 => { + Box::new(HllGroupsAccumulator::>::new()) + } + DataType::Time32(TimeUnit::Second) => { + Box::new(HllGroupsAccumulator::>::new()) + } + DataType::Time32(TimeUnit::Millisecond) => Box::new(HllGroupsAccumulator::< + NumericHasher, + >::new()), + DataType::Time64(TimeUnit::Microsecond) => Box::new(HllGroupsAccumulator::< + NumericHasher, + >::new()), + DataType::Time64(TimeUnit::Nanosecond) => Box::new(HllGroupsAccumulator::< + NumericHasher, + >::new()), + DataType::Timestamp(TimeUnit::Second, _) => Box::new(HllGroupsAccumulator::< + NumericHasher, + >::new()), + DataType::Timestamp(TimeUnit::Millisecond, _) => { + Box::new(HllGroupsAccumulator::< + NumericHasher, + >::new()) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + Box::new(HllGroupsAccumulator::< + NumericHasher, + >::new()) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( + HllGroupsAccumulator::>::new(), + ), + DataType::Utf8 => Box::new(HllGroupsAccumulator::>::new()), + DataType::LargeUtf8 => { + Box::new(HllGroupsAccumulator::>::new()) + } + DataType::Utf8View => Box::new(HllGroupsAccumulator::::new()), + DataType::Binary => { + Box::new(HllGroupsAccumulator::>::new()) + } + DataType::LargeBinary => { + Box::new(HllGroupsAccumulator::>::new()) + } + other => { + return not_impl_err!( + "GroupsAccumulator for 'approx_distinct' is not implemented for data type {other}" + ); } }; Ok(accumulator) @@ -406,3 +1028,265 @@ impl AggregateUDFImpl for ApproxDistinct { self.doc() } } + +/// Returns true for the data types backed by the HyperLogLog +/// [`HllGroupsAccumulator`]. The fixed-domain types (booleans / small ints) and +/// `Null` fall back to the per-group [`Accumulator`] path. +fn is_hll_groups_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::UInt32 + | DataType::UInt64 + | DataType::Int32 + | DataType::Int64 + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(TimeUnit::Second) + | DataType::Time32(TimeUnit::Millisecond) + | DataType::Time64(TimeUnit::Microsecond) + | DataType::Time64(TimeUnit::Nanosecond) + | DataType::Timestamp(TimeUnit::Second, _) + | DataType::Timestamp(TimeUnit::Millisecond, _) + | DataType::Timestamp(TimeUnit::Microsecond, _) + | DataType::Timestamp(TimeUnit::Nanosecond, _) + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Utf8View + | DataType::Binary + | DataType::LargeBinary + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{AsArray, Int64Array, StringViewArray}; + use std::sync::Arc; + + // A string longer than the 12-byte inline limit + const LONG: &str = "this string is definitely longer than twelve bytes"; + + fn h(v: u64) -> u64 { + HLL_HASH_STATE.hash_one(v) + } + + /// Reference count: fold the given distinct hashes straight into a dense + /// HyperLogLog. The grouped sketch must agree with this exactly. + fn reference_count(hashes: &[u64]) -> u64 { + let mut hll = HyperLogLog::::new(); + for &hash in hashes { + hll.add_hashed(hash); + } + hll.count() as u64 + } + + fn serialize(g: &mut GroupHll) -> Vec { + let mut buf = Vec::new(); + g.serialize(&mut buf); + buf + } + + fn distinct_count(acc: &mut StringViewHLLAccumulator) -> u64 { + match acc.evaluate().unwrap() { + ScalarValue::UInt64(Some(v)) => v, + other => panic!("unexpected evaluate result: {other:?}"), + } + } + + #[test] + fn sparse_stays_sparse_for_small_groups() { + let mut g = GroupHll::default(); + let hashes: Vec = (0..50).map(h).collect(); + for &hash in &hashes { + g.add_hash(hash); + } + // duplicates must not change the estimate or trigger promotion + for &hash in &hashes { + g.add_hash(hash); + } + assert!( + matches!(g, GroupHll::Sparse(_)), + "small group must be sparse" + ); + assert_eq!(g.count(), reference_count(&hashes)); + // sparse serialized state is far smaller than a dense 16 KiB sketch + // and must not exceed the sparse limit contract enforced by merge_serialized + let serialized = serialize(&mut g); + assert!(serialized.len() < NUM_REGISTERS); + assert!(serialized.len() <= SPARSE_LIMIT * size_of::()); + } + + #[test] + fn promotes_to_dense_for_large_groups() { + let mut g = GroupHll::default(); + let hashes: Vec = (0..(SPARSE_LIMIT as u64 * 4)).map(h).collect(); + for &hash in &hashes { + g.add_hash(hash); + } + assert!(matches!(g, GroupHll::Dense(_)), "large group must be dense"); + assert_eq!(g.count(), reference_count(&hashes)); + } + + #[test] + fn serialize_then_merge_roundtrips() { + for n in [0u64, 10, SPARSE_LIMIT as u64 * 4] { + let hashes: Vec = (0..n).map(h).collect(); + let mut src = GroupHll::default(); + for &hash in &hashes { + src.add_hash(hash); + } + let bytes = serialize(&mut src); + let mut dst = GroupHll::default(); + dst.merge_serialized(&bytes).unwrap(); + assert_eq!(dst.count(), reference_count(&hashes), "n = {n}"); + } + } + + #[test] + fn sparse_limit_group_serializes_as_mergeable_sparse_state() { + let hashes: Vec = (0..SPARSE_LIMIT as u64).map(h).collect(); + let mut src = GroupHll::default(); + for &hash in &hashes { + src.add_hash(hash); + } + assert!(matches!(src, GroupHll::Sparse(_))); + + let bytes = serialize(&mut src); + assert_eq!(bytes.len(), SPARSE_LIMIT * size_of::()); + + let mut dst = GroupHll::default(); + dst.merge_serialized(&bytes).unwrap(); + assert_eq!(dst.count(), reference_count(&hashes)); + } + + #[test] + fn medium_sparse_group_serializes_as_mergeable_dense_state() { + let n = SPARSE_LIMIT as u64 + 44; + let hashes: Vec = (0..n).map(h).collect(); + let mut src = GroupHll::default(); + for &hash in &hashes { + src.add_hash(hash); + } + assert!( + matches!(src, GroupHll::Sparse(_)), + "group should not promote during update before the compaction threshold" + ); + + let bytes = serialize(&mut src); + assert_eq!(bytes.len(), NUM_REGISTERS); + + let mut dst = GroupHll::default(); + dst.merge_serialized(&bytes).unwrap(); + assert_eq!(dst.count(), reference_count(&hashes)); + } + + #[test] + fn merge_combines_disjoint_groups() { + // sparse + sparse, sparse + dense, dense + dense + let left: Vec = (0..100).map(h).collect(); + let right: Vec = (100..(SPARSE_LIMIT as u64 * 4)).map(h).collect(); + let all: Vec = left.iter().chain(right.iter()).copied().collect(); + + let mut a = GroupHll::default(); + for &hash in &left { + a.add_hash(hash); + } + let mut b = GroupHll::default(); + for &hash in &right { + b.add_hash(hash); + } + let b_bytes = serialize(&mut b); + a.merge_serialized(&b_bytes).unwrap(); + assert_eq!(a.count(), reference_count(&all)); + } + + #[test] + fn empty_group_counts_zero() { + let mut g = GroupHll::default(); + assert_eq!(g.count(), 0); + let bytes = serialize(&mut g); + assert!(bytes.is_empty()); + let mut dst = GroupHll::default(); + dst.merge_serialized(&bytes).unwrap(); + assert_eq!(dst.count(), 0); + } + + /// `approx_distinct(v) FILTER (WHERE nullable_bool)` — a NULL filter row + /// must not be counted (null filter is treated the same as false). + #[test] + fn update_batch_nullable_filter_excludes_null_filter_rows() { + let values: ArrayRef = Arc::new(Int64Array::from(vec![1i64, 2, 3, 4, 5])); + // row 0: filter=true, row 1: filter=NULL, row 2: filter=false, + // row 3: filter=NULL, row 4: filter=true + let filter = + BooleanArray::from(vec![Some(true), None, Some(false), None, Some(true)]); + + let mut acc = HllGroupsAccumulator::>::new(); + // put all rows in group 0 + let group_indices = vec![0usize; 5]; + acc.update_batch(&[values], &group_indices, Some(&filter), 1) + .unwrap(); + + // Only rows 0 and 4 (values 1 and 5) should be counted. + let result = acc.evaluate(EmitTo::All).unwrap(); + let counts = result.as_any().downcast_ref::().unwrap(); + // reference: hash 1 and 5 into a dense sketch + let expected = reference_count(&[h(1), h(5)]); + assert_eq!(counts.value(0), expected); + } + + /// Regression: a short (≤ 12-byte) Utf8View string must hash identically + /// in an all-inline batch and in a mixed batch that also contains a long + /// string (which forces a data buffer). + #[test] + fn utf8view_groups_short_string_hashed_consistently_across_batches() { + // Batch 1: all-inline (no data buffers) — "aaa" is hashed as u128 view. + let batch1: ArrayRef = Arc::new(StringViewArray::from(vec!["aaa", "bbb"])); + assert!(batch1.as_string_view().data_buffers().is_empty()); + + // Batch 2: mixed — LONG forces a data buffer; "aaa" must still be + // hashed as u128 view so it matches its appearance in batch 1. + let batch2: ArrayRef = Arc::new(StringViewArray::from(vec!["aaa", LONG])); + assert!(!batch2.as_string_view().data_buffers().is_empty()); + + let group_indices = vec![0usize, 0]; + let mut acc = HllGroupsAccumulator::::new(); + acc.update_batch(&[batch1], &group_indices, None, 1) + .unwrap(); + acc.update_batch(&[batch2], &group_indices, None, 1) + .unwrap(); + + // True distinct values: {"aaa", "bbb", LONG} == 3. + let result = acc.evaluate(EmitTo::All).unwrap(); + let counts = result.as_any().downcast_ref::().unwrap(); + assert_eq!(counts.value(0), 3); + } + + /// Regression: a short (≤ 12-byte) Utf8View string must hash identically + /// regardless of which batch it appears in — all-inline or mixed. + #[test] + fn utf8view_acc_split_batches_match_single_mixed_batch() { + // Multiset: {"aaa" x2, "bbb", LONG}, so 3 distinct values. + let mixed: ArrayRef = + Arc::new(StringViewArray::from(vec!["aaa", "bbb", LONG, "aaa"])); + let mut acc_single = StringViewHLLAccumulator::new(); + acc_single.update_batch(&[mixed]).unwrap(); + + // Same multiset, but split so "aaa" lands in both an all-inline batch + // and a batch with a data buffer (forced by LONG). + let inline_only: ArrayRef = Arc::new(StringViewArray::from(vec!["aaa", "bbb"])); + let with_buffer: ArrayRef = Arc::new(StringViewArray::from(vec!["aaa", LONG])); + assert!(inline_only.as_string_view().data_buffers().is_empty()); + assert!(!with_buffer.as_string_view().data_buffers().is_empty()); + + let mut acc_split = StringViewHLLAccumulator::new(); + acc_split.update_batch(&[inline_only]).unwrap(); + acc_split.update_batch(&[with_buffer]).unwrap(); + + assert_eq!( + distinct_count(&mut acc_single), + distinct_count(&mut acc_split) + ); + assert_eq!(distinct_count(&mut acc_single), 3); + } +} diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index 530dbf3e43c79..162dc224f2ccb 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -19,16 +19,17 @@ use arrow::datatypes::DataType::{Float64, UInt64}; use arrow::datatypes::{DataType, Field, FieldRef}; -use std::any::Any; +use datafusion_common::types::NativeType; +use datafusion_functions_aggregate_common::noop_accumulator::NoopAccumulator; use std::fmt::Debug; use std::sync::Arc; -use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_common::{Result, not_impl_err}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, + Accumulator, AggregateUDFImpl, Coercion, Documentation, Signature, TypeSignature, + TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; @@ -57,20 +58,11 @@ make_udaf_expr_and_func!( ```"#, standard_argument(name = "expression",) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ApproxMedian { signature: Signature, } -impl Debug for ApproxMedian { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("ApproxMedian") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for ApproxMedian { fn default() -> Self { Self::new() @@ -81,33 +73,46 @@ impl ApproxMedian { /// Create a new APPROX_MEDIAN aggregate function pub fn new() -> Self { Self { - signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::one_of( + vec![TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Float, + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + )])], + Volatility::Immutable, + ), } } } impl AggregateUDFImpl for ApproxMedian { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { - Ok(vec![ - Field::new(format_state_name(args.name, "max_size"), UInt64, false), - Field::new(format_state_name(args.name, "sum"), Float64, false), - Field::new(format_state_name(args.name, "count"), UInt64, false), - Field::new(format_state_name(args.name, "max"), Float64, false), - Field::new(format_state_name(args.name, "min"), Float64, false), - Field::new_list( - format_state_name(args.name, "centroids"), - Field::new_list_field(Float64, true), - false, - ), - ] - .into_iter() - .map(Arc::new) - .collect()) + if args.input_fields[0].data_type().is_null() { + Ok(vec![ + Field::new( + format_state_name(args.name, self.name()), + DataType::Null, + true, + ) + .into(), + ]) + } else { + Ok(vec![ + Field::new(format_state_name(args.name, "max_size"), UInt64, false), + Field::new(format_state_name(args.name, "sum"), Float64, false), + Field::new(format_state_name(args.name, "count"), Float64, false), + Field::new(format_state_name(args.name, "max"), Float64, false), + Field::new(format_state_name(args.name, "min"), Float64, false), + Field::new_list( + format_state_name(args.name, "centroids"), + Field::new_list_field(Float64, true), + false, + ), + ] + .into_iter() + .map(Arc::new) + .collect()) + } } fn name(&self) -> &str { @@ -119,9 +124,6 @@ impl AggregateUDFImpl for ApproxMedian { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("ApproxMedian requires numeric input types"); - } Ok(arg_types[0].clone()) } @@ -132,10 +134,14 @@ impl AggregateUDFImpl for ApproxMedian { ); } - Ok(Box::new(ApproxPercentileAccumulator::new( - 0.5_f64, - acc_args.expr_fields[0].data_type().clone(), - ))) + if acc_args.expr_fields[0].data_type().is_null() { + Ok(Box::new(NoopAccumulator::default())) + } else { + Ok(Box::new(ApproxPercentileAccumulator::new( + 0.5_f64, + acc_args.expr_fields[0].data_type().clone(), + ))) + } } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index ce1b149ba0c5e..ea8fea1b1bc29 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -15,36 +15,30 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; use std::mem::size_of_val; use std::sync::Arc; -use arrow::array::Array; +use arrow::array::{Array, Float16Array}; use arrow::compute::{filter, is_not_null}; use arrow::datatypes::FieldRef; use arrow::{ - array::{ - ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, - }, + array::{ArrayRef, Float32Array, Float64Array}, datatypes::{DataType, Field}, }; +use datafusion_common::types::{NativeType, logical_float64}; use datafusion_common::{ - downcast_value, internal_err, not_impl_err, plan_err, DataFusionError, Result, - ScalarValue, + DataFusionError, Result, ScalarValue, downcast_value, internal_err, not_impl_err, + plan_err, }; use datafusion_expr::expr::{AggregateFunction, Sort}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature, - Volatility, -}; -use datafusion_functions_aggregate_common::tdigest::{ - TDigest, TryIntoF64, DEFAULT_MAX_SIZE, + Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, Signature, + TypeSignature, TypeSignatureClass, Volatility, }; +use datafusion_functions_aggregate_common::tdigest::{DEFAULT_MAX_SIZE, TDigest}; use datafusion_macros::user_doc; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -121,20 +115,11 @@ An alternate syntax is also supported: description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory." ) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct ApproxPercentileCont { signature: Signature, } -impl Debug for ApproxPercentileCont { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - f.debug_struct("ApproxPercentileCont") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for ApproxPercentileCont { fn default() -> Self { Self::new() @@ -144,22 +129,44 @@ impl Default for ApproxPercentileCont { impl ApproxPercentileCont { /// Create a new [`ApproxPercentileCont`] aggregate function. pub fn new() -> Self { - let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); // Accept any numeric value paired with a float64 percentile - for num in NUMERICS { - variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); - // Additionally accept an integer number of centroids for T-Digest - for int in INTEGERS { - variants.push(TypeSignature::Exact(vec![ - num.clone(), - DataType::Float64, - int.clone(), - ])) - } - } - Self { - signature: Signature::one_of(variants, Volatility::Immutable), - } + let signature = Signature::one_of( + vec![ + // 2 args - numeric, percentile (float) + TypeSignature::Coercible(vec![ + Coercion::new_implicit( + TypeSignatureClass::Float, + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ), + ]), + // 3 args - numeric, percentile (float), number of centroid for T-Digest (integer) + TypeSignature::Coercible(vec![ + Coercion::new_implicit( + TypeSignatureClass::Float, + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ), + Coercion::new_implicit( + TypeSignatureClass::Integer, + vec![TypeSignatureClass::Numeric], + NativeType::Int64, + ), + ]), + ], + Volatility::Immutable, + ); + Self { signature } } pub(crate) fn create_accumulator( @@ -189,18 +196,13 @@ impl ApproxPercentileCont { let data_type = args.expr_fields[0].data_type(); let accumulator: ApproxPercentileAccumulator = match data_type { - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 => { + DataType::Float16 | DataType::Float32 | DataType::Float64 => { if let Some(max_size) = tdigest_max_size { - ApproxPercentileAccumulator::new_with_max_size(percentile, data_type.clone(), max_size) + ApproxPercentileAccumulator::new_with_max_size( + percentile, + data_type.clone(), + max_size, + ) } else { ApproxPercentileAccumulator::new(percentile, data_type.clone()) } @@ -208,7 +210,7 @@ impl ApproxPercentileCont { other => { return not_impl_err!( "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented" - ) + ); } }; @@ -237,19 +239,14 @@ fn validate_input_max_size_expr(expr: &Arc) -> Result { return plan_err!( "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).", sv.data_type() - ) - }, + ); + } }; Ok(max_size) } impl AggregateUDFImpl for ApproxPercentileCont { - fn as_any(&self) -> &dyn Any { - self - } - - #[allow(rustdoc::private_intra_doc_links)] /// See [`TDigest::to_scalar_state()`] for a description of the serialized /// state. fn state_fields(&self, args: StateFieldsArgs) -> Result> { @@ -266,7 +263,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { ), Field::new( format_state_name(args.name, "count"), - DataType::UInt64, + DataType::Float64, false, ), Field::new( @@ -304,6 +301,13 @@ impl AggregateUDFImpl for ApproxPercentileCont { } fn return_type(&self, arg_types: &[DataType]) -> Result { + // Defensive: the public signature already restricts callers to 2 or 3 + // arguments. This guards against aggregate planning accidentally + // feeding state-field types (e.g. from `PartialReduce`) back into + // `return_type`, which would otherwise silently choose the wrong type. + if arg_types.len() > 3 { + return plan_err!("approx_percentile_cont requires at most 3 arguments"); + } if !arg_types[0].is_numeric() { return plan_err!("approx_percentile_cont requires numeric input types"); } @@ -372,83 +376,19 @@ impl ApproxPercentileAccumulator { match values.data_type() { DataType::Float64 => { let array = downcast_value!(values, Float64Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) + Ok(array.values().iter().copied().collect::>()) } DataType::Float32 => { let array = downcast_value!(values, Float32Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::Int64 => { - let array = downcast_value!(values, Int64Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::Int32 => { - let array = downcast_value!(values, Int32Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::Int16 => { - let array = downcast_value!(values, Int16Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::Int8 => { - let array = downcast_value!(values, Int8Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::UInt64 => { - let array = downcast_value!(values, UInt64Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::UInt32 => { - let array = downcast_value!(values, UInt32Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::UInt16 => { - let array = downcast_value!(values, UInt16Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) + Ok(array.values().iter().map(|v| *v as f64).collect::>()) } - DataType::UInt8 => { - let array = downcast_value!(values, UInt8Array); + DataType::Float16 => { + let array = downcast_value!(values, Float16Array); Ok(array .values() .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) + .map(|v| v.to_f64()) + .collect::>()) } e => internal_err!( "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}" @@ -475,7 +415,7 @@ impl Accumulator for ApproxPercentileAccumulator { } fn evaluate(&mut self) -> Result { - if self.digest.count() == 0 { + if self.digest.count() == 0.0 { return ScalarValue::try_from(self.return_type.clone()); } let q = self.digest.estimate_quantile(self.percentile); @@ -483,14 +423,7 @@ impl Accumulator for ApproxPercentileAccumulator { // These acceptable return types MUST match the validation in // ApproxPercentile::create_accumulator. Ok(match &self.return_type { - DataType::Int8 => ScalarValue::Int8(Some(q as i8)), - DataType::Int16 => ScalarValue::Int16(Some(q as i16)), - DataType::Int32 => ScalarValue::Int32(Some(q as i32)), - DataType::Int64 => ScalarValue::Int64(Some(q as i64)), - DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)), - DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)), - DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)), - DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)), + DataType::Float16 => ScalarValue::Float16(Some(half::f16::from_f64(q))), DataType::Float32 => ScalarValue::Float32(Some(q as f32)), DataType::Float64 => ScalarValue::Float64(Some(q)), v => unreachable!("unexpected return type {}", v), @@ -551,8 +484,8 @@ mod tests { ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100); accumulator.merge_digests(&[t1]); - assert_eq!(accumulator.digest.count(), 50_000); + assert_eq!(accumulator.digest.count(), 50_000.0); accumulator.merge_digests(&[t2]); - assert_eq!(accumulator.digest.count(), 100_000); + assert_eq!(accumulator.digest.count(), 100_000.0); } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index adf07dede3d8b..6ada47fb38040 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -15,8 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; use std::hash::Hash; use std::mem::size_of_val; use std::sync::Arc; @@ -25,13 +24,13 @@ use arrow::compute::{and, filter, is_not_null}; use arrow::datatypes::FieldRef; use arrow::{array::ArrayRef, datatypes::DataType}; use datafusion_common::ScalarValue; -use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_common::types::{NativeType, logical_float64}; +use datafusion_common::{Result, not_impl_err, plan_err}; use datafusion_expr::expr::{AggregateFunction, Sort}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; -use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature, + Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, Signature, + TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest}; use datafusion_macros::user_doc; @@ -111,20 +110,12 @@ An alternative syntax is also supported: description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory." ) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct ApproxPercentileContWithWeight { signature: Signature, approx_percentile_cont: ApproxPercentileCont, } -impl Debug for ApproxPercentileContWithWeight { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ApproxPercentileContWithWeight") - .field("signature", &self.signature) - .finish() - } -} - impl Default for ApproxPercentileContWithWeight { fn default() -> Self { Self::new() @@ -134,36 +125,60 @@ impl Default for ApproxPercentileContWithWeight { impl ApproxPercentileContWithWeight { /// Create a new [`ApproxPercentileContWithWeight`] aggregate function. pub fn new() -> Self { - let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); - // Accept any numeric value paired with weight and float64 percentile - for num in NUMERICS { - variants.push(TypeSignature::Exact(vec![ - num.clone(), - num.clone(), - DataType::Float64, - ])); - // Additionally accept an integer number of centroids for T-Digest - for int in INTEGERS { - variants.push(TypeSignature::Exact(vec![ - num.clone(), - num.clone(), - DataType::Float64, - int.clone(), - ])); - } - } + let signature = Signature::one_of( + vec![ + // 3 args - numeric, weight (float), percentile (float) + TypeSignature::Coercible(vec![ + Coercion::new_implicit( + TypeSignatureClass::Float, + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ), + Coercion::new_implicit( + TypeSignatureClass::Float, + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ), + ]), + // 4 args - numeric, weight (float), percentile (float), centroid (integer) + TypeSignature::Coercible(vec![ + Coercion::new_implicit( + TypeSignatureClass::Float, + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ), + Coercion::new_implicit( + TypeSignatureClass::Float, + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ), + Coercion::new_implicit( + TypeSignatureClass::Integer, + vec![TypeSignatureClass::Numeric], + NativeType::Int64, + ), + ]), + ], + Volatility::Immutable, + ); Self { - signature: Signature::one_of(variants, Immutable), + signature, approx_percentile_cont: ApproxPercentileCont::new(), } } } impl AggregateUDFImpl for ApproxPercentileContWithWeight { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "approx_percentile_cont_with_weight" } @@ -184,7 +199,9 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { ); } if arg_types[2] != DataType::Float64 { - return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types"); + return plan_err!( + "approx_percentile_cont_with_weight requires float64 percentile input types" + ); } if arg_types.len() == 4 && !arg_types[3].is_integer() { return plan_err!( @@ -251,7 +268,6 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { Ok(Box::new(accumulator)) } - #[allow(rustdoc::private_intra_doc_links)] /// See [`TDigest::to_scalar_state()`] for a description of the serialized /// state. fn state_fields(&self, args: StateFieldsArgs) -> Result> { diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 4f5797c308f9b..8ed3fbf8c3d26 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -18,26 +18,32 @@ //! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`] use std::cmp::Ordering; -use std::collections::{HashSet, VecDeque}; +use std::collections::{HashMap, VecDeque}; use std::mem::{size_of, size_of_val, take}; use std::sync::Arc; use arrow::array::{ - new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray, + Array, ArrayRef, AsArray, BooleanArray, ListArray, NullBufferBuilder, StructArray, + UInt32Array, new_empty_array, }; -use arrow::compute::{filter, SortOptions}; +use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow::compute::{SortOptions, filter}; use arrow::datatypes::{DataType, Field, FieldRef, Fields}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::{ - compare_rows, get_row_at_idx, take_function_args, SingleRowListArrayBuilder, + SingleRowListArrayBuilder, compare_rows, get_row_at_idx, take_function_args, +}; +use datafusion_common::{ + Result, ScalarValue, assert_eq_or_internal_err, exec_err, internal_err, }; -use datafusion_common::{assert_eq_or_internal_err, exec_err, Result, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, Signature, + Volatility, }; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filter_to_nulls; use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; use datafusion_functions_aggregate_common::utils::ordering_fields; @@ -92,10 +98,6 @@ impl Default for ArrayAgg { } impl AggregateUDFImpl for ArrayAgg { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "array_agg" } @@ -113,22 +115,26 @@ impl AggregateUDFImpl for ArrayAgg { fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { - return Ok(vec![Field::new_list( - format_state_name(args.name, "distinct_array_agg"), + return Ok(vec![ + Field::new_list( + format_state_name(args.name, "distinct_array_agg"), + // See COMMENTS.md to understand why nullable is set to true + Field::new_list_field(args.input_fields[0].data_type().clone(), true), + true, + ) + .into(), + ]); + } + + let mut fields = vec![ + Field::new_list( + format_state_name(args.name, "array_agg"), // See COMMENTS.md to understand why nullable is set to true Field::new_list_field(args.input_fields[0].data_type().clone(), true), true, ) - .into()]); - } - - let mut fields = vec![Field::new_list( - format_state_name(args.name, "array_agg"), - // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_fields[0].data_type().clone(), true), - true, - ) - .into()]; + .into(), + ]; if args.ordering_fields.is_empty() { return Ok(fields); @@ -224,6 +230,23 @@ impl AggregateUDFImpl for ArrayAgg { datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf()) } + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + !args.is_distinct && args.order_bys.is_empty() + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let field = &args.expr_fields[0]; + let data_type = field.data_type().clone(); + let ignore_nulls = args.ignore_nulls && field.is_nullable(); + Ok(Box::new(ArrayAggGroupsAccumulator::new( + data_type, + ignore_nulls, + ))) + } + fn supports_null_handling_clause(&self) -> bool { true } @@ -235,18 +258,22 @@ impl AggregateUDFImpl for ArrayAgg { #[derive(Debug)] pub struct ArrayAggAccumulator { - values: Vec, + values: VecDeque, datatype: DataType, ignore_nulls: bool, + /// Number of elements already consumed (retracted) from the front array. + /// Used by sliding window frames to avoid copying on partial retract. + front_offset: usize, } impl ArrayAggAccumulator { /// new array_agg accumulator based on given item data type pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result { Ok(Self { - values: vec![], + values: VecDeque::new(), datatype: datatype.clone(), ignore_nulls, + front_offset: 0, }) } @@ -335,7 +362,7 @@ impl Accumulator for ArrayAggAccumulator { }; if !val.is_empty() { - self.values.push(val) + self.values.push_back(val) } Ok(()) @@ -355,12 +382,12 @@ impl Accumulator for ArrayAggAccumulator { Some(values) => { // Make sure we don't insert empty lists if !values.is_empty() { - self.values.push(values); + self.values.push_back(values); } } None => { for arr in list_arr.iter().flatten() { - self.values.push(arr); + self.values.push_back(arr); } } } @@ -373,19 +400,71 @@ impl Accumulator for ArrayAggAccumulator { } fn evaluate(&mut self) -> Result { - // Transform Vec to ListArr - let element_arrays: Vec<&dyn Array> = - self.values.iter().map(|a| a.as_ref()).collect(); + if self.values.is_empty() { + return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); + } + + let element_arrays: Vec = self + .values + .iter() + .enumerate() + .map(|(i, a)| { + if i == 0 && self.front_offset > 0 { + a.slice(self.front_offset, a.len() - self.front_offset) + } else { + Arc::clone(a) + } + }) + .collect(); + + let element_refs: Vec<&dyn Array> = + element_arrays.iter().map(|a| a.as_ref()).collect(); - if element_arrays.is_empty() { + if element_refs.iter().all(|a| a.is_empty()) { return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); } - let concated_array = arrow::compute::concat(&element_arrays)?; + let concated_array = arrow::compute::concat(&element_refs)?; Ok(SingleRowListArrayBuilder::new(concated_array).build_list_scalar()) } + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + assert_eq_or_internal_err!(values.len(), 1, "expects single batch"); + + let val = &values[0]; + let mut to_retract = if self.ignore_nulls { + val.len() - val.logical_null_count() + } else { + val.len() + }; + + while to_retract > 0 { + let Some(front) = self.values.front() else { + break; + }; + let available = front.len() - self.front_offset; + if to_retract >= available { + self.values.pop_front(); + to_retract -= available; + self.front_offset = 0; + } else { + self.front_offset += to_retract; + to_retract = 0; + } + } + + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } + fn size(&self) -> usize { size_of_val(self) + (size_of::() * self.values.capacity()) @@ -411,8 +490,336 @@ impl Accumulator for ArrayAggAccumulator { } #[derive(Debug)] -struct DistinctArrayAggAccumulator { - values: HashSet, +struct ArrayAggGroupsAccumulator { + datatype: DataType, + ignore_nulls: bool, + /// Source arrays — input arrays (from update_batch) or list backing + /// arrays (from merge_batch). + batches: Vec, + /// Per-batch list of (group_idx, row_idx) pairs. + batch_entries: Vec>, + /// Total number of groups tracked. + num_groups: usize, +} + +impl ArrayAggGroupsAccumulator { + fn new(datatype: DataType, ignore_nulls: bool) -> Self { + Self { + datatype, + ignore_nulls, + batches: Vec::new(), + batch_entries: Vec::new(), + num_groups: 0, + } + } + + fn clear_state(&mut self) { + // `size()` measures Vec capacity rather than len, so allocate new + // buffers instead of using `clear()`. + self.batches = Vec::new(); + self.batch_entries = Vec::new(); + self.num_groups = 0; + } + + fn compact_retained_state(&mut self, emit_groups: usize) -> Result<()> { + // EmitTo::First is used to recover from memory pressure. Simply + // removing emitted entries in place is not enough because mixed batches + // would continue to pin their original Array arrays, even if only a few + // retained rows remain. + // + // Rebuild the retained state from scratch so fully emitted batches are + // dropped, mixed batches are compacted to arrays containing only the + // surviving rows, and retained metadata is right-sized. + let emit_groups = emit_groups as u32; + let old_batches = take(&mut self.batches); + let old_batch_entries = take(&mut self.batch_entries); + + let mut batches = Vec::new(); + let mut batch_entries = Vec::new(); + + for (batch, entries) in old_batches.into_iter().zip(old_batch_entries) { + let retained_len = entries.iter().filter(|(g, _)| *g >= emit_groups).count(); + + if retained_len == 0 { + continue; + } + + if retained_len == entries.len() { + // Nothing was emitted from this batch, so we keep the existing + // array and only renumber the remaining group IDs so that they + // start from 0. + let mut retained_entries = entries; + for (g, _) in &mut retained_entries { + *g -= emit_groups; + } + retained_entries.shrink_to_fit(); + batches.push(batch); + batch_entries.push(retained_entries); + continue; + } + + let mut retained_entries = Vec::with_capacity(retained_len); + let mut retained_rows = Vec::with_capacity(retained_len); + + for (g, r) in entries { + if g >= emit_groups { + // Compute the new `(group_idx, row_idx)` pair for a + // retained row. `group_idx` is renumbered to start from + // 0, and `row_idx` points into the new dense batch we are + // building. + retained_entries.push((g - emit_groups, retained_rows.len() as u32)); + retained_rows.push(r); + } + } + + debug_assert_eq!(retained_entries.len(), retained_len); + debug_assert_eq!(retained_rows.len(), retained_len); + + let batch = if retained_len == batch.len() { + batch + } else { + // Compact mixed batches so retained rows no longer pin the + // original array. + let retained_rows = UInt32Array::from(retained_rows); + arrow::compute::take(batch.as_ref(), &retained_rows, None)? + }; + + batches.push(batch); + batch_entries.push(retained_entries); + } + + self.batches = batches; + self.batch_entries = batch_entries; + self.num_groups -= emit_groups as usize; + + Ok(()) + } +} + +impl GroupsAccumulator for ArrayAggGroupsAccumulator { + /// Store a reference to the input batch, plus a `(group_idx, row_idx)` pair + /// for every row. + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let input = &values[0]; + + self.num_groups = self.num_groups.max(total_num_groups); + + let nulls = if self.ignore_nulls { + input.logical_nulls() + } else { + None + }; + + let mut entries = Vec::new(); + + for (row_idx, &group_idx) in group_indices.iter().enumerate() { + // Skip filtered rows + if let Some(filter) = opt_filter + && (filter.is_null(row_idx) || !filter.value(row_idx)) + { + continue; + } + + // Skip null values when ignore_nulls is set + if let Some(ref nulls) = nulls + && nulls.is_null(row_idx) + { + continue; + } + + entries.push((group_idx as u32, row_idx as u32)); + } + + // We only need to record the batch if it was non-empty. + if !entries.is_empty() { + self.batches.push(Arc::clone(input)); + self.batch_entries.push(entries); + } + + Ok(()) + } + + /// Produce a `ListArray` ordered by group index: the list at + /// position N contains the aggregated values for group N. + /// + /// Uses a counting sort to rearrange the stored `(group, row)` + /// entries into group order, then calls `interleave` to gather + /// the values into a flat array that backs the output `ListArray`. + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let emit_groups = match emit_to { + EmitTo::All => self.num_groups, + EmitTo::First(n) => n, + }; + + // Step 1: Count entries per group. For EmitTo::First(n), only groups + // 0..n are counted; the rest are retained to be emitted in the future. + let mut counts = vec![0u32; emit_groups]; + for entries in &self.batch_entries { + for &(g, _) in entries { + let g = g as usize; + if g < emit_groups { + counts[g] += 1; + } + } + } + + // Step 2: Do a prefix sum over the counts and use it to build ListArray + // offsets, null buffer, and write positions for the counting sort. + let mut offsets = Vec::::with_capacity(emit_groups + 1); + offsets.push(0); + let mut nulls_builder = NullBufferBuilder::new(emit_groups); + let mut write_positions = Vec::with_capacity(emit_groups); + let mut cur_offset = 0u32; + for &count in &counts { + if count == 0 { + nulls_builder.append_null(); + } else { + nulls_builder.append_non_null(); + } + write_positions.push(cur_offset); + cur_offset += count; + offsets.push(cur_offset as i32); + } + let total_rows = cur_offset as usize; + + // Step 3: Scatter entries into group order using the counting sort. The + // batch index is implicit from the outer loop position. + let flat_values = if total_rows == 0 { + new_empty_array(&self.datatype) + } else { + let mut interleave_indices = vec![(0usize, 0usize); total_rows]; + for (batch_idx, entries) in self.batch_entries.iter().enumerate() { + for &(g, r) in entries { + let g = g as usize; + if g < emit_groups { + let wp = write_positions[g] as usize; + interleave_indices[wp] = (batch_idx, r as usize); + write_positions[g] += 1; + } + } + } + + let sources: Vec<&dyn Array> = + self.batches.iter().map(|b| b.as_ref()).collect(); + arrow::compute::interleave(&sources, &interleave_indices)? + }; + + // Step 4: Release state for emitted groups. + match emit_to { + EmitTo::All => self.clear_state(), + EmitTo::First(_) => self.compact_retained_state(emit_groups)?, + } + + let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets)); + let field = Arc::new(Field::new_list_field(self.datatype.clone(), true)); + let result = ListArray::new(field, offsets, flat_values, nulls_builder.finish()); + + Ok(Arc::new(result)) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + Ok(vec![self.evaluate(emit_to)?]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + _opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + let input_list = values[0].as_list::(); + + self.num_groups = self.num_groups.max(total_num_groups); + + // Push the ListArray's backing values array as a single batch. + let list_values = input_list.values(); + let list_offsets = input_list.offsets(); + + let mut entries = Vec::new(); + + for (row_idx, &group_idx) in group_indices.iter().enumerate() { + if input_list.is_null(row_idx) { + continue; + } + let start = list_offsets[row_idx] as u32; + let end = list_offsets[row_idx + 1] as u32; + for pos in start..end { + entries.push((group_idx as u32, pos)); + } + } + + if !entries.is_empty() { + self.batches.push(Arc::clone(list_values)); + self.batch_entries.push(entries); + } + + Ok(()) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + assert_eq!(values.len(), 1, "one argument to convert_to_state"); + + let input = &values[0]; + + // Each row becomes a 1-element list: offsets are [0, 1, 2, ..., n]. + let offsets = OffsetBuffer::from_repeated_length(1, input.len()); + + // Filtered rows become null list entries, which merge_batch will skip. + let filter_nulls = opt_filter.map(filter_to_nulls); + + // With ignore_nulls, null values also become null list entries. Without + // ignore_nulls, null values stay as [NULL] so merge_batch retains them. + let nulls = if self.ignore_nulls { + let logical = input.logical_nulls(); + NullBuffer::union(filter_nulls.as_ref(), logical.as_ref()) + } else { + filter_nulls + }; + + let field = Arc::new(Field::new_list_field(self.datatype.clone(), true)); + let list_array = ListArray::new(field, offsets, Arc::clone(input), nulls); + + Ok(vec![Arc::new(list_array)]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.batches + .iter() + .map(|arr| arr.to_data().get_slice_memory_size().unwrap_or_default()) + .sum::() + + self.batches.capacity() * size_of::() + + self + .batch_entries + .iter() + .map(|e| e.capacity() * size_of::<(u32, u32)>()) + .sum::() + + self.batch_entries.capacity() * size_of::>() + } +} + +#[derive(Debug)] +pub struct DistinctArrayAggAccumulator { + // Value → live refcount. Multiset state lets `retract_batch` correctly + // drop a duplicate occurrence while keeping the key alive if other + // copies remain in the current window frame. + values: HashMap, datatype: DataType, sort_options: Option, ignore_nulls: bool, @@ -425,7 +832,7 @@ impl DistinctArrayAggAccumulator { ignore_nulls: bool, ) -> Result { Ok(Self { - values: HashSet::new(), + values: HashMap::new(), datatype: datatype.clone(), sort_options, ignore_nulls, @@ -454,8 +861,8 @@ impl Accumulator for DistinctArrayAggAccumulator { if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) { for i in 0..val.len() { if nulls.is_none_or(|nulls| nulls.is_valid(i)) { - self.values - .insert(ScalarValue::try_from_array(val, i)?.compacted()); + let key = ScalarValue::try_from_array(val, i)?.compacted(); + *self.values.entry(key).or_insert(0) += 1; } } } @@ -470,6 +877,12 @@ impl Accumulator for DistinctArrayAggAccumulator { assert_eq_or_internal_err!(states.len(), 1, "expects single state"); + // The DISTINCT state schema is `List` — partial accumulators + // ship the set of values they saw, not multiplicities. Re-ingesting + // each element here makes the merged counts represent "partitions + // that emitted this value," which is fine because `evaluate` only + // reads keys. Refcount semantics for retract are only valid within + // a single accumulator instance (window execution). states[0] .as_list::() .iter() @@ -478,7 +891,7 @@ impl Accumulator for DistinctArrayAggAccumulator { } fn evaluate(&mut self) -> Result { - let mut values: Vec = self.values.iter().cloned().collect(); + let mut values: Vec = self.values.keys().cloned().collect(); if values.is_empty() { return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); } @@ -514,8 +927,50 @@ impl Accumulator for DistinctArrayAggAccumulator { Ok(ScalarValue::List(arr)) } + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + assert_eq_or_internal_err!(values.len(), 1, "expects single batch"); + + let val = &values[0]; + let nulls = if self.ignore_nulls { + val.logical_nulls() + } else { + None + }; + let nulls = nulls.as_ref(); + + for i in 0..val.len() { + if nulls.is_some_and(|nulls| !nulls.is_valid(i)) { + continue; + } + let key = ScalarValue::try_from_array(val, i)?; + match self.values.get_mut(&key) { + Some(count) => { + *count -= 1; + if *count == 0 { + self.values.remove(&key); + } + } + None => { + return internal_err!( + "DistinctArrayAggAccumulator::retract_batch: value not present in state" + ); + } + } + } + + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } + fn size(&self) -> usize { - size_of_val(self) + ScalarValue::size_of_hashset(&self.values) + size_of_val(self) + ScalarValue::size_of_hashmap(&self.values) - size_of_val(&self.values) + self.datatype.size() - size_of_val(&self.datatype) @@ -606,7 +1061,13 @@ impl OrderSensitiveArrayAggAccumulator { } else { (0..fields.len()) .map(|i| { - let column_values = self.ordering_values.iter().map(|x| x[i].clone()); + let column_values: Box> = if self + .reverse + { + Box::new(self.ordering_values.iter().rev().map(|x| x[i].clone())) + } else { + Box::new(self.ordering_values.iter().map(|x| x[i].clone())) + }; ScalarValue::iter_to_array(column_values) }) .collect::>()? @@ -796,13 +1257,11 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { mod tests { use super::*; use arrow::array::{ListBuilder, StringBuilder}; - use arrow::datatypes::{FieldRef, Schema}; + use arrow::datatypes::Schema; use datafusion_common::cast::as_generic_string_array; use datafusion_common::internal_err; - use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::PhysicalExpr; - use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; - use std::sync::Arc; + use datafusion_physical_expr::expressions::Column; #[test] fn no_duplicates_no_distinct() -> Result<()> { @@ -1071,7 +1530,7 @@ mod tests { acc2.update_batch(&[data(["b", "c", "a"])])?; acc1 = merge(acc1, acc2)?; - assert_eq!(acc1.size(), 266); + assert_eq!(acc1.size(), 282); Ok(()) } @@ -1088,8 +1547,8 @@ mod tests { acc2.update_batch(&[string_list_data([vec!["e", "f", "g"]])])?; acc1 = merge(acc1, acc2)?; - // without compaction, the size is 16660 - assert_eq!(acc1.size(), 1660); + // without compaction, the size is 16684 + assert_eq!(acc1.size(), 1684); Ok(()) } @@ -1107,11 +1566,120 @@ mod tests { ])])?; // without compaction, the size is 17112 - assert_eq!(acc.size(), 2184); + assert_eq!(acc.size(), 2224); Ok(()) } + // Reproduces the bug where `state()` emits reversed values but non-reversed + // orderings when the optimizer sets is_input_pre_ordered=true + reverse=true + // (DESC aggregate with ASC pre-sorted input). The partial states are fed into + // a final accumulator via merge_batch; without the fix the ordering keys and + // values are mismatched so the final sort produces wrong order. + #[test] + fn desc_order_partial_final_merge_correct() -> Result<()> { + use arrow::array::Int64Array; + use datafusion_physical_expr::expressions::Column; + + let schema = Schema::new(vec![ + Field::new("val", DataType::Int64, true), + Field::new("ord", DataType::Int64, true), + ]); + let ord_expr = Arc::new( + Column::new_with_schema("ord", &schema).expect("column not in schema"), + ) as Arc; + + // ordering_req for partial = [ord ASC] (reversed, because input is pre-sorted ASC + // and the user wants DESC — the optimizer reverses the requirement) + let asc_opts = SortOptions { + descending: false, + nulls_first: false, + }; + let desc_opts = SortOptions { + descending: true, + nulls_first: false, + }; + + let asc_ordering = LexOrdering::new(vec![PhysicalSortExpr::new( + Arc::clone(&ord_expr), + asc_opts, + )]) + .unwrap(); + let desc_ordering = LexOrdering::new(vec![PhysicalSortExpr::new( + Arc::clone(&ord_expr), + desc_opts, + )]) + .unwrap(); + + let ordering_dtype = DataType::Int64; + + // Partial acc A: sees rows [0,1,2] arriving in ASC order (pre-ordered). + // is_input_pre_ordered=true, reverse=true, ordering_req=[ASC]. + let mut partial_a = OrderSensitiveArrayAggAccumulator::try_new( + &DataType::Int64, + std::slice::from_ref(&ordering_dtype), + asc_ordering.clone(), + /*is_input_pre_ordered=*/ true, + /*reverse=*/ true, + /*ignore_nulls=*/ false, + )?; + let vals_a = Arc::new(Int64Array::from(vec![0i64, 1, 2])) as ArrayRef; + let ords_a = Arc::new(Int64Array::from(vec![0i64, 1, 2])) as ArrayRef; + partial_a.update_batch(&[vals_a, ords_a])?; + let state_a = partial_a + .state()? + .iter() + .map(|v| v.to_array()) + .collect::>>()?; + + // Partial acc B: sees rows [3,4,5] arriving in ASC order. + let mut partial_b = OrderSensitiveArrayAggAccumulator::try_new( + &DataType::Int64, + std::slice::from_ref(&ordering_dtype), + asc_ordering, + /*is_input_pre_ordered=*/ true, + /*reverse=*/ true, + /*ignore_nulls=*/ false, + )?; + let vals_b = Arc::new(Int64Array::from(vec![3i64, 4, 5])) as ArrayRef; + let ords_b = Arc::new(Int64Array::from(vec![3i64, 4, 5])) as ArrayRef; + partial_b.update_batch(&[vals_b, ords_b])?; + let state_b = partial_b + .state()? + .iter() + .map(|v| v.to_array()) + .collect::>>()?; + + // Final acc: not optimized — ordering_req=[DESC], reverse=false. + let mut final_acc = OrderSensitiveArrayAggAccumulator::try_new( + &DataType::Int64, + std::slice::from_ref(&ordering_dtype), + desc_ordering, + /*is_input_pre_ordered=*/ false, + /*reverse=*/ false, + /*ignore_nulls=*/ false, + )?; + final_acc.merge_batch(&state_a)?; + final_acc.merge_batch(&state_b)?; + let result = final_acc.evaluate()?; + + let ScalarValue::List(list) = result else { + return datafusion_common::internal_err!("expected List"); + }; + let result_vals: Vec = list + .values() + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect(); + + // Expected DESC: [5, 4, 3, 2, 1, 0] + assert_eq!(result_vals, vec![5i64, 4, 3, 2, 1, 0]); + Ok(()) + } + struct ArrayAggAccumulatorBuilder { return_field: FieldRef, distinct: bool, @@ -1223,4 +1791,803 @@ mod tests { acc1.merge_batch(&intermediate_state)?; Ok(acc1) } + + // ---- GroupsAccumulator tests ---- + + use arrow::array::Int32Array; + + fn list_array_to_i32_vecs(list: &ListArray) -> Vec>>> { + (0..list.len()) + .map(|i| { + if list.is_null(i) { + None + } else { + let arr = list.value(i); + let vals: Vec> = arr + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect(); + Some(vals) + } + }) + .collect() + } + + fn eval_i32_lists( + acc: &mut ArrayAggGroupsAccumulator, + emit_to: EmitTo, + ) -> Result>>>> { + let result = acc.evaluate(emit_to)?; + Ok(list_array_to_i32_vecs(result.as_list::())) + } + + #[test] + fn groups_accumulator_multiple_batches() -> Result<()> { + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + // First batch + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + acc.update_batch(&[values], &[0, 1, 0], None, 2)?; + + // Second batch + let values: ArrayRef = Arc::new(Int32Array::from(vec![4, 5])); + acc.update_batch(&[values], &[1, 0], None, 2)?; + + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + assert_eq!(vals[0], Some(vec![Some(1), Some(3), Some(5)])); + assert_eq!(vals[1], Some(vec![Some(2), Some(4)])); + + Ok(()) + } + + #[test] + fn groups_accumulator_emit_first() -> Result<()> { + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let values: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30])); + acc.update_batch(&[values], &[0, 1, 2], None, 3)?; + + // Emit first 2 groups + let vals = eval_i32_lists(&mut acc, EmitTo::First(2))?; + assert_eq!(vals.len(), 2); + assert_eq!(vals[0], Some(vec![Some(10)])); + assert_eq!(vals[1], Some(vec![Some(20)])); + + // Remaining group (was index 2, now shifted to 0) + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + assert_eq!(vals.len(), 1); + assert_eq!(vals[0], Some(vec![Some(30)])); + + Ok(()) + } + + #[test] + fn groups_accumulator_emit_first_frees_batches() -> Result<()> { + // Batch 0 has rows only for group 0; batch 1 has rows for + // both groups. After emitting group 0, batch 0 should be + // dropped entirely and batch 1 should be compacted to the + // retained row(s). + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let batch0: ArrayRef = Arc::new(Int32Array::from(vec![10, 20])); + acc.update_batch(&[batch0], &[0, 0], None, 2)?; + + let batch1: ArrayRef = Arc::new(Int32Array::from(vec![30, 40])); + acc.update_batch(&[batch1], &[0, 1], None, 2)?; + + assert_eq!(acc.batches.len(), 2); + assert!(!acc.batches[0].is_empty()); + assert!(!acc.batches[1].is_empty()); + + // Emit group 0. Batch 0 is only referenced by group 0, so it + // should be removed. Batch 1 is mixed, so it should be compacted + // to contain only the retained row for group 1. + let vals = eval_i32_lists(&mut acc, EmitTo::First(1))?; + assert_eq!(vals[0], Some(vec![Some(10), Some(20), Some(30)])); + + assert_eq!(acc.batches.len(), 1); + let retained = acc.batches[0] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(retained.values(), &[40]); + assert_eq!(acc.batch_entries, vec![vec![(0, 0)]]); + + // Emit remaining group 1 + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + assert_eq!(vals[0], Some(vec![Some(40)])); + + assert!(acc.batches.is_empty()); + assert_eq!(acc.size(), 0); + + Ok(()) + } + + #[test] + fn groups_accumulator_emit_first_compacts_mixed_batches() -> Result<()> { + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let batch: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30, 40])); + acc.update_batch(&[batch], &[0, 1, 0, 1], None, 2)?; + + let size_before = acc.size(); + let vals = eval_i32_lists(&mut acc, EmitTo::First(1))?; + assert_eq!(vals[0], Some(vec![Some(10), Some(30)])); + + assert_eq!(acc.num_groups, 1); + assert_eq!(acc.batches.len(), 1); + let retained = acc.batches[0] + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(retained.values(), &[20, 40]); + assert_eq!(acc.batch_entries, vec![vec![(0, 0), (0, 1)]]); + assert!(acc.size() < size_before); + + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + assert_eq!(vals[0], Some(vec![Some(20), Some(40)])); + assert_eq!(acc.size(), 0); + + Ok(()) + } + + #[test] + fn groups_accumulator_emit_all_releases_capacity() -> Result<()> { + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let batch: ArrayRef = Arc::new(Int32Array::from_iter_values(0..64)); + acc.update_batch( + &[batch], + &(0..64).map(|i| i % 4).collect::>(), + None, + 4, + )?; + + assert!(acc.size() > 0); + let _ = eval_i32_lists(&mut acc, EmitTo::All)?; + + assert_eq!(acc.size(), 0); + assert_eq!(acc.batches.capacity(), 0); + assert_eq!(acc.batch_entries.capacity(), 0); + + Ok(()) + } + + #[test] + fn groups_accumulator_null_groups() -> Result<()> { + // Groups that never receive values should produce null + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let values: ArrayRef = Arc::new(Int32Array::from(vec![1])); + // Only group 0 gets a value, groups 1 and 2 are empty + acc.update_batch(&[values], &[0], None, 3)?; + + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + assert_eq!(vals, vec![Some(vec![Some(1)]), None, None]); + + Ok(()) + } + + #[test] + fn groups_accumulator_ignore_nulls() -> Result<()> { + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, true); + + let values: ArrayRef = + Arc::new(Int32Array::from(vec![Some(1), None, Some(3), None])); + acc.update_batch(&[values], &[0, 0, 1, 1], None, 2)?; + + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + // Group 0: only non-null value is 1 + assert_eq!(vals[0], Some(vec![Some(1)])); + // Group 1: only non-null value is 3 + assert_eq!(vals[1], Some(vec![Some(3)])); + + Ok(()) + } + + #[test] + fn groups_accumulator_opt_filter() -> Result<()> { + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4])); + // Use a mix of false and null to filter out rows — both should + // be skipped. + let filter = BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]); + acc.update_batch(&[values], &[0, 0, 1, 1], Some(&filter), 2)?; + + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + assert_eq!(vals[0], Some(vec![Some(1)])); // row 1 filtered (null) + assert_eq!(vals[1], Some(vec![Some(3)])); // row 3 filtered (false) + + Ok(()) + } + + #[test] + fn groups_accumulator_state_merge_roundtrip() -> Result<()> { + // Accumulator 1: update_batch, then merge, then update_batch again. + // Verifies that values appear in chronological insertion order. + let mut acc1 = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); + acc1.update_batch(&[values], &[0, 1], None, 2)?; + + // Accumulator 2 + let mut acc2 = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + let values: ArrayRef = Arc::new(Int32Array::from(vec![3, 4])); + acc2.update_batch(&[values], &[0, 1], None, 2)?; + + // Merge acc2's state into acc1 + let state = acc2.state(EmitTo::All)?; + acc1.merge_batch(&state, &[0, 1], None, 2)?; + + // Another update_batch on acc1 after the merge + let values: ArrayRef = Arc::new(Int32Array::from(vec![5, 6])); + acc1.update_batch(&[values], &[0, 1], None, 2)?; + + // Each group's values in insertion order: + // group 0: update(1), merge(3), update(5) → [1, 3, 5] + // group 1: update(2), merge(4), update(6) → [2, 4, 6] + let vals = eval_i32_lists(&mut acc1, EmitTo::All)?; + assert_eq!(vals[0], Some(vec![Some(1), Some(3), Some(5)])); + assert_eq!(vals[1], Some(vec![Some(2), Some(4), Some(6)])); + + Ok(()) + } + + #[test] + fn groups_accumulator_convert_to_state() -> Result<()> { + let acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let values: ArrayRef = Arc::new(Int32Array::from(vec![Some(10), None, Some(30)])); + let state = acc.convert_to_state(&[values], None)?; + + assert_eq!(state.len(), 1); + let vals = list_array_to_i32_vecs(state[0].as_list::()); + assert_eq!( + vals, + vec![ + Some(vec![Some(10)]), + Some(vec![None]), // null preserved inside list, not promoted + Some(vec![Some(30)]), + ] + ); + + Ok(()) + } + + #[test] + fn groups_accumulator_convert_to_state_with_filter() -> Result<()> { + let acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let values: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 30])); + let filter = BooleanArray::from(vec![true, false, true]); + let state = acc.convert_to_state(&[values], Some(&filter))?; + + let vals = list_array_to_i32_vecs(state[0].as_list::()); + assert_eq!( + vals, + vec![ + Some(vec![Some(10)]), + None, // filtered + Some(vec![Some(30)]), + ] + ); + + Ok(()) + } + + #[test] + fn groups_accumulator_convert_to_state_merge_preserves_nulls() -> Result<()> { + // Verifies that null values survive the convert_to_state -> merge_batch + // round-trip when ignore_nulls is false (default null handling). + let acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + let values: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])); + let state = acc.convert_to_state(&[values], None)?; + + // Feed state into a new accumulator via merge_batch + let mut acc2 = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + acc2.merge_batch(&state, &[0, 0, 1], None, 2)?; + + // Group 0 received rows 0 ([1]) and 1 ([NULL]) → [1, NULL] + let vals = eval_i32_lists(&mut acc2, EmitTo::All)?; + assert_eq!(vals[0], Some(vec![Some(1), None])); + // Group 1 received row 2 ([3]) → [3] + assert_eq!(vals[1], Some(vec![Some(3)])); + + Ok(()) + } + + #[test] + fn groups_accumulator_convert_to_state_merge_ignore_nulls() -> Result<()> { + // Verifies that null values are dropped in the convert_to_state -> + // merge_batch round-trip when ignore_nulls is true. + let acc = ArrayAggGroupsAccumulator::new(DataType::Int32, true); + + let values: ArrayRef = + Arc::new(Int32Array::from(vec![Some(1), None, Some(3), None])); + let state = acc.convert_to_state(&[values], None)?; + + let list = state[0].as_list::(); + // Rows 0 and 2 are valid lists; rows 1 and 3 are null list entries + assert!(!list.is_null(0)); + assert!(list.is_null(1)); + assert!(!list.is_null(2)); + assert!(list.is_null(3)); + + // Feed state into a new accumulator via merge_batch + let mut acc2 = ArrayAggGroupsAccumulator::new(DataType::Int32, true); + acc2.merge_batch(&state, &[0, 0, 1, 1], None, 2)?; + + // Group 0: received [1] and null (skipped) → [1] + let vals = eval_i32_lists(&mut acc2, EmitTo::All)?; + assert_eq!(vals[0], Some(vec![Some(1)])); + // Group 1: received [3] and null (skipped) → [3] + assert_eq!(vals[1], Some(vec![Some(3)])); + + Ok(()) + } + + #[test] + fn groups_accumulator_all_groups_empty() -> Result<()> { + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, false); + + // Create groups but don't add any values (all filtered out) + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); + let filter = BooleanArray::from(vec![false, false]); + acc.update_batch(&[values], &[0, 1], Some(&filter), 2)?; + + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + assert_eq!(vals, vec![None, None]); + + Ok(()) + } + + #[test] + fn groups_accumulator_ignore_nulls_all_null_group() -> Result<()> { + // When ignore_nulls is true and a group receives only nulls, + // it should produce a null output + let mut acc = ArrayAggGroupsAccumulator::new(DataType::Int32, true); + + let values: ArrayRef = Arc::new(Int32Array::from(vec![None, Some(1), None])); + acc.update_batch(&[values], &[0, 1, 0], None, 2)?; + + let vals = eval_i32_lists(&mut acc, EmitTo::All)?; + assert_eq!(vals[0], None); // group 0 got only nulls, all filtered + assert_eq!(vals[1], Some(vec![Some(1)])); // group 1 got value 1 + + Ok(()) + } + + // ---- retract_batch tests ---- + + #[test] + fn retract_basic_sliding_window() -> Result<()> { + let mut acc = ArrayAggAccumulator::try_new(&DataType::Utf8, false)?; + + // Simulate ROWS BETWEEN 1 PRECEDING AND CURRENT ROW over [A, B, C, D] + // Row 1: frame = [A] + acc.update_batch(&[data(["A"])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A"]); + + // Row 2: frame = [A, B] + acc.update_batch(&[data(["B"])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "B"]); + + // Row 3: frame = [B, C] — A leaves + acc.update_batch(&[data(["C"])])?; + acc.retract_batch(&[data(["A"])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["B", "C"]); + + // Row 4: frame = [C, D] — B leaves + acc.update_batch(&[data(["D"])])?; + acc.retract_batch(&[data(["B"])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["C", "D"]); + + Ok(()) + } + + #[test] + fn retract_multi_element_across_arrays() -> Result<()> { + let mut acc = ArrayAggAccumulator::try_new(&DataType::Utf8, false)?; + + // First batch: 3 elements + acc.update_batch(&[data(["A", "B", "C"])])?; + // Second batch: 1 element + acc.update_batch(&[data(["D"])])?; + + assert_eq!( + print_nulls(str_arr(acc.evaluate()?)?), + vec!["A", "B", "C", "D"] + ); + + // Partial retract from front array: A leaves + acc.retract_batch(&[data(["A"])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["B", "C", "D"]); + + // Retract spanning two arrays: B, C (rest of first array) + D (second array) + acc.retract_batch(&[data(["B", "C", "D"])])?; + let result = acc.evaluate()?; + assert!( + matches!(&result, ScalarValue::List(arr) if arr.is_null(0)), + "expected null list after full retract, got {result:?}" + ); + + Ok(()) + } + + #[test] + fn retract_with_nulls_preserved() -> Result<()> { + // ignore_nulls = false: NULLs are stored and counted for retract + let mut acc = ArrayAggAccumulator::try_new(&DataType::Utf8, false)?; + + acc.update_batch(&[data([Some("A"), None, Some("C")])])?; + assert_eq!( + print_nulls(str_arr(acc.evaluate()?)?), + vec!["A", "NULL", "C"] + ); + + // Retract 2 elements: A and NULL both leave + acc.retract_batch(&[data([Some("A"), None])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["C"]); + + Ok(()) + } + + #[test] + fn retract_with_ignore_nulls() -> Result<()> { + // ignore_nulls = true: NULLs are NOT stored by update_batch, + // so retract must only count non-null values + let mut acc = ArrayAggAccumulator::try_new(&DataType::Utf8, true)?; + + // update_batch with [A, NULL, C] → stores only [A, C] (NULL filtered) + acc.update_batch(&[data([Some("A"), None, Some("C")])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "C"]); + + // retract_batch receives the original values including NULL: [A, NULL] + // But only 1 non-null value (A) should be retracted + acc.retract_batch(&[data([Some("A"), None])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["C"]); + + // retract_batch with [NULL, C] — only C (1 non-null) retracted + acc.retract_batch(&[data([None, Some("C")])])?; + let result = acc.evaluate()?; + assert!( + matches!(&result, ScalarValue::List(arr) if arr.is_null(0)), + "expected null list after full retract, got {result:?}" + ); + + Ok(()) + } + + #[test] + fn retract_ignore_nulls_all_nulls_batch() -> Result<()> { + // When ignore_nulls = true and retract batch is all NULLs, nothing is retracted + let mut acc = ArrayAggAccumulator::try_new(&DataType::Utf8, true)?; + + acc.update_batch(&[data([Some("A"), Some("B")])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "B"]); + + // Retract batch of all NULLs: to_retract = 0, nothing changes + acc.retract_batch(&[data::, 3>([None, None, None])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "B"]); + + Ok(()) + } + + #[test] + fn retract_empty_accumulator() -> Result<()> { + let mut acc = ArrayAggAccumulator::try_new(&DataType::Utf8, false)?; + + // Retract on empty accumulator should be a no-op + acc.retract_batch(&[data(["A"])])?; + let result = acc.evaluate()?; + assert!( + matches!(&result, ScalarValue::List(arr) if arr.is_null(0)), + "expected null list for empty accumulator, got {result:?}" + ); + + Ok(()) + } + + #[test] + fn retract_front_offset_partial_consume() -> Result<()> { + // Reproduces the RANGE BETWEEN 2 PRECEDING AND 2 FOLLOWING scenario: + // ts: 1, 2, 3, 4, 100 + // + // Row 1 (ts=1): update [A,B,C] (3 elements, ts in [-1,3]) + // Row 2 (ts=2): update [D] (ts=4 enters) + // Row 3 (ts=3): no change (same frame [0..4)) + // Row 4 (ts=4): retract [A] (ts=1 leaves, partial consume) + // Row 5 (ts=100): retract [B,C,D] (3-element retract spanning arrays) + let mut acc = ArrayAggAccumulator::try_new(&DataType::Utf8, false)?; + + // Row 1: update_batch(["A","B","C"]) + acc.update_batch(&[data(["A", "B", "C"])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "B", "C"]); + + // Row 2: update_batch(["D"]) + acc.update_batch(&[data(["D"])])?; + assert_eq!( + print_nulls(str_arr(acc.evaluate()?)?), + vec!["A", "B", "C", "D"] + ); + + // Row 4: retract_batch(["A"]) — partial consume, front_offset = 1 + acc.retract_batch(&[data(["A"])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["B", "C", "D"]); + + // Row 5: update_batch(["E"]), then retract_batch(["B","C","D"]) + // retract spans: ["A","B","C"] (offset=1, 2 remaining) + ["D"] (1 element) + acc.update_batch(&[data(["E"])])?; + acc.retract_batch(&[data(["B", "C", "D"])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["E"]); + + Ok(()) + } + + #[test] + fn retract_update_after_full_drain() -> Result<()> { + // Verify accumulator works correctly after being fully drained + let mut acc = ArrayAggAccumulator::try_new(&DataType::Utf8, false)?; + + acc.update_batch(&[data(["A", "B"])])?; + acc.retract_batch(&[data(["A", "B"])])?; + + // Accumulator is empty now + let result = acc.evaluate()?; + assert!( + matches!(&result, ScalarValue::List(arr) if arr.is_null(0)), + "expected null list, got {result:?}" + ); + + // New values should work normally after drain + acc.update_batch(&[data(["X", "Y"])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["X", "Y"]); + + acc.retract_batch(&[data(["X"])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["Y"]); + + Ok(()) + } + + #[test] + fn retract_supports_retract_batch() -> Result<()> { + let acc = ArrayAggAccumulator::try_new(&DataType::Utf8, false)?; + assert!(acc.supports_retract_batch()); + + let acc_ignore = ArrayAggAccumulator::try_new(&DataType::Utf8, true)?; + assert!(acc_ignore.supports_retract_batch()); + + Ok(()) + } + + #[test] + fn retract_ignore_nulls_logical_vs_physical() -> Result<()> { + // Regression test: DictionaryArray where logical nulls differ from physical nulls. + // Manually construct a DictionaryArray where all indices are valid + // (physical null_count = 0) but some point to null dictionary values + // (logical_null_count > 0). + use arrow::array::{DictionaryArray, Int32Array, StringArray}; + + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let mut acc = ArrayAggAccumulator::try_new(&dict_type, true)?; + + // Dictionary values: ["hello", NULL, "world"] + // Keys: [0, 1, 2, 1] — all valid, but keys 1 and 3 point to null value + let values = StringArray::from(vec![Some("hello"), None, Some("world")]); + let keys = Int32Array::from(vec![0, 1, 2, 1]); + let dict_array: ArrayRef = Arc::new(DictionaryArray::new(keys, Arc::new(values))); + + // Confirm the divergence this test exists to exercise + assert_eq!( + dict_array.null_count(), + 0, + "physical nulls: none in keys bitmap" + ); + assert_eq!( + dict_array.logical_null_count(), + 2, + "logical nulls: keys pointing to null values" + ); + + // update_batch uses logical_nulls() → stores only ["hello", "world"] + acc.update_batch(std::slice::from_ref(&dict_array))?; + + // Verify 2 elements stored + let result = acc.evaluate()?; + match &result { + ScalarValue::List(arr) => { + let values = arr.value(0); + assert_eq!(values.len(), 2); + } + other => panic!("expected List, got {other:?}"), + } + + // retract_batch with same array: should retract 2 (logical non-nulls), not 4 (len) or 0 (physical non-nulls would be len-0=4) + acc.retract_batch(&[dict_array])?; + let result = acc.evaluate()?; + assert!( + matches!(&result, ScalarValue::List(arr) if arr.is_null(0)), + "expected null list after full retract, got {result:?}" + ); + + Ok(()) + } + + #[test] + fn retract_ignore_nulls_dict_partial() -> Result<()> { + // Partial retraction with DictionaryArray where logical != physical nulls. + // Manually construct so keys are all valid but some point to null values. + use arrow::array::{DictionaryArray, Int32Array, StringArray}; + + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let mut acc = ArrayAggAccumulator::try_new(&dict_type, true)?; + + // update with ["A", "B", "C"] (no nulls) + let values = StringArray::from(vec!["A", "B", "C"]); + let keys = Int32Array::from(vec![0, 1, 2]); + let update_array: ArrayRef = + Arc::new(DictionaryArray::new(keys, Arc::new(values))); + acc.update_batch(&[update_array])?; + + // retract with dict ["A", NULL, NULL]: + // keys [0, 1, 1] all valid → physical null_count = 0 + // keys 1,2 point to null value → logical_null_count = 2 + // non-null count = 3 - 2 = 1 → retract 1 element + let values = StringArray::from(vec![Some("A"), None]); + let keys = Int32Array::from(vec![0, 1, 1]); + let retract_array: ArrayRef = + Arc::new(DictionaryArray::new(keys, Arc::new(values))); + + assert_eq!( + retract_array.null_count(), + 0, + "physical nulls: none in keys bitmap" + ); + assert_eq!( + retract_array.logical_null_count(), + 2, + "logical nulls: keys pointing to null values" + ); + + acc.retract_batch(&[retract_array])?; + + // Should have retracted only 1 element, leaving ["B", "C"] + let result = acc.evaluate()?; + match &result { + ScalarValue::List(arr) => { + let values = arr.value(0); + assert_eq!(values.len(), 2); + } + other => panic!("expected List with 2 elements, got {other:?}"), + } + + Ok(()) + } + + // ---- DistinctArrayAggAccumulator retract_batch tests ---- + + // Build a DISTINCT accumulator with ascending sort so evaluate output is + // deterministic regardless of HashMap iteration order. + fn distinct_acc(ignore_nulls: bool) -> Result { + DistinctArrayAggAccumulator::try_new( + &DataType::Utf8, + Some(SortOptions::default()), + ignore_nulls, + ) + } + + #[test] + fn distinct_retract_duplicate_remains() -> Result<()> { + // Canonical regression for the HashSet-can't-retract bug: a value + // that appears multiple times in-frame must survive retraction of + // a single occurrence. + let mut acc = distinct_acc(false)?; + + // Feed [A, A, B] across two batches to exercise multi-batch state. + acc.update_batch(&[data(["A", "A"])])?; + acc.update_batch(&[data(["B"])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "B"]); + + // Retract a single A — the other A is still in the frame. + acc.retract_batch(&[data(["A"])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "B"]); + + // Retract the remaining A — only B left. + acc.retract_batch(&[data(["A"])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["B"]); + + Ok(()) + } + + #[test] + fn distinct_retract_full_removal() -> Result<()> { + let mut acc = distinct_acc(false)?; + + acc.update_batch(&[data(["A", "B"])])?; + acc.retract_batch(&[data(["A", "B"])])?; + + let result = acc.evaluate()?; + assert!( + matches!(&result, ScalarValue::List(arr) if arr.is_null(0)), + "expected null list after full retract, got {result:?}" + ); + + Ok(()) + } + + #[test] + fn distinct_retract_ignore_nulls_skips() -> Result<()> { + // ignore_nulls=true: NULL never enters state on update, so retract + // must also skip NULL — otherwise we'd error on the missing key. + let mut acc = distinct_acc(true)?; + + acc.update_batch(&[data([Some("A"), None, Some("B")])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A", "B"]); + + // Retract [A, NULL] — the NULL is skipped, only A is removed. + acc.retract_batch(&[data([Some("A"), None])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["B"]); + + Ok(()) + } + + #[test] + fn distinct_retract_null_tracked() -> Result<()> { + // ignore_nulls=false: NULL enters state with a refcount and must + // retract symmetrically; the NULL key must be removed at zero + // (else evaluate still emits a NULL element). + let mut acc = distinct_acc(false)?; + + acc.update_batch(&[data([Some("A"), None, None])])?; + // With nulls_first=true (SortOptions default), NULL sorts before A. + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["NULL", "A"]); + + // Retract one NULL — count drops to 1, key still present. + acc.retract_batch(&[data::, 1>([None])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["NULL", "A"]); + + // Retract the remaining NULL — key is removed. + acc.retract_batch(&[data::, 1>([None])])?; + assert_eq!(print_nulls(str_arr(acc.evaluate()?)?), vec!["A"]); + + Ok(()) + } + + #[test] + fn distinct_supports_retract_batch() -> Result<()> { + let acc = distinct_acc(false)?; + assert!(acc.supports_retract_batch()); + + let acc_ignore = distinct_acc(true)?; + assert!(acc_ignore.supports_retract_batch()); + + Ok(()) + } + + #[test] + fn distinct_merge_then_evaluate_regression() -> Result<()> { + // Non-window path: state -> merge_batch -> evaluate must still + // produce the union of distinct values across partitions. + let mut acc1 = distinct_acc(false)?; + let mut acc2 = distinct_acc(false)?; + + acc1.update_batch(&[data(["A", "A", "B"])])?; + acc2.update_batch(&[data(["A", "C"])])?; + + let state = acc2.state()?; + let state_arrs: Vec = state + .into_iter() + .map(|sv| sv.to_array_of_size(1)) + .collect::>>()?; + acc1.merge_batch(&state_arrs)?; + + assert_eq!(print_nulls(str_arr(acc1.evaluate()?)?), vec!["A", "B", "C"]); + + Ok(()) + } } diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index f4b3e598c31d2..ddeb9b0870a16 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -24,15 +24,15 @@ use arrow::array::{ use arrow::compute::sum; use arrow::datatypes::{ - i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Decimal32Type, - Decimal64Type, DecimalType, DurationMicrosecondType, DurationMillisecondType, - DurationNanosecondType, DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, - UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, - DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, - DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, + ArrowNativeType, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, + DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, DECIMAL128_MAX_PRECISION, + DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DataType, + Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType, + DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, + DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type, i256, }; -use datafusion_common::types::{logical_float64, NativeType}; -use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::types::{NativeType, logical_float64}; +use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -50,7 +50,6 @@ use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls: use datafusion_functions_aggregate_common::utils::DecimalAverager; use datafusion_macros::user_doc; use log::debug; -use std::any::Any; use std::fmt::Debug; use std::mem::{size_of, size_of_val}; use std::sync::Arc; @@ -127,10 +126,6 @@ impl Default for Avg { } impl AggregateUDFImpl for Avg { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "avg" } @@ -310,12 +305,14 @@ impl AggregateUDFImpl for Avg { }; // Similar to datafusion_functions_aggregate::sum::Sum::state_fields // since the accumulator uses DistinctSumAccumulator internally. - Ok(vec![Field::new_list( - format_state_name(args.name, "avg distinct"), - Field::new_list_field(dt, true), - false, - ) - .into()]) + Ok(vec![ + Field::new_list( + format_state_name(args.name, "avg distinct"), + Field::new_list_field(dt, true), + false, + ) + .into(), + ]) } else { Ok(vec![ Field::new( @@ -522,9 +519,16 @@ impl Accumulator for AvgAccumulator { } fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Float64( - self.sum.map(|f| f / self.count as f64), - )) + // In sliding-window mode `retract_batch` can bring `count` back to 0 + // while `sum` remains `Some(..)` (possibly zero or a floating-point + // residual). Guard against that so the frame with no non-NULL values + // yields NULL rather than NaN / ±Inf. + let avg = if self.count == 0 { + None + } else { + self.sum.map(|f| f / self.count as f64) + }; + Ok(ScalarValue::Float64(avg)) } fn size(&self) -> usize { @@ -587,17 +591,23 @@ impl Accumulator for DecimalAvgAccumu } fn evaluate(&mut self) -> Result { - let v = self - .sum - .map(|v| { - DecimalAverager::::try_new( - self.sum_scale, - self.target_precision, - self.target_scale, - )? - .avg(v, T::Native::from_usize(self.count as usize).unwrap()) - }) - .transpose()?; + // `count == 0` can occur in sliding-window mode after `retract_batch` + // removes every contributing value. Return NULL rather than dividing + // by zero (which would panic for integer decimal types). + let v = if self.count == 0 { + None + } else { + self.sum + .map(|v| { + DecimalAverager::::try_new( + self.sum_scale, + self.target_precision, + self.target_scale, + )? + .avg(v, T::Native::from_usize(self.count as usize).unwrap()) + }) + .transpose()? + }; ScalarValue::new_primitive::( v, @@ -673,7 +683,14 @@ impl Accumulator for DurationAvgAccumulator { } fn evaluate(&mut self) -> Result { - let avg = self.sum.map(|sum| sum / self.count as i64); + // Guard against `count == 0` which can happen in sliding-window mode + // after every contributing value has been retracted. Without this + // check we would integer-divide by zero. + let avg = if self.count == 0 { + None + } else { + self.sum.map(|sum| sum / self.count as i64) + }; match self.result_unit { TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)), @@ -752,7 +769,7 @@ impl Accumulator for DurationAvgAccumulator { struct AvgGroupsAccumulator where T: ArrowNumericType + Send, - F: Fn(T::Native, u64) -> Result + Send, + F: Fn(T::Native, u64) -> Result + Send + 'static, { /// The type of the internal sum sum_data_type: DataType, @@ -776,7 +793,7 @@ where impl AvgGroupsAccumulator where T: ArrowNumericType + Send, - F: Fn(T::Native, u64) -> Result + Send, + F: Fn(T::Native, u64) -> Result + Send + 'static, { pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self { debug!( @@ -798,7 +815,7 @@ where impl GroupsAccumulator for AvgGroupsAccumulator where T: ArrowNumericType + Send, - F: Fn(T::Native, u64) -> Result + Send, + F: Fn(T::Native, u64) -> Result + Send + 'static, { fn update_batch( &mut self, @@ -819,7 +836,8 @@ where opt_filter, total_num_groups, |group_index, new_value| { - let sum = &mut self.sums[group_index]; + // SAFETY: group_index is guaranteed to be in bounds + let sum = unsafe { self.sums.get_unchecked_mut(group_index) }; *sum = sum.add_wrapping(new_value); self.counts[group_index] += 1; @@ -834,12 +852,16 @@ where let sums = emit_to.take_needed(&mut self.sums); let nulls = self.null_state.build(emit_to); - assert_eq!(nulls.len(), sums.len()); + if let Some(nulls) = &nulls { + assert_eq!(nulls.len(), sums.len()); + } assert_eq!(counts.len(), sums.len()); // don't evaluate averages with null inputs to avoid errors on null values - let array: PrimitiveArray = if nulls.null_count() > 0 { + let array: PrimitiveArray = if let Some(nulls) = &nulls + && nulls.null_count() > 0 + { let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()) .with_data_type(self.return_data_type.clone()); let iter = sums.into_iter().zip(counts).zip(nulls.iter()); @@ -855,10 +877,10 @@ where } else { let averages: Vec = sums .into_iter() - .zip(counts.into_iter()) + .zip(counts) .map(|(sum, count)| (self.avg_fn)(sum, count)) .collect::>>()?; - PrimitiveArray::new(averages.into(), Some(nulls)) // no copy + PrimitiveArray::new(averages.into(), nulls) // no copy .with_data_type(self.return_data_type.clone()) }; @@ -868,7 +890,6 @@ where // return arrays for sums and counts fn state(&mut self, emit_to: EmitTo) -> Result> { let nulls = self.null_state.build(emit_to); - let nulls = Some(nulls); let counts = emit_to.take_needed(&mut self.counts); let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy @@ -902,7 +923,9 @@ where opt_filter, total_num_groups, |group_index, partial_count| { - self.counts[group_index] += partial_count; + // SAFETY: group_index is guaranteed to be in bounds + let count = unsafe { self.counts.get_unchecked_mut(group_index) }; + *count += partial_count; }, ); @@ -914,7 +937,8 @@ where opt_filter, total_num_groups, |group_index, new_value: ::Native| { - let sum = &mut self.sums[group_index]; + // SAFETY: group_index is guaranteed to be in bounds + let sum = unsafe { self.sums.get_unchecked_mut(group_index) }; *sum = sum.add_wrapping(new_value); }, ); diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index 28e5f7e37503e..d730a6c1cb3eb 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -17,21 +17,20 @@ //! Defines `BitAnd`, `BitOr`, `BitXor` and `BitXor DISTINCT` aggregate accumulators -use std::any::Any; use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::hash::Hash; use std::mem::{size_of, size_of_val}; -use ahash::RandomState; -use arrow::array::{downcast_integer, Array, ArrayRef, AsArray}; +use arrow::array::{Array, ArrayRef, AsArray, downcast_integer}; use arrow::datatypes::{ - ArrowNativeType, ArrowNumericType, DataType, Field, FieldRef, Int16Type, Int32Type, - Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ArrowNativeType, ArrowNumericType, DataType, Field, FieldRef, Int8Type, Int16Type, + Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, }; +use datafusion_common::hash_utils::RandomState; use datafusion_common::cast::as_list_array; -use datafusion_common::{not_impl_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, not_impl_err}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -240,10 +239,6 @@ impl BitwiseOperation { } impl AggregateUDFImpl for BitwiseOperation { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { self.func_name } @@ -262,30 +257,36 @@ impl AggregateUDFImpl for BitwiseOperation { fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.input_fields[0].data_type().is_null() { - Ok(vec![Field::new( - format_state_name(args.name, self.name()), - DataType::Null, - true, - ) - .into()]) + Ok(vec![ + Field::new( + format_state_name(args.name, self.name()), + DataType::Null, + true, + ) + .into(), + ]) } else if self.operation == BitwiseOperationType::Xor && args.is_distinct { - Ok(vec![Field::new_list( - format_state_name( - args.name, - format!("{} distinct", self.name()).as_str(), - ), - // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.return_type().clone(), true), - false, - ) - .into()]) + Ok(vec![ + Field::new_list( + format_state_name( + args.name, + format!("{} distinct", self.name()).as_str(), + ), + // See COMMENTS.md to understand why nullable is set to true + Field::new_list_field(args.return_type().clone(), true), + false, + ) + .into(), + ]) } else { - Ok(vec![Field::new( - format_state_name(args.name, self.name()), - args.return_field.data_type().clone(), - true, - ) - .into()]) + Ok(vec![ + Field::new( + format_state_name(args.name, self.name()), + args.return_field.data_type().clone(), + true, + ) + .into(), + ]) } } diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index ff389bb419e2e..3b900f1655ec1 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -17,7 +17,6 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use std::any::Any; use std::mem::size_of_val; use arrow::array::ArrayRef; @@ -28,10 +27,10 @@ use arrow::datatypes::Field; use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::internal_err; -use datafusion_common::{downcast_value, not_impl_err}; use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{downcast_value, not_impl_err}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; +use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF, Signature, Volatility, @@ -114,11 +113,7 @@ pub struct BoolAnd { impl BoolAnd { fn new() -> Self { Self { - signature: Signature::uniform( - 1, - vec![DataType::Boolean], - Volatility::Immutable, - ), + signature: Signature::exact(vec![DataType::Boolean], Volatility::Immutable), } } } @@ -130,10 +125,6 @@ impl Default for BoolAnd { } impl AggregateUDFImpl for BoolAnd { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "bool_and" } @@ -151,12 +142,14 @@ impl AggregateUDFImpl for BoolAnd { } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - Ok(vec![Field::new( - format_state_name(args.name, self.name()), - DataType::Boolean, - true, - ) - .into()]) + Ok(vec![ + Field::new( + format_state_name(args.name, self.name()), + DataType::Boolean, + true, + ) + .into(), + ]) } fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { @@ -249,11 +242,7 @@ pub struct BoolOr { impl BoolOr { fn new() -> Self { Self { - signature: Signature::uniform( - 1, - vec![DataType::Boolean], - Volatility::Immutable, - ), + signature: Signature::exact(vec![DataType::Boolean], Volatility::Immutable), } } } @@ -265,10 +254,6 @@ impl Default for BoolOr { } impl AggregateUDFImpl for BoolOr { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "bool_or" } @@ -286,12 +271,14 @@ impl AggregateUDFImpl for BoolOr { } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - Ok(vec![Field::new( - format_state_name(args.name, self.name()), - DataType::Boolean, - true, - ) - .into()]) + Ok(vec![ + Field::new( + format_state_name(args.name, self.name()), + DataType::Boolean, + true, + ) + .into(), + ]) } fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 4cac159ebb3a7..2621fcf0bf3c7 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -17,14 +17,13 @@ //! [`Correlation`]: correlation sample aggregations. -use std::any::Any; use std::fmt::Debug; use std::mem::size_of_val; use std::sync::Arc; use arrow::array::{ - downcast_array, Array, AsArray, BooleanArray, Float64Array, NullBufferBuilder, - UInt64Array, + Array, AsArray, BooleanArray, Float64Array, NullBufferBuilder, UInt64Array, + downcast_array, }; use arrow::compute::{and, filter, is_not_null}; use arrow::datatypes::{FieldRef, Float64Type, UInt64Type}; @@ -40,9 +39,9 @@ use crate::covariance::CovarianceAccumulator; use crate::stddev::StddevAccumulator; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, - Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, }; use datafusion_functions_aggregate_common::stats::StatsType; use datafusion_macros::user_doc; @@ -96,11 +95,6 @@ impl Correlation { } impl AggregateUDFImpl for Correlation { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "corr" } @@ -214,15 +208,14 @@ impl Accumulator for CorrelationAccumulator { return Ok(ScalarValue::Float64(None)); } - if let ScalarValue::Float64(Some(c)) = covar { - if let ScalarValue::Float64(Some(s1)) = stddev1 { - if let ScalarValue::Float64(Some(s2)) = stddev2 { - if s1 == 0_f64 || s2 == 0_f64 { - return Ok(ScalarValue::Float64(None)); - } else { - return Ok(ScalarValue::Float64(Some(c / s1 / s2))); - } - } + if let ScalarValue::Float64(Some(c)) = covar + && let ScalarValue::Float64(Some(s1)) = stddev1 + && let ScalarValue::Float64(Some(s2)) = stddev2 + { + if s1 == 0_f64 || s2 == 0_f64 { + return Ok(ScalarValue::Float64(None)); + } else { + return Ok(ScalarValue::Float64(Some(c / s1 / s2))); } } @@ -368,7 +361,7 @@ fn accumulate_correlation_states( /// where: /// n = number of observations /// sum_x = sum of x values -/// sum_y = sum of y values +/// sum_y = sum of y values /// sum_xy = sum of (x * y) /// sum_xx = sum of x^2 values /// sum_yy = sum of y^2 values @@ -412,11 +405,15 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { } fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let n = match emit_to { - EmitTo::All => self.count.len(), - EmitTo::First(n) => n, - }; - + // Drain the state vectors for the groups being emitted + let counts = emit_to.take_needed(&mut self.count); + let sum_xs = emit_to.take_needed(&mut self.sum_x); + let sum_ys = emit_to.take_needed(&mut self.sum_y); + let sum_xys = emit_to.take_needed(&mut self.sum_xy); + let sum_xxs = emit_to.take_needed(&mut self.sum_xx); + let sum_yys = emit_to.take_needed(&mut self.sum_yy); + + let n = counts.len(); let mut values = Vec::with_capacity(n); let mut nulls = NullBufferBuilder::new(n); @@ -428,14 +425,13 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { // result should be `Null` (according to PostgreSQL's behavior). // - However, if any of the accumulated values contain NaN, the result should // be NaN regardless of the count (even for single-row groups). - // for i in 0..n { - let count = self.count[i]; - let sum_x = self.sum_x[i]; - let sum_y = self.sum_y[i]; - let sum_xy = self.sum_xy[i]; - let sum_xx = self.sum_xx[i]; - let sum_yy = self.sum_yy[i]; + let count = counts[i]; + let sum_x = sum_xs[i]; + let sum_y = sum_ys[i]; + let sum_xy = sum_xys[i]; + let sum_xx = sum_xxs[i]; + let sum_yy = sum_yys[i]; // If BOTH sum_x AND sum_y are NaN, then both input values are NaN → return NaN // If only ONE of them is NaN, then only one input value is NaN → return NULL @@ -471,18 +467,21 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { } fn state(&mut self, emit_to: EmitTo) -> Result> { - let n = match emit_to { - EmitTo::All => self.count.len(), - EmitTo::First(n) => n, - }; + // Drain the state vectors for the groups being emitted + let count = emit_to.take_needed(&mut self.count); + let sum_x = emit_to.take_needed(&mut self.sum_x); + let sum_y = emit_to.take_needed(&mut self.sum_y); + let sum_xy = emit_to.take_needed(&mut self.sum_xy); + let sum_xx = emit_to.take_needed(&mut self.sum_xx); + let sum_yy = emit_to.take_needed(&mut self.sum_yy); Ok(vec![ - Arc::new(UInt64Array::from(self.count[0..n].to_vec())), - Arc::new(Float64Array::from(self.sum_x[0..n].to_vec())), - Arc::new(Float64Array::from(self.sum_y[0..n].to_vec())), - Arc::new(Float64Array::from(self.sum_xy[0..n].to_vec())), - Arc::new(Float64Array::from(self.sum_xx[0..n].to_vec())), - Arc::new(Float64Array::from(self.sum_yy[0..n].to_vec())), + Arc::new(UInt64Array::from(count)), + Arc::new(Float64Array::from(sum_x)), + Arc::new(Float64Array::from(sum_y)), + Arc::new(Float64Array::from(sum_xy)), + Arc::new(Float64Array::from(sum_xx)), + Arc::new(Float64Array::from(sum_yy)), ]) } @@ -509,7 +508,10 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { let partial_sum_xx = values[4].as_primitive::(); let partial_sum_yy = values[5].as_primitive::(); - assert!(opt_filter.is_none(), "aggregate filter should be applied in partial stage, there should be no filter in final stage"); + assert!( + opt_filter.is_none(), + "aggregate filter should be applied in partial stage, there should be no filter in final stage" + ); accumulate_correlation_states( group_indices, @@ -535,19 +537,18 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { } fn size(&self) -> usize { - size_of_val(&self.count) - + size_of_val(&self.sum_x) - + size_of_val(&self.sum_y) - + size_of_val(&self.sum_xy) - + size_of_val(&self.sum_xx) - + size_of_val(&self.sum_yy) + self.count.capacity() * size_of::() + + self.sum_x.capacity() * size_of::() + + self.sum_y.capacity() * size_of::() + + self.sum_xy.capacity() * size_of::() + + self.sum_xx.capacity() * size_of::() + + self.sum_yy.capacity() * size_of::() } } #[cfg(test)] mod tests { use super::*; - use arrow::array::{Float64Array, UInt64Array}; #[test] fn test_accumulate_correlation_states() { diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index a291e8e21eb0f..eab36d4951a9c 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -15,33 +15,38 @@ // specific language governing permissions and limitations // under the License. -use ahash::RandomState; use arrow::{ array::{Array, ArrayRef, AsArray, BooleanArray, Int64Array, PrimitiveArray}, buffer::BooleanBuffer, compute, datatypes::{ DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, - FieldRef, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, - Int8Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + FieldRef, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, + Int64Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, + UInt8Type, UInt16Type, UInt32Type, UInt64Type, }, }; +use datafusion_common::hash_utils::RandomState; use datafusion_common::{ - downcast_value, internal_err, not_impl_err, stats::Precision, - utils::expr::COUNT_STAR_EXPANSION, HashMap, Result, ScalarValue, + HashMap, Result, ScalarValue, downcast_value, exec_err, internal_err, not_impl_err, + stats::Precision, utils::expr::COUNT_STAR_EXPANSION, }; use datafusion_expr::{ - expr::WindowFunction, - function::{AccumulatorArgs, StateFieldsArgs}, - utils::format_state_name, Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, GroupsAccumulator, ReversedUDAF, SetMonotonicity, Signature, StatisticsArgs, TypeSignature, Volatility, WindowFunctionDefinition, + expr::WindowFunction, + function::{AccumulatorArgs, StateFieldsArgs}, + utils::format_state_name, }; +use datafusion_functions_aggregate_common::aggregate::count_distinct::PrimitiveDistinctCountGroupsAccumulator; use datafusion_functions_aggregate_common::aggregate::{ + count_distinct::Bitmap65536DistinctCountAccumulator, + count_distinct::Bitmap65536DistinctCountAccumulatorI16, + count_distinct::BoolArray256DistinctCountAccumulator, + count_distinct::BoolArray256DistinctCountAccumulatorI8, count_distinct::BytesDistinctCountAccumulator, count_distinct::BytesViewDistinctCountAccumulator, count_distinct::DictionaryCountAccumulator, @@ -147,20 +152,11 @@ pub fn count_all_window() -> Expr { ```"#, standard_argument(name = "expression",) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Count { signature: Signature, } -impl Debug for Count { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("Count") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for Count { fn default() -> Self { Self::new() @@ -179,31 +175,23 @@ impl Count { } fn get_count_accumulator(data_type: &DataType) -> Box { match data_type { - // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator - DataType::Int8 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - DataType::Int16 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), + // HashSet-based accumulator for larger integer types DataType::Int32 => Box::new(PrimitiveDistinctCountAccumulator::::new( data_type, )), DataType::Int64 => Box::new(PrimitiveDistinctCountAccumulator::::new( data_type, )), - DataType::UInt8 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - DataType::UInt16 => Box::new( - PrimitiveDistinctCountAccumulator::::new(data_type), - ), DataType::UInt32 => Box::new( PrimitiveDistinctCountAccumulator::::new(data_type), ), DataType::UInt64 => Box::new( PrimitiveDistinctCountAccumulator::::new(data_type), ), + // Small int types - cold path + DataType::UInt8 | DataType::Int8 | DataType::UInt16 | DataType::Int16 => { + get_small_int_accumulator(data_type).unwrap() + } DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< Decimal128Type, >::new(data_type)), @@ -279,11 +267,19 @@ fn get_count_accumulator(data_type: &DataType) -> Box { } } -impl AggregateUDFImpl for Count { - fn as_any(&self) -> &dyn std::any::Any { - self +/// Uses optimized bitmap accumulators but separated to keep hot path small +#[cold] +fn get_small_int_accumulator(data_type: &DataType) -> Result> { + match data_type { + DataType::UInt8 => Ok(Box::new(BoolArray256DistinctCountAccumulator::new())), + DataType::Int8 => Ok(Box::new(BoolArray256DistinctCountAccumulatorI8::new())), + DataType::UInt16 => Ok(Box::new(Bitmap65536DistinctCountAccumulator::new())), + DataType::Int16 => Ok(Box::new(Bitmap65536DistinctCountAccumulatorI16::new())), + _ => exec_err!("unsupported accumulator for datatype: {}", data_type), } +} +impl AggregateUDFImpl for Count { fn name(&self) -> &str { "count" } @@ -307,20 +303,24 @@ impl AggregateUDFImpl for Count { &dtype => dtype.clone(), }; - Ok(vec![Field::new_list( - format_state_name(args.name, "count distinct"), - // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(dtype, true), - false, - ) - .into()]) + Ok(vec![ + Field::new_list( + format_state_name(args.name, "count distinct"), + // See COMMENTS.md to understand why nullable is set to true + Field::new_list_field(dtype, true), + false, + ) + .into(), + ]) } else { - Ok(vec![Field::new( - format_state_name(args.name, "count"), - DataType::Int64, - false, - ) - .into()]) + Ok(vec![ + Field::new( + format_state_name(args.name, "count"), + DataType::Int64, + false, + ) + .into(), + ]) } } @@ -345,20 +345,33 @@ impl AggregateUDFImpl for Count { } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { - // groups accumulator only supports `COUNT(c1)`, not - // `COUNT(c1, c2)`, etc - if args.is_distinct { + if args.exprs.len() != 1 { return false; } - args.exprs.len() == 1 + if !args.is_distinct { + return true; + } + matches!( + args.expr_fields[0].data_type(), + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + ) } fn create_groups_accumulator( &self, - _args: AccumulatorArgs, + args: AccumulatorArgs, ) -> Result> { - // instantiate specialized accumulator - Ok(Box::new(CountGroupsAccumulator::new())) + if !args.is_distinct { + return Ok(Box::new(CountGroupsAccumulator::new())); + } + create_distinct_count_groups_accumulator(&args) } fn reverse_expr(&self) -> ReversedUDAF { @@ -370,32 +383,39 @@ impl AggregateUDFImpl for Count { } fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + let [expr] = statistics_args.exprs else { + return None; + }; + let col_stats = &statistics_args.statistics.column_statistics; + if statistics_args.is_distinct { + // Only column references can be resolved from statistics; + // expressions like casts or literals are not supported. + let col_expr = expr.downcast_ref::()?; + if let Precision::Exact(dc) = col_stats[col_expr.index()].distinct_count { + let dc = i64::try_from(dc).ok()?; + return Some(ScalarValue::Int64(Some(dc))); + } return None; } - if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows { - if statistics_args.exprs.len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = statistics_args.exprs[0] - .as_any() - .downcast_ref::() - { - let current_val = &statistics_args.statistics.column_statistics - [col_expr.index()] - .null_count; - if let &Precision::Exact(val) = current_val { - return Some(ScalarValue::Int64(Some((num_rows - val) as i64))); - } - } else if let Some(lit_expr) = statistics_args.exprs[0] - .as_any() - .downcast_ref::() - { - if lit_expr.value() == &COUNT_STAR_EXPANSION { - return Some(ScalarValue::Int64(Some(num_rows as i64))); - } - } + + let Precision::Exact(num_rows) = statistics_args.statistics.num_rows else { + return None; + }; + + // TODO optimize with exprs other than Column + if let Some(col_expr) = expr.downcast_ref::() { + if let Precision::Exact(val) = col_stats[col_expr.index()].null_count { + let count = i64::try_from(num_rows - val).ok()?; + return Some(ScalarValue::Int64(Some(count))); } + } else if let Some(lit_expr) = expr.downcast_ref::() + && lit_expr.value() == &COUNT_STAR_EXPANSION + { + let num_rows = i64::try_from(num_rows).ok()?; + return Some(ScalarValue::Int64(Some(num_rows))); } + None } @@ -424,6 +444,43 @@ impl AggregateUDFImpl for Count { } } +#[cold] +fn create_distinct_count_groups_accumulator( + args: &AccumulatorArgs, +) -> Result> { + let data_type = args.expr_fields[0].data_type(); + match data_type { + DataType::Int8 => Ok(Box::new( + PrimitiveDistinctCountGroupsAccumulator::::new(), + )), + DataType::Int16 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Int16Type, + >::new())), + DataType::Int32 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Int32Type, + >::new())), + DataType::Int64 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::< + Int64Type, + >::new())), + DataType::UInt8 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::< + UInt8Type, + >::new())), + DataType::UInt16 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::< + UInt16Type, + >::new())), + DataType::UInt32 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::< + UInt32Type, + >::new())), + DataType::UInt64 => Ok(Box::new(PrimitiveDistinctCountGroupsAccumulator::< + UInt64Type, + >::new())), + _ => not_impl_err!( + "GroupsAccumulator not supported for COUNT(DISTINCT) with {}", + data_type + ), + } +} + // DistinctCountAccumulator does not support retract_batch and sliding window // this is a specialized accumulator for distinct count that supports retract_batch // and sliding window. @@ -466,12 +523,12 @@ impl Accumulator for SlidingDistinctCountAccumulator { let arr = &values[0]; for i in 0..arr.len() { let v = ScalarValue::try_from_array(arr, i)?; - if !v.is_null() { - if let Some(cnt) = self.counts.get_mut(&v) { - *cnt -= 1; - if *cnt == 0 { - self.counts.remove(&v); - } + if !v.is_null() + && let Some(cnt) = self.counts.get_mut(&v) + { + *cnt -= 1; + if *cnt == 0 { + self.counts.remove(&v); } } } @@ -595,7 +652,9 @@ impl GroupsAccumulator for CountGroupsAccumulator { values.logical_nulls().as_ref(), opt_filter, |group_index| { - self.counts[group_index] += 1; + // SAFETY: group_index is guaranteed to be in bounds + let count = unsafe { self.counts.get_unchecked_mut(group_index) }; + *count += 1; }, ); @@ -854,7 +913,7 @@ mod tests { datatypes::{DataType, Field, Int32Type, Schema}, }; use datafusion_expr::function::AccumulatorArgs; - use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; + use datafusion_physical_expr::{PhysicalExpr, expressions::Column}; use std::sync::Arc; /// Helper function to create a dictionary array with non-null keys but some null values /// Returns a dictionary array where: diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index 7e34ffbaad01b..18d602ab33940 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -17,20 +17,14 @@ //! [`CovarianceSample`]: covariance sample aggregations. -use arrow::datatypes::FieldRef; -use arrow::{ - array::{ArrayRef, Float64Array, UInt64Array}, - compute::kernels::cast, - datatypes::{DataType, Field}, -}; -use datafusion_common::{ - downcast_value, plan_err, unwrap_or_internal_err, Result, ScalarValue, -}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::{as_float64_array, as_uint64_array}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, function::{AccumulatorArgs, StateFieldsArgs}, - type_coercion::aggregates::NUMERICS, utils::format_state_name, - Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, }; use datafusion_functions_aggregate_common::stats::StatsType; use datafusion_macros::user_doc; @@ -69,21 +63,12 @@ make_udaf_expr_and_func!( standard_argument(name = "expression1", prefix = "First"), standard_argument(name = "expression2", prefix = "Second") )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct CovarianceSample { signature: Signature, aliases: Vec, } -impl Debug for CovarianceSample { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("CovarianceSample") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for CovarianceSample { fn default() -> Self { Self::new() @@ -94,16 +79,15 @@ impl CovarianceSample { pub fn new() -> Self { Self { aliases: vec![String::from("covar")], - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::exact( + vec![DataType::Float64, DataType::Float64], + Volatility::Immutable, + ), } } } impl AggregateUDFImpl for CovarianceSample { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "covar_samp" } @@ -112,11 +96,7 @@ impl AggregateUDFImpl for CovarianceSample { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Covariance requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } @@ -165,20 +145,11 @@ impl AggregateUDFImpl for CovarianceSample { standard_argument(name = "expression1", prefix = "First"), standard_argument(name = "expression2", prefix = "Second") )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct CovariancePopulation { signature: Signature, } -impl Debug for CovariancePopulation { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("CovariancePopulation") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for CovariancePopulation { fn default() -> Self { Self::new() @@ -188,16 +159,15 @@ impl Default for CovariancePopulation { impl CovariancePopulation { pub fn new() -> Self { Self { - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::exact( + vec![DataType::Float64, DataType::Float64], + Volatility::Immutable, + ), } } } impl AggregateUDFImpl for CovariancePopulation { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "covar_pop" } @@ -206,11 +176,7 @@ impl AggregateUDFImpl for CovariancePopulation { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Covariance requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } @@ -304,30 +270,15 @@ impl Accumulator for CovarianceAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values1 = &cast(&values[0], &DataType::Float64)?; - let values2 = &cast(&values[1], &DataType::Float64)?; + let values1 = as_float64_array(&values[0])?; + let values2 = as_float64_array(&values[1])?; - let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); - let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); - - for i in 0..values1.len() { - let value1 = if values1.is_valid(i) { - arr1.next() - } else { - None - }; - let value2 = if values2.is_valid(i) { - arr2.next() - } else { - None + for (value1, value2) in values1.iter().zip(values2) { + let (value1, value2) = match (value1, value2) { + (Some(a), Some(b)) => (a, b), + _ => continue, }; - if value1.is_none() || value2.is_none() { - continue; - } - - let value1 = unwrap_or_internal_err!(value1); - let value2 = unwrap_or_internal_err!(value2); let new_count = self.count + 1; let delta1 = value1 - self.mean1; let new_mean1 = delta1 / new_count as f64 + self.mean1; @@ -345,29 +296,14 @@ impl Accumulator for CovarianceAccumulator { } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values1 = &cast(&values[0], &DataType::Float64)?; - let values2 = &cast(&values[1], &DataType::Float64)?; - let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); - let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); - - for i in 0..values1.len() { - let value1 = if values1.is_valid(i) { - arr1.next() - } else { - None - }; - let value2 = if values2.is_valid(i) { - arr2.next() - } else { - None - }; + let values1 = as_float64_array(&values[0])?; + let values2 = as_float64_array(&values[1])?; - if value1.is_none() || value2.is_none() { - continue; - } - - let value1 = unwrap_or_internal_err!(value1); - let value2 = unwrap_or_internal_err!(value2); + for (value1, value2) in values1.iter().zip(values2) { + let (value1, value2) = match (value1, value2) { + (Some(a), Some(b)) => (a, b), + _ => continue, + }; let new_count = self.count - 1; let delta1 = self.mean1 - value1; @@ -386,10 +322,10 @@ impl Accumulator for CovarianceAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); - let means1 = downcast_value!(states[1], Float64Array); - let means2 = downcast_value!(states[2], Float64Array); - let cs = downcast_value!(states[3], Float64Array); + let counts = as_uint64_array(&states[0])?; + let means1 = as_float64_array(&states[1])?; + let means2 = as_float64_array(&states[2])?; + let cs = as_float64_array(&states[3])?; for i in 0..counts.len() { let c = counts.value(i); diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index bedf124d3095c..1935f29c4cfe8 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -17,33 +17,30 @@ //! Defines the FIRST_VALUE/LAST_VALUE aggregations. -use std::any::Any; use std::fmt::Debug; use std::hash::Hash; use std::mem::size_of_val; use std::sync::Arc; -use arrow::array::{ - Array, ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, BooleanBufferBuilder, - PrimitiveArray, -}; -use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::array::{Array, ArrayRef, AsArray, BooleanArray, BooleanBufferBuilder}; +use arrow::buffer::BooleanBuffer; use arrow::compute::{self, LexicographicalComparator, SortColumn, SortOptions}; use arrow::datatypes::{ - DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Decimal32Type, - Decimal64Type, Field, FieldRef, Float16Type, Float32Type, Float64Type, Int16Type, - Int32Type, Int64Type, Int8Type, Time32MillisecondType, Time32SecondType, + DataType, Date32Type, Date64Type, Decimal32Type, Decimal64Type, Decimal128Type, + Decimal256Type, Field, FieldRef, Float16Type, Float32Type, Float64Type, Int8Type, + Int16Type, Int32Type, Int64Type, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, }; use datafusion_common::cast::as_boolean_array; use datafusion_common::utils::{compare_rows, extract_row_at_idx_to_buf, get_row_at_idx}; use datafusion_common::{ - arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, + DataFusionError, Result, ScalarValue, arrow_datafusion_err, internal_err, + not_impl_err, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; +use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt, GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility, @@ -52,6 +49,10 @@ use datafusion_functions_aggregate_common::utils::get_sort_options; use datafusion_macros::user_doc; use datafusion_physical_expr_common::sort_expr::LexOrdering; +mod state; + +use state::{BytesValueState, PrimitiveValueState, ValueState}; + create_func!(FirstValue, first_value_udaf); create_func!(LastValue, last_value_udaf); @@ -75,6 +76,142 @@ pub fn last_value(expression: Expr, order_by: Vec) -> Expr { .unwrap() } +fn create_groups_accumulator_helper( + args: &AccumulatorArgs, + is_first: bool, + state: S, +) -> Result> { + let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else { + return internal_err!("Groups accumulator must have an ordering."); + }; + + let ordering_dtypes = ordering + .iter() + .map(|e| e.expr.data_type(args.schema)) + .collect::>>()?; + + Ok(Box::new(FirstLastGroupsAccumulator::try_new( + state, + ordering, + args.ignore_nulls, + &ordering_dtypes, + is_first, + )?)) +} + +fn create_groups_accumulator( + args: &AccumulatorArgs, + is_first: bool, + function_name: &str, +) -> Result> { + let data_type = args.return_field.data_type(); + + macro_rules! instantiate_primitive { + ($t:ty) => { + create_groups_accumulator_helper( + args, + is_first, + PrimitiveValueState::<$t>::new(data_type.clone()), + ) + }; + } + + match data_type { + DataType::Int8 => instantiate_primitive!(Int8Type), + DataType::Int16 => instantiate_primitive!(Int16Type), + DataType::Int32 => instantiate_primitive!(Int32Type), + DataType::Int64 => instantiate_primitive!(Int64Type), + DataType::UInt8 => instantiate_primitive!(UInt8Type), + DataType::UInt16 => instantiate_primitive!(UInt16Type), + DataType::UInt32 => instantiate_primitive!(UInt32Type), + DataType::UInt64 => instantiate_primitive!(UInt64Type), + DataType::Float16 => instantiate_primitive!(Float16Type), + DataType::Float32 => instantiate_primitive!(Float32Type), + DataType::Float64 => instantiate_primitive!(Float64Type), + + DataType::Decimal32(_, _) => instantiate_primitive!(Decimal32Type), + DataType::Decimal64(_, _) => instantiate_primitive!(Decimal64Type), + DataType::Decimal128(_, _) => instantiate_primitive!(Decimal128Type), + DataType::Decimal256(_, _) => instantiate_primitive!(Decimal256Type), + + DataType::Timestamp(TimeUnit::Second, _) => { + instantiate_primitive!(TimestampSecondType) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + instantiate_primitive!(TimestampMillisecondType) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + instantiate_primitive!(TimestampMicrosecondType) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + instantiate_primitive!(TimestampNanosecondType) + } + + DataType::Date32 => instantiate_primitive!(Date32Type), + DataType::Date64 => instantiate_primitive!(Date64Type), + DataType::Time32(TimeUnit::Second) => instantiate_primitive!(Time32SecondType), + DataType::Time32(TimeUnit::Millisecond) => { + instantiate_primitive!(Time32MillisecondType) + } + DataType::Time64(TimeUnit::Microsecond) => { + instantiate_primitive!(Time64MicrosecondType) + } + DataType::Time64(TimeUnit::Nanosecond) => { + instantiate_primitive!(Time64NanosecondType) + } + + DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Utf8View + | DataType::Binary + | DataType::LargeBinary + | DataType::BinaryView => create_groups_accumulator_helper( + args, + is_first, + BytesValueState::try_new(data_type.clone())?, + ), + + _ => internal_err!( + "GroupsAccumulator not supported for {}({})", + function_name, + data_type + ), + } +} + +fn groups_accumulator_supported(args: &AccumulatorArgs) -> bool { + use DataType::*; + !args.order_bys.is_empty() + && matches!( + args.return_field.data_type(), + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float16 + | Float32 + | Float64 + | Decimal32(_, _) + | Decimal64(_, _) + | Decimal128(_, _) + | Decimal256(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + | Utf8 + | LargeUtf8 + | Utf8View + | Binary + | LargeBinary + | BinaryView + ) +} + #[user_doc( doc_section(label = "General Functions"), description = "Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.", @@ -89,22 +226,12 @@ pub fn last_value(expression: Expr, order_by: Vec) -> Expr { ```"#, standard_argument(name = "expression",) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct FirstValue { signature: Signature, is_input_pre_ordered: bool, } -impl Debug for FirstValue { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("FirstValue") - .field("name", &self.name()) - .field("signature", &self.signature) - .field("accumulator", &"") - .finish() - } -} - impl Default for FirstValue { fn default() -> Self { Self::new() @@ -121,10 +248,6 @@ impl FirstValue { } impl AggregateUDFImpl for FirstValue { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "first_value" } @@ -133,8 +256,20 @@ impl AggregateUDFImpl for FirstValue { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + not_impl_err!("Not called because the return_field_from_args is implemented") + } + + fn return_field(&self, arg_fields: &[FieldRef]) -> Result { + // Preserve metadata from the first argument field + Ok(Arc::new( + Field::new( + self.name(), + arg_fields[0].data_type().clone(), + true, // always nullable, there may be no rows + ) + .with_metadata(arg_fields[0].metadata().clone()), + )) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -159,12 +294,14 @@ impl AggregateUDFImpl for FirstValue { } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - let mut fields = vec![Field::new( - format_state_name(args.name, "first_value"), - args.return_type().clone(), - true, - ) - .into()]; + let mut fields = vec![ + Field::new( + format_state_name(args.name, "first_value"), + args.return_type().clone(), + true, + ) + .into(), + ]; fields.extend(args.ordering_fields.iter().cloned()); fields.push( Field::new( @@ -178,110 +315,14 @@ impl AggregateUDFImpl for FirstValue { } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { - use DataType::*; - !args.order_bys.is_empty() - && matches!( - args.return_field.data_type(), - Int8 | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float16 - | Float32 - | Float64 - | Decimal32(_, _) - | Decimal64(_, _) - | Decimal128(_, _) - | Decimal256(_, _) - | Date32 - | Date64 - | Time32(_) - | Time64(_) - | Timestamp(_, _) - ) + groups_accumulator_supported(&args) } fn create_groups_accumulator( &self, args: AccumulatorArgs, ) -> Result> { - fn create_accumulator( - args: &AccumulatorArgs, - ) -> Result> { - let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else { - return internal_err!("Groups accumulator must have an ordering."); - }; - - let ordering_dtypes = ordering - .iter() - .map(|e| e.expr.data_type(args.schema)) - .collect::>>()?; - - FirstPrimitiveGroupsAccumulator::::try_new( - ordering, - args.ignore_nulls, - args.return_field.data_type(), - &ordering_dtypes, - true, - ) - .map(|acc| Box::new(acc) as _) - } - - match args.return_field.data_type() { - DataType::Int8 => create_accumulator::(&args), - DataType::Int16 => create_accumulator::(&args), - DataType::Int32 => create_accumulator::(&args), - DataType::Int64 => create_accumulator::(&args), - DataType::UInt8 => create_accumulator::(&args), - DataType::UInt16 => create_accumulator::(&args), - DataType::UInt32 => create_accumulator::(&args), - DataType::UInt64 => create_accumulator::(&args), - DataType::Float16 => create_accumulator::(&args), - DataType::Float32 => create_accumulator::(&args), - DataType::Float64 => create_accumulator::(&args), - - DataType::Decimal32(_, _) => create_accumulator::(&args), - DataType::Decimal64(_, _) => create_accumulator::(&args), - DataType::Decimal128(_, _) => create_accumulator::(&args), - DataType::Decimal256(_, _) => create_accumulator::(&args), - - DataType::Timestamp(TimeUnit::Second, _) => { - create_accumulator::(&args) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - create_accumulator::(&args) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - create_accumulator::(&args) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - create_accumulator::(&args) - } - - DataType::Date32 => create_accumulator::(&args), - DataType::Date64 => create_accumulator::(&args), - DataType::Time32(TimeUnit::Second) => { - create_accumulator::(&args) - } - DataType::Time32(TimeUnit::Millisecond) => { - create_accumulator::(&args) - } - - DataType::Time64(TimeUnit::Microsecond) => { - create_accumulator::(&args) - } - DataType::Time64(TimeUnit::Nanosecond) => { - create_accumulator::(&args) - } - - _ => internal_err!( - "GroupsAccumulator not supported for first_value({})", - args.return_field.data_type() - ), - } + create_groups_accumulator(&args, true, self.name()) } fn with_beneficial_ordering( @@ -311,13 +352,9 @@ impl AggregateUDFImpl for FirstValue { } } -// TODO: rename to PrimitiveGroupsAccumulator -struct FirstPrimitiveGroupsAccumulator -where - T: ArrowPrimitiveType + Send, -{ +struct FirstLastGroupsAccumulator { // ================ state =========== - vals: Vec, + state: S, // Stores ordering values, of the aggregator requirement corresponding to first value // of the aggregator. // The `orderings` are stored row-wise, meaning that `orderings[group_idx]` @@ -326,19 +363,16 @@ where // At the beginning, `is_sets[group_idx]` is false, which means `first` is not seen yet. // Once we see the first value, we set the `is_sets[group_idx]` flag is_sets: BooleanBufferBuilder, - // null_builder[group_idx] == false => vals[group_idx] is null - null_builder: BooleanBufferBuilder, // size of `self.orderings` // Calculating the memory usage of `self.orderings` using `ScalarValue::size_of_vec` is quite costly. // Therefore, we cache it and compute `size_of` only after each update // to avoid calling `ScalarValue::size_of_vec` by Self.size. size_of_orderings: usize, - // buffer for `get_filtered_min_of_each_group` + // buffer for `get_filtered_extreme_of_each_group` // filter_min_of_each_group_buf.0[group_idx] -> idx_in_val // only valid if filter_min_of_each_group_buf.1[group_idx] == true - // TODO: rename to extreme_of_each_group_buf - min_of_each_group_buf: (Vec, BooleanBufferBuilder), + extreme_of_each_group_buf: (Vec, BooleanBufferBuilder), // =========== option ============ @@ -351,19 +385,14 @@ where sort_options: Vec, // Ignore null values. ignore_nulls: bool, - /// The output type - data_type: DataType, default_orderings: Vec, } -impl FirstPrimitiveGroupsAccumulator -where - T: ArrowPrimitiveType + Send, -{ +impl FirstLastGroupsAccumulator { fn try_new( + state: S, ordering_req: LexOrdering, ignore_nulls: bool, - data_type: &DataType, ordering_dtypes: &[DataType], pick_first_in_group: bool, ) -> Result { @@ -375,17 +404,15 @@ where let sort_options = get_sort_options(&ordering_req); Ok(Self { - null_builder: BooleanBufferBuilder::new(0), ordering_req, sort_options, ignore_nulls, default_orderings, - data_type: data_type.clone(), - vals: Vec::new(), + state, orderings: Vec::new(), is_sets: BooleanBufferBuilder::new(0), size_of_orderings: 0, - min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)), + extreme_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)), pick_first_in_group, }) } @@ -399,7 +426,7 @@ where return Ok(true); } - assert!(new_ordering_values.len() == self.ordering_req.len()); + debug_assert!(new_ordering_values.len() == self.ordering_req.len()); let current_ordering = &self.orderings[group_idx]; compare_rows(current_ordering, new_ordering_values, &self.sort_options).map(|x| { if self.pick_first_in_group { @@ -424,32 +451,8 @@ where result } - fn take_need( - bool_buf_builder: &mut BooleanBufferBuilder, - emit_to: EmitTo, - ) -> BooleanBuffer { - let bool_buf = bool_buf_builder.finish(); - match emit_to { - EmitTo::All => bool_buf, - EmitTo::First(n) => { - // split off the first N values in seen_values - // - // TODO make this more efficient rather than two - // copies and bitwise manipulation - let first_n: BooleanBuffer = bool_buf.iter().take(n).collect(); - // reset the existing buffer - for b in bool_buf.iter().skip(n) { - bool_buf_builder.append(b); - } - first_n - } - } - } - fn resize_states(&mut self, new_size: usize) { - self.vals.resize(new_size, T::default_value()); - - self.null_builder.resize(new_size); + self.state.resize(new_size); if self.orderings.len() < new_size { let current_len = self.orderings.len(); @@ -468,44 +471,43 @@ where self.is_sets.resize(new_size); - self.min_of_each_group_buf.0.resize(new_size, 0); - self.min_of_each_group_buf.1.resize(new_size); + self.extreme_of_each_group_buf.0.resize(new_size, 0); + self.extreme_of_each_group_buf.1.resize(new_size); } fn update_state( &mut self, group_idx: usize, orderings: &[ScalarValue], - new_val: T::Native, - is_null: bool, - ) { - self.vals[group_idx] = new_val; + array: &ArrayRef, + idx: usize, + ) -> Result<()> { + self.state.update(group_idx, array, idx)?; self.is_sets.set_bit(group_idx, true); - self.null_builder.set_bit(group_idx, !is_null); - - assert!(orderings.len() == self.ordering_req.len()); + debug_assert!(orderings.len() == self.ordering_req.len()); let old_size = ScalarValue::size_of_vec(&self.orderings[group_idx]); self.orderings[group_idx].clear(); self.orderings[group_idx].extend_from_slice(orderings); let new_size = ScalarValue::size_of_vec(&self.orderings[group_idx]); self.size_of_orderings = self.size_of_orderings - old_size + new_size; + Ok(()) } fn take_state( &mut self, emit_to: EmitTo, - ) -> (ArrayRef, Vec>, BooleanBuffer) { - emit_to.take_needed(&mut self.min_of_each_group_buf.0); - self.min_of_each_group_buf + ) -> Result<(ArrayRef, Vec>, BooleanBuffer)> { + emit_to.take_needed(&mut self.extreme_of_each_group_buf.0); + self.extreme_of_each_group_buf .1 - .truncate(self.min_of_each_group_buf.0.len()); + .truncate(self.extreme_of_each_group_buf.0.len()); - ( - self.take_vals_and_null_buf(emit_to), + Ok(( + self.state.take(emit_to)?, self.take_orderings(emit_to), - Self::take_need(&mut self.is_sets, emit_to), - ) + state::take_need(&mut self.is_sets, emit_to), + )) } // should be used in test only @@ -519,20 +521,19 @@ where /// Returns a vector of tuples `(group_idx, idx_in_val)` representing the index of the /// minimum value in `orderings` for each group, using lexicographical comparison. /// Values are filtered using `opt_filter` and `is_set_arr` if provided. - /// TODO: rename to get_filtered_extreme_of_each_group - fn get_filtered_min_of_each_group( + fn get_filtered_extreme_of_each_group( &mut self, orderings: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, - vals: &PrimitiveArray, + vals: &ArrayRef, is_set_arr: Option<&BooleanArray>, ) -> Result> { // Set all values in min_of_each_group_buf.1 to false. - self.min_of_each_group_buf.1.truncate(0); - self.min_of_each_group_buf + self.extreme_of_each_group_buf.1.truncate(0); + self.extreme_of_each_group_buf .1 - .append_n(self.vals.len(), false); + .append_n(self.is_sets.len(), false); // No need to call `clear` since `self.min_of_each_group_buf.0[group_idx]` // is only valid when `self.min_of_each_group_buf.1[group_idx] == true`. @@ -565,48 +566,35 @@ where continue; } - let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx); + let is_valid = self.extreme_of_each_group_buf.1.get_bit(group_idx); if !is_valid { - self.min_of_each_group_buf.1.set_bit(group_idx, true); - self.min_of_each_group_buf.0[group_idx] = idx_in_val; + self.extreme_of_each_group_buf.1.set_bit(group_idx, true); + self.extreme_of_each_group_buf.0[group_idx] = idx_in_val; } else { let ordering = comparator - .compare(self.min_of_each_group_buf.0[group_idx], idx_in_val); + .compare(self.extreme_of_each_group_buf.0[group_idx], idx_in_val); if (ordering.is_gt() && self.pick_first_in_group) || (ordering.is_lt() && !self.pick_first_in_group) { - self.min_of_each_group_buf.0[group_idx] = idx_in_val; + self.extreme_of_each_group_buf.0[group_idx] = idx_in_val; } } } Ok(self - .min_of_each_group_buf + .extreme_of_each_group_buf .0 .iter() .enumerate() - .filter(|(group_idx, _)| self.min_of_each_group_buf.1.get_bit(*group_idx)) + .filter(|(group_idx, _)| self.extreme_of_each_group_buf.1.get_bit(*group_idx)) .map(|(group_idx, idx_in_val)| (group_idx, *idx_in_val)) .collect::>()) } - - fn take_vals_and_null_buf(&mut self, emit_to: EmitTo) -> ArrayRef { - let r = emit_to.take_needed(&mut self.vals); - - let null_buf = NullBuffer::new(Self::take_need(&mut self.null_builder, emit_to)); - - let values = PrimitiveArray::::new(r.into(), Some(null_buf)) // no copy - .with_data_type(self.data_type.clone()); - Arc::new(values) - } } -impl GroupsAccumulator for FirstPrimitiveGroupsAccumulator -where - T: ArrowPrimitiveType + Send, -{ +impl GroupsAccumulator for FirstLastGroupsAccumulator { fn update_batch( &mut self, // e.g. first_value(a order by b): values_and_order_cols will be [a, b] @@ -617,13 +605,13 @@ where ) -> Result<()> { self.resize_states(total_num_groups); - let vals = values_and_order_cols[0].as_primitive::(); + let vals = &values_and_order_cols[0]; let mut ordering_buf = Vec::with_capacity(self.ordering_req.len()); // The overhead of calling `extract_row_at_idx_to_buf` is somewhat high, so we need to minimize its calls as much as possible. for (group_idx, idx) in self - .get_filtered_min_of_each_group( + .get_filtered_extreme_of_each_group( &values_and_order_cols[1..], group_indices, opt_filter, @@ -639,12 +627,7 @@ where )?; if self.should_update_state(group_idx, &ordering_buf)? { - self.update_state( - group_idx, - &ordering_buf, - vals.value(idx), - vals.is_null(idx), - ); + self.update_state(group_idx, &ordering_buf, vals, idx)?; } } @@ -652,11 +635,11 @@ where } fn evaluate(&mut self, emit_to: EmitTo) -> Result { - Ok(self.take_state(emit_to).0) + Ok(self.take_state(emit_to)?.0) } fn state(&mut self, emit_to: EmitTo) -> Result> { - let (val_arr, orderings, is_sets) = self.take_state(emit_to); + let (val_arr, orderings, is_sets) = self.take_state(emit_to)?; let mut result = Vec::with_capacity(self.orderings.len() + 2); result.push(val_arr); @@ -667,7 +650,7 @@ where ordering_cols.push(Vec::with_capacity(self.orderings.len())); } for row in orderings.into_iter() { - assert_eq!(row.len(), self.ordering_req.len()); + debug_assert!(row.len() == self.ordering_req.len()); for (col_idx, ordering) in row.into_iter().enumerate() { ordering_cols[col_idx].push(ordering); } @@ -702,9 +685,9 @@ where let is_set_arr = as_boolean_array(is_set_arr)?; - let vals = values[0].as_primitive::(); + let vals = &values[0]; // The overhead of calling `extract_row_at_idx_to_buf` is somewhat high, so we need to minimize its calls as much as possible. - let groups = self.get_filtered_min_of_each_group( + let groups = self.get_filtered_extreme_of_each_group( &val_and_order_cols[1..], group_indices, opt_filter, @@ -716,12 +699,7 @@ where extract_row_at_idx_to_buf(&val_and_order_cols[1..], idx, &mut ordering_buf)?; if self.should_update_state(group_idx, &ordering_buf)? { - self.update_state( - group_idx, - &ordering_buf, - vals.value(idx), - vals.is_null(idx), - ); + self.update_state(group_idx, &ordering_buf, vals, idx)?; } } @@ -729,12 +707,11 @@ where } fn size(&self) -> usize { - self.vals.capacity() * size_of::() - + self.null_builder.capacity() / 8 // capacity is in bits, so convert to bytes - + self.is_sets.capacity() / 8 + self.state.size() + + self.is_sets.capacity() / 8 // capacity is in bits, so convert to bytes + self.size_of_orderings - + self.min_of_each_group_buf.0.capacity() * size_of::() - + self.min_of_each_group_buf.1.capacity() / 8 + + self.extreme_of_each_group_buf.0.capacity() * size_of::() + + self.extreme_of_each_group_buf.1.capacity() / 8 } fn supports_convert_to_state(&self) -> bool { @@ -807,8 +784,7 @@ impl Accumulator for TrivialFirstValueAccumulator { first_idx = Some(0); } if let Some(first_idx) = first_idx { - let mut row = get_row_at_idx(values, first_idx)?; - self.first = row.swap_remove(0); + self.first = ScalarValue::try_from_array(&values[0], first_idx)?; self.first.compact(); self.is_set = true; } @@ -825,11 +801,11 @@ impl Accumulator for TrivialFirstValueAccumulator { let filtered_states = filter_states_according_to_is_set(&states[0..1], flags)?; - if let Some(first) = filtered_states.first() { - if !first.is_empty() { - self.first = ScalarValue::try_from_array(first, 0)?; - self.is_set = true; - } + if let Some(first) = filtered_states.first() + && !first.is_empty() + { + self.first = ScalarValue::try_from_array(first, 0)?; + self.is_set = true; } } Ok(()) @@ -854,6 +830,8 @@ pub struct FirstValueAccumulator { orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, + // derived from `ordering_req`. + sort_options: Vec, // Stores whether incoming data already satisfies the ordering requirement. is_input_pre_ordered: bool, // Ignore null values. @@ -873,11 +851,13 @@ impl FirstValueAccumulator { .iter() .map(ScalarValue::try_from) .collect::>()?; + let sort_options = get_sort_options(&ordering_req); ScalarValue::try_from(data_type).map(|first| Self { first, is_set: false, orderings, ordering_req, + sort_options, is_input_pre_ordered, ignore_nulls, }) @@ -950,12 +930,8 @@ impl Accumulator for FirstValueAccumulator { let row = get_row_at_idx(values, first_idx)?; if !self.is_set || (!self.is_input_pre_ordered - && compare_rows( - &self.orderings, - &row[1..], - &get_sort_options(&self.ordering_req), - )? - .is_gt()) + && compare_rows(&self.orderings, &row[1..], &self.sort_options)? + .is_gt()) { self.update_with_new_row(row); } @@ -983,10 +959,10 @@ impl Accumulator for FirstValueAccumulator { let mut first_row = get_row_at_idx(&filtered_states, first_idx)?; // When collecting orderings, we exclude the is_set flag from the state. let first_ordering = &first_row[1..is_set_idx]; - let sort_options = get_sort_options(&self.ordering_req); // Either there is no existing value, or there is an earlier version in new data. if !self.is_set - || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt() + || compare_rows(&self.orderings, first_ordering, &self.sort_options)? + .is_gt() { // Update with first value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state @@ -1025,22 +1001,12 @@ impl Accumulator for FirstValueAccumulator { ```"#, standard_argument(name = "expression",) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct LastValue { signature: Signature, is_input_pre_ordered: bool, } -impl Debug for LastValue { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("LastValue") - .field("name", &self.name()) - .field("signature", &self.signature) - .field("accumulator", &"") - .finish() - } -} - impl Default for LastValue { fn default() -> Self { Self::new() @@ -1057,10 +1023,6 @@ impl LastValue { } impl AggregateUDFImpl for LastValue { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "last_value" } @@ -1069,8 +1031,20 @@ impl AggregateUDFImpl for LastValue { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + not_impl_err!("Not called because the return_field_from_args is implemented") + } + + fn return_field(&self, arg_fields: &[FieldRef]) -> Result { + // Preserve metadata from the first argument field + Ok(Arc::new( + Field::new( + self.name(), + arg_fields[0].data_type().clone(), + true, // always nullable, there may be no rows + ) + .with_metadata(arg_fields[0].metadata().clone()), + )) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -1095,12 +1069,14 @@ impl AggregateUDFImpl for LastValue { } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - let mut fields = vec![Field::new( - format_state_name(args.name, "last_value"), - args.return_field.data_type().clone(), - true, - ) - .into()]; + let mut fields = vec![ + Field::new( + format_state_name(args.name, "last_value"), + args.return_field.data_type().clone(), + true, + ) + .into(), + ]; fields.extend(args.ordering_fields.iter().cloned()); fields.push( Field::new( @@ -1140,114 +1116,14 @@ impl AggregateUDFImpl for LastValue { } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { - use DataType::*; - !args.order_bys.is_empty() - && matches!( - args.return_field.data_type(), - Int8 | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float16 - | Float32 - | Float64 - | Decimal32(_, _) - | Decimal64(_, _) - | Decimal128(_, _) - | Decimal256(_, _) - | Date32 - | Date64 - | Time32(_) - | Time64(_) - | Timestamp(_, _) - ) + groups_accumulator_supported(&args) } fn create_groups_accumulator( &self, args: AccumulatorArgs, ) -> Result> { - fn create_accumulator( - args: &AccumulatorArgs, - ) -> Result> - where - T: ArrowPrimitiveType + Send, - { - let Some(ordering) = LexOrdering::new(args.order_bys.to_vec()) else { - return internal_err!("Groups accumulator must have an ordering."); - }; - - let ordering_dtypes = ordering - .iter() - .map(|e| e.expr.data_type(args.schema)) - .collect::>>()?; - - Ok(Box::new(FirstPrimitiveGroupsAccumulator::::try_new( - ordering, - args.ignore_nulls, - args.return_field.data_type(), - &ordering_dtypes, - false, - )?)) - } - - match args.return_field.data_type() { - DataType::Int8 => create_accumulator::(&args), - DataType::Int16 => create_accumulator::(&args), - DataType::Int32 => create_accumulator::(&args), - DataType::Int64 => create_accumulator::(&args), - DataType::UInt8 => create_accumulator::(&args), - DataType::UInt16 => create_accumulator::(&args), - DataType::UInt32 => create_accumulator::(&args), - DataType::UInt64 => create_accumulator::(&args), - DataType::Float16 => create_accumulator::(&args), - DataType::Float32 => create_accumulator::(&args), - DataType::Float64 => create_accumulator::(&args), - - DataType::Decimal32(_, _) => create_accumulator::(&args), - DataType::Decimal64(_, _) => create_accumulator::(&args), - DataType::Decimal128(_, _) => create_accumulator::(&args), - DataType::Decimal256(_, _) => create_accumulator::(&args), - - DataType::Timestamp(TimeUnit::Second, _) => { - create_accumulator::(&args) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - create_accumulator::(&args) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - create_accumulator::(&args) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - create_accumulator::(&args) - } - - DataType::Date32 => create_accumulator::(&args), - DataType::Date64 => create_accumulator::(&args), - DataType::Time32(TimeUnit::Second) => { - create_accumulator::(&args) - } - DataType::Time32(TimeUnit::Millisecond) => { - create_accumulator::(&args) - } - - DataType::Time64(TimeUnit::Microsecond) => { - create_accumulator::(&args) - } - DataType::Time64(TimeUnit::Nanosecond) => { - create_accumulator::(&args) - } - - _ => { - internal_err!( - "GroupsAccumulator not supported for last_value({})", - args.return_field.data_type() - ) - } - } + create_groups_accumulator(&args, false, self.name()) } } @@ -1299,8 +1175,7 @@ impl Accumulator for TrivialLastValueAccumulator { last_idx = Some(value.len() - 1); } if let Some(last_idx) = last_idx { - let mut row = get_row_at_idx(values, last_idx)?; - self.last = row.swap_remove(0); + self.last = ScalarValue::try_from_array(&values[0], last_idx)?; self.last.compact(); self.is_set = true; } @@ -1314,11 +1189,11 @@ impl Accumulator for TrivialLastValueAccumulator { validate_is_set_flags(flags, "last_value")?; let filtered_states = filter_states_according_to_is_set(&states[0..1], flags)?; - if let Some(last) = filtered_states.last() { - if !last.is_empty() { - self.last = ScalarValue::try_from_array(last, 0)?; - self.is_set = true; - } + if let Some(last) = filtered_states.last() + && !last.is_empty() + { + self.last = ScalarValue::try_from_array(last, 0)?; + self.is_set = true; } Ok(()) } @@ -1344,6 +1219,8 @@ struct LastValueAccumulator { orderings: Vec, // Stores the applicable ordering requirement. ordering_req: LexOrdering, + // derived from `ordering_req`. + sort_options: Vec, // Stores whether incoming data already satisfies the ordering requirement. is_input_pre_ordered: bool, // Ignore null values. @@ -1363,11 +1240,13 @@ impl LastValueAccumulator { .iter() .map(ScalarValue::try_from) .collect::>()?; + let sort_options = get_sort_options(&ordering_req); ScalarValue::try_from(data_type).map(|last| Self { last, is_set: false, orderings, ordering_req, + sort_options, is_input_pre_ordered, ignore_nulls, }) @@ -1440,12 +1319,7 @@ impl Accumulator for LastValueAccumulator { // Update when there is a more recent entry if !self.is_set || self.is_input_pre_ordered - || compare_rows( - &self.orderings, - orderings, - &get_sort_options(&self.ordering_req), - )? - .is_lt() + || compare_rows(&self.orderings, orderings, &self.sort_options)?.is_lt() { self.update_with_new_row(row); } @@ -1473,12 +1347,12 @@ impl Accumulator for LastValueAccumulator { let mut last_row = get_row_at_idx(&filtered_states, last_idx)?; // When collecting orderings, we exclude the is_set flag from the state. let last_ordering = &last_row[1..is_set_idx]; - let sort_options = get_sort_options(&self.ordering_req); // Either there is no existing value, or there is a newer (latest) // version in the new data: if !self.is_set || self.is_input_pre_ordered - || compare_rows(&self.orderings, last_ordering, &sort_options)?.is_lt() + || compare_rows(&self.orderings, last_ordering, &self.sort_options)? + .is_lt() { // Update with last value in the state. Note that we should exclude the // is_set flag from the state. Otherwise, we will end up with a state @@ -1541,11 +1415,11 @@ mod tests { use std::iter::repeat_with; use arrow::{ - array::{BooleanArray, Int64Array, ListArray, StringArray}, + array::{BooleanArray, Int64Array, ListArray, PrimitiveArray, StringArray}, compute::SortOptions, datatypes::Schema, }; - use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; + use datafusion_physical_expr::{PhysicalSortExpr, expressions::col}; use super::*; @@ -1668,10 +1542,10 @@ mod tests { options: SortOptions::default(), }]; - let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( + let mut group_acc = FirstLastGroupsAccumulator::try_new( + PrimitiveValueState::::new(DataType::Int64), sort_keys.into(), true, - &DataType::Int64, &[DataType::Int64], true, )?; @@ -1762,10 +1636,10 @@ mod tests { options: SortOptions::default(), }]; - let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( + let mut group_acc = FirstLastGroupsAccumulator::try_new( + PrimitiveValueState::::new(DataType::Int64), sort_keys.into(), true, - &DataType::Int64, &[DataType::Int64], true, )?; @@ -1843,10 +1717,10 @@ mod tests { options: SortOptions::default(), }]; - let mut group_acc = FirstPrimitiveGroupsAccumulator::::try_new( + let mut group_acc = FirstLastGroupsAccumulator::try_new( + PrimitiveValueState::::new(DataType::Int64), sort_keys.into(), true, - &DataType::Int64, &[DataType::Int64], false, )?; @@ -1967,10 +1841,12 @@ mod tests { let trivial_states = vec![Arc::clone(&value), Arc::clone(&corrupted_flag)]; let result = trivial_accumulator.merge_batch(&trivial_states); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("is_set flags contain nulls")); + assert!( + result + .unwrap_err() + .to_string() + .contains("is_set flags contain nulls") + ); // Test FirstValueAccumulator (with ordering) let schema = Schema::new(vec![Field::new("ordering", DataType::Int64, false)]); @@ -1990,10 +1866,12 @@ mod tests { let ordered_states = vec![value, ordering, corrupted_flag]; let result = ordered_accumulator.merge_batch(&ordered_states); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("is_set flags contain nulls")); + assert!( + result + .unwrap_err() + .to_string() + .contains("is_set flags contain nulls") + ); Ok(()) } @@ -2010,10 +1888,12 @@ mod tests { let trivial_states = vec![Arc::clone(&value), Arc::clone(&corrupted_flag)]; let result = trivial_accumulator.merge_batch(&trivial_states); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("is_set flags contain nulls")); + assert!( + result + .unwrap_err() + .to_string() + .contains("is_set flags contain nulls") + ); // Test LastValueAccumulator (with ordering) let schema = Schema::new(vec![Field::new("ordering", DataType::Int64, false)]); @@ -2033,10 +1913,12 @@ mod tests { let ordered_states = vec![value, ordering, corrupted_flag]; let result = ordered_accumulator.merge_batch(&ordered_states); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("is_set flags contain nulls")); + assert!( + result + .unwrap_err() + .to_string() + .contains("is_set flags contain nulls") + ); Ok(()) } diff --git a/datafusion/functions-aggregate/src/first_last/state.rs b/datafusion/functions-aggregate/src/first_last/state.rs new file mode 100644 index 0000000000000..cd7114bf04f9c --- /dev/null +++ b/datafusion/functions-aggregate/src/first_last/state.rs @@ -0,0 +1,462 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::mem::size_of; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, ArrowPrimitiveType, AsArray, BinaryBuilder, BinaryViewBuilder, + BooleanBufferBuilder, LargeBinaryBuilder, LargeStringBuilder, PrimitiveArray, + StringBuilder, StringViewBuilder, +}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::datatypes::DataType; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::EmitTo; + +pub(crate) trait ValueState: Send + Sync { + /// Resizes the state to accommodate `new_size` groups. + fn resize(&mut self, new_size: usize); + /// Updates the state for the specified `group_idx` using the value at `idx` from the provided `array`. + /// + /// Note: While this is not a batch interface, it is not a performance bottleneck. + /// In heavy aggregation benchmarks, the overhead of this method is typically less than 1%. + /// + /// ```sql + /// -- TPC-H SF10 + /// select l_shipmode, last_value(l_partkey order by l_orderkey, l_linenumber, l_comment, l_suppkey, l_tax) + /// from 'benchmarks/data/tpch_sf10/lineitem' + /// group by l_shipmode; + /// + /// -- H2O G1_1e8 + /// select t.id1, first_value(t.id3 order by t.id2, t.id4) as r2 + /// from 'benchmarks/data/h2o/G1_1e8_1e8_100_0.parquet' as t + /// group by t.id1, t.v1; + /// ``` + fn update(&mut self, group_idx: usize, array: &ArrayRef, idx: usize) -> Result<()>; + /// Takes the accumulated state and returns it as an [`ArrayRef`], respecting the `emit_to` strategy. + fn take(&mut self, emit_to: EmitTo) -> Result; + /// Returns the estimated memory size of the state in bytes. + fn size(&self) -> usize; +} + +pub(crate) struct PrimitiveValueState { + /// Values data + vals: Vec, + nulls: BooleanBufferBuilder, + data_type: DataType, +} + +impl PrimitiveValueState { + pub(crate) fn new(data_type: DataType) -> Self { + Self { + vals: vec![], + nulls: BooleanBufferBuilder::new(0), + data_type, + } + } +} + +impl ValueState for PrimitiveValueState { + fn resize(&mut self, new_size: usize) { + self.vals.resize(new_size, T::default_value()); + self.nulls.resize(new_size); + } + + fn update(&mut self, group_idx: usize, array: &ArrayRef, idx: usize) -> Result<()> { + let array = array.as_primitive::(); + self.vals[group_idx] = array.value(idx); + self.nulls.set_bit(group_idx, !array.is_null(idx)); + Ok(()) + } + + fn take(&mut self, emit_to: EmitTo) -> Result { + let values = emit_to.take_needed(&mut self.vals); + let null_buf = NullBuffer::new(take_need(&mut self.nulls, emit_to)); + let array: PrimitiveArray = + PrimitiveArray::::new(values.into(), Some(null_buf)) + .with_data_type(self.data_type.clone()); + Ok(Arc::new(array)) + } + + fn size(&self) -> usize { + self.vals.capacity() * size_of::() + self.nulls.capacity() / 8 + } +} + +/// Stores internal state for "bytes" types (Utf8, Binary, etc.). +/// +/// This implementation is similar to `MinMaxBytesState` in `min_max_bytes.rs`, but +/// it does not reuse it for two main reasons: +/// +/// 1. **Direct Overwrite**: `MinMaxBytesState::update_batch` is tightly coupled +/// with min/max comparison logic, whereas `FirstLast` performs its own comparisons +/// externally (using ordering columns) and only needs a simple interface to +/// unconditionally set/overwrite values for specific groups. +/// 2. **Different NULL Handling**: `MinMaxBytesState` always ignores `NULL` values +/// in the input, while `BytesValueState` needs to support setting `NULL` values +/// to correctly implement `RESPECT NULLS` behavior. +/// +pub(crate) struct BytesValueState { + vals: Vec>>, + data_type: DataType, + /// The sum of the capacities of all vectors in `vals`. + total_capacity: usize, +} + +impl BytesValueState { + pub(crate) fn try_new(data_type: DataType) -> Result { + if !matches!( + data_type, + DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Utf8View + | DataType::Binary + | DataType::LargeBinary + | DataType::BinaryView + ) { + return internal_err!("BytesValueState does not support {}", data_type); + } + Ok(Self { + vals: vec![], + data_type, + total_capacity: 0, + }) + } +} + +impl ValueState for BytesValueState { + fn resize(&mut self, new_size: usize) { + if new_size < self.vals.len() { + for v in self.vals[new_size..].iter().flatten() { + self.total_capacity -= v.capacity(); + } + } + self.vals.resize(new_size, None); + } + + fn update(&mut self, group_idx: usize, array: &ArrayRef, idx: usize) -> Result<()> { + if let Some(v) = &self.vals[group_idx] { + self.total_capacity -= v.capacity(); + } + + if array.is_null(idx) { + self.vals[group_idx] = None; + } else { + let val = match self.data_type { + DataType::Utf8 => array.as_string::().value(idx).as_bytes(), + DataType::LargeUtf8 => array.as_string::().value(idx).as_bytes(), + DataType::Utf8View => array.as_string_view().value(idx).as_bytes(), + DataType::Binary => array.as_binary::().value(idx), + DataType::LargeBinary => array.as_binary::().value(idx), + DataType::BinaryView => array.as_binary_view().value(idx), + _ => { + return internal_err!( + "Unsupported data type for BytesValueState: {}", + self.data_type + ); + } + }; + + if let Some(v) = &mut self.vals[group_idx] { + v.clear(); + v.extend_from_slice(val); + } else { + let v = val.to_vec(); + self.vals[group_idx] = Some(v); + } + + self.vals[group_idx] + .as_ref() + .inspect(|x| self.total_capacity += x.capacity()); + } + Ok(()) + } + + fn take(&mut self, emit_to: EmitTo) -> Result { + let values = emit_to.take_needed(&mut self.vals); + + let (total_len, taken_capacity) = values + .iter() + .flatten() + .fold((0, 0), |(len_acc, cap_acc), v| { + (len_acc + v.len(), cap_acc + v.capacity()) + }); + self.total_capacity -= taken_capacity; + + match self.data_type { + DataType::Utf8 => { + let mut builder = StringBuilder::with_capacity(values.len(), total_len); + for val in values { + match val { + Some(v) => builder.append_value( + // SAFETY: The bytes were originally from a valid UTF-8 array in `update` + unsafe { std::str::from_utf8_unchecked(&v) }, + ), + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish())) + } + DataType::LargeUtf8 => { + let mut builder = + LargeStringBuilder::with_capacity(values.len(), total_len); + for val in values { + match val { + Some(v) => builder.append_value( + // SAFETY: The bytes were originally from a valid UTF-8 array in `update` + unsafe { std::str::from_utf8_unchecked(&v) }, + ), + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Utf8View => { + let mut builder = StringViewBuilder::with_capacity(values.len()); + for val in values { + match val { + Some(v) => builder.append_value( + // SAFETY: The bytes were originally from a valid UTF-8 array in `update` + unsafe { std::str::from_utf8_unchecked(&v) }, + ), + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish())) + } + DataType::Binary => { + let mut builder = BinaryBuilder::with_capacity(values.len(), total_len); + for val in values { + match val { + Some(v) => builder.append_value(&v), + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish())) + } + DataType::LargeBinary => { + let mut builder = + LargeBinaryBuilder::with_capacity(values.len(), total_len); + for val in values { + match val { + Some(v) => builder.append_value(&v), + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish())) + } + DataType::BinaryView => { + let mut builder = BinaryViewBuilder::with_capacity(values.len()); + for val in values { + match val { + Some(v) => builder.append_value(&v), + None => builder.append_null(), + } + } + Ok(Arc::new(builder.finish())) + } + _ => internal_err!( + "Unsupported data type for BytesValueState: {}", + self.data_type + ), + } + } + + fn size(&self) -> usize { + self.vals.capacity() * size_of::>>() + self.total_capacity + } +} + +impl BytesValueState { + #[cfg(test)] + /// For testing only: strictly calculate the sum of capacities of all vectors in `vals`. + fn total_capacity_calculated(&self) -> usize { + self.vals.iter().flatten().map(|v| v.capacity()).sum() + } +} + +pub(crate) fn take_need( + bool_buf_builder: &mut BooleanBufferBuilder, + emit_to: EmitTo, +) -> BooleanBuffer { + let bool_buf = bool_buf_builder.finish(); + match emit_to { + EmitTo::All => bool_buf, + EmitTo::First(n) => { + // split off the first N values in seen_values + // + let first_n: BooleanBuffer = bool_buf.slice(0, n); + // reset the existing buffer + bool_buf_builder.append_buffer(&bool_buf.slice(n, bool_buf.len() - n)); + first_n + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ + BinaryArray, BinaryViewArray, LargeBinaryArray, LargeStringArray, StringArray, + StringViewArray, + }; + + #[test] + fn test_bytes_value_state_utf8() -> Result<()> { + let mut state = BytesValueState::try_new(DataType::Utf8)?; + state.resize(2); + + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("hello"), + Some("world"), + Some("longer_string_than_hello"), + ])); + + state.update(0, &array, 0)?; // group 0 = "hello" + state.update(1, &array, 1)?; // group 1 = "world" + + assert_eq!(state.total_capacity, state.total_capacity_calculated()); + + // Overwrite group 0 with a longer string (checks capacity update logic) + state.update(0, &array, 2)?; + assert_eq!(state.total_capacity, state.total_capacity_calculated()); + + let result = state.take(EmitTo::All)?; + let result = result.as_string::(); + assert_eq!(result.len(), 2); + assert_eq!(result.value(0), "longer_string_than_hello"); + assert_eq!(result.value(1), "world"); + + // After take all, size should be 0 (excluding vals vector capacity) + assert_eq!(state.total_capacity, 0); + assert_eq!(state.total_capacity, state.total_capacity_calculated()); + + Ok(()) + } + + #[test] + fn test_bytes_value_state_large_utf8() -> Result<()> { + let mut state = BytesValueState::try_new(DataType::LargeUtf8)?; + state.resize(1); + let array: ArrayRef = Arc::new(LargeStringArray::from(vec!["large_utf8"])); + state.update(0, &array, 0)?; + let result = state.take(EmitTo::All)?; + assert_eq!(result.as_string::().value(0), "large_utf8"); + Ok(()) + } + + #[test] + fn test_bytes_value_state_utf8_view() -> Result<()> { + let mut state = BytesValueState::try_new(DataType::Utf8View)?; + state.resize(1); + let array: ArrayRef = Arc::new(StringViewArray::from(vec!["Utf8View"])); + state.update(0, &array, 0)?; + let result = state.take(EmitTo::All)?; + assert_eq!(result.as_string_view().value(0), "Utf8View"); + Ok(()) + } + + #[test] + fn test_bytes_value_state_binary() -> Result<()> { + let mut state = BytesValueState::try_new(DataType::Binary)?; + state.resize(1); + let array: ArrayRef = Arc::new(BinaryArray::from(vec![b"binary" as &[u8]])); + state.update(0, &array, 0)?; + let result = state.take(EmitTo::All)?; + assert_eq!(result.as_binary::().value(0), b"binary"); + Ok(()) + } + + #[test] + fn test_bytes_value_state_large_binary() -> Result<()> { + let mut state = BytesValueState::try_new(DataType::LargeBinary)?; + state.resize(1); + let array: ArrayRef = + Arc::new(LargeBinaryArray::from(vec![b"large_binary" as &[u8]])); + state.update(0, &array, 0)?; + let result = state.take(EmitTo::All)?; + assert_eq!(result.as_binary::().value(0), b"large_binary"); + Ok(()) + } + + #[test] + fn test_bytes_value_state_binary_view() -> Result<()> { + let mut state = BytesValueState::try_new(DataType::BinaryView)?; + state.resize(1); + + let data: Vec> = vec![Some(b"long_binary_value_to_test_view")]; + let array: ArrayRef = Arc::new(BinaryViewArray::from(data)); + + state.update(0, &array, 0)?; + + let result = state.take(EmitTo::All)?; + let result = result.as_binary_view(); + assert_eq!(result.value(0), b"long_binary_value_to_test_view"); + + Ok(()) + } + + #[test] + fn test_bytes_value_state_emit_first() -> Result<()> { + let mut state = BytesValueState::try_new(DataType::Utf8)?; + state.resize(3); + + let array: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c"])); + state.update(0, &array, 0)?; + state.update(1, &array, 1)?; + state.update(2, &array, 2)?; + + let result = state.take(EmitTo::First(2))?; + let result = result.as_string::(); + assert_eq!(result.len(), 2); + assert_eq!(result.value(0), "a"); + assert_eq!(result.value(1), "b"); + + // Remaining should be "c" + let result = state.take(EmitTo::All)?; + let result = result.as_string::(); + assert_eq!(result.len(), 1); + assert_eq!(result.value(0), "c"); + + Ok(()) + } + + #[test] + fn test_bytes_value_state_update_null() -> Result<()> { + let mut state = BytesValueState::try_new(DataType::Utf8)?; + state.resize(1); + + let array: ArrayRef = Arc::new(StringArray::from(vec![Some("hello"), None])); + + // group 0 = "hello" + state.update(0, &array, 0)?; + assert_eq!(state.total_capacity, state.total_capacity_calculated()); + assert!(state.total_capacity > 0); + + // group 0 = NULL + state.update(0, &array, 1)?; + assert_eq!( + state.total_capacity, + state.total_capacity_calculated(), + "total_capacity should match calculated capacity after update(NULL)" + ); + assert_eq!(state.total_capacity, 0); + + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 4d1da1dad5949..a170a7e0c95df 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -17,12 +17,9 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use std::any::Any; -use std::fmt; - use arrow::datatypes::Field; use arrow::datatypes::{DataType, FieldRef}; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::{Result, not_impl_err}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::utils::format_state_name; @@ -60,20 +57,11 @@ make_udaf_expr_and_func!( description = "Expression to evaluate whether data is aggregated across the specified column. Can be a constant, column, or function." ) )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Grouping { signature: Signature, } -impl fmt::Debug for Grouping { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Grouping") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for Grouping { fn default() -> Self { Self::new() @@ -90,10 +78,6 @@ impl Grouping { } impl AggregateUDFImpl for Grouping { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "grouping" } @@ -107,12 +91,14 @@ impl AggregateUDFImpl for Grouping { } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - Ok(vec![Field::new( - format_state_name(args.name, "grouping"), - DataType::Int32, - true, - ) - .into()]) + Ok(vec![ + Field::new( + format_state_name(args.name, "grouping"), + DataType::Int32, + true, + ) + .into(), + ]) } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { diff --git a/datafusion/functions-aggregate/src/hyperloglog.rs b/datafusion/functions-aggregate/src/hyperloglog.rs index 3074889eab23c..182fe15cf0f24 100644 --- a/datafusion/functions-aggregate/src/hyperloglog.rs +++ b/datafusion/functions-aggregate/src/hyperloglog.rs @@ -34,7 +34,7 @@ //! //! This module also borrows some code structure from [pdatastructs.rs](https://github.com/crepererum/pdatastructs.rs/blob/3997ed50f6b6871c9e53c4c5e0f48f431405fc63/src/hyperloglog.rs). -use ahash::RandomState; +use std::hash::BuildHasher; use std::hash::Hash; use std::marker::PhantomData; @@ -42,7 +42,7 @@ use std::marker::PhantomData; const HLL_P: usize = 14_usize; /// The number of bits of the hash value used determining the number of leading zeros const HLL_Q: usize = 64_usize - HLL_P; -const NUM_REGISTERS: usize = 1_usize << HLL_P; +pub(crate) const NUM_REGISTERS: usize = 1_usize << HLL_P; /// Mask to obtain index into the registers const HLL_P_MASK: u64 = (NUM_REGISTERS as u64) - 1; @@ -58,15 +58,11 @@ where /// Fixed seed for the hashing so that values are consistent across runs /// /// Note that when we later move on to have serialized HLL register binaries -/// shared across cluster, this SEED will have to be consistent across all +/// shared across cluster, this HLL_HASH_STATE will have to be consistent across all /// parties otherwise we might have corruption. So ideally for later this seed /// shall be part of the serialized form (or stay unchanged across versions). -const SEED: RandomState = RandomState::with_seeds( - 0x885f6cab121d01a3_u64, - 0x71e4379f2976ad8f_u64, - 0xbf30173dd28a8816_u64, - 0x0eaea5d736d733a4_u64, -); +pub(crate) const HLL_HASH_STATE: foldhash::quality::FixedState = + foldhash::quality::FixedState::with_seed(0); impl Default for HyperLogLog where @@ -97,17 +93,26 @@ where } } - /// choice of hash function: ahash is already an dependency + /// choice of hash function: foldhash is already an dependency /// and it fits the requirements of being a 64bit hash with /// reasonable performance. #[inline] fn hash_value(&self, obj: &T) -> u64 { - SEED.hash_one(obj) + HLL_HASH_STATE.hash_one(obj) } /// Adds an element to the HyperLogLog. pub fn add(&mut self, obj: &T) { let hash = self.hash_value(obj); + self.add_hashed(hash); + } + + /// Adds a pre-computed hash value directly to the HyperLogLog. + /// + /// The hash should be computed using [`HLL_HASH_STATE`], the same hasher used + /// by [`Self::add`]. + #[inline] + pub(crate) fn add_hashed(&mut self, hash: u64) { let index = (hash & HLL_P_MASK) as usize; let p = ((hash >> HLL_P) | (1_u64 << HLL_Q)).trailing_zeros() + 1; self.registers[index] = self.registers[index].max(p as u8); @@ -140,16 +145,69 @@ where /// Guess the number of unique elements seen by the HyperLogLog. pub fn count(&self) -> usize { - let histogram = self.get_histogram(); - let m = NUM_REGISTERS as f64; - let mut z = m * hll_tau((m - histogram[HLL_Q + 1] as f64) / m); - for i in histogram[1..=HLL_Q].iter().rev() { - z += *i as f64; - z *= 0.5; + count_from_histogram(&self.get_histogram()) + } +} + +/// Compute `index` and `rho` (register value) for a precomputed hash, exactly as +/// [`HyperLogLog::add_hashed`] does. +#[inline] +pub(crate) fn register_for_hash(hash: u64) -> (usize, u8) { + let index = (hash & HLL_P_MASK) as usize; + let rho = (((hash >> HLL_P) | (1_u64 << HLL_Q)).trailing_zeros() + 1) as u8; + (index, rho) +} + +/// Estimate the cardinality of a set of precomputed hashes without +/// materializing a full [`NUM_REGISTERS`]-byte register array. +/// +/// This is equivalent to adding every hash to a fresh [`HyperLogLog`] via +/// [`HyperLogLog::add_hashed`] and calling [`HyperLogLog::count`], but only does +/// work proportional to the number of hashes. It is used to cheaply estimate the +/// many small groups produced by a high-cardinality `GROUP BY`, where allocating +/// and scanning a 16 KiB sketch per group would dominate the runtime. +/// +/// `hashes` may contain duplicates (duplicate hashes are idempotent). +pub(crate) fn count_from_hashes(hashes: &[u64]) -> usize { + if hashes.is_empty() { + return 0; + } + // For each touched register index keep the maximum rho. Sorting by + // (index, rho) groups equal indices together with the max rho last. + let mut idx_rho: Vec<(usize, u8)> = + hashes.iter().map(|&hash| register_for_hash(hash)).collect(); + idx_rho.sort_unstable(); + + let mut histogram = [0u32; HLL_Q + 2]; + let mut touched = 0u32; + let mut i = 0; + while i < idx_rho.len() { + let index = idx_rho[i].0; + let mut max_rho = idx_rho[i].1; + i += 1; + while i < idx_rho.len() && idx_rho[i].0 == index { + max_rho = idx_rho[i].1; // ascending rho => last is the max + i += 1; } - z += m * hll_sigma(histogram[0] as f64 / m); - (0.5 / 2_f64.ln() * m * m / z).round() as usize + histogram[max_rho as usize] += 1; + touched += 1; + } + // All remaining registers are still zero. + histogram[0] = NUM_REGISTERS as u32 - touched; + count_from_histogram(&histogram) +} + +/// Apply the HyperLogLog cardinality estimator to a register histogram. +#[inline] +fn count_from_histogram(histogram: &[u32; HLL_Q + 2]) -> usize { + let m = NUM_REGISTERS as f64; + let mut z = m * hll_tau((m - histogram[HLL_Q + 1] as f64) / m); + for i in histogram[1..=HLL_Q].iter().rev() { + z += *i as f64; + z *= 0.5; } + z += m * hll_sigma(histogram[0] as f64 / m); + (0.5 / 2_f64.ln() * m * m / z).round() as usize } /// Helper function sigma as defined in diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 5454a902e4b7a..1b9996220d882 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] #![doc( html_logo_url = "https://raw.githubusercontent.com/apache/datafusion/19fe44cf2f30cbdd63d4a4f52c74055163c6cc38/docs/logos/standalone_logo/logo_original.svg", diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index 6c6bf72838899..0c919a1e5ea74 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -67,7 +67,6 @@ macro_rules! create_func { create_func!($UDAF, $AGGREGATE_UDF_FN, <$UDAF>::default()); }; ($UDAF:ty, $AGGREGATE_UDF_FN:ident, $CREATE:expr) => { - paste::paste! { #[doc = concat!("AggregateFunction that returns a [`AggregateUDF`](datafusion_expr::AggregateUDF) for [`", stringify!($UDAF), "`]")] pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc { // Singleton instance of [$UDAF], ensures the UDAF is only created once @@ -76,7 +75,6 @@ macro_rules! create_func { std::sync::Arc::new(datafusion_expr::AggregateUDF::from($CREATE)) }); std::sync::Arc::clone(&INSTANCE) - } } } } diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index ef76d1e6ea2db..e7e7d03937f12 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -21,8 +21,8 @@ use std::mem::{size_of, size_of_val}; use std::sync::Arc; use arrow::array::{ - downcast_integer, ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, - PrimitiveBuilder, + ArrowNumericType, BooleanArray, ListArray, PrimitiveArray, PrimitiveBuilder, + downcast_integer, }; use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::{ @@ -39,20 +39,22 @@ use arrow::datatypes::{ ArrowNativeType, ArrowPrimitiveType, Decimal32Type, Decimal64Type, FieldRef, }; +use datafusion_common::types::{NativeType, logical_float64}; use datafusion_common::{ - assert_eq_or_internal_err, internal_datafusion_err, DataFusionError, Result, - ScalarValue, + DataFusionError, Result, ScalarValue, assert_eq_or_internal_err, exec_datafusion_err, + internal_datafusion_err, }; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::{ - function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, - Documentation, Signature, Volatility, + Accumulator, AggregateUDFImpl, Coercion, Documentation, Signature, TypeSignature, + TypeSignatureClass, Volatility, function::AccumulatorArgs, utils::format_state_name, }; use datafusion_expr::{EmitTo, GroupsAccumulator}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; -use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer; +use datafusion_functions_aggregate_common::utils::{GenericDistinctBuffer, Hashable}; use datafusion_macros::user_doc; +use std::collections::HashMap; make_udaf_expr_and_func!( Median, @@ -84,20 +86,11 @@ make_udaf_expr_and_func!( /// If using the distinct variation, the memory usage will be similarly high if the /// cardinality is high as it stores all distinct values in memory before computing the /// result, but if cardinality is low then memory usage will also be lower. -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Median { signature: Signature, } -impl Debug for Median { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - f.debug_struct("Median") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for Median { fn default() -> Self { Self::new() @@ -107,16 +100,30 @@ impl Default for Median { impl Median { pub fn new() -> Self { Self { - signature: Signature::numeric(1, Volatility::Immutable), + // Integer inputs are coerced to Float64 so the average of the two + // middle values is not truncated. This matches DuckDB / PostgreSQL / Spark. + // Float and Decimal inputs preserve their type. + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![Coercion::new_exact( + TypeSignatureClass::Decimal, + )]), + TypeSignature::Coercible(vec![Coercion::new_exact( + TypeSignatureClass::Float, + )]), + TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Integer], + NativeType::Float64, + )]), + ], + Volatility::Immutable, + ), } } } impl AggregateUDFImpl for Median { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "median" } @@ -138,12 +145,14 @@ impl AggregateUDFImpl for Median { "median" }; - Ok(vec![Field::new( - format_state_name(args.name, state_name), - DataType::List(Arc::new(field)), - true, - ) - .into()]) + Ok(vec![ + Field::new( + format_state_name(args.name, state_name), + DataType::List(Arc::new(field)), + true, + ) + .into(), + ]) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { @@ -273,7 +282,12 @@ impl Accumulator for MedianAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = values[0].as_primitive::(); - self.all_values.reserve(values.len() - values.null_count()); + let additional = values.len() - values.null_count(); + self.all_values.try_reserve(additional).map_err(|e| { + exec_datafusion_err!( + "failed to reserve {additional} values for median accumulator: {e}" + ) + })?; self.all_values.extend(values.iter().flatten()); Ok(()) } @@ -287,14 +301,46 @@ impl Accumulator for MedianAccumulator { } fn evaluate(&mut self) -> Result { - let d = std::mem::take(&mut self.all_values); - let median = calculate_median::(d); + let median = calculate_median::(&mut self.all_values); ScalarValue::new_primitive::(median, &self.data_type) } fn size(&self) -> usize { size_of_val(self) + self.all_values.capacity() * size_of::() } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let mut to_remove: HashMap, usize> = HashMap::new(); + + let arr = values[0].as_primitive::(); + for value in arr.iter().flatten() { + *to_remove.entry(Hashable(value)).or_default() += 1; + } + + let mut i = 0; + while i < self.all_values.len() { + let k = Hashable(self.all_values[i]); + if let Some(count) = to_remove.get_mut(&k) + && *count > 0 + { + self.all_values.swap_remove(i); + *count -= 1; + if *count == 0 { + to_remove.remove(&k); + if to_remove.is_empty() { + break; + } + } + } else { + i += 1; + } + } + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } } /// The median groups accumulator accumulates the raw input values @@ -441,8 +487,8 @@ impl GroupsAccumulator for MedianGroupsAccumulator::new().with_data_type(self.data_type.clone()); - for values in emit_group_values { - let median = calculate_median::(values); + for mut values in emit_group_values { + let median = calculate_median::(&mut values); evaluate_result_builder.append_option(median); } @@ -526,11 +572,9 @@ impl Accumulator for DistinctMedianAccumulator { } fn evaluate(&mut self) -> Result { - let d = std::mem::take(&mut self.distinct_values.values) - .into_iter() - .map(|v| v.0) - .collect::>(); - let median = calculate_median::(d); + let mut d: Vec = + self.distinct_values.values.iter().map(|v| v.0).collect(); + let median = calculate_median::(&mut d); ScalarValue::new_primitive::(median, &self.data_type) } @@ -554,9 +598,7 @@ where .unwrap() } -fn calculate_median( - mut values: Vec, -) -> Option { +fn calculate_median(values: &mut [T::Native]) -> Option { let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); let len = values.len(); @@ -566,9 +608,25 @@ fn calculate_median( let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp); // Get the maximum of the low (left side after bi-partitioning) let left_max = slice_max::(low); - let median = left_max - .add_wrapping(*high) - .div_wrapping(T::Native::usize_as(2)); + // Calculate median as the average of the two middle values. + // Use checked arithmetic to detect overflow and fall back to safe formula. + let two = T::Native::usize_as(2); + let median = match left_max.add_checked(*high) { + Ok(sum) => sum.div_wrapping(two), + Err(_) => { + // Overflow detected - use safe midpoint formula: + // a/2 + b/2 + ((a%2 + b%2) / 2) + // This avoids overflow by dividing before adding. + let half_left = left_max.div_wrapping(two); + let half_right = (*high).div_wrapping(two); + let rem_left = left_max.mod_wrapping(two); + let rem_right = (*high).mod_wrapping(two); + // The sum of remainders (0, 1, or 2 for unsigned; -2 to 2 for signed) + // divided by 2 gives the correction factor (0 or 1 for unsigned; -1, 0, or 1 for signed) + let correction = rem_left.add_wrapping(rem_right).div_wrapping(two); + half_left.add_wrapping(half_right).add_wrapping(correction) + } + }; Some(median) } else { let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp); diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 1a46afefffb3b..f4eaaab853464 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -23,13 +23,13 @@ mod min_max_struct; use arrow::array::ArrayRef; use arrow::datatypes::{ - DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, - DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, - Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + DurationSecondType, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, + Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, }; use datafusion_common::stats::Precision; -use datafusion_common::{exec_err, internal_err, ColumnStatistics, Result}; +use datafusion_common::{ColumnStatistics, Result, exec_err, internal_err}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use datafusion_physical_expr::expressions; use std::cmp::Ordering; @@ -46,8 +46,8 @@ use crate::min_max::min_max_bytes::MinMaxBytesAccumulator; use crate::min_max::min_max_struct::MinMaxStructAccumulator; use datafusion_common::ScalarValue; use datafusion_expr::{ - function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation, - SetMonotonicity, Signature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, SetMonotonicity, Signature, Volatility, + function::AccumulatorArgs, }; use datafusion_expr::{GroupsAccumulator, StatisticsArgs}; use datafusion_macros::user_doc; @@ -171,9 +171,8 @@ trait FromColumnStatistics { let col_stats = &statistics_args.statistics.column_statistics; if statistics_args.exprs.len() == 1 { // TODO optimize with exprs other than Column - if let Some(col_expr) = statistics_args.exprs[0] - .as_any() - .downcast_ref::() + if let Some(col_expr) = + statistics_args.exprs[0].downcast_ref::() { return self.value_from_column_statistics( &col_stats[col_expr.index()], @@ -193,20 +192,16 @@ impl FromColumnStatistics for Max { &self, col_stats: &ColumnStatistics, ) -> Option { - if let Precision::Exact(ref val) = col_stats.max_value { - if !val.is_null() { - return Some(val.clone()); - } + if let Precision::Exact(ref val) = col_stats.max_value + && !val.is_null() + { + return Some(val.clone()); } None } } impl AggregateUDFImpl for Max { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "max" } @@ -480,20 +475,16 @@ impl FromColumnStatistics for Min { &self, col_stats: &ColumnStatistics, ) -> Option { - if let Precision::Exact(ref val) = col_stats.min_value { - if !val.is_null() { - return Some(val.clone()); - } + if let Precision::Exact(ref val) = col_stats.min_value + && !val.is_null() + { + return Some(val.clone()); } None } } impl AggregateUDFImpl for Min { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "min" } @@ -1012,12 +1003,13 @@ mod tests { use super::*; use arrow::{ array::{ - DictionaryArray, Float32Array, Int32Array, IntervalDayTimeArray, - IntervalMonthDayNanoArray, IntervalYearMonthArray, StringArray, + Array, DictionaryArray, Float32Array, Int8Array, Int32Array, + IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, + PrimitiveArray, StringArray, }, datatypes::{ - IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, - IntervalYearMonthType, + ArrowDictionaryKeyType, IntervalDayTimeType, IntervalMonthDayNanoType, + IntervalUnit, IntervalYearMonthType, }, }; use std::sync::Arc; @@ -1154,7 +1146,6 @@ mod tests { check(&mut max(), &[&[zero, neg_inf]], zero); } - use datafusion_common::Result; use rand::Rng; fn get_random_vec_i32(len: usize) -> Vec { @@ -1268,7 +1259,178 @@ mod tests { let mut max_acc = MaxAccumulator::try_new(&rt_type)?; max_acc.update_batch(&[Arc::clone(&dict_array_ref)])?; let max_result = max_acc.evaluate()?; - assert_eq!(max_result, ScalarValue::Utf8(Some("🦀".to_string()))); + assert_eq!(max_result, ScalarValue::Utf8(Some("d".to_string()))); + Ok(()) + } + + fn dict_scalar(key_type: DataType, inner: ScalarValue) -> ScalarValue { + ScalarValue::Dictionary(Box::new(key_type), Box::new(inner)) + } + + fn utf8_dict_scalar(key_type: DataType, value: &str) -> ScalarValue { + dict_scalar(key_type, ScalarValue::Utf8(Some(value.to_string()))) + } + + fn string_dictionary_batch(values: &[&str], keys: &[Option]) -> ArrayRef { + string_dictionary_batch_with_keys(Int32Array::from(keys.to_vec()), values) + } + + fn string_dictionary_batch_with_keys( + keys: PrimitiveArray, + values: &[&str], + ) -> ArrayRef + where + K: ArrowDictionaryKeyType, + { + let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef; + Arc::new(DictionaryArray::try_new(keys, values).unwrap()) as ArrayRef + } + + fn optional_string_dictionary_batch( + values: &[Option<&str>], + keys: &[Option], + ) -> ArrayRef { + let values = Arc::new(StringArray::from(values.to_vec())) as ArrayRef; + Arc::new( + DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(), + ) as ArrayRef + } + + fn float_dictionary_batch(values: &[f32], keys: &[Option]) -> ArrayRef { + let values = Arc::new(Float32Array::from(values.to_vec())) as ArrayRef; + Arc::new( + DictionaryArray::try_new(Int32Array::from(keys.to_vec()), values).unwrap(), + ) as ArrayRef + } + + fn evaluate_dictionary_accumulator( + mut acc: impl Accumulator, + batches: &[ArrayRef], + ) -> Result { + for batch in batches { + acc.update_batch(&[Arc::clone(batch)])?; + } + acc.evaluate() + } + + fn assert_dictionary_min_max( + dict_type: &DataType, + batches: &[ArrayRef], + expected_min: &str, + expected_max: &str, + ) -> Result<()> { + let key_type = match dict_type { + DataType::Dictionary(key_type, _) => key_type.as_ref().clone(), + other => panic!("expected dictionary type, got {other:?}"), + }; + + let min_result = evaluate_dictionary_accumulator( + MinAccumulator::try_new(dict_type)?, + batches, + )?; + assert_eq!(min_result, utf8_dict_scalar(key_type.clone(), expected_min)); + + let max_result = evaluate_dictionary_accumulator( + MaxAccumulator::try_new(dict_type)?, + batches, + )?; + assert_eq!(max_result, utf8_dict_scalar(key_type, expected_max)); + + Ok(()) + } + + #[test] + fn test_min_max_dictionary_without_coercion() -> Result<()> { + let dict_array_ref = string_dictionary_batch( + &["b", "c", "a", "d"], + &[Some(0), Some(1), Some(2), Some(3)], + ); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d") + } + + #[test] + fn test_min_max_dictionary_with_nulls() -> Result<()> { + let dict_array_ref = string_dictionary_batch( + &["b", "c", "a"], + &[None, Some(0), None, Some(1), Some(2)], + ); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "c") + } + + #[test] + fn test_min_max_dictionary_ignores_unreferenced_values() -> Result<()> { + let dict_array_ref = + string_dictionary_batch(&["a", "z", "zz_unused"], &[Some(1), Some(1), None]); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "z", "z") + } + + #[test] + fn test_min_max_dictionary_ignores_referenced_null_values() -> Result<()> { + let dict_array_ref = optional_string_dictionary_batch( + &[Some("b"), None, Some("a"), Some("d")], + &[Some(0), Some(1), Some(2), Some(3)], + ); + let dict_type = dict_array_ref.data_type().clone(); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d") + } + + #[test] + fn test_min_max_dictionary_multi_batch() -> Result<()> { + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let batch1 = string_dictionary_batch(&["b", "c"], &[Some(0), Some(1)]); + let batch2 = string_dictionary_batch(&["a", "d"], &[Some(0), Some(1)]); + + assert_dictionary_min_max(&dict_type, &[batch1, batch2], "a", "d") + } + + #[test] + fn test_min_max_dictionary_int8_keys() -> Result<()> { + let dict_type = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); + let dict_array_ref = string_dictionary_batch_with_keys( + Int8Array::from(vec![Some(0), Some(1), Some(2), Some(3)]), + &["b", "c", "a", "d"], + ); + + assert_dictionary_min_max(&dict_type, &[dict_array_ref], "a", "d") + } + + #[test] + fn test_min_max_dictionary_float_with_nans() -> Result<()> { + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Float32)); + let batch1 = float_dictionary_batch(&[0.0, f32::NAN], &[Some(0), Some(1)]); + let batch2 = float_dictionary_batch(&[f32::NEG_INFINITY], &[Some(0)]); + + let min_result = evaluate_dictionary_accumulator( + MinAccumulator::try_new(&dict_type)?, + &[Arc::clone(&batch1), Arc::clone(&batch2)], + )?; + assert_eq!( + min_result, + dict_scalar( + DataType::Int32, + ScalarValue::Float32(Some(f32::NEG_INFINITY)), + ) + ); + + let max_result = evaluate_dictionary_accumulator( + MaxAccumulator::try_new(&dict_type)?, + &[batch1, batch2], + )?; + assert_eq!( + max_result, + dict_scalar(DataType::Int32, ScalarValue::Float32(Some(f32::NAN))) + ); + Ok(()) } } diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs index 30b2739c08edc..b56c2106e32b5 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -21,12 +21,14 @@ use arrow::array::{ }; use arrow::datatypes::DataType; use datafusion_common::hash_map::Entry; -use datafusion_common::{internal_err, HashMap, Result}; +use datafusion_common::{HashMap, Result, internal_err}; use datafusion_expr::{EmitTo, GroupsAccumulator}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls; use std::mem::size_of; use std::sync::Arc; +use datafusion_common::utils::split_vec_min_alloc; + /// Implements fast Min/Max [`GroupsAccumulator`] for "bytes" types ([`StringArray`], /// [`BinaryArray`], [`StringViewArray`], etc) /// @@ -493,7 +495,7 @@ impl MinMaxBytesState { ) } EmitTo::First(n) => { - let first_min_maxes: Vec<_> = self.min_max.drain(..n).collect(); + let first_min_maxes = split_vec_min_alloc(&mut self.min_max, n); let first_data_capacity: usize = first_min_maxes .iter() .map(|opt| opt.as_ref().map(|s| s.len()).unwrap_or(0)) diff --git a/datafusion/functions-aggregate/src/min_max/min_max_struct.rs b/datafusion/functions-aggregate/src/min_max/min_max_struct.rs index 8038f2f01d90c..7c94e7f5738be 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_struct.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_struct.rs @@ -24,13 +24,14 @@ use arrow::{ datatypes::DataType, }; use datafusion_common::{ - internal_err, + Result, internal_err, scalar::{copy_array_data, partial_cmp_struct}, - Result, }; use datafusion_expr::{EmitTo, GroupsAccumulator}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls; +use datafusion_common::utils::split_vec_min_alloc; + /// Accumulator for MIN/MAX operations on Struct data types. /// /// This accumulator tracks the minimum or maximum struct value encountered @@ -283,7 +284,7 @@ impl MinMaxStructState { ) } EmitTo::First(n) => { - let first_min_maxes: Vec<_> = self.min_max.drain(..n).collect(); + let first_min_maxes = split_vec_min_alloc(&mut self.min_max, n); let first_data_capacity: usize = first_min_maxes .iter() .map(|opt| opt.as_ref().map(|s| s.len()).unwrap_or(0)) diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 05026940fec45..bddc46e27e15c 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -18,23 +18,22 @@ //! Defines NTH_VALUE aggregate expression which may specify ordering requirement //! that can evaluated at runtime during query execution -use std::any::Any; use std::collections::VecDeque; use std::mem::{size_of, size_of_val}; use std::sync::Arc; -use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; +use arrow::array::{ArrayRef, AsArray, StructArray, new_empty_array}; use arrow::datatypes::{DataType, Field, FieldRef, Fields}; -use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; +use datafusion_common::utils::{SingleRowListArrayBuilder, get_row_at_idx}; use datafusion_common::{ - assert_or_internal_err, exec_err, not_impl_err, Result, ScalarValue, + Result, ScalarValue, assert_or_internal_err, exec_err, not_impl_err, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - lit, Accumulator, AggregateUDFImpl, Documentation, ExprFunctionExt, ReversedUDAF, - Signature, SortExpr, Volatility, + Accumulator, AggregateUDFImpl, Documentation, ExprFunctionExt, ReversedUDAF, + Signature, SortExpr, Volatility, lit, }; use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; use datafusion_functions_aggregate_common::utils::ordering_fields; @@ -112,10 +111,6 @@ impl Default for NthValueAgg { } impl AggregateUDFImpl for NthValueAgg { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "nth_value" } @@ -130,7 +125,6 @@ impl AggregateUDFImpl for NthValueAgg { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { let n = match acc_args.exprs[1] - .as_any() .downcast_ref::() .map(|lit| lit.value()) { @@ -146,7 +140,7 @@ impl AggregateUDFImpl for NthValueAgg { "{} not supported for n: {}", self.name(), &acc_args.exprs[1] - ) + ); } }; @@ -372,7 +366,7 @@ impl NthValueAccumulator { let array = if column_values.is_empty() { new_empty_array(fields[i].data_type()) } else { - ScalarValue::iter_to_array(column_values.into_iter())? + ScalarValue::iter_to_array(column_values)? }; column_wise_ordering_values.push(array); } diff --git a/datafusion/functions-aggregate/src/percentile_cont.rs b/datafusion/functions-aggregate/src/percentile_cont.rs index b46186bdfcab8..714988bde2acf 100644 --- a/datafusion/functions-aggregate/src/percentile_cont.rs +++ b/datafusion/functions-aggregate/src/percentile_cont.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use std::fmt::{Debug, Formatter}; +use std::collections::HashMap; +use std::fmt::Debug; use std::mem::{size_of, size_of_val}; use std::sync::Arc; @@ -25,30 +26,35 @@ use arrow::array::{ use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::{ array::{Array, ArrayRef, AsArray}, - datatypes::{ - ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Decimal32Type, - Decimal64Type, Field, FieldRef, Float16Type, Float32Type, Float64Type, - }, + datatypes::{DataType, Field, FieldRef, Float16Type, Float32Type, Float64Type}, }; +use num_traits::AsPrimitive; + use arrow::array::ArrowNativeTypeOp; +use datafusion_common::internal_err; +use datafusion_common::types::{NativeType, logical_float64}; +use datafusion_functions_aggregate_common::noop_accumulator::NoopAccumulator; +use crate::min_max::{max_udaf, min_udaf}; use datafusion_common::{ - assert_eq_or_internal_err, internal_datafusion_err, plan_err, DataFusionError, - Result, ScalarValue, + Result, ScalarValue, exec_datafusion_err, internal_datafusion_err, + utils::take_function_args, }; -use datafusion_expr::expr::{AggregateFunction, Sort}; -use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature, - Volatility, + Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, Signature, + TypeSignatureClass, Volatility, }; use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_expr::{ + expr::{AggregateFunction, Sort}, + function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + simplify::SimplifyContext, +}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; -use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer; +use datafusion_functions_aggregate_common::utils::{GenericDistinctBuffer, Hashable}; use datafusion_macros::user_doc; use crate::utils::validate_percentile_expr; @@ -63,7 +69,10 @@ use crate::utils::validate_percentile_expr; /// The interpolation formula: `lower + (upper - lower) * fraction` /// is computed as: `lower + ((upper - lower) * (fraction * PRECISION)) / PRECISION` /// to avoid floating-point operations on integer types while maintaining precision. -const INTERPOLATION_PRECISION: usize = 1_000_000; +/// +/// The interpolation arithmetic is performed in f64 and then cast back to the +/// native type to avoid overflowing Float16 intermediates. +const INTERPOLATION_PRECISION: f64 = 1_000_000.0; create_func!(PercentileCont, percentile_cont_udaf); @@ -117,21 +126,12 @@ An alternate syntax is also supported: /// If using the distinct variation, the memory usage will be similarly high if the /// cardinality is high as it stores all distinct values in memory before computing the /// result, but if cardinality is low then memory usage will also be lower. -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct PercentileCont { signature: Signature, aliases: Vec, } -impl Debug for PercentileCont { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - f.debug_struct("PercentileCont") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for PercentileCont { fn default() -> Self { Self::new() @@ -140,83 +140,30 @@ impl Default for PercentileCont { impl PercentileCont { pub fn new() -> Self { - let mut variants = Vec::with_capacity(NUMERICS.len()); - // Accept any numeric value paired with a float64 percentile - for num in NUMERICS { - variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); - } Self { - signature: Signature::one_of(variants, Volatility::Immutable) - .with_parameter_names(vec!["expr".to_string(), "percentile".to_string()]) - .expect("valid parameter names for percentile_cont"), + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Float, + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ), + ], + Volatility::Immutable, + ) + .with_parameter_names(vec!["expr", "percentile"]) + .unwrap(), aliases: vec![String::from("quantile_cont")], } } - - fn create_accumulator(&self, args: &AccumulatorArgs) -> Result> { - let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?; - - let is_descending = args - .order_bys - .first() - .map(|sort_expr| sort_expr.options.descending) - .unwrap_or(false); - - let percentile = if is_descending { - 1.0 - percentile - } else { - percentile - }; - - macro_rules! helper { - ($t:ty, $dt:expr) => { - if args.is_distinct { - Ok(Box::new(DistinctPercentileContAccumulator::<$t> { - data_type: $dt.clone(), - distinct_values: GenericDistinctBuffer::new($dt), - percentile, - })) - } else { - Ok(Box::new(PercentileContAccumulator::<$t> { - data_type: $dt.clone(), - all_values: vec![], - percentile, - })) - } - }; - } - - let input_dt = args.exprs[0].data_type(args.schema)?; - match input_dt { - // For integer types, use Float64 internally since percentile_cont returns Float64 - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => helper!(Float64Type, DataType::Float64), - DataType::Float16 => helper!(Float16Type, input_dt), - DataType::Float32 => helper!(Float32Type, input_dt), - DataType::Float64 => helper!(Float64Type, input_dt), - DataType::Decimal32(_, _) => helper!(Decimal32Type, input_dt), - DataType::Decimal64(_, _) => helper!(Decimal64Type, input_dt), - DataType::Decimal128(_, _) => helper!(Decimal128Type, input_dt), - DataType::Decimal256(_, _) => helper!(Decimal256Type, input_dt), - _ => Err(DataFusionError::NotImplemented(format!( - "PercentileContAccumulator not supported for {} with {}", - args.name, input_dt, - ))), - } - } } impl AggregateUDFImpl for PercentileCont { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "percentile_cont" } @@ -230,134 +177,110 @@ impl AggregateUDFImpl for PercentileCont { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("percentile_cont requires numeric input types"); - } - // PERCENTILE_CONT performs linear interpolation and should return a float type - // For integer inputs, return Float64 (matching PostgreSQL/DuckDB behavior) - // For float inputs, preserve the float type match &arg_types[0] { - DataType::Float16 | DataType::Float32 | DataType::Float64 => { - Ok(arg_types[0].clone()) - } - DataType::Decimal32(_, _) - | DataType::Decimal64(_, _) - | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) => Ok(arg_types[0].clone()), - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 => Ok(DataType::Float64), - // Shouldn't happen due to signature check, but just in case - dt => plan_err!( - "percentile_cont does not support input type {}, must be numeric", - dt - ), + DataType::Null => Ok(DataType::Float64), + dt => Ok(dt.clone()), } } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - //Intermediate state is a list of the elements we have collected so far let input_type = args.input_fields[0].data_type().clone(); - // For integer types, we store as Float64 internally - let storage_type = match &input_type { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => DataType::Float64, - _ => input_type, - }; + if input_type.is_null() { + return Ok(vec![ + Field::new( + format_state_name(args.name, self.name()), + DataType::Null, + true, + ) + .into(), + ]); + } - let field = Field::new_list_field(storage_type, true); + let field = Field::new_list_field(input_type, true); let state_name = if args.is_distinct { "distinct_percentile_cont" } else { "percentile_cont" }; - Ok(vec![Field::new( - format_state_name(args.name, state_name), - DataType::List(Arc::new(field)), - true, - ) - .into()]) + Ok(vec![ + Field::new( + format_state_name(args.name, state_name), + DataType::List(Arc::new(field)), + true, + ) + .into(), + ]) } - fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - self.create_accumulator(&acc_args) + fn accumulator(&self, args: AccumulatorArgs) -> Result> { + let percentile = get_percentile(&args)?; + + let input_dt = args.expr_fields[0].data_type(); + if input_dt.is_null() { + return Ok(Box::new(NoopAccumulator::new(ScalarValue::Float64(None)))); + } + + if args.is_distinct { + match input_dt { + DataType::Float16 => Ok(Box::new(DistinctPercentileContAccumulator::< + Float16Type, + >::new(percentile))), + DataType::Float32 => Ok(Box::new(DistinctPercentileContAccumulator::< + Float32Type, + >::new(percentile))), + DataType::Float64 => Ok(Box::new(DistinctPercentileContAccumulator::< + Float64Type, + >::new(percentile))), + dt => internal_err!("Unsupported datatype for percentile cont: {dt}"), + } + } else { + match input_dt { + DataType::Float16 => Ok(Box::new( + PercentileContAccumulator::::new(percentile), + )), + DataType::Float32 => Ok(Box::new( + PercentileContAccumulator::::new(percentile), + )), + DataType::Float64 => Ok(Box::new( + PercentileContAccumulator::::new(percentile), + )), + dt => internal_err!("Unsupported datatype for percentile cont: {dt}"), + } + } } fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { - !args.is_distinct + !args.is_distinct && !args.expr_fields[0].data_type().is_null() } fn create_groups_accumulator( &self, args: AccumulatorArgs, ) -> Result> { - let num_args = args.exprs.len(); - assert_eq_or_internal_err!( - num_args, - 2, - "percentile_cont should have 2 args, but found num args:{}", - num_args - ); - - let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?; - - let is_descending = args - .order_bys - .first() - .map(|sort_expr| sort_expr.options.descending) - .unwrap_or(false); - - let percentile = if is_descending { - 1.0 - percentile - } else { - percentile - }; - - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(PercentileContGroupsAccumulator::<$t>::new( - $dt, percentile, - ))) - }; - } + let percentile = get_percentile(&args)?; - let input_dt = args.exprs[0].data_type(args.schema)?; + let input_dt = args.expr_fields[0].data_type(); match input_dt { - // For integer types, use Float64 internally since percentile_cont returns Float64 - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => helper!(Float64Type, DataType::Float64), - DataType::Float16 => helper!(Float16Type, input_dt), - DataType::Float32 => helper!(Float32Type, input_dt), - DataType::Float64 => helper!(Float64Type, input_dt), - DataType::Decimal32(_, _) => helper!(Decimal32Type, input_dt), - DataType::Decimal64(_, _) => helper!(Decimal64Type, input_dt), - DataType::Decimal128(_, _) => helper!(Decimal128Type, input_dt), - DataType::Decimal256(_, _) => helper!(Decimal256Type, input_dt), - _ => Err(DataFusionError::NotImplemented(format!( - "PercentileContGroupsAccumulator not supported for {} with {}", - args.name, input_dt, - ))), + DataType::Float16 => Ok(Box::new(PercentileContGroupsAccumulator::< + Float16Type, + >::new(percentile))), + DataType::Float32 => Ok(Box::new(PercentileContGroupsAccumulator::< + Float32Type, + >::new(percentile))), + DataType::Float64 => Ok(Box::new(PercentileContGroupsAccumulator::< + Float64Type, + >::new(percentile))), + dt => internal_err!("Unsupported datatype for percentile cont: {dt}"), } } + fn simplify(&self) -> Option { + Some(Box::new(|aggregate_function, info| { + simplify_percentile_cont_aggregate(aggregate_function, info) + })) + } + fn supports_within_group_clause(&self) -> bool { true } @@ -367,6 +290,83 @@ impl AggregateUDFImpl for PercentileCont { } } +fn get_percentile(args: &AccumulatorArgs) -> Result { + let percentile = validate_percentile_expr(&args.exprs[1], "PERCENTILE_CONT")?; + + let is_descending = args + .order_bys + .first() + .map(|sort_expr| sort_expr.options.descending) + .unwrap_or(false); + + let percentile = if is_descending { + 1.0 - percentile + } else { + percentile + }; + + Ok(percentile) +} + +fn simplify_percentile_cont_aggregate( + aggregate_function: AggregateFunction, + info: &SimplifyContext, +) -> Result { + enum PercentileRewriteTarget { + Min, + Max, + } + + let params = &aggregate_function.params; + let [value, percentile] = take_function_args("percentile_cont", ¶ms.args)?; + // + // For simplicity we don't bother with null types (otherwise we'd need to + // cast the return type) + let input_type = info.get_data_type(value)?; + if input_type.is_null() { + return Ok(Expr::AggregateFunction(aggregate_function)); + } + + let is_descending = params + .order_by + .first() + .map(|sort| !sort.asc) + .unwrap_or(false); + + let rewrite_target = match percentile { + Expr::Literal(ScalarValue::Float64(Some(0.0)), _) => { + if is_descending { + PercentileRewriteTarget::Max + } else { + PercentileRewriteTarget::Min + } + } + Expr::Literal(ScalarValue::Float64(Some(1.0)), _) => { + if is_descending { + PercentileRewriteTarget::Min + } else { + PercentileRewriteTarget::Max + } + } + _ => return Ok(Expr::AggregateFunction(aggregate_function)), + }; + + let udaf = match rewrite_target { + PercentileRewriteTarget::Min => min_udaf(), + PercentileRewriteTarget::Max => max_udaf(), + }; + + let rewritten = Expr::AggregateFunction(AggregateFunction::new_udf( + udaf, + vec![value.clone()], + params.distinct, + params.filter.clone(), + vec![], + params.null_treatment, + )); + Ok(rewritten) +} + /// The percentile_cont accumulator accumulates the raw input values /// as native types. /// @@ -374,23 +374,27 @@ impl AggregateUDFImpl for PercentileCont { /// `merge_batch` and a `Vec` of native values that are converted to scalar values /// in the final evaluation step so that we avoid expensive conversions and /// allocations during `update_batch`. -struct PercentileContAccumulator { - data_type: DataType, +#[derive(Debug)] +struct PercentileContAccumulator { all_values: Vec, percentile: f64, } -impl Debug for PercentileContAccumulator { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "PercentileContAccumulator({}, percentile={})", - self.data_type, self.percentile - ) +impl PercentileContAccumulator { + fn new(percentile: f64) -> Self { + Self { + all_values: vec![], + percentile, + } } } -impl Accumulator for PercentileContAccumulator { +impl Accumulator for PercentileContAccumulator +where + T: ArrowNumericType + Debug, + T::Native: Copy + AsPrimitive, + f64: AsPrimitive, +{ fn state(&mut self) -> Result> { // Convert `all_values` to `ListArray` and return a single List ScalarValue @@ -402,12 +406,11 @@ impl Accumulator for PercentileContAccumulator { let values_array = PrimitiveArray::::new( ScalarBuffer::from(std::mem::take(&mut self.all_values)), None, - ) - .with_data_type(self.data_type.clone()); + ); // Build the result list array let list_array = ListArray::new( - Arc::new(Field::new_list_field(self.data_type.clone(), true)), + Arc::new(Field::new_list_field(T::DATA_TYPE, true)), offsets, Arc::new(values_array), None, @@ -417,36 +420,64 @@ impl Accumulator for PercentileContAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - // Cast to target type if needed (e.g., integer to Float64) - let values = if values[0].data_type() != &self.data_type { - arrow::compute::cast(&values[0], &self.data_type)? - } else { - Arc::clone(&values[0]) - }; - - let values = values.as_primitive::(); - self.all_values.reserve(values.len() - values.null_count()); + let values = values[0].as_primitive::(); + let additional = values.len() - values.null_count(); + self.all_values.try_reserve(additional).map_err(|e| { + exec_datafusion_err!( + "failed to reserve {additional} values for percentile_cont accumulator: {e}" + ) + })?; self.all_values.extend(values.iter().flatten()); Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { let array = states[0].as_list::(); - for v in array.iter().flatten() { - self.update_batch(&[v])? - } + self.update_batch(&[array.value(0)])?; Ok(()) } fn evaluate(&mut self) -> Result { - let d = std::mem::take(&mut self.all_values); - let value = calculate_percentile::(d, self.percentile); - ScalarValue::new_primitive::(value, &self.data_type) + let value = calculate_percentile::(&mut self.all_values, self.percentile); + ScalarValue::new_primitive::(value, &T::DATA_TYPE) } fn size(&self) -> usize { size_of_val(self) + self.all_values.capacity() * size_of::() } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let mut to_remove: HashMap, usize> = HashMap::new(); + + let arr = values[0].as_primitive::(); + for value in arr.iter().flatten() { + *to_remove.entry(Hashable(value)).or_default() += 1; + } + + let mut i = 0; + while i < self.all_values.len() { + let k = Hashable(self.all_values[i]); + if let Some(count) = to_remove.get_mut(&k) + && *count > 0 + { + self.all_values.swap_remove(i); + *count -= 1; + if *count == 0 { + to_remove.remove(&k); + if to_remove.is_empty() { + break; + } + } + } else { + i += 1; + } + } + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } } /// The percentile_cont groups accumulator accumulates the raw input values @@ -457,23 +488,24 @@ impl Accumulator for PercentileContAccumulator { /// will be actually organized as a `Vec>`. #[derive(Debug)] struct PercentileContGroupsAccumulator { - data_type: DataType, group_values: Vec>, percentile: f64, } impl PercentileContGroupsAccumulator { - pub fn new(data_type: DataType, percentile: f64) -> Self { + fn new(percentile: f64) -> Self { Self { - data_type, - group_values: Vec::new(), + group_values: vec![], percentile, } } } -impl GroupsAccumulator - for PercentileContGroupsAccumulator +impl GroupsAccumulator for PercentileContGroupsAccumulator +where + T: ArrowNumericType + Send, + T::Native: Copy + AsPrimitive, + f64: AsPrimitive, { fn update_batch( &mut self, @@ -485,14 +517,7 @@ impl GroupsAccumulator // For ordered-set aggregates, we only care about the ORDER BY column (first element) // The percentile parameter is already stored in self.percentile - // Cast to target type if needed (e.g., integer to Float64) - let values_array = if values[0].data_type() != &self.data_type { - arrow::compute::cast(&values[0], &self.data_type)? - } else { - Arc::clone(&values[0]) - }; - - let values = values_array.as_primitive::(); + let values = values[0].as_primitive::(); // Push the `not nulls + not filtered` row into its group self.group_values.resize(total_num_groups, Vec::new()); @@ -555,12 +580,11 @@ impl GroupsAccumulator let flatten_group_values = emit_group_values.into_iter().flatten().collect::>(); let group_values_array = - PrimitiveArray::::new(ScalarBuffer::from(flatten_group_values), None) - .with_data_type(self.data_type.clone()); + PrimitiveArray::::new(ScalarBuffer::from(flatten_group_values), None); // Build the result list array let result_list_array = ListArray::new( - Arc::new(Field::new_list_field(self.data_type.clone(), true)), + Arc::new(Field::new_list_field(T::DATA_TYPE, true)), offsets, Arc::new(group_values_array), None, @@ -571,13 +595,13 @@ impl GroupsAccumulator fn evaluate(&mut self, emit_to: EmitTo) -> Result { // Emit values - let emit_group_values = emit_to.take_needed(&mut self.group_values); + let mut emit_group_values = emit_to.take_needed(&mut self.group_values); // Calculate percentile for each group let mut evaluate_result_builder = - PrimitiveBuilder::::new().with_data_type(self.data_type.clone()); - for values in emit_group_values { - let value = calculate_percentile::(values, self.percentile); + PrimitiveBuilder::::with_capacity(emit_group_values.len()); + for values in &mut emit_group_values { + let value = calculate_percentile::(values.as_mut_slice(), self.percentile); evaluate_result_builder.append_option(value); } @@ -591,14 +615,7 @@ impl GroupsAccumulator ) -> Result> { assert_eq!(values.len(), 1, "one argument to merge_batch"); - // Cast to target type if needed (e.g., integer to Float64) - let values_array = if values[0].data_type() != &self.data_type { - arrow::compute::cast(&values[0], &self.data_type)? - } else { - Arc::clone(&values[0]) - }; - - let input_array = values_array.as_primitive::(); + let input_array = values[0].as_primitive::(); // Directly convert the input array to states, each row will be // seen as a respective group. @@ -608,8 +625,7 @@ impl GroupsAccumulator // to null. // Reuse values buffer in `input_array` to build `values` in `ListArray` - let values = PrimitiveArray::::new(input_array.values().clone(), None) - .with_data_type(self.data_type.clone()); + let values = PrimitiveArray::::new(input_array.values().clone(), None); // `offsets` in `ListArray`, each row as a list element let offset_end = i32::try_from(input_array.len()).map_err(|e| { @@ -630,7 +646,7 @@ impl GroupsAccumulator let nulls = filtered_null_mask(opt_filter, input_array); let converted_list_array = ListArray::new( - Arc::new(Field::new_list_field(self.data_type.clone(), true)), + Arc::new(Field::new_list_field(T::DATA_TYPE, true)), offsets, Arc::new(values), nulls, @@ -656,11 +672,24 @@ impl GroupsAccumulator #[derive(Debug)] struct DistinctPercentileContAccumulator { distinct_values: GenericDistinctBuffer, - data_type: DataType, percentile: f64, } -impl Accumulator for DistinctPercentileContAccumulator { +impl DistinctPercentileContAccumulator { + fn new(percentile: f64) -> Self { + Self { + distinct_values: GenericDistinctBuffer::new(T::DATA_TYPE), + percentile, + } + } +} + +impl Accumulator for DistinctPercentileContAccumulator +where + T: ArrowNumericType + Debug, + T::Native: Copy + AsPrimitive, + f64: AsPrimitive, +{ fn state(&mut self) -> Result> { self.distinct_values.state() } @@ -674,17 +703,31 @@ impl Accumulator for DistinctPercentileContAccumula } fn evaluate(&mut self) -> Result { - let d = std::mem::take(&mut self.distinct_values.values) - .into_iter() - .map(|v| v.0) - .collect::>(); - let value = calculate_percentile::(d, self.percentile); - ScalarValue::new_primitive::(value, &self.data_type) + let mut values: Vec = + self.distinct_values.values.iter().map(|v| v.0).collect(); + let value = calculate_percentile::(&mut values, self.percentile); + ScalarValue::new_primitive::(value, &T::DATA_TYPE) } fn size(&self) -> usize { size_of_val(self) + self.distinct_values.size() } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = values[0].as_primitive::(); + for value in arr.iter().flatten() { + self.distinct_values.values.remove(&Hashable(value)); + } + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } } /// Calculate the percentile value for a given set of values. @@ -694,10 +737,18 @@ impl Accumulator for DistinctPercentileContAccumula /// For percentile p and n values: /// - If p * (n-1) is an integer, return the value at that position /// - Otherwise, interpolate between the two closest values +/// +/// Note: This function takes a mutable slice and sorts it in place, but does not +/// consume the data. This is important for window frame queries where evaluate() +/// may be called multiple times on the same accumulator state. fn calculate_percentile( - mut values: Vec, + values: &mut [T::Native], percentile: f64, -) -> Option { +) -> Option +where + T::Native: Copy + AsPrimitive, + f64: AsPrimitive, +{ let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); let len = values.len(); @@ -741,22 +792,47 @@ fn calculate_percentile( let (_, upper_value, _) = values.select_nth_unstable_by(upper_index, cmp); let upper_value = *upper_value; - // Linear interpolation using wrapping arithmetic - // We use wrapping operations here (matching the approach in median.rs) because: - // 1. Both values come from the input data, so diff is bounded by the value range - // 2. fraction is between 0 and 1, and INTERPOLATION_PRECISION is small enough - // to prevent overflow when combined with typical numeric ranges - // 3. The result is guaranteed to be between lower_value and upper_value - // 4. For floating-point types, wrapping ops behave the same as standard ops + // Linear interpolation. + // We compute a quantized interpolation weight using `INTERPOLATION_PRECISION` because: + // 1. Both values come from the input data, so (upper - lower) is bounded by the value range + // 2. fraction is between 0 and 1; quantizing it provides stable, predictable results + // 3. The result is guaranteed to be between lower_value and upper_value (modulo cast rounding) + // 4. Arithmetic is performed in f64 and cast back to avoid overflowing Float16 intermediates let fraction = index - (lower_index as f64); - let diff = upper_value.sub_wrapping(lower_value); - let interpolated = lower_value.add_wrapping( - diff.mul_wrapping(T::Native::usize_as( - (fraction * INTERPOLATION_PRECISION as f64) as usize, - )) - .div_wrapping(T::Native::usize_as(INTERPOLATION_PRECISION)), - ); - Some(interpolated) + let scaled = (fraction * INTERPOLATION_PRECISION) as usize; + let weight = scaled as f64 / INTERPOLATION_PRECISION; + + let lower_f: f64 = lower_value.as_(); + let upper_f: f64 = upper_value.as_(); + let interpolated_f = lower_f + (upper_f - lower_f) * weight; + Some(interpolated_f.as_()) } } } + +#[cfg(test)] +mod tests { + use super::calculate_percentile; + use half::f16; + + #[test] + fn f16_interpolation_does_not_overflow_to_nan() { + // Regression test for https://github.com/apache/datafusion/issues/18945 + // Interpolating between 0 and the max finite f16 value previously overflowed + // intermediate f16 computations and produced NaN. + let mut values = vec![f16::from_f32(0.0), f16::from_f32(65504.0)]; + let result = + calculate_percentile::(&mut values, 0.5) + .expect("non-empty input"); + let result_f = result.to_f32(); + assert!( + !result_f.is_nan(), + "expected non-NaN result, got {result_f}" + ); + // 0.5 percentile should be close to midpoint + assert!( + (result_f - 32752.0).abs() < 1.0, + "unexpected result {result_f}" + ); + } +} diff --git a/datafusion/functions-aggregate/src/planner.rs b/datafusion/functions-aggregate/src/planner.rs index f0e37f6b1dbe4..8a6d9b9bb1e9f 100644 --- a/datafusion/functions-aggregate/src/planner.rs +++ b/datafusion/functions-aggregate/src/planner.rs @@ -19,11 +19,11 @@ use datafusion_common::Result; use datafusion_expr::{ + Expr, expr::{AggregateFunction, AggregateFunctionParams}, expr_rewriter::NamePreserver, planner::{ExprPlanner, PlannerResult, RawAggregateExpr}, utils::COUNT_STAR_EXPANSION, - Expr, }; #[derive(Debug)] diff --git a/datafusion/functions-aggregate/src/regr.rs b/datafusion/functions-aggregate/src/regr.rs index 045cb99838430..3d5bbf1eda24e 100644 --- a/datafusion/functions-aggregate/src/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -17,25 +17,16 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use arrow::array::Float64Array; use arrow::datatypes::FieldRef; -use arrow::{ - array::{ArrayRef, UInt64Array}, - compute::cast, - datatypes::DataType, - datatypes::Field, -}; -use datafusion_common::{ - downcast_value, plan_err, unwrap_or_internal_err, HashMap, Result, ScalarValue, -}; +use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; +use datafusion_common::cast::{as_float64_array, as_uint64_array}; +use datafusion_common::{HashMap, Result, ScalarValue}; use datafusion_doc::aggregate_doc_sections::DOC_SECTION_STATISTICAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, }; -use std::any::Any; use std::fmt::Debug; use std::hash::Hash; use std::mem::size_of_val; @@ -58,26 +49,20 @@ make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX); make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY); make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY); -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Regr { signature: Signature, regr_type: RegrType, func_name: &'static str, } -impl Debug for Regr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("regr") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Regr { pub fn new(regr_type: RegrType, func_name: &'static str) -> Self { Self { - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + signature: Signature::exact( + vec![DataType::Float64, DataType::Float64], + Volatility::Immutable, + ), regr_type, func_name, } @@ -85,7 +70,6 @@ impl Regr { } #[derive(Debug, Clone, PartialEq, Hash, Eq)] -#[allow(clippy::upper_case_acronyms)] pub enum RegrType { /// Variant for `regr_slope` aggregate expression /// Returns the slope of the linear regression line for non-null pairs in aggregate columns. @@ -457,10 +441,6 @@ fn get_regr_docs() -> &'static HashMap { } impl AggregateUDFImpl for Regr { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { self.func_name } @@ -469,18 +449,26 @@ impl AggregateUDFImpl for Regr { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Covariance requires numeric input types"); - } - - if matches!(self.regr_type, RegrType::Count) { + fn return_type(&self, _arg_types: &[DataType]) -> Result { + if self.regr_type == RegrType::Count { Ok(DataType::UInt64) } else { Ok(DataType::Float64) } } + fn default_value(&self, _data_type: &DataType) -> Result { + if self.regr_type == RegrType::Count { + Ok(ScalarValue::UInt64(Some(0))) + } else { + Ok(ScalarValue::Float64(None)) + } + } + + fn is_nullable(&self) -> bool { + self.regr_type != RegrType::Count + } + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) } @@ -607,32 +595,18 @@ impl Accumulator for RegrAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { // regr_slope(Y, X) calculates k in y = k*x + b - let values_y = &cast(&values[0], &DataType::Float64)?; - let values_x = &cast(&values[1], &DataType::Float64)?; - - let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten(); - let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten(); + let values_y = as_float64_array(&values[0])?; + let values_x = as_float64_array(&values[1])?; - for i in 0..values_y.len() { + for (value_y, value_x) in values_y.iter().zip(values_x) { // skip either x or y is NULL - let value_y = if values_y.is_valid(i) { - arr_y.next() - } else { - None - }; - let value_x = if values_x.is_valid(i) { - arr_x.next() - } else { - None + let (value_y, value_x) = match (value_y, value_x) { + (Some(y), Some(x)) => (y, x), + // skip either x or y is NULL + _ => continue, }; - if value_y.is_none() || value_x.is_none() { - continue; - } // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)] - let value_y = unwrap_or_internal_err!(value_y); - let value_x = unwrap_or_internal_err!(value_x); - self.count += 1; let delta_x = value_x - self.mean_x; let delta_y = value_y - self.mean_y; @@ -653,32 +627,18 @@ impl Accumulator for RegrAccumulator { } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values_y = &cast(&values[0], &DataType::Float64)?; - let values_x = &cast(&values[1], &DataType::Float64)?; + let values_y = as_float64_array(&values[0])?; + let values_x = as_float64_array(&values[1])?; - let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten(); - let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten(); - - for i in 0..values_y.len() { + for (value_y, value_x) in values_y.iter().zip(values_x) { // skip either x or y is NULL - let value_y = if values_y.is_valid(i) { - arr_y.next() - } else { - None - }; - let value_x = if values_x.is_valid(i) { - arr_x.next() - } else { - None + let (value_y, value_x) = match (value_y, value_x) { + (Some(y), Some(x)) => (y, x), + // skip either x or y is NULL + _ => continue, }; - if value_y.is_none() || value_x.is_none() { - continue; - } // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)] - let value_y = unwrap_or_internal_err!(value_y); - let value_x = unwrap_or_internal_err!(value_x); - if self.count > 1 { self.count -= 1; let delta_x = value_x - self.mean_x; @@ -704,12 +664,12 @@ impl Accumulator for RegrAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let count_arr = downcast_value!(states[0], UInt64Array); - let mean_x_arr = downcast_value!(states[1], Float64Array); - let mean_y_arr = downcast_value!(states[2], Float64Array); - let m2_x_arr = downcast_value!(states[3], Float64Array); - let m2_y_arr = downcast_value!(states[4], Float64Array); - let algo_const_arr = downcast_value!(states[5], Float64Array); + let count_arr = as_uint64_array(&states[0])?; + let mean_x_arr = as_float64_array(&states[1])?; + let mean_y_arr = as_float64_array(&states[2])?; + let m2_x_arr = as_float64_array(&states[3])?; + let m2_y_arr = as_float64_array(&states[4])?; + let algo_const_arr = as_float64_array(&states[5])?; for i in 0..count_arr.len() { let count_b = count_arr.value(i); diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 782524aa4d0ac..68e38a3b8db07 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -17,8 +17,7 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use std::any::Any; -use std::fmt::{Debug, Formatter}; +use std::fmt::Debug; use std::hash::Hash; use std::mem::align_of_val; use std::sync::Arc; @@ -26,8 +25,8 @@ use std::sync::Arc; use arrow::array::Float64Array; use arrow::datatypes::FieldRef; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; -use datafusion_common::{internal_err, not_impl_err, Result}; -use datafusion_common::{plan_err, ScalarValue}; +use datafusion_common::ScalarValue; +use datafusion_common::{Result, internal_err, not_impl_err}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -62,21 +61,12 @@ make_udaf_expr_and_func!( standard_argument(name = "expression",) )] /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct Stddev { signature: Signature, alias: Vec, } -impl Debug for Stddev { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Stddev") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for Stddev { fn default() -> Self { Self::new() @@ -87,18 +77,13 @@ impl Stddev { /// Create a new STDDEV aggregate function pub fn new() -> Self { Self { - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), alias: vec!["stddev_samp".to_string()], } } } impl AggregateUDFImpl for Stddev { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "stddev" } @@ -180,20 +165,11 @@ make_udaf_expr_and_func!( standard_argument(name = "expression",) )] /// STDDEV_POP population aggregate expression -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct StddevPop { signature: Signature, } -impl Debug for StddevPop { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("StddevPop") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for StddevPop { fn default() -> Self { Self::new() @@ -204,17 +180,12 @@ impl StddevPop { /// Create a new STDDEV_POP aggregate function pub fn new() -> Self { Self { - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), } } } impl AggregateUDFImpl for StddevPop { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "stddev_pop" } @@ -249,11 +220,7 @@ impl AggregateUDFImpl for StddevPop { Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("StddevPop requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } @@ -318,13 +285,8 @@ impl Accumulator for StddevAccumulator { fn evaluate(&mut self) -> Result { let variance = self.variance.evaluate()?; match variance { - ScalarValue::Float64(e) => { - if e.is_none() { - Ok(ScalarValue::Float64(None)) - } else { - Ok(ScalarValue::Float64(e.map(|f| f.sqrt()))) - } - } + ScalarValue::Float64(None) => Ok(ScalarValue::Float64(None)), + ScalarValue::Float64(Some(f)) => Ok(ScalarValue::Float64(Some(f.sqrt()))), _ => internal_err!("Variance should be f64"), } } @@ -396,7 +358,6 @@ mod tests { use datafusion_expr::AggregateUDF; use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays; use datafusion_physical_expr::expressions::col; - use std::sync::Arc; #[test] fn stddev_f64_merge_1() -> Result<()> { @@ -473,12 +434,16 @@ mod tests { let mut accum1 = agg1.accumulator(args1)?; let mut accum2 = agg2.accumulator(args2)?; - let value1 = vec![col("a", schema)? - .evaluate(batch1) - .and_then(|v| v.into_array(batch1.num_rows()))?]; - let value2 = vec![col("a", schema)? - .evaluate(batch2) - .and_then(|v| v.into_array(batch2.num_rows()))?]; + let value1 = vec![ + col("a", schema)? + .evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows()))?, + ]; + let value2 = vec![ + col("a", schema)? + .evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows()))?, + ]; accum1.update_batch(&value1)?; accum2.update_batch(&value2)?; diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 4a040df7b4a3b..f0757818afb93 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -17,26 +17,26 @@ //! [`StringAgg`] accumulator for the `string_agg` function -use std::any::Any; use std::hash::Hash; use std::mem::size_of_val; +use std::sync::Arc; use crate::array_agg::ArrayAgg; -use arrow::array::ArrayRef; +use arrow::array::{ArrayRef, AsArray, BooleanArray, LargeStringArray}; use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion_common::cast::{ - as_generic_string_array, as_string_array, as_string_view_array, -}; +use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::{ - internal_datafusion_err, internal_err, not_impl_err, Result, ScalarValue, + Result, ScalarValue, internal_datafusion_err, internal_err, not_impl_err, }; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, Signature, + TypeSignature, Volatility, }; use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls; use datafusion_macros::user_doc; use datafusion_physical_expr::expressions::Literal; @@ -117,6 +117,27 @@ impl StringAgg { array_agg: Default::default(), } } + + /// Extract the delimiter string from the second argument expression. + fn extract_delimiter(args: &AccumulatorArgs) -> Result { + let Some(lit) = args.exprs[1].downcast_ref::() else { + return not_impl_err!("string_agg delimiter must be a string literal"); + }; + + if lit.value().is_null() { + return Ok(String::new()); + } + + match lit.value().try_as_str() { + Some(s) => Ok(s.unwrap_or("").to_string()), + None => { + not_impl_err!( + "string_agg not supported for delimiter \"{}\"", + lit.value() + ) + } + } + } } impl Default for StringAgg { @@ -125,13 +146,11 @@ impl Default for StringAgg { } } -/// If there is no `distinct` and `order by` required by the `string_agg` call, a -/// more efficient accumulator `SimpleStringAggAccumulator` will be used. +/// Three accumulation strategies depending on query shape: +/// - No DISTINCT / ORDER BY with GROUP BY: `StringAggGroupsAccumulator` +/// - No DISTINCT / ORDER BY without GROUP BY: `SimpleStringAggAccumulator` +/// - With DISTINCT or ORDER BY: `StringAggAccumulator` (delegates to `ArrayAgg`) impl AggregateUDFImpl for StringAgg { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "string_agg" } @@ -145,52 +164,26 @@ impl AggregateUDFImpl for StringAgg { } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - // See comments in `impl AggregateUDFImpl ...` for more detail - let no_order_no_distinct = - (args.ordering_fields.is_empty()) && (!args.is_distinct); - if no_order_no_distinct { - // Case `SimpleStringAggAccumulator` - Ok(vec![Field::new( - format_state_name(args.name, "string_agg"), - DataType::LargeUtf8, - true, - ) - .into()]) + if !args.is_distinct && args.ordering_fields.is_empty() { + Ok(vec![ + Field::new( + format_state_name(args.name, "string_agg"), + DataType::LargeUtf8, + true, + ) + .into(), + ]) } else { - // Case `StringAggAccumulator` self.array_agg.state_fields(args) } } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() else { - return not_impl_err!( - "The second argument of the string_agg function must be a string literal" - ); - }; - - let delimiter = if lit.value().is_null() { - // If the second argument (the delimiter that joins strings) is NULL, join - // on an empty string. (e.g. [a, b, c] => "abc"). - "" - } else if let Some(lit_string) = lit.value().try_as_str() { - lit_string.unwrap_or("") - } else { - return not_impl_err!( - "StringAgg not supported for delimiter \"{}\"", - lit.value() - ); - }; + let delimiter = Self::extract_delimiter(&acc_args)?; - // See comments in `impl AggregateUDFImpl ...` for more detail - let no_order_no_distinct = - acc_args.order_bys.is_empty() && (!acc_args.is_distinct); - - if no_order_no_distinct { - // simple case (more efficient) - Ok(Box::new(SimpleStringAggAccumulator::new(delimiter))) + if !acc_args.is_distinct && acc_args.order_bys.is_empty() { + Ok(Box::new(SimpleStringAggAccumulator::new(&delimiter))) } else { - // general case let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs { return_field: Field::new( "f", @@ -213,7 +206,7 @@ impl AggregateUDFImpl for StringAgg { Ok(Box::new(StringAggAccumulator::new( array_agg_acc, - delimiter, + &delimiter, ))) } } @@ -222,6 +215,18 @@ impl AggregateUDFImpl for StringAgg { datafusion_expr::ReversedUDAF::Reversed(string_agg_udaf()) } + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + !args.is_distinct && args.order_bys.is_empty() + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let delimiter = Self::extract_delimiter(&args)?; + Ok(Box::new(StringAggGroupsAccumulator::new(delimiter))) + } + fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -252,7 +257,10 @@ impl Accumulator for StringAggAccumulator { let scalar = self.array_agg_acc.evaluate()?; let ScalarValue::List(list) = scalar else { - return internal_err!("Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}", scalar.data_type()); + return internal_err!( + "Expected a DataType::List while evaluating underlying ArrayAggAccumulator, but got {}", + scalar.data_type() + ); }; let string_arr: Vec<_> = match list.value_type() { @@ -272,7 +280,7 @@ impl Accumulator for StringAggAccumulator { return internal_err!( "Expected elements to of type Utf8 or LargeUtf8, but got {}", list.value_type() - ) + ); } }; @@ -310,10 +318,136 @@ fn filter_index(values: &[T], index: usize) -> Vec { .collect::>() } -/// StringAgg accumulator for the simple case (no order or distinct specified) -/// This accumulator is more efficient than `StringAggAccumulator` -/// because it accumulates the string directly, -/// whereas `StringAggAccumulator` uses `ArrayAggAccumulator`. +/// GroupsAccumulator for `string_agg` without DISTINCT or ORDER BY. +#[derive(Debug)] +struct StringAggGroupsAccumulator { + /// The delimiter placed between concatenated values. + delimiter: String, + /// Accumulated string per group. `None` means no values have been seen + /// (the group's output will be NULL). + /// A potential improvement is to avoid this String allocation + /// See + values: Vec>, + /// Running total of string data bytes across all groups. + total_data_bytes: usize, +} + +impl StringAggGroupsAccumulator { + fn new(delimiter: String) -> Self { + Self { + delimiter, + values: Vec::new(), + total_data_bytes: 0, + } + } + + fn append_batch<'a>( + &mut self, + iter: impl Iterator>, + group_indices: &[usize], + ) { + for (opt_value, &group_idx) in iter.zip(group_indices.iter()) { + if let Some(value) = opt_value { + match &mut self.values[group_idx] { + Some(existing) => { + let added = self.delimiter.len() + value.len(); + existing.reserve(added); + existing.push_str(&self.delimiter); + existing.push_str(value); + self.total_data_bytes += added; + } + slot @ None => { + *slot = Some(value.to_string()); + self.total_data_bytes += value.len(); + } + } + } + } + } +} + +impl GroupsAccumulator for StringAggGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.values.resize(total_num_groups, None); + let array = apply_filter_as_nulls(&values[0], opt_filter)?; + match array.data_type() { + DataType::Utf8 => { + self.append_batch(array.as_string::().iter(), group_indices) + } + DataType::LargeUtf8 => { + self.append_batch(array.as_string::().iter(), group_indices) + } + DataType::Utf8View => { + self.append_batch(array.as_string_view().iter(), group_indices) + } + other => { + return internal_err!("string_agg unexpected data type: {other}"); + } + } + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let to_emit = emit_to.take_needed(&mut self.values); + let emitted_bytes: usize = to_emit + .iter() + .filter_map(|opt| opt.as_ref().map(|s| s.len())) + .sum(); + self.total_data_bytes -= emitted_bytes; + + let result: ArrayRef = Arc::new(LargeStringArray::from(to_emit)); + Ok(result) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + self.evaluate(emit_to).map(|arr| vec![arr]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // State is always LargeUtf8, which update_batch already handles. + self.update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let input = apply_filter_as_nulls(&values[0], opt_filter)?; + let result = if input.data_type() == &DataType::LargeUtf8 { + input + } else { + arrow::compute::cast(&input, &DataType::LargeUtf8)? + }; + Ok(vec![result]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.total_data_bytes + + self.values.capacity() * size_of::>() + + self.delimiter.capacity() + + size_of_val(self) + } +} + +/// Per-row accumulator for `string_agg` without DISTINCT or ORDER BY. Used for +/// non-grouped aggregation; grouped queries use [`StringAggGroupsAccumulator`]. #[derive(Debug)] pub(crate) struct SimpleStringAggAccumulator { delimiter: String, @@ -326,7 +460,7 @@ impl SimpleStringAggAccumulator { pub fn new(delimiter: &str) -> Self { Self { delimiter: delimiter.to_string(), - accumulated_string: "".to_string(), + accumulated_string: String::new(), has_value: false, } } @@ -356,18 +490,11 @@ impl Accumulator for SimpleStringAggAccumulator { })?; match string_arr.data_type() { - DataType::Utf8 => { - let array = as_string_array(string_arr)?; - self.append_strings(array.iter()); - } + DataType::Utf8 => self.append_strings(string_arr.as_string::().iter()), DataType::LargeUtf8 => { - let array = as_generic_string_array::(string_arr)?; - self.append_strings(array.iter()); - } - DataType::Utf8View => { - let array = as_string_view_array(string_arr)?; - self.append_strings(array.iter()); + self.append_strings(string_arr.as_string::().iter()) } + DataType::Utf8View => self.append_strings(string_arr.as_string_view().iter()), other => { return internal_err!( "Planner should ensure string_agg first argument is Utf8-like, found {other}" @@ -379,14 +506,13 @@ impl Accumulator for SimpleStringAggAccumulator { } fn evaluate(&mut self) -> Result { - let result = if self.has_value { - ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string))) + if self.has_value { + Ok(ScalarValue::LargeUtf8(Some( + self.accumulated_string.clone(), + ))) } else { - ScalarValue::LargeUtf8(None) - }; - - self.has_value = false; - Ok(result) + Ok(ScalarValue::LargeUtf8(None)) + } } fn size(&self) -> usize { @@ -415,7 +541,6 @@ mod tests { use arrow::array::LargeStringArray; use arrow::compute::SortOptions; use arrow::datatypes::{Fields, Schema}; - use datafusion_common::internal_err; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use std::sync::Arc; @@ -659,4 +784,172 @@ mod tests { acc1.merge_batch(&intermediate_state)?; Ok(acc1) } + + // --------------------------------------------------------------- + // Tests for StringAggGroupsAccumulator + // --------------------------------------------------------------- + + fn make_groups_acc(delimiter: &str) -> StringAggGroupsAccumulator { + StringAggGroupsAccumulator::new(delimiter.to_string()) + } + + /// Helper: evaluate and downcast to LargeStringArray + fn evaluate_groups( + acc: &mut StringAggGroupsAccumulator, + emit_to: EmitTo, + ) -> Vec> { + let result = acc.evaluate(emit_to).unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + arr.iter().map(|v| v.map(|s| s.to_string())).collect() + } + + #[test] + fn groups_basic() -> Result<()> { + let mut acc = make_groups_acc(","); + + // 6 rows, 3 groups: group 0 gets "a","d"; group 1 gets "b","e"; group 2 gets "c","f" + let values: ArrayRef = + Arc::new(LargeStringArray::from(vec!["a", "b", "c", "d", "e", "f"])); + let group_indices = vec![0, 1, 2, 0, 1, 2]; + acc.update_batch(&[values], &group_indices, None, 3)?; + + let result = evaluate_groups(&mut acc, EmitTo::All); + assert_eq!( + result, + vec![ + Some("a,d".to_string()), + Some("b,e".to_string()), + Some("c,f".to_string()), + ] + ); + Ok(()) + } + + #[test] + fn groups_with_nulls() -> Result<()> { + let mut acc = make_groups_acc("|"); + + // Group 0: "a", NULL, "c" → "a|c" + // Group 1: NULL, "b" → "b" + // Group 2: NULL only → NULL + let values: ArrayRef = Arc::new(LargeStringArray::from(vec![ + Some("a"), + None, + Some("c"), + None, + Some("b"), + None, + ])); + let group_indices = vec![0, 1, 0, 2, 1, 2]; + acc.update_batch(&[values], &group_indices, None, 3)?; + + let result = evaluate_groups(&mut acc, EmitTo::All); + assert_eq!( + result, + vec![Some("a|c".to_string()), Some("b".to_string()), None,] + ); + Ok(()) + } + + #[test] + fn groups_with_filter() -> Result<()> { + let mut acc = make_groups_acc(","); + + let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", "b", "c", "d"])); + let group_indices = vec![0, 0, 1, 1]; + // Filter: only rows 0 and 3 are included + let filter = BooleanArray::from(vec![true, false, false, true]); + acc.update_batch(&[values], &group_indices, Some(&filter), 2)?; + + let result = evaluate_groups(&mut acc, EmitTo::All); + assert_eq!(result, vec![Some("a".to_string()), Some("d".to_string())]); + Ok(()) + } + + #[test] + fn groups_emit_first() -> Result<()> { + let mut acc = make_groups_acc(","); + + let values: ArrayRef = + Arc::new(LargeStringArray::from(vec!["a", "b", "c", "d", "e", "f"])); + let group_indices = vec![0, 1, 2, 0, 1, 2]; + acc.update_batch(&[values], &group_indices, None, 3)?; + + // Emit only the first 2 groups + let result = evaluate_groups(&mut acc, EmitTo::First(2)); + assert_eq!( + result, + vec![Some("a,d".to_string()), Some("b,e".to_string())] + ); + + // Group 2 (now shifted to index 0) should still be intact + let result = evaluate_groups(&mut acc, EmitTo::All); + assert_eq!(result, vec![Some("c,f".to_string())]); + Ok(()) + } + + #[test] + fn groups_merge_batch() -> Result<()> { + let mut acc = make_groups_acc(","); + + // First batch: group 0 = "a", group 1 = "b" + let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", "b"])); + acc.update_batch(&[values], &[0, 1], None, 2)?; + + // Simulate a second accumulator's state (LargeUtf8 partial strings) + let partial_state: ArrayRef = Arc::new(LargeStringArray::from(vec!["c,d", "e"])); + acc.merge_batch(&[partial_state], &[0, 1], None, 2)?; + + let result = evaluate_groups(&mut acc, EmitTo::All); + assert_eq!( + result, + vec![Some("a,c,d".to_string()), Some("b,e".to_string())] + ); + Ok(()) + } + + #[test] + fn groups_empty_groups() -> Result<()> { + let mut acc = make_groups_acc(","); + + // 4 groups total, but only groups 0 and 2 receive values + let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", "b"])); + acc.update_batch(&[values], &[0, 2], None, 4)?; + + let result = evaluate_groups(&mut acc, EmitTo::All); + assert_eq!( + result, + vec![ + Some("a".to_string()), + None, // group 1: never received a value + Some("b".to_string()), + None, // group 3: never received a value + ] + ); + Ok(()) + } + + #[test] + fn groups_multiple_batches() -> Result<()> { + let mut acc = make_groups_acc("|"); + + // Batch 1: 2 groups + let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", "b"])); + acc.update_batch(&[values], &[0, 1], None, 2)?; + + // Batch 2: same groups, plus a new group + let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["c", "d", "e"])); + acc.update_batch(&[values], &[0, 1, 2], None, 3)?; + + let result = evaluate_groups(&mut acc, EmitTo::All); + assert_eq!( + result, + vec![ + Some("a|c".to_string()), + Some("b|d".to_string()), + Some("e".to_string()), + ] + ); + Ok(()) + } } diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index d40709a467cf3..c3c2e5e0b9677 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -17,32 +17,34 @@ //! Defines `SUM` and `SUM DISTINCT` aggregate accumulators -use ahash::RandomState; use arrow::array::{Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, AsArray}; use arrow::datatypes::Field; use arrow::datatypes::{ - ArrowNativeType, DataType, Decimal128Type, Decimal256Type, Decimal32Type, - Decimal64Type, DurationMicrosecondType, DurationMillisecondType, - DurationNanosecondType, DurationSecondType, FieldRef, Float64Type, Int64Type, - TimeUnit, UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, - DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, + ArrowNativeType, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, + DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType, Decimal32Type, + Decimal64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType, + DurationMillisecondType, DurationNanosecondType, DurationSecondType, FieldRef, + Float64Type, Int64Type, TimeUnit, UInt64Type, }; +use datafusion_common::hash_utils::RandomState; +use datafusion_common::internal_err; use datafusion_common::types::{ - logical_float64, logical_int16, logical_int32, logical_int64, logical_int8, - logical_uint16, logical_uint32, logical_uint64, logical_uint8, NativeType, + NativeType, logical_float64, logical_int8, logical_int16, logical_int32, + logical_int64, logical_uint8, logical_uint16, logical_uint32, logical_uint64, }; -use datafusion_common::{exec_err, not_impl_err, HashMap, Result, ScalarValue}; +use datafusion_common::{HashMap, Result, ScalarValue, exec_err, not_impl_err}; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::expr_fn::cast; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; +use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, GroupsAccumulator, - ReversedUDAF, SetMonotonicity, Signature, TypeSignature, TypeSignatureClass, - Volatility, + Operator, ReversedUDAF, SetMonotonicity, Signature, TypeSignature, + TypeSignatureClass, Volatility, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator; use datafusion_macros::user_doc; -use std::any::Any; use std::mem::size_of_val; make_udaf_expr_and_func!( @@ -54,7 +56,7 @@ make_udaf_expr_and_func!( ); pub fn sum_distinct(expr: Expr) -> Expr { - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + Expr::AggregateFunction(AggregateFunction::new_udf( sum_udaf(), vec![expr], true, @@ -198,10 +200,6 @@ impl Default for Sum { } impl AggregateUDFImpl for Sum { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "sum" } @@ -260,20 +258,24 @@ impl AggregateUDFImpl for Sum { fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { - Ok(vec![Field::new_list( - format_state_name(args.name, "sum distinct"), - // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.return_type().clone(), true), - false, - ) - .into()]) + Ok(vec![ + Field::new_list( + format_state_name(args.name, "sum distinct"), + // See COMMENTS.md to understand why nullable is set to true + Field::new_list_field(args.return_type().clone(), true), + false, + ) + .into(), + ]) } else { - Ok(vec![Field::new( - format_state_name(args.name, "sum"), - args.return_type().clone(), - true, - ) - .into()]) + Ok(vec![ + Field::new( + format_state_name(args.name, "sum"), + args.return_type().clone(), + true, + ) + .into(), + ]) } } @@ -342,6 +344,47 @@ impl AggregateUDFImpl for Sum { _ => SetMonotonicity::NotMonotonic, } } + + /// Implement ClickBench Q29 specific optimization: + /// `SUM(arg + constant)` --> `SUM(arg) + constant * COUNT(arg)` + /// + /// See background on [`AggregateUDFImpl::simplify_expr_op_literal`] + fn simplify_expr_op_literal( + &self, + agg_function: &AggregateFunction, + arg: &Expr, + op: Operator, + lit: &Expr, + // Only support '+' so the order of the args doesn't matter + _arg_is_left: bool, + ) -> Result> { + if op != Operator::Plus { + return Ok(None); + } + + let lit_type = match &lit { + Expr::Literal(value, _) => value.data_type(), + _ => { + return internal_err!( + "Sum::simplify_expr_op_literal got a non literal argument" + ); + } + }; + if lit_type == DataType::Null { + return Ok(None); + } + + // Build up SUM(arg) + let mut sum_agg = agg_function.clone(); + sum_agg.params.args = vec![arg.clone()]; + let sum_agg = Expr::AggregateFunction(sum_agg); + + // COUNT(arg) - cast to the correct type + let count_agg = cast(crate::count::count(arg.clone()), lit_type); + + // SUM(arg) + lit * COUNT(arg) + Ok(Some(sum_agg + (lit.clone() * count_agg))) + } } /// This accumulator computes SUM incrementally @@ -490,25 +533,60 @@ impl SlidingDistinctSumAccumulator { data_type: data_type.clone(), }) } + + fn update_value(&mut self, value: i64) { + let cnt = self.counts.entry(value).or_insert(0); + if *cnt == 0 { + // first occurrence in window + self.sum = self.sum.wrapping_add(value); + } + *cnt += 1; + } + + fn retract_value(&mut self, value: i64) { + if let Some(cnt) = self.counts.get_mut(&value) { + *cnt -= 1; + if *cnt == 0 { + // last copy leaving window + self.sum = self.sum.wrapping_sub(value); + self.counts.remove(&value); + } + } + } + + fn apply_valid_values( + &mut self, + arr: &arrow::array::PrimitiveArray, + mut op: F, + ) where + F: FnMut(&mut Self, i64), + { + if arr.null_count() == 0 { + for &value in arr.values() { + op(self, value); + } + } else { + for (idx, &value) in arr.values().iter().enumerate() { + if arr.is_valid(idx) { + op(self, value); + } + } + } + } } impl Accumulator for SlidingDistinctSumAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let arr = values[0].as_primitive::(); - for &v in arr.values() { - let cnt = self.counts.entry(v).or_insert(0); - if *cnt == 0 { - // first occurrence in window - self.sum = self.sum.wrapping_add(v); - } - *cnt += 1; - } + self.apply_valid_values(arr, Self::update_value); Ok(()) } fn evaluate(&mut self) -> Result { // O(1) wrap of running sum - Ok(ScalarValue::Int64(Some(self.sum))) + Ok(ScalarValue::Int64( + (!self.counts.is_empty()).then_some(self.sum), + )) } fn size(&self) -> usize { @@ -538,11 +616,7 @@ impl Accumulator for SlidingDistinctSumAccumulator { if let ScalarValue::Int64(Some(v)) = ScalarValue::try_from_array(&*maybe_inner, idx)? { - let cnt = self.counts.entry(v).or_insert(0); - if *cnt == 0 { - self.sum = self.sum.wrapping_add(v); - } - *cnt += 1; + self.update_value(v); } } } @@ -551,16 +625,7 @@ impl Accumulator for SlidingDistinctSumAccumulator { fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let arr = values[0].as_primitive::(); - for &v in arr.values() { - if let Some(cnt) = self.counts.get_mut(&v) { - *cnt -= 1; - if *cnt == 0 { - // last copy leaving window - self.sum = self.sum.wrapping_sub(v); - self.counts.remove(&v); - } - } - } + self.apply_valid_values(arr, Self::retract_value); Ok(()) } @@ -568,3 +633,53 @@ impl Accumulator for SlidingDistinctSumAccumulator { true } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::{ + array::Int64Array, + buffer::{NullBuffer, ScalarBuffer}, + }; + use std::sync::Arc; + + #[test] + fn sliding_distinct_sum_ignores_null_slots() -> Result<()> { + let mut acc = SlidingDistinctSumAccumulator::try_new(&DataType::Int64)?; + + let values: ArrayRef = Arc::new(Int64Array::new( + ScalarBuffer::from(vec![42, 5, 5]), + Some(NullBuffer::from(vec![false, true, true])), + )); + acc.update_batch(&[values])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(5))); + + let retract: ArrayRef = Arc::new(Int64Array::new( + ScalarBuffer::from(vec![42, 5]), + Some(NullBuffer::from(vec![false, true])), + )); + acc.retract_batch(&[retract])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(5))); + + let retract_last: ArrayRef = + Arc::new(Int64Array::new(ScalarBuffer::from(vec![5]), None)); + acc.retract_batch(&[retract_last])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(None)); + + Ok(()) + } + + #[test] + fn sliding_distinct_sum_returns_null_for_all_null_frame() -> Result<()> { + let mut acc = SlidingDistinctSumAccumulator::try_new(&DataType::Int64)?; + + let values: ArrayRef = Arc::new(Int64Array::new( + ScalarBuffer::from(vec![99]), + Some(NullBuffer::from(vec![false])), + )); + acc.update_batch(&[values])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(None)); + + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/utils.rs b/datafusion/functions-aggregate/src/utils.rs index c058b64f95727..6d816e54bdaf2 100644 --- a/datafusion/functions-aggregate/src/utils.rs +++ b/datafusion/functions-aggregate/src/utils.rs @@ -19,7 +19,7 @@ use std::sync::Arc; use arrow::array::RecordBatch; use arrow::datatypes::Schema; -use datafusion_common::{internal_err, plan_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err, plan_err}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -54,11 +54,16 @@ pub(crate) fn validate_percentile_expr( let percentile = match scalar_value { ScalarValue::Float32(Some(value)) => value as f64, ScalarValue::Float64(Some(value)) => value, + ScalarValue::Float32(None) | ScalarValue::Float64(None) => { + return plan_err!( + "Percentile value for '{fn_name}' must be Float32 or Float64 literal (got null)" + ); + } sv => { return plan_err!( "Percentile value for '{fn_name}' must be Float32 or Float64 literal (got data type {})", sv.data_type() - ) + ); } }; diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 846c145cb11e7..ce3e00b9ffd91 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -18,20 +18,21 @@ //! [`VarianceSample`]: variance sample aggregations. //! [`VariancePopulation`]: variance population aggregations. -use arrow::datatypes::FieldRef; +use arrow::datatypes::{FieldRef, Float64Type}; use arrow::{ array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array}, buffer::NullBuffer, - compute::kernels::cast, datatypes::{DataType, Field}, }; -use datafusion_common::{downcast_value, not_impl_err, plan_err, Result, ScalarValue}; +use datafusion_common::cast::{as_float64_array, as_uint64_array}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ - function::{AccumulatorArgs, StateFieldsArgs}, - utils::format_state_name, Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, Volatility, + function::{AccumulatorArgs, StateFieldsArgs}, + utils::format_state_name, }; +use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer; use datafusion_functions_aggregate_common::{ aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, }; @@ -61,21 +62,12 @@ make_udaf_expr_and_func!( syntax_example = "var(expression)", standard_argument(name = "expression", prefix = "Numeric") )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct VarianceSample { signature: Signature, aliases: Vec, } -impl Debug for VarianceSample { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("VarianceSample") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for VarianceSample { fn default() -> Self { Self::new() @@ -86,16 +78,12 @@ impl VarianceSample { pub fn new() -> Self { Self { aliases: vec![String::from("var_sample"), String::from("var_samp")], - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), } } } impl AggregateUDFImpl for VarianceSample { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "var" } @@ -110,19 +98,35 @@ impl AggregateUDFImpl for VarianceSample { fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; - Ok(vec![ - Field::new(format_state_name(name, "count"), DataType::UInt64, true), - Field::new(format_state_name(name, "mean"), DataType::Float64, true), - Field::new(format_state_name(name, "m2"), DataType::Float64, true), - ] - .into_iter() - .map(Arc::new) - .collect()) + match args.is_distinct { + false => Ok(vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new(format_state_name(name, "mean"), DataType::Float64, true), + Field::new(format_state_name(name, "m2"), DataType::Float64, true), + ] + .into_iter() + .map(Arc::new) + .collect()), + true => { + let field = Field::new_list_field(DataType::Float64, true); + let state_name = "distinct_var"; + Ok(vec![ + Field::new( + format_state_name(name, state_name), + DataType::List(Arc::new(field)), + true, + ) + .into(), + ]) + } + } } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if acc_args.is_distinct { - return not_impl_err!("VAR(DISTINCT) aggregations are not available"); + return Ok(Box::new(DistinctVarianceAccumulator::new( + StatsType::Sample, + ))); } Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) @@ -154,21 +158,12 @@ impl AggregateUDFImpl for VarianceSample { syntax_example = "var_pop(expression)", standard_argument(name = "expression", prefix = "Numeric") )] -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, Hash, Debug)] pub struct VariancePopulation { signature: Signature, aliases: Vec, } -impl Debug for VariancePopulation { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("VariancePopulation") - .field("name", &self.name()) - .field("signature", &self.signature) - .finish() - } -} - impl Default for VariancePopulation { fn default() -> Self { Self::new() @@ -179,16 +174,12 @@ impl VariancePopulation { pub fn new() -> Self { Self { aliases: vec![String::from("var_population")], - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), } } } impl AggregateUDFImpl for VariancePopulation { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "var_pop" } @@ -197,29 +188,43 @@ impl AggregateUDFImpl for VariancePopulation { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if !arg_types[0].is_numeric() { - return plan_err!("Variance requires numeric input types"); - } - + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - let name = args.name; - Ok(vec![ - Field::new(format_state_name(name, "count"), DataType::UInt64, true), - Field::new(format_state_name(name, "mean"), DataType::Float64, true), - Field::new(format_state_name(name, "m2"), DataType::Float64, true), - ] - .into_iter() - .map(Arc::new) - .collect()) + match args.is_distinct { + false => { + let name = args.name; + Ok(vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new(format_state_name(name, "mean"), DataType::Float64, true), + Field::new(format_state_name(name, "m2"), DataType::Float64, true), + ] + .into_iter() + .map(Arc::new) + .collect()) + } + true => { + let field = Field::new_list_field(DataType::Float64, true); + let state_name = "distinct_var"; + Ok(vec![ + Field::new( + format_state_name(args.name, state_name), + DataType::List(Arc::new(field)), + true, + ) + .into(), + ]) + } + } } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if acc_args.is_distinct { - return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available"); + return Ok(Box::new(DistinctVarianceAccumulator::new( + StatsType::Population, + ))); } Ok(Box::new(VarianceAccumulator::try_new( @@ -243,6 +248,7 @@ impl AggregateUDFImpl for VariancePopulation { StatsType::Population, ))) } + fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -330,10 +336,8 @@ impl Accumulator for VarianceAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &cast(&values[0], &DataType::Float64)?; - let arr = downcast_value!(values, Float64Array).iter().flatten(); - - for value in arr { + let arr = as_float64_array(&values[0])?; + for value in arr.iter().flatten() { (self.count, self.mean, self.m2) = update(self.count, self.mean, self.m2, value) } @@ -342,10 +346,8 @@ impl Accumulator for VarianceAccumulator { } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &cast(&values[0], &DataType::Float64)?; - let arr = downcast_value!(values, Float64Array).iter().flatten(); - - for value in arr { + let arr = as_float64_array(&values[0])?; + for value in arr.iter().flatten() { let new_count = self.count - 1; let delta1 = self.mean - value; let new_mean = delta1 / new_count as f64 + self.mean; @@ -361,9 +363,9 @@ impl Accumulator for VarianceAccumulator { } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); - let means = downcast_value!(states[1], Float64Array); - let m2s = downcast_value!(states[2], Float64Array); + let counts = as_uint64_array(&states[0])?; + let means = as_float64_array(&states[1])?; + let m2s = as_float64_array(&states[2])?; for i in 0..counts.len() { let c = counts.value(i); @@ -498,8 +500,7 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); - let values = &cast(&values[0], &DataType::Float64)?; - let values = downcast_value!(values, Float64Array); + let values = as_float64_array(&values[0])?; self.resize(total_num_groups); accumulate(group_indices, values, opt_filter, |group_index, value| { @@ -526,9 +527,9 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { ) -> Result<()> { assert_eq!(values.len(), 3, "two arguments to merge_batch"); // first batch is counts, second is partial means, third is partial m2s - let partial_counts = downcast_value!(values[0], UInt64Array); - let partial_means = downcast_value!(values[1], Float64Array); - let partial_m2s = downcast_value!(values[2], Float64Array); + let partial_counts = as_uint64_array(&values[0])?; + let partial_means = as_float64_array(&values[1])?; + let partial_m2s = as_float64_array(&values[2])?; self.resize(total_num_groups); Self::merge( @@ -581,6 +582,71 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { } } +#[derive(Debug)] +pub struct DistinctVarianceAccumulator { + distinct_values: GenericDistinctBuffer, + stat_type: StatsType, +} + +impl DistinctVarianceAccumulator { + pub fn new(stat_type: StatsType) -> Self { + Self { + distinct_values: GenericDistinctBuffer::::new(DataType::Float64), + stat_type, + } + } +} + +impl Accumulator for DistinctVarianceAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.distinct_values.update_batch(values) + } + + fn evaluate(&mut self) -> Result { + let values = self + .distinct_values + .values + .iter() + .map(|v| v.0) + .collect::>(); + + let count = match self.stat_type { + StatsType::Sample => { + if !values.is_empty() { + values.len() - 1 + } else { + 0 + } + } + StatsType::Population => values.len(), + }; + + let mean = values.iter().sum::() / values.len() as f64; + let m2 = values.iter().map(|x| (x - mean) * (x - mean)).sum::(); + + Ok(ScalarValue::Float64(match values.len() { + 0 => None, + 1 => match self.stat_type { + StatsType::Population => Some(0.0), + StatsType::Sample => None, + }, + _ => Some(m2 / count as f64), + })) + } + + fn size(&self) -> usize { + size_of_val(self) + self.distinct_values.size() + } + + fn state(&mut self) -> Result> { + self.distinct_values.state() + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.distinct_values.merge_batch(states) + } +} + #[cfg(test)] mod tests { use datafusion_expr::EmitTo; diff --git a/datafusion/functions-nested/Cargo.toml b/datafusion/functions-nested/Cargo.toml index 3a02db7501b73..ed5a89b8e3e72 100644 --- a/datafusion/functions-nested/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -57,17 +57,30 @@ datafusion-functions-aggregate = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-macros = { workspace = true } datafusion-physical-expr-common = { workspace = true } +hashbrown = { workspace = true } itertools = { workspace = true, features = ["use_std"] } +itoa = { workspace = true } log = { workspace = true } -paste = "1.0.14" +memchr = { workspace = true } [dev-dependencies] +arrow = { workspace = true, features = ["test_utils"] } criterion = { workspace = true, features = ["async_tokio"] } rand = { workspace = true } +# used to test array_transform +datafusion-physical-expr = { workspace = true } [[bench]] harness = false -name = "array_expression" +name = "array_concat" + +[[bench]] +harness = false +name = "array_min_max" + +[[bench]] +harness = false +name = "arrays_zip" [[bench]] harness = false @@ -77,6 +90,50 @@ name = "array_has" harness = false name = "array_reverse" +[[bench]] +harness = false +name = "array_slice" + [[bench]] harness = false name = "map" + +[[bench]] +harness = false +name = "array_remove" + +[[bench]] +harness = false +name = "array_replace" + +[[bench]] +harness = false +name = "array_repeat" + +[[bench]] +harness = false +name = "array_set_ops" + +[[bench]] +harness = false +name = "array_to_string" + +[[bench]] +harness = false +name = "array_position" + +[[bench]] +harness = false +name = "array_sort" + +[[bench]] +harness = false +name = "string_to_array" + +[[bench]] +harness = false +name = "array_resize" + +[[bench]] +harness = false +name = "array_range" diff --git a/datafusion/functions-nested/benches/array_concat.rs b/datafusion/functions-nested/benches/array_concat.rs new file mode 100644 index 0000000000000..75dcc88f14737 --- /dev/null +++ b/datafusion/functions-nested/benches/array_concat.rs @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::hint::black_box; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int32Array, ListArray}; +use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{DataType, Field}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; + +use datafusion_functions_nested::concat::array_concat_inner; + +const SEED: u64 = 42; + +/// Build a `ListArray` with `num_lists` rows, each containing +/// `elements_per_list` random i32 values. Every 10th row is null. +fn make_list_array( + rng: &mut StdRng, + num_lists: usize, + elements_per_list: usize, +) -> ArrayRef { + let total_values = num_lists * elements_per_list; + let values: Vec = (0..total_values).map(|_| rng.random()).collect(); + let values = Arc::new(Int32Array::from(values)); + + let offsets: Vec = (0..=num_lists) + .map(|i| (i * elements_per_list) as i32) + .collect(); + let offsets = OffsetBuffer::new(ScalarBuffer::from(offsets)); + + let nulls: Vec = (0..num_lists).map(|i| i % 10 != 0).collect(); + let nulls = Some(NullBuffer::from(nulls)); + + Arc::new(ListArray::new( + Arc::new(Field::new("item", DataType::Int32, false)), + offsets, + values, + nulls, + )) +} + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("array_concat"); + + // Benchmark: varying number of rows, 20 elements per list + for num_rows in [100, 1000, 10000] { + let mut rng = StdRng::seed_from_u64(SEED); + let list_a = make_list_array(&mut rng, num_rows, 20); + let list_b = make_list_array(&mut rng, num_rows, 20); + let args: Vec = vec![list_a, list_b]; + + group.bench_with_input(BenchmarkId::new("rows", num_rows), &args, |b, args| { + b.iter(|| black_box(array_concat_inner(args).unwrap())); + }); + } + + // Benchmark: 1000 rows, varying element counts per list + for elements_per_list in [5, 50, 500] { + let mut rng = StdRng::seed_from_u64(SEED); + let list_a = make_list_array(&mut rng, 1000, elements_per_list); + let list_b = make_list_array(&mut rng, 1000, elements_per_list); + let args: Vec = vec![list_a, list_b]; + + group.bench_with_input( + BenchmarkId::new("elements_per_list", elements_per_list), + &args, + |b, args| { + b.iter(|| black_box(array_concat_inner(args).unwrap())); + }, + ); + } + + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_has.rs b/datafusion/functions-nested/benches/array_has.rs index a44a80c6ae63e..f5e66d56c0efe 100644 --- a/datafusion/functions-nested/benches/array_has.rs +++ b/datafusion/functions-nested/benches/array_has.rs @@ -15,20 +15,31 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; - -use criterion::{BenchmarkId, Criterion}; -use datafusion_expr::lit; -use datafusion_functions_nested::expr_fn::{ - array_has, array_has_all, array_has_any, make_array, +use arrow::array::{ArrayRef, Int64Array, ListArray, StringArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; +use criterion::{ + criterion_group, criterion_main, {BenchmarkId, Criterion}, }; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::array_has::{ArrayHas, ArrayHasAll, ArrayHasAny}; +use rand::Rng; +use rand::SeedableRng; +use rand::rngs::StdRng; use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: usize = 10000; +const SEED: u64 = 42; +const NULL_DENSITY: f64 = 0.1; +const NEEDLE_SIZE: usize = 3; // If not explicitly stated, `array` and `array_size` refer to the haystack array. fn criterion_benchmark(c: &mut Criterion) { // Test different array sizes - let array_sizes = vec![1, 10, 100, 1000, 10000]; + let array_sizes = vec![10, 100, 500]; for &size in &array_sizes { bench_array_has(c, size); @@ -41,49 +52,67 @@ fn criterion_benchmark(c: &mut Criterion) { bench_array_has_all_strings(c); bench_array_has_any_strings(c); - // Edge cases - bench_array_has_edge_cases(c); + // Benchmark for array_has_any with one scalar arg + bench_array_has_any_scalar(c); } fn bench_array_has(c: &mut Criterion, array_size: usize) { let mut group = c.benchmark_group("array_has_i64"); - - // Benchmark: element found at beginning - group.bench_with_input( - BenchmarkId::new("found_at_start", array_size), - &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle = lit(0_i64); - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) - }, - ); - - // Benchmark: element found at end + let list_array = create_int64_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let config_options = Arc::new(ConfigOptions::default()); + let return_field: Arc = Field::new("result", DataType::Boolean, true).into(); + let arg_fields: Vec> = vec![ + Field::new("arr", list_array.data_type().clone(), false).into(), + Field::new("el", DataType::Int64, false).into(), + ]; + + // Benchmark: element found + let args_found = vec![ + ColumnarValue::Array(list_array.clone()), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ]; group.bench_with_input( - BenchmarkId::new("found_at_end", array_size), + BenchmarkId::new("found", array_size), &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle = lit((size - 1) as i64); - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) + |b, _| { + let udf = ArrayHas::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_found.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }, ); // Benchmark: element not found + let args_not_found = vec![ + ColumnarValue::Array(list_array.clone()), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-999))), + ]; group.bench_with_input( BenchmarkId::new("not_found", array_size), &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle = lit(-1_i64); // Not in array - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) + |b, _| { + let udf = ArrayHas::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_not_found.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }, ); @@ -92,90 +121,190 @@ fn bench_array_has(c: &mut Criterion, array_size: usize) { fn bench_array_has_all(c: &mut Criterion, array_size: usize) { let mut group = c.benchmark_group("array_has_all"); + let haystack = create_int64_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let list_type = haystack.data_type().clone(); + let config_options = Arc::new(ConfigOptions::default()); + let return_field: Arc = Field::new("result", DataType::Boolean, true).into(); + let arg_fields: Vec> = vec![ + Field::new("haystack", list_type.clone(), false).into(), + Field::new("needle", list_type.clone(), false).into(), + ]; // Benchmark: all elements found (small needle) + let needle_found = create_int64_list_array(NUM_ROWS, NEEDLE_SIZE, 0.0); + let args_found = vec![ + ColumnarValue::Array(haystack.clone()), + ColumnarValue::Array(needle_found), + ]; group.bench_with_input( BenchmarkId::new("all_found_small_needle", array_size), &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle_array = make_array(vec![lit(0_i64), lit(1_i64), lit(2_i64)]); - - b.iter(|| black_box(array_has_all(list_array.clone(), needle_array.clone()))) + |b, _| { + let udf = ArrayHasAll::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_found.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }, ); - // Benchmark: all elements found (medium needle - 10% of haystack) + // Benchmark: not all found (needle contains elements outside haystack range) + let needle_missing = + create_int64_list_array_with_offset(NUM_ROWS, NEEDLE_SIZE, array_size as i64); + let args_missing = vec![ + ColumnarValue::Array(haystack.clone()), + ColumnarValue::Array(needle_missing), + ]; group.bench_with_input( - BenchmarkId::new("all_found_medium_needle", array_size), + BenchmarkId::new("not_all_found", array_size), &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle_size = (size / 10).max(1); - let needle = (0..needle_size).map(|i| lit(i as i64)).collect::>(); - let needle_array = make_array(needle); - - b.iter(|| black_box(array_has_all(list_array.clone(), needle_array.clone()))) - }, - ); - - // Benchmark: not all found (early exit) - group.bench_with_input( - BenchmarkId::new("early_exit", array_size), - &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle_array = make_array(vec![lit(0_i64), lit(-1_i64)]); // -1 not in array - - b.iter(|| black_box(array_has_all(list_array.clone(), needle_array.clone()))) + |b, _| { + let udf = ArrayHasAll::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_missing.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }, ); group.finish(); } +const SMALL_ARRAY_SIZE: usize = NEEDLE_SIZE; + fn bench_array_has_any(c: &mut Criterion, array_size: usize) { let mut group = c.benchmark_group("array_has_any"); - - // Benchmark: first element matches (best case) + let first_arr = create_int64_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let list_type = first_arr.data_type().clone(); + let config_options = Arc::new(ConfigOptions::default()); + let return_field: Arc = Field::new("result", DataType::Boolean, true).into(); + let arg_fields: Vec> = vec![ + Field::new("first", list_type.clone(), false).into(), + Field::new("second", list_type.clone(), false).into(), + ]; + + // Benchmark: some elements match + let second_match = create_int64_list_array(NUM_ROWS, SMALL_ARRAY_SIZE, 0.0); + let args_match = vec![ + ColumnarValue::Array(first_arr.clone()), + ColumnarValue::Array(second_match), + ]; group.bench_with_input( - BenchmarkId::new("first_match", array_size), + BenchmarkId::new("some_match", array_size), &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle_array = make_array(vec![lit(0_i64), lit(-1_i64), lit(-2_i64)]); - - b.iter(|| black_box(array_has_any(list_array.clone(), needle_array.clone()))) + |b, _| { + let udf = ArrayHasAny::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_match.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }, ); - // Benchmark: last element matches (worst case) + // Benchmark: no match + let second_no_match = create_int64_list_array_with_offset( + NUM_ROWS, + SMALL_ARRAY_SIZE, + array_size as i64, + ); + let args_no_match = vec![ + ColumnarValue::Array(first_arr.clone()), + ColumnarValue::Array(second_no_match), + ]; group.bench_with_input( - BenchmarkId::new("last_match", array_size), + BenchmarkId::new("no_match", array_size), &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle_array = make_array(vec![lit(-1_i64), lit(-2_i64), lit(0_i64)]); - - b.iter(|| black_box(array_has_any(list_array.clone(), needle_array.clone()))) + |b, _| { + let udf = ArrayHasAny::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_no_match.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }, ); - // Benchmark: no match + // Benchmark: scalar second arg, some match + let scalar_second_match = create_int64_scalar_list(SMALL_ARRAY_SIZE, 0); + let args_scalar_match = vec![ + ColumnarValue::Array(first_arr.clone()), + ColumnarValue::Scalar(scalar_second_match), + ]; group.bench_with_input( - BenchmarkId::new("no_match", array_size), + BenchmarkId::new("scalar_some_match", array_size), &array_size, - |b, &size| { - let array = (0..size).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle_array = make_array(vec![lit(-1_i64), lit(-2_i64), lit(-3_i64)]); + |b, _| { + let udf = ArrayHasAny::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_scalar_match.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }, + ); - b.iter(|| black_box(array_has_any(list_array.clone(), needle_array.clone()))) + // Benchmark: scalar second arg, no match + let scalar_second_no_match = + create_int64_scalar_list(SMALL_ARRAY_SIZE, array_size as i64); + let args_scalar_no_match = vec![ + ColumnarValue::Array(first_arr.clone()), + ColumnarValue::Scalar(scalar_second_no_match), + ]; + group.bench_with_input( + BenchmarkId::new("scalar_no_match", array_size), + &array_size, + |b, _| { + let udf = ArrayHasAny::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_scalar_no_match.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }, ); @@ -184,29 +313,56 @@ fn bench_array_has_any(c: &mut Criterion, array_size: usize) { fn bench_array_has_strings(c: &mut Criterion) { let mut group = c.benchmark_group("array_has_strings"); + let config_options = Arc::new(ConfigOptions::default()); + let return_field: Arc = Field::new("result", DataType::Boolean, true).into(); - // Benchmark with string arrays (common use case for tickers, tags, etc.) - let sizes = vec![10, 100, 1000]; + let sizes = vec![10, 100, 500]; for &size in &sizes { - group.bench_with_input(BenchmarkId::new("found", size), &size, |b, &size| { - let array = (0..size) - .map(|i| lit(format!("TICKER{i:04}"))) - .collect::>(); - let list_array = make_array(array); - let needle = lit("TICKER0005"); - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) + let list_array = create_string_list_array(NUM_ROWS, size, NULL_DENSITY); + let arg_fields: Vec> = vec![ + Field::new("arr", list_array.data_type().clone(), false).into(), + Field::new("el", DataType::Utf8, false).into(), + ]; + + let args_found = vec![ + ColumnarValue::Array(list_array.clone()), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("value_1".to_string()))), + ]; + group.bench_with_input(BenchmarkId::new("found", size), &size, |b, _| { + let udf = ArrayHas::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_found.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }); - group.bench_with_input(BenchmarkId::new("not_found", size), &size, |b, &size| { - let array = (0..size) - .map(|i| lit(format!("TICKER{i:04}"))) - .collect::>(); - let list_array = make_array(array); - let needle = lit("NOTFOUND"); - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) + let args_not_found = vec![ + ColumnarValue::Array(list_array.clone()), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("NOTFOUND".to_string()))), + ]; + group.bench_with_input(BenchmarkId::new("not_found", size), &size, |b, _| { + let udf = ArrayHas::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_not_found.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) }); } @@ -215,49 +371,173 @@ fn bench_array_has_strings(c: &mut Criterion) { fn bench_array_has_all_strings(c: &mut Criterion) { let mut group = c.benchmark_group("array_has_all_strings"); + let config_options = Arc::new(ConfigOptions::default()); + let return_field: Arc = Field::new("result", DataType::Boolean, true).into(); - // Realistic scenario: checking if a portfolio contains certain tickers - let portfolio_size = 100; - let check_sizes = vec![1, 3, 5, 10]; + let sizes = vec![10, 100, 500]; - for &check_size in &check_sizes { - group.bench_with_input( - BenchmarkId::new("all_found", check_size), - &check_size, - |b, &check_size| { - let portfolio = (0..portfolio_size) - .map(|i| lit(format!("TICKER{i:04}"))) - .collect::>(); - let list_array = make_array(portfolio); - - let checking = (0..check_size) - .map(|i| lit(format!("TICKER{i:04}"))) - .collect::>(); - let needle_array = make_array(checking); + for &size in &sizes { + let haystack = create_string_list_array(NUM_ROWS, size, NULL_DENSITY); + let list_type = haystack.data_type().clone(); + let arg_fields: Vec> = vec![ + Field::new("haystack", list_type.clone(), false).into(), + Field::new("needle", list_type.clone(), false).into(), + ]; + + let needle_found = create_string_list_array(NUM_ROWS, NEEDLE_SIZE, 0.0); + let args_found = vec![ + ColumnarValue::Array(haystack.clone()), + ColumnarValue::Array(needle_found), + ]; + group.bench_with_input(BenchmarkId::new("all_found", size), &size, |b, _| { + let udf = ArrayHasAll::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_found.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }); + let needle_missing = + create_string_list_array_with_prefix(NUM_ROWS, NEEDLE_SIZE, "missing_"); + let args_missing = vec![ + ColumnarValue::Array(haystack.clone()), + ColumnarValue::Array(needle_missing), + ]; + group.bench_with_input(BenchmarkId::new("not_all_found", size), &size, |b, _| { + let udf = ArrayHasAll::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_missing.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }); + } + + group.finish(); +} + +fn bench_array_has_any_strings(c: &mut Criterion) { + let mut group = c.benchmark_group("array_has_any_strings"); + let config_options = Arc::new(ConfigOptions::default()); + let return_field: Arc = Field::new("result", DataType::Boolean, true).into(); + + let sizes = vec![10, 100, 500]; + + for &size in &sizes { + let first_arr = create_string_list_array(NUM_ROWS, size, NULL_DENSITY); + let list_type = first_arr.data_type().clone(); + let arg_fields: Vec> = vec![ + Field::new("first", list_type.clone(), false).into(), + Field::new("second", list_type.clone(), false).into(), + ]; + + let second_match = create_string_list_array(NUM_ROWS, SMALL_ARRAY_SIZE, 0.0); + let args_match = vec![ + ColumnarValue::Array(first_arr.clone()), + ColumnarValue::Array(second_match), + ]; + group.bench_with_input(BenchmarkId::new("some_match", size), &size, |b, _| { + let udf = ArrayHasAny::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_match.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }); + + let second_no_match = + create_string_list_array_with_prefix(NUM_ROWS, SMALL_ARRAY_SIZE, "missing_"); + let args_no_match = vec![ + ColumnarValue::Array(first_arr.clone()), + ColumnarValue::Array(second_no_match), + ]; + group.bench_with_input(BenchmarkId::new("no_match", size), &size, |b, _| { + let udf = ArrayHasAny::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_no_match.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }); + + // Benchmark: scalar second arg, some match + let scalar_second_match = create_string_scalar_list(SMALL_ARRAY_SIZE, "value_"); + let args_scalar_match = vec![ + ColumnarValue::Array(first_arr.clone()), + ColumnarValue::Scalar(scalar_second_match), + ]; + group.bench_with_input( + BenchmarkId::new("scalar_some_match", size), + &size, + |b, _| { + let udf = ArrayHasAny::new(); b.iter(|| { - black_box(array_has_all(list_array.clone(), needle_array.clone())) + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_scalar_match.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) }) }, ); + // Benchmark: scalar second arg, no match + let scalar_second_no_match = + create_string_scalar_list(SMALL_ARRAY_SIZE, "missing_"); + let args_scalar_no_match = vec![ + ColumnarValue::Array(first_arr.clone()), + ColumnarValue::Scalar(scalar_second_no_match), + ]; group.bench_with_input( - BenchmarkId::new("some_missing", check_size), - &check_size, - |b, &check_size| { - let portfolio = (0..portfolio_size) - .map(|i| lit(format!("TICKER{i:04}"))) - .collect::>(); - let list_array = make_array(portfolio); - - let mut checking = (0..check_size - 1) - .map(|i| lit(format!("TICKER{i:04}"))) - .collect::>(); - checking.push(lit("NOTFOUND".to_string())); - let needle_array = make_array(checking); - + BenchmarkId::new("scalar_no_match", size), + &size, + |b, _| { + let udf = ArrayHasAny::new(); b.iter(|| { - black_box(array_has_all(list_array.clone(), needle_array.clone())) + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_scalar_no_match.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) }) }, ); @@ -266,48 +546,81 @@ fn bench_array_has_all_strings(c: &mut Criterion) { group.finish(); } -fn bench_array_has_any_strings(c: &mut Criterion) { - let mut group = c.benchmark_group("array_has_any_strings"); - - let portfolio_size = 100; - let check_sizes = vec![1, 3, 5, 10]; - - for &check_size in &check_sizes { +/// Benchmarks array_has_any with one scalar arg. Varies the scalar argument +/// size while keeping the columnar array small (3 elements per row). +fn bench_array_has_any_scalar(c: &mut Criterion) { + let mut group = c.benchmark_group("array_has_any_scalar"); + let config_options = Arc::new(ConfigOptions::default()); + let return_field: Arc = Field::new("result", DataType::Boolean, true).into(); + + let array_size = 3; + let scalar_sizes = vec![1, 10, 100, 1000]; + + // i64 benchmarks + let first_arr_i64 = create_int64_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let list_type_i64 = first_arr_i64.data_type().clone(); + let arg_fields_i64: Vec> = vec![ + Field::new("first", list_type_i64.clone(), false).into(), + Field::new("second", list_type_i64.clone(), false).into(), + ]; + + for &scalar_size in &scalar_sizes { + let scalar_arg = create_int64_scalar_list(scalar_size, array_size as i64); + let args = vec![ + ColumnarValue::Array(first_arr_i64.clone()), + ColumnarValue::Scalar(scalar_arg), + ]; group.bench_with_input( - BenchmarkId::new("first_matches", check_size), - &check_size, - |b, &check_size| { - let portfolio = (0..portfolio_size) - .map(|i| lit(format!("TICKER{i:04}"))) - .collect::>(); - let list_array = make_array(portfolio); - - let mut checking = vec![lit("TICKER0000".to_string())]; - checking.extend((1..check_size).map(|_| lit("NOTFOUND".to_string()))); - let needle_array = make_array(checking); - + BenchmarkId::new("i64_no_match", scalar_size), + &scalar_size, + |b, _| { + let udf = ArrayHasAny::new(); b.iter(|| { - black_box(array_has_any(list_array.clone(), needle_array.clone())) + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields_i64.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) }) }, ); + } + // String benchmarks + let first_arr_str = create_string_list_array(NUM_ROWS, array_size, NULL_DENSITY); + let list_type_str = first_arr_str.data_type().clone(); + let arg_fields_str: Vec> = vec![ + Field::new("first", list_type_str.clone(), false).into(), + Field::new("second", list_type_str.clone(), false).into(), + ]; + + for &scalar_size in &scalar_sizes { + let scalar_arg = create_string_scalar_list(scalar_size, "missing_"); + let args = vec![ + ColumnarValue::Array(first_arr_str.clone()), + ColumnarValue::Scalar(scalar_arg), + ]; group.bench_with_input( - BenchmarkId::new("none_match", check_size), - &check_size, - |b, &check_size| { - let portfolio = (0..portfolio_size) - .map(|i| lit(format!("TICKER{i:04}"))) - .collect::>(); - let list_array = make_array(portfolio); - - let checking = (0..check_size) - .map(|i| lit(format!("NOTFOUND{i}"))) - .collect::>(); - let needle_array = make_array(checking); - + BenchmarkId::new("string_no_match", scalar_size), + &scalar_size, + |b, _| { + let udf = ArrayHasAny::new(); b.iter(|| { - black_box(array_has_any(list_array.clone(), needle_array.clone())) + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields_str.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) }) }, ); @@ -316,61 +629,152 @@ fn bench_array_has_any_strings(c: &mut Criterion) { group.finish(); } -fn bench_array_has_edge_cases(c: &mut Criterion) { - let mut group = c.benchmark_group("array_has_edge_cases"); - - // Empty array - group.bench_function("empty_array", |b| { - let list_array = make_array(vec![]); - let needle = lit(1_i64); - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) - }); - - // Single element array - found - group.bench_function("single_element_found", |b| { - let list_array = make_array(vec![lit(1_i64)]); - let needle = lit(1_i64); - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) - }); - - // Single element array - not found - group.bench_function("single_element_not_found", |b| { - let list_array = make_array(vec![lit(1_i64)]); - let needle = lit(2_i64); - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) - }); - - // Array with duplicates - group.bench_function("array_with_duplicates", |b| { - let array = vec![lit(1_i64); 1000]; - let list_array = make_array(array); - let needle = lit(1_i64); - - b.iter(|| black_box(array_has(list_array.clone(), needle.clone()))) - }); +fn create_int64_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0..array_size as i64)) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} - // array_has_all: empty needle - group.bench_function("array_has_all_empty_needle", |b| { - let array = (0..1000).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle_array = make_array(vec![]); +/// Like `create_int64_list_array` but values are offset so they won't +/// appear in a standard list array (useful for "not found" benchmarks). +fn create_int64_list_array_with_offset( + num_rows: usize, + array_size: usize, + offset: i64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED + 1); + let values = (0..num_rows * array_size) + .map(|_| Some(rng.random_range(0..array_size as i64) + offset)) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} - b.iter(|| black_box(array_has_all(list_array.clone(), needle_array.clone()))) - }); +fn create_string_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + let idx = rng.random_range(0..array_size); + Some(format!("value_{idx}")) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Utf8, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} - // array_has_any: empty needle - group.bench_function("array_has_any_empty_needle", |b| { - let array = (0..1000).map(|i| lit(i as i64)).collect::>(); - let list_array = make_array(array); - let needle_array = make_array(vec![]); +/// Like `create_string_list_array` but values use a different prefix so +/// they won't appear in a standard string list array. +fn create_string_list_array_with_prefix( + num_rows: usize, + array_size: usize, + prefix: &str, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED + 1); + let values = (0..num_rows * array_size) + .map(|_| { + let idx = rng.random_range(0..array_size); + Some(format!("{prefix}{idx}")) + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Utf8, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} - b.iter(|| black_box(array_has_any(list_array.clone(), needle_array.clone()))) - }); +/// Create a `ScalarValue::List` containing a single list of `size` i64 elements, +/// with values starting at `offset`. +fn create_int64_scalar_list(size: usize, offset: i64) -> ScalarValue { + let values = (0..size as i64) + .map(|i| Some(i + offset)) + .collect::(); + let list = ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(vec![0, size as i32].into()), + Arc::new(values), + None, + ) + .unwrap(); + ScalarValue::List(Arc::new(list)) +} - group.finish(); +/// Create a `ScalarValue::List` containing a single list of `size` string elements, +/// with values like "{prefix}0", "{prefix}1", etc. +fn create_string_scalar_list(size: usize, prefix: &str) -> ScalarValue { + let values = (0..size) + .map(|i| Some(format!("{prefix}{i}"))) + .collect::(); + let list = ListArray::try_new( + Arc::new(Field::new("item", DataType::Utf8, true)), + OffsetBuffer::new(vec![0, size as i32].into()), + Arc::new(values), + None, + ) + .unwrap(); + ScalarValue::List(Arc::new(list)) } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions-nested/benches/array_min_max.rs b/datafusion/functions-nested/benches/array_min_max.rs new file mode 100644 index 0000000000000..45838da79f95b --- /dev/null +++ b/datafusion/functions-nested/benches/array_min_max.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef}; +use arrow::datatypes::{DataType, Field, Int64Type}; +use arrow::util::bench_util::create_primitive_list_array_with_seed; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::min_max::ArrayMax; + +const NUM_ROWS: usize = 8192; +const SEED: u64 = 42; +const LIST_NULL_DENSITY: f64 = 0.1; +const ELEMENT_NULL_DENSITY: f64 = 0.1; + +fn criterion_benchmark(c: &mut Criterion) { + let udf = ArrayMax::new(); + let config_options = Arc::new(ConfigOptions::default()); + + for list_size in [10, 100, 1000] { + for (label, null_density) in [("nulls", ELEMENT_NULL_DENSITY), ("no_nulls", 0.0)] + { + let list_array: ArrayRef = + Arc::new(create_primitive_list_array_with_seed::( + NUM_ROWS, + LIST_NULL_DENSITY as f32, + null_density as f32, + list_size, + SEED, + )); + let args = vec![ColumnarValue::Array(Arc::clone(&list_array))]; + let arg_fields = + vec![Field::new("arg_0", list_array.data_type().clone(), true).into()]; + let return_field: Arc = Field::new("f", DataType::Int64, true).into(); + + c.bench_with_input( + BenchmarkId::new("array_max", format!("{label}/list_size={list_size}")), + &list_array, + |b, _| { + b.iter(|| { + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap() + }); + }, + ); + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_position.rs b/datafusion/functions-nested/benches/array_position.rs new file mode 100644 index 0000000000000..c718b2b725640 --- /dev/null +++ b/datafusion/functions-nested/benches/array_position.rs @@ -0,0 +1,344 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int64Array, ListArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; +use criterion::{ + criterion_group, criterion_main, {BenchmarkId, Criterion}, +}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::position::{ArrayPosition, ArrayPositions}; +use rand::Rng; +use rand::SeedableRng; +use rand::rngs::StdRng; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: usize = 10000; +const SEED: u64 = 42; +const NULL_DENSITY: f64 = 0.1; +const SENTINEL_NEEDLE: i64 = -1; + +fn criterion_benchmark(c: &mut Criterion) { + for size in [10, 100, 500] { + bench_array_position(c, size); + bench_array_positions(c, size); + } +} + +fn bench_array_position(c: &mut Criterion, array_size: usize) { + let mut group = c.benchmark_group("array_position_i64"); + let haystack_found_once = create_haystack_with_sentinel( + NUM_ROWS, + array_size, + NULL_DENSITY, + SENTINEL_NEEDLE, + 0, + ); + let haystack_found_many = create_haystack_with_sentinels( + NUM_ROWS, + array_size, + NULL_DENSITY, + SENTINEL_NEEDLE, + ); + let haystack_not_found = + create_haystack_without_sentinel(NUM_ROWS, array_size, NULL_DENSITY); + let num_rows = haystack_not_found.len(); + let arg_fields: Vec> = vec![ + Field::new("haystack", haystack_not_found.data_type().clone(), false).into(), + Field::new("needle", DataType::Int64, false).into(), + ]; + let return_field: Arc = Field::new("result", DataType::UInt64, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + let needle = ScalarValue::Int64(Some(SENTINEL_NEEDLE)); + + // Benchmark: one match per row. + let args_found_once = vec![ + ColumnarValue::Array(haystack_found_once.clone()), + ColumnarValue::Scalar(needle.clone()), + ]; + group.bench_with_input( + BenchmarkId::new("found_once", array_size), + &array_size, + |b, _| { + let udf = ArrayPosition::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_found_once.clone(), + arg_fields: arg_fields.clone(), + number_rows: num_rows, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }, + ); + + // Benchmark: many matches per row. + let args_found_many = vec![ + ColumnarValue::Array(haystack_found_many.clone()), + ColumnarValue::Scalar(needle.clone()), + ]; + group.bench_with_input( + BenchmarkId::new("found_many", array_size), + &array_size, + |b, _| { + let udf = ArrayPosition::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_found_many.clone(), + arg_fields: arg_fields.clone(), + number_rows: num_rows, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }, + ); + + // Benchmark: needle is not found in any row. + let args_not_found = vec![ + ColumnarValue::Array(haystack_not_found.clone()), + ColumnarValue::Scalar(needle.clone()), + ]; + group.bench_with_input( + BenchmarkId::new("not_found", array_size), + &array_size, + |b, _| { + let udf = ArrayPosition::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_not_found.clone(), + arg_fields: arg_fields.clone(), + number_rows: num_rows, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }, + ); + + group.finish(); +} + +fn bench_array_positions(c: &mut Criterion, array_size: usize) { + let mut group = c.benchmark_group("array_positions_i64"); + let haystack_found_once = create_haystack_with_sentinel( + NUM_ROWS, + array_size, + NULL_DENSITY, + SENTINEL_NEEDLE, + 0, + ); + let haystack_found_many = create_haystack_with_sentinels( + NUM_ROWS, + array_size, + NULL_DENSITY, + SENTINEL_NEEDLE, + ); + let haystack_not_found = + create_haystack_without_sentinel(NUM_ROWS, array_size, NULL_DENSITY); + let num_rows = haystack_not_found.len(); + let arg_fields: Vec> = vec![ + Field::new("haystack", haystack_not_found.data_type().clone(), false).into(), + Field::new("needle", DataType::Int64, false).into(), + ]; + let return_field: Arc = Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field(DataType::UInt64, true))), + true, + ) + .into(); + let config_options = Arc::new(ConfigOptions::default()); + let needle = ScalarValue::Int64(Some(SENTINEL_NEEDLE)); + + let args_found_once = vec![ + ColumnarValue::Array(haystack_found_once.clone()), + ColumnarValue::Scalar(needle.clone()), + ]; + group.bench_with_input( + BenchmarkId::new("found_once", array_size), + &array_size, + |b, _| { + let udf = ArrayPositions::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_found_once.clone(), + arg_fields: arg_fields.clone(), + number_rows: num_rows, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }, + ); + + let args_found_many = vec![ + ColumnarValue::Array(haystack_found_many.clone()), + ColumnarValue::Scalar(needle.clone()), + ]; + group.bench_with_input( + BenchmarkId::new("found_many", array_size), + &array_size, + |b, _| { + let udf = ArrayPositions::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_found_many.clone(), + arg_fields: arg_fields.clone(), + number_rows: num_rows, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }, + ); + + let args_not_found = vec![ + ColumnarValue::Array(haystack_not_found.clone()), + ColumnarValue::Scalar(needle.clone()), + ]; + group.bench_with_input( + BenchmarkId::new("not_found", array_size), + &array_size, + |b, _| { + let udf = ArrayPositions::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_not_found.clone(), + arg_fields: arg_fields.clone(), + number_rows: num_rows, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }, + ); + + group.finish(); +} + +fn create_haystack_without_sentinel( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + create_haystack_from_fn(num_rows, array_size, |_, _, rng| { + random_haystack_value(rng, array_size, null_density) + }) +} + +fn create_haystack_with_sentinel( + num_rows: usize, + array_size: usize, + null_density: f64, + sentinel: i64, + sentinel_index: usize, +) -> ArrayRef { + assert!(sentinel_index < array_size); + + create_haystack_from_fn(num_rows, array_size, |_, col, rng| { + if col == sentinel_index { + Some(sentinel) + } else { + random_haystack_value(rng, array_size, null_density) + } + }) +} + +fn create_haystack_with_sentinels( + num_rows: usize, + array_size: usize, + null_density: f64, + sentinel: i64, +) -> ArrayRef { + create_haystack_from_fn(num_rows, array_size, |_, col, rng| { + // Place the sentinel in half the positions to create many matches per row. + if col % 2 == 0 { + Some(sentinel) + } else { + random_haystack_value(rng, array_size, null_density) + } + }) +} + +fn create_haystack_from_fn( + num_rows: usize, + array_size: usize, + mut value_at: F, +) -> ArrayRef +where + F: FnMut(usize, usize, &mut StdRng) -> Option, +{ + let mut rng = StdRng::seed_from_u64(SEED); + let mut values = Vec::with_capacity(num_rows * array_size); + for row in 0..num_rows { + for col in 0..array_size { + values.push(value_at(row, col, &mut rng)); + } + } + let values = values.into_iter().collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn random_haystack_value( + rng: &mut StdRng, + array_size: usize, + null_density: f64, +) -> Option { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range(0..array_size as i64)) + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_range.rs b/datafusion/functions-nested/benches/array_range.rs new file mode 100644 index 0000000000000..1f82cbd2291fd --- /dev/null +++ b/datafusion/functions-nested/benches/array_range.rs @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for range and generate_series functions. + +use arrow::array::{ArrayRef, Int64Array}; +use arrow::datatypes::{DataType, Field}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::{ScalarValue, config::ConfigOptions}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::range::{Range, gen_series_udf}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: &[usize] = &[100, 1000, 10000]; +const INT_STEPS: &[i64] = &[1, 5, 50]; +const INT_DIRECTIONS: &[(&str, bool)] = &[("increasing", true), ("decreasing", false)]; +/// Each row produces at most RANGE_SIZE elements in its Int64 list +const RANGE_SIZE: i64 = 200; +const SEED: u64 = 42; + +fn criterion_benchmark(c: &mut Criterion) { + bench_range_int64(c); + bench_generate_series_int64(c); +} + +// --------------------------------------------------------------------------- +// Int64 – range(start, stop, step) +// --------------------------------------------------------------------------- +fn bench_range_int64(c: &mut Criterion) { + let mut group = c.benchmark_group("range_int64"); + + for &num_rows in NUM_ROWS { + for &(direction, increasing) in INT_DIRECTIONS { + let (start_array, stop_array) = + make_int64_start_stop_arrays(num_rows, RANGE_SIZE, increasing); + + for &step in INT_STEPS { + let step = if increasing { step } else { -step }; + let args = vec![ + ColumnarValue::Array(start_array.clone()), + ColumnarValue::Array(stop_array.clone()), + ColumnarValue::Scalar(ScalarValue::Int64(Some(step))), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("{direction}/step_{step}"), num_rows), + &num_rows, + |b, _| { + let udf = Range::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Arc::new(Field::new( + "start", + DataType::Int64, + true, + )), + Arc::new(Field::new( + "stop", + DataType::Int64, + true, + )), + Arc::new(Field::new( + "step", + DataType::Int64, + true, + )), + ], + number_rows: num_rows, + return_field: Arc::new(Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + DataType::Int64, + true, + ))), + true, + )), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + } + + group.finish(); +} + +// --------------------------------------------------------------------------- +// Int64 – generate_series(start, stop, step) +// --------------------------------------------------------------------------- +fn bench_generate_series_int64(c: &mut Criterion) { + let mut group = c.benchmark_group("generate_series_int64"); + + for &num_rows in NUM_ROWS { + for &(direction, increasing) in INT_DIRECTIONS { + let (start_array, stop_array) = + make_int64_start_stop_arrays(num_rows, RANGE_SIZE, increasing); + + for &step in INT_STEPS { + let step = if increasing { step } else { -step }; + let args = vec![ + ColumnarValue::Array(start_array.clone()), + ColumnarValue::Array(stop_array.clone()), + ColumnarValue::Scalar(ScalarValue::Int64(Some(step))), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("{direction}/step_{step}"), num_rows), + &num_rows, + |b, _| { + let udf = gen_series_udf(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Arc::new(Field::new( + "start", + DataType::Int64, + true, + )), + Arc::new(Field::new( + "stop", + DataType::Int64, + true, + )), + Arc::new(Field::new( + "step", + DataType::Int64, + true, + )), + ], + number_rows: num_rows, + return_field: Arc::new(Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + DataType::Int64, + true, + ))), + true, + )), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + } + + group.finish(); +} + +/// Build (start, stop) Int64Arrays where each stop = start + offset, +/// with offset in [1, max_range]. This ensures every row produces a +/// bounded-size list, avoiding OOM from unbounded i64 ranges. +fn make_int64_start_stop_arrays( + num_rows: usize, + max_range: i64, + increasing: bool, +) -> (ArrayRef, ArrayRef) { + let mut rng = StdRng::seed_from_u64(SEED); + let mut starts: Vec = Vec::with_capacity(num_rows); + let mut stops: Vec = Vec::with_capacity(num_rows); + for _ in 0..num_rows { + let s = rng.random_range(0..max_range); + let offset = rng.random_range(1..=max_range); + if increasing { + starts.push(s); + stops.push(s + offset); + } else { + starts.push(s + offset); + stops.push(s); + } + } + ( + Arc::new(Int64Array::from(starts)), + Arc::new(Int64Array::from(stops)), + ) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_remove.rs b/datafusion/functions-nested/benches/array_remove.rs new file mode 100644 index 0000000000000..bfa7357384856 --- /dev/null +++ b/datafusion/functions-nested/benches/array_remove.rs @@ -0,0 +1,553 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayBuilder, ArrayRef, BooleanBuilder, FixedSizeBinaryArray, Int64Builder, + ListArray, ListBuilder, StringBuilder, +}; +use arrow::buffer::{NullBuffer, OffsetBuffer}; +use arrow::datatypes::{DataType, Field}; +use criterion::{ + criterion_group, criterion_main, {BenchmarkId, Criterion}, +}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN}; +use rand::Rng; +use rand::SeedableRng; +use rand::rngs::StdRng; +use rand::seq::IndexedRandom; +use std::hint::black_box; +use std::sync::Arc; + +// (num_rows, list_size) +// Settings tuned so benchmarks finish in approx 5 seconds +const SIZES: &[(usize, usize)] = &[(4_000, 10), (10_000, 100), (10_000, 500)]; +const NESTED_SIZES: &[(usize, usize)] = &[(4_000, 10), (3_000, 100), (1_500, 300)]; +const SEED: u64 = 42; +const HAYSTACK_NULL_DENSITY: f64 = 0.1; +const NEEDLE_DENSITY: f64 = 0.1; + +fn criterion_benchmark(c: &mut Criterion) { + bench_array_remove_int64(c); + bench_array_remove_n_int64(c); + bench_array_remove_all_int64(c); + + bench_array_remove_int64_nested(c); + bench_array_remove_n_int64_nested(c); + bench_array_remove_all_int64_nested(c); + + bench_array_remove_strings(c); + bench_array_remove_boolean(c); + bench_array_remove_fixed_size_binary(c); +} + +fn bench_array_remove_int64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_int64"); + + let filler_values = [None, Some(1), Some(2), Some(3), Some(4), Some(5)]; + let needle = 0; + for &(num_rows, list_size) in SIZES { + let list_array = create_list_array::( + num_rows, + list_size, + needle, + &filler_values, + ); + group.bench_with_input( + BenchmarkId::new( + "remove", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + let args = create_args(list_array.clone(), ScalarValue::from(needle)); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_n_int64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_n_int64"); + + let filler_values = [None, Some(1), Some(2), Some(3), Some(4), Some(5)]; + let needle = 0; + for &(num_rows, list_size) in SIZES { + let list_array = create_list_array::( + num_rows, + list_size, + needle, + &filler_values, + ); + let n = (NEEDLE_DENSITY / 2.0 * list_size as f64) as i64; + let n = 2.max(n); + + group.bench_with_input( + BenchmarkId::new( + "remove", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = ArrayRemoveN::new(); + b.iter(|| { + let args = create_args_n( + list_array.clone(), + ScalarValue::from(needle), + ScalarValue::from(n), + ); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_all_int64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_all_int64"); + + let filler_values = [None, Some(1), Some(2), Some(3), Some(4), Some(5)]; + let needle = 0; + for &(num_rows, list_size) in SIZES { + let list_array = create_list_array::( + num_rows, + list_size, + needle, + &filler_values, + ); + group.bench_with_input( + BenchmarkId::new( + "remove", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = ArrayRemoveAll::new(); + b.iter(|| { + let args = create_args(list_array.clone(), ScalarValue::from(needle)); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_int64_nested(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_int64_nested"); + + let filler_values = [ + None, + Some(vec![Some(1), Some(0), Some(2), Some(0)]), + Some(vec![Some(1)]), + Some(vec![]), + Some(vec![Some(1), Some(0), Some(2), Some(4), None]), + Some(vec![None]), + ]; + let needle = vec![Some(1), Some(0), Some(2), Some(4)]; + let needle_scalar = needle + .iter() + .copied() + .map(ScalarValue::from) + .collect::>(); + let needle_scalar = ScalarValue::List(ScalarValue::new_list_nullable( + &needle_scalar, + &DataType::Int64, + )); + for &(num_rows, list_size) in NESTED_SIZES { + let list_array = + create_nested_i64_list_array(num_rows, list_size, &needle, &filler_values); + group.bench_with_input( + BenchmarkId::new( + "remove", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + let args = create_args(list_array.clone(), needle_scalar.clone()); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_n_int64_nested(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_n_int64_nested"); + + let filler_values = [ + None, + Some(vec![Some(1), Some(0), Some(2), Some(0)]), + Some(vec![Some(1)]), + Some(vec![]), + Some(vec![Some(1), Some(0), Some(2), Some(4), None]), + Some(vec![None]), + ]; + let needle = vec![Some(1), Some(0), Some(2), Some(4)]; + let needle_scalar = needle + .iter() + .copied() + .map(ScalarValue::from) + .collect::>(); + let needle_scalar = ScalarValue::List(ScalarValue::new_list_nullable( + &needle_scalar, + &DataType::Int64, + )); + for &(num_rows, list_size) in NESTED_SIZES { + let list_array = + create_nested_i64_list_array(num_rows, list_size, &needle, &filler_values); + let n = (NEEDLE_DENSITY / 2.0 * list_size as f64) as i64; + let n = 2.max(n); + group.bench_with_input( + BenchmarkId::new( + "remove", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = ArrayRemoveN::new(); + b.iter(|| { + let args = create_args_n( + list_array.clone(), + needle_scalar.clone(), + ScalarValue::from(n), + ); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_all_int64_nested(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_all_int64_nested"); + + let filler_values = [ + None, + Some(vec![Some(1), Some(0), Some(2), Some(0)]), + Some(vec![Some(1)]), + Some(vec![]), + Some(vec![Some(1), Some(0), Some(2), Some(4), None]), + Some(vec![None]), + ]; + let needle = vec![Some(1), Some(0), Some(2), Some(4)]; + let needle_scalar = needle + .iter() + .copied() + .map(ScalarValue::from) + .collect::>(); + let needle_scalar = ScalarValue::List(ScalarValue::new_list_nullable( + &needle_scalar, + &DataType::Int64, + )); + for &(num_rows, list_size) in NESTED_SIZES { + let list_array = + create_nested_i64_list_array(num_rows, list_size, &needle, &filler_values); + group.bench_with_input( + BenchmarkId::new( + "remove", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = ArrayRemoveAll::new(); + b.iter(|| { + let args = create_args(list_array.clone(), needle_scalar.clone()); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_strings(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_strings"); + + let filler_values = [ + None, + Some("neenee"), + Some("notthis"), + Some("value1"), + Some("abc"), + Some("hello"), + ]; + let needle = "needle"; + for &(num_rows, list_size) in SIZES { + let list_array = create_list_array::( + num_rows, + list_size, + needle, + &filler_values, + ); + group.bench_with_input( + BenchmarkId::new( + "remove", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + let args = create_args(list_array.clone(), ScalarValue::from(needle)); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_boolean(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_boolean"); + + let filler_values = [None, Some(false)]; + let needle = true; + for &(num_rows, list_size) in SIZES { + let list_array = create_list_array::( + num_rows, + list_size, + needle, + &filler_values, + ); + group.bench_with_input( + BenchmarkId::new( + "remove", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + let args = create_args(list_array.clone(), ScalarValue::from(needle)); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_remove_fixed_size_binary(c: &mut Criterion) { + let mut group = c.benchmark_group("array_remove_fixed_size_binary"); + + const SIZE: usize = 16; + let filler_values = [ + None, + Some([2_u8; SIZE]), + Some([3_u8; SIZE]), + Some([4_u8; SIZE]), + Some([5_u8; SIZE]), + Some([6_u8; SIZE]), + ]; + let needle = [1_u8; SIZE]; + for &(num_rows, list_size) in SIZES { + let list_array = create_fixed_size_binary_list_array::( + num_rows, + list_size, + needle, + &filler_values, + ); + group.bench_with_input( + BenchmarkId::new( + "remove", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = ArrayRemove::new(); + b.iter(|| { + let args = create_args( + list_array.clone(), + ScalarValue::FixedSizeBinary(SIZE as i32, Some(needle.to_vec())), + ); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +#[inline] +fn create_args(haystack: ArrayRef, needle: ScalarValue) -> ScalarFunctionArgs { + let number_rows = haystack.len(); + let haystack_type = haystack.data_type().clone(); + let needle_type = needle.data_type().clone(); + ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(haystack), + ColumnarValue::Scalar(needle), + ], + arg_fields: vec![ + Field::new("haystack", haystack_type.clone(), true).into(), + Field::new("needle", needle_type, true).into(), + ], + number_rows, + return_field: Field::new("result", haystack_type, true).into(), + config_options: Arc::new(ConfigOptions::default()), + } +} + +#[inline] +fn create_args_n( + haystack: ArrayRef, + needle: ScalarValue, + n: ScalarValue, +) -> ScalarFunctionArgs { + let number_rows = haystack.len(); + let haystack_type = haystack.data_type().clone(); + let needle_type = needle.data_type().clone(); + let n_type = n.data_type().clone(); + ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(haystack), + ColumnarValue::Scalar(needle), + ColumnarValue::Scalar(n), + ], + arg_fields: vec![ + Field::new("haystack", haystack_type.clone(), true).into(), + Field::new("needle", needle_type, true).into(), + Field::new("n", n_type, true).into(), + ], + number_rows, + return_field: Field::new("result", haystack_type, true).into(), + config_options: Arc::new(ConfigOptions::default()), + } +} + +fn create_list_array( + num_rows: usize, + list_size: usize, + needle_value: Item, + filler_values: &[Option], +) -> ArrayRef +where + Builder: ArrayBuilder + Default + Extend>, + Item: Copy, +{ + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows) + .map(|_| { + if rng.random_bool(HAYSTACK_NULL_DENSITY) { + None + } else { + let list = (0..list_size) + .map(|_| { + if rng.random_bool(NEEDLE_DENSITY) { + Some(needle_value) + } else { + *filler_values.choose(&mut rng).unwrap() + } + }) + .collect::>(); + Some(list) + } + }) + .collect::>(); + Arc::new(ListArray::from_nested_iter::(values)) +} + +fn create_fixed_size_binary_list_array( + num_rows: usize, + list_size: usize, + needle_value: [u8; SIZE], + filler_values: &[Option<[u8; SIZE]>], +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let mut buffer = Vec::with_capacity(num_rows * list_size); + for _ in 0..num_rows { + for _ in 0..list_size { + if rng.random_bool(NEEDLE_DENSITY) { + buffer.push(Some(needle_value)); + } else { + buffer.push(*filler_values.choose(&mut rng).unwrap()); + } + } + } + let values = FixedSizeBinaryArray::try_from_sparse_iter_with_size( + buffer.into_iter(), + SIZE as i32, + ) + .unwrap(); + + let null_buffer = NullBuffer::from_iter( + (0..num_rows).map(|_| rng.random_bool(1.0 - HAYSTACK_NULL_DENSITY)), + ); + + Arc::new(ListArray::new( + Field::new("item", DataType::FixedSizeBinary(SIZE as i32), true).into(), + OffsetBuffer::from_repeated_length(list_size, num_rows), + Arc::new(values), + Some(null_buffer), + )) +} + +fn create_nested_i64_list_array( + num_rows: usize, + list_size: usize, + needle_value: &[Option], + filler_values: &[Option>>], +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + + let value_builder = Int64Builder::new(); + let inner_builder = ListBuilder::new(value_builder); + let mut outer_builder = ListBuilder::new(inner_builder); + + for _ in 0..num_rows { + if rng.random_bool(HAYSTACK_NULL_DENSITY) { + outer_builder.append(false); + continue; + } + + for _ in 0..list_size { + let inner = outer_builder.values(); + if rng.random_bool(NEEDLE_DENSITY) { + inner.append_value(needle_value.to_vec()); + } else { + inner.append_option(filler_values.choose(&mut rng).unwrap().clone()); + } + } + outer_builder.append(true); + } + + Arc::new(outer_builder.finish()) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_repeat.rs b/datafusion/functions-nested/benches/array_repeat.rs new file mode 100644 index 0000000000000..42372322e2812 --- /dev/null +++ b/datafusion/functions-nested/benches/array_repeat.rs @@ -0,0 +1,407 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, ListArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field, Int64Type}; +use arrow::util::bench_util::{ + create_boolean_array, create_f64_array, create_primitive_array, + create_primitive_list_array_with_seed, create_string_array_with_max_len, +}; +use criterion::{ + criterion_group, criterion_main, {BenchmarkId, Criterion}, +}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::repeat::ArrayRepeat; +use rand::Rng; +use rand::SeedableRng; +use rand::rngs::StdRng; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: &[usize] = &[100, 1000, 10000]; +// Must be of type i64 because ArrayRepeat's second argument is Int64 +const REPEAT_COUNTS: &[i64] = &[5, 50]; +const SEED: u64 = 42; +const NULL_DENSITY: f64 = 0.1; + +fn criterion_benchmark(c: &mut Criterion) { + // Test array_repeat with different element types + bench_array_repeat_int64(c); + bench_array_repeat_string(c); + bench_array_repeat_float64(c); + bench_array_repeat_boolean(c); + + // Test array_repeat with list element (nested arrays) + bench_array_repeat_nested_int64_list(c); + bench_array_repeat_nested_string_list(c); +} + +fn bench_array_repeat_int64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_int64"); + + for &num_rows in NUM_ROWS { + let element_array: ArrayRef = Arc::new(create_primitive_array::( + num_rows, + NULL_DENSITY as f32, + )); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(element_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("element", DataType::Int64, false).into(), + Field::new("count", DataType::Int64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + DataType::Int64, + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn bench_array_repeat_string(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_string"); + + for &num_rows in NUM_ROWS { + let element_array = Arc::new(create_string_array_with_max_len::( + num_rows, + NULL_DENSITY as f32, + 100, + )); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(element_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("element", DataType::Utf8, false).into(), + Field::new("count", DataType::Int64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + DataType::Utf8, + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn bench_array_repeat_nested_int64_list(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_nested_int64"); + + for &num_rows in NUM_ROWS { + let list_array: ArrayRef = + Arc::new(create_primitive_list_array_with_seed::( + num_rows, + NULL_DENSITY as f32, + NULL_DENSITY as f32, + 5, + SEED, + )); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(list_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new( + "element", + list_array.data_type().clone(), + false, + ) + .into(), + Field::new("count", DataType::Int64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + list_array.data_type().clone(), + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn bench_array_repeat_float64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_float64"); + + for &num_rows in NUM_ROWS { + let element_array = Arc::new(create_f64_array(num_rows, NULL_DENSITY as f32)); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(element_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("element", DataType::Float64, false) + .into(), + Field::new("count", DataType::Int64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + DataType::Float64, + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn bench_array_repeat_boolean(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_boolean"); + + for &num_rows in NUM_ROWS { + let element_array = Arc::new(create_boolean_array( + num_rows, + NULL_DENSITY as f32, + f32::MAX, + )); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(element_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new("element", DataType::Boolean, false) + .into(), + Field::new("count", DataType::Int64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + DataType::Boolean, + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn bench_array_repeat_nested_string_list(c: &mut Criterion) { + let mut group = c.benchmark_group("array_repeat_nested_string"); + + for &num_rows in NUM_ROWS { + let list_array = create_string_list_array(num_rows, 5, NULL_DENSITY); + + for &repeat_count in REPEAT_COUNTS { + let args = vec![ + ColumnarValue::Array(list_array.clone()), + ColumnarValue::Scalar(ScalarValue::from(repeat_count)), + ]; + + group.bench_with_input( + BenchmarkId::new(format!("repeat_{repeat_count}_count"), num_rows), + &num_rows, + |b, _| { + let udf = ArrayRepeat::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: vec![ + Field::new( + "element", + list_array.data_type().clone(), + false, + ) + .into(), + Field::new("count", DataType::Int64, false).into(), + ], + number_rows: num_rows, + return_field: Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field( + list_array.data_type().clone(), + true, + ))), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn create_string_list_array( + num_rows: usize, + array_size: usize, + null_density: f64, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + use arrow::array::StringArray; + + let values = (0..num_rows * array_size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(format!("value_{}", rng.random_range(0..100))) + } + }) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Utf8, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_replace.rs b/datafusion/functions-nested/benches/array_replace.rs new file mode 100644 index 0000000000000..a75b97c3fafca --- /dev/null +++ b/datafusion/functions-nested/benches/array_replace.rs @@ -0,0 +1,589 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayBuilder, ArrayRef, BooleanBuilder, FixedSizeBinaryArray, Int64Builder, + ListArray, ListBuilder, StringBuilder, +}; +use arrow::buffer::{NullBuffer, OffsetBuffer}; +use arrow::datatypes::{DataType, Field}; +use criterion::{ + criterion_group, criterion_main, {BenchmarkId, Criterion}, +}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions_nested::replace::{ + array_replace_all_udf, array_replace_n_udf, array_replace_udf, +}; +use rand::Rng; +use rand::SeedableRng; +use rand::rngs::StdRng; +use rand::seq::IndexedRandom; +use std::hint::black_box; +use std::sync::Arc; + +// (num_rows, list_size) +const SIZES: &[(usize, usize)] = &[(4_000, 10), (10_000, 100), (10_000, 500)]; +const NESTED_SIZES: &[(usize, usize)] = &[(4_000, 10), (3_000, 100), (1_500, 300)]; +const SEED: u64 = 42; +const HAYSTACK_NULL_DENSITY: f64 = 0.1; +const NEEDLE_DENSITY: f64 = 0.1; + +fn criterion_benchmark(c: &mut Criterion) { + bench_array_replace_int64(c); + bench_array_replace_n_int64(c); + bench_array_replace_all_int64(c); + + bench_array_replace_int64_nested(c); + bench_array_replace_n_int64_nested(c); + bench_array_replace_all_int64_nested(c); + + bench_array_replace_strings(c); + bench_array_replace_boolean(c); + bench_array_replace_fixed_size_binary(c); +} + +fn bench_array_replace_int64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_replace_int64"); + + let filler_values = [None, Some(1), Some(2), Some(3), Some(4), Some(5)]; + let from = 0_i64; + let to = 6_i64; + for &(num_rows, list_size) in SIZES { + let list_array = create_list_array::( + num_rows, + list_size, + from, + &filler_values, + ); + group.bench_with_input( + BenchmarkId::new( + "replace", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = array_replace_udf(); + b.iter(|| { + let args = create_args( + list_array.clone(), + ScalarValue::from(from), + ScalarValue::from(to), + ); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_replace_n_int64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_replace_n_int64"); + + let filler_values = [None, Some(1), Some(2), Some(3), Some(4), Some(5)]; + let from = 0_i64; + let to = 6_i64; + for &(num_rows, list_size) in SIZES { + let list_array = create_list_array::( + num_rows, + list_size, + from, + &filler_values, + ); + let n = (NEEDLE_DENSITY / 2.0 * list_size as f64) as i64; + let n = 2.max(n); + + group.bench_with_input( + BenchmarkId::new( + "replace", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = array_replace_n_udf(); + b.iter(|| { + let args = create_args_n( + list_array.clone(), + ScalarValue::from(from), + ScalarValue::from(to), + ScalarValue::from(n), + ); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_replace_all_int64(c: &mut Criterion) { + let mut group = c.benchmark_group("array_replace_all_int64"); + + let filler_values = [None, Some(1), Some(2), Some(3), Some(4), Some(5)]; + let from = 0_i64; + let to = 6_i64; + for &(num_rows, list_size) in SIZES { + let list_array = create_list_array::( + num_rows, + list_size, + from, + &filler_values, + ); + group.bench_with_input( + BenchmarkId::new( + "replace", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = array_replace_all_udf(); + b.iter(|| { + let args = create_args( + list_array.clone(), + ScalarValue::from(from), + ScalarValue::from(to), + ); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_replace_int64_nested(c: &mut Criterion) { + let mut group = c.benchmark_group("array_replace_int64_nested"); + + let filler_values = [ + None, + Some(vec![Some(1), Some(0), Some(2), Some(0)]), + Some(vec![Some(1)]), + Some(vec![]), + Some(vec![Some(1), Some(0), Some(2), Some(4), None]), + Some(vec![None]), + ]; + let from = vec![Some(1), Some(0), Some(2), Some(4)]; + let to = vec![Some(9), Some(8), Some(7)]; + let from_scalar = list_scalar(&from); + let to_scalar = list_scalar(&to); + for &(num_rows, list_size) in NESTED_SIZES { + let list_array = + create_nested_i64_list_array(num_rows, list_size, &from, &filler_values); + group.bench_with_input( + BenchmarkId::new( + "replace", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = array_replace_udf(); + b.iter(|| { + let args = create_args( + list_array.clone(), + from_scalar.clone(), + to_scalar.clone(), + ); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_replace_n_int64_nested(c: &mut Criterion) { + let mut group = c.benchmark_group("array_replace_n_int64_nested"); + + let filler_values = [ + None, + Some(vec![Some(1), Some(0), Some(2), Some(0)]), + Some(vec![Some(1)]), + Some(vec![]), + Some(vec![Some(1), Some(0), Some(2), Some(4), None]), + Some(vec![None]), + ]; + let from = vec![Some(1), Some(0), Some(2), Some(4)]; + let to = vec![Some(9), Some(8), Some(7)]; + let from_scalar = list_scalar(&from); + let to_scalar = list_scalar(&to); + for &(num_rows, list_size) in NESTED_SIZES { + let list_array = + create_nested_i64_list_array(num_rows, list_size, &from, &filler_values); + let n = (NEEDLE_DENSITY / 2.0 * list_size as f64) as i64; + let n = 2.max(n); + group.bench_with_input( + BenchmarkId::new( + "replace", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = array_replace_n_udf(); + b.iter(|| { + let args = create_args_n( + list_array.clone(), + from_scalar.clone(), + to_scalar.clone(), + ScalarValue::from(n), + ); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_replace_all_int64_nested(c: &mut Criterion) { + let mut group = c.benchmark_group("array_replace_all_int64_nested"); + + let filler_values = [ + None, + Some(vec![Some(1), Some(0), Some(2), Some(0)]), + Some(vec![Some(1)]), + Some(vec![]), + Some(vec![Some(1), Some(0), Some(2), Some(4), None]), + Some(vec![None]), + ]; + let from = vec![Some(1), Some(0), Some(2), Some(4)]; + let to = vec![Some(9), Some(8), Some(7)]; + let from_scalar = list_scalar(&from); + let to_scalar = list_scalar(&to); + for &(num_rows, list_size) in NESTED_SIZES { + let list_array = + create_nested_i64_list_array(num_rows, list_size, &from, &filler_values); + group.bench_with_input( + BenchmarkId::new( + "replace", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = array_replace_all_udf(); + b.iter(|| { + let args = create_args( + list_array.clone(), + from_scalar.clone(), + to_scalar.clone(), + ); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_replace_strings(c: &mut Criterion) { + let mut group = c.benchmark_group("array_replace_strings"); + + let filler_values = [ + None, + Some("neenee"), + Some("notthis"), + Some("value1"), + Some("abc"), + Some("hello"), + ]; + let from = "needle"; + let to = "replacement"; + for &(num_rows, list_size) in SIZES { + let list_array = create_list_array::( + num_rows, + list_size, + from, + &filler_values, + ); + group.bench_with_input( + BenchmarkId::new( + "replace", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = array_replace_udf(); + b.iter(|| { + let args = create_args( + list_array.clone(), + ScalarValue::from(from), + ScalarValue::from(to), + ); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_replace_boolean(c: &mut Criterion) { + let mut group = c.benchmark_group("array_replace_boolean"); + + let filler_values = [None, Some(false)]; + let from = true; + let to = false; + for &(num_rows, list_size) in SIZES { + let list_array = create_list_array::( + num_rows, + list_size, + from, + &filler_values, + ); + group.bench_with_input( + BenchmarkId::new( + "replace", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = array_replace_udf(); + b.iter(|| { + let args = create_args( + list_array.clone(), + ScalarValue::from(from), + ScalarValue::from(to), + ); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +fn bench_array_replace_fixed_size_binary(c: &mut Criterion) { + let mut group = c.benchmark_group("array_replace_fixed_size_binary"); + + const SIZE: usize = 16; + let filler_values = [ + None, + Some([2_u8; SIZE]), + Some([3_u8; SIZE]), + Some([4_u8; SIZE]), + Some([5_u8; SIZE]), + Some([6_u8; SIZE]), + ]; + let from = [1_u8; SIZE]; + let to = [7_u8; SIZE]; + for &(num_rows, list_size) in SIZES { + let list_array = create_fixed_size_binary_list_array::( + num_rows, + list_size, + from, + &filler_values, + ); + group.bench_with_input( + BenchmarkId::new( + "replace", + format!("list size: {list_size}, num_rows: {num_rows}"), + ), + &(list_size, num_rows), + |b, _| { + let udf = array_replace_udf(); + b.iter(|| { + let args = create_args( + list_array.clone(), + ScalarValue::FixedSizeBinary(SIZE as i32, Some(from.to_vec())), + ScalarValue::FixedSizeBinary(SIZE as i32, Some(to.to_vec())), + ); + black_box(udf.invoke_with_args(args).unwrap()) + }) + }, + ); + } + + group.finish(); +} + +#[inline] +fn create_args( + haystack: ArrayRef, + from: ScalarValue, + to: ScalarValue, +) -> ScalarFunctionArgs { + let number_rows = haystack.len(); + let haystack_type = haystack.data_type().clone(); + let from_type = from.data_type().clone(); + let to_type = to.data_type().clone(); + ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(haystack), + ColumnarValue::Scalar(from), + ColumnarValue::Scalar(to), + ], + arg_fields: vec![ + Field::new("haystack", haystack_type.clone(), true).into(), + Field::new("from", from_type, true).into(), + Field::new("to", to_type, true).into(), + ], + number_rows, + return_field: Field::new("result", haystack_type, true).into(), + config_options: Arc::new(ConfigOptions::default()), + } +} + +#[inline] +fn create_args_n( + haystack: ArrayRef, + from: ScalarValue, + to: ScalarValue, + n: ScalarValue, +) -> ScalarFunctionArgs { + let number_rows = haystack.len(); + let haystack_type = haystack.data_type().clone(); + let from_type = from.data_type().clone(); + let to_type = to.data_type().clone(); + let n_type = n.data_type().clone(); + ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(haystack), + ColumnarValue::Scalar(from), + ColumnarValue::Scalar(to), + ColumnarValue::Scalar(n), + ], + arg_fields: vec![ + Field::new("haystack", haystack_type.clone(), true).into(), + Field::new("from", from_type, true).into(), + Field::new("to", to_type, true).into(), + Field::new("n", n_type, true).into(), + ], + number_rows, + return_field: Field::new("result", haystack_type, true).into(), + config_options: Arc::new(ConfigOptions::default()), + } +} + +fn create_list_array( + num_rows: usize, + list_size: usize, + needle_value: Item, + filler_values: &[Option], +) -> ArrayRef +where + Builder: ArrayBuilder + Default + Extend>, + Item: Copy, +{ + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..num_rows) + .map(|_| { + if rng.random_bool(HAYSTACK_NULL_DENSITY) { + None + } else { + let list = (0..list_size) + .map(|_| { + if rng.random_bool(NEEDLE_DENSITY) { + Some(needle_value) + } else { + *filler_values.choose(&mut rng).unwrap() + } + }) + .collect::>(); + Some(list) + } + }) + .collect::>(); + Arc::new(ListArray::from_nested_iter::(values)) +} + +fn create_fixed_size_binary_list_array( + num_rows: usize, + list_size: usize, + needle_value: [u8; SIZE], + filler_values: &[Option<[u8; SIZE]>], +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let mut buffer = Vec::with_capacity(num_rows * list_size); + for _ in 0..num_rows { + for _ in 0..list_size { + if rng.random_bool(NEEDLE_DENSITY) { + buffer.push(Some(needle_value)); + } else { + buffer.push(*filler_values.choose(&mut rng).unwrap()); + } + } + } + let values = FixedSizeBinaryArray::try_from_sparse_iter_with_size( + buffer.into_iter(), + SIZE as i32, + ) + .unwrap(); + + let null_buffer = NullBuffer::from_iter( + (0..num_rows).map(|_| rng.random_bool(1.0 - HAYSTACK_NULL_DENSITY)), + ); + + Arc::new(ListArray::new( + Field::new("item", DataType::FixedSizeBinary(SIZE as i32), true).into(), + OffsetBuffer::from_repeated_length(list_size, num_rows), + Arc::new(values), + Some(null_buffer), + )) +} + +fn create_nested_i64_list_array( + num_rows: usize, + list_size: usize, + needle_value: &[Option], + filler_values: &[Option>>], +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + + let value_builder = Int64Builder::new(); + let inner_builder = ListBuilder::new(value_builder); + let mut outer_builder = ListBuilder::new(inner_builder); + + for _ in 0..num_rows { + if rng.random_bool(HAYSTACK_NULL_DENSITY) { + outer_builder.append(false); + continue; + } + + for _ in 0..list_size { + let inner = outer_builder.values(); + if rng.random_bool(NEEDLE_DENSITY) { + inner.append_value(needle_value.to_vec()); + } else { + inner.append_option(filler_values.choose(&mut rng).unwrap().clone()); + } + } + outer_builder.append(true); + } + + Arc::new(outer_builder.finish()) +} + +fn list_scalar(values: &[Option]) -> ScalarValue { + let values = values + .iter() + .copied() + .map(ScalarValue::from) + .collect::>(); + ScalarValue::List(ScalarValue::new_list_nullable(&values, &DataType::Int64)) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_resize.rs b/datafusion/functions-nested/benches/array_resize.rs new file mode 100644 index 0000000000000..d605ab3a20d3e --- /dev/null +++ b/datafusion/functions-nested/benches/array_resize.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int64Array, ListArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; +use criterion::{ + BenchmarkGroup, Criterion, criterion_group, criterion_main, measurement::WallTime, +}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::resize::ArrayResize; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: usize = 1_000; + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("array_resize_i64"); + let list_field: Arc = Field::new_list_field(DataType::Int64, true).into(); + let list_data_type = DataType::List(Arc::clone(&list_field)); + let arg_fields = vec![ + Field::new("array", list_data_type.clone(), true).into(), + Field::new("size", DataType::Int64, false).into(), + Field::new("value", DataType::Int64, true).into(), + ]; + let return_field: Arc = Field::new("result", list_data_type, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + let two_arg_fields = arg_fields[..2].to_vec(); + + bench_case( + &mut group, + "grow_uniform_fill_10_to_500", + &[ + ColumnarValue::Array(create_int64_list_array(NUM_ROWS, 10)), + ColumnarValue::Array(repeated_int64_array(500)), + ColumnarValue::Array(repeated_int64_array(7)), + ], + &arg_fields, + &return_field, + &config_options, + ); + + bench_case( + &mut group, + "shrink_uniform_fill_500_to_10", + &[ + ColumnarValue::Array(create_int64_list_array(NUM_ROWS, 500)), + ColumnarValue::Array(repeated_int64_array(10)), + ColumnarValue::Array(repeated_int64_array(7)), + ], + &arg_fields, + &return_field, + &config_options, + ); + + bench_case( + &mut group, + "grow_default_null_fill_10_to_500", + &[ + ColumnarValue::Array(create_int64_list_array(NUM_ROWS, 10)), + ColumnarValue::Array(repeated_int64_array(500)), + ], + &two_arg_fields, + &return_field, + &config_options, + ); + + bench_case( + &mut group, + "grow_variable_fill_10_to_500", + &[ + ColumnarValue::Array(create_int64_list_array(NUM_ROWS, 10)), + ColumnarValue::Array(repeated_int64_array(500)), + ColumnarValue::Array(distinct_fill_array()), + ], + &arg_fields, + &return_field, + &config_options, + ); + + bench_case( + &mut group, + "mixed_grow_shrink_1000x_100", + &[ + ColumnarValue::Array(create_int64_list_array(NUM_ROWS, 100)), + ColumnarValue::Array(mixed_size_array()), + ], + &arg_fields[..2], + &return_field, + &config_options, + ); + + group.finish(); +} + +fn bench_case( + group: &mut BenchmarkGroup<'_, WallTime>, + name: &str, + args: &[ColumnarValue], + arg_fields: &[Arc], + return_field: &Arc, + config_options: &Arc, +) { + let udf = ArrayResize::new(); + group.bench_function(name, |b| { + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + arg_fields: arg_fields.to_vec(), + number_rows: NUM_ROWS, + return_field: return_field.clone(), + config_options: config_options.clone(), + }) + .unwrap(), + ) + }) + }); +} + +fn create_int64_list_array(num_rows: usize, list_len: usize) -> ArrayRef { + let values = (0..(num_rows * list_len)) + .map(|v| Some(v as i64)) + .collect::(); + let offsets = (0..=num_rows) + .map(|i| (i * list_len) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new_list_field(DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +fn repeated_int64_array(value: i64) -> ArrayRef { + Arc::new(Int64Array::from_value(value, NUM_ROWS)) +} + +fn distinct_fill_array() -> ArrayRef { + Arc::new(Int64Array::from_iter((0..NUM_ROWS).map(|i| Some(i as i64)))) +} + +fn mixed_size_array() -> ArrayRef { + Arc::new(Int64Array::from_iter( + (0..NUM_ROWS).map(|i| Some(if i % 2 == 0 { 200_i64 } else { 10_i64 })), + )) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_reverse.rs b/datafusion/functions-nested/benches/array_reverse.rs index 92a65128fe6ba..0c37296188315 100644 --- a/datafusion/functions-nested/benches/array_reverse.rs +++ b/datafusion/functions-nested/benches/array_reverse.rs @@ -15,18 +15,14 @@ // specific language governing permissions and limitations // under the License. -#[macro_use] -extern crate criterion; -extern crate arrow; - use std::{hint::black_box, sync::Arc}; -use crate::criterion::Criterion; use arrow::{ array::{ArrayRef, FixedSizeListArray, Int32Array, ListArray, ListViewArray}, buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}, datatypes::{DataType, Field}, }; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_functions_nested::reverse::array_reverse_inner; fn array_reverse(array: &ArrayRef) -> ArrayRef { diff --git a/datafusion/functions-nested/benches/array_set_ops.rs b/datafusion/functions-nested/benches/array_set_ops.rs new file mode 100644 index 0000000000000..d43bbdb577d06 --- /dev/null +++ b/datafusion/functions-nested/benches/array_set_ops.rs @@ -0,0 +1,389 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int64Array, ListArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; +use criterion::{ + criterion_group, criterion_main, {BenchmarkId, Criterion}, +}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::except::ArrayExcept; +use datafusion_functions_nested::set_ops::{ArrayDistinct, ArrayIntersect, ArrayUnion}; +use rand::SeedableRng; +use rand::prelude::SliceRandom; +use rand::rngs::StdRng; +use std::collections::HashSet; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: usize = 1000; +const ARRAY_SIZES: &[usize] = &[10, 50, 100]; +const SEED: u64 = 42; +/// Extra rows on each side when building sliced arrays, so the underlying +/// values buffer is much larger than the visible portion. +const SLICE_PADDING: usize = 5000; + +fn criterion_benchmark(c: &mut Criterion) { + bench_array_union(c); + bench_array_intersect(c); + bench_array_except(c); + bench_array_distinct(c); + bench_array_union_sliced(c); + bench_array_intersect_sliced(c); + bench_array_distinct_sliced(c); + bench_array_except_sliced(c); +} + +fn invoke_udf(udf: &impl ScalarUDFImpl, array1: &ArrayRef, array2: &ArrayRef) { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(array1.clone()), + ColumnarValue::Array(array2.clone()), + ], + arg_fields: vec![ + Field::new("arr1", array1.data_type().clone(), false).into(), + Field::new("arr2", array2.data_type().clone(), false).into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new("result", array1.data_type().clone(), false).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ); +} + +fn bench_array_union(c: &mut Criterion) { + let mut group = c.benchmark_group("array_union"); + let udf = ArrayUnion::new(); + + for (overlap_label, overlap_ratio) in &[("high_overlap", 0.8), ("low_overlap", 0.2)] { + for &array_size in ARRAY_SIZES { + let (array1, array2) = + create_arrays_with_overlap(NUM_ROWS, array_size, *overlap_ratio); + group.bench_with_input( + BenchmarkId::new(*overlap_label, array_size), + &array_size, + |b, _| b.iter(|| invoke_udf(&udf, &array1, &array2)), + ); + } + } + + group.finish(); +} + +fn bench_array_intersect(c: &mut Criterion) { + let mut group = c.benchmark_group("array_intersect"); + let udf = ArrayIntersect::new(); + + for (overlap_label, overlap_ratio) in &[("high_overlap", 0.8), ("low_overlap", 0.2)] { + for &array_size in ARRAY_SIZES { + let (array1, array2) = + create_arrays_with_overlap(NUM_ROWS, array_size, *overlap_ratio); + group.bench_with_input( + BenchmarkId::new(*overlap_label, array_size), + &array_size, + |b, _| b.iter(|| invoke_udf(&udf, &array1, &array2)), + ); + } + } + + group.finish(); +} + +fn bench_array_except(c: &mut Criterion) { + let mut group = c.benchmark_group("array_except"); + let udf = ArrayExcept::new(); + + for (overlap_label, overlap_ratio) in &[("high_overlap", 0.8), ("low_overlap", 0.2)] { + for &array_size in ARRAY_SIZES { + let (array1, array2) = + create_arrays_with_overlap(NUM_ROWS, array_size, *overlap_ratio); + group.bench_with_input( + BenchmarkId::new(*overlap_label, array_size), + &array_size, + |b, _| b.iter(|| invoke_udf(&udf, &array1, &array2)), + ); + } + } + + group.finish(); +} + +fn bench_array_distinct(c: &mut Criterion) { + let mut group = c.benchmark_group("array_distinct"); + let udf = ArrayDistinct::new(); + + for (duplicate_label, duplicate_ratio) in + &[("high_duplicate", 0.8), ("low_duplicate", 0.2)] + { + for &array_size in ARRAY_SIZES { + let array = + create_array_with_duplicates(NUM_ROWS, array_size, *duplicate_ratio); + group.bench_with_input( + BenchmarkId::new(*duplicate_label, array_size), + &array_size, + |b, _| { + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(array.clone())], + arg_fields: vec![ + Field::new("arr", array.data_type().clone(), false) + .into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + } + + group.finish(); +} + +fn create_arrays_with_overlap( + num_rows: usize, + array_size: usize, + overlap_ratio: f64, +) -> (ArrayRef, ArrayRef) { + assert!((0.0..=1.0).contains(&overlap_ratio)); + let overlap_count = ((array_size as f64) * overlap_ratio).round() as usize; + + let mut rng = StdRng::seed_from_u64(SEED); + + let mut values1 = Vec::with_capacity(num_rows * array_size); + let mut values2 = Vec::with_capacity(num_rows * array_size); + + for row in 0..num_rows { + let base = (row as i64) * (array_size as i64) * 2; + + for i in 0..array_size { + values1.push(base + i as i64); + } + + let mut positions: Vec = (0..array_size).collect(); + positions.shuffle(&mut rng); + + let overlap_positions: HashSet<_> = + positions[..overlap_count].iter().copied().collect(); + + for i in 0..array_size { + if overlap_positions.contains(&i) { + values2.push(base + i as i64); + } else { + values2.push(base + array_size as i64 + i as i64); + } + } + } + + let values1 = Int64Array::from(values1); + let values2 = Int64Array::from(values2); + + let field = Arc::new(Field::new("item", DataType::Int64, true)); + + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + let array1 = Arc::new( + ListArray::try_new( + field.clone(), + OffsetBuffer::new(offsets.clone().into()), + Arc::new(values1), + None, + ) + .unwrap(), + ); + + let array2 = Arc::new( + ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + Arc::new(values2), + None, + ) + .unwrap(), + ); + + (array1, array2) +} + +fn create_array_with_duplicates( + num_rows: usize, + array_size: usize, + duplicate_ratio: f64, +) -> ArrayRef { + assert!((0.0..=1.0).contains(&duplicate_ratio)); + let unique_count = ((array_size as f64) * (1.0 - duplicate_ratio)).round() as usize; + let duplicate_count = array_size - unique_count; + + let mut rng = StdRng::seed_from_u64(SEED); + let mut values = Vec::with_capacity(num_rows * array_size); + + for row in 0..num_rows { + let base = (row as i64) * (array_size as i64) * 2; + + // Add unique values first + for i in 0..unique_count { + values.push(base + i as i64); + } + + // Fill the rest with duplicates randomly picked from the unique values + let mut unique_indices: Vec = + (0..unique_count).map(|i| base + i as i64).collect(); + unique_indices.shuffle(&mut rng); + + for i in 0..duplicate_count { + values.push(unique_indices[i % unique_count]); + } + } + + let values = Int64Array::from(values); + let field = Arc::new(Field::new("item", DataType::Int64, true)); + + let offsets = (0..=num_rows) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +/// Slice a pair of arrays to the middle `NUM_ROWS` rows from a larger array. +fn slice_pair(arrays: &(ArrayRef, ArrayRef)) -> (ArrayRef, ArrayRef) { + let a1 = arrays.0.slice(SLICE_PADDING, NUM_ROWS); + let a2 = arrays.1.slice(SLICE_PADDING, NUM_ROWS); + (a1, a2) +} + +fn bench_array_union_sliced(c: &mut Criterion) { + let mut group = c.benchmark_group("array_union_sliced"); + let udf = ArrayUnion::new(); + + for &array_size in ARRAY_SIZES { + let (a1, a2) = slice_pair(&create_arrays_with_overlap( + NUM_ROWS + 2 * SLICE_PADDING, + array_size, + 0.5, + )); + group.bench_with_input( + BenchmarkId::from_parameter(array_size), + &array_size, + |b, _| b.iter(|| invoke_udf(&udf, &a1, &a2)), + ); + } + group.finish(); +} + +fn bench_array_intersect_sliced(c: &mut Criterion) { + let mut group = c.benchmark_group("array_intersect_sliced"); + let udf = ArrayIntersect::new(); + + for &array_size in ARRAY_SIZES { + let (a1, a2) = slice_pair(&create_arrays_with_overlap( + NUM_ROWS + 2 * SLICE_PADDING, + array_size, + 0.5, + )); + group.bench_with_input( + BenchmarkId::from_parameter(array_size), + &array_size, + |b, _| b.iter(|| invoke_udf(&udf, &a1, &a2)), + ); + } + group.finish(); +} + +fn bench_array_except_sliced(c: &mut Criterion) { + let mut group = c.benchmark_group("array_except_sliced"); + let udf = ArrayExcept::new(); + + for &array_size in ARRAY_SIZES { + let (a1, a2) = slice_pair(&create_arrays_with_overlap( + NUM_ROWS + 2 * SLICE_PADDING, + array_size, + 0.5, + )); + group.bench_with_input( + BenchmarkId::from_parameter(array_size), + &array_size, + |b, _| b.iter(|| invoke_udf(&udf, &a1, &a2)), + ); + } + group.finish(); +} + +fn bench_array_distinct_sliced(c: &mut Criterion) { + let mut group = c.benchmark_group("array_distinct_sliced"); + let udf = ArrayDistinct::new(); + + for &array_size in ARRAY_SIZES { + let array = + create_array_with_duplicates(NUM_ROWS + 2 * SLICE_PADDING, array_size, 0.5) + .slice(SLICE_PADDING, NUM_ROWS); + group.bench_with_input( + BenchmarkId::from_parameter(array_size), + &array_size, + |b, _| { + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(array.clone())], + arg_fields: vec![ + Field::new("arr", array.data_type().clone(), false) + .into(), + ], + number_rows: NUM_ROWS, + return_field: Field::new( + "result", + array.data_type().clone(), + false, + ) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_slice.rs b/datafusion/functions-nested/benches/array_slice.rs new file mode 100644 index 0000000000000..b95fe47575e53 --- /dev/null +++ b/datafusion/functions-nested/benches/array_slice.rs @@ -0,0 +1,228 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Int64Array, ListArray, ListViewArray, NullBufferBuilder, PrimitiveArray, +}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{DataType, Field, Int64Type}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions_nested::extract::array_slice_udf; +use rand::rngs::StdRng; +use rand::seq::IndexedRandom; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn create_inputs( + rng: &mut StdRng, + size: usize, + child_array_size: usize, + null_density: f32, +) -> (ListArray, ListViewArray) { + let mut nulls_builder = NullBufferBuilder::new(size); + let mut sizes = Vec::with_capacity(size); + + for _ in 0..size { + if rng.random::() < null_density { + nulls_builder.append_null(); + } else { + nulls_builder.append_non_null(); + } + sizes.push(rng.random_range(1..child_array_size)); + } + let nulls = nulls_builder.finish(); + + let length = sizes.iter().sum(); + let values: PrimitiveArray = + (0..length).map(|_| Some(rng.random())).collect(); + let values = Arc::new(values); + + let offsets = OffsetBuffer::from_lengths(sizes.clone()); + let list_array = ListArray::new( + Arc::new(Field::new_list_field(DataType::Int64, true)), + offsets.clone(), + values.clone(), + nulls.clone(), + ); + + let offsets = ScalarBuffer::from(offsets.slice(0, size - 1)); + let sizes = ScalarBuffer::from_iter(sizes.into_iter().map(|v| v as i32)); + let list_view_array = ListViewArray::new( + Arc::new(Field::new_list_field(DataType::Int64, true)), + offsets, + sizes, + values, + nulls, + ); + + (list_array, list_view_array) +} + +/// Create `from`, `to`, and `stride` from an array of strides. +fn random_from_to_stride( + rng: &mut StdRng, + size: i64, + null_density: f32, + stride_choices: &[Option], +) -> (Option, Option, Option) { + let from = if rng.random::() < null_density { + None + } else { + Some(rng.random_range(1..=size)) + }; + + let to = if rng.random::() < null_density { + None + } else { + match from { + Some(from) => Some(rng.random_range(from..=size)), + None => Some(rng.random_range(1..=size)), + } + }; + + let stride = stride_choices.choose(rng).cloned().unwrap_or(None); + + if from.is_none() || to.is_none() || stride.is_none_or(|s| s > 0) { + (from, to, stride) + } else { + // stride < 0, swap from and to + (to, from, stride) + } +} + +fn array_slice_benchmark( + name: &str, + input: ColumnarValue, + mut args: Vec, + c: &mut Criterion, + size: usize, +) { + args.insert(0, input); + + let array_slice = array_slice_udf(); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + >::from(Field::new(format!("arg_{idx}"), arg.data_type(), true)) + }) + .collect::>(); + c.bench_function(name, |b| { + b.iter(|| { + black_box( + array_slice + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new_list_field(args[0].data_type(), true) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let rng = &mut StdRng::seed_from_u64(42); + let size = 1_000_000; + let child_array_size = 100; + let null_density = 0.1; + + let (list_array, list_view_array) = + create_inputs(rng, size, child_array_size, null_density); + + let mut array_from = Vec::with_capacity(size); + let mut array_to = Vec::with_capacity(size); + let mut array_stride = Vec::with_capacity(size); + for child_array_size in list_array.offsets().lengths() { + let (from, to, stride) = random_from_to_stride( + rng, + child_array_size as i64, + null_density, + &[None, Some(-2), Some(-1), Some(1), Some(2)], + ); + array_from.push(from); + array_to.push(to); + array_stride.push(stride); + } + + // input + let list_array = ColumnarValue::Array(Arc::new(list_array)); + let list_view_array = ColumnarValue::Array(Arc::new(list_view_array)); + + // args + let array_from = ColumnarValue::Array(Arc::new(Int64Array::from(array_from))); + let array_to = ColumnarValue::Array(Arc::new(Int64Array::from(array_to))); + let array_stride = ColumnarValue::Array(Arc::new(Int64Array::from(array_stride))); + let scalar_from = ColumnarValue::Scalar(ScalarValue::from(1i64)); + let scalar_to = ColumnarValue::Scalar(ScalarValue::from(child_array_size as i64 / 2)); + + for input in [list_array, list_view_array] { + let input_type = input.data_type().to_string(); + + array_slice_benchmark( + &format!("array_slice: input {input_type}, array args"), + input.clone(), + vec![array_from.clone(), array_to.clone(), array_stride.clone()], + c, + size, + ); + + array_slice_benchmark( + &format!("array_slice: input {input_type}, array args, no stride"), + input.clone(), + vec![array_from.clone(), array_to.clone()], + c, + size, + ); + + array_slice_benchmark( + &format!("array_slice: input {input_type}, scalar args, no stride"), + input.clone(), + vec![scalar_from.clone(), scalar_to.clone()], + c, + size, + ); + + for stride in [-2i64, -1i64, 1i64, 2i64] { + // swap from and to if stride < 0 + let (scalar_from, scalar_to) = if stride > 0 { + (scalar_from.clone(), scalar_to.clone()) + } else { + (scalar_to.clone(), scalar_from.clone()) + }; + let scalar_stride = ColumnarValue::Scalar(ScalarValue::from(stride)); + array_slice_benchmark( + &format!("array_slice: input {input_type}, scalar args, stride={stride}"), + input.clone(), + vec![scalar_from, scalar_to, scalar_stride], + c, + size, + ); + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_sort.rs b/datafusion/functions-nested/benches/array_sort.rs new file mode 100644 index 0000000000000..940c0396cbb08 --- /dev/null +++ b/datafusion/functions-nested/benches/array_sort.rs @@ -0,0 +1,195 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::hint::black_box; +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanBufferBuilder, Int32Array, ListArray, StringArray}; +use arrow::buffer::{NullBuffer, OffsetBuffer}; +use arrow::datatypes::{DataType, Field}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::sort::ArraySort; +use rand::SeedableRng; +use rand::rngs::StdRng; +use rand::seq::SliceRandom; + +const SEED: u64 = 42; +const NUM_ROWS: usize = 8192; + +fn create_int32_list_array( + num_rows: usize, + elements_per_row: usize, + with_nulls: bool, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let total_values = num_rows * elements_per_row; + + let mut values: Vec = (0..total_values as i32).collect(); + values.shuffle(&mut rng); + + let values = Arc::new(Int32Array::from(values)); + let offsets: Vec = (0..=num_rows) + .map(|i| (i * elements_per_row) as i32) + .collect(); + + let nulls = if with_nulls { + // Every 10th row is null + Some(NullBuffer::from( + (0..num_rows).map(|i| i % 10 != 0).collect::>(), + )) + } else { + None + }; + + Arc::new(ListArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + OffsetBuffer::new(offsets.into()), + values, + nulls, + )) +} + +/// Creates a ListArray where ~10% of elements within each row are null. +fn create_int32_list_array_with_null_elements( + num_rows: usize, + elements_per_row: usize, +) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let total_values = num_rows * elements_per_row; + + let mut values: Vec = (0..total_values as i32).collect(); + values.shuffle(&mut rng); + + // ~10% of elements are null + let mut validity = BooleanBufferBuilder::new(total_values); + for i in 0..total_values { + validity.append(i % 10 != 0); + } + let null_buffer = NullBuffer::from(validity.finish()); + + let values = Arc::new(Int32Array::new(values.into(), Some(null_buffer))); + let offsets: Vec = (0..=num_rows) + .map(|i| (i * elements_per_row) as i32) + .collect(); + + Arc::new(ListArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + OffsetBuffer::new(offsets.into()), + values, + None, + )) +} + +fn create_string_list_array(num_rows: usize, elements_per_row: usize) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let total_values = num_rows * elements_per_row; + + let mut indices: Vec = (0..total_values).collect(); + indices.shuffle(&mut rng); + let string_values: Vec = + indices.iter().map(|i| format!("value_{i:06}")).collect(); + let values = Arc::new(StringArray::from(string_values)); + + let offsets: Vec = (0..=num_rows) + .map(|i| (i * elements_per_row) as i32) + .collect(); + + Arc::new(ListArray::new( + Arc::new(Field::new("item", DataType::Utf8, true)), + OffsetBuffer::new(offsets.into()), + values, + None, + )) +} + +fn invoke_array_sort(udf: &ArraySort, array: &ArrayRef) -> ColumnarValue { + udf.invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::clone(array))], + arg_fields: vec![Field::new("arr", array.data_type().clone(), true).into()], + number_rows: array.len(), + return_field: Field::new("result", array.data_type().clone(), true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap() +} + +/// Vary elements_per_row over [5, 20, 100, 1000]: for small arrays, per-row +/// overhead dominates, whereas for larger arrays the sort kernel dominates. +fn bench_array_sort(c: &mut Criterion) { + let mut group = c.benchmark_group("array_sort"); + let udf = ArraySort::new(); + + // Int32 arrays + for &elements_per_row in &[5, 20, 100, 1000] { + let array = create_int32_list_array(NUM_ROWS, elements_per_row, false); + group.bench_with_input( + BenchmarkId::new("int32", elements_per_row), + &elements_per_row, + |b, _| { + b.iter(|| { + black_box(invoke_array_sort(&udf, &array)); + }); + }, + ); + } + + // Int32 with nulls in the outer list (10% null rows), single size + { + let array = create_int32_list_array(NUM_ROWS, 50, true); + group.bench_function("int32_with_nulls", |b| { + b.iter(|| { + black_box(invoke_array_sort(&udf, &array)); + }); + }); + } + + // Int32 with null elements (~10% of elements within rows are null) + for &elements_per_row in &[5, 20, 100, 1000] { + let array = + create_int32_list_array_with_null_elements(NUM_ROWS, elements_per_row); + group.bench_with_input( + BenchmarkId::new("int32_null_elements", elements_per_row), + &elements_per_row, + |b, _| { + b.iter(|| { + black_box(invoke_array_sort(&udf, &array)); + }); + }, + ); + } + + // String arrays + for &elements_per_row in &[5, 20, 100, 1000] { + let array = create_string_list_array(NUM_ROWS, elements_per_row); + group.bench_with_input( + BenchmarkId::new("string", elements_per_row), + &elements_per_row, + |b, _| { + b.iter(|| { + black_box(invoke_array_sort(&udf, &array)); + }); + }, + ); + } + + group.finish(); +} + +criterion_group!(benches, bench_array_sort); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/array_to_string.rs b/datafusion/functions-nested/benches/array_to_string.rs new file mode 100644 index 0000000000000..4b63d705480bf --- /dev/null +++ b/datafusion/functions-nested/benches/array_to_string.rs @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, ListArray, StringArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field, Float64Type, Int64Type}; +use arrow::util::bench_util::create_primitive_list_array_with_seed; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::string::ArrayToString; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: usize = 1000; +const ARRAY_SIZES: &[usize] = &[5, 20, 100]; +const NESTED_ARRAY_SIZE: usize = 3; +const SEED: u64 = 42; +const NULL_DENSITY: f64 = 0.1; + +fn criterion_benchmark(c: &mut Criterion) { + bench_array_to_string(c, "array_to_string_int64", create_int64_list_array); + bench_array_to_string(c, "array_to_string_float64", create_float64_list_array); + bench_array_to_string(c, "array_to_string_string", create_string_list_array); + bench_array_to_string( + c, + "array_to_string_nested_int64", + create_nested_int64_list_array, + ); +} + +fn bench_array_to_string( + c: &mut Criterion, + group_name: &str, + make_array: impl Fn(usize) -> ArrayRef, +) { + let mut group = c.benchmark_group(group_name); + + for &array_size in ARRAY_SIZES { + let list_array = make_array(array_size); + let args = vec![ + ColumnarValue::Array(list_array.clone()), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))), + ]; + let arg_fields = vec![ + Field::new("array", list_array.data_type().clone(), true).into(), + Field::new("delimiter", DataType::Utf8, false).into(), + ]; + + group.bench_with_input( + BenchmarkId::from_parameter(array_size), + &array_size, + |b, _| { + let udf = ArrayToString::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: Field::new("result", DataType::Utf8, true) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn create_int64_list_array(array_size: usize) -> ArrayRef { + Arc::new(create_primitive_list_array_with_seed::( + NUM_ROWS, + 0.0, + NULL_DENSITY as f32, + array_size, + SEED, + )) +} + +fn create_nested_int64_list_array(array_size: usize) -> ArrayRef { + let inner = create_int64_list_array(array_size); + let inner_rows = NUM_ROWS; + let outer_rows = inner_rows / NESTED_ARRAY_SIZE; + let offsets = (0..=outer_rows) + .map(|i| (i * NESTED_ARRAY_SIZE) as i32) + .collect::>(); + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", inner.data_type().clone(), true)), + OffsetBuffer::new(offsets.into()), + inner, + None, + ) + .unwrap(), + ) +} + +fn create_float64_list_array(array_size: usize) -> ArrayRef { + Arc::new(create_primitive_list_array_with_seed::( + NUM_ROWS, + 0.0, + NULL_DENSITY as f32, + array_size, + SEED, + )) +} + +fn create_string_list_array(array_size: usize) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let values = (0..NUM_ROWS * array_size) + .map(|_| { + if rng.random::() < NULL_DENSITY { + None + } else { + Some(format!("value_{}", rng.random_range(0..100))) + } + }) + .collect::(); + let offsets = (0..=NUM_ROWS) + .map(|i| (i * array_size) as i32) + .collect::>(); + + Arc::new( + ListArray::try_new( + Arc::new(Field::new("item", DataType::Utf8, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + ) + .unwrap(), + ) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/arrays_zip.rs b/datafusion/functions-nested/benches/arrays_zip.rs new file mode 100644 index 0000000000000..812e5e3dbec8a --- /dev/null +++ b/datafusion/functions-nested/benches/arrays_zip.rs @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::hint::black_box; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array, ListArray}; +use arrow::buffer::{NullBuffer, OffsetBuffer}; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::arrays_zip::ArraysZip; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; + +const NUM_ROWS: usize = 8192; +const LIST_SIZE: usize = 10; +const SEED: u64 = 42; + +/// Build a ListArray of Int64 with `num_rows` rows, each containing +/// `list_size` elements. If `null_density > 0`, that fraction of +/// rows will be null at the list level. +fn make_list_array( + rng: &mut StdRng, + num_rows: usize, + list_size: usize, + null_density: f64, +) -> ArrayRef { + let total = num_rows * list_size; + let values: Vec = (0..total).map(|_| rng.random_range(0..1000i64)).collect(); + let values_array = Arc::new(Int64Array::from(values)) as ArrayRef; + + let offsets: Vec = (0..=num_rows).map(|i| (i * list_size) as i32).collect(); + + let nulls = if null_density > 0.0 { + let valid: Vec = (0..num_rows) + .map(|_| rng.random::() >= null_density) + .collect(); + Some(NullBuffer::from(valid)) + } else { + None + }; + + Arc::new( + ListArray::try_new( + Arc::new(Field::new_list_field(DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + values_array, + nulls, + ) + .unwrap(), + ) +} + +fn bench_arrays_zip(c: &mut Criterion, name: &str, null_density: f64) { + let mut rng = StdRng::seed_from_u64(SEED); + let arr1 = make_list_array(&mut rng, NUM_ROWS, LIST_SIZE, null_density); + let arr2 = make_list_array(&mut rng, NUM_ROWS, LIST_SIZE, null_density); + let arr3 = make_list_array(&mut rng, NUM_ROWS, LIST_SIZE, null_density); + + let udf = ArraysZip::new(); + let args_vec = vec![ + ColumnarValue::Array(Arc::clone(&arr1)), + ColumnarValue::Array(Arc::clone(&arr2)), + ColumnarValue::Array(Arc::clone(&arr3)), + ]; + let return_type = udf + .return_type(&[ + arr1.data_type().clone(), + arr2.data_type().clone(), + arr3.data_type().clone(), + ]) + .unwrap(); + let return_field = Arc::new(Field::new("f", return_type, true)); + let arg_fields: Vec<_> = (0..3) + .map(|_| Arc::new(Field::new("a", arr1.data_type().clone(), true))) + .collect(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(name, |b| { + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args_vec.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .expect("arrays_zip should work"), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + bench_arrays_zip(c, "arrays_zip_perfect_zip_8192", 0.0); + bench_arrays_zip(c, "arrays_zip_10pct_nulls_8192", 0.1); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 3197cc55cc957..67e7f314d2515 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -15,48 +15,116 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - -use arrow::array::{Int32Array, ListArray, StringArray}; +use arrow::array::{ + ArrayRef, BinaryArray, BinaryViewArray, Int32Array, ListArray, StringArray, + StringViewArray, +}; use arrow::buffer::{OffsetBuffer, ScalarBuffer}; -use arrow::datatypes::{DataType, Field}; -use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion_common::config::ConfigOptions; +use arrow::datatypes::Field; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; use datafusion_expr::planner::ExprPlanner; use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionArgs}; use datafusion_functions_nested::map::map_udf; use datafusion_functions_nested::planner::NestedFunctionPlanner; -use rand::prelude::ThreadRng; use rand::Rng; +use rand::prelude::ThreadRng; use std::collections::HashSet; +use std::hash::Hash; use std::hint::black_box; use std::sync::Arc; -fn keys(rng: &mut ThreadRng) -> Vec { - let mut keys = HashSet::with_capacity(1000); +const MAP_ROWS: usize = 1000; +const MAP_KEYS_PER_ROW: usize = 1000; + +fn gen_unique_values( + rng: &mut ThreadRng, + mut make_value: impl FnMut(i32) -> T, +) -> Vec +where + T: Eq + Hash, +{ + let mut values = HashSet::with_capacity(MAP_KEYS_PER_ROW); - while keys.len() < 1000 { - keys.insert(rng.random_range(0..10000).to_string()); + while values.len() < MAP_KEYS_PER_ROW { + values.insert(make_value(rng.random_range(0..10000))); } - keys.into_iter().collect() + values.into_iter().collect() } -fn values(rng: &mut ThreadRng) -> Vec { - let mut values = HashSet::with_capacity(1000); +fn gen_repeat_values(values: &[T], repeats: usize) -> Vec { + let mut repeated = Vec::with_capacity(values.len() * repeats); - while values.len() < 1000 { - values.insert(rng.random_range(0..10000)); + for _ in 0..repeats { + repeated.extend_from_slice(values); } - values.into_iter().collect() + + repeated +} + +fn gen_utf8_values(rng: &mut ThreadRng) -> Vec { + gen_unique_values(rng, |value| value.to_string()) +} + +fn gen_binary_values(rng: &mut ThreadRng) -> Vec> { + gen_unique_values(rng, |value| value.to_le_bytes().to_vec()) +} + +fn gen_primitive_values(rng: &mut ThreadRng) -> Vec { + gen_unique_values(rng, |value| value) +} + +fn list_array(values: ArrayRef, row_count: usize, values_per_row: usize) -> ArrayRef { + let offsets = (0..=row_count) + .map(|index| (index * values_per_row) as i32) + .collect::>(); + Arc::new(ListArray::new( + Arc::new(Field::new_list_field(values.data_type().clone(), true)), + OffsetBuffer::new(ScalarBuffer::from(offsets)), + values, + None, + )) +} + +fn bench_map_case(c: &mut Criterion, name: &str, keys: ArrayRef, values: ArrayRef) { + let number_rows = keys.len(); + let keys = ColumnarValue::Array(keys); + let values = ColumnarValue::Array(values); + + let return_type = map_udf() + .return_type(&[keys.data_type(), values.data_type()]) + .expect("should get return type"); + let arg_fields = vec![ + Field::new("a", keys.data_type(), true).into(), + Field::new("a", values.data_type(), true).into(), + ]; + let return_field = Field::new("f", return_type, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(name, |b| { + b.iter(|| { + black_box( + map_udf() + .invoke_with_args(ScalarFunctionArgs { + args: vec![keys.clone(), values.clone()], + arg_fields: arg_fields.clone(), + number_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .expect("map should work on valid values"), + ); + }); + }); } fn criterion_benchmark(c: &mut Criterion) { c.bench_function("make_map_1000", |b| { let mut rng = rand::rng(); - let keys = keys(&mut rng); - let values = values(&mut rng); + let keys = gen_utf8_values(&mut rng); + let values = gen_primitive_values(&mut rng); let mut buffer = Vec::new(); for i in 0..1000 { buffer.push(Expr::Literal( @@ -65,9 +133,7 @@ fn criterion_benchmark(c: &mut Criterion) { )); buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])), None)); } - let planner = NestedFunctionPlanner {}; - b.iter(|| { black_box( planner @@ -77,51 +143,73 @@ fn criterion_benchmark(c: &mut Criterion) { }); }); - c.bench_function("map_1000", |b| { - let mut rng = rand::rng(); - let field = Arc::new(Field::new_list_field(DataType::Utf8, true)); - let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); - let key_list = ListArray::new( - field, - offsets, - Arc::new(StringArray::from(keys(&mut rng))), - None, - ); - let field = Arc::new(Field::new_list_field(DataType::Int32, true)); - let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); - let value_list = ListArray::new( - field, - offsets, - Arc::new(Int32Array::from(values(&mut rng))), - None, - ); - let keys = ColumnarValue::Scalar(ScalarValue::List(Arc::new(key_list))); - let values = ColumnarValue::Scalar(ScalarValue::List(Arc::new(value_list))); - - let return_type = map_udf() - .return_type(&[keys.data_type(), values.data_type()]) - .expect("should get return type"); - let arg_fields = vec![ - Field::new("a", keys.data_type(), true).into(), - Field::new("a", values.data_type(), true).into(), - ]; - let return_field = Field::new("f", return_type, true).into(); - let config_options = Arc::new(ConfigOptions::default()); + let mut rng = rand::rng(); + let values = Arc::new(Int32Array::from(gen_repeat_values( + &gen_primitive_values(&mut rng), + MAP_ROWS, + ))) as ArrayRef; + let values = list_array(values, MAP_ROWS, MAP_KEYS_PER_ROW); + let map_cases = [ + ( + "map_1000_utf8", + list_array( + Arc::new(StringArray::from(gen_repeat_values( + &gen_utf8_values(&mut rng), + MAP_ROWS, + ))) as ArrayRef, + MAP_ROWS, + MAP_KEYS_PER_ROW, + ), + ), + ( + "map_1000_binary", + list_array( + Arc::new(BinaryArray::from_iter_values(gen_repeat_values( + &gen_binary_values(&mut rng), + MAP_ROWS, + ))) as ArrayRef, + MAP_ROWS, + MAP_KEYS_PER_ROW, + ), + ), + ( + "map_1000_utf8_view", + list_array( + Arc::new(StringViewArray::from(gen_repeat_values( + &gen_utf8_values(&mut rng), + MAP_ROWS, + ))) as ArrayRef, + MAP_ROWS, + MAP_KEYS_PER_ROW, + ), + ), + ( + "map_1000_binary_view", + list_array( + Arc::new(BinaryViewArray::from_iter_values(gen_repeat_values( + &gen_binary_values(&mut rng), + MAP_ROWS, + ))) as ArrayRef, + MAP_ROWS, + MAP_KEYS_PER_ROW, + ), + ), + ( + "map_1000_int32", + list_array( + Arc::new(Int32Array::from(gen_repeat_values( + &gen_primitive_values(&mut rng), + MAP_ROWS, + ))) as ArrayRef, + MAP_ROWS, + MAP_KEYS_PER_ROW, + ), + ), + ]; - b.iter(|| { - black_box( - map_udf() - .invoke_with_args(ScalarFunctionArgs { - args: vec![keys.clone(), values.clone()], - arg_fields: arg_fields.clone(), - number_rows: 1, - return_field: Arc::clone(&return_field), - config_options: Arc::clone(&config_options), - }) - .expect("map should work on valid values"), - ); - }); - }); + for (name, keys) in map_cases { + bench_map_case(c, name, keys, Arc::clone(&values)); + } } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions-nested/benches/string_to_array.rs b/datafusion/functions-nested/benches/string_to_array.rs new file mode 100644 index 0000000000000..e403d5e51bac8 --- /dev/null +++ b/datafusion/functions-nested/benches/string_to_array.rs @@ -0,0 +1,244 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, StringArray}; +use arrow::datatypes::{DataType, Field}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions_nested::string::StringToArray; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: usize = 1000; +const SEED: u64 = 42; + +fn criterion_benchmark(c: &mut Criterion) { + // Single-char delimiter + let comma = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); + bench_string_to_array( + c, + "string_to_array_single_char_delim", + create_csv_strings, + &comma, + None, + ); + + // Multi-char delimiter + let double_colon = ColumnarValue::Scalar(ScalarValue::Utf8(Some("::".to_string()))); + bench_string_to_array( + c, + "string_to_array_multi_char_delim", + create_multi_delim_strings, + &double_colon, + None, + ); + + // With null_str argument + let null_str = ColumnarValue::Scalar(ScalarValue::Utf8(Some("NULL".to_string()))); + bench_string_to_array( + c, + "string_to_array_with_null_str", + create_csv_strings_with_nulls, + &comma, + Some(&null_str), + ); + + // NULL delimiter + let null_delim = ColumnarValue::Scalar(ScalarValue::Utf8(None)); + bench_string_to_array( + c, + "string_to_array_null_delim", + create_short_strings, + &null_delim, + None, + ); + + // Columnar delimiter (fall-back path) + bench_string_to_array_columnar_delim(c); +} + +fn bench_string_to_array_columnar_delim(c: &mut Criterion) { + let mut group = c.benchmark_group("string_to_array_columnar_delim"); + + for &num_elements in &[5, 20, 100] { + let string_array = create_csv_strings(num_elements); + let delimiter_array: ArrayRef = + Arc::new(StringArray::from(vec![Some(","); NUM_ROWS])); + + let args = vec![ + ColumnarValue::Array(string_array.clone()), + ColumnarValue::Array(delimiter_array), + ]; + let arg_fields = vec![ + Field::new("str", DataType::Utf8, true).into(), + Field::new("delimiter", DataType::Utf8, false).into(), + ]; + + let return_field = Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field(DataType::Utf8, true))), + true, + ); + + group.bench_with_input( + BenchmarkId::from_parameter(num_elements), + &num_elements, + |b, _| { + let udf = StringToArray::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone().into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +fn bench_string_to_array( + c: &mut Criterion, + group_name: &str, + make_strings: fn(usize) -> ArrayRef, + delimiter: &ColumnarValue, + null_str: Option<&ColumnarValue>, +) { + let mut group = c.benchmark_group(group_name); + + for &num_elements in &[5, 20, 100] { + let string_array = make_strings(num_elements); + + let mut args = vec![ + ColumnarValue::Array(string_array.clone()), + delimiter.clone(), + ]; + let mut arg_fields = vec![ + Field::new("str", DataType::Utf8, true).into(), + Field::new("delimiter", DataType::Utf8, true).into(), + ]; + if let Some(ns) = null_str { + args.push(ns.clone()); + arg_fields.push(Field::new("null_str", DataType::Utf8, true).into()); + } + + let return_field = Field::new( + "result", + DataType::List(Arc::new(Field::new_list_field(DataType::Utf8, true))), + true, + ); + + group.bench_with_input( + BenchmarkId::from_parameter(num_elements), + &num_elements, + |b, _| { + let udf = StringToArray::new(); + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: NUM_ROWS, + return_field: return_field.clone().into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }, + ); + } + + group.finish(); +} + +/// Creates strings like "val1,val2,val3,...,valN" with `num_elements` elements. +fn create_csv_strings(num_elements: usize) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let strings: StringArray = (0..NUM_ROWS) + .map(|_| { + let parts: Vec = (0..num_elements) + .map(|_| format!("val{}", rng.random_range(0..1000))) + .collect(); + Some(parts.join(",")) + }) + .collect(); + Arc::new(strings) +} + +/// Creates strings like "val1::val2::val3::...::valN". +fn create_multi_delim_strings(num_elements: usize) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let strings: StringArray = (0..NUM_ROWS) + .map(|_| { + let parts: Vec = (0..num_elements) + .map(|_| format!("val{}", rng.random_range(0..1000))) + .collect(); + Some(parts.join("::")) + }) + .collect(); + Arc::new(strings) +} + +/// Creates CSV strings where ~10% of elements are the literal "NULL". +fn create_csv_strings_with_nulls(num_elements: usize) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let strings: StringArray = (0..NUM_ROWS) + .map(|_| { + let parts: Vec = (0..num_elements) + .map(|_| { + if rng.random::() < 0.1 { + "NULL".to_string() + } else { + format!("val{}", rng.random_range(0..1000)) + } + }) + .collect(); + Some(parts.join(",")) + }) + .collect(); + Arc::new(strings) +} + +/// Creates short strings (length = `num_chars`) for the NULL-delimiter +/// (split-into-characters) benchmark. +fn create_short_strings(num_chars: usize) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(SEED); + let strings: StringArray = (0..NUM_ROWS) + .map(|_| { + let s: String = (0..num_chars) + .map(|_| rng.random_range(b'a'..=b'z') as char) + .collect(); + Some(s) + }) + .collect(); + Arc::new(strings) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/src/array_add.rs b/datafusion/functions-nested/src/array_add.rs new file mode 100644 index 0000000000000..c6edf67bf5a93 --- /dev/null +++ b/datafusion/functions-nested/src/array_add.rs @@ -0,0 +1,203 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_add function. + +use crate::utils::{coerce_array_math_arg_types, make_scalar_function}; +use arrow::array::{ + Array, ArrayRef, Float64Array, GenericListArray, NullBufferBuilder, + OffsetBufferBuilder, OffsetSizeTrait, +}; +use arrow::buffer::NullBuffer; +use arrow::datatypes::{ + DataType, + DataType::{LargeList, List}, + Field, +}; +use datafusion_common::cast::{as_float64_array, as_generic_list_array}; +use datafusion_common::{Result, exec_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +make_udf_expr_and_func!( + ArrayAdd, + array_add, + array1 array2, + "returns the element-wise sum of two numeric arrays.", + array_add_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the element-wise sum of two numeric arrays of equal length, computed as `array1[i] + array2[i]` per position. NULL is propagated per element: if either input element at position `i` is NULL, the corresponding output element is NULL (positions are preserved). Returns NULL if either entire input array is NULL. Errors if the per-row lengths differ. Returns an empty array if both inputs are empty.", + syntax_example = "array_add(array1, array2)", + sql_example = r#"```sql +> select array_add([1.0, 2.0, 3.0], [10.0, 20.0, 30.0]); ++---------------------------------------------------------+ +| array_add(List([1.0,2.0,3.0]),List([10.0,20.0,30.0])) | ++---------------------------------------------------------+ +| [11.0, 22.0, 33.0] | ++---------------------------------------------------------+ +```"#, + argument( + name = "array1", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array2", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayAdd { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayAdd { + fn default() -> Self { + Self::new() + } +} + +impl ArrayAdd { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_add".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayAdd { + fn name(&self) -> &str { + "array_add" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // After `coerce_types`, both args share the same List/LargeList shape. + Ok(arg_types[0].clone()) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [_, _] = take_function_args(self.name(), arg_types)?; + coerce_array_math_arg_types(self.name(), arg_types) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_add_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn array_add_inner(args: &[ArrayRef]) -> Result { + let [array1, array2] = take_function_args("array_add", args)?; + match (array1.data_type(), array2.data_type()) { + (List(_), List(_)) => general_array_add::(array1, array2), + (LargeList(_), LargeList(_)) => general_array_add::(array1, array2), + (arg_type1, arg_type2) => exec_err!( + "array_add received unexpected types after coercion: {arg_type1} and {arg_type2}" + ), + } +} + +fn general_array_add( + lhs: &ArrayRef, + rhs: &ArrayRef, +) -> Result { + let lhs = as_generic_list_array::(lhs)?; + let rhs = as_generic_list_array::(rhs)?; + + let lhs_values = as_float64_array(lhs.values())?; + let rhs_values = as_float64_array(rhs.values())?; + let lhs_offsets = lhs.value_offsets(); + let rhs_offsets = rhs.value_offsets(); + + // Row-level validity: a row is valid iff both sides are valid at that row. + let row_nulls = NullBuffer::union(lhs.nulls(), rhs.nulls()); + + let mut out_values: Vec = Vec::with_capacity(lhs_values.len()); + let mut out_inner_nulls = NullBufferBuilder::new(lhs_values.len()); + let mut out_offsets = OffsetBufferBuilder::::new(lhs.len()); + + for row in 0..lhs.len() { + // Whole-row NULL on either side -> NULL output row, no elements. + if row_nulls.as_ref().is_some_and(|nb| nb.is_null(row)) { + out_offsets.push_length(0); + continue; + } + + let start1 = lhs_offsets[row].as_usize(); + let len1 = lhs.value_length(row).as_usize(); + let start2 = rhs_offsets[row].as_usize(); + let len2 = rhs.value_length(row).as_usize(); + + if len1 != len2 { + return exec_err!( + "array_add requires both list inputs to have the same length per row, got {len1} and {len2} at row {row}" + ); + } + + let l_slice = lhs_values.slice(start1, len1); + let r_slice = rhs_values.slice(start2, len2); + + let l_vals = l_slice.values(); + let r_vals = r_slice.values(); + + for i in 0..len1 { + out_values.push(l_vals[i] + r_vals[i]); + } + + // Per-element validity: position `i` is valid iff both lhs[i] and rhs[i] + // are valid. `NullBuffer::union` returns `None` when both sides are + // entirely valid. + match NullBuffer::union(l_slice.nulls(), r_slice.nulls()) { + Some(nb) => out_inner_nulls.append_buffer(&nb), + None => out_inner_nulls.append_n_non_nulls(len1), + } + + out_offsets.push_length(len1); + } + + let values_array = Arc::new(Float64Array::new( + out_values.into(), + out_inner_nulls.finish(), + )); + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + + Ok(Arc::new(GenericListArray::::try_new( + field, + out_offsets.finish(), + values_array, + row_nulls, + )?)) +} diff --git a/datafusion/functions-nested/src/array_any_match.rs b/datafusion/functions-nested/src/array_any_match.rs new file mode 100644 index 0000000000000..c8ba978881394 --- /dev/null +++ b/datafusion/functions-nested/src/array_any_match.rs @@ -0,0 +1,521 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`datafusion_expr::HigherOrderUDF`] definitions for array_any_match function. + +use arrow::{ + array::{Array, AsArray, BooleanArray, BooleanBuilder, new_null_array}, + buffer::NullBuffer, + compute::take_arrays, + datatypes::{ArrowNativeType, DataType, Field, FieldRef}, +}; +use datafusion_common::{ + Result, exec_datafusion_err, exec_err, plan_err, + utils::{ + adjust_offsets_for_slice, list_values, list_values_row_number, take_function_args, + }, +}; +use datafusion_expr::{ + ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, + HigherOrderSignature, HigherOrderUDFImpl, LambdaParametersProgress, ValueOrLambda, + Volatility, +}; +use datafusion_macros::user_doc; +use std::{fmt::Debug, sync::Arc}; + +make_higher_order_function_expr_and_func!( + ArrayAnyMatch, + array_any_match, + array lambda, + "returns true if any element in the array satisfies the predicate", + array_any_match_higher_order_function +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns whether any elements of an array match the given predicate. Returns true if one or more elements match, false if none match (including empty arrays), and null if the predicate returns null for some elements and false for all others.", + syntax_example = "any_match(array, predicate)", + sql_example = r#"```sql +> select any_match([1, 2, 3], x -> x > 2); ++----------------------------------+ +| any_match([1, 2, 3], x -> x > 2) | ++----------------------------------+ +| true | ++----------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "predicate", + description = "Lambda predicate that returns a boolean" + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayAnyMatch { + signature: HigherOrderSignature, + aliases: Vec, +} + +impl Default for ArrayAnyMatch { + fn default() -> Self { + Self::new() + } +} + +impl ArrayAnyMatch { + pub fn new() -> Self { + Self { + signature: HigherOrderSignature::exact( + vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())], + Volatility::Immutable, + ), + aliases: vec![String::from("any_match"), String::from("list_any_match")], + } + } +} + +// Returns Some(true) if any element in [start, end) is true, +// None if no element is true but some are null, +// Some(false) if all are false or range is empty. +fn any_match_for_range( + predicate: &BooleanArray, + start: usize, + end: usize, +) -> Option { + let any_true = (start..end).any(|j| predicate.is_valid(j) && predicate.value(j)); + if any_true { + return Some(true); + } + let any_null = (start..end).any(|j| predicate.is_null(j)); + if any_null { None } else { Some(false) } +} + +impl HigherOrderUDFImpl for ArrayAnyMatch { + fn name(&self) -> &str { + "array_any_match" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &HigherOrderSignature { + &self.signature + } + + fn coerce_value_types(&self, arg_types: &[DataType]) -> Result> { + let [list] = arg_types else { + return plan_err!( + "{} function requires 1 value argument, got {}", + self.name(), + arg_types.len() + ); + }; + + let coerced = match list { + DataType::List(_) | DataType::LargeList(_) => list.clone(), + DataType::ListView(field) | DataType::FixedSizeList(field, _) => { + DataType::List(Arc::clone(field)) + } + DataType::LargeListView(field) => DataType::LargeList(Arc::clone(field)), + _ => { + return plan_err!( + "{} expected a list as first argument, got {}", + self.name(), + list + ); + } + }; + + Ok(vec![coerced]) + } + + fn lambda_parameters( + &self, + _step: usize, + fields: &[ValueOrLambda>], + ) -> Result { + let [list, _] = take_function_args(self.name(), fields)?; + let ValueOrLambda::Value(list) = list else { + return plan_err!("{} expects a value as first argument", self.name()); + }; + + let field = match list.data_type() { + DataType::List(field) => field, + DataType::LargeList(field) => field, + other => return plan_err!("expected list, got {other}"), + }; + + Ok(LambdaParametersProgress::Complete(vec![vec![Arc::clone( + field, + )]])) + } + + fn return_field_from_args( + &self, + args: HigherOrderReturnFieldArgs, + ) -> Result> { + let [ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)] = + take_function_args(self.name(), args.arg_fields)? + else { + return plan_err!("{} expects a value as first argument", self.name()); + }; + let nullable = list.is_nullable() || lambda.is_nullable(); + Ok(Arc::new(Field::new("", DataType::Boolean, nullable))) + } + + fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result { + let [ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)] = + take_function_args(self.name(), &args.args)? + else { + return exec_err!("{} expects a value followed by a lambda", self.name()); + }; + + let list_array = list.to_array(args.number_rows)?; + + // fast path: fully null input — also required for FixedSizeList which can't be + // handled by clear_null_values when fully null + if list_array.null_count() == list_array.len() { + return Ok(ColumnarValue::Array(new_null_array( + args.return_type(), + list_array.len(), + ))); + } + + let list_values = list_values(&list_array)?; + + let values_param = || Ok(Arc::clone(&list_values)); + + let predicate_results = lambda + .evaluate(&[&values_param], |arrays| { + let indices = list_values_row_number(&list_array)?; + Ok(take_arrays(arrays, &indices, None)?) + })? + .into_array(list_values.len())?; + + let predicate_bool = predicate_results + .as_any() + .downcast_ref::() + .ok_or_else(|| { + exec_datafusion_err!( + "{} predicate must return boolean array", + self.name() + ) + })?; + + let mut values = BooleanBuilder::with_capacity(list_array.len()); + + // Maps predicate results (flat over all elements) back to one Boolean per row. + // Uses adjusted offsets so sliced lists index correctly into the predicate array. + macro_rules! process_list { + ($list_typed:expr) => {{ + let offsets = adjust_offsets_for_slice($list_typed); + for i in 0..$list_typed.len() { + let start = offsets[i].as_usize(); + let end = offsets[i + 1].as_usize(); + // any_match_for_range returns None when nulls poison the result; + // null rows produce an empty range and return Some(false), but their + // null bit is preserved by attaching the original null bitmap below. + values.append_option(any_match_for_range(predicate_bool, start, end)); + } + }}; + } + + match list_array.data_type() { + DataType::List(_) => { + process_list!(list_array.as_list::()); + } + DataType::LargeList(_) => { + process_list!(list_array.as_list::()); + } + other => return exec_err!("expected list, got {other}"), + } + + let (boolean_buffer, predicate_nulls) = values.finish().into_parts(); + // Merge: a row is null if the input list row was null or the predicate returned null. + let nulls = NullBuffer::union(list_array.nulls(), predicate_nulls.as_ref()); + Ok(ColumnarValue::Array(Arc::new(BooleanArray::new( + boolean_buffer, + nulls, + )))) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, sync::Arc}; + + use arrow::{ + array::{ArrayRef, BooleanArray, Int32Array, ListArray, RecordBatch}, + buffer::{NullBuffer, OffsetBuffer}, + datatypes::{DataType, Field}, + }; + use datafusion_common::{DFSchema, Result}; + use datafusion_expr::{ + Expr, HigherOrderReturnFieldArgs, HigherOrderUDFImpl, ValueOrLambda, col, + execution_props::ExecutionProps, + expr::{HigherOrderFunction, LambdaVariable}, + lambda, lit, + }; + use datafusion_physical_expr::create_physical_expr; + + use crate::array_any_match::{ArrayAnyMatch, array_any_match_higher_order_function}; + + fn run_any_match( + list: impl arrow::array::Array + Clone + 'static, + ) -> Result { + let schema = DFSchema::from_unqualified_fields( + vec![Field::new( + "list", + list.data_type().clone(), + list.is_nullable(), + )] + .into(), + HashMap::new(), + )?; + + create_physical_expr( + &Expr::HigherOrderFunction(HigherOrderFunction::new( + array_any_match_higher_order_function(), + vec![ + col("list"), + lambda( + ["x"], + Expr::LambdaVariable(LambdaVariable::new( + "x".to_string(), + Some(Arc::new(Field::new("x", DataType::Int32, true))), + )) + .gt(lit(2i32)), + ), + ], + )), + &schema, + &ExecutionProps::new(), + )? + .evaluate(&RecordBatch::try_new( + Arc::clone(schema.inner()), + vec![Arc::new(list.clone())], + )?)? + .into_array(list.len()) + } + + fn run_any_match_div( + list: impl arrow::array::Array + Clone + 'static, + ) -> Result { + let schema = DFSchema::from_unqualified_fields( + vec![Field::new( + "list", + list.data_type().clone(), + list.is_nullable(), + )] + .into(), + HashMap::new(), + )?; + + let x = Expr::LambdaVariable(LambdaVariable::new( + "x".to_string(), + Some(Arc::new(Field::new("x", DataType::Int32, true))), + )); + // predicate: (100 / x) > 5 — panics on divide by zero if x == 0 is evaluated + create_physical_expr( + &Expr::HigherOrderFunction(HigherOrderFunction::new( + array_any_match_higher_order_function(), + vec![col("list"), lambda(["x"], (lit(100i32) / x).gt(lit(5i32)))], + )), + &schema, + &ExecutionProps::new(), + )? + .evaluate(&RecordBatch::try_new( + Arc::clone(schema.inner()), + vec![Arc::new(list.clone())], + )?)? + .into_array(list.len()) + } + + fn make_list(values: Vec, offsets: OffsetBuffer) -> ListArray { + make_list_with_nulls(values, offsets, None) + } + + fn make_list_with_nulls( + values: Vec, + offsets: OffsetBuffer, + nulls: Option, + ) -> ListArray { + ListArray::new( + Arc::new(Field::new_list_field(DataType::Int32, true)), + offsets, + Arc::new(Int32Array::from(values)), + nulls, + ) + } + + #[test] + fn test_any_match_some_true() -> Result<()> { + let list = make_list(vec![1, 2, 3], OffsetBuffer::from_lengths(vec![3])); + let result = run_any_match(list)?; + assert_eq!( + result.as_any().downcast_ref::().unwrap(), + &BooleanArray::from(vec![Some(true)]) + ); + Ok(()) + } + + #[test] + fn test_any_match_none_true() -> Result<()> { + let list = make_list(vec![1, 2], OffsetBuffer::from_lengths(vec![2])); + let result = run_any_match(list)?; + assert_eq!( + result.as_any().downcast_ref::().unwrap(), + &BooleanArray::from(vec![Some(false)]) + ); + Ok(()) + } + + #[test] + fn test_any_match_empty_array() -> Result<()> { + let list = make_list(vec![], OffsetBuffer::from_lengths(vec![0])); + let result = run_any_match(list)?; + assert_eq!( + result.as_any().downcast_ref::().unwrap(), + &BooleanArray::from(vec![Some(false)]) + ); + Ok(()) + } + + #[test] + fn test_any_match_multiple_rows() -> Result<()> { + let list = make_list(vec![1, 2, 3, 1, 2], OffsetBuffer::from_lengths(vec![3, 2])); + let result = run_any_match(list)?; + assert_eq!( + result.as_any().downcast_ref::().unwrap(), + &BooleanArray::from(vec![Some(true), Some(false)]) + ); + Ok(()) + } + + #[test] + fn test_any_match_return_field_nullability() -> Result<()> { + for list_nullable in [true, false] { + for lambda_nullable in [true, false] { + let list = Arc::new(Field::new( + "list", + DataType::new_list(DataType::Int32, true), + list_nullable, + )); + let lambda = + Arc::new(Field::new("predicate", DataType::Boolean, lambda_nullable)); + let arg_fields = [ + ValueOrLambda::Value(Arc::clone(&list)), + ValueOrLambda::Lambda(Arc::clone(&lambda)), + ]; + let scalar_arguments = [None, None]; + + let result = ArrayAnyMatch::new().return_field_from_args( + HigherOrderReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &scalar_arguments, + }, + )?; + + assert_eq!( + result, + Arc::new(Field::new( + "", + DataType::Boolean, + list_nullable || lambda_nullable, + )) + ); + } + } + + Ok(()) + } + + // Predicate must not be evaluated on elements belonging to null rows. + // The 10 in the null row would satisfy x > 5, but the row result must be None. + #[test] + fn test_any_match_should_not_evaluate_predicate_on_values_underlying_null() + -> Result<()> { + let list = make_list_with_nulls( + vec![1, 2, 10, 1, 2], + OffsetBuffer::from_lengths(vec![3, 2]), + Some(NullBuffer::from(vec![false, true])), + ); + let result = run_any_match(list)?; + assert_eq!( + result.as_any().downcast_ref::().unwrap(), + &BooleanArray::from(vec![None, Some(false)]) + ); + Ok(()) + } + + // Predicate must not be evaluated on elements before the slice offset. + // The 10 before the slice would satisfy x > 5, but it is unreachable. + #[test] + fn test_any_match_on_sliced_list_should_not_evaluate_on_unreachable_values() + -> Result<()> { + let list = make_list( + vec![10, 1, 2, 1, 2], + OffsetBuffer::from_lengths(vec![1, 2, 2]), + ) + .slice(1, 2); + let result = run_any_match(list)?; + assert_eq!( + result.as_any().downcast_ref::().unwrap(), + &BooleanArray::from(vec![Some(false), Some(false)]) + ); + Ok(()) + } + + // 0 in the null row would cause divide by zero if the predicate is evaluated on it. + #[test] + fn test_any_match_does_not_evaluate_predicate_on_null_row_values() -> Result<()> { + let list = make_list_with_nulls( + vec![1, 2, 0, 4, 5], + OffsetBuffer::from_lengths(vec![3, 2]), + Some(NullBuffer::from(vec![false, true])), + ); + let result = run_any_match_div(list)?; + assert_eq!( + result.as_any().downcast_ref::().unwrap(), + &BooleanArray::from(vec![None, Some(true)]) + ); + Ok(()) + } + + // 0 before the slice offset would cause divide by zero if evaluated. + #[test] + fn test_any_match_does_not_evaluate_predicate_on_unreachable_values() -> Result<()> { + let list = make_list( + vec![0, 4, 5, 50, 100], + OffsetBuffer::from_lengths(vec![1, 2, 2]), + ) + .slice(1, 2); + let result = run_any_match_div(list)?; + assert_eq!( + result.as_any().downcast_ref::().unwrap(), + &BooleanArray::from(vec![Some(true), Some(false)]) + ); + Ok(()) + } +} diff --git a/datafusion/functions-nested/src/array_compact.rs b/datafusion/functions-nested/src/array_compact.rs new file mode 100644 index 0000000000000..11be494b5b20f --- /dev/null +++ b/datafusion/functions-nested/src/array_compact.rs @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_compact function. + +use crate::utils::make_scalar_function; +use arrow::array::{ + Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetSizeTrait, + make_array, +}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{LargeList, List, Null}; +use datafusion_common::cast::{as_large_list_array, as_list_array}; +use datafusion_common::{Result, exec_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +make_udf_expr_and_func!( + ArrayCompact, + array_compact, + array, + "removes null values from the array.", + array_compact_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Removes null values from the array.", + syntax_example = "array_compact(array)", + sql_example = r#"```sql +> select array_compact([1, NULL, 2, NULL, 3]) arr; ++-----------+ +| arr | ++-----------+ +| [1, 2, 3] | ++-----------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayCompact { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayCompact { + fn default() -> Self { + Self::new() + } +} + +impl ArrayCompact { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec!["list_compact".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayCompact { + fn name(&self) -> &str { + "array_compact" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_compact_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +/// array_compact SQL function +fn array_compact_inner(arg: &[ArrayRef]) -> Result { + let [input_array] = take_function_args("array_compact", arg)?; + + match &input_array.data_type() { + List(field) => { + let array = as_list_array(input_array)?; + compact_list::(array, field) + } + LargeList(field) => { + let array = as_large_list_array(input_array)?; + compact_list::(array, field) + } + Null => Ok(Arc::clone(input_array)), + array_type => exec_err!("array_compact does not support type '{array_type}'."), + } +} + +/// Remove null elements from each row of a list array. +fn compact_list( + list_array: &GenericListArray, + field: &Arc, +) -> Result { + let values = list_array.values(); + + // Fast path: no nulls in values, return input unchanged + if values.null_count() == 0 { + return Ok(Arc::new(list_array.clone())); + } + + let original_data = values.to_data(); + let capacity = original_data.len() - values.null_count(); + let mut offsets = Vec::::with_capacity(list_array.len() + 1); + offsets.push(O::zero()); + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data], + false, + Capacities::Array(capacity), + ); + + for row_index in 0..list_array.len() { + if list_array.nulls().is_some_and(|n| n.is_null(row_index)) { + offsets.push(offsets[row_index]); + continue; + } + + let start = list_array.offsets()[row_index].as_usize(); + let end = list_array.offsets()[row_index + 1].as_usize(); + let mut copied = 0usize; + + // Batch consecutive non-null elements into single extend() calls + // to reduce per-element overhead. For [1, 2, NULL, 3, 4] this + // produces 2 extend calls (0..2, 3..5) instead of 4 individual ones. + let mut batch_start: Option = None; + for i in start..end { + if values.is_null(i) { + // Null breaks the current batch — flush it + if let Some(bs) = batch_start { + mutable.extend(0, bs, i); + copied += i - bs; + batch_start = None; + } + } else if batch_start.is_none() { + batch_start = Some(i); + } + } + // Flush any remaining batch after the loop + if let Some(bs) = batch_start { + mutable.extend(0, bs, end); + copied += end - bs; + } + + offsets.push(offsets[row_index] + O::usize_as(copied)); + } + + let new_values = make_array(mutable.freeze()); + Ok(Arc::new(GenericListArray::::try_new( + Arc::clone(field), + OffsetBuffer::new(offsets.into()), + new_values, + list_array.nulls().cloned(), + )?)) +} diff --git a/datafusion/functions-nested/src/array_filter.rs b/datafusion/functions-nested/src/array_filter.rs new file mode 100644 index 0000000000000..a1fa8268a31a9 --- /dev/null +++ b/datafusion/functions-nested/src/array_filter.rs @@ -0,0 +1,464 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`datafusion_expr::HigherOrderUDF`] definitions for array_filter function. + +use arrow::{ + array::{ + Array, ArrayRef, AsArray, BooleanArray, LargeListArray, ListArray, + OffsetBufferBuilder, OffsetSizeTrait, new_empty_array, + }, + buffer::{OffsetBuffer, ScalarBuffer}, + compute::{filter as arrow_filter, take_arrays}, + datatypes::{DataType, Field, FieldRef}, +}; +use datafusion_common::{ + Result, ScalarValue, exec_err, + utils::{adjust_offsets_for_slice, list_values_row_number}, +}; +use datafusion_expr::{ + ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, + HigherOrderSignature, HigherOrderUDFImpl, LambdaParametersProgress, ValueOrLambda, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +use crate::lambda_utils::{ + ListValuesResult, coerce_single_list_arg, extract_list_values, + single_list_lambda_parameters, value_lambda_pair, +}; + +make_higher_order_function_expr_and_func!( + ArrayFilter, + array_filter, + array lambda, + "filters the values of an array using a boolean lambda", + array_filter_higher_order_function +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "filters the values of an array using a boolean lambda", + syntax_example = "array_filter(array, x -> x > 2)", + sql_example = r#"```sql +> select array_filter([1, 2, 3, 4, 5], x -> x > 2); ++--------------------------------------------+ +| array_filter([1, 2, 3, 4, 5], x -> x > 2) | ++--------------------------------------------+ +| [3, 4, 5] | ++--------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "lambda", + description = "Lambda that returns a boolean. Elements for which the lambda returns true are kept." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayFilter { + signature: HigherOrderSignature, + aliases: Vec, +} + +impl Default for ArrayFilter { + fn default() -> Self { + Self::new() + } +} + +impl ArrayFilter { + pub fn new() -> Self { + Self { + signature: HigherOrderSignature::exact( + vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())], + Volatility::Immutable, + ), + aliases: vec![String::from("list_filter")], + } + } +} + +impl HigherOrderUDFImpl for ArrayFilter { + fn name(&self) -> &str { + "array_filter" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &HigherOrderSignature { + &self.signature + } + + fn lambda_parameters( + &self, + _step: usize, + fields: &[ValueOrLambda>], + ) -> Result { + single_list_lambda_parameters(self.name(), fields) + } + + fn return_field_from_args( + &self, + args: HigherOrderReturnFieldArgs, + ) -> Result> { + let (list, _lambda) = value_lambda_pair(self.name(), args.arg_fields)?; + Ok(Arc::new(Field::new( + "", + list.data_type().clone(), + list.is_nullable(), + ))) + } + + fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result { + let (list, lambda) = value_lambda_pair(self.name(), &args.args)?; + let list_array = list.to_array(args.number_rows)?; + + let list_values = match extract_list_values(&list_array, args.return_type())? { + ListValuesResult::EarlyReturn(v) => return Ok(v), + ListValuesResult::Values(v) => v, + }; + + let field = match args.return_field.data_type() { + DataType::List(field) | DataType::LargeList(field) => Arc::clone(field), + _ => { + return exec_err!( + "{} expected return_field to be a list, got {}", + self.name(), + args.return_field + ); + } + }; + + let values_param = || Ok(Arc::clone(&list_values)); + let predicate_output = lambda.evaluate(&[&values_param], |arrays| { + let indices = list_values_row_number(&list_array)?; + Ok(take_arrays(arrays, &indices, None)?) + })?; + + // Scalar predicate short-circuit: x -> true or x -> false/null + if let ColumnarValue::Scalar(ScalarValue::Boolean(b)) = &predicate_output { + return match b { + Some(true) => Ok(ColumnarValue::Array(list_array)), + _ => Ok(ColumnarValue::Array(empty_filtered_list( + &list_array, + field, + )?)), + }; + } + + let predicate = predicate_output.into_array(list_values.len())?; + let Some(predicate) = predicate.as_any().downcast_ref::() else { + return exec_err!( + "{} lambda must return boolean, got {}", + self.name(), + predicate.data_type() + ); + }; + + // ListView and LargeListView are coerced to List/LargeList by coerce_value_types. + let filtered_list = match list_array.data_type() { + DataType::List(_) => { + let list = list_array.as_list::(); + let adjusted_offsets = adjust_offsets_for_slice(list); + let (filtered_values, new_offsets) = + filter_list_values(&list_values, predicate, &adjusted_offsets)?; + Arc::new(ListArray::new( + field, + new_offsets, + filtered_values, + list.nulls().cloned(), + )) as ArrayRef + } + DataType::LargeList(_) => { + let large_list = list_array.as_list::(); + let adjusted_offsets = adjust_offsets_for_slice(large_list); + let (filtered_values, new_offsets) = + filter_list_values(&list_values, predicate, &adjusted_offsets)?; + Arc::new(LargeListArray::new( + field, + new_offsets, + filtered_values, + large_list.nulls().cloned(), + )) + } + other => exec_err!("expected list, got {other}")?, + }; + + Ok(ColumnarValue::Array(filtered_list)) + } + + fn coerce_value_types(&self, arg_types: &[DataType]) -> Result> { + coerce_single_list_arg(self.name(), arg_types) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +/// Returns a list array with every non-null sublist emptied, preserving the null buffer. +/// Used for the `x -> false` / `x -> null` scalar predicate short-circuit. +fn empty_filtered_list(list_array: &ArrayRef, field: FieldRef) -> Result { + let n = list_array.len(); + let empty_values = new_empty_array(field.data_type()); + Ok(match list_array.data_type() { + DataType::List(_) => { + let list = list_array.as_list::(); + Arc::new(ListArray::new( + field, + OffsetBuffer::new(ScalarBuffer::from(vec![0i32; n + 1])), + empty_values, + list.nulls().cloned(), + )) + } + DataType::LargeList(_) => { + let list = list_array.as_list::(); + Arc::new(LargeListArray::new( + field, + OffsetBuffer::new(ScalarBuffer::from(vec![0i64; n + 1])), + empty_values, + list.nulls().cloned(), + )) + } + other => return exec_err!("expected list, got {other}"), + }) +} + +/// Filters flat list values using a boolean predicate, returning filtered values and +/// recomputed per-sublist offsets. Null predicate values are treated as false. +fn filter_list_values( + values: &ArrayRef, + predicate: &BooleanArray, + offsets: &OffsetBuffer, +) -> Result<(ArrayRef, OffsetBuffer)> { + let num_sublists = offsets.len().saturating_sub(1); + let mut builder = OffsetBufferBuilder::::new(num_sublists); + + let has_nulls = predicate.null_count() > 0; + for i in 0..num_sublists { + let start = offsets[i].as_usize(); + let end = offsets[i + 1].as_usize(); + let count = if has_nulls { + (start..end) + .filter(|&j| predicate.is_valid(j) && predicate.value(j)) + .count() + } else { + predicate + .values() + .slice(start, end - start) + .count_set_bits() + }; + builder.push_length(count); + } + + let new_offsets = builder.finish(); + + if new_offsets.last() == offsets.last() { + return Ok((Arc::clone(values), offsets.clone())); + } + + // arrow_filter treats null predicate values as false + let filtered_values = arrow_filter(values.as_ref(), predicate)?; + Ok((filtered_values, new_offsets)) +} + +#[cfg(test)] +mod tests { + use arrow::{ + array::{Array, AsArray}, + buffer::{NullBuffer, OffsetBuffer}, + }; + + use crate::array_filter::array_filter_higher_order_function; + use crate::lambda_utils::test_utils::{create_i32_list, eval_hof_on_i32_list, v}; + use datafusion_expr::lit; + + fn keep_greater_than_two( + list: impl Array + Clone + 'static, + ) -> datafusion_common::Result { + eval_hof_on_i32_list( + array_filter_higher_order_function(), + list, + v().gt(lit(2i32)), + ) + } + + #[test] + fn filter_basic() { + let list = create_i32_list( + vec![1, 2, 3, 4, 5], + OffsetBuffer::::from_lengths(vec![5]), + None, + ); + + let res = keep_greater_than_two(list).unwrap(); + let actual = res.as_list::(); + + let expected = create_i32_list( + vec![3, 4, 5], + OffsetBuffer::::from_lengths(vec![3]), + None, + ); + + assert_eq!(actual, &expected); + } + + #[test] + fn filter_multiple_sublists() { + let list = create_i32_list( + vec![1, 5, 2, 4, 3], + OffsetBuffer::::from_lengths(vec![2, 3]), + None, + ); + + let res = keep_greater_than_two(list).unwrap(); + let actual = res.as_list::(); + + // [1,5] -> [5], [2,4,3] -> [4,3] + let expected = create_i32_list( + vec![5, 4, 3], + OffsetBuffer::::from_lengths(vec![1, 2]), + None, + ); + + assert_eq!(actual, &expected); + } + + #[test] + fn filter_on_sliced_list_should_not_evaluate_on_unreachable_values() { + // First sublist [0] is sliced away; sliced array covers sublists [1..3] + let list = create_i32_list( + vec![ + 0, // unreachable after slice — if evaluated, it would appear in output + 1, 5, 2, 4, 3, 7, + ], + OffsetBuffer::::from_lengths(vec![1, 3, 3]), + None, + ) + .slice(1, 2); + + let res = keep_greater_than_two(list).unwrap(); + let actual = res.as_list::(); + + // [1,5,2] -> [5], [4,3,7] -> [4,3,7] + let expected = create_i32_list( + vec![5, 4, 3, 7], + OffsetBuffer::::from_lengths(vec![1, 3]), + None, + ); + + assert_eq!(actual, &expected); + } + + #[test] + fn filter_should_not_be_evaluated_on_values_underlying_null() { + // The null sublist (index 1) contains values that would pass the predicate + // if evaluated. We verify they do NOT appear in the output. + let list = create_i32_list( + vec![1, 5, 99, 100, 3, 7], + OffsetBuffer::::from_lengths(vec![2, 2, 2]), + Some(NullBuffer::from(vec![true, false, true])), + ); + + let res = keep_greater_than_two(list).unwrap(); + let actual = res.as_list::(); + + // sublist 0: [1,5] -> [5] + // sublist 1: null -> null (empty range, null bit) + // sublist 2: [3,7] -> [3,7] + let expected = create_i32_list( + vec![5, 3, 7], + OffsetBuffer::::from_lengths(vec![1, 0, 2]), + Some(NullBuffer::from(vec![true, false, true])), + ); + + assert_eq!(actual.data_type(), expected.data_type()); + assert_eq!(actual, &expected); + } + + #[test] + fn filter_all_filtered_out() { + let list = + create_i32_list(vec![1, 2], OffsetBuffer::::from_lengths(vec![2]), None); + + let res = keep_greater_than_two(list).unwrap(); + let actual = res.as_list::(); + + let expected = create_i32_list( + vec![0i32; 0], + OffsetBuffer::::from_lengths(vec![0]), + None, + ); + + assert_eq!(actual, &expected); + } + + #[test] + fn filter_nothing_filtered_reuses_values() { + let list = create_i32_list( + vec![3, 4, 5], + OffsetBuffer::::from_lengths(vec![3]), + None, + ); + // all elements > 2, so nothing is filtered — values buffer should be reused + let res = keep_greater_than_two(list.clone()).unwrap(); + assert_eq!(res.as_list::(), &list); + } + + #[test] + fn scalar_true_predicate_returns_original_list() { + let list = create_i32_list( + vec![1, 2, 3], + OffsetBuffer::::from_lengths(vec![3]), + None, + ); + // x -> true: every element kept, should return list unchanged + let res = eval_hof_on_i32_list( + array_filter_higher_order_function(), + list.clone(), + lit(true), + ) + .unwrap(); + assert_eq!(res.as_list::(), &list); + } + + #[test] + fn scalar_false_predicate_returns_empty_sublists() { + let list = create_i32_list( + vec![1, 2, 3, 4], + OffsetBuffer::::from_lengths(vec![2, 2]), + None, + ); + // x -> false: every sublist emptied + let res = + eval_hof_on_i32_list(array_filter_higher_order_function(), list, lit(false)) + .unwrap(); + let actual = res.as_list::(); + let expected = create_i32_list( + vec![0i32; 0], + OffsetBuffer::::from_lengths(vec![0, 0]), + None, + ); + assert_eq!(actual, &expected); + } +} diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index 8ae8c42b79d5e..04818258f040b 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -17,18 +17,22 @@ //! [`ScalarUDFImpl`] definitions for array_has, array_has_all and array_has_any functions. -use arrow::array::{Array, ArrayRef, BooleanArray, Datum, Scalar}; -use arrow::buffer::BooleanBuffer; +use arrow::array::{ + Array, ArrayRef, AsArray, BooleanArray, BooleanBufferBuilder, Datum, Scalar, + StringArrayType, +}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::datatypes::DataType; use arrow::row::{RowConverter, Rows, SortField}; use datafusion_common::cast::{as_fixed_size_list_array, as_generic_list_array}; use datafusion_common::utils::string_utils::string_array_to_vec; use datafusion_common::utils::take_function_args; -use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ - in_list, ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, in_list, }; use datafusion_macros::user_doc; use datafusion_physical_expr_common::datum::compare_with_eq; @@ -37,7 +41,8 @@ use itertools::Itertools; use crate::make_array::make_array_udf; use crate::utils::make_scalar_function; -use std::any::Any; +use hashbrown::HashSet; +use std::ops::Range; use std::sync::Arc; // Create static instances of ScalarUDFs for each function @@ -55,7 +60,7 @@ make_udf_expr_and_func!(ArrayHasAll, ); make_udf_expr_and_func!(ArrayHasAny, array_has_any, - haystack_array needle_array, // arg names + first_array second_array, // arg names "returns true if at least one element of the second array appears in the first array; otherwise, it returns false.", // doc array_has_any_udf // internal function name ); @@ -107,9 +112,6 @@ impl ArrayHas { } impl ScalarUDFImpl for ArrayHas { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "array_has" } @@ -125,7 +127,7 @@ impl ScalarUDFImpl for ArrayHas { fn simplify( &self, mut args: Vec, - _info: &dyn datafusion_expr::simplify::SimplifyInfo, + _info: &datafusion_expr::simplify::SimplifyContext, ) -> Result { let [haystack, needle] = take_function_args(self.name(), &mut args)?; @@ -136,7 +138,7 @@ impl ScalarUDFImpl for ArrayHas { return Ok(ExprSimplifyResult::Simplified(Expr::Literal( ScalarValue::Boolean(None), None, - ))) + ))); } Expr::Literal( // FixedSizeList gets coerced to List @@ -176,10 +178,7 @@ impl ScalarUDFImpl for ArrayHas { Ok(ExprSimplifyResult::Original(args)) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let [first_arg, second_arg] = take_function_args(self.name(), &args.args)?; if first_arg.data_type().is_null() { // Always return null if the first argument is null @@ -262,7 +261,7 @@ impl<'a> TryFrom<&'a dyn Array> for ArrayWrapper<'a> { DataType::FixedSizeList(_, _) => Ok(ArrayWrapper::FixedSizeList( as_fixed_size_list_array(value)?, )), - _ => exec_err!("array_has does not support type '{:?}'.", value.data_type()), + _ => exec_err!("array_has does not support type '{}'.", value.data_type()), } } } @@ -303,10 +302,8 @@ impl<'a> ArrayWrapper<'a> { fn offsets(&self) -> Box + 'a> { match self { ArrayWrapper::FixedSizeList(arr) => { - let offsets = (0..=arr.len()) - .step_by(arr.value_length() as usize) - .collect::>(); - Box::new(offsets.into_iter()) + let value_length = arr.value_length() as usize; + Box::new((0..=arr.len()).map(move |i| i * value_length)) } ArrayWrapper::List(arr) => { Box::new(arr.offsets().iter().map(|o| (*o) as usize)) @@ -316,34 +313,41 @@ impl<'a> ArrayWrapper<'a> { } } } + + fn nulls(&self) -> Option<&NullBuffer> { + match self { + ArrayWrapper::FixedSizeList(arr) => arr.nulls(), + ArrayWrapper::List(arr) => arr.nulls(), + ArrayWrapper::LargeList(arr) => arr.nulls(), + } + } } fn array_has_dispatch_for_array<'a>( haystack: ArrayWrapper<'a>, needle: &ArrayRef, ) -> Result { - let mut boolean_builder = BooleanArray::builder(haystack.len()); + let combined_nulls = NullBuffer::union(haystack.nulls(), needle.nulls()); + let mut result = BooleanBufferBuilder::new(haystack.len()); for (i, arr) in haystack.iter().enumerate() { - if arr.is_none() || needle.is_null(i) { - boolean_builder.append_null(); + if combined_nulls.as_ref().is_some_and(|n| n.is_null(i)) { + result.append(false); continue; } let arr = arr.unwrap(); let is_nested = arr.data_type().is_nested(); let needle_row = Scalar::new(needle.slice(i, 1)); let eq_array = compare_with_eq(&arr, &needle_row, is_nested)?; - boolean_builder.append_value(eq_array.true_count() > 0); + result.append(eq_array.has_true()); } - Ok(Arc::new(boolean_builder.finish())) + Ok(Arc::new(BooleanArray::new(result.finish(), combined_nulls))) } fn array_has_dispatch_for_scalar( haystack: ArrayWrapper<'_>, needle: &dyn Datum, ) -> Result { - let values = haystack.values(); - let is_nested = values.data_type().is_nested(); // If first argument is empty list (second argument is non-null), return false // i.e. array_has([], non-null element) -> false if haystack.len() == 0 { @@ -352,67 +356,125 @@ fn array_has_dispatch_for_scalar( None, ))); } - let eq_array = compare_with_eq(values, needle, is_nested)?; - let mut final_contained = vec![None; haystack.len()]; - // Check validity buffer to distinguish between null and empty arrays + // For sliced ListArrays, values() returns the full underlying array but + // only elements between the first and last offset are visible. + let offsets: Vec = haystack.offsets().collect(); + let first_offset = offsets[0]; + let visible_values = haystack + .values() + .slice(first_offset, offsets[offsets.len() - 1] - first_offset); + + let is_nested = visible_values.data_type().is_nested(); + let eq_array = compare_with_eq(&visible_values, needle, is_nested)?; + + // When a haystack element is null, `eq()` returns null (not false). + // In Arrow, a null BooleanArray entry has validity=0 but an + // undefined value bit that may happen to be 1. Since set_indices() + // operates on the raw value buffer and ignores validity, we AND the + // values with the validity bitmap to clear any undefined bits at + // null positions. This ensures set_indices() only yields positions + // where the comparison genuinely returned true. + let eq_bits = match eq_array.nulls() { + Some(nulls) => eq_array.values() & nulls.inner(), + None => eq_array.values().clone(), + }; + let validity = match &haystack { ArrayWrapper::FixedSizeList(arr) => arr.nulls(), ArrayWrapper::List(arr) => arr.nulls(), ArrayWrapper::LargeList(arr) => arr.nulls(), }; + let mut matches = eq_bits.set_indices().peekable(); + let mut result = BooleanBufferBuilder::new(haystack.len()); + result.append_n(haystack.len(), false); - for (i, (start, end)) in haystack.offsets().tuple_windows().enumerate() { - let length = end - start; + // Match positions are relative to visible_values (0-based), so + // subtract first_offset from each offset when comparing. + for (i, window) in offsets.windows(2).enumerate() { + let end = window[1] - first_offset; - // Check if the array at this position is null - if let Some(validity_buffer) = validity { - if !validity_buffer.is_valid(i) { - final_contained[i] = None; // null array -> null result - continue; - } + let has_match = matches.peek().is_some_and(|&p| p < end); + + // Advance past all match positions in this row's range. + while matches.peek().is_some_and(|&p| p < end) { + matches.next(); } - // For non-null arrays: length is 0 for empty arrays - if length == 0 { - final_contained[i] = Some(false); // empty array -> false - } else { - let sliced_array = eq_array.slice(start, length); - final_contained[i] = Some(sliced_array.true_count() > 0); + if has_match && validity.is_none_or(|v| v.is_valid(i)) { + result.set_bit(i, true); } } - Ok(Arc::new(BooleanArray::from(final_contained))) + // A null haystack row always produces a null output, so we can + // reuse the haystack's null buffer directly. + Ok(Arc::new(BooleanArray::new( + result.finish(), + validity.cloned(), + ))) } fn array_has_all_inner(args: &[ArrayRef]) -> Result { array_has_all_and_any_inner(args, ComparisonType::All) } +/// Number of rows to process at a time when doing batched row conversion. This +/// amortizes the row conversion overhead over more rows, but making this too +/// large can cause cache pressure for large arrays. See +/// for context. +const ROW_CONVERSION_CHUNK_SIZE: usize = 512; + // General row comparison for array_has_all and array_has_any fn general_array_has_for_all_and_any<'a>( haystack: ArrayWrapper<'a>, needle: ArrayWrapper<'a>, comparison_type: ComparisonType, ) -> Result { - let mut boolean_builder = BooleanArray::builder(haystack.len()); + let num_rows = haystack.len(); let converter = RowConverter::new(vec![SortField::new(haystack.value_type())])?; - for (arr, sub_arr) in haystack.iter().zip(needle.iter()) { - if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = converter.convert_columns(&[sub_arr])?; - boolean_builder.append_value(general_array_has_all_and_any_kernel( - &arr_values, - &sub_arr_values, + let h_offsets: Vec = haystack.offsets().collect(); + let n_offsets: Vec = needle.offsets().collect(); + + let combined_nulls = NullBuffer::union(haystack.nulls(), needle.nulls()); + let mut result = BooleanBufferBuilder::new(num_rows); + + for chunk_start in (0..num_rows).step_by(ROW_CONVERSION_CHUNK_SIZE) { + let chunk_end = (chunk_start + ROW_CONVERSION_CHUNK_SIZE).min(num_rows); + + // For efficiency with sliced arrays, only process the visible elements, + // not the entire underlying buffer. + let h_elem_start = h_offsets[chunk_start]; + let h_elem_end = h_offsets[chunk_end]; + let n_elem_start = n_offsets[chunk_start]; + let n_elem_end = n_offsets[chunk_end]; + + let h_vals = haystack + .values() + .slice(h_elem_start, h_elem_end - h_elem_start); + let n_vals = needle + .values() + .slice(n_elem_start, n_elem_end - n_elem_start); + + let chunk_h_rows = converter.convert_columns(&[h_vals])?; + let chunk_n_rows = converter.convert_columns(&[n_vals])?; + + for i in chunk_start..chunk_end { + if combined_nulls.as_ref().is_some_and(|n| n.is_null(i)) { + result.append(false); + continue; + } + result.append(general_array_has_all_and_any_kernel( + &chunk_h_rows, + (h_offsets[i] - h_elem_start)..(h_offsets[i + 1] - h_elem_start), + &chunk_n_rows, + (n_offsets[i] - n_elem_start)..(n_offsets[i + 1] - n_elem_start), comparison_type, )); - } else { - boolean_builder.append_null(); } } - Ok(Arc::new(boolean_builder.finish())) + Ok(Arc::new(BooleanArray::new(result.finish(), combined_nulls))) } // String comparison for array_has_all and array_has_any @@ -421,25 +483,50 @@ fn array_has_all_and_any_string_internal<'a>( needle: ArrayWrapper<'a>, comparison_type: ComparisonType, ) -> Result { - let mut boolean_builder = BooleanArray::builder(haystack.len()); - for (arr, sub_arr) in haystack.iter().zip(needle.iter()) { - match (arr, sub_arr) { - (Some(arr), Some(sub_arr)) => { - let haystack_array = string_array_to_vec(&arr); - let needle_array = string_array_to_vec(&sub_arr); - boolean_builder.append_value(array_has_string_kernel( - &haystack_array, - &needle_array, - comparison_type, - )); - } - (_, _) => { - boolean_builder.append_null(); + let num_rows = haystack.len(); + + let h_offsets: Vec = haystack.offsets().collect(); + let n_offsets: Vec = needle.offsets().collect(); + + let combined_nulls = NullBuffer::union(haystack.nulls(), needle.nulls()); + let mut result = BooleanBufferBuilder::new(num_rows); + + for chunk_start in (0..num_rows).step_by(ROW_CONVERSION_CHUNK_SIZE) { + let chunk_end = (chunk_start + ROW_CONVERSION_CHUNK_SIZE).min(num_rows); + + let h_elem_start = h_offsets[chunk_start]; + let h_elem_end = h_offsets[chunk_end]; + let n_elem_start = n_offsets[chunk_start]; + let n_elem_end = n_offsets[chunk_end]; + + let h_vals = haystack + .values() + .slice(h_elem_start, h_elem_end - h_elem_start); + let n_vals = needle + .values() + .slice(n_elem_start, n_elem_end - n_elem_start); + + let chunk_h_strings = string_array_to_vec(h_vals.as_ref()); + let chunk_n_strings = string_array_to_vec(n_vals.as_ref()); + + for i in chunk_start..chunk_end { + if combined_nulls.as_ref().is_some_and(|n| n.is_null(i)) { + result.append(false); + continue; } + let h_start = h_offsets[i] - h_elem_start; + let h_end = h_offsets[i + 1] - h_elem_start; + let n_start = n_offsets[i] - n_elem_start; + let n_end = n_offsets[i + 1] - n_elem_start; + result.append(array_has_string_kernel( + &chunk_h_strings[h_start..h_end], + &chunk_n_strings[n_start..n_end], + comparison_type, + )); } } - Ok(Arc::new(boolean_builder.finish())) + Ok(Arc::new(BooleanArray::new(result.finish(), combined_nulls))) } fn array_has_all_and_any_dispatch<'a>( @@ -476,6 +563,219 @@ fn array_has_any_inner(args: &[ArrayRef]) -> Result { array_has_all_and_any_inner(args, ComparisonType::Any) } +/// Fast path for `array_has_any` when exactly one argument is a scalar. +fn array_has_any_with_scalar( + columnar_arg: &ColumnarValue, + scalar_arg: &ScalarValue, +) -> Result { + if scalar_arg.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + + // Convert the scalar to a 1-element ListArray, then extract the inner values + let scalar_array = scalar_arg.to_array_of_size(1)?; + let scalar_list: ArrayWrapper = scalar_array.as_ref().try_into()?; + let offsets: Vec = scalar_list.offsets().collect(); + let scalar_values = scalar_list + .values() + .slice(offsets[0], offsets[1] - offsets[0]); + + // If scalar list is empty, result is always false + if scalar_values.is_empty() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(false)))); + } + + match scalar_values.data_type() { + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { + array_has_any_with_scalar_string(columnar_arg, &scalar_values) + } + _ => array_has_any_with_scalar_general(columnar_arg, &scalar_values), + } +} + +/// When the scalar argument has more elements than this, the scalar fast path +/// builds a HashSet for O(1) lookups. At or below this threshold, it falls +/// back to a linear scan, since hashing every columnar element is more +/// expensive than a linear scan over a short array. +const SCALAR_SMALL_THRESHOLD: usize = 8; + +/// String-specialized scalar fast path for `array_has_any`. +fn array_has_any_with_scalar_string( + columnar_arg: &ColumnarValue, + scalar_values: &ArrayRef, +) -> Result { + let (col_arr, is_scalar_output) = match columnar_arg { + ColumnarValue::Array(arr) => (Arc::clone(arr), false), + ColumnarValue::Scalar(s) => (s.to_array_of_size(1)?, true), + }; + + let col_list: ArrayWrapper = col_arr.as_ref().try_into()?; + let col_values = col_list.values(); + let col_offsets: Vec = col_list.offsets().collect(); + let col_nulls = col_list.nulls(); + + let scalar_lookup = ScalarStringLookup::new(scalar_values); + let has_null_scalar = scalar_values.null_count() > 0; + + let result = match col_values.data_type() { + DataType::Utf8 => array_has_any_string_inner( + col_values.as_string::(), + &col_offsets, + col_nulls, + has_null_scalar, + &scalar_lookup, + ), + DataType::LargeUtf8 => array_has_any_string_inner( + col_values.as_string::(), + &col_offsets, + col_nulls, + has_null_scalar, + &scalar_lookup, + ), + DataType::Utf8View => array_has_any_string_inner( + col_values.as_string_view(), + &col_offsets, + col_nulls, + has_null_scalar, + &scalar_lookup, + ), + _ => unreachable!("array_has_any_with_scalar_string called with non-string type"), + }; + + if is_scalar_output { + Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &result, 0, + )?)) + } else { + Ok(ColumnarValue::Array(result)) + } +} + +/// Pre-computed lookup structure for the scalar string fastpath. +enum ScalarStringLookup<'a> { + /// Large scalar: HashSet for O(1) lookups. + Set(HashSet<&'a str>), + /// Small scalar: Vec for linear scan. + List(Vec>), +} + +impl<'a> ScalarStringLookup<'a> { + fn new(scalar_values: &'a ArrayRef) -> Self { + let strings = string_array_to_vec(scalar_values.as_ref()); + if strings.len() > SCALAR_SMALL_THRESHOLD { + ScalarStringLookup::Set(strings.into_iter().flatten().collect()) + } else { + ScalarStringLookup::List(strings) + } + } + + fn contains(&self, value: &str) -> bool { + match self { + ScalarStringLookup::Set(set) => set.contains(value), + ScalarStringLookup::List(list) => list.contains(&Some(value)), + } + } +} + +/// Inner implementation of the string scalar fast path, generic over string +/// array type to allow direct element access by index. +fn array_has_any_string_inner<'a, C: StringArrayType<'a> + Copy>( + col_strings: C, + col_offsets: &[usize], + col_nulls: Option<&NullBuffer>, + has_null_scalar: bool, + scalar_lookup: &ScalarStringLookup<'_>, +) -> ArrayRef { + let num_rows = col_offsets.len() - 1; + let mut result = BooleanBufferBuilder::new(num_rows); + + for i in 0..num_rows { + if col_nulls.is_some_and(|v| v.is_null(i)) { + result.append(false); + continue; + } + let start = col_offsets[i]; + let end = col_offsets[i + 1]; + let found = (start..end).any(|j| { + if col_strings.is_null(j) { + has_null_scalar + } else { + scalar_lookup.contains(col_strings.value(j)) + } + }); + result.append(found); + } + + Arc::new(BooleanArray::new(result.finish(), col_nulls.cloned())) +} + +/// General scalar fast path for `array_has_any`, using RowConverter for +/// type-erased comparison. +fn array_has_any_with_scalar_general( + columnar_arg: &ColumnarValue, + scalar_values: &ArrayRef, +) -> Result { + let converter = + RowConverter::new(vec![SortField::new(scalar_values.data_type().clone())])?; + let scalar_rows = converter.convert_columns(&[Arc::clone(scalar_values)])?; + + let (col_arr, is_scalar_output) = match columnar_arg { + ColumnarValue::Array(arr) => (Arc::clone(arr), false), + ColumnarValue::Scalar(s) => (s.to_array_of_size(1)?, true), + }; + + let col_list: ArrayWrapper = col_arr.as_ref().try_into()?; + let col_rows = converter.convert_columns(&[Arc::clone(col_list.values())])?; + let col_offsets: Vec = col_list.offsets().collect(); + let col_nulls = col_list.nulls(); + + let mut result = BooleanBufferBuilder::new(col_list.len()); + let num_scalar = scalar_rows.num_rows(); + + if num_scalar > SCALAR_SMALL_THRESHOLD { + // Large scalar: build HashSet for O(1) lookups + let scalar_set: HashSet> = (0..num_scalar) + .map(|i| Box::from(scalar_rows.row(i).as_ref())) + .collect(); + + for i in 0..col_list.len() { + if col_nulls.is_some_and(|v| v.is_null(i)) { + result.append(false); + continue; + } + let start = col_offsets[i]; + let end = col_offsets[i + 1]; + let found = + (start..end).any(|j| scalar_set.contains(col_rows.row(j).as_ref())); + result.append(found); + } + } else { + // Small scalar: linear scan avoids HashSet hashing overhead + for i in 0..col_list.len() { + if col_nulls.is_some_and(|v| v.is_null(i)) { + result.append(false); + continue; + } + let start = col_offsets[i]; + let end = col_offsets[i + 1]; + let found = (start..end) + .any(|j| (0..num_scalar).any(|k| col_rows.row(j) == scalar_rows.row(k))); + result.append(found); + } + } + + let output: ArrayRef = + Arc::new(BooleanArray::new(result.finish(), col_nulls.cloned())); + + if is_scalar_output { + Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &output, 0, + )?)) + } else { + Ok(ColumnarValue::Array(output)) + } +} + #[user_doc( doc_section(label = "Array Functions"), description = "Returns true if all elements of sub-array exist in array.", @@ -519,9 +819,6 @@ impl ArrayHasAll { } impl ScalarUDFImpl for ArrayHasAll { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "array_has_all" } @@ -534,10 +831,7 @@ impl ScalarUDFImpl for ArrayHasAll { Ok(DataType::Boolean) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_has_all_inner)(&args.args) } @@ -552,8 +846,8 @@ impl ScalarUDFImpl for ArrayHasAll { #[user_doc( doc_section(label = "Array Functions"), - description = "Returns true if any elements exist in both arrays.", - syntax_example = "array_has_any(array, sub-array)", + description = "Returns true if the arrays have any elements in common.", + syntax_example = "array_has_any(array1, array2)", sql_example = r#"```sql > select array_has_any([1, 2, 3], [3, 4]); +------------------------------------------+ @@ -563,11 +857,11 @@ impl ScalarUDFImpl for ArrayHasAll { +------------------------------------------+ ```"#, argument( - name = "array", + name = "array1", description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ), argument( - name = "sub-array", + name = "array2", description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ) )] @@ -593,9 +887,6 @@ impl ArrayHasAny { } impl ScalarUDFImpl for ArrayHasAny { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "array_has_any" } @@ -608,11 +899,16 @@ impl ScalarUDFImpl for ArrayHasAny { Ok(DataType::Boolean) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - make_scalar_function(array_has_any_inner)(&args.args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [first_arg, second_arg] = take_function_args(self.name(), &args.args)?; + + // If either argument is scalar, use the fast path. + match (&first_arg, &second_arg) { + (cv, ColumnarValue::Scalar(scalar)) | (ColumnarValue::Scalar(scalar), cv) => { + array_has_any_with_scalar(cv, scalar) + } + _ => make_scalar_function(array_has_any_inner)(&args.args), + } } fn aliases(&self) -> &[String] { @@ -652,19 +948,22 @@ fn array_has_string_kernel( fn general_array_has_all_and_any_kernel( haystack_rows: &Rows, + h_range: Range, needle_rows: &Rows, + mut n_range: Range, comparison_type: ComparisonType, ) -> bool { + let h_start = h_range.start; + let h_end = h_range.end; + match comparison_type { - ComparisonType::All => needle_rows.iter().all(|needle_row| { - haystack_rows - .iter() - .any(|haystack_row| haystack_row == needle_row) + ComparisonType::All => n_range.all(|ni| { + let needle_row = needle_rows.row(ni); + (h_start..h_end).any(|hi| haystack_rows.row(hi) == needle_row) }), - ComparisonType::Any => needle_rows.iter().any(|needle_row| { - haystack_rows - .iter() - .any(|haystack_row| haystack_row == needle_row) + ComparisonType::Any => n_range.any(|ni| { + let needle_row = needle_rows.row(ni); + (h_start..h_end).any(|hi| haystack_rows.row(hi) == needle_row) }), } } @@ -675,22 +974,26 @@ mod tests { use arrow::datatypes::Int32Type; use arrow::{ - array::{create_array, Array, ArrayRef, AsArray, Int32Array, ListArray}, + array::{ + Array, ArrayRef, AsArray, FixedSizeListArray, Int32Array, ListArray, + create_array, + }, buffer::OffsetBuffer, datatypes::{DataType, Field}, }; use datafusion_common::{ - config::ConfigOptions, utils::SingleRowListArrayBuilder, DataFusionError, - ScalarValue, + DataFusionError, ScalarValue, config::ConfigOptions, + utils::SingleRowListArrayBuilder, }; + use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{ - col, execution_props::ExecutionProps, lit, simplify::ExprSimplifyResult, - ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDFImpl, + ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDFImpl, col, lit, + simplify::ExprSimplifyResult, }; use crate::expr_fn::make_array; - use super::ArrayHas; + use super::{ArrayHas, ArrayHasAll, ArrayHasAny}; #[test] fn test_simplify_array_has_to_in_list() { @@ -701,8 +1004,7 @@ mod tests { .build_list_scalar()); let needle = col("c"); - let props = ExecutionProps::new(); - let context = datafusion_expr::simplify::SimplifyContext::new(&props); + let context = SimplifyContext::default(); let Ok(ExprSimplifyResult::Simplified(Expr::InList(in_list))) = ArrayHas::new().simplify(vec![haystack, needle.clone()], &context) @@ -725,8 +1027,7 @@ mod tests { let haystack = make_array(vec![lit(1), lit(2), lit(3)]); let needle = col("c"); - let props = ExecutionProps::new(); - let context = datafusion_expr::simplify::SimplifyContext::new(&props); + let context = SimplifyContext::default(); let Ok(ExprSimplifyResult::Simplified(Expr::InList(in_list))) = ArrayHas::new().simplify(vec![haystack, needle.clone()], &context) @@ -749,8 +1050,7 @@ mod tests { let haystack = Expr::Literal(ScalarValue::Null, None); let needle = col("c"); - let props = ExecutionProps::new(); - let context = datafusion_expr::simplify::SimplifyContext::new(&props); + let context = SimplifyContext::default(); let Ok(ExprSimplifyResult::Simplified(simplified)) = ArrayHas::new().simplify(vec![haystack, needle], &context) else { @@ -767,8 +1067,7 @@ mod tests { let haystack = Expr::Literal(ScalarValue::List(Arc::new(haystack)), None); let needle = col("c"); - let props = ExecutionProps::new(); - let context = datafusion_expr::simplify::SimplifyContext::new(&props); + let context = SimplifyContext::default(); let Ok(ExprSimplifyResult::Simplified(simplified)) = ArrayHas::new().simplify(vec![haystack, needle], &context) else { @@ -783,8 +1082,7 @@ mod tests { let haystack = col("c1"); let needle = col("c2"); - let props = ExecutionProps::new(); - let context = datafusion_expr::simplify::SimplifyContext::new(&props); + let context = SimplifyContext::default(); let Ok(ExprSimplifyResult::Original(args)) = ArrayHas::new().simplify(vec![haystack, needle.clone()], &context) @@ -830,6 +1128,52 @@ mod tests { Ok(()) } + #[test] + fn test_array_has_sliced_list() -> Result<(), DataFusionError> { + // [[10, 20], [30, 40], [50, 60], [70, 80]] → slice(1,2) → [[30, 40], [50, 60]] + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(10), Some(20)]), + Some(vec![Some(30), Some(40)]), + Some(vec![Some(50), Some(60)]), + Some(vec![Some(70), Some(80)]), + ]); + let sliced = list.slice(1, 2); + let haystack_field = + Arc::new(Field::new("haystack", sliced.data_type().clone(), true)); + let needle_field = Arc::new(Field::new("needle", DataType::Int32, true)); + let return_field = Arc::new(Field::new("return", DataType::Boolean, true)); + + // Search for elements that exist only in sliced-away rows: + // 10 is in the prefix row, 70 is in the suffix row. + let invoke = |needle: i32| -> Result { + ArrayHas::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(sliced.clone())), + ColumnarValue::Scalar(ScalarValue::Int32(Some(needle))), + ], + arg_fields: vec![ + Arc::clone(&haystack_field), + Arc::clone(&needle_field), + ], + number_rows: 2, + return_field: Arc::clone(&return_field), + config_options: Arc::new(ConfigOptions::default()), + })? + .into_array(2) + }; + + let output = invoke(10)?.as_boolean().clone(); + assert!(!output.value(0)); + assert!(!output.value(1)); + + let output = invoke(70)?.as_boolean().clone(); + assert!(!output.value(0)); + assert!(!output.value(1)); + + Ok(()) + } + #[test] fn test_array_has_list_null_haystack() -> Result<(), DataFusionError> { let haystack_field = Arc::new(Field::new("haystack", DataType::Null, true)); @@ -859,4 +1203,112 @@ mod tests { Ok(()) } + + /// Invoke a two-argument list UDF with the given arrays and assert the + /// boolean output matches `expected`. + fn invoke_and_assert( + udf: &dyn ScalarUDFImpl, + haystack: &ArrayRef, + needle: ArrayRef, + expected: &[Option], + ) { + let num_rows = haystack.len(); + let list_type = haystack.data_type(); + let result = udf + .invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::clone(haystack)), + ColumnarValue::Array(needle), + ], + arg_fields: vec![ + Arc::new(Field::new("haystack", list_type.clone(), false)), + Arc::new(Field::new("needle", list_type.clone(), false)), + ], + number_rows: num_rows, + return_field: Arc::new(Field::new("return", DataType::Boolean, true)), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(); + let output = result.into_array(num_rows).unwrap(); + assert_eq!(output.as_boolean().iter().collect::>(), expected); + } + + #[test] + fn test_sliced_list_offsets() { + // Full rows: + // row 0: [1, 2] (not visible after slicing) + // row 1: [11, 12] (visible row 0) + // row 2: [21, 22] (visible row 1) + // row 3: [31, 32] (not visible after slicing) + let field: Arc = Arc::new(Field::new("item", DataType::Int32, false)); + let full_values = Arc::new(Int32Array::from(vec![1, 2, 11, 12, 21, 22, 31, 32])); + let full_offsets = OffsetBuffer::new(vec![0, 2, 4, 6, 8].into()); + let full = ListArray::new(Arc::clone(&field), full_offsets, full_values, None); + let sliced_haystack: ArrayRef = Arc::new(full.slice(1, 2)); + + // array_has_all: needle row 0 = [11], row 1 = [21] + let needle_all: ArrayRef = Arc::new(ListArray::new( + Arc::clone(&field), + OffsetBuffer::new(vec![0, 1, 2].into()), + Arc::new(Int32Array::from(vec![11, 21])), + None, + )); + invoke_and_assert( + &ArrayHasAll::new(), + &sliced_haystack, + needle_all, + &[Some(true), Some(true)], + ); + + // array_has_any: needle row 0 = [99, 11], row 1 = [99, 21] + let needle_any: ArrayRef = Arc::new(ListArray::new( + field, + OffsetBuffer::new(vec![0, 2, 4].into()), + Arc::new(Int32Array::from(vec![99, 11, 99, 21])), + None, + )); + invoke_and_assert( + &ArrayHasAny::new(), + &sliced_haystack, + needle_any, + &[Some(true), Some(true)], + ); + } + + #[test] + fn test_sliced_fixed_size_list_offsets() { + // Same logical data as test_sliced_list_offsets, but using FixedSizeListArray. + let field = Arc::new(Field::new("item", DataType::Int32, false)); + let full_values = Arc::new(Int32Array::from(vec![1, 2, 11, 12, 21, 22, 31, 32])); + let full = FixedSizeListArray::new(Arc::clone(&field), 2, full_values, None); + let sliced_haystack: ArrayRef = Arc::new(full.slice(1, 2)); + + // array_has_all: needle row 0 = [11, 12], row 1 = [21, 22] + let needle_all: ArrayRef = Arc::new(FixedSizeListArray::new( + Arc::clone(&field), + 2, + Arc::new(Int32Array::from(vec![11, 12, 21, 22])), + None, + )); + invoke_and_assert( + &ArrayHasAll::new(), + &sliced_haystack, + needle_all, + &[Some(true), Some(true)], + ); + + // array_has_any: needle row 0 = [99, 12], row 1 = [99, 22] + let needle_any: ArrayRef = Arc::new(FixedSizeListArray::new( + field, + 2, + Arc::new(Int32Array::from(vec![99, 12, 99, 22])), + None, + )); + invoke_and_assert( + &ArrayHasAny::new(), + &sliced_haystack, + needle_any, + &[Some(true), Some(true)], + ); + } } diff --git a/datafusion/functions-nested/src/array_normalize.rs b/datafusion/functions-nested/src/array_normalize.rs new file mode 100644 index 0000000000000..0ff7674032d7f --- /dev/null +++ b/datafusion/functions-nested/src/array_normalize.rs @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_normalize function. + +use crate::utils::make_scalar_function; +use arrow::array::{ + Array, ArrayRef, Float64Array, GenericListArray, NullBufferBuilder, + OffsetBufferBuilder, OffsetSizeTrait, +}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, Null}, + Field, +}; +use datafusion_common::cast::{as_float64_array, as_generic_list_array}; +use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; +use datafusion_common::{Result, internal_err, plan_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +make_udf_expr_and_func!( + ArrayNormalize, + array_normalize, + array, + "returns the L2-normalized vector for a numeric array.", + array_normalize_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the L2-normalized vector for the input numeric array, computed as `array[i] / sqrt(sum(array[i]^2))` per element. Returns NULL if the input is NULL, contains NULL elements, or has zero magnitude (all elements are zero). Returns an empty array for an empty input array.", + syntax_example = "array_normalize(array)", + sql_example = r#"```sql +> select array_normalize([3.0, 4.0]); ++-----------------------------+ +| array_normalize(List([3.0,4.0])) | ++-----------------------------+ +| [0.6, 0.8] | ++-----------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayNormalize { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayNormalize { + fn default() -> Self { + Self::new() + } +} + +impl ArrayNormalize { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_normalize".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayNormalize { + fn name(&self) -> &str { + "array_normalize" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // After `coerce_types`, `arg_types[0]` is one of List(Float64) or LargeList(Float64). + Ok(arg_types[0].clone()) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [arg_type] = take_function_args(self.name(), arg_types)?; + let coercion = Some(&ListCoercion::FixedSizedListToList); + + if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + return plan_err!("{} does not support type {arg_type}", self.name()); + } + + let coerced = if matches!(arg_type, Null) { + List(Arc::new(Field::new_list_field(DataType::Float64, true))) + } else { + coerced_type_with_base_type_only(arg_type, &DataType::Float64, coercion) + }; + + Ok(vec![coerced]) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_normalize_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn array_normalize_inner(args: &[ArrayRef]) -> Result { + let [array] = take_function_args("array_normalize", args)?; + match array.data_type() { + List(_) => general_array_normalize::(args), + LargeList(_) => general_array_normalize::(args), + arg_type => internal_err!( + "array_normalize received unexpected type after coercion: {arg_type}" + ), + } +} + +fn general_array_normalize(arrays: &[ArrayRef]) -> Result { + let list_array = as_generic_list_array::(&arrays[0])?; + let values = as_float64_array(list_array.values())?; + let offsets = list_array.value_offsets(); + + let mut new_values: Vec = Vec::with_capacity(values.len()); + let mut new_offsets = OffsetBufferBuilder::::new(list_array.len()); + let mut nulls = NullBufferBuilder::new(list_array.len()); + + for row in 0..list_array.len() { + if list_array.is_null(row) { + nulls.append_null(); + new_offsets.push_length(0); + continue; + } + + let start = offsets[row].as_usize(); + let end = offsets[row + 1].as_usize(); + let len = end - start; + + let slice = values.slice(start, len); + if slice.null_count() != 0 { + nulls.append_null(); + new_offsets.push_length(0); + continue; + } + + let vals = slice.values(); + + // Empty array: return empty array (no normalization needed, no division by zero risk) + if len == 0 { + nulls.append_non_null(); + new_offsets.push_length(0); + continue; + } + + // Compute squared magnitude. + let mut sq_sum = 0.0; + for i in 0..len { + sq_sum += vals[i] * vals[i]; + } + + // Zero magnitude: undefined normalization. Emit NULL row. + if sq_sum == 0.0 { + nulls.append_null(); + new_offsets.push_length(0); + continue; + } + + let mag = sq_sum.sqrt(); + for i in 0..len { + new_values.push(vals[i] / mag); + } + nulls.append_non_null(); + new_offsets.push_length(len); + } + + let values_array = Arc::new(Float64Array::from(new_values)); + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + + Ok(Arc::new(GenericListArray::::try_new( + field, + new_offsets.finish(), + values_array, + nulls.finish(), + )?)) +} diff --git a/datafusion/functions-nested/src/array_product.rs b/datafusion/functions-nested/src/array_product.rs new file mode 100644 index 0000000000000..a5cef43142fa0 --- /dev/null +++ b/datafusion/functions-nested/src/array_product.rs @@ -0,0 +1,174 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_product function. + +use crate::utils::make_scalar_function; +use arrow::array::{Array, ArrayRef, Float64Array, OffsetSizeTrait}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, Null}, + Field, +}; +use datafusion_common::cast::{as_float64_array, as_generic_list_array}; +use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; +use datafusion_common::{Result, internal_err, plan_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +make_udf_expr_and_func!( + ArrayProduct, + array_product, + array, + "returns the product of the elements of a numeric array.", + array_product_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the product of the elements in the input numeric array. \ + NULL elements inside the array are skipped (matching SQL aggregate \ + convention). Returns NULL if the input is NULL, every element is \ + NULL, or the array is empty. The result is always returned as \ + `Float64`.", + syntax_example = "array_product(array)", + sql_example = r#"```sql +> select array_product([1.0, 2.0, 3.0]); ++------------------------------------+ +| array_product(List([1.0,2.0,3.0])) | ++------------------------------------+ +| 6.0 | ++------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayProduct { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayProduct { + fn default() -> Self { + Self::new() + } +} + +impl ArrayProduct { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_product".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayProduct { + fn name(&self) -> &str { + "array_product" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [arg_type] = take_function_args(self.name(), arg_types)?; + let coercion = Some(&ListCoercion::FixedSizedListToList); + + if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + return plan_err!("{} does not support type {arg_type}", self.name()); + } + + let coerced = if matches!(arg_type, Null) { + List(Arc::new(Field::new_list_field(DataType::Float64, true))) + } else { + coerced_type_with_base_type_only(arg_type, &DataType::Float64, coercion) + }; + + Ok(vec![coerced]) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_product_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn array_product_inner(args: &[ArrayRef]) -> Result { + let [array] = take_function_args("array_product", args)?; + match array.data_type() { + List(_) => general_array_product::(args), + LargeList(_) => general_array_product::(args), + arg_type => internal_err!( + "array_product received unexpected type after coercion: {arg_type}" + ), + } +} + +fn general_array_product(arrays: &[ArrayRef]) -> Result { + let list_array = as_generic_list_array::(&arrays[0])?; + let values = as_float64_array(list_array.values())?; + let offsets = list_array.value_offsets(); + + let mut builder = Float64Array::builder(list_array.len()); + + for row in 0..list_array.len() { + if list_array.is_null(row) { + builder.append_null(); + continue; + } + + let start = offsets[row].as_usize(); + let end = offsets[row + 1].as_usize(); + + let mut prod = 1.0_f64; + let mut any_valid = false; + for i in start..end { + if values.is_valid(i) { + prod *= values.value(i); + any_valid = true; + } + } + + if any_valid { + builder.append_value(prod); + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) +} diff --git a/datafusion/functions-nested/src/array_scale.rs b/datafusion/functions-nested/src/array_scale.rs new file mode 100644 index 0000000000000..24750ade8a775 --- /dev/null +++ b/datafusion/functions-nested/src/array_scale.rs @@ -0,0 +1,220 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_scale function. + +use crate::utils::make_scalar_function; +use arrow::array::{ + Array, ArrayRef, Float64Array, GenericListArray, OffsetBufferBuilder, OffsetSizeTrait, +}; +use arrow::buffer::NullBuffer; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, Null}, + Field, +}; +use datafusion_common::cast::{as_float64_array, as_generic_list_array}; +use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; +use datafusion_common::{Result, internal_err, plan_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +make_udf_expr_and_func!( + ArrayScale, + array_scale, + array scalar, + "scales each element of a numeric array by a scalar.", + array_scale_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns a new array with each element of the input array multiplied by a scalar value, computed as `array[i] * scalar`. Returns NULL if the input row is NULL or the scalar is NULL. If a NULL element appears in the input array at position `i`, the result element at position `i` is NULL. Returns an empty array for an empty input array.", + syntax_example = "array_scale(array, scalar)", + sql_example = r#"```sql +> select array_scale([1.0, 2.0, 3.0], 2.0); ++----------------------------------+ +| array_scale(List([1.0,2.0,3.0]),Float64(2.0)) | ++----------------------------------+ +| [2.0, 4.0, 6.0] | ++----------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "scalar", + description = "Numeric scalar to multiply each element by. Can be a constant or column expression." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayScale { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayScale { + fn default() -> Self { + Self::new() + } +} + +impl ArrayScale { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_scale".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayScale { + fn name(&self) -> &str { + "array_scale" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // After `coerce_types`, `arg_types[0]` is one of List(Float64) or LargeList(Float64). + Ok(arg_types[0].clone()) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [array_type, scalar_type] = take_function_args(self.name(), arg_types)?; + let coercion = Some(&ListCoercion::FixedSizedListToList); + + if !matches!( + array_type, + Null | List(_) | LargeList(_) | FixedSizeList(..) + ) { + return plan_err!( + "{} first argument must be a list type, got {array_type}", + self.name() + ); + } + + if !scalar_type.is_numeric() && !matches!(scalar_type, Null) { + return plan_err!( + "{} second argument must be numeric, got {scalar_type}", + self.name() + ); + } + + let coerced_array = if matches!(array_type, Null) { + List(Arc::new(Field::new_list_field(DataType::Float64, true))) + } else { + coerced_type_with_base_type_only(array_type, &DataType::Float64, coercion) + }; + + Ok(vec![coerced_array, DataType::Float64]) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_scale_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn array_scale_inner(args: &[ArrayRef]) -> Result { + let [array, scalar] = take_function_args("array_scale", args)?; + match array.data_type() { + List(_) => general_array_scale::(array, scalar), + LargeList(_) => general_array_scale::(array, scalar), + arg_type => internal_err!( + "array_scale received unexpected type after coercion: {arg_type}" + ), + } +} + +fn general_array_scale( + array: &ArrayRef, + scalar: &ArrayRef, +) -> Result { + let list_array = as_generic_list_array::(array)?; + let scalar_array = as_float64_array(scalar)?; + + let values = as_float64_array(list_array.values())?; + let offsets = list_array.value_offsets(); + + // A row is null whenever either input row is null. The scalar applies + // uniformly across the array, so a null scalar makes the whole row + // undefined; union the two row-level null buffers in a single pass + // rather than tracking row nulls inside the value loop. + let row_nulls = NullBuffer::union(list_array.nulls(), scalar_array.nulls()); + + let mut value_builder = Float64Array::builder(values.len()); + let mut new_offsets = OffsetBufferBuilder::::new(list_array.len()); + + for row in 0..list_array.len() { + if row_nulls.as_ref().is_some_and(|nb| nb.is_null(row)) { + new_offsets.push_length(0); + continue; + } + + let start = offsets[row].as_usize(); + let end = offsets[row + 1].as_usize(); + let len = end - start; + let scalar_val = scalar_array.value(row); + + let slice = values.slice(start, len); + + // Per-element NULL propagation for NULL elements inside the array. + for i in 0..len { + if slice.is_null(i) { + value_builder.append_null(); + } else { + value_builder.append_value(slice.value(i) * scalar_val); + } + } + + new_offsets.push_length(len); + } + + let values_array = Arc::new(value_builder.finish()); + + // Preserve the inner field from the input array (including any user + // metadata). After `coerce_types` the inner type is Float64, but the + // input may still carry field-level annotations worth keeping. + let field = match list_array.data_type() { + List(f) | LargeList(f) => Arc::clone(f), + other => { + return internal_err!("array_scale unexpected list type: {other}"); + } + }; + + Ok(Arc::new(GenericListArray::::try_new( + field, + new_offsets.finish(), + values_array, + row_nulls, + )?)) +} diff --git a/datafusion/functions-nested/src/array_subtract.rs b/datafusion/functions-nested/src/array_subtract.rs new file mode 100644 index 0000000000000..24600da04f74e --- /dev/null +++ b/datafusion/functions-nested/src/array_subtract.rs @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_subtract function. + +use crate::utils::{ + array_math_binary_op, coerce_array_math_arg_types, make_scalar_function, +}; +use arrow::array::ArrayRef; +use arrow::datatypes::{ + DataType, + DataType::{LargeList, List}, +}; +use datafusion_common::{Result, exec_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; + +make_udf_expr_and_func!( + ArraySubtract, + array_subtract, + array1 array2, + "returns the element-wise difference of two numeric arrays.", + array_subtract_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the element-wise difference of two numeric arrays of equal length, computed as `array1[i] - array2[i]` per position. NULL is propagated per element: if either input element at position `i` is NULL, the corresponding output element is NULL (positions are preserved). Returns NULL if either entire input array is NULL. Errors if the per-row lengths differ. Returns an empty array if both inputs are empty.", + syntax_example = "array_subtract(array1, array2)", + sql_example = r#"```sql +> select array_subtract([10.0, 20.0, 30.0], [1.0, 2.0, 3.0]); ++--------------------------------------------------------------+ +| array_subtract(List([10.0,20.0,30.0]),List([1.0,2.0,3.0])) | ++--------------------------------------------------------------+ +| [9.0, 18.0, 27.0] | ++--------------------------------------------------------------+ +```"#, + argument( + name = "array1", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array2", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArraySubtract { + signature: Signature, + aliases: Vec, +} + +impl Default for ArraySubtract { + fn default() -> Self { + Self::new() + } +} + +impl ArraySubtract { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_subtract".to_string()], + } + } +} + +impl ScalarUDFImpl for ArraySubtract { + fn name(&self) -> &str { + "array_subtract" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [_, _] = take_function_args(self.name(), arg_types)?; + coerce_array_math_arg_types(self.name(), arg_types) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_subtract_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn array_subtract_inner(args: &[ArrayRef]) -> Result { + let [array1, array2] = take_function_args("array_subtract", args)?; + let sub = |a: f64, b: f64| a - b; + match (array1.data_type(), array2.data_type()) { + (List(_), List(_)) => { + array_math_binary_op::("array_subtract", array1, array2, sub) + } + (LargeList(_), LargeList(_)) => { + array_math_binary_op::("array_subtract", array1, array2, sub) + } + (arg_type1, arg_type2) => exec_err!( + "array_subtract received unexpected types after coercion: {arg_type1} and {arg_type2}" + ), + } +} diff --git a/datafusion/functions-nested/src/array_sum.rs b/datafusion/functions-nested/src/array_sum.rs new file mode 100644 index 0000000000000..d115355f5cbb9 --- /dev/null +++ b/datafusion/functions-nested/src/array_sum.rs @@ -0,0 +1,174 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_sum function. + +use crate::utils::make_scalar_function; +use arrow::array::{Array, ArrayRef, Float64Array, OffsetSizeTrait}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, Null}, + Field, +}; +use datafusion_common::cast::{as_float64_array, as_generic_list_array}; +use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; +use datafusion_common::{Result, internal_err, plan_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +make_udf_expr_and_func!( + ArraySum, + array_sum, + array, + "returns the sum of elements in a numeric array.", + array_sum_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the sum of the elements of the input array, computed as `array[0] + array[1] + ...`. NULL elements are skipped (per SQL aggregate convention). Returns NULL if the input row is NULL, every element is NULL, or the array is empty.", + syntax_example = "array_sum(array)", + sql_example = r#"```sql +> select array_sum([1.0, 2.0, 3.0]); ++----------------------------+ +| array_sum(List([1.0,2.0,3.0])) | ++----------------------------+ +| 6.0 | ++----------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArraySum { + signature: Signature, + aliases: Vec, +} + +impl Default for ArraySum { + fn default() -> Self { + Self::new() + } +} + +impl ArraySum { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_sum".to_string()], + } + } +} + +impl ScalarUDFImpl for ArraySum { + fn name(&self) -> &str { + "array_sum" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [arg_type] = take_function_args(self.name(), arg_types)?; + let coercion = Some(&ListCoercion::FixedSizedListToList); + + if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + return plan_err!("{} does not support type {arg_type}", self.name()); + } + + let coerced = if matches!(arg_type, Null) { + List(Arc::new(Field::new_list_field(DataType::Float64, true))) + } else { + coerced_type_with_base_type_only(arg_type, &DataType::Float64, coercion) + }; + + Ok(vec![coerced]) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_sum_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn array_sum_inner(args: &[ArrayRef]) -> Result { + let [array] = take_function_args("array_sum", args)?; + match array.data_type() { + List(_) => general_array_sum::(array), + LargeList(_) => general_array_sum::(array), + arg_type => { + internal_err!("array_sum received unexpected type after coercion: {arg_type}") + } + } +} + +fn general_array_sum(array: &ArrayRef) -> Result { + let list_array = as_generic_list_array::(array)?; + let values = as_float64_array(list_array.values())?; + let offsets = list_array.value_offsets(); + + let mut builder = Float64Array::builder(list_array.len()); + + for row in 0..list_array.len() { + if list_array.is_null(row) { + builder.append_null(); + continue; + } + + let start = offsets[row].as_usize(); + let end = offsets[row + 1].as_usize(); + + // Skip NULL elements per SQL aggregate convention (matches PostgreSQL + // array_sum, DuckDB list_sum, Spark aggregate). Empty arrays and + // all-NULL arrays both yield NULL — same behavior as SQL SUM over + // an empty set or all-NULL column. + let mut sum = 0.0_f64; + let mut any_valid = false; + for i in start..end { + if values.is_valid(i) { + sum += values.value(i); + any_valid = true; + } + } + + if any_valid { + builder.append_value(sum); + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) +} diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs new file mode 100644 index 0000000000000..1c1c5077344e1 --- /dev/null +++ b/datafusion/functions-nested/src/array_transform.rs @@ -0,0 +1,293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`datafusion_expr::HigherOrderUDF`] definitions for array_transform function. + +use arrow::{ + array::{Array, ArrayRef, AsArray, LargeListArray, ListArray}, + compute::take_arrays, + datatypes::{DataType, Field, FieldRef}, +}; +use datafusion_common::{ + Result, exec_err, plan_err, + utils::{adjust_offsets_for_slice, list_values_row_number, take_function_args}, +}; +use datafusion_expr::{ + ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, + HigherOrderSignature, HigherOrderUDFImpl, LambdaParametersProgress, ValueOrLambda, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +use crate::lambda_utils::{ + ListValuesResult, coerce_single_list_arg, extract_list_values, + single_list_lambda_parameters, +}; + +make_higher_order_function_expr_and_func!( + ArrayTransform, + array_transform, + array lambda, + "transforms the values of an array", + array_transform_higher_order_function +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "transforms the values of an array", + syntax_example = "array_transform(array, x -> x*2)", + sql_example = r#"```sql +> select array_transform([1, 2, 3, 4, 5], x -> x*2); ++-------------------------------------------+ +| array_transform([1, 2, 3, 4, 5], x -> x*2) | ++-------------------------------------------+ +| [2, 4, 6, 8, 10] | ++-------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument(name = "lambda", description = "Lambda") +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayTransform { + signature: HigherOrderSignature, + aliases: Vec, +} + +impl Default for ArrayTransform { + fn default() -> Self { + Self::new() + } +} + +impl ArrayTransform { + pub fn new() -> Self { + Self { + signature: HigherOrderSignature::exact( + vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())], + Volatility::Immutable, + ), + aliases: vec![String::from("list_transform")], + } + } +} + +impl HigherOrderUDFImpl for ArrayTransform { + fn name(&self) -> &str { + "array_transform" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &HigherOrderSignature { + &self.signature + } + + fn coerce_value_types(&self, arg_types: &[DataType]) -> Result> { + coerce_single_list_arg(self.name(), arg_types) + } + + fn lambda_parameters( + &self, + _step: usize, + fields: &[ValueOrLambda>], + ) -> Result { + single_list_lambda_parameters(self.name(), fields) + } + + fn return_field_from_args( + &self, + args: HigherOrderReturnFieldArgs, + ) -> Result> { + let [ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)] = + take_function_args(self.name(), args.arg_fields)? + else { + return plan_err!("{} expects a value followed by a lambda", self.name()); + }; + + //TODO: should metadata be copied into the transformed array? + + // lambda is the resulting field of executing the lambda body + // with the parameters returned in lambda_parameters + let field = Arc::new(Field::new( + Field::LIST_FIELD_DEFAULT_NAME, + lambda.data_type().clone(), + lambda.is_nullable(), + )); + + let return_type = match list.data_type() { + DataType::List(_) => DataType::List(field), + DataType::LargeList(_) => DataType::LargeList(field), + other => plan_err!("expected list, got {other}")?, + }; + + Ok(Arc::new(Field::new("", return_type, list.is_nullable()))) + } + + fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result { + let [list, lambda] = take_function_args(self.name(), &args.args)?; + let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)) = (list, lambda) + else { + return plan_err!("{} expects a value followed by a lambda", self.name()); + }; + + let list_array = list.to_array(args.number_rows)?; + + let list_values = match extract_list_values(&list_array, args.return_type())? { + ListValuesResult::EarlyReturn(v) => return Ok(v), + ListValuesResult::Values(v) => v, + }; + + // by passing closures, lambda.evaluate can evaluate only those actually needed + let values_param = || Ok(Arc::clone(&list_values)); + + // call the transforming lambda + let transformed_values = lambda + .evaluate(&[&values_param], |arrays| { + // if any column got captured, we need to adjust it to the values arrays, + // duplicating values of list with multitple values and removing values of empty lists + let indices = list_values_row_number(&list_array)?; + Ok(take_arrays(arrays, &indices, None)?) + })? + .into_array(list_values.len())?; + + let field = match args.return_field.data_type() { + DataType::List(field) | DataType::LargeList(field) => Arc::clone(field), + _ => { + return exec_err!( + "{} expected ScalarFunctionArgs.return_field to be a list, got {}", + self.name(), + args.return_field + ); + } + }; + + let transformed_list = match list_array.data_type() { + DataType::List(_) => { + let list = list_array.as_list(); + + // since we called list_values above which would return sliced values for + // a sliced list, we must adjust the offsets here as otherwise they would be invalid + let adjusted_offsets = adjust_offsets_for_slice(list); + + Arc::new(ListArray::new( + field, + adjusted_offsets, + transformed_values, + list.nulls().cloned(), + )) as ArrayRef + } + DataType::LargeList(_) => { + let large_list = list_array.as_list(); + + // since we called list_values above which would return sliced values for + // a sliced list, we must adjust the offsets here as otherwise they would be invalid + let adjusted_offsets = adjust_offsets_for_slice(large_list); + + Arc::new(LargeListArray::new( + field, + adjusted_offsets, + transformed_values, + large_list.nulls().cloned(), + )) + } + other => exec_err!("expected list, got {other}")?, + }; + + Ok(ColumnarValue::Array(transformed_list)) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +#[cfg(test)] +mod tests { + use arrow::{ + array::{Array, AsArray}, + buffer::{NullBuffer, OffsetBuffer}, + }; + + use crate::array_transform::array_transform_higher_order_function; + use crate::lambda_utils::test_utils::{create_i32_list, eval_hof_on_i32_list, v}; + use datafusion_expr::lit; + + fn divide_100_by( + list: impl Array + Clone + 'static, + ) -> datafusion_common::Result { + eval_hof_on_i32_list( + array_transform_higher_order_function(), + list, + lit(100i32) / v(), + ) + } + + #[test] + fn transform_on_sliced_list_should_not_evaluate_on_unreachable_values() { + let list = create_i32_list( + vec![ + // Have 0 here so if the expression is called on data that it will fail + 0, 4, 100, 25, 20, 5, 2, 1, 10, + ], + OffsetBuffer::::from_lengths(vec![1, 3, 4, 1]), + None, + ) + .slice(1, 3); + + let res = divide_100_by(list).unwrap(); + + let actual_list = res.as_list::(); + + let expected_list = create_i32_list( + vec![25, 1, 4, 5, 20, 50, 100, 10], + OffsetBuffer::::from_lengths(vec![3, 4, 1]), + None, + ); + + assert_eq!(actual_list, &expected_list); + } + + #[test] + fn transform_function_should_not_be_evaluated_on_values_underlying_null() { + let list = create_i32_list( + // 0 here for one of the values behind null, so if it will be evaluated + // it will fail due to divide by 0 + vec![100, 20, 10, 0, 1, 2, 0, 1, 50], + OffsetBuffer::::from_lengths(vec![3, 4, 2]), + Some(NullBuffer::from(vec![true, false, true])), + ); + + let res = divide_100_by(list).unwrap(); + + let actual_list = res.as_list::(); + + let expected_list = create_i32_list( + vec![1, 5, 10, 100, 2], + OffsetBuffer::::from_lengths(vec![3, 0, 2]), + Some(NullBuffer::from(vec![true, false, true])), + ); + + assert_eq!(actual_list.data_type(), expected_list.data_type()); + assert_eq!(actual_list, &expected_list); + } +} diff --git a/datafusion/functions-nested/src/arrays_zip.rs b/datafusion/functions-nested/src/arrays_zip.rs new file mode 100644 index 0000000000000..76b1b589f42f5 --- /dev/null +++ b/datafusion/functions-nested/src/arrays_zip.rs @@ -0,0 +1,613 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for arrays_zip function. + +use crate::utils::make_scalar_function; +use arrow::array::{ + Array, ArrayRef, Capacities, ListArray, MutableArrayData, NullBufferBuilder, + StructArray, new_null_array, +}; +use arrow::buffer::{NullBuffer, OffsetBuffer}; +use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null}; +use arrow::datatypes::{DataType, Field, Fields}; +use datafusion_common::cast::{ + as_fixed_size_list_array, as_large_list_array, as_list_array, +}; +use datafusion_common::{Result, exec_err}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +/// Type-erased view of a list column (works for both List and LargeList). +/// Stores the information needed to iterate rows without re-downcasting. +struct ListColumnView { + /// The flat values array backing this list column. + values: ArrayRef, + /// Pre-computed per-row start offsets (length = num_rows + 1). + offsets: Vec, + /// Null bitmap from the input array (None means no nulls). + nulls: Option, +} + +impl ListColumnView { + fn is_null(&self, idx: usize) -> bool { + self.nulls.as_ref().is_some_and(|n| n.is_null(idx)) + } +} + +make_udf_expr_and_func!( + ArraysZip, + arrays_zip, + "combines one or multiple arrays into a single array of structs.", + arrays_zip_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns an array of structs created by combining the elements of each input array at the same index. If the arrays have different lengths, shorter arrays are padded with NULLs.", + syntax_example = "arrays_zip(array1[, ..., array_n])", + sql_example = r#"```sql +> select arrays_zip([1, 2, 3]); ++---------------------------------------------------+ +| arrays_zip([1, 2, 3]) | ++---------------------------------------------------+ +| [{1: 1}, {1: 2}, {1: 3}] | ++---------------------------------------------------+ +> select arrays_zip([1, 2], [3, 4, 5]); ++---------------------------------------------------+ +| arrays_zip([1, 2], [3, 4, 5]) | ++---------------------------------------------------+ +| [{1: 1, 2: 3}, {1: 2, 2: 4}, {1: NULL, 2: 5}] | ++---------------------------------------------------+ +```"#, + argument(name = "array1", description = "First array expression."), + argument( + name = "array_n", + description = "Optional additional array expressions." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArraysZip { + signature: Signature, + aliases: Vec, +} + +impl Default for ArraysZip { + fn default() -> Self { + Self::new() + } +} + +impl ArraysZip { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec![String::from("list_zip")], + } + } +} + +impl ScalarUDFImpl for ArraysZip { + fn name(&self) -> &str { + "arrays_zip" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.is_empty() { + return exec_err!("arrays_zip requires at least one argument"); + } + + let mut fields = Vec::with_capacity(arg_types.len()); + for (i, arg_type) in arg_types.iter().enumerate() { + let element_type = match arg_type { + List(field) | LargeList(field) | FixedSizeList(field, _) => { + field.data_type().clone() + } + Null => Null, + dt => { + return exec_err!("arrays_zip expects array arguments, got {dt}"); + } + }; + fields.push(Field::new(arrays_zip_field_name(i), element_type, true)); + } + + Ok(List(Arc::new(Field::new_list_field( + DataType::Struct(Fields::from(fields)), + true, + )))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(arrays_zip_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +/// Core implementation for arrays_zip. +/// +/// Takes N list arrays and produces a list of structs where each struct +/// has one field per input array. If arrays within a row have different +/// lengths, shorter arrays are padded with NULLs. +/// Supports List, LargeList, and Null input types. +fn arrays_zip_inner(args: &[ArrayRef]) -> Result { + if args.is_empty() { + return exec_err!("arrays_zip requires at least one argument"); + } + + let field_names = arrays_zip_field_names(args.len()); + let num_rows = args[0].len(); + + if let Some(result) = try_perfect_list_zip(args, &field_names)? { + return Ok(result); + } + + // Build a type-erased ListColumnView for each argument. + // None means the argument is Null-typed (all nulls, no backing data). + let mut views: Vec> = Vec::with_capacity(args.len()); + let mut element_types: Vec = Vec::with_capacity(args.len()); + + for (i, arg) in args.iter().enumerate() { + match arg.data_type() { + List(field) => { + let arr = as_list_array(arg)?; + let raw_offsets = arr.value_offsets(); + let offsets: Vec = + raw_offsets.iter().map(|&o| o as usize).collect(); + element_types.push(field.data_type().clone()); + views.push(Some(ListColumnView { + values: Arc::clone(arr.values()), + offsets, + nulls: arr.nulls().cloned(), + })); + } + LargeList(field) => { + let arr = as_large_list_array(arg)?; + let raw_offsets = arr.value_offsets(); + let offsets: Vec = + raw_offsets.iter().map(|&o| o as usize).collect(); + element_types.push(field.data_type().clone()); + views.push(Some(ListColumnView { + values: Arc::clone(arr.values()), + offsets, + nulls: arr.nulls().cloned(), + })); + } + FixedSizeList(field, size) => { + let arr = as_fixed_size_list_array(arg)?; + let size = *size as usize; + let offsets: Vec = (0..=num_rows).map(|row| row * size).collect(); + element_types.push(field.data_type().clone()); + views.push(Some(ListColumnView { + values: Arc::clone(arr.values()), + offsets, + nulls: arr.nulls().cloned(), + })); + } + Null => { + element_types.push(Null); + views.push(None); + } + dt => { + return exec_err!("arrays_zip argument {i} expected list type, got {dt}"); + } + } + } + + // Collect per-column values data for MutableArrayData builders. + let values_data: Vec<_> = views + .iter() + .map(|v| v.as_ref().map(|view| view.values.to_data())) + .collect(); + + let struct_fields: Fields = element_types + .iter() + .zip(field_names.iter()) + .map(|(dt, name)| Field::new(name.clone(), dt.clone(), true)) + .collect::>() + .into(); + + // Create a MutableArrayData builder per column. For None (Null-typed) + // args we only need extend_nulls, so we track them separately. + let mut builders: Vec> = values_data + .iter() + .map(|vd| { + vd.as_ref().map(|data| { + MutableArrayData::with_capacities(vec![data], true, Capacities::Array(0)) + }) + }) + .collect(); + + let mut offsets: Vec = Vec::with_capacity(num_rows + 1); + offsets.push(0); + let mut null_builder = NullBufferBuilder::new(num_rows); + let mut total_values: usize = 0; + + // Process each row: compute per-array lengths, then copy values + // and pad shorter arrays with NULLs. + for row_idx in 0..num_rows { + let mut max_len: usize = 0; + let mut all_null = true; + + for view in views.iter().flatten() { + if !view.is_null(row_idx) { + all_null = false; + let len = view.offsets[row_idx + 1] - view.offsets[row_idx]; + max_len = max_len.max(len); + } + } + + if all_null { + null_builder.append_null(); + offsets.push(*offsets.last().unwrap()); + continue; + } + null_builder.append_non_null(); + + // Extend each column builder for this row. + for (col_idx, view) in views.iter().enumerate() { + match view { + Some(v) if !v.is_null(row_idx) => { + let start = v.offsets[row_idx]; + let end = v.offsets[row_idx + 1]; + let len = end - start; + let builder = builders[col_idx].as_mut().unwrap(); + builder.extend(0, start, end); + if len < max_len { + builder.extend_nulls(max_len - len); + } + } + _ => { + // Null list entry or None (Null-typed) arg — all nulls. + if let Some(builder) = builders[col_idx].as_mut() { + builder.extend_nulls(max_len); + } + } + } + } + + total_values += max_len; + let last = *offsets.last().unwrap(); + offsets.push(last + max_len as i32); + } + + // Assemble struct columns from builders. + let struct_columns: Vec = builders + .into_iter() + .zip(element_types.iter()) + .map(|(builder, elem_type)| match builder { + Some(b) => arrow::array::make_array(b.freeze()), + None => new_null_array( + if elem_type.is_null() { + &Null + } else { + elem_type + }, + total_values, + ), + }) + .collect(); + + let struct_array = StructArray::try_new(struct_fields, struct_columns, None)?; + + let null_buffer = null_builder.finish(); + + let result = ListArray::try_new( + Arc::new(Field::new_list_field( + struct_array.data_type().clone(), + true, + )), + OffsetBuffer::new(offsets.into()), + Arc::new(struct_array), + null_buffer, + )?; + + Ok(Arc::new(result)) +} + +fn arrays_zip_field_name(index: usize) -> String { + (index + 1).to_string() +} + +fn arrays_zip_field_names(len: usize) -> Vec { + (0..len).map(arrays_zip_field_name).collect() +} + +/// Fast path for regular List inputs whose existing buffers already match the +/// zipped output: all offsets and values lengths match, and null rows cover no +/// values. This lets us reuse offsets and child values instead of rebuilding. +fn try_perfect_list_zip( + args: &[ArrayRef], + field_names: &[String], +) -> Result> { + debug_assert_eq!(args.len(), field_names.len()); + + let mut list_arrays = Vec::with_capacity(args.len()); + let mut struct_fields = Vec::with_capacity(args.len()); + + for (arg, field_name) in args.iter().zip(field_names) { + let arr = match arg.data_type() { + List(field) => { + struct_fields.push(Field::new( + field_name.clone(), + field.data_type().clone(), + true, + )); + as_list_array(arg)? + } + _ => return Ok(None), + }; + + list_arrays.push(arr); + } + + let first = list_arrays[0]; + let num_rows = first.len(); + let offsets = first.offsets().clone(); + let values_len = first.values().len(); + + // Reusing the child arrays is only valid when every list uses the exact + // same row boundaries and exposes the same total number of child values. + for arr in &list_arrays { + if arr.values().len() != values_len || arr.offsets() != &offsets { + return Ok(None); + } + } + + let nulls = if list_arrays.iter().any(|arr| arr.null_count() != 0) { + let first_nulls = first.nulls(); + if list_arrays.iter().all(|arr| arr.nulls() == first_nulls) { + first_nulls.cloned() + } else { + // Match the general path: arrays_zip only marks an output row null + // when every concrete input list is null. Mixed null and non-null + // empty lists still produce a non-null empty list, but mixed null + // rows with values must fall back to preserve field-level nulls. + let mut null_builder = NullBufferBuilder::new(num_rows); + for row_idx in 0..num_rows { + let mut all_null = true; + + for arr in &list_arrays { + if arr.is_null(row_idx) { + if arr.offsets()[row_idx + 1] != arr.offsets()[row_idx] { + return Ok(None); + } + } else { + all_null = false; + } + } + + if all_null { + null_builder.append_null(); + } else { + null_builder.append_non_null(); + } + } + + null_builder.finish() + } + } else { + None + }; + + let struct_columns = list_arrays + .iter() + .map(|arr| Arc::clone(arr.values())) + .collect::>(); + let struct_array = + StructArray::try_new(Fields::from(struct_fields), struct_columns, None)?; + let result = ListArray::try_new( + Arc::new(Field::new_list_field( + struct_array.data_type().clone(), + true, + )), + offsets, + Arc::new(struct_array), + nulls, + )?; + + Ok(Some(Arc::new(result))) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int64Array; + use arrow::buffer::NullBuffer; + + fn list(values: Vec, offsets: Vec) -> Arc { + list_with_validity(values, offsets, None) + } + + fn list_with_validity( + values: Vec, + offsets: Vec, + valid: Option>, + ) -> Arc { + Arc::new( + ListArray::try_new( + Arc::new(Field::new_list_field(DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(Int64Array::from(values)), + valid.map(NullBuffer::from), + ) + .unwrap(), + ) + } + + #[test] + fn perfect_zip_reuses_input_values_and_offsets() { + let left = list(vec![1, 2, 3, 4, 5, 6], vec![0, 2, 3, 6]); + let right = list(vec![10, 20, 30, 40, 50, 60], vec![0, 2, 3, 6]); + + let result = arrays_zip_inner(&[ + Arc::clone(&left) as ArrayRef, + Arc::clone(&right) as ArrayRef, + ]) + .unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + let values = result + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(result.offsets().ptr_eq(left.offsets())); + assert!(Arc::ptr_eq(values.column(0), left.values())); + assert!(Arc::ptr_eq(values.column(1), right.values())); + } + + #[test] + fn perfect_zip_uses_supplied_field_names() { + let left = list(vec![1, 2, 3], vec![0, 1, 3]); + let right = list(vec![10, 20, 30], vec![0, 1, 3]); + let field_names = vec!["left".to_string(), "right".to_string()]; + + let result = try_perfect_list_zip( + &[ + Arc::clone(&left) as ArrayRef, + Arc::clone(&right) as ArrayRef, + ], + &field_names, + ) + .unwrap() + .unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + let values = result + .values() + .as_any() + .downcast_ref::() + .unwrap(); + let names = values + .fields() + .iter() + .map(|field| field.name().as_str()) + .collect::>(); + + assert_eq!(names, vec!["left", "right"]); + } + + #[test] + fn perfect_zip_reuses_zero_length_null_rows() { + let left = list_with_validity( + vec![1, 2, 3, 4], + vec![0, 2, 2, 4], + Some(vec![true, false, true]), + ); + let right = list_with_validity( + vec![10, 20, 30, 40], + vec![0, 2, 2, 4], + Some(vec![true, false, true]), + ); + + let result = arrays_zip_inner(&[ + Arc::clone(&left) as ArrayRef, + Arc::clone(&right) as ArrayRef, + ]) + .unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert!(result.offsets().ptr_eq(left.offsets())); + assert!(result.is_null(1)); + } + + #[test] + fn perfect_zip_preserves_mixed_null_empty_rows() { + let left = + list_with_validity(vec![], vec![0, 0, 0, 0], Some(vec![false, true, false])); + let right = + list_with_validity(vec![], vec![0, 0, 0, 0], Some(vec![true, false, false])); + + let result = arrays_zip_inner(&[ + Arc::clone(&left) as ArrayRef, + Arc::clone(&right) as ArrayRef, + ]) + .unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert!(result.offsets().ptr_eq(left.offsets())); + assert!(!result.is_null(0)); + assert!(!result.is_null(1)); + assert!(result.is_null(2)); + } + + #[test] + fn perfect_zip_reuses_null_rows_with_hidden_values() { + let left = + list_with_validity(vec![1, 2, 3, 4], vec![0, 2, 4], Some(vec![true, false])); + let right = list_with_validity( + vec![10, 20, 30, 40], + vec![0, 2, 4], + Some(vec![true, false]), + ); + + let result = arrays_zip_inner(&[ + Arc::clone(&left) as ArrayRef, + Arc::clone(&right) as ArrayRef, + ]) + .unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert!(result.offsets().ptr_eq(left.offsets())); + assert_eq!(result.value_offsets(), &[0, 2, 4]); + assert!(result.is_null(1)); + } + + #[test] + fn mixed_null_row_with_hidden_values_uses_general_path() { + let left = + list_with_validity(vec![1, 2, 3, 4], vec![0, 2, 4], Some(vec![true, false])); + let right = list_with_validity( + vec![10, 20, 30, 40], + vec![0, 2, 4], + Some(vec![true, true]), + ); + + let result = arrays_zip_inner(&[ + Arc::clone(&left) as ArrayRef, + Arc::clone(&right) as ArrayRef, + ]) + .unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + let values = result + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(!result.offsets().ptr_eq(left.offsets())); + assert_eq!(result.value_offsets(), &[0, 2, 4]); + assert!(values.column(0).is_null(2)); + assert!(values.column(0).is_null(3)); + assert!(!values.column(1).is_null(2)); + assert!(!values.column(1).is_null(3)); + } +} diff --git a/datafusion/functions-nested/src/cardinality.rs b/datafusion/functions-nested/src/cardinality.rs index 58a83feb66764..d21bb72a457a8 100644 --- a/datafusion/functions-nested/src/cardinality.rs +++ b/datafusion/functions-nested/src/cardinality.rs @@ -25,16 +25,15 @@ use arrow::datatypes::{ DataType, DataType::{LargeList, List, Map, Null, UInt64}, }; +use datafusion_common::Result; use datafusion_common::cast::{as_large_list_array, as_list_array, as_map_array}; use datafusion_common::exec_err; -use datafusion_common::utils::{take_function_args, ListCoercion}; -use datafusion_common::Result; +use datafusion_common::utils::{ListCoercion, take_function_args}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, - ScalarUDFImpl, Signature, TypeSignature, Volatility, + ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; make_udf_expr_and_func!( @@ -90,9 +89,6 @@ impl Default for Cardinality { } } impl ScalarUDFImpl for Cardinality { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "cardinality" } @@ -105,10 +101,7 @@ impl ScalarUDFImpl for Cardinality { Ok(UInt64) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(cardinality_inner)(&args.args) } @@ -120,7 +113,7 @@ impl ScalarUDFImpl for Cardinality { fn cardinality_inner(args: &[ArrayRef]) -> Result { let [array] = take_function_args("cardinality", args)?; match array.data_type() { - Null => Ok(Arc::new(UInt64Array::from_value(0, array.len()))), + Null => Ok(Arc::new(UInt64Array::new_null(array.len()))), List(_) => { let list_array = as_list_array(array)?; generic_list_cardinality::(list_array) @@ -152,9 +145,14 @@ fn generic_list_cardinality( ) -> Result { let result = array .iter() - .map(|arr| match crate::utils::compute_array_dims(arr)? { - Some(vector) => Ok(Some(vector.iter().map(|x| x.unwrap()).product::())), - None => Ok(None), + .map(|arr| match arr { + Some(arr) if arr.is_empty() => Ok(Some(0u64)), + arr => match crate::utils::compute_array_dims(arr)? { + Some(vector) => { + Ok(Some(vector.iter().map(|x| x.unwrap()).product::())) + } + None => Ok(None), + }, }) .collect::>()?; Ok(Arc::new(result) as ArrayRef) diff --git a/datafusion/functions-nested/src/concat.rs b/datafusion/functions-nested/src/concat.rs index a565006a2577d..8d06140889a55 100644 --- a/datafusion/functions-nested/src/concat.rs +++ b/datafusion/functions-nested/src/concat.rs @@ -17,21 +17,20 @@ //! [`ScalarUDFImpl`] definitions for `array_append`, `array_prepend` and `array_concat` functions. -use std::any::Any; use std::sync::Arc; use crate::make_array::make_array_inner; use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function}; use arrow::array::{ Array, ArrayData, ArrayRef, Capacities, GenericListArray, MutableArrayData, - NullBufferBuilder, OffsetSizeTrait, + OffsetSizeTrait, }; -use arrow::buffer::OffsetBuffer; +use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::datatypes::{DataType, Field}; +use datafusion_common::Result; use datafusion_common::utils::{ - base_type, coerced_type_with_base_type_only, ListCoercion, + ListCoercion, base_type, coerced_type_with_base_type_only, }; -use datafusion_common::Result; use datafusion_common::{ cast::as_generic_list_array, exec_err, plan_err, @@ -39,7 +38,8 @@ use datafusion_common::{ }; use datafusion_expr::binary::type_union_resolution; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; use itertools::Itertools; @@ -96,10 +96,6 @@ impl ArrayAppend { } impl ScalarUDFImpl for ArrayAppend { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_append" } @@ -117,10 +113,7 @@ impl ScalarUDFImpl for ArrayAppend { } } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_append_inner)(&args.args) } @@ -185,10 +178,6 @@ impl ArrayPrepend { } impl ScalarUDFImpl for ArrayPrepend { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_prepend" } @@ -206,10 +195,7 @@ impl ScalarUDFImpl for ArrayPrepend { } } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_prepend_inner)(&args.args) } @@ -276,10 +262,6 @@ impl ArrayConcat { } impl ScalarUDFImpl for ArrayConcat { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_concat" } @@ -297,7 +279,7 @@ impl ScalarUDFImpl for ArrayConcat { DataType::Null | DataType::List(_) | DataType::FixedSizeList(..) => (), DataType::LargeList(_) => large_list = true, arg_type => { - return plan_err!("{} does not support type {arg_type}", self.name()) + return plan_err!("{} does not support type {arg_type}", self.name()); } } @@ -326,10 +308,7 @@ impl ScalarUDFImpl for ArrayConcat { } } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_concat_inner)(&args.args) } @@ -338,10 +317,23 @@ impl ScalarUDFImpl for ArrayConcat { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let base_type = base_type(&self.return_type(arg_types)?); + let return_type = self.return_type(arg_types)?; + let base_type = base_type(&return_type); let coercion = Some(&ListCoercion::FixedSizedListToList); + // When the return type is a `LargeList`, the outer container of every + // input must be widened to `LargeList` as well. Otherwise + // `array_concat_inner` would later try to downcast a `List` argument + // to `GenericListArray` and fail. + let promote_to_large_list = matches!(return_type, DataType::LargeList(_)); let arg_types = arg_types.iter().map(|arg_type| { - coerced_type_with_base_type_only(arg_type, &base_type, coercion) + let coerced = + coerced_type_with_base_type_only(arg_type, &base_type, coercion); + match coerced { + DataType::List(field) if promote_to_large_list => { + DataType::LargeList(field) + } + other => other, + } }); Ok(arg_types.collect()) @@ -352,7 +344,7 @@ impl ScalarUDFImpl for ArrayConcat { } } -fn array_concat_inner(args: &[ArrayRef]) -> Result { +pub fn array_concat_inner(args: &[ArrayRef]) -> Result { if args.is_empty() { return exec_err!("array_concat expects at least one argument"); } @@ -396,58 +388,65 @@ fn concat_internal(args: &[ArrayRef]) -> Result { .iter() .map(|arg| as_generic_list_array::(arg)) .collect::>>()?; - // Assume number of rows is the same for all arrays let row_count = list_arrays[0].len(); - let mut array_lengths = vec![]; - let mut arrays = vec![]; - let mut valid = NullBufferBuilder::new(row_count); - for i in 0..row_count { - let nulls = list_arrays + // Extract underlying values ArrayData from each list array for MutableArrayData. + let values_data: Vec = + list_arrays.iter().map(|la| la.values().to_data()).collect(); + let values_data_refs: Vec<&ArrayData> = values_data.iter().collect(); + + // Estimate capacity as the sum of all values arrays' lengths. + let total_capacity: usize = values_data.iter().map(|d| d.len()).sum(); + + let mut mutable = MutableArrayData::with_capacities( + values_data_refs, + false, + Capacities::Array(total_capacity), + ); + let mut offsets: Vec = Vec::with_capacity(row_count + 1); + offsets.push(O::zero()); + + // Compute the output null buffer: a row is null only if null in ALL input + // arrays. This is the bitwise OR of validity bits (valid if valid in ANY + // input). If any array has no null buffer (all valid), no output row can be + // null. + let nulls = list_arrays + .iter() + .filter_map(|la| la.nulls()) + .collect::>(); + let valid = if nulls.len() == list_arrays.len() { + nulls .iter() - .map(|arr| arr.is_null(i)) - .collect::>(); - - // If all the arrays are null, the concatenated array is null - let is_null = nulls.iter().all(|&x| x); - if is_null { - array_lengths.push(0); - valid.append_null(); - } else { - // Get all the arrays on i-th row - let values = list_arrays - .iter() - .map(|arr| arr.value(i)) - .collect::>(); - - let elements = values - .iter() - .map(|a| a.as_ref()) - .collect::>(); - - // Concatenated array on i-th row - let concatenated_array = arrow::compute::concat(elements.as_slice())?; - array_lengths.push(concatenated_array.len()); - arrays.push(concatenated_array); - valid.append_non_null(); + .map(|n| n.inner().clone()) + .reduce(|a, b| &a | &b) + .map(NullBuffer::new) + } else { + None + }; + + for row_idx in 0..row_count { + for (arr_idx, list_array) in list_arrays.iter().enumerate() { + if list_array.is_null(row_idx) { + continue; + } + let start = list_array.offsets()[row_idx].to_usize().unwrap(); + let end = list_array.offsets()[row_idx + 1].to_usize().unwrap(); + if start < end { + mutable.extend(arr_idx, start, end); + } } + offsets.push(O::usize_as(mutable.len())); } - // Assume all arrays have the same data type - let data_type = list_arrays[0].value_type(); - let elements = arrays - .iter() - .map(|a| a.as_ref()) - .collect::>(); + let data_type = list_arrays[0].value_type(); + let data = mutable.freeze(); - let list_arr = GenericListArray::::new( + Ok(Arc::new(GenericListArray::::try_new( Arc::new(Field::new_list_field(data_type, true)), - OffsetBuffer::from_lengths(array_lengths), - Arc::new(arrow::compute::concat(elements.as_slice())?), - valid.finish(), - ); - - Ok(Arc::new(list_arr)) + OffsetBuffer::new(offsets.into()), + arrow::array::make_array(data), + valid, + )?)) } // Kernel functions diff --git a/datafusion/functions-nested/src/cosine_distance.rs b/datafusion/functions-nested/src/cosine_distance.rs new file mode 100644 index 0000000000000..335856075046c --- /dev/null +++ b/datafusion/functions-nested/src/cosine_distance.rs @@ -0,0 +1,219 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for cosine_distance function. + +use crate::utils::make_scalar_function; +use arrow::array::{Array, ArrayRef, Float64Array, OffsetSizeTrait}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, Null}, + Field, +}; +use datafusion_common::cast::{as_float64_array, as_generic_list_array}; +use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; +use datafusion_common::{ + Result, exec_err, internal_err, plan_err, utils::take_function_args, +}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +make_udf_expr_and_func!( + CosineDistance, + cosine_distance, + array1 array2, + "returns the cosine distance between two numeric arrays.", + cosine_distance_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the cosine distance between two input arrays of equal length. The cosine distance is defined as 1 - cosine_similarity, i.e. `1 - dot(a,b) / (||a|| * ||b||)`. Returns NULL if either array is NULL or contains only zeros.", + syntax_example = "cosine_distance(array1, array2)", + sql_example = r#"```sql +> select cosine_distance([1.0, 0.0], [0.0, 1.0]); ++-----------------------------------------------+ +| cosine_distance(List([1.0,0.0]),List([0.0,1.0])) | ++-----------------------------------------------+ +| 1.0 | ++-----------------------------------------------+ +```"#, + argument( + name = "array1", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array2", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct CosineDistance { + signature: Signature, +} + +impl Default for CosineDistance { + fn default() -> Self { + Self::new() + } +} + +impl CosineDistance { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for CosineDistance { + fn name(&self) -> &str { + "cosine_distance" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [_, _] = take_function_args(self.name(), arg_types)?; + let coercion = Some(&ListCoercion::FixedSizedListToList); + + for arg_type in arg_types { + if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + return plan_err!("{} does not support type {arg_type}", self.name()); + } + } + + // If any input is `LargeList`, both sides must be widened to `LargeList` + // so the runtime dispatch in `cosine_distance_inner` sees a homogeneous + // pair. Follows the pattern in `ArrayConcat::coerce_types`. + let any_large_list = arg_types.iter().any(|t| matches!(t, LargeList(_))); + + let coerced = arg_types + .iter() + .map(|arg_type| { + if matches!(arg_type, Null) { + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + return if any_large_list { + LargeList(field) + } else { + List(field) + }; + } + let coerced = coerced_type_with_base_type_only( + arg_type, + &DataType::Float64, + coercion, + ); + match coerced { + List(field) if any_large_list => LargeList(field), + other => other, + } + }) + .collect(); + + Ok(coerced) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(cosine_distance_inner)(&args.args) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn cosine_distance_inner(args: &[ArrayRef]) -> Result { + let [array1, array2] = take_function_args("cosine_distance", args)?; + match (array1.data_type(), array2.data_type()) { + (List(_), List(_)) => general_cosine_distance::(args), + (LargeList(_), LargeList(_)) => general_cosine_distance::(args), + (arg_type1, arg_type2) => internal_err!( + "cosine_distance received unexpected types after coercion: {arg_type1} and {arg_type2}" + ), + } +} + +fn general_cosine_distance(arrays: &[ArrayRef]) -> Result { + let list_array1 = as_generic_list_array::(&arrays[0])?; + let list_array2 = as_generic_list_array::(&arrays[1])?; + + let values1 = as_float64_array(list_array1.values())?; + let values2 = as_float64_array(list_array2.values())?; + let offsets1 = list_array1.value_offsets(); + let offsets2 = list_array2.value_offsets(); + + let mut builder = Float64Array::builder(list_array1.len()); + for row in 0..list_array1.len() { + if list_array1.is_null(row) || list_array2.is_null(row) { + builder.append_null(); + continue; + } + + let start1 = offsets1[row].as_usize(); + let end1 = offsets1[row + 1].as_usize(); + let start2 = offsets2[row].as_usize(); + let end2 = offsets2[row + 1].as_usize(); + let len1 = end1 - start1; + let len2 = end2 - start2; + + if len1 != len2 { + return exec_err!( + "cosine_distance requires both list inputs to have the same length, got {len1} and {len2}" + ); + } + + let slice1 = values1.slice(start1, len1); + let slice2 = values2.slice(start2, len2); + if slice1.null_count() != 0 || slice2.null_count() != 0 { + builder.append_null(); + continue; + } + + let vals1 = slice1.values(); + let vals2 = slice2.values(); + + let mut dot = 0.0; + let mut sq1 = 0.0; + let mut sq2 = 0.0; + for i in 0..len1 { + let a = vals1[i]; + let b = vals2[i]; + dot += a * b; + sq1 += a * a; + sq2 += b * b; + } + + if sq1 == 0.0 || sq2 == 0.0 { + builder.append_null(); + } else { + builder.append_value(1.0 - dot / (sq1.sqrt() * sq2.sqrt())); + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) +} diff --git a/datafusion/functions-nested/src/dimension.rs b/datafusion/functions-nested/src/dimension.rs index d0fa294fe42db..01fb81d878e0b 100644 --- a/datafusion/functions-nested/src/dimension.rs +++ b/datafusion/functions-nested/src/dimension.rs @@ -23,17 +23,17 @@ use arrow::datatypes::{ DataType::{FixedSizeList, LargeList, List, Null, UInt64}, UInt64Type, }; -use std::any::Any; use datafusion_common::cast::{ as_fixed_size_list_array, as_large_list_array, as_list_array, }; -use datafusion_common::{exec_err, utils::take_function_args, Result}; +use datafusion_common::{Result, exec_err, utils::take_function_args}; use crate::utils::{compute_array_dims, make_scalar_function}; use datafusion_common::utils::list_ndims; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; use itertools::Itertools; @@ -86,9 +86,6 @@ impl ArrayDims { } impl ScalarUDFImpl for ArrayDims { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "array_dims" } @@ -101,10 +98,7 @@ impl ScalarUDFImpl for ArrayDims { Ok(DataType::new_list(UInt64, true)) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_dims_inner)(&args.args) } @@ -158,9 +152,6 @@ impl ArrayNdims { } impl ScalarUDFImpl for ArrayNdims { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "array_ndims" } @@ -173,10 +164,7 @@ impl ScalarUDFImpl for ArrayNdims { Ok(UInt64) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_ndims_inner)(&args.args) } diff --git a/datafusion/functions-nested/src/distance.rs b/datafusion/functions-nested/src/distance.rs index dc8eaa699f878..edf1806b66c2d 100644 --- a/datafusion/functions-nested/src/distance.rs +++ b/datafusion/functions-nested/src/distance.rs @@ -29,15 +29,15 @@ use datafusion_common::cast::{ as_float32_array, as_float64_array, as_generic_list_array, as_int32_array, as_int64_array, }; -use datafusion_common::utils::{coerced_type_with_base_type_only, ListCoercion}; -use datafusion_common::{exec_err, plan_err, utils::take_function_args, Result}; +use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; +use datafusion_common::{Result, exec_err, plan_err, utils::take_function_args}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_functions::downcast_arg; use datafusion_macros::user_doc; use itertools::Itertools; -use std::any::Any; use std::sync::Arc; make_udf_expr_and_func!( @@ -91,10 +91,6 @@ impl ArrayDistance { } impl ScalarUDFImpl for ArrayDistance { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_distance" } @@ -125,10 +121,7 @@ impl ScalarUDFImpl for ArrayDistance { arg_types.try_collect() } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_distance_inner)(&args.args) } diff --git a/datafusion/functions-nested/src/empty.rs b/datafusion/functions-nested/src/empty.rs index 3f90775752054..262eb4935c968 100644 --- a/datafusion/functions-nested/src/empty.rs +++ b/datafusion/functions-nested/src/empty.rs @@ -25,12 +25,12 @@ use arrow::datatypes::{ DataType::{Boolean, FixedSizeList, LargeList, List}, }; use datafusion_common::cast::as_generic_list_array; -use datafusion_common::{exec_err, utils::take_function_args, Result}; +use datafusion_common::{Result, exec_err, utils::take_function_args}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; make_udf_expr_and_func!( @@ -79,9 +79,6 @@ impl ArrayEmpty { } impl ScalarUDFImpl for ArrayEmpty { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "empty" } @@ -94,10 +91,7 @@ impl ScalarUDFImpl for ArrayEmpty { Ok(Boolean) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_empty_inner)(&args.args) } diff --git a/datafusion/functions-nested/src/except.rs b/datafusion/functions-nested/src/except.rs index 8b6bcaa0620c1..dbf815c0ec539 100644 --- a/datafusion/functions-nested/src/except.rs +++ b/datafusion/functions-nested/src/except.rs @@ -15,20 +15,26 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarUDFImpl`] definitions for array_except function. +//! [`ScalarUDFImpl`] definition for array_except function. use crate::utils::{check_datatypes, make_scalar_function}; -use arrow::array::{cast::AsArray, Array, ArrayRef, GenericListArray, OffsetSizeTrait}; -use arrow::buffer::OffsetBuffer; +use arrow::array::new_null_array; +use arrow::array::{ + Array, ArrayRef, GenericListArray, OffsetSizeTrait, UInt32Array, UInt64Array, + cast::AsArray, +}; +use arrow::buffer::{NullBuffer, OffsetBuffer}; +use arrow::compute::take; use arrow::datatypes::{DataType, FieldRef}; use arrow::row::{RowConverter, SortField}; -use datafusion_common::utils::{take_function_args, ListCoercion}; -use datafusion_common::{internal_err, HashSet, Result}; +use datafusion_common::utils::{ListCoercion, normalize_float_zero, take_function_args}; +use datafusion_common::{HashSet, Result, internal_err}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; +use itertools::Itertools; use std::sync::Arc; make_udf_expr_and_func!( @@ -92,9 +98,6 @@ impl ArrayExcept { } impl ScalarUDFImpl for ArrayExcept { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "array_except" } @@ -104,16 +107,16 @@ impl ScalarUDFImpl for ArrayExcept { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match (&arg_types[0].clone(), &arg_types[1].clone()) { - (DataType::Null, _) | (_, DataType::Null) => Ok(arg_types[0].clone()), + match (&arg_types[0], &arg_types[1]) { + (DataType::Null, DataType::Null) => { + Ok(DataType::new_list(DataType::Null, true)) + } + (DataType::Null, dt) | (dt, DataType::Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_except_inner)(&args.args) } @@ -129,8 +132,16 @@ impl ScalarUDFImpl for ArrayExcept { fn array_except_inner(args: &[ArrayRef]) -> Result { let [array1, array2] = take_function_args("array_except", args)?; + let len = array1.len(); match (array1.data_type(), array2.data_type()) { - (DataType::Null, _) | (_, DataType::Null) => Ok(array1.to_owned()), + (DataType::Null, DataType::Null) => Ok(new_null_array( + &DataType::new_list(DataType::Null, true), + len, + )), + (DataType::Null, dt @ DataType::List(_)) + | (DataType::Null, dt @ DataType::LargeList(_)) + | (dt @ DataType::List(_), DataType::Null) + | (dt @ DataType::LargeList(_), DataType::Null) => Ok(new_null_array(dt, len)), (DataType::List(field), DataType::List(_)) => { check_datatypes("array_except", &[array1, array2])?; let list1 = array1.as_list::(); @@ -158,43 +169,137 @@ fn general_except( ) -> Result> { let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; - let l_values = l.values().to_owned(); - let r_values = r.values().to_owned(); - let l_values = converter.convert_columns(&[l_values])?; - let r_values = converter.convert_columns(&[r_values])?; + // Normalize -0.0 → +0.0 so RowConverter (IEEE 754 totalOrder) groups + // ±0 together for both the rhs lookup set and the lhs probe. + let l_values_norm = normalize_float_zero(l.values()); + let r_values_norm = normalize_float_zero(r.values()); + + // Only convert the visible portion of the values array. For sliced + // ListArrays, values() returns the full underlying array but only + // elements between the first and last offset are referenced. + let l_first = l.offsets()[0].as_usize(); + let l_len = l.offsets()[l.len()].as_usize() - l_first; + let l_values = converter.convert_columns(&[l_values_norm.slice(l_first, l_len)])?; + + let r_first = r.offsets()[0].as_usize(); + let r_len = r.offsets()[r.len()].as_usize() - r_first; + let r_values = converter.convert_columns(&[r_values_norm.slice(r_first, r_len)])?; let mut offsets = Vec::::with_capacity(l.len() + 1); offsets.push(OffsetSize::usize_as(0)); - let mut rows = Vec::with_capacity(l_values.num_rows()); + let mut indices: Vec = Vec::with_capacity(l_values.num_rows()); let mut dedup = HashSet::new(); - for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) { - let l_slice = l_w[0].as_usize()..l_w[1].as_usize(); - let r_slice = r_w[0].as_usize()..r_w[1].as_usize(); - for i in r_slice { - let right_row = r_values.row(i); + let nulls = NullBuffer::union(l.nulls(), r.nulls()); + + let l_offsets_iter = l.offsets().iter().tuple_windows(); + let r_offsets_iter = r.offsets().iter().tuple_windows(); + for (list_index, ((l_start, l_end), (r_start, r_end))) in + l_offsets_iter.zip(r_offsets_iter).enumerate() + { + if nulls + .as_ref() + .is_some_and(|nulls| nulls.is_null(list_index)) + { + offsets.push(OffsetSize::usize_as(indices.len())); + continue; + } + + for element_index in r_start.as_usize() - r_first..r_end.as_usize() - r_first { + let right_row = r_values.row(element_index); dedup.insert(right_row); } - for i in l_slice { - let left_row = l_values.row(i); + for element_index in l_start.as_usize() - l_first..l_end.as_usize() - l_first { + let left_row = l_values.row(element_index); if dedup.insert(left_row) { - rows.push(left_row); + indices.push(element_index + l_first); } } - offsets.push(OffsetSize::usize_as(rows.len())); + offsets.push(OffsetSize::usize_as(indices.len())); dedup.clear(); } - if let Some(values) = converter.convert_rows(rows)?.first() { - Ok(GenericListArray::::new( - field.to_owned(), - OffsetBuffer::new(offsets.into()), - values.to_owned(), - l.nulls().cloned(), - )) + // Gather distinct left-side values by index. + // Use UInt64Array for LargeList to support values arrays exceeding u32::MAX. + let values = if indices.is_empty() { + arrow::array::new_empty_array(&l.value_type()) + } else if OffsetSize::IS_LARGE { + let indices = + UInt64Array::from(indices.into_iter().map(|i| i as u64).collect::>()); + take(l_values_norm.as_ref(), &indices, None)? } else { - internal_err!("array_except failed to convert rows") + let indices = + UInt32Array::from(indices.into_iter().map(|i| i as u32).collect::>()); + take(l_values_norm.as_ref(), &indices, None)? + }; + + Ok(GenericListArray::::new( + field.to_owned(), + OffsetBuffer::new(offsets.into()), + values, + nulls, + )) +} + +#[cfg(test)] +mod tests { + use super::ArrayExcept; + use arrow::array::{Array, AsArray, Int32Array, ListArray}; + use arrow::datatypes::{Field, Int32Type}; + use datafusion_common::{Result, config::ConfigOptions}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use std::sync::Arc; + + #[test] + fn test_array_except_sliced_lists() -> Result<()> { + // l: [[1,2], [3,4], [5,6], [7,8]] → slice(1,2) → [[3,4], [5,6]] + // r: [[3], [5], [6], [8]] → slice(1,2) → [[5], [6]] + // except(l, r) should be [[3,4], [5]] + let l_full = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(3), Some(4)]), + Some(vec![Some(5), Some(6)]), + Some(vec![Some(7), Some(8)]), + ]); + let r_full = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(3)]), + Some(vec![Some(5)]), + Some(vec![Some(6)]), + Some(vec![Some(8)]), + ]); + let l_sliced = l_full.slice(1, 2); + let r_sliced = r_full.slice(1, 2); + + let list_field = Arc::new(Field::new("item", l_sliced.data_type().clone(), true)); + let return_field = + Arc::new(Field::new("return", l_sliced.data_type().clone(), true)); + + let result = ArrayExcept::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(l_sliced)), + ColumnarValue::Array(Arc::new(r_sliced)), + ], + arg_fields: vec![Arc::clone(&list_field), Arc::clone(&list_field)], + number_rows: 2, + return_field, + config_options: Arc::new(ConfigOptions::default()), + })?; + + let output = result.into_array(2)?; + let output = output.as_list::(); + + // Row 0: [3,4] except [5] = [3,4] + let row0 = output.value(0); + let row0 = row0.as_any().downcast_ref::().unwrap(); + assert_eq!(row0.values().as_ref(), &[3, 4]); + + // Row 1: [5,6] except [6] = [5] + let row1 = output.value(1); + let row1 = row1.as_any().downcast_ref::().unwrap(); + assert_eq!(row1.values().as_ref(), &[5]); + + Ok(()) } } diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index 57505c59493af..202a76bd0b035 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -19,9 +19,9 @@ use arrow::array::{ Array, ArrayRef, Capacities, GenericListArray, GenericListViewArray, Int64Array, - MutableArrayData, NullArray, NullBufferBuilder, OffsetSizeTrait, + MutableArrayData, NullArray, OffsetSizeTrait, }; -use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow::datatypes::DataType; use arrow::datatypes::{ DataType::{FixedSizeList, LargeList, LargeListView, List, ListView, Null}, @@ -35,17 +35,17 @@ use datafusion_common::cast::{ use datafusion_common::internal_err; use datafusion_common::utils::ListCoercion; use datafusion_common::{ - exec_datafusion_err, exec_err, internal_datafusion_err, plan_err, - utils::take_function_args, Result, + Result, exec_datafusion_err, exec_err, internal_datafusion_err, plan_err, + utils::take_function_args, }; use datafusion_expr::{ - ArrayFunctionArgument, ArrayFunctionSignature, Expr, TypeSignature, + ArrayFunctionArgument, ArrayFunctionSignature, Expr, ScalarFunctionArgs, + TypeSignature, }; use datafusion_expr::{ ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; use crate::utils::make_scalar_function; @@ -132,9 +132,6 @@ impl ArrayElement { } impl ScalarUDFImpl for ArrayElement { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "array_element" } @@ -172,10 +169,7 @@ impl ScalarUDFImpl for ArrayElement { } } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_element_inner)(&args.args) } @@ -262,8 +256,8 @@ where let end = offset_window[1]; let len = end - start; - // array is null - if len == O::usize_as(0) { + // array or index is null + if array.is_null(row_index) || indexes.is_null(row_index) { mutable.extend_nulls(1); continue; } @@ -358,10 +352,6 @@ impl ArraySlice { } impl ScalarUDFImpl for ArraySlice { - fn as_any(&self) -> &dyn Any { - self - } - fn display_name(&self, args: &[Expr]) -> Result { let args_name = args.iter().map(ToString::to_string).collect::>(); if let Some((arr, indexes)) = args_name.split_first() { @@ -395,10 +385,7 @@ impl ScalarUDFImpl for ArraySlice { Ok(arg_types[0].clone()) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_slice_inner)(&args.args) } @@ -608,6 +595,21 @@ where } } +/// Combine null bitmaps from all slice inputs into a single mask. +fn combine_input_nulls( + array: &dyn Array, + from_array: &Int64Array, + to_array: &Int64Array, + stride: Option<&Int64Array>, +) -> Option { + NullBuffer::union_many([ + array.nulls(), + from_array.nulls(), + to_array.nulls(), + stride.and_then(|s| s.nulls()), + ]) +} + fn general_array_slice( array: &GenericListArray, from_array: &Int64Array, @@ -628,25 +630,19 @@ where // The rule `adjusted_from_index` and `adjusted_to_index` follows the rule of array_slice in duckdb. let mut offsets = vec![O::usize_as(0)]; - let mut null_builder = NullBufferBuilder::new(array.len()); + + let nulls = combine_input_nulls(array, from_array, to_array, stride); for (row_index, offset_window) in array.offsets().windows(2).enumerate() { let start = offset_window[0]; let end = offset_window[1]; let len = end - start; - // If any input is null, return null. - if array.is_null(row_index) - || from_array.is_null(row_index) - || to_array.is_null(row_index) - || stride.is_some_and(|s| s.is_null(row_index)) - { + if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) { mutable.extend_nulls(1); offsets.push(offsets[row_index] + O::usize_as(1)); - null_builder.append_null(); continue; } - null_builder.append_non_null(); // Empty arrays always return an empty array. if len == O::usize_as(0) { @@ -689,7 +685,7 @@ where Arc::new(Field::new_list_field(array.value_type(), true)), OffsetBuffer::::new(offsets.into()), arrow::array::make_array(data), - null_builder.finish(), + nulls, )?)) } @@ -720,21 +716,15 @@ where let mut offsets = Vec::with_capacity(array.len()); let mut sizes = Vec::with_capacity(array.len()); let mut current_offset = O::usize_as(0); - let mut null_builder = NullBufferBuilder::new(array.len()); + + let nulls = combine_input_nulls(array, from_array, to_array, stride); for row_index in 0..array.len() { - // Propagate NULL semantics: any NULL input yields a NULL output slot. - if array.is_null(row_index) - || from_array.is_null(row_index) - || to_array.is_null(row_index) - || stride.is_some_and(|s| s.is_null(row_index)) - { - null_builder.append_null(); + if nulls.as_ref().is_some_and(|n| n.is_null(row_index)) { offsets.push(current_offset); sizes.push(O::usize_as(0)); continue; } - null_builder.append_non_null(); let len = array.value_size(row_index); @@ -790,7 +780,7 @@ where ScalarBuffer::from(offsets), ScalarBuffer::from(sizes), arrow::array::make_array(data), - null_builder.finish(), + nulls, )?)) } @@ -827,9 +817,6 @@ impl ArrayPopFront { } impl ScalarUDFImpl for ArrayPopFront { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "array_pop_front" } @@ -842,10 +829,7 @@ impl ScalarUDFImpl for ArrayPopFront { Ok(arg_types[0].clone()) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_pop_front_inner)(&args.args) } @@ -923,9 +907,6 @@ impl ArrayPopBack { } impl ScalarUDFImpl for ArrayPopBack { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "array_pop_back" } @@ -938,10 +919,7 @@ impl ScalarUDFImpl for ArrayPopBack { Ok(arg_types[0].clone()) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_pop_back_inner)(&args.args) } @@ -1023,9 +1001,6 @@ impl ArrayAnyValue { } impl ScalarUDFImpl for ArrayAnyValue { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "array_any_value" } @@ -1034,19 +1009,16 @@ impl ScalarUDFImpl for ArrayAnyValue { } fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { - List(field) - | LargeList(field) - | FixedSizeList(field, _) => Ok(field.data_type().clone()), + List(field) | LargeList(field) | FixedSizeList(field, _) => { + Ok(field.data_type().clone()) + } _ => plan_err!( "array_any_value can only accept List, LargeList or FixedSizeList as the argument" ), } } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_any_value_inner)(&args.args) } @@ -1090,11 +1062,9 @@ where for (row_index, offset_window) in array.offsets().windows(2).enumerate() { let start = offset_window[0]; - let end = offset_window[1]; - let len = end - start; // array is null - if len == O::usize_as(0) { + if array.is_null(row_index) { mutable.extend_nulls(1); continue; } @@ -1127,14 +1097,18 @@ where #[cfg(test)] mod tests { - use super::{array_element_udf, general_list_view_array_slice}; + use super::{ + array_element_udf, general_array_any_value, general_array_element, + general_list_view_array_slice, + }; use arrow::array::{ - cast::AsArray, Array, ArrayRef, GenericListViewArray, Int32Array, Int64Array, - ListViewArray, + Array, ArrayRef, GenericListViewArray, Int32Array, Int64Array, ListViewArray, + cast::AsArray, }; - use arrow::buffer::ScalarBuffer; - use arrow::datatypes::{DataType, Field}; - use datafusion_common::{Column, DFSchema, Result}; + use arrow::array::{ListArray, RecordBatch}; + use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; + use arrow::datatypes::{DataType, Field, Int32Type}; + use datafusion_common::{Column, DFSchema, Result, assert_batches_eq}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{Expr, ExprSchemable}; use std::collections::HashMap; @@ -1195,6 +1169,73 @@ mod tests { ); } + #[test] + fn test_array_element_null_handling() -> Result<()> { + let values = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 4, 5])); + let nulls = NullBuffer::from(vec![true, false, true]); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + + let list_array = ListArray::new(field, offsets, values, Some(nulls)); + let indexes = Int64Array::from(vec![1, 1, 1]); + + let result = general_array_element(&list_array, &indexes)?; + + let expected = [ + "+--------+", + "| result |", + "+--------+", + "| 1 |", + "| |", + "| 5 |", + "+--------+", + ]; + + let batch = RecordBatch::try_from_iter([("result", result)])?; + + assert_batches_eq!(expected, &[batch]); + + Ok(()) + } + + #[test] + fn test_array_element_null_index_with_non_zero_buffer_returns_null() -> Result<()> { + let list_array = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4)]), + Some(vec![Some(5)]), + ]); + let indexes = Int64Array::new( + ScalarBuffer::from(vec![1, 1, 1]), + Some(NullBuffer::from(vec![true, false, true])), + ); + + let result = general_array_element(&list_array, &indexes)?; + let expected = Int32Array::from(vec![Some(1), None, Some(5)]); + + assert_eq!(result.as_primitive::(), &expected); + + Ok(()) + } + + #[test] + fn test_array_any_null_handling() -> Result<()> { + let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 3, 4, 5])); + let nulls = NullBuffer::from(vec![true, false, true]); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + + let list_array = ListArray::new(field, offsets, values, Some(nulls)); + + let result = general_array_any_value(&list_array)?; + + assert!(!result.is_null(0)); + assert!(result.is_null(1)); + assert!(!result.is_null(2)); + + Ok(()) + } + #[test] fn test_array_slice_list_view_basic() -> Result<()> { let values: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); diff --git a/datafusion/functions-nested/src/flatten.rs b/datafusion/functions-nested/src/flatten.rs index 76c4714de1afc..23612c93fefd5 100644 --- a/datafusion/functions-nested/src/flatten.rs +++ b/datafusion/functions-nested/src/flatten.rs @@ -25,12 +25,12 @@ use arrow::datatypes::{ DataType::{FixedSizeList, LargeList, List, Null}, }; use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, utils::take_function_args, Result}; +use datafusion_common::{Result, exec_err, utils::take_function_args}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; make_udf_expr_and_func!( @@ -80,10 +80,6 @@ impl Flatten { } impl ScalarUDFImpl for Flatten { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "flatten" } @@ -114,10 +110,7 @@ impl ScalarUDFImpl for Flatten { Ok(data_type) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(flatten_inner)(&args.args) } @@ -208,7 +201,7 @@ fn flatten_inner(args: &[ArrayRef]) -> Result { } Null => Ok(Arc::clone(array)), _ => { - exec_err!("flatten does not support type '{:?}'", array.data_type()) + exec_err!("flatten does not support type '{}'", array.data_type()) } } } diff --git a/datafusion/functions-nested/src/inner_product.rs b/datafusion/functions-nested/src/inner_product.rs new file mode 100644 index 0000000000000..8d8c5656635aa --- /dev/null +++ b/datafusion/functions-nested/src/inner_product.rs @@ -0,0 +1,214 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for inner_product function. + +use crate::utils::make_scalar_function; +use arrow::array::{Array, ArrayRef, Float64Array, OffsetSizeTrait}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, Null}, + Field, +}; +use datafusion_common::cast::{as_float64_array, as_generic_list_array}; +use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; +use datafusion_common::{ + Result, exec_err, internal_err, plan_err, utils::take_function_args, +}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +make_udf_expr_and_func!( + InnerProduct, + inner_product, + array1 array2, + "returns the inner product (dot product) of two numeric arrays.", + inner_product_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the inner product (dot product) of two input arrays of equal length, computed as `sum(array1[i] * array2[i])`. Returns NULL if either array is NULL or contains NULL elements. Returns 0.0 for two empty arrays.", + syntax_example = "inner_product(array1, array2)", + sql_example = r#"```sql +> select inner_product([1.0, 2.0, 3.0], [4.0, 5.0, 6.0]); ++-------------------------------------------------------+ +| inner_product(List([1.0,2.0,3.0]),List([4.0,5.0,6.0])) | ++-------------------------------------------------------+ +| 32.0 | ++-------------------------------------------------------+ +```"#, + argument( + name = "array1", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array2", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct InnerProduct { + signature: Signature, + aliases: Vec, +} + +impl Default for InnerProduct { + fn default() -> Self { + Self::new() + } +} + +impl InnerProduct { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["dot_product".to_string()], + } + } +} + +impl ScalarUDFImpl for InnerProduct { + fn name(&self) -> &str { + "inner_product" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [_, _] = take_function_args(self.name(), arg_types)?; + let coercion = Some(&ListCoercion::FixedSizedListToList); + + for arg_type in arg_types { + if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + return plan_err!("{} does not support type {arg_type}", self.name()); + } + } + + // If any input is `LargeList`, both sides must be widened to `LargeList` + // so the runtime dispatch in `inner_product_inner` sees a homogeneous + // pair. Follows the pattern in `ArrayConcat::coerce_types`. + let any_large_list = arg_types.iter().any(|t| matches!(t, LargeList(_))); + + let coerced = arg_types + .iter() + .map(|arg_type| { + if matches!(arg_type, Null) { + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + return if any_large_list { + LargeList(field) + } else { + List(field) + }; + } + let coerced = coerced_type_with_base_type_only( + arg_type, + &DataType::Float64, + coercion, + ); + match coerced { + List(field) if any_large_list => LargeList(field), + other => other, + } + }) + .collect(); + + Ok(coerced) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(inner_product_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn inner_product_inner(args: &[ArrayRef]) -> Result { + let [array1, array2] = take_function_args("inner_product", args)?; + match (array1.data_type(), array2.data_type()) { + (List(_), List(_)) => general_inner_product::(args), + (LargeList(_), LargeList(_)) => general_inner_product::(args), + (arg_type1, arg_type2) => internal_err!( + "inner_product received unexpected types after coercion: {arg_type1} and {arg_type2}" + ), + } +} + +fn general_inner_product(arrays: &[ArrayRef]) -> Result { + let list_array1 = as_generic_list_array::(&arrays[0])?; + let list_array2 = as_generic_list_array::(&arrays[1])?; + + let values1 = as_float64_array(list_array1.values())?; + let values2 = as_float64_array(list_array2.values())?; + let offsets1 = list_array1.value_offsets(); + let offsets2 = list_array2.value_offsets(); + + let mut builder = Float64Array::builder(list_array1.len()); + for row in 0..list_array1.len() { + if list_array1.is_null(row) || list_array2.is_null(row) { + builder.append_null(); + continue; + } + + let start1 = offsets1[row].as_usize(); + let end1 = offsets1[row + 1].as_usize(); + let start2 = offsets2[row].as_usize(); + let end2 = offsets2[row + 1].as_usize(); + let len1 = end1 - start1; + let len2 = end2 - start2; + + if len1 != len2 { + return exec_err!( + "inner_product requires both list inputs to have the same length, got {len1} and {len2}" + ); + } + + let slice1 = values1.slice(start1, len1); + let slice2 = values2.slice(start2, len2); + if slice1.null_count() != 0 || slice2.null_count() != 0 { + builder.append_null(); + continue; + } + + let vals1 = slice1.values(); + let vals2 = slice2.values(); + + let mut dot = 0.0; + for i in 0..len1 { + dot += vals1[i] * vals2[i]; + } + builder.append_value(dot); + } + + Ok(Arc::new(builder.finish()) as ArrayRef) +} diff --git a/datafusion/functions-nested/src/lambda_utils.rs b/datafusion/functions-nested/src/lambda_utils.rs new file mode 100644 index 0000000000000..0f208ce5d26b2 --- /dev/null +++ b/datafusion/functions-nested/src/lambda_utils.rs @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shared utilities for `(array, lambda)` style higher-order functions. + +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, FieldRef}; +use datafusion_common::{ + Result, ScalarValue, plan_err, + utils::{list_values, take_function_args}, +}; +use datafusion_expr::{ColumnarValue, LambdaParametersProgress, ValueOrLambda}; +use std::sync::Arc; + +/// Extracts a `(value, lambda)` pair from a [`ValueOrLambda`] slice. +pub(crate) fn value_lambda_pair<'a, V: std::fmt::Debug, L: std::fmt::Debug>( + name: &str, + args: &'a [ValueOrLambda], +) -> Result<(&'a V, &'a L)> { + let [value, lambda] = take_function_args(name, args)?; + + let (ValueOrLambda::Value(value), ValueOrLambda::Lambda(lambda)) = (value, lambda) + else { + return plan_err!( + "{name} expects a value followed by a lambda, got {value:?} and {lambda:?}" + ); + }; + + Ok((value, lambda)) +} + +/// Coerces a single list argument for `(array, lambda)` style higher-order functions. +/// +/// Normalises `ListView`/`FixedSizeList` → `List` and `LargeListView` → `LargeList`. +pub(crate) fn coerce_single_list_arg( + name: &str, + arg_types: &[DataType], +) -> Result> { + let list = if arg_types.len() == 1 { + &arg_types[0] + } else { + return plan_err!( + "{name} function requires 1 value arguments, got {}", + arg_types.len() + ); + }; + + let coerced = match list { + DataType::List(_) | DataType::LargeList(_) => list.clone(), + DataType::ListView(field) | DataType::FixedSizeList(field, _) => { + DataType::List(Arc::clone(field)) + } + DataType::LargeListView(field) => DataType::LargeList(Arc::clone(field)), + _ => return plan_err!("{name} expected a list as first argument, got {list}"), + }; + + Ok(vec![coerced]) +} + +/// Returns the single lambda parameter set for `(array, v -> body)` style HOFs. +pub(crate) fn single_list_lambda_parameters( + name: &str, + fields: &[ValueOrLambda>], +) -> Result { + let (list, _lambda) = value_lambda_pair(name, fields)?; + + let field = match list.data_type() { + DataType::List(field) | DataType::LargeList(field) => field, + _ => return plan_err!("expected list, got {list}"), + }; + + Ok(LambdaParametersProgress::Complete(vec![vec![Arc::clone( + field, + )]])) +} + +/// Result of extracting flat list values, with fast-path short-circuits handled. +pub(crate) enum ListValuesResult { + /// Caller should return this value immediately. + EarlyReturn(ColumnarValue), + /// Flat values extracted from the list; continue with execution. + Values(ArrayRef), +} + +/// Extracts flat list values, handling all fast-path short-circuits. +/// +/// - All-null input → `EarlyReturn(null scalar)` +/// - All sublists empty and non-null → `EarlyReturn(default empty-list scalar)` +/// - Otherwise → `Values(flat_values)` +pub(crate) fn extract_list_values( + list_array: &ArrayRef, + return_type: &DataType, +) -> Result { + if list_array.null_count() == list_array.len() { + return Ok(ListValuesResult::EarlyReturn(ColumnarValue::Scalar( + ScalarValue::try_new_null(return_type)?, + ))); + } + + let values = list_values(list_array)?; + + if values.is_empty() + && list_array.null_count() == 0 + && matches!(return_type, DataType::List(_) | DataType::LargeList(_)) + { + return Ok(ListValuesResult::EarlyReturn(ColumnarValue::Scalar( + ScalarValue::new_default(return_type)?, + ))); + } + + Ok(ListValuesResult::Values(values)) +} + +#[cfg(test)] +pub(crate) mod test_utils { + use std::{collections::HashMap, sync::Arc}; + + use arrow::{ + array::{Array, ArrayRef, Int32Array, ListArray, RecordBatch}, + buffer::{NullBuffer, OffsetBuffer}, + datatypes::{DataType, Field}, + }; + use datafusion_common::{DFSchema, Result}; + use datafusion_expr::{ + Expr, HigherOrderUDF, col, + execution_props::ExecutionProps, + expr::{HigherOrderFunction, LambdaVariable}, + lambda, + }; + use datafusion_physical_expr::create_physical_expr; + + pub(crate) fn create_i32_list( + values: impl Into, + offsets: OffsetBuffer, + nulls: Option, + ) -> ListArray { + let list_field = Arc::new(Field::new_list_field(DataType::Int32, true)); + ListArray::new(list_field, offsets, Arc::new(values.into()), nulls) + } + + pub(crate) fn eval_hof_on_i32_list( + func: Arc, + list: impl Array + Clone + 'static, + lambda_body: Expr, + ) -> Result { + let schema = DFSchema::from_unqualified_fields( + vec![Field::new( + "list", + list.data_type().clone(), + list.is_nullable(), + )] + .into(), + HashMap::new(), + )?; + + create_physical_expr( + &Expr::HigherOrderFunction(HigherOrderFunction::new( + func, + vec![col("list"), lambda(["v"], lambda_body)], + )), + &schema, + &ExecutionProps::new(), + )? + .evaluate(&RecordBatch::try_new( + Arc::clone(schema.inner()), + vec![Arc::new(list.clone())], + )?)? + .into_array(list.len()) + } + + pub(crate) fn v() -> Expr { + Expr::LambdaVariable(LambdaVariable::new( + "v".to_string(), + Some(Arc::new(Field::new("v", DataType::Int32, true))), + )) + } +} diff --git a/datafusion/functions-nested/src/length.rs b/datafusion/functions-nested/src/length.rs index ceceee7bfa523..9579c3c9cd658 100644 --- a/datafusion/functions-nested/src/length.rs +++ b/datafusion/functions-nested/src/length.rs @@ -29,14 +29,13 @@ use arrow::datatypes::{ use datafusion_common::cast::{ as_fixed_size_list_array, as_generic_list_array, as_int64_array, }; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, - ScalarUDFImpl, Signature, TypeSignature, Volatility, + ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_functions::downcast_arg; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; make_udf_expr_and_func!( @@ -102,9 +101,6 @@ impl ArrayLength { } impl ScalarUDFImpl for ArrayLength { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "array_length" } @@ -117,10 +113,7 @@ impl ScalarUDFImpl for ArrayLength { Ok(UInt64) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_length_inner)(&args.args) } diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 51210b9ae22d4..5b27e2780481b 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -23,8 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! Nested type Functions for [DataFusion]. @@ -39,9 +37,24 @@ #[macro_use] pub mod macros; +#[macro_use] +pub mod macros_lambda; + +pub mod array_add; +pub mod array_any_match; +pub mod array_compact; +pub mod array_filter; pub mod array_has; +pub mod array_normalize; +pub mod array_product; +pub mod array_scale; +pub mod array_subtract; +pub mod array_sum; +pub mod array_transform; +pub mod arrays_zip; pub mod cardinality; pub mod concat; +pub mod cosine_distance; pub mod dimension; pub mod distance; pub mod empty; @@ -49,6 +62,8 @@ pub mod except; pub mod expr_ext; pub mod extract; pub mod flatten; +pub mod inner_product; +pub(crate) mod lambda_utils; pub mod length; pub mod make_array; pub mod map; @@ -72,19 +87,31 @@ pub mod utils; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; -use datafusion_expr::ScalarUDF; +use datafusion_expr::{HigherOrderUDF, ScalarUDF}; use log::debug; use std::sync::Arc; /// Fluent-style API for creating `Expr`s pub mod expr_fn { + pub use super::array_add::array_add; + pub use super::array_any_match::array_any_match; + pub use super::array_compact::array_compact; + pub use super::array_filter::array_filter; pub use super::array_has::array_has; pub use super::array_has::array_has_all; pub use super::array_has::array_has_any; + pub use super::array_normalize::array_normalize; + pub use super::array_product::array_product; + pub use super::array_scale::array_scale; + pub use super::array_subtract::array_subtract; + pub use super::array_sum::array_sum; + pub use super::array_transform::array_transform; + pub use super::arrays_zip::arrays_zip; pub use super::cardinality::cardinality; pub use super::concat::array_append; pub use super::concat::array_concat; pub use super::concat::array_prepend; + pub use super::cosine_distance::cosine_distance; pub use super::dimension::array_dims; pub use super::dimension::array_ndims; pub use super::distance::array_distance; @@ -96,6 +123,7 @@ pub mod expr_fn { pub use super::extract::array_pop_front; pub use super::extract::array_slice; pub use super::flatten::flatten; + pub use super::inner_product::inner_product; pub use super::length::array_length; pub use super::make_array::make_array; pub use super::map_entries::map_entries; @@ -128,6 +156,7 @@ pub mod expr_fn { /// Return all default nested type functions pub fn all_default_nested_functions() -> Vec> { vec![ + array_compact::array_compact_udf(), string::array_to_string_udf(), string::string_to_array_udf(), range::range_udf(), @@ -150,6 +179,14 @@ pub fn all_default_nested_functions() -> Vec> { array_has::array_has_any_udf(), empty::array_empty_udf(), length::array_length_udf(), + array_normalize::array_normalize_udf(), + array_add::array_add_udf(), + array_product::array_product_udf(), + array_scale::array_scale_udf(), + array_subtract::array_subtract_udf(), + array_sum::array_sum_udf(), + cosine_distance::cosine_distance_udf(), + inner_product::inner_product_udf(), distance::array_distance_udf(), flatten::flatten_udf(), min_max::array_max_udf(), @@ -161,6 +198,7 @@ pub fn all_default_nested_functions() -> Vec> { set_ops::array_distinct_udf(), set_ops::array_intersect_udf(), set_ops::array_union_udf(), + arrays_zip::arrays_zip_udf(), position::array_position_udf(), position::array_positions_udf(), remove::array_remove_udf(), @@ -177,6 +215,14 @@ pub fn all_default_nested_functions() -> Vec> { ] } +pub fn all_default_higher_order_functions() -> Vec> { + vec![ + array_any_match::array_any_match_higher_order_function(), + array_filter::array_filter_higher_order_function(), + array_transform::array_transform_higher_order_function(), + ] +} + /// Registers all enabled packages with a [`FunctionRegistry`] pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { let functions: Vec> = all_default_nested_functions(); @@ -188,25 +234,43 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { Ok(()) as Result<()> })?; + let functions: Vec> = all_default_higher_order_functions(); + functions.into_iter().try_for_each(|function| { + let existing_function = registry.register_higher_order_function(function)?; + if let Some(existing_function) = existing_function { + debug!( + "Overwrite existing higher-order function: {}", + existing_function.name() + ); + } + Ok(()) as Result<()> + })?; + Ok(()) } #[cfg(test)] mod tests { - use crate::all_default_nested_functions; + use crate::{all_default_higher_order_functions, all_default_nested_functions}; use datafusion_common::Result; use std::collections::HashSet; #[test] fn test_no_duplicate_name() -> Result<()> { + let scalars = all_default_nested_functions(); + let scalars = scalars.iter().map(|s| (s.name(), s.aliases())); + + let lambdas = all_default_higher_order_functions(); + let lambdas = lambdas.iter().map(|l| (l.name(), l.aliases())); + let mut names = HashSet::new(); - for func in all_default_nested_functions() { + + for (name, aliases) in scalars.chain(lambdas) { assert!( - names.insert(func.name().to_string().to_lowercase()), - "duplicate function name: {}", - func.name() + names.insert(name.to_string().to_lowercase()), + "duplicate function name: {name}", ); - for alias in func.aliases() { + for alias in aliases { assert!( names.insert(alias.to_string().to_lowercase()), "duplicate function name: {alias}" diff --git a/datafusion/functions-nested/src/macros.rs b/datafusion/functions-nested/src/macros.rs index 5380f6b1272d1..5f12113150a40 100644 --- a/datafusion/functions-nested/src/macros.rs +++ b/datafusion/functions-nested/src/macros.rs @@ -50,7 +50,6 @@ macro_rules! make_udf_expr_and_func { make_udf_expr_and_func!($UDF, $EXPR_FN, $($arg)*, $DOC, $SCALAR_UDF_FN, $UDF::new); }; ($UDF:ident, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $SCALAR_UDF_FN:ident, $CTOR:path) => { - paste::paste! { // "fluent expr_fn" style function #[doc = $DOC] pub fn $EXPR_FN($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { @@ -60,13 +59,11 @@ macro_rules! make_udf_expr_and_func { )) } create_func!($UDF, $SCALAR_UDF_FN, $CTOR); - } }; ($UDF:ident, $EXPR_FN:ident, $DOC:expr, $SCALAR_UDF_FN:ident) => { make_udf_expr_and_func!($UDF, $EXPR_FN, $DOC, $SCALAR_UDF_FN, $UDF::new); }; ($UDF:ident, $EXPR_FN:ident, $DOC:expr, $SCALAR_UDF_FN:ident, $CTOR:path) => { - paste::paste! { // "fluent expr_fn" style function #[doc = $DOC] pub fn $EXPR_FN(arg: Vec) -> datafusion_expr::Expr { @@ -76,7 +73,6 @@ macro_rules! make_udf_expr_and_func { )) } create_func!($UDF, $SCALAR_UDF_FN, $CTOR); - } }; } @@ -97,7 +93,6 @@ macro_rules! create_func { create_func!($UDF, $SCALAR_UDF_FN, $UDF::new); }; ($UDF:ident, $SCALAR_UDF_FN:ident, $CTOR:path) => { - paste::paste! { #[doc = concat!("ScalarFunction that returns a [`ScalarUDF`](datafusion_expr::ScalarUDF) for ")] #[doc = stringify!($UDF)] pub fn $SCALAR_UDF_FN() -> std::sync::Arc { @@ -110,6 +105,5 @@ macro_rules! create_func { }); std::sync::Arc::clone(&INSTANCE) } - } }; } diff --git a/datafusion/functions-nested/src/macros_lambda.rs b/datafusion/functions-nested/src/macros_lambda.rs new file mode 100644 index 0000000000000..c8fe670844b2d --- /dev/null +++ b/datafusion/functions-nested/src/macros_lambda.rs @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Creates external API functions for an array UDF. Specifically, creates +/// +/// 1. Single `HigherOrderUDF` instance +/// +/// Creates a singleton `HigherOrderUDF` of the `$UDF` function named `STATIC_$(UDF)` and a +/// function named `$HIGHER_ORDER_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. +/// +/// This is used to ensure creating the list of `HigherOrderUDF` only happens once. +/// +/// # 2. `expr_fn` style function +/// +/// These are functions that create an `Expr` that invokes the UDF, used +/// primarily to programmatically create expressions. +/// +/// For example: +/// ```text +/// pub fn array_to_string(delimiter: Expr) -> Expr { +/// ... +/// } +/// ``` +/// # Arguments +/// * `UDF`: name of the [`HigherOrderUDF`] +/// * `EXPR_FN`: name of the expr_fn function to be created +/// * `arg`: 0 or more named arguments for the function +/// * `DOC`: documentation string for the function +/// * `HIGHER_ORDER_UDF_FUNC`: name of the function to create (just) the `HigherOrderUDF` +/// * (optional) `$CTOR`: Pass a custom constructor. When omitted it +/// automatically resolves to `$UDF::new()`. +/// +/// [`HigherOrderUDF`]: datafusion_expr::HigherOrderUDF +macro_rules! make_higher_order_function_expr_and_func { + ($UDF:ident, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $HIGHER_ORDER_UDF_FN:ident) => { + make_higher_order_function_expr_and_func!($UDF, $EXPR_FN, $($arg)*, $DOC, $HIGHER_ORDER_UDF_FN, $UDF::new); + }; + ($UDF:ident, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $HIGHER_ORDER_UDF_FN:ident, $CTOR:path) => { + // "fluent expr_fn" style function + #[doc = $DOC] + pub fn $EXPR_FN($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { + datafusion_expr::Expr::HigherOrderFunction(datafusion_expr::expr::HigherOrderFunction::new( + $HIGHER_ORDER_UDF_FN(), + vec![$($arg),*], + )) + } + create_higher_order!($UDF, $HIGHER_ORDER_UDF_FN, $CTOR); + }; + ($UDF:ident, $EXPR_FN:ident, $DOC:expr, $HIGHER_ORDER_UDF_FN:ident) => { + make_higher_order_function_expr_and_func!($UDF, $EXPR_FN, $DOC, $HIGHER_ORDER_UDF_FN, $UDF::new); + }; + ($UDF:ident, $EXPR_FN:ident, $DOC:expr, $HIGHER_ORDER_UDF_FN:ident, $CTOR:path) => { + // "fluent expr_fn" style function + #[doc = $DOC] + pub fn $EXPR_FN(arg: Vec) -> datafusion_expr::Expr { + datafusion_expr::Expr::HigherOrderFunction(datafusion_expr::expr::HigherOrderFunction::new( + $HIGHER_ORDER_UDF_FN(), + arg, + )) + } + create_higher_order!($UDF, $HIGHER_ORDER_UDF_FN, $CTOR); + }; +} + +/// Creates a singleton `HigherOrderUDF` of the `$UDF` function named `STATIC_$(UDF)` and a +/// function named `$HIGHER_ORDER_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. +/// +/// This is used to ensure creating the list of `HigherOrderUDF` only happens once. +/// +/// # Arguments +/// * `UDF`: name of the [`HigherOrderUDF`] +/// * `HIGHER_ORDER_UDF_FUNC`: name of the function to create (just) the `HigherOrderUDF` +/// * (optional) `$CTOR`: Pass a custom constructor. When omitted it +/// automatically resolves to `$UDF::new()`. +/// +/// [`HigherOrderUDF`]: datafusion_expr::HigherOrderUDF +macro_rules! create_higher_order { + ($UDF:ident, $HIGHER_ORDER_UDF_FN:ident) => { + create_higher_order!($UDF, $HIGHER_ORDER_UDF_FN, $UDF::new); + }; + ($UDF:ident, $HIGHER_ORDER_UDF_FN:ident, $CTOR:path) => { + #[doc = concat!("HigherOrderFunction that returns a [`HigherOrderUDF`](datafusion_expr::HigherOrderUDF) for ")] + #[doc = stringify!($UDF)] + pub fn $HIGHER_ORDER_UDF_FN() -> std::sync::Arc { + // Singleton instance of [`$UDF`], ensures the UDF is only created once + static INSTANCE: std::sync::LazyLock> = + std::sync::LazyLock::new(|| { + std::sync::Arc::new(datafusion_expr::HigherOrderUDF::new_from_impl($CTOR())) + }); + std::sync::Arc::clone(&INSTANCE) + } + }; +} diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index 97d64c70cd364..32af5df2c6019 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -17,26 +17,25 @@ //! [`ScalarUDFImpl`] definitions for `make_array` function. -use std::any::Any; use std::sync::Arc; use std::vec; use crate::utils::make_scalar_function; use arrow::array::{ - new_null_array, Array, ArrayData, ArrayRef, Capacities, GenericListArray, - MutableArrayData, NullArray, OffsetSizeTrait, + Array, ArrayData, ArrayRef, Capacities, GenericListArray, MutableArrayData, + NullArray, OffsetSizeTrait, new_null_array, }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::DataType; use arrow::datatypes::{DataType::Null, Field}; use datafusion_common::utils::SingleRowListArrayBuilder; -use datafusion_common::{plan_err, Result}; +use datafusion_common::{Result, plan_err}; use datafusion_expr::binary::{ try_type_union_resolution_with_struct, type_union_resolution, }; -use datafusion_expr::TypeSignature; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; use itertools::Itertools as _; @@ -80,20 +79,13 @@ impl Default for MakeArray { impl MakeArray { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![TypeSignature::Nullary, TypeSignature::UserDefined], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), aliases: vec![String::from("make_list")], } } } impl ScalarUDFImpl for MakeArray { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "make_array" } @@ -113,10 +105,7 @@ impl ScalarUDFImpl for MakeArray { Ok(DataType::new_list(element_type, true)) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(make_array_inner)(&args.args) } @@ -125,18 +114,10 @@ impl ScalarUDFImpl for MakeArray { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if let Ok(unified) = try_type_union_resolution_with_struct(arg_types) { - return Ok(unified); - } - - if let Some(unified) = type_union_resolution(arg_types) { - Ok(vec![unified; arg_types.len()]) + if arg_types.is_empty() { + Ok(vec![]) } else { - plan_err!( - "Failed to unify argument types of {}: [{}]", - self.name(), - arg_types.iter().join(", ") - ) + coerce_types_inner(arg_types, self.name()) } } @@ -163,7 +144,7 @@ pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { SingleRowListArrayBuilder::new(array).build_list_array(), )) } else { - array_array::(arrays, data_type.clone()) + array_array::(arrays, data_type.clone(), Field::LIST_FIELD_DEFAULT_NAME) } } @@ -207,9 +188,10 @@ pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { /// └──────────────┘ └──────────────┘ └─────────────────────────────┘ /// col1 col2 output /// ``` -fn array_array( +pub fn array_array( args: &[ArrayRef], data_type: DataType, + field_name: &str, ) -> Result { // do not accept 0 arguments. if args.is_empty() { @@ -252,9 +234,25 @@ fn array_array( let data = mutable.freeze(); Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new_list_field(data_type, true)), + Arc::new(Field::new(field_name, data_type, true)), OffsetBuffer::new(offsets.into()), arrow::array::make_array(data), None, )?)) } + +pub fn coerce_types_inner(arg_types: &[DataType], name: &str) -> Result> { + if let Ok(unified) = try_type_union_resolution_with_struct(arg_types) { + return Ok(unified); + } + + if let Some(unified) = type_union_resolution(arg_types) { + Ok(vec![unified; arg_types.len()]) + } else { + plan_err!( + "Failed to unify argument types of {}: [{}]", + name, + arg_types.iter().join(", ") + ) + } +} diff --git a/datafusion/functions-nested/src/map.rs b/datafusion/functions-nested/src/map.rs index fe9bc609c0130..c7418e9021494 100644 --- a/datafusion/functions-nested/src/map.rs +++ b/datafusion/functions-nested/src/map.rs @@ -15,21 +15,28 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::collections::VecDeque; +use std::hash::Hash; use std::sync::Arc; -use arrow::array::{Array, ArrayData, ArrayRef, MapArray, OffsetSizeTrait, StructArray}; +use arrow::array::{ + Array, ArrayData, ArrayRef, ArrowPrimitiveType, MapArray, OffsetSizeTrait, + StructArray, cast::AsArray, +}; use arrow::buffer::Buffer; -use arrow::datatypes::{DataType, Field, SchemaBuilder, ToByteSlice}; +use arrow::datatypes::{ + DataType, Date32Type, Date64Type, Field, Int8Type, Int16Type, Int32Type, Int64Type, + SchemaBuilder, ToByteSlice, UInt8Type, UInt16Type, UInt32Type, UInt64Type, +}; use datafusion_common::utils::{fixed_size_list_to_arrays, list_to_arrays}; use datafusion_common::{ - exec_err, utils::take_function_args, HashSet, Result, ScalarValue, + HashSet, Result, ScalarValue, exec_err, utils::take_function_args, }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -65,23 +72,30 @@ fn make_map_batch(args: &[ColumnarValue]) -> Result { let key_array = keys.as_ref(); match keys_arg { - ColumnarValue::Array(_) => { - let row_keys = match key_array.data_type() { - DataType::List(_) => list_to_arrays::(&keys), - DataType::LargeList(_) => list_to_arrays::(&keys), - DataType::FixedSizeList(_, _) => fixed_size_list_to_arrays(&keys), - data_type => { - return exec_err!( - "Expected list, large_list or fixed_size_list, got {:?}", - data_type - ); - } - }; - - row_keys + ColumnarValue::Array(_) => match key_array.data_type() { + DataType::List(_) => keys + .as_list::() .iter() - .try_for_each(|key| validate_map_keys(key.as_ref()))?; - } + .flatten() + .try_for_each(|row| validate_map_keys(row.as_ref()))?, + DataType::LargeList(_) => keys + .as_list::() + .iter() + .flatten() + .try_for_each(|row| validate_map_keys(row.as_ref()))?, + DataType::FixedSizeList(_, _) => { + keys.as_fixed_size_list() + .iter() + .flatten() + .try_for_each(|row| validate_map_keys(row.as_ref()))? + } + data_type => { + return exec_err!( + "Expected list, large_list or fixed_size_list, got {:?}", + data_type + ); + } + }, ColumnarValue::Scalar(_) => { validate_map_keys(key_array)?; } @@ -92,8 +106,67 @@ fn make_map_batch(args: &[ColumnarValue]) -> Result { make_map_batch_internal(&keys, &values, can_evaluate_to_const, &keys_arg.data_type()) } -/// Validates that map keys are non-null and unique. -fn validate_map_keys(array: &dyn Array) -> Result<()> { +fn validate_unique_primitive_keys(array: &dyn Array) -> Result<()> +where + T::Native: Copy + Eq + Hash + std::fmt::Display, +{ + let primitive_array = array.as_primitive::(); + if primitive_array.null_count() > 0 { + return exec_err!("map key cannot be null"); + } + + if let Some(value) = find_duplicate_value( + primitive_array.len(), + primitive_array.values().iter().copied(), + ) { + return exec_err!("map key must be unique, duplicate key found: {}", value); + } + + Ok(()) +} + +fn validate_unique_str_keys<'a>( + null_count: usize, + len: usize, + values: impl IntoIterator, +) -> Result<()> { + if null_count > 0 { + return exec_err!("map key cannot be null"); + } + + if let Some(value) = find_duplicate_value(len, values) { + return exec_err!("map key must be unique, duplicate key found: {}", value); + } + + Ok(()) +} + +fn validate_unique_binary_keys<'a>( + null_count: usize, + len: usize, + values: impl IntoIterator, +) -> Result<()> { + if null_count > 0 { + return exec_err!("map key cannot be null"); + } + + if let Some(value) = find_duplicate_value(len, values) { + return exec_err!("map key must be unique, duplicate key found: {:?}", value); + } + + Ok(()) +} + +fn find_duplicate_value(len: usize, values: I) -> Option +where + T: Copy + Eq + Hash, + I: IntoIterator, +{ + let mut seen_keys = HashSet::with_capacity(len); + values.into_iter().find(|value| !seen_keys.insert(*value)) +} + +fn validate_unique_keys_generic(array: &dyn Array) -> Result<()> { let mut seen_keys = HashSet::with_capacity(array.len()); for i in 0..array.len() { @@ -113,13 +186,54 @@ fn validate_map_keys(array: &dyn Array) -> Result<()> { Ok(()) } +/// Validates that map keys are non-null and unique. +fn validate_map_keys(array: &dyn Array) -> Result<()> { + match array.data_type() { + DataType::Int8 => validate_unique_primitive_keys::(array), + DataType::Int16 => validate_unique_primitive_keys::(array), + DataType::Int32 => validate_unique_primitive_keys::(array), + DataType::Int64 => validate_unique_primitive_keys::(array), + DataType::UInt8 => validate_unique_primitive_keys::(array), + DataType::UInt16 => validate_unique_primitive_keys::(array), + DataType::UInt32 => validate_unique_primitive_keys::(array), + DataType::UInt64 => validate_unique_primitive_keys::(array), + DataType::Date32 => validate_unique_primitive_keys::(array), + DataType::Date64 => validate_unique_primitive_keys::(array), + DataType::Utf8 => { + let arr = array.as_string::(); + validate_unique_str_keys(arr.null_count(), arr.len(), arr.iter().flatten()) + } + DataType::LargeUtf8 => { + let arr = array.as_string::(); + validate_unique_str_keys(arr.null_count(), arr.len(), arr.iter().flatten()) + } + DataType::Utf8View => { + let arr = array.as_string_view(); + validate_unique_str_keys(arr.null_count(), arr.len(), arr.iter().flatten()) + } + DataType::Binary => { + let arr = array.as_binary::(); + validate_unique_binary_keys(arr.null_count(), arr.len(), arr.iter().flatten()) + } + DataType::LargeBinary => { + let arr = array.as_binary::(); + validate_unique_binary_keys(arr.null_count(), arr.len(), arr.iter().flatten()) + } + DataType::BinaryView => { + let arr = array.as_binary_view(); + validate_unique_binary_keys(arr.null_count(), arr.len(), arr.iter().flatten()) + } + _ => validate_unique_keys_generic(array), + } +} + fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result { match columnar_value { ColumnarValue::Scalar(value) => match value { ScalarValue::List(array) => Ok(array.value(0)), ScalarValue::LargeList(array) => Ok(array.value(0)), ScalarValue::FixedSizeList(array) => Ok(array.value(0)), - _ => exec_err!("Expected array, got {:?}", value), + _ => exec_err!("Expected array, got {}", value), }, ColumnarValue::Array(array) => Ok(array.to_owned()), } @@ -256,10 +370,6 @@ impl MapFunc { } impl ScalarUDFImpl for MapFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "map" } @@ -288,10 +398,7 @@ impl ScalarUDFImpl for MapFunc { )) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_map_batch(&args.args) } @@ -381,7 +488,7 @@ fn make_map_array_internal( let nulls_bitmap = keys.nulls().cloned(); let keys = list_to_arrays::(keys); - let values = list_to_arrays::(values); + let values = list_to_arrays_skipping_null_rows::(values, nulls_bitmap.as_ref()); build_map_array( &keys, @@ -408,7 +515,8 @@ fn make_map_array_from_fixed_size_list( let nulls_bitmap = keys.nulls().cloned(); let keys = fixed_size_list_to_arrays(keys); - let values = fixed_size_list_to_arrays(values); + let values = + fixed_size_list_to_arrays_skipping_null_rows(values, nulls_bitmap.as_ref()); build_map_array( &keys, @@ -419,6 +527,41 @@ fn make_map_array_from_fixed_size_list( nulls_bitmap, ) } +fn list_to_arrays_skipping_null_rows( + array: &ArrayRef, + null_rows: Option<&arrow::buffer::NullBuffer>, +) -> Vec { + array + .as_list::() + .iter() + .enumerate() + .filter_map(|(i, row)| { + if null_rows.is_some_and(|nulls| nulls.is_null(i)) { + None + } else { + row + } + }) + .collect() +} + +fn fixed_size_list_to_arrays_skipping_null_rows( + array: &ArrayRef, + null_rows: Option<&arrow::buffer::NullBuffer>, +) -> Vec { + array + .as_fixed_size_list() + .iter() + .enumerate() + .filter_map(|(i, row)| { + if null_rows.is_some_and(|nulls| nulls.is_null(i)) { + None + } else { + row + } + }) + .collect() +} /// Common logic to build a MapArray from decomposed list arrays fn build_map_array( @@ -429,6 +572,10 @@ fn build_map_array( original_len: usize, nulls_bitmap: Option, ) -> Result { + if keys.len() != values.len() { + return exec_err!("map requires key and value lists to have the same length"); + } + let mut key_array_vec = vec![]; let mut value_array_vec = vec![]; for (k, v) in keys.iter().zip(values.iter()) { @@ -694,7 +841,7 @@ mod tests { use arrow::array::FixedSizeListBuilder; - // Build keys array as FixedSizeList(2): [['a', 'b'], ['c', 'd']] + // Build keys array as FixedSizeList(2): [['a', 'b'], NULL, ['c', 'd']] let key_values_builder = arrow::array::StringBuilder::new(); let mut key_builder = FixedSizeListBuilder::new(key_values_builder, 2); @@ -703,6 +850,11 @@ mod tests { key_builder.values().append_value("b"); key_builder.append(true); + // Second map: NULL (entire map is NULL) + key_builder.values().append_null(); + key_builder.values().append_null(); + key_builder.append(false); + // Second map: ['c', 'd'] key_builder.values().append_value("c"); key_builder.values().append_value("d"); @@ -710,7 +862,8 @@ mod tests { let keys_array = Arc::new(key_builder.finish()); - // Build values array as FixedSizeList(2): [[1, 2], [3, 4]] + // Build values array as FixedSizeList(2): [[1, 2], [99, 100], [3, 4]] + // The middle row should be ignored because the corresponding key row is NULL. let value_values_builder = arrow::array::Int32Builder::new(); let mut value_builder = FixedSizeListBuilder::new(value_values_builder, 2); @@ -718,6 +871,10 @@ mod tests { value_builder.values().append_value(2); value_builder.append(true); + value_builder.values().append_value(99); + value_builder.values().append_value(100); + value_builder.append(true); + value_builder.values().append_value(3); value_builder.values().append_value(4); value_builder.append(true); @@ -742,8 +899,9 @@ mod tests { _ => panic!("Expected Array result"), }; - assert_eq!(map_array.len(), 2, "Should have 2 maps"); + assert_eq!(map_array.len(), 3, "Should have 3 maps"); assert!(!map_array.is_null(0), "First map should not be NULL"); - assert!(!map_array.is_null(1), "Second map should not be NULL"); + assert!(map_array.is_null(1), "Second map should be NULL"); + assert!(!map_array.is_null(2), "Third map should not be NULL"); } } diff --git a/datafusion/functions-nested/src/map_entries.rs b/datafusion/functions-nested/src/map_entries.rs index 7d9d103206dbc..e465b39d02751 100644 --- a/datafusion/functions-nested/src/map_entries.rs +++ b/datafusion/functions-nested/src/map_entries.rs @@ -21,13 +21,12 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow::array::{Array, ArrayRef, ListArray}; use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::utils::take_function_args; -use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_common::{Result, cast::as_map_array, exec_err}; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, - TypeSignature, Volatility, + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; make_udf_expr_and_func!( @@ -79,10 +78,6 @@ impl MapEntriesFunc { } impl ScalarUDFImpl for MapEntriesFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "map_entries" } @@ -111,10 +106,7 @@ impl ScalarUDFImpl for MapEntriesFunc { )))) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(map_entries_inner)(&args.args) } diff --git a/datafusion/functions-nested/src/map_extract.rs b/datafusion/functions-nested/src/map_extract.rs index 4aab5d7a60d18..aab0d013a4152 100644 --- a/datafusion/functions-nested/src/map_extract.rs +++ b/datafusion/functions-nested/src/map_extract.rs @@ -19,17 +19,17 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow::array::{ - make_array, Array, ArrayRef, Capacities, ListArray, MapArray, MutableArrayData, + Array, ArrayRef, Capacities, ListArray, MapArray, MutableArrayData, make_array, }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, Field}; use datafusion_common::utils::take_function_args; -use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_common::{Result, cast::as_map_array, exec_err}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; use std::vec; @@ -57,6 +57,11 @@ SELECT map_extract(MAP {1: 'one', 2: 'two'}, 2); SELECT map_extract(MAP {'x': 10, 'y': NULL, 'z': 30}, 'y'); ---- +[NULL] + +-- non-existing key +SELECT map_extract(MAP {'x': 10, 'y': NULL, 'z': 30}, 'a'); +---- [] ```"#, argument( @@ -90,9 +95,6 @@ impl MapExtract { } impl ScalarUDFImpl for MapExtract { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "map_extract" } @@ -110,10 +112,7 @@ impl ScalarUDFImpl for MapExtract { )))) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(map_extract_inner)(&args.args) } diff --git a/datafusion/functions-nested/src/map_keys.rs b/datafusion/functions-nested/src/map_keys.rs index 2fc44670d74a2..4d3421085c5d0 100644 --- a/datafusion/functions-nested/src/map_keys.rs +++ b/datafusion/functions-nested/src/map_keys.rs @@ -21,13 +21,12 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow::array::{Array, ArrayRef, ListArray}; use arrow::datatypes::{DataType, Field}; use datafusion_common::utils::take_function_args; -use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_common::{Result, cast::as_map_array, exec_err}; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, - TypeSignature, Volatility, + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; make_udf_expr_and_func!( @@ -79,10 +78,6 @@ impl MapKeysFunc { } impl ScalarUDFImpl for MapKeysFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "map_keys" } @@ -101,10 +96,7 @@ impl ScalarUDFImpl for MapKeysFunc { )))) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(map_keys_inner)(&args.args) } diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index 6ae8a278063da..d4fa3b0924f1b 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -21,13 +21,12 @@ use crate::utils::{get_map_entry_field, make_scalar_function}; use arrow::array::{Array, ArrayRef, ListArray}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::utils::take_function_args; -use datafusion_common::{cast::as_map_array, exec_err, internal_err, Result}; +use datafusion_common::{Result, cast::as_map_array, exec_err, internal_err}; use datafusion_expr::{ - ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, - TypeSignature, Volatility, + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; use std::ops::Deref; use std::sync::Arc; @@ -80,10 +79,6 @@ impl MapValuesFunc { } impl ScalarUDFImpl for MapValuesFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "map_values" } @@ -111,10 +106,7 @@ impl ScalarUDFImpl for MapValuesFunc { .into()) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(map_values_inner)(&args.args) } diff --git a/datafusion/functions-nested/src/min_max.rs b/datafusion/functions-nested/src/min_max.rs index 1f3623ca243dc..ba9e5e7a07eb4 100644 --- a/datafusion/functions-nested/src/min_max.rs +++ b/datafusion/functions-nested/src/min_max.rs @@ -15,15 +15,18 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarUDFImpl`] definitions for array_max function. +//! [`ScalarUDFImpl`] definitions for array_min and array_max functions. use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, GenericListArray, OffsetSizeTrait}; +use arrow::array::{ + Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray, GenericListArray, + OffsetSizeTrait, PrimitiveBuilder, downcast_primitive, +}; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{LargeList, List}; +use datafusion_common::Result; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::utils::take_function_args; -use datafusion_common::Result; -use datafusion_common::{exec_err, plan_err, ScalarValue}; +use datafusion_common::{ScalarValue, exec_err, plan_err}; use datafusion_doc::Documentation; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -31,7 +34,7 @@ use datafusion_expr::{ use datafusion_functions_aggregate_common::min_max::{max_batch, min_batch}; use datafusion_macros::user_doc; use itertools::Itertools; -use std::any::Any; +use std::sync::Arc; make_udf_expr_and_func!( ArrayMax, @@ -80,10 +83,6 @@ impl ArrayMax { } impl ScalarUDFImpl for ArrayMax { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_max" } @@ -116,8 +115,8 @@ impl ScalarUDFImpl for ArrayMax { fn array_max_inner(args: &[ArrayRef]) -> Result { let [array] = take_function_args("array_max", args)?; match array.data_type() { - List(_) => array_min_max_helper(as_list_array(array)?, max_batch), - LargeList(_) => array_min_max_helper(as_large_list_array(array)?, max_batch), + List(_) => array_min_max_helper(as_list_array(array)?, false), + LargeList(_) => array_min_max_helper(as_large_list_array(array)?, false), arg_type => exec_err!("array_max does not support type: {arg_type}"), } } @@ -166,10 +165,6 @@ impl ArrayMin { } impl ScalarUDFImpl for ArrayMin { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_min" } @@ -198,16 +193,23 @@ impl ScalarUDFImpl for ArrayMin { fn array_min_inner(args: &[ArrayRef]) -> Result { let [array] = take_function_args("array_min", args)?; match array.data_type() { - List(_) => array_min_max_helper(as_list_array(array)?, min_batch), - LargeList(_) => array_min_max_helper(as_large_list_array(array)?, min_batch), + List(_) => array_min_max_helper(as_list_array(array)?, true), + LargeList(_) => array_min_max_helper(as_large_list_array(array)?, true), arg_type => exec_err!("array_min does not support type: {arg_type}"), } } fn array_min_max_helper( array: &GenericListArray, - agg_fn: fn(&ArrayRef) -> Result, + is_min: bool, ) -> Result { + // Try the primitive fast path first + if let Some(result) = try_primitive_array_min_max(array, is_min) { + return result; + } + + // Fallback: per-row ScalarValue path for non-primitive types + let agg_fn = if is_min { min_batch } else { max_batch }; let null_value = ScalarValue::try_from(array.value_type())?; let result_vec: Vec = array .iter() @@ -215,3 +217,96 @@ fn array_min_max_helper( .try_collect()?; ScalarValue::iter_to_array(result_vec) } + +/// Dispatches to a typed primitive min/max implementation, or returns `None` if +/// the element type is not a primitive. +fn try_primitive_array_min_max( + list_array: &GenericListArray, + is_min: bool, +) -> Option> { + macro_rules! helper { + ($t:ty) => { + return Some(primitive_array_min_max::(list_array, is_min)) + }; + } + downcast_primitive! { + list_array.value_type() => (helper), + _ => {} + } + None +} + +/// Threshold to switch from direct iteration to using `min` / `max` kernel from +/// `arrow::compute`. The latter has enough per-invocation overhead that direct +/// iteration is faster for small lists. +const ARROW_COMPUTE_THRESHOLD: usize = 32; + +/// Computes min or max for each row of a primitive ListArray. +fn primitive_array_min_max( + list_array: &GenericListArray, + is_min: bool, +) -> Result { + let values_array = list_array.values().as_primitive::(); + let values_slice = values_array.values(); + let values_nulls = values_array.nulls(); + let mut result_builder = PrimitiveBuilder::::with_capacity(list_array.len()) + .with_data_type(values_array.data_type().clone()); + + for (row, w) in list_array.offsets().windows(2).enumerate() { + let row_result = if list_array.is_null(row) { + None + } else { + let start = w[0].as_usize(); + let end = w[1].as_usize(); + let len = end - start; + + match len { + 0 => None, + _ if len < ARROW_COMPUTE_THRESHOLD => { + scalar_min_max::(values_slice, values_nulls, start, end, is_min) + } + _ => { + let slice = values_array.slice(start, len); + if is_min { + arrow::compute::min::(&slice) + } else { + arrow::compute::max::(&slice) + } + } + } + }; + + result_builder.append_option(row_result); + } + + Ok(Arc::new(result_builder.finish()) as ArrayRef) +} + +/// Computes min or max for a single list row by directly scanning a slice of +/// the flat values buffer. +#[inline] +fn scalar_min_max( + values_slice: &[T::Native], + values_nulls: Option<&arrow::buffer::NullBuffer>, + start: usize, + end: usize, + is_min: bool, +) -> Option { + let mut best: Option = None; + for (i, &val) in values_slice[start..end].iter().enumerate() { + if let Some(nulls) = values_nulls + && !nulls.is_valid(start + i) + { + continue; + } + let update_best = match best { + None => true, + Some(current) if is_min => val.is_lt(current), + Some(current) => val.is_gt(current), + }; + if update_best { + best = Some(val); + } + } + best +} diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index 4fec5e38065b5..e96fdb7d4baca 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -18,15 +18,15 @@ //! SQL planning extensions like [`NestedFunctionPlanner`] and [`FieldAccessPlanner`] use arrow::datatypes::DataType; -use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result}; +use datafusion_common::{DFSchema, Result, plan_err, utils::list_ndims}; +use datafusion_expr::AggregateUDF; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams}; #[cfg(feature = "sql")] use datafusion_expr::sqlparser::ast::BinaryOperator; -use datafusion_expr::AggregateUDF; use datafusion_expr::{ - planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, Expr, ExprSchemable, GetFieldAccess, + planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, }; #[cfg(not(feature = "sql"))] use datafusion_expr_common::operator::Operator as BinaryOperator; @@ -37,7 +37,7 @@ use std::sync::Arc; use crate::map::map_udf; use crate::{ - array_has::{array_has_all, array_has_udf}, + array_has::array_has_all, expr_fn::{array_append, array_concat, array_prepend}, extract::{array_element, array_slice}, make_array::make_array, @@ -120,20 +120,6 @@ impl ExprPlanner for NestedFunctionPlanner { ScalarFunction::new_udf(map_udf(), vec![keys, values]), ))) } - - fn plan_any(&self, expr: RawBinaryExpr) -> Result> { - if expr.op == BinaryOperator::Eq { - Ok(PlannerResult::Planned(Expr::ScalarFunction( - ScalarFunction::new_udf( - array_has_udf(), - // left and right are reversed here so `needle=any(haystack)` -> `array_has(haystack, needle)` - vec![expr.right, expr.left], - ), - ))) - } else { - plan_err!("Unsupported AnyOp: '{}', only '=' is supported", expr.op) - } - } } #[derive(Debug)] @@ -148,6 +134,9 @@ impl ExprPlanner for FieldAccessPlanner { match field_access { // expr["field"] => get_field(expr, "field") + // Nested accesses like expr["a"]["b"] create nested get_field calls, + // which are then merged by the SimplifyExpressions optimizer pass via + // the GetFieldFunc::simplify() method. GetFieldAccess::NamedStructField { name } => { Ok(PlannerResult::Planned(get_field(expr, name))) } diff --git a/datafusion/functions-nested/src/position.rs b/datafusion/functions-nested/src/position.rs index 2844eefaf058d..d65620ede38e6 100644 --- a/datafusion/functions-nested/src/position.rs +++ b/datafusion/functions-nested/src/position.rs @@ -17,29 +17,30 @@ //! [`ScalarUDFImpl`] definitions for array_position and array_positions functions. +use arrow::array::Scalar; +use arrow::buffer::OffsetBuffer; use arrow::datatypes::DataType; use arrow::datatypes::{ DataType::{LargeList, List, UInt64}, Field, }; +use datafusion_common::ScalarValue; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; use arrow::array::{ - types::UInt64Type, Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, - UInt64Array, + Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array, + types::UInt64Type, }; use datafusion_common::cast::{ as_generic_list_array, as_int64_array, as_large_list_array, as_list_array, }; -use datafusion_common::{ - assert_or_internal_err, exec_err, utils::take_function_args, Result, -}; +use datafusion_common::{Result, exec_err, utils::take_function_args}; use itertools::Itertools; use crate::utils::{compare_element_to_list, make_scalar_function}; @@ -54,7 +55,7 @@ make_udf_expr_and_func!( #[user_doc( doc_section(label = "Array Functions"), - description = "Returns the position of the first occurrence of the specified element in the array, or NULL if not found.", + description = "Returns the position of the first occurrence of the specified element in the array, or NULL if not found. Comparisons are done using `IS DISTINCT FROM` semantics, so NULL is considered to match NULL.", syntax_example = "array_position(array, element)\narray_position(array, element, index)", sql_example = r#"```sql > select array_position([1, 2, 2, 3, 1, 4], 2); @@ -74,10 +75,7 @@ make_udf_expr_and_func!( name = "array", description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ), - argument( - name = "element", - description = "Element to search for position in the array." - ), + argument(name = "element", description = "Element to search for in the array."), argument( name = "index", description = "Index at which to start searching (1-indexed)." @@ -110,9 +108,6 @@ impl ArrayPosition { } impl ScalarUDFImpl for ArrayPosition { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "array_position" } @@ -125,11 +120,11 @@ impl ScalarUDFImpl for ArrayPosition { Ok(UInt64) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - make_scalar_function(array_position_inner)(&args.args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match try_array_position_scalar(&args.args)? { + Some(result) => Ok(result), + None => make_scalar_function(array_position_inner)(&args.args), + } } fn aliases(&self) -> &[String] { @@ -141,6 +136,57 @@ impl ScalarUDFImpl for ArrayPosition { } } +/// Attempts the scalar-needle fast path for `array_position`. +fn try_array_position_scalar(args: &[ColumnarValue]) -> Result> { + if args.len() < 2 || args.len() > 3 { + return exec_err!("array_position expects two or three arguments"); + } + + // Fallback to the generic code path if the needle is an array + let scalar_needle = match &args[1] { + ColumnarValue::Scalar(s) => s, + ColumnarValue::Array(_) => return Ok(None), + }; + + // `not_distinct` doesn't support nested types (List, Struct, etc.), + // so fall back to the generic code path for those. + if scalar_needle.data_type().is_nested() { + return Ok(None); + } + + // Determine batch length from whichever argument is columnar; + // if all inputs are scalar, batch length is 1. + let (num_rows, all_inputs_scalar) = match (&args[0], args.get(2)) { + (ColumnarValue::Array(a), _) => (a.len(), false), + (_, Some(ColumnarValue::Array(a))) => (a.len(), false), + _ => (1, true), + }; + + let needle = scalar_needle.to_array_of_size(1)?; + let haystack = args[0].to_array(num_rows)?; + let arr_from = resolve_start_from(args.get(2), num_rows)?; + + let result = match haystack.data_type() { + List(_) => { + let list = as_list_array(&haystack)?; + array_position_scalar::(list, &needle, &arr_from) + } + LargeList(_) => { + let list = as_large_list_array(&haystack)?; + array_position_scalar::(list, &needle, &arr_from) + } + t => exec_err!("array_position does not support type '{t}'"), + }?; + + if all_inputs_scalar { + Ok(Some(ColumnarValue::Scalar(ScalarValue::try_from_array( + &result, 0, + )?))) + } else { + Ok(Some(ColumnarValue::Array(result))) + } +} + fn array_position_inner(args: &[ArrayRef]) -> Result { if args.len() < 2 || args.len() > 3 { return exec_err!("array_position expects two or three arguments"); @@ -148,57 +194,148 @@ fn array_position_inner(args: &[ArrayRef]) -> Result { match &args[0].data_type() { List(_) => general_position_dispatch::(args), LargeList(_) => general_position_dispatch::(args), - array_type => exec_err!("array_position does not support type '{array_type}'."), + dt => exec_err!("array_position does not support type '{dt}'"), + } +} + +/// Resolves the optional `start_from` argument into a `Vec` of +/// 0-indexed starting positions. +fn resolve_start_from( + third_arg: Option<&ColumnarValue>, + num_rows: usize, +) -> Result> { + match third_arg { + None => Ok(vec![0i64; num_rows]), + Some(ColumnarValue::Scalar(ScalarValue::Int64(Some(v)))) => { + Ok(vec![v - 1; num_rows]) + } + Some(ColumnarValue::Scalar(s)) => { + exec_err!("array_position expected Int64 for start_from, got {s}") + } + Some(ColumnarValue::Array(a)) => { + Ok(as_int64_array(a)?.values().iter().map(|&x| x - 1).collect()) + } } } +/// Fast path for `array_position` when the needle is scalar. +/// +/// Performs a single bulk `not_distinct` comparison of the needle against the +/// entire flat values buffer, then walks the result bitmap using offsets to +/// find per-row first-match positions. +fn array_position_scalar( + haystack: &GenericListArray, + needle: &ArrayRef, + arr_from: &[i64], // 0-indexed +) -> Result { + crate::utils::check_datatypes("array_position", &[haystack.values(), needle])?; + + if haystack.len() == 0 { + return Ok(Arc::new(UInt64Array::new_null(0))); + } + + let needle_datum = Scalar::new(Arc::clone(needle)); + let validity = haystack.nulls(); + + // Only convert the visible portion of the values array. For sliced + // ListArrays, values() returns the full underlying array but only + // elements between the first and last offset are referenced. + let offsets = haystack.offsets(); + let first_offset = offsets[0].as_usize(); + let last_offset = offsets[haystack.len()].as_usize(); + let visible_values = haystack + .values() + .slice(first_offset, last_offset - first_offset); + + // `not_distinct` treats NULL=NULL as true, matching the semantics of + // `array_position`. + let eq_array = arrow_ord::cmp::not_distinct(&visible_values, &needle_datum)?; + let eq_bits = eq_array.values(); + + let mut result: Vec> = Vec::with_capacity(haystack.len()); + let mut matches = eq_bits.set_indices().peekable(); + + // Match positions are relative to visible_values (0-based), so + // subtract first_offset from each offset when comparing. + for i in 0..haystack.len() { + let start = offsets[i].as_usize() - first_offset; + let end = offsets[i + 1].as_usize() - first_offset; + + if validity.is_some_and(|v| v.is_null(i)) { + // Null row -> null output; advance past matches in range + while matches.peek().is_some_and(|&p| p < end) { + matches.next(); + } + result.push(None); + continue; + } + + let from = arr_from[i]; + let row_len = end - start; + if !(from >= 0 && (from as usize) <= row_len) { + return exec_err!("start_from out of bounds: {}", from + 1); + } + let search_start = start + from as usize; + + // Advance past matches before search_start + while matches.peek().is_some_and(|&p| p < search_start) { + matches.next(); + } + + // First match in [search_start, end)? + if matches.peek().is_some_and(|&p| p < end) { + let pos = *matches.peek().unwrap(); + result.push(Some((pos - start + 1) as u64)); + // Advance past remaining matches in this row + while matches.peek().is_some_and(|&p| p < end) { + matches.next(); + } + } else { + result.push(None); + } + } + + debug_assert_eq!(result.len(), haystack.len()); + Ok(Arc::new(UInt64Array::from(result))) +} + fn general_position_dispatch(args: &[ArrayRef]) -> Result { - let list_array = as_generic_list_array::(&args[0])?; - let element_array = &args[1]; + let haystack = as_generic_list_array::(&args[0])?; + let needle = &args[1]; - crate::utils::check_datatypes( - "array_position", - &[list_array.values(), element_array], - )?; + crate::utils::check_datatypes("array_position", &[haystack.values(), needle])?; let arr_from = if args.len() == 3 { as_int64_array(&args[2])? .values() - .to_vec() .iter() .map(|&x| x - 1) .collect::>() } else { - vec![0; list_array.len()] + vec![0; haystack.len()] }; - // if `start_from` index is out of bounds, return error - for (arr, &from) in list_array.iter().zip(arr_from.iter()) { - // If `arr` is `None`: we will get null if we got null in the array, so we don't need to check - assert_or_internal_err!( - arr.is_none_or(|arr| from >= 0 && (from as usize) <= arr.len()), - "start_from index out of bounds" - ); + for (row, &from) in haystack.iter().zip(arr_from.iter()) { + if !row.is_none_or(|row| from >= 0 && (from as usize) <= row.len()) { + return exec_err!("start_from out of bounds: {}", from + 1); + } } - generic_position::(list_array, element_array, &arr_from) + generic_position::(haystack, needle, &arr_from) } -fn generic_position( - list_array: &GenericListArray, - element_array: &ArrayRef, +fn generic_position( + haystack: &GenericListArray, + needle: &ArrayRef, arr_from: &[i64], // 0-indexed ) -> Result { - let mut data = Vec::with_capacity(list_array.len()); + let mut data = Vec::with_capacity(haystack.len()); - for (row_index, (list_array_row, &from)) in - list_array.iter().zip(arr_from.iter()).enumerate() - { + for (row_index, (row, &from)) in haystack.iter().zip(arr_from.iter()).enumerate() { let from = from as usize; - if let Some(list_array_row) = list_array_row { - let eq_array = - compare_element_to_list(&list_array_row, element_array, row_index, true)?; + if let Some(row) = row { + let eq_array = compare_element_to_list(&row, needle, row_index, true)?; // Collect `true`s in 1-indexed positions let index = eq_array @@ -240,17 +377,20 @@ make_udf_expr_and_func!( name = "array", description = "Array expression. Can be a constant, column, or function, and any combination of array operators." ), - argument( - name = "element", - description = "Element to search for position in the array." - ) + argument(name = "element", description = "Element to search for in the array.") )] #[derive(Debug, PartialEq, Eq, Hash)] -pub(super) struct ArrayPositions { +pub struct ArrayPositions { signature: Signature, aliases: Vec, } +impl Default for ArrayPositions { + fn default() -> Self { + Self::new() + } +} + impl ArrayPositions { pub fn new() -> Self { Self { @@ -261,9 +401,6 @@ impl ArrayPositions { } impl ScalarUDFImpl for ArrayPositions { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "array_positions" } @@ -276,11 +413,11 @@ impl ScalarUDFImpl for ArrayPositions { Ok(List(Arc::new(Field::new_list_field(UInt64, true)))) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - make_scalar_function(array_positions_inner)(&args.args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match try_array_positions_scalar(&args.args)? { + Some(result) => Ok(result), + None => make_scalar_function(array_positions_inner)(&args.args), + } } fn aliases(&self) -> &[String] { @@ -292,36 +429,70 @@ impl ScalarUDFImpl for ArrayPositions { } } -fn array_positions_inner(args: &[ArrayRef]) -> Result { - let [array, element] = take_function_args("array_positions", args)?; +/// Attempts the scalar-needle fast path for `array_positions`. +fn try_array_positions_scalar(args: &[ColumnarValue]) -> Result> { + let [haystack_arg, needle_arg] = take_function_args("array_positions", args)?; + + let scalar_needle = match needle_arg { + ColumnarValue::Scalar(s) => s, + ColumnarValue::Array(_) => return Ok(None), + }; - match &array.data_type() { + // `not_distinct` doesn't support nested types (List, Struct, etc.), + // so fall back to the per-row path for those. + if scalar_needle.data_type().is_nested() { + return Ok(None); + } + + let (num_rows, all_inputs_scalar) = match haystack_arg { + ColumnarValue::Array(a) => (a.len(), false), + ColumnarValue::Scalar(_) => (1, true), + }; + + let needle = scalar_needle.to_array_of_size(1)?; + let haystack = haystack_arg.to_array(num_rows)?; + + let result = match haystack.data_type() { List(_) => { - let arr = as_list_array(&array)?; - crate::utils::check_datatypes("array_positions", &[arr.values(), element])?; - general_positions::(arr, element) + let list = as_list_array(&haystack)?; + array_positions_scalar::(list, &needle) } LargeList(_) => { - let arr = as_large_list_array(&array)?; - crate::utils::check_datatypes("array_positions", &[arr.values(), element])?; - general_positions::(arr, element) - } - array_type => { - exec_err!("array_positions does not support type '{array_type}'.") + let list = as_large_list_array(&haystack)?; + array_positions_scalar::(list, &needle) } + t => exec_err!("array_positions does not support type '{t}'"), + }?; + + if all_inputs_scalar { + Ok(Some(ColumnarValue::Scalar(ScalarValue::try_from_array( + &result, 0, + )?))) + } else { + Ok(Some(ColumnarValue::Array(result))) } } -fn general_positions( - list_array: &GenericListArray, - element_array: &ArrayRef, +fn array_positions_inner(args: &[ArrayRef]) -> Result { + let [haystack, needle] = take_function_args("array_positions", args)?; + + match &haystack.data_type() { + List(_) => general_positions::(as_list_array(&haystack)?, needle), + LargeList(_) => general_positions::(as_large_list_array(&haystack)?, needle), + dt => exec_err!("array_positions does not support type '{dt}'"), + } +} + +fn general_positions( + haystack: &GenericListArray, + needle: &ArrayRef, ) -> Result { - let mut data = Vec::with_capacity(list_array.len()); + crate::utils::check_datatypes("array_positions", &[haystack.values(), needle])?; + let mut data = Vec::with_capacity(haystack.len()); - for (row_index, list_array_row) in list_array.iter().enumerate() { - if let Some(list_array_row) = list_array_row { - let eq_array = - compare_element_to_list(&list_array_row, element_array, row_index, true)?; + for (row_index, row) in haystack.iter().enumerate() { + if let Some(row) = row { + let eq_array = compare_element_to_list(&row, needle, row_index, true)?; // Collect `true`s in 1-indexed positions let indexes = eq_array @@ -340,3 +511,243 @@ fn general_positions( ListArray::from_iter_primitive::(data), )) } + +/// Fast path for `array_positions` when the needle is scalar. +/// +/// Performs a single bulk `not_distinct` comparison of the needle against the +/// entire flat values buffer, then walks the result bitmap using offsets to +/// collect all per-row match positions. +fn array_positions_scalar( + haystack: &GenericListArray, + needle: &ArrayRef, +) -> Result { + crate::utils::check_datatypes("array_positions", &[haystack.values(), needle])?; + + let num_rows = haystack.len(); + if num_rows == 0 { + return Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new_list_field(UInt64, true)), + OffsetBuffer::new_zeroed(1), + Arc::new(UInt64Array::from(Vec::::new())), + None, + )?)); + } + + let needle_datum = Scalar::new(Arc::clone(needle)); + let validity = haystack.nulls(); + + // Only convert the visible portion of the values array. For sliced + // ListArrays, values() returns the full underlying array but only + // elements between the first and last offset are referenced. + let offsets = haystack.offsets(); + let first_offset = offsets[0].as_usize(); + let last_offset = offsets[num_rows].as_usize(); + let visible_values = haystack + .values() + .slice(first_offset, last_offset - first_offset); + + // `not_distinct` treats NULL=NULL as true, matching the semantics of + // `array_positions`. + let eq_array = arrow_ord::cmp::not_distinct(&visible_values, &needle_datum)?; + let eq_bits = eq_array.values(); + + let num_matches = eq_bits.count_set_bits(); + let mut positions: Vec = Vec::with_capacity(num_matches); + let mut result_offsets: Vec = Vec::with_capacity(num_rows + 1); + result_offsets.push(0); + let mut matches = eq_bits.set_indices().peekable(); + + // Match positions are relative to visible_values (0-based), so + // subtract first_offset from each offset when comparing. + for i in 0..num_rows { + let start = offsets[i].as_usize() - first_offset; + let end = offsets[i + 1].as_usize() - first_offset; + + if validity.is_some_and(|v| v.is_null(i)) { + // Null row -> null output; advance past matches in range. + while matches.peek().is_some_and(|&p| p < end) { + matches.next(); + } + result_offsets.push(positions.len() as i32); + continue; + } + + // Collect all matches in [start, end). + while let Some(pos) = matches.next_if(|&p| p < end) { + positions.push((pos - start + 1) as u64); + } + result_offsets.push(positions.len() as i32); + } + + debug_assert_eq!(result_offsets.len(), num_rows + 1); + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new_list_field(UInt64, true)), + OffsetBuffer::new(result_offsets.into()), + Arc::new(UInt64Array::from(positions)), + validity.cloned(), + )?)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::AsArray; + use arrow::datatypes::Int32Type; + use datafusion_common::config::ConfigOptions; + + #[test] + fn test_array_position_sliced_list() -> Result<()> { + // [[10, 20], [30, 40], [50, 60], [70, 80]] → slice(1,2) → [[30, 40], [50, 60]] + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(10), Some(20)]), + Some(vec![Some(30), Some(40)]), + Some(vec![Some(50), Some(60)]), + Some(vec![Some(70), Some(80)]), + ]); + let sliced = list.slice(1, 2); + let haystack_field = + Arc::new(Field::new("haystack", sliced.data_type().clone(), true)); + let needle_field = Arc::new(Field::new("needle", DataType::Int32, true)); + let return_field = Arc::new(Field::new("return", UInt64, true)); + + // Search for elements that exist only in sliced-away rows: + // 10 is in the prefix row, 70 is in the suffix row. + let invoke = |needle: i32| -> Result { + ArrayPosition::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(sliced.clone())), + ColumnarValue::Scalar(ScalarValue::Int32(Some(needle))), + ], + arg_fields: vec![ + Arc::clone(&haystack_field), + Arc::clone(&needle_field), + ], + number_rows: 2, + return_field: Arc::clone(&return_field), + config_options: Arc::new(ConfigOptions::default()), + })? + .into_array(2) + }; + + let output = invoke(10)?; + let output = output.as_primitive::(); + assert!(output.is_null(0)); + assert!(output.is_null(1)); + + let output = invoke(70)?; + let output = output.as_primitive::(); + assert!(output.is_null(0)); + assert!(output.is_null(1)); + + Ok(()) + } + + #[test] + fn test_array_positions_sliced_list() -> Result<()> { + // [[10, 20, 30], [30, 40, 30], [50, 60, 30], [70, 80, 30]] + // → slice(1,2) → [[30, 40, 30], [50, 60, 30]] + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(10), Some(20), Some(30)]), + Some(vec![Some(30), Some(40), Some(30)]), + Some(vec![Some(50), Some(60), Some(30)]), + Some(vec![Some(70), Some(80), Some(30)]), + ]); + let sliced = list.slice(1, 2); + let haystack_field = + Arc::new(Field::new("haystack", sliced.data_type().clone(), true)); + let needle_field = Arc::new(Field::new("needle", DataType::Int32, true)); + let return_field = Arc::new(Field::new( + "return", + List(Arc::new(Field::new_list_field(UInt64, true))), + true, + )); + + let invoke = |needle: i32| -> Result { + ArrayPositions::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(sliced.clone())), + ColumnarValue::Scalar(ScalarValue::Int32(Some(needle))), + ], + arg_fields: vec![ + Arc::clone(&haystack_field), + Arc::clone(&needle_field), + ], + number_rows: 2, + return_field: Arc::clone(&return_field), + config_options: Arc::new(ConfigOptions::default()), + })? + .into_array(2) + }; + + // Needle 30: appears at positions 1,3 in row 0 ([30,40,30]) + // and position 3 in row 1 ([50,60,30]). + let output = invoke(30)?; + let output = output.as_list::(); + let row0 = output.value(0); + let row0 = row0.as_primitive::(); + assert_eq!(row0.values().as_ref(), &[1, 3]); + let row1 = output.value(1); + let row1 = row1.as_primitive::(); + assert_eq!(row1.values().as_ref(), &[3]); + + // Needle 10: only in the sliced-away prefix row → empty lists. + let output = invoke(10)?; + let output = output.as_list::(); + assert!(output.value(0).is_empty()); + assert!(output.value(1).is_empty()); + + // Needle 70: only in the sliced-away suffix row → empty lists. + let output = invoke(70)?; + let output = output.as_list::(); + assert!(output.value(0).is_empty()); + assert!(output.value(1).is_empty()); + + Ok(()) + } + + #[test] + fn test_array_positions_sliced_list_with_nulls() -> Result<()> { + // [[1, 2], null, [3, 1], [4, 5]] → slice(1,2) → [null, [3, 1]] + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + Some(vec![Some(3), Some(1)]), + Some(vec![Some(4), Some(5)]), + ]); + let sliced = list.slice(1, 2); + let haystack_field = + Arc::new(Field::new("haystack", sliced.data_type().clone(), true)); + let needle_field = Arc::new(Field::new("needle", DataType::Int32, true)); + let return_field = Arc::new(Field::new( + "return", + List(Arc::new(Field::new_list_field(UInt64, true))), + true, + )); + + let output = ArrayPositions::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(sliced)), + ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), + ], + arg_fields: vec![Arc::clone(&haystack_field), Arc::clone(&needle_field)], + number_rows: 2, + return_field: Arc::clone(&return_field), + config_options: Arc::new(ConfigOptions::default()), + })? + .into_array(2)?; + + let output = output.as_list::(); + // Row 0 is null (from the sliced null row). + assert!(output.is_null(0)); + // Row 1 is [3, 1] → needle 1 found at position 2. + assert!(!output.is_null(1)); + let row1 = output.value(1); + let row1 = row1.as_primitive::(); + assert_eq!(row1.values().as_ref(), &[2]); + + Ok(()) + } +} diff --git a/datafusion/functions-nested/src/range.rs b/datafusion/functions-nested/src/range.rs index e570ecf97420f..65d9244ecdd4c 100644 --- a/datafusion/functions-nested/src/range.rs +++ b/datafusion/functions-nested/src/range.rs @@ -23,35 +23,33 @@ use arrow::datatypes::TimeUnit; use arrow::datatypes::{DataType, Field, IntervalUnit::MonthDayNano}; use arrow::{ array::{ + Array, ArrayRef, Int64Array, ListArray, ListBuilder, NullBufferBuilder, builder::{Date32Builder, TimestampNanosecondBuilder}, temporal_conversions::as_datetime_with_timezone, timezone::Tz, types::{Date32Type, IntervalMonthDayNanoType, TimestampNanosecondType}, - Array, ArrayRef, Int64Array, ListArray, ListBuilder, NullBufferBuilder, }, compute::cast, }; use datafusion_common::internal_err; use datafusion_common::{ + Result, exec_datafusion_err, exec_err, utils::take_function_args, +}; +use datafusion_common::{ + ScalarValue, cast::{ as_date32_array, as_int64_array, as_interval_mdn_array, as_timestamp_nanosecond_array, }, types::{ - logical_date, logical_int64, logical_interval_mdn, logical_string, NativeType, + NativeType, logical_date, logical_int64, logical_interval_mdn, logical_string, }, - ScalarValue, -}; -use datafusion_common::{ - exec_datafusion_err, exec_err, not_impl_datafusion_err, utils::take_function_args, - Result, }; use datafusion_expr::{ - Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, - TypeSignatureClass, Volatility, + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; use std::cmp::Ordering; use std::iter::from_fn; use std::str::FromStr; @@ -212,10 +210,6 @@ impl Range { } impl ScalarUDFImpl for Range { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { if self.include_upper_bound { "generate_series" @@ -252,10 +246,7 @@ impl ScalarUDFImpl for Range { } } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = &args.args; if args.iter().any(|arg| arg.data_type().is_null()) { @@ -297,7 +288,7 @@ impl Range { /// /// # Arguments /// - /// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values. + /// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step (step value can not be zero) values. /// /// # Examples /// @@ -332,16 +323,13 @@ impl Range { ); } Some((start, stop, step)) => { - // Below, we utilize `usize` to represent steps. - // On 32-bit targets, the absolute value of `i64` may fail to fit into `usize`. - let step_abs = - usize::try_from(step.unsigned_abs()).map_err(|_| { - not_impl_datafusion_err!("step {} can't fit into usize", step) - })?; - values.extend( - gen_range_iter(start, stop, step < 0, self.include_upper_bound) - .step_by(step_abs), - ); + generate_range_values( + start, + stop, + step, + self.include_upper_bound, + &mut values, + )?; offsets.push(values.len() as i32); valid.append_non_null(); } @@ -392,20 +380,27 @@ impl Range { } let stop = if !self.include_upper_bound { - Date32Type::subtract_month_day_nano(stop, step) + Date32Type::subtract_month_day_nano_opt(stop, step).ok_or_else(|| { + exec_datafusion_err!( + "Cannot generate date range where stop {} - {step:?}) overflows", + date32_to_string(stop) + ) + })? } else { stop }; let neg = months < 0 || days < 0; - let mut new_date = start; + let mut new_date = Some(start); let values = from_fn(|| { - if (neg && new_date < stop) || (!neg && new_date > stop) { + let Some(current_date) = new_date else { + return None; // previous overflow + }; + if (neg && current_date < stop) || (!neg && current_date > stop) { None } else { - let current_date = new_date; - new_date = Date32Type::add_month_day_nano(new_date, step); + new_date = Date32Type::add_month_day_nano_opt(current_date, step); Some(Some(current_date)) } }); @@ -543,38 +538,88 @@ fn retrieve_range_args( Some((start, stop, step)) } -/// Returns an iterator of i64 values from start to stop -fn gen_range_iter( +/// Reserve space for `count` more elements, returning an error when the +/// allocation would overflow `Vec`'s capacity limit or the allocator +/// rejects it, rather than panicking on user-supplied SQL. +fn reserve_range_capacity(values: &mut Vec, count: u64) -> Result<()> { + let count_usize = usize::try_from(count).map_err(|_| { + exec_datafusion_err!( + "Range too large to materialize: would produce {count} elements" + ) + })?; + values.try_reserve(count_usize).map_err(|e| { + exec_datafusion_err!( + "Range too large to materialize: failed to allocate {count} elements: {e}" + ) + }) +} + +/// Generate integer range values directly into the provided buffer. +#[inline] +fn generate_range_values( start: i64, stop: i64, - decreasing: bool, + step: i64, include_upper: bool, -) -> Box> { - match (decreasing, include_upper) { - // Decreasing range, stop is inclusive - (true, true) => Box::new((stop..=start).rev()), - // Decreasing range, stop is exclusive - (true, false) => { - if stop == i64::MAX { - // start is never greater than stop, and stop is exclusive, - // so the decreasing range must be empty. - Box::new(std::iter::empty()) - } else { - // Increase the stop value by one to exclude it. - // Since stop is not i64::MAX, `stop + 1` will not overflow. - Box::new((stop + 1..=start).rev()) + values: &mut Vec, +) -> Result<()> { + if !include_upper && start == stop { + return Ok(()); + } + + if step > 0 { + let limit = if include_upper { + stop + } else { + stop.saturating_sub(1) + }; + if start > limit { + return Ok(()); + } + let count = (start.abs_diff(limit) / step.unsigned_abs()).saturating_add(1); + reserve_range_capacity(values, count)?; + let mut current = start; + while current <= limit { + values.push(current); + match current.checked_add(step) { + Some(next) => current = next, + None => break, + } + } + } else if step < 0 { + let limit = if include_upper { + stop + } else { + stop.saturating_add(1) + }; + if start < limit { + return Ok(()); + } + let count = (start.abs_diff(limit) / step.unsigned_abs()).saturating_add(1); + reserve_range_capacity(values, count)?; + let mut current = start; + while current >= limit { + values.push(current); + match current.checked_add(step) { + Some(next) => current = next, + None => break, } } - // Increasing range, stop is inclusive - (false, true) => Box::new(start..=stop), - // Increasing range, stop is exclusive - (false, false) => Box::new(start..stop), } + Ok(()) } fn parse_tz(tz: &Option<&str>) -> Result { - let tz = tz.as_ref().map_or_else(|| "+00", |s| s); + let tz = tz.unwrap_or_else(|| "+00"); Tz::from_str(tz) .map_err(|op| exec_datafusion_err!("failed to parse timezone {tz}: {:?}", op)) } + +fn date32_to_string(value: i32) -> String { + if let Some(d) = Date32Type::to_naive_date_opt(value) { + format!("{value} ({d})") + } else { + format!("{value} (unknown date)") + } +} diff --git a/datafusion/functions-nested/src/remove.rs b/datafusion/functions-nested/src/remove.rs index 46111b0c2d122..44ef56c039b71 100644 --- a/datafusion/functions-nested/src/remove.rs +++ b/datafusion/functions-nested/src/remove.rs @@ -18,35 +18,36 @@ //! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions. use crate::utils; -use crate::utils::make_scalar_function; use arrow::array::{ - cast::AsArray, new_empty_array, Array, ArrayRef, BooleanArray, GenericListArray, - OffsetSizeTrait, + Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, NullBufferBuilder, + OffsetBufferBuilder, OffsetSizeTrait, Scalar, cast::AsArray, make_array, + new_null_array, }; use arrow::buffer::OffsetBuffer; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::cast::as_int64_array; use datafusion_common::utils::ListCoercion; -use datafusion_common::{exec_err, utils::take_function_args, Result}; +use datafusion_common::{ + Result, ScalarValue, exec_err, internal_err, utils::take_function_args, +}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, - ScalarUDFImpl, Signature, TypeSignature, Volatility, + ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; make_udf_expr_and_func!( ArrayRemove, array_remove, array element, - "removes the first element from the array equal to the given value.", + "removes the first element from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.", array_remove_udf ); #[user_doc( doc_section(label = "Array Functions"), - description = "Removes the first element from the array equal to the given value.", + description = "Removes the first element from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.", syntax_example = "array_remove(array, element)", sql_example = r#"```sql > select array_remove([1, 2, 2, 3, 2, 1, 4], 2); @@ -55,6 +56,13 @@ make_udf_expr_and_func!( +----------------------------------------------+ | [1, 2, 3, 2, 1, 4] | +----------------------------------------------+ + +> select array_remove([1, 2, NULL, 2, 4], 2); ++---------------------------------------------------+ +| array_remove(List([1,2,NULL,2,4]),Int64(2)) | ++---------------------------------------------------+ +| [1, NULL, 2, 4] | ++---------------------------------------------------+ ```"#, argument( name = "array", @@ -87,10 +95,6 @@ impl ArrayRemove { } impl ScalarUDFImpl for ArrayRemove { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_remove" } @@ -99,15 +103,39 @@ impl ScalarUDFImpl for ArrayRemove { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") } - fn invoke_with_args( + fn return_field_from_args( &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - make_scalar_function(array_remove_inner)(&args.args) + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + let array_field = args.arg_fields[0].as_ref().clone(); + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(array_field.with_nullable(nullable))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + match element_arg { + ColumnarValue::Scalar(scalar_element) + if !scalar_element.is_null() + && !scalar_element.data_type().is_nested() => + { + let result = + array_remove_with_scalar_args(&list_array, scalar_element, 1i64)?; + Ok(ColumnarValue::Array(result)) + } + element_arg => { + let element_array = element_arg.to_array(num_rows)?; + let result = + array_remove_internal(&list_array, &element_array, &[Some(1)])?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -123,14 +151,14 @@ make_udf_expr_and_func!( ArrayRemoveN, array_remove_n, array element max, - "removes the first `max` elements from the array equal to the given value.", + "removes the first `max` elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.", array_remove_n_udf ); #[user_doc( doc_section(label = "Array Functions"), - description = "Removes the first `max` elements from the array equal to the given value.", - syntax_example = "array_remove_n(array, element, max))", + description = "Removes the first `max` elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.", + syntax_example = "array_remove_n(array, element, max)", sql_example = r#"```sql > select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2); +---------------------------------------------------------+ @@ -138,6 +166,13 @@ make_udf_expr_and_func!( +---------------------------------------------------------+ | [1, 3, 2, 1, 4] | +---------------------------------------------------------+ + +> select array_remove_n([1, 2, NULL, 2, 4], 2, 2); ++----------------------------------------------------------+ +| array_remove_n(List([1,2,NULL,2,4]),Int64(2),Int64(2)) | ++----------------------------------------------------------+ +| [1, NULL, 4] | ++----------------------------------------------------------+ ```"#, argument( name = "array", @@ -150,11 +185,17 @@ make_udf_expr_and_func!( argument(name = "max", description = "Number of first occurrences to remove.") )] #[derive(Debug, PartialEq, Eq, Hash)] -pub(super) struct ArrayRemoveN { +pub struct ArrayRemoveN { signature: Signature, aliases: Vec, } +impl Default for ArrayRemoveN { + fn default() -> Self { + Self::new() + } +} + impl ArrayRemoveN { pub fn new() -> Self { Self { @@ -175,10 +216,6 @@ impl ArrayRemoveN { } impl ScalarUDFImpl for ArrayRemoveN { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_remove_n" } @@ -187,15 +224,47 @@ impl ScalarUDFImpl for ArrayRemoveN { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") } - fn invoke_with_args( + fn return_field_from_args( &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - make_scalar_function(array_remove_n_inner)(&args.args) + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + let array_field = args.arg_fields[0].as_ref().clone(); + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(array_field.with_nullable(nullable))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [list_arg, element_arg, max_arg] = + take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + match (element_arg, max_arg) { + ( + ColumnarValue::Scalar(scalar_element), + ColumnarValue::Scalar(scalar_max), + ) if !scalar_element.is_null() && !scalar_element.data_type().is_nested() => { + let ScalarValue::Int64(Some(n)) = scalar_max else { + return Ok(ColumnarValue::Array(new_null_array( + list_array.data_type(), + num_rows, + ))); + }; + let result = + array_remove_with_scalar_args(&list_array, scalar_element, *n)?; + Ok(ColumnarValue::Array(result)) + } + (element_arg, max_arg) => { + let element_array = element_arg.to_array(num_rows)?; + let max_array = max_arg.to_array(num_rows)?; + let arr_n = as_int64_array(&max_array)?.iter().collect::>(); + let result = array_remove_internal(&list_array, &element_array, &arr_n)?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -211,13 +280,13 @@ make_udf_expr_and_func!( ArrayRemoveAll, array_remove_all, array element, - "removes all elements from the array equal to the given value.", + "removes all elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.", array_remove_all_udf ); #[user_doc( doc_section(label = "Array Functions"), - description = "Removes all elements from the array equal to the given value.", + description = "Removes all elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.", syntax_example = "array_remove_all(array, element)", sql_example = r#"```sql > select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2); @@ -226,6 +295,13 @@ make_udf_expr_and_func!( +--------------------------------------------------+ | [1, 3, 1, 4] | +--------------------------------------------------+ + +> select array_remove_all([1, 2, NULL, 2, 4], 2); ++-----------------------------------------------------+ +| array_remove_all(List([1,2,NULL,2,4]),Int64(2)) | ++-----------------------------------------------------+ +| [1, NULL, 4] | ++-----------------------------------------------------+ ```"#, argument( name = "array", @@ -237,11 +313,17 @@ make_udf_expr_and_func!( ) )] #[derive(Debug, PartialEq, Eq, Hash)] -pub(super) struct ArrayRemoveAll { +pub struct ArrayRemoveAll { signature: Signature, aliases: Vec, } +impl Default for ArrayRemoveAll { + fn default() -> Self { + Self::new() + } +} + impl ArrayRemoveAll { pub fn new() -> Self { Self { @@ -252,10 +334,6 @@ impl ArrayRemoveAll { } impl ScalarUDFImpl for ArrayRemoveAll { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_remove_all" } @@ -264,15 +342,42 @@ impl ScalarUDFImpl for ArrayRemoveAll { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") } - fn invoke_with_args( + fn return_field_from_args( &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - make_scalar_function(array_remove_all_inner)(&args.args) + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + let array_field = args.arg_fields[0].as_ref().clone(); + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(array_field.with_nullable(nullable))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + match element_arg { + ColumnarValue::Scalar(scalar_element) + if !scalar_element.is_null() + && !scalar_element.data_type().is_nested() => + { + let result = + array_remove_with_scalar_args(&list_array, scalar_element, i64::MAX)?; + Ok(ColumnarValue::Array(result)) + } + element_arg => { + let element_array = element_arg.to_array(num_rows)?; + let result = array_remove_internal( + &list_array, + &element_array, + &[Some(i64::MAX)], + )?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -284,31 +389,10 @@ impl ScalarUDFImpl for ArrayRemoveAll { } } -fn array_remove_inner(args: &[ArrayRef]) -> Result { - let [array, element] = take_function_args("array_remove", args)?; - - let arr_n = vec![1; array.len()]; - array_remove_internal(array, element, &arr_n) -} - -fn array_remove_n_inner(args: &[ArrayRef]) -> Result { - let [array, element, max] = take_function_args("array_remove_n", args)?; - - let arr_n = as_int64_array(max)?.values().to_vec(); - array_remove_internal(array, element, &arr_n) -} - -fn array_remove_all_inner(args: &[ArrayRef]) -> Result { - let [array, element] = take_function_args("array_remove_all", args)?; - - let arr_n = vec![i64::MAX; array.len()]; - array_remove_internal(array, element, &arr_n) -} - fn array_remove_internal( array: &ArrayRef, element_array: &ArrayRef, - arr_n: &[i64], + arr_n: &[Option], ) -> Result { match array.data_type() { DataType::List(_) => { @@ -325,6 +409,28 @@ fn array_remove_internal( } } +/// Fast path for `array_remove` when the needle is a non-null, non-nested scalar. +/// Dispatches to the bulk `not_distinct` comparison kernel. +fn array_remove_with_scalar_args( + array: &ArrayRef, + scalar_needle: &ScalarValue, + max_removals: i64, +) -> Result { + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_remove_with_scalar::(list_array, scalar_needle, max_removals) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_remove_with_scalar::(list_array, scalar_needle, max_removals) + } + array_type => exec_err!( + "array_remove/array_remove_n/array_remove_all does not support type '{array_type}'." + ), + } +} + /// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences /// of `element_array[i]`. /// @@ -345,74 +451,722 @@ fn array_remove_internal( fn general_remove( list_array: &GenericListArray, element_array: &ArrayRef, - arr_n: &[i64], + arr_n: &[Option], ) -> Result { - let data_type = list_array.value_type(); - let mut new_values = vec![]; + let list_field = match list_array.data_type() { + DataType::List(field) | DataType::LargeList(field) => field, + _ => { + return exec_err!( + "Expected List or LargeList data type, got {:?}", + list_array.data_type() + ); + } + }; + let original_data = list_array.values().to_data(); // Build up the offsets for the final output array let mut offsets = Vec::::with_capacity(arr_n.len() + 1); offsets.push(OffsetSize::zero()); - // n is the number of elements to remove in this row - for (row_index, (list_array_row, n)) in - list_array.iter().zip(arr_n.iter()).enumerate() - { - match list_array_row { - Some(list_array_row) => { - let eq_array = utils::compare_element_to_list( - &list_array_row, - element_array, - row_index, - false, - )?; + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data], + false, + Capacities::Array(original_data.len()), + ); + let mut valid = NullBufferBuilder::new(list_array.len()); - // We need to keep at most first n elements as `false`, which represent the elements to remove. - let eq_array = if eq_array.false_count() < *n as usize { - eq_array - } else { - let mut count = 0; - eq_array - .iter() - .map(|e| { - // Keep first n `false` elements, and reverse other elements to `true`. - if let Some(false) = e { - if count < *n { - count += 1; - e - } else { - Some(true) - } - } else { - e - } - }) - .collect::() - }; + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + if list_array.is_null(row_index) || element_array.is_null(row_index) { + offsets.push(offsets[row_index]); + valid.append_null(); + continue; + } - let filtered_array = arrow::compute::filter(&list_array_row, &eq_array)?; - offsets.push( - offsets[row_index] + OffsetSize::usize_as(filtered_array.len()), - ); - new_values.push(filtered_array); - } - None => { - // Null element results in a null row (no new offsets) - offsets.push(offsets[row_index]); + let n = if arr_n.len() == 1 { + arr_n[0] + } else { + arr_n[row_index] + }; + let Some(n) = n else { + offsets.push(offsets[row_index]); + valid.append_null(); + continue; + }; + + let start = offset_window[0].to_usize().unwrap(); + let end = offset_window[1].to_usize().unwrap(); + + // compare each element in the list, `false` means the element matches and should be removed + let eq_array = utils::compare_element_to_list( + &list_array.value(row_index), + element_array, + row_index, + false, + )?; + + let num_to_remove = eq_array.false_count(); + + // Fast path: no elements to remove, copy entire row + if num_to_remove == 0 { + mutable.extend(0, start, end); + offsets.push(offsets[row_index] + OffsetSize::usize_as(end - start)); + valid.append_non_null(); + continue; + } + + // Remove at most `n` matching elements + let max_removals = n.min(num_to_remove as i64); + let mut removed = 0i64; + let mut copied = 0usize; + // marks the beginning of a range of elements pending to be copied. + let mut pending_batch_to_retain: Option = None; + for (i, keep) in eq_array.iter().enumerate() { + if keep == Some(false) && removed < max_removals { + // Flush pending batch before skipping this element + if let Some(bs) = pending_batch_to_retain { + mutable.extend(0, start + bs, start + i); + copied += i - bs; + pending_batch_to_retain = None; + } + removed += 1; + } else if pending_batch_to_retain.is_none() { + pending_batch_to_retain = Some(i); } } + + // Flush remaining batch + if let Some(bs) = pending_batch_to_retain { + mutable.extend(0, start + bs, start + eq_array.len()); + copied += eq_array.len() - bs; + } + + offsets.push(offsets[row_index] + OffsetSize::usize_as(copied)); + valid.append_non_null(); + } + + let new_values = make_array(mutable.freeze()); + Ok(Arc::new(GenericListArray::::try_new( + Arc::clone(list_field), + OffsetBuffer::new(offsets.into()), + new_values, + valid.finish(), + )?)) +} + +/// For each element of `list_array[i]`, removes up to `max_removals` occurrences +/// of the scalar needle. +/// +/// This is a specialized version of `general_remove` for scalar elements that +/// uses bulk comparison for better performance. +fn general_remove_with_scalar( + list_array: &GenericListArray, + scalar_needle: &ScalarValue, + max_removals: i64, +) -> Result { + if max_removals <= 0 { + return Ok(Arc::new(list_array.clone())); } - let values = if new_values.is_empty() { - new_empty_array(&data_type) - } else { - let new_values = new_values.iter().map(|x| x.as_ref()).collect::>(); - arrow::compute::concat(&new_values)? + let list_field = match list_array.data_type() { + DataType::List(field) | DataType::LargeList(field) => field, + _ => { + return exec_err!( + "Expected List or LargeList data type, got {:?}", + list_array.data_type() + ); + } }; + let list_offsets = list_array.offsets(); + let first_offset = list_offsets[0].to_usize().unwrap(); + let last_offset = list_offsets[list_offsets.len() - 1].to_usize().unwrap(); + let values_range_len = last_offset - first_offset; + let values_slice = list_array.values().slice(first_offset, values_range_len); + let original_data = values_slice.to_data(); + let mut offsets = OffsetBufferBuilder::::new(list_array.len()); + + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data], + false, + Capacities::Array(original_data.len()), + ); + let nulls = list_array.nulls().cloned(); + let needle = scalar_needle.to_array_of_size(1)?; + let remove_mask = arrow_ord::cmp::not_distinct(&values_slice, &Scalar::new(needle))?; + let remove_bits = remove_mask.values(); + + for (row_index, offset_window) in list_offsets.windows(2).enumerate() { + if nulls.as_ref().is_some_and(|nulls| nulls.is_null(row_index)) { + offsets.push_length(0); + continue; + } + + let start = offset_window[0].to_usize().unwrap() - first_offset; + let end = offset_window[1].to_usize().unwrap() - first_offset; + let row_len = end - start; + + let row_remove_bits = remove_bits.slice(start, row_len); + let num_to_remove = row_remove_bits.count_set_bits(); + + if num_to_remove == 0 { + mutable.extend(0, start, end); + offsets.push_length(row_len); + continue; + } + + let removals_to_apply = max_removals.min(num_to_remove as i64) as usize; + + // Iterate only over the removal positions via set_indices. This is + // efficient when the number of removals is small relative to the row + // length (common case), since it skips over retained elements. + let mut removed = 0usize; + let mut copied = 0usize; + let mut prev_end = start; + for remove_pos in row_remove_bits.set_indices() { + let abs_pos = start + remove_pos; + if abs_pos > prev_end { + mutable.extend(0, prev_end, abs_pos); + copied += abs_pos - prev_end; + } + prev_end = abs_pos + 1; + removed += 1; + if removed == removals_to_apply { + break; + } + } + // Copy the remaining tail after the last removal + if prev_end < end { + mutable.extend(0, prev_end, end); + copied += end - prev_end; + } + + offsets.push_length(copied); + } + + let new_values = make_array(mutable.freeze()); Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new_list_field(data_type, true)), - OffsetBuffer::new(offsets.into()), - values, - list_array.nulls().cloned(), + Arc::clone(list_field), + offsets.finish(), + new_values, + nulls, )?)) } + +#[cfg(test)] +mod tests { + use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN}; + use arrow::array::{ + Array, ArrayRef, AsArray, GenericListArray, Int32Array, Int64Array, ListArray, + OffsetSizeTrait, + }; + use arrow::buffer::{NullBuffer, ScalarBuffer}; + use arrow::datatypes::{DataType, Field, Int32Type}; + use datafusion_common::ScalarValue; + use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use std::ops::Deref; + use std::sync::Arc; + + #[test] + fn test_array_remove_nullability() { + for nullability in [true, false] { + for item_nullability in [true, false] { + for element_nullability in [true, false] { + let input_field = Arc::new(Field::new( + "num", + DataType::new_list(DataType::Int32, item_nullability), + nullability, + )); + let args_fields = vec![ + Arc::clone(&input_field), + Arc::new(Field::new("a", DataType::Int32, element_nullability)), + ]; + let scalar_args = vec![None, Some(&ScalarValue::Int32(Some(1)))]; + + let result = ArrayRemove::new() + .return_field_from_args(ReturnFieldArgs { + arg_fields: &args_fields, + scalar_arguments: &scalar_args, + }) + .unwrap(); + + let expected = Arc::new( + input_field + .as_ref() + .clone() + .with_nullable(nullability || element_nullability), + ); + + assert_eq!(result, expected); + } + } + } + } + + #[test] + fn test_array_remove_n_nullability() { + for nullability in [true, false] { + for item_nullability in [true, false] { + for element_nullability in [true, false] { + for count_nullability in [true, false] { + let input_field = Arc::new(Field::new( + "num", + DataType::new_list(DataType::Int32, item_nullability), + nullability, + )); + let args_fields = vec![ + Arc::clone(&input_field), + Arc::new(Field::new( + "a", + DataType::Int32, + element_nullability, + )), + Arc::new(Field::new("b", DataType::Int64, count_nullability)), + ]; + let scalar_args = vec![ + None, + Some(&ScalarValue::Int32(Some(1))), + Some(&ScalarValue::Int64(Some(1))), + ]; + + let result = ArrayRemoveN::new() + .return_field_from_args(ReturnFieldArgs { + arg_fields: &args_fields, + scalar_arguments: &scalar_args, + }) + .unwrap(); + + let expected_nullable = + nullability || element_nullability || count_nullability; + let expected = Arc::new( + input_field + .as_ref() + .clone() + .with_nullable(expected_nullable), + ); + + assert_eq!(result, expected); + } + } + } + } + } + + #[test] + fn test_array_remove_all_nullability() { + for nullability in [true, false] { + for item_nullability in [true, false] { + for element_nullability in [true, false] { + let input_field = Arc::new(Field::new( + "num", + DataType::new_list(DataType::Int32, item_nullability), + nullability, + )); + let args_fields = vec![ + Arc::clone(&input_field), + Arc::new(Field::new("a", DataType::Int32, element_nullability)), + ]; + let scalar_args = vec![None, Some(&ScalarValue::Int32(Some(1)))]; + let result = ArrayRemoveAll::new() + .return_field_from_args(ReturnFieldArgs { + arg_fields: &args_fields, + scalar_arguments: &scalar_args, + }) + .unwrap(); + + let expected = Arc::new( + input_field + .as_ref() + .clone() + .with_nullable(nullability || element_nullability), + ); + + assert_eq!(result, expected); + } + } + } + } + + fn ensure_field_nullability( + field_nullable: bool, + list: GenericListArray, + ) -> GenericListArray { + let (field, offsets, values, nulls) = list.into_parts(); + + if field.is_nullable() == field_nullable { + return GenericListArray::new(field, offsets, values, nulls); + } + if !field_nullable { + assert_eq!(nulls, None); + } + + let field = Arc::new(field.deref().clone().with_nullable(field_nullable)); + + GenericListArray::new(field, offsets, values, nulls) + } + + #[test] + fn test_array_remove_non_nullable() { + let input_list = Arc::new(ensure_field_nullability( + false, + ListArray::from_iter_primitive::(vec![ + Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)), + Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)), + ]), + )); + let expected_list = ensure_field_nullability( + false, + ListArray::from_iter_primitive::(vec![ + Some(([1, 2, 3, 2, 1, 4]).iter().copied().map(Some)), + Some(([42, 55, 63, 2]).iter().copied().map(Some)), + ]), + ); + + let element_to_remove = ScalarValue::Int32(Some(2)); + + assert_array_remove(input_list, expected_list, element_to_remove); + } + + #[test] + fn test_array_remove_nullable() { + let input_list = Arc::new(ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![ + Some(vec![ + Some(1), + Some(2), + Some(2), + Some(3), + None, + Some(1), + Some(4), + ]), + Some(vec![Some(42), Some(2), None, Some(63), Some(2)]), + ]), + )); + let expected_list = ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3), None, Some(1), Some(4)]), + Some(vec![Some(42), None, Some(63), Some(2)]), + ]), + ); + + let element_to_remove = ScalarValue::Int32(Some(2)); + + assert_array_remove(input_list, expected_list, element_to_remove); + } + + fn assert_array_remove( + input_list: ArrayRef, + expected_list: GenericListArray, + element_to_remove: ScalarValue, + ) { + assert_eq!(input_list.data_type(), expected_list.data_type()); + assert_eq!(expected_list.value_type(), element_to_remove.data_type()); + let input_list_len = input_list.len(); + let input_list_data_type = input_list.data_type().clone(); + + let udf = ArrayRemove::new(); + let args_fields = vec![ + Arc::new(Field::new("num", input_list.data_type().clone(), false)), + Arc::new(Field::new( + "el", + element_to_remove.data_type(), + element_to_remove.is_null(), + )), + ]; + let scalar_args = vec![None, Some(&element_to_remove)]; + + let return_field = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &args_fields, + scalar_arguments: &scalar_args, + }) + .unwrap(); + + let result = udf + .invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(input_list), + ColumnarValue::Scalar(element_to_remove), + ], + arg_fields: args_fields, + number_rows: input_list_len, + return_field, + config_options: Arc::new(Default::default()), + }) + .unwrap(); + + assert_eq!(result.data_type(), input_list_data_type); + match result { + ColumnarValue::Array(array) => { + let result_list = array.as_list::(); + assert_eq!(result_list, &expected_list); + } + _ => panic!("Expected ColumnarValue::Array"), + } + } + + #[test] + fn test_array_remove_n_non_nullable() { + let input_list = Arc::new(ensure_field_nullability( + false, + ListArray::from_iter_primitive::(vec![ + Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)), + Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)), + ]), + )); + let expected_list = ensure_field_nullability( + false, + ListArray::from_iter_primitive::(vec![ + Some(([1, 3, 2, 1, 4]).iter().copied().map(Some)), + Some(([42, 55, 63]).iter().copied().map(Some)), + ]), + ); + + let element_to_remove = ScalarValue::Int32(Some(2)); + + assert_array_remove_n(input_list, expected_list, element_to_remove, 2); + } + + #[test] + fn test_array_remove_n_nullable() { + let input_list = Arc::new(ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![ + Some(vec![ + Some(1), + Some(2), + Some(2), + Some(3), + None, + Some(1), + Some(4), + ]), + Some(vec![Some(42), Some(2), None, Some(63), Some(2)]), + ]), + )); + let expected_list = ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(3), None, Some(1), Some(4)]), + Some(vec![Some(42), None, Some(63)]), + ]), + ); + + let element_to_remove = ScalarValue::Int32(Some(2)); + + assert_array_remove_n(input_list, expected_list, element_to_remove, 2); + } + + #[test] + fn test_array_remove_n_null_count_returns_null() { + let array: ArrayRef = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(2)]), + Some(vec![Some(4), Some(2)]), + ])); + let element: ArrayRef = Arc::new(Int32Array::from(vec![2, 2])); + let max: ArrayRef = Arc::new(Int64Array::new( + ScalarBuffer::from(vec![1, 1]), + Some(NullBuffer::from(vec![true, false])), + )); + + let udf = ArrayRemoveN::new(); + let args_fields = vec![ + Arc::new(Field::new("num", array.data_type().clone(), false)), + Arc::new(Field::new("el", DataType::Int32, false)), + Arc::new(Field::new("count", DataType::Int64, true)), + ]; + let scalar_args = vec![None, None, None]; + let return_field = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &args_fields, + scalar_arguments: &scalar_args, + }) + .unwrap(); + let result = udf + .invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(array), + ColumnarValue::Array(element), + ColumnarValue::Array(max), + ], + arg_fields: args_fields, + number_rows: 2, + return_field, + config_options: Arc::new(Default::default()), + }) + .unwrap(); + let expected = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + ]); + + match result { + ColumnarValue::Array(array) => { + assert_eq!(array.as_list::(), &expected); + } + _ => panic!("Expected ColumnarValue::Array"), + } + } + + fn assert_array_remove_n( + input_list: ArrayRef, + expected_list: GenericListArray, + element_to_remove: ScalarValue, + n: i64, + ) { + assert_eq!(input_list.data_type(), expected_list.data_type()); + assert_eq!(expected_list.value_type(), element_to_remove.data_type()); + let input_list_len = input_list.len(); + let input_list_data_type = input_list.data_type().clone(); + + let count_scalar = ScalarValue::Int64(Some(n)); + + let udf = ArrayRemoveN::new(); + let args_fields = vec![ + Arc::new(Field::new("num", input_list.data_type().clone(), false)), + Arc::new(Field::new( + "el", + element_to_remove.data_type(), + element_to_remove.is_null(), + )), + Arc::new(Field::new("count", DataType::Int64, false)), + ]; + let scalar_args = vec![None, Some(&element_to_remove), Some(&count_scalar)]; + + let return_field = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &args_fields, + scalar_arguments: &scalar_args, + }) + .unwrap(); + + let result = udf + .invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(input_list), + ColumnarValue::Scalar(element_to_remove), + ColumnarValue::Scalar(count_scalar), + ], + arg_fields: args_fields, + number_rows: input_list_len, + return_field, + config_options: Arc::new(Default::default()), + }) + .unwrap(); + + assert_eq!(result.data_type(), input_list_data_type); + match result { + ColumnarValue::Array(array) => { + let result_list = array.as_list::(); + assert_eq!(result_list, &expected_list); + } + _ => panic!("Expected ColumnarValue::Array"), + } + } + + #[test] + fn test_array_remove_all_non_nullable() { + let input_list = Arc::new(ensure_field_nullability( + false, + ListArray::from_iter_primitive::(vec![ + Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)), + Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)), + ]), + )); + let expected_list = ensure_field_nullability( + false, + ListArray::from_iter_primitive::(vec![ + Some(([1, 3, 1, 4]).iter().copied().map(Some)), + Some(([42, 55, 63]).iter().copied().map(Some)), + ]), + ); + + let element_to_remove = ScalarValue::Int32(Some(2)); + + assert_array_remove_all(input_list, expected_list, element_to_remove); + } + + #[test] + fn test_array_remove_all_nullable() { + let input_list = Arc::new(ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![ + Some(vec![ + Some(1), + Some(2), + Some(2), + Some(3), + None, + Some(1), + Some(4), + ]), + Some(vec![Some(42), Some(2), None, Some(63), Some(2)]), + ]), + )); + let expected_list = ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(3), None, Some(1), Some(4)]), + Some(vec![Some(42), None, Some(63)]), + ]), + ); + + let element_to_remove = ScalarValue::Int32(Some(2)); + + assert_array_remove_all(input_list, expected_list, element_to_remove); + } + + fn assert_array_remove_all( + input_list: ArrayRef, + expected_list: GenericListArray, + element_to_remove: ScalarValue, + ) { + assert_eq!(input_list.data_type(), expected_list.data_type()); + assert_eq!(expected_list.value_type(), element_to_remove.data_type()); + let input_list_len = input_list.len(); + let input_list_data_type = input_list.data_type().clone(); + + let udf = ArrayRemoveAll::new(); + let args_fields = vec![ + Arc::new(Field::new("num", input_list.data_type().clone(), false)), + Arc::new(Field::new( + "el", + element_to_remove.data_type(), + element_to_remove.is_null(), + )), + ]; + let scalar_args = vec![None, Some(&element_to_remove)]; + + let return_field = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &args_fields, + scalar_arguments: &scalar_args, + }) + .unwrap(); + + let result = udf + .invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(input_list), + ColumnarValue::Scalar(element_to_remove), + ], + arg_fields: args_fields, + number_rows: input_list_len, + return_field, + config_options: Arc::new(Default::default()), + }) + .unwrap(); + + assert_eq!(result.data_type(), input_list_data_type); + match result { + ColumnarValue::Array(array) => { + let result_list = array.as_list::(); + assert_eq!(result_list, &expected_list); + } + _ => panic!("Expected ColumnarValue::Array"), + } + } +} diff --git a/datafusion/functions-nested/src/repeat.rs b/datafusion/functions-nested/src/repeat.rs index d978081e490c8..d7dff21141429 100644 --- a/datafusion/functions-nested/src/repeat.rs +++ b/datafusion/functions-nested/src/repeat.rs @@ -19,26 +19,31 @@ use crate::utils::make_scalar_function; use arrow::array::{ - new_null_array, Array, ArrayRef, Capacities, GenericListArray, ListArray, - MutableArrayData, OffsetSizeTrait, UInt64Array, + Array, ArrayRef, BooleanBufferBuilder, GenericListArray, Int64Array, OffsetSizeTrait, + UInt64Array, }; -use arrow::buffer::OffsetBuffer; +use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::compute; -use arrow::compute::cast; use arrow::datatypes::DataType; use arrow::datatypes::{ DataType::{LargeList, List}, Field, }; -use datafusion_common::cast::{as_large_list_array, as_list_array, as_uint64_array}; -use datafusion_common::{exec_err, utils::take_function_args, Result}; +use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; +use datafusion_common::types::{NativeType, logical_int64}; +use datafusion_common::{Result, exec_datafusion_err}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; -use std::any::Any; +use std::mem::size_of; use std::sync::Arc; +const ARRAY_REPEAT_LENGTH_EXCEEDED: &str = + "array_repeat: requested length exceeds maximum array size"; + make_udf_expr_and_func!( ArrayRepeat, array_repeat, @@ -89,17 +94,23 @@ impl Default for ArrayRepeat { impl ArrayRepeat { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Any), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), + ], + Volatility::Immutable, + ), aliases: vec![String::from("list_repeat")], } } } impl ScalarUDFImpl for ArrayRepeat { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_repeat" } @@ -109,16 +120,20 @@ impl ScalarUDFImpl for ArrayRepeat { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(List(Arc::new(Field::new_list_field( - arg_types[0].clone(), - true, - )))) + let element_type = &arg_types[0]; + match element_type { + LargeList(_) => Ok(LargeList(Arc::new(Field::new_list_field( + element_type.clone(), + true, + )))), + _ => Ok(List(Arc::new(Field::new_list_field( + element_type.clone(), + true, + )))), + } } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_repeat_inner)(&args.args) } @@ -126,23 +141,6 @@ impl ScalarUDFImpl for ArrayRepeat { &self.aliases } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let [first_type, second_type] = take_function_args(self.name(), arg_types)?; - - // Coerce the second argument to Int64/UInt64 if it's a numeric type - let second = match second_type { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - DataType::Int64 - } - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { - DataType::UInt64 - } - _ => return exec_err!("count must be an integer type"), - }; - - Ok(vec![first_type.clone(), second]) - } - fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -150,15 +148,7 @@ impl ScalarUDFImpl for ArrayRepeat { fn array_repeat_inner(args: &[ArrayRef]) -> Result { let element = &args[0]; - let count_array = &args[1]; - - let count_array = match count_array.data_type() { - DataType::Int64 => &cast(count_array, &DataType::UInt64)?, - DataType::UInt64 => count_array, - _ => return exec_err!("count must be an integer type"), - }; - - let count_array = as_uint64_array(count_array)?; + let count_array = as_int64_array(&args[1])?; match element.data_type() { List(_) => { @@ -187,45 +177,43 @@ fn array_repeat_inner(args: &[ArrayRef]) -> Result { /// ``` fn general_repeat( array: &ArrayRef, - count_array: &UInt64Array, + count_array: &Int64Array, ) -> Result { - let data_type = array.data_type(); - let mut new_values = vec![]; - - let count_vec = count_array - .values() - .to_vec() - .iter() - .map(|x| *x as usize) - .collect::>(); - - for (row_index, &count) in count_vec.iter().enumerate() { - let repeated_array = if array.is_null(row_index) { - new_null_array(data_type, count) - } else { - let original_data = array.to_data(); - let capacity = Capacities::Array(count); - let mut mutable = - MutableArrayData::with_capacities(vec![&original_data], false, capacity); - - for _ in 0..count { - mutable.extend(0, row_index, row_index + 1); - } - - let data = mutable.freeze(); - arrow::array::make_array(data) + let total_repeated_values = + (0..count_array.len()).try_fold(0usize, |total, idx| { + total + .checked_add(repeat_count(count_array, idx).unwrap_or_default()) + .ok_or_else(|| { + exec_datafusion_err!( + "array_repeat: total repeated values overflowed usize" + ) + }) + })?; + ensure_repeated_values_fit::(total_repeated_values)?; + let (offsets, _) = build_repeat_offsets::(count_array)?; + + let mut take_indices = Vec::with_capacity(total_repeated_values); + + for idx in 0..count_array.len() { + let Some(count) = repeat_count(count_array, idx) else { + continue; }; - new_values.push(repeated_array); + take_indices.extend(std::iter::repeat_n(idx as u64, count)); } - let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); - let values = compute::concat(&new_values)?; + // Build the flattened values + let repeated_values = compute::take( + array.as_ref(), + &UInt64Array::from_iter_values(take_indices), + None, + )?; + // Construct final ListArray Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new_list_field(data_type.to_owned(), true)), - OffsetBuffer::from_lengths(count_vec), - values, - None, + Arc::new(Field::new_list_field(array.data_type().to_owned(), true)), + OffsetBuffer::new(offsets.into()), + repeated_values, + count_array.nulls().cloned(), )?)) } @@ -241,58 +229,247 @@ fn general_repeat( /// ``` fn general_list_repeat( list_array: &GenericListArray, - count_array: &UInt64Array, + count_array: &Int64Array, ) -> Result { - let data_type = list_array.data_type(); - let value_type = list_array.value_type(); - let mut new_values = vec![]; - - let count_vec = count_array - .values() - .to_vec() - .iter() - .map(|x| *x as usize) - .collect::>(); - - for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) { - let list_arr = match list_array_row { - Some(list_array_row) => { - let original_data = list_array_row.to_data(); - let capacity = Capacities::Array(original_data.len() * count); - let mut mutable = MutableArrayData::with_capacities( - vec![&original_data], - false, - capacity, - ); - - for _ in 0..count { - mutable.extend(0, 0, original_data.len()); - } - - let data = mutable.freeze(); - let repeated_array = arrow::array::make_array(data); - - let list_arr = GenericListArray::::try_new( - Arc::new(Field::new_list_field(value_type.clone(), true)), - OffsetBuffer::::from_lengths(vec![original_data.len(); count]), - repeated_array, - None, - )?; - Arc::new(list_arr) as ArrayRef - } - None => new_null_array(data_type, count), + let list_offsets = list_array.value_offsets(); + let (outer_offsets, outer_total) = build_repeat_offsets::(count_array)?; + + // calculate capacities for pre-allocation + let mut inner_total = 0usize; + for i in 0..count_array.len() { + let Some(count) = repeat_count(count_array, i) else { + continue; }; - new_values.push(list_arr); + if count > 0 && list_array.is_valid(i) { + let len = list_offsets[i + 1].to_usize().unwrap() + - list_offsets[i].to_usize().unwrap(); + inner_total = + checked_repeat_len_add(inner_total, checked_repeat_len_mul(len, count)?)?; + ensure_repeated_values_fit::(inner_total)?; + } } - let lengths = new_values.iter().map(|a| a.len()).collect::>(); - let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); - let values = compute::concat(&new_values)?; + // Build inner structures + let inner_offsets_capacity = checked_offset_slots_capacity::(outer_total)?; + let mut inner_offsets = Vec::with_capacity(inner_offsets_capacity); + let mut take_indices = Vec::with_capacity(inner_total); + let mut inner_nulls = BooleanBufferBuilder::new(outer_total); + let mut inner_running = 0usize; + inner_offsets.push(O::zero()); + + for row_idx in 0..count_array.len() { + let Some(count) = repeat_count(count_array, row_idx) else { + continue; + }; + let list_is_valid = list_array.is_valid(row_idx); + let start = list_offsets[row_idx].to_usize().unwrap(); + let end = list_offsets[row_idx + 1].to_usize().unwrap(); + let row_len = end - start; + + for _ in 0..count { + inner_running = checked_repeat_len_add(inner_running, row_len)?; + ensure_repeated_values_fit::(inner_running)?; + let offset = checked_repeat_offset::(inner_running)?; + inner_offsets.push(offset); + inner_nulls.append(list_is_valid); + if list_is_valid { + take_indices.extend(start as u64..end as u64); + } + } + } - Ok(Arc::new(ListArray::try_new( - Arc::new(Field::new_list_field(data_type.to_owned(), true)), - OffsetBuffer::::from_lengths(lengths), - values, + // Build inner ListArray + let inner_values = compute::take( + list_array.values().as_ref(), + &UInt64Array::from_iter_values(take_indices), None, + )?; + let inner_list = GenericListArray::::try_new( + Arc::new(Field::new_list_field(list_array.value_type().clone(), true)), + OffsetBuffer::new(inner_offsets.into()), + inner_values, + Some(NullBuffer::new(inner_nulls.finish())), + )?; + + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new_list_field( + list_array.data_type().to_owned(), + true, + )), + OffsetBuffer::new(outer_offsets.into()), + Arc::new(inner_list), + count_array.nulls().cloned(), )?)) } + +fn build_repeat_offsets( + count_array: &Int64Array, +) -> Result<(Vec, usize)> { + let offsets_capacity = checked_offset_slots_capacity::(count_array.len())?; + let mut offsets = Vec::with_capacity(offsets_capacity); + offsets.push(O::zero()); + let mut running_offset = 0usize; + + for idx in 0..count_array.len() { + let Some(count) = repeat_count(count_array, idx) else { + offsets.push(*offsets.last().unwrap()); + continue; + }; + running_offset = checked_repeat_len_add(running_offset, count)?; + ensure_repeated_values_fit::(running_offset)?; + let offset = checked_repeat_offset::(running_offset)?; + offsets.push(offset); + } + + Ok((offsets, running_offset)) +} + +fn checked_repeat_len_add(lhs: usize, rhs: usize) -> Result { + lhs.checked_add(rhs) + .ok_or_else(|| exec_datafusion_err!("{}", ARRAY_REPEAT_LENGTH_EXCEEDED)) +} + +fn checked_repeat_len_mul(lhs: usize, rhs: usize) -> Result { + lhs.checked_mul(rhs) + .ok_or_else(|| exec_datafusion_err!("{}", ARRAY_REPEAT_LENGTH_EXCEEDED)) +} + +fn ensure_repeated_values_fit(len: usize) -> Result<()> { + ensure_vec_capacity::(len)?; + checked_repeat_offset::(len)?; + + Ok(()) +} + +fn ensure_vec_capacity(len: usize) -> Result<()> { + if len > max_vec_elements::() { + return Err(exec_datafusion_err!("{}", ARRAY_REPEAT_LENGTH_EXCEEDED)); + } + + Ok(()) +} + +fn checked_offset_slots_capacity(len: usize) -> Result { + let capacity = checked_repeat_len_add(len, 1)?; + ensure_vec_capacity::(capacity)?; + + Ok(capacity) +} + +fn checked_repeat_offset(offset: usize) -> Result { + O::from_usize(offset).ok_or_else(|| { + exec_datafusion_err!( + "array_repeat: offset {offset} exceeds the maximum value for offset type" + ) + }) +} + +fn max_vec_elements() -> usize { + let element_size = size_of::(); + (isize::MAX as usize) + .checked_div(element_size) + .unwrap_or(usize::MAX) +} + +/// Helper function to get count from count_array at given index. +/// Returns `None` for NULL values and `Some(0)` for non-positive counts. +#[inline] +fn repeat_count(count_array: &Int64Array, idx: usize) -> Option { + if count_array.is_null(idx) { + None + } else { + let c = count_array.value(idx); + Some(if c > 0 { c as usize } else { 0 }) + } +} + +#[cfg(test)] +mod tests { + use super::{array_repeat_inner, general_list_repeat, general_repeat}; + use arrow::array::{Array, ArrayRef, AsArray, Int32Array, Int64Array, ListArray}; + use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; + use arrow::datatypes::{Field, Int32Type}; + use datafusion_common::Result; + use std::sync::Arc; + + #[test] + fn test_array_repeat_null_count_stays_null() -> Result<()> { + let array: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + let counts = Int64Array::new( + ScalarBuffer::from(vec![2, 1, 1]), + Some(NullBuffer::from(vec![true, false, true])), + ); + + let result = general_repeat::(&array, &counts)?; + let expected = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(1)]), + None, + Some(vec![Some(3)]), + ]); + + assert_eq!(result.as_list::(), &expected); + + Ok(()) + } + + #[test] + fn test_array_repeat_nested_null_count_stays_null() -> Result<()> { + let list_array = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(3), Some(4)]), + Some(vec![Some(5)]), + ]); + let counts = Int64Array::new( + ScalarBuffer::from(vec![2, 1, 1]), + Some(NullBuffer::from(vec![true, false, true])), + ); + + let result = general_list_repeat::(&list_array, &counts)?; + let repeated_values = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(1), Some(2)]), + Some(vec![Some(5)]), + ]); + let expected = ListArray::new( + Arc::new(Field::new_list_field( + repeated_values.data_type().clone(), + true, + )), + OffsetBuffer::new(ScalarBuffer::from(vec![0, 2, 2, 3])), + Arc::new(repeated_values), + Some(NullBuffer::from(vec![true, false, true])), + ); + + assert_eq!(result.as_list::(), &expected); + + Ok(()) + } + + #[test] + fn scalar_count_exceeding_max_array_size_returns_error() { + let element: ArrayRef = Arc::new(Int64Array::from(vec![1])); + let count: ArrayRef = Arc::new(Int64Array::from(vec![i64::MAX])); + + let err = array_repeat_inner(&[element, count]).unwrap_err(); + assert!( + err.to_string().starts_with( + "Execution error: array_repeat: requested length exceeds maximum array size" + ), + "unexpected error: {err}" + ); + } + + #[test] + fn scalar_count_exceeding_list_offset_limit_returns_error() { + let element: ArrayRef = Arc::new(Int64Array::from(vec![1])); + let count: ArrayRef = Arc::new(Int64Array::from(vec![i32::MAX as i64 + 1])); + + let err = array_repeat_inner(&[element, count]).unwrap_err(); + assert!( + err.to_string().starts_with( + "Execution error: array_repeat: offset 2147483648 exceeds the maximum value for offset type" + ), + "unexpected error: {err}" + ); + } +} diff --git a/datafusion/functions-nested/src/replace.rs b/datafusion/functions-nested/src/replace.rs index 53182b58988f2..f129972fc7ea8 100644 --- a/datafusion/functions-nested/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -18,25 +18,22 @@ //! [`ScalarUDFImpl`] definitions for array_replace, array_replace_n and array_replace_all functions. use arrow::array::{ - new_null_array, Array, ArrayRef, AsArray, Capacities, GenericListArray, - MutableArrayData, NullBufferBuilder, OffsetSizeTrait, + Array, ArrayRef, AsArray, Capacities, GenericListArray, MutableArrayData, + NullBufferBuilder, OffsetBufferBuilder, OffsetSizeTrait, Scalar, new_null_array, }; -use arrow::datatypes::{DataType, Field}; - use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field}; use datafusion_common::cast::as_int64_array; use datafusion_common::utils::ListCoercion; -use datafusion_common::{exec_err, utils::take_function_args, Result}; +use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, - ScalarUDFImpl, Signature, TypeSignature, Volatility, + ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; use crate::utils::compare_element_to_list; -use crate::utils::make_scalar_function; -use std::any::Any; use std::sync::Arc; // Create static instances of ScalarUDFs for each function @@ -113,10 +110,6 @@ impl ArrayReplace { } impl ScalarUDFImpl for ArrayReplace { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_replace" } @@ -129,11 +122,32 @@ impl ScalarUDFImpl for ArrayReplace { Ok(args[0].clone()) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - make_scalar_function(array_replace_inner)(&args.args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + match (from_arg, to_arg) { + (ColumnarValue::Scalar(scalar_from), ColumnarValue::Scalar(scalar_to)) => { + let result = array_replace_with_scalar_args( + &list_array, + scalar_from, + scalar_to, + 1i64, + )?; + Ok(ColumnarValue::Array(result)) + } + (from_arg, to_arg) => { + let from_array = from_arg.to_array(num_rows)?; + let to_array = to_arg.to_array(num_rows)?; + let result = array_replace_internal( + &list_array, + &from_array, + &to_array, + &[Some(1)], + )?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -195,10 +209,6 @@ impl ArrayReplaceN { } impl ScalarUDFImpl for ArrayReplaceN { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_replace_n" } @@ -211,11 +221,44 @@ impl ScalarUDFImpl for ArrayReplaceN { Ok(args[0].clone()) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - make_scalar_function(array_replace_n_inner)(&args.args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [list_arg, from_arg, to_arg, max_arg] = + take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + match (from_arg, to_arg, max_arg) { + ( + ColumnarValue::Scalar(scalar_from), + ColumnarValue::Scalar(scalar_to), + ColumnarValue::Scalar(scalar_max), + ) => { + let ScalarValue::Int64(Some(n)) = scalar_max else { + return Ok(ColumnarValue::Array(new_null_array( + list_array.data_type(), + num_rows, + ))); + }; + let result = array_replace_with_scalar_args( + &list_array, + scalar_from, + scalar_to, + *n, + )?; + Ok(ColumnarValue::Array(result)) + } + (from_arg, to_arg, max_arg) => { + let from_array = from_arg.to_array(num_rows)?; + let to_array = to_arg.to_array(num_rows)?; + let max_array = max_arg.to_array(num_rows)?; + let result = array_replace_n_inner( + &list_array, + &from_array, + &to_array, + &max_array, + )?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -275,10 +318,6 @@ impl ArrayReplaceAll { } impl ScalarUDFImpl for ArrayReplaceAll { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_replace_all" } @@ -291,11 +330,32 @@ impl ScalarUDFImpl for ArrayReplaceAll { Ok(args[0].clone()) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - make_scalar_function(array_replace_all_inner)(&args.args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [list_arg, from_arg, to_arg] = take_function_args(self.name(), &args.args)?; + let num_rows = args.number_rows; + let list_array = list_arg.to_array(num_rows)?; + match (from_arg, to_arg) { + (ColumnarValue::Scalar(scalar_from), ColumnarValue::Scalar(scalar_to)) => { + let result = array_replace_with_scalar_args( + &list_array, + scalar_from, + scalar_to, + i64::MAX, + )?; + Ok(ColumnarValue::Array(result)) + } + (from_arg, to_arg) => { + let from_array = from_arg.to_array(num_rows)?; + let to_array = to_arg.to_array(num_rows)?; + let result = array_replace_internal( + &list_array, + &from_array, + &to_array, + &[Some(i64::MAX)], + )?; + Ok(ColumnarValue::Array(result)) + } + } } fn aliases(&self) -> &[String] { @@ -328,7 +388,7 @@ fn general_replace( list_array: &GenericListArray, from_array: &ArrayRef, to_array: &ArrayRef, - arr_n: &[i64], + arr_n: &[Option], ) -> Result { // Build up the offsets for the final output array let mut offsets: Vec = vec![O::usize_as(0)]; @@ -353,6 +413,17 @@ fn general_replace( continue; } + let n = if arr_n.len() == 1 { + arr_n[0] + } else { + arr_n[row_index] + }; + let Some(n) = n else { + offsets.push(offsets[row_index]); + valid.append_null(); + continue; + }; + let start = offset_window[0]; let end = offset_window[1]; @@ -365,11 +436,10 @@ fn general_replace( let original_idx = O::usize_as(0); let replace_idx = O::usize_as(1); - let n = arr_n[row_index]; let mut counter = 0; // All elements are false, no need to replace, just copy original data - if eq_array.false_count() == eq_array.len() { + if n <= 0 || !eq_array.has_true() { mutable.extend( original_idx.to_usize().unwrap(), start.to_usize().unwrap(), @@ -380,9 +450,18 @@ fn general_replace( continue; } + let mut pending_retain: Option = None; for (i, to_replace) in eq_array.iter().enumerate() { let i = O::usize_as(i); - if let Some(true) = to_replace { + if to_replace == Some(true) && counter < n { + // Flush any pending retain run before emitting the replacement. + if let Some(rs) = pending_retain.take() { + mutable.extend( + original_idx.to_usize().unwrap(), + (start + rs).to_usize().unwrap(), + (start + i).to_usize().unwrap(), + ); + } mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1); counter += 1; if counter == n { @@ -394,16 +473,23 @@ fn general_replace( ); break; } - } else { - // copy original data for false / null matches - mutable.extend( - original_idx.to_usize().unwrap(), - (start + i).to_usize().unwrap(), - (start + i).to_usize().unwrap() + 1, - ); + } else if pending_retain.is_none() { + pending_retain = Some(i); } } + // Flush trailing retain run when we exited the loop without ever + // hitting `counter == n` (i.e. fewer than `n` matches in this row). + if counter < n + && let Some(rs) = pending_retain + { + mutable.extend( + original_idx.to_usize().unwrap(), + (start + rs).to_usize().unwrap(), + end.to_usize().unwrap(), + ); + } + offsets.push(offsets[row_index] + (end - start)); valid.append_non_null(); } @@ -418,63 +504,238 @@ fn general_replace( )?)) } -fn array_replace_inner(args: &[ArrayRef]) -> Result { - let [array, from, to] = take_function_args("array_replace", args)?; +/// Replaces up to `max_replacements` occurrences of `needle` with the single +/// element in `to_array` for each row in `list_array`. +/// +/// This is a specialized fast path for the all-scalar case that uses a single +/// bulk `not_distinct` comparison over only the visible values range, then +/// iterates match positions via `set_indices` instead of scanning every bit. +fn general_replace_with_scalar( + list_array: &GenericListArray, + needle: &Scalar, + scalar_to: &ScalarValue, + max_replacements: i64, +) -> Result { + // No replacement needed - return unchanged. + if max_replacements <= 0 { + return Ok(Arc::new(list_array.clone())); + } + + let first_offset = list_array.offsets()[0].to_usize().unwrap(); + let last_offset = list_array.offsets()[list_array.len()].to_usize().unwrap(); + let visible_values = list_array + .values() + .slice(first_offset, last_offset - first_offset); - // replace at most one occurrence for each element - let arr_n = vec![1; array.len()]; - match array.data_type() { - DataType::List(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + let to_array = scalar_to.to_array_of_size(1)?; + let original_data = visible_values.to_data(); + let to_data = to_array.to_data(); + let capacity = Capacities::Array(original_data.len()); + + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &to_data], + false, + capacity, + ); + + let mut offsets = OffsetBufferBuilder::::new(list_array.len()); + + // Single bulk comparison over the visible values only. + let match_bitmap = arrow_ord::cmp::not_distinct(&visible_values, needle)?; + let match_bits = match_bitmap.values(); + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + // Offsets relative to visible_values (subtract first_offset). + let start = offset_window[0].to_usize().unwrap() - first_offset; + let end = offset_window[1].to_usize().unwrap() - first_offset; + let row_len = end - start; + + if list_array.is_null(row_index) { + offsets.push_length(0); + continue; } - DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + + // Slice the match bits to this row and iterate only over true positions. + let row_bits = match_bits.slice(start, row_len); + let mut match_positions = row_bits + .set_indices() + .take(max_replacements as usize) + .peekable(); + if match_positions.peek().is_none() { + mutable.extend(0, start, end); + offsets.push_length(row_len); + continue; } - DataType::Null => Ok(new_null_array(array.data_type(), 1)), - array_type => exec_err!("array_replace does not support type '{array_type}'."), + + // Iterate only over the positions that match using set_indices, + // which is more efficient than scanning every bit because the number + // of matches is typically much smaller than the total array size. + let mut prev_end = 0usize; + for match_pos in match_positions { + // Retain elements before this match. + if match_pos > prev_end { + mutable.extend(0, start + prev_end, start + match_pos); + } + // Emit the replacement element. + mutable.extend(1, 0, 1); + prev_end = match_pos + 1; + } + + // Copy remaining elements after the last replacement. + if prev_end < row_len { + mutable.extend(0, start + prev_end, end); + } + + offsets.push_length(row_len); } + + let data = mutable.freeze(); + + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new_list_field(list_array.value_type(), true)), + offsets.finish(), + arrow::array::make_array(data), + list_array.nulls().cloned(), + )?)) } -fn array_replace_n_inner(args: &[ArrayRef]) -> Result { - let [array, from, to, max] = take_function_args("array_replace_n", args)?; +/// Fast path for `array_replace` when all arguments are scalars. +/// +/// Uses a single bulk `not_distinct` comparison instead of per-row comparisons. +fn array_replace_with_scalar_args( + list_array: &ArrayRef, + scalar_from: &ScalarValue, + scalar_to: &ScalarValue, + max_replacements: i64, +) -> Result { + // `not_distinct` doesn't support nested types, fall back to the generic array path. + if scalar_from.data_type().is_nested() { + let num_rows = list_array.len(); + let from_array = scalar_from.to_array_of_size(num_rows)?; + let to_array = scalar_to.to_array_of_size(num_rows)?; + return array_replace_internal( + list_array, + &from_array, + &to_array, + &vec![Some(max_replacements); num_rows], + ); + } - // replace the specified number of occurrences - let arr_n = as_int64_array(max)?.values().to_vec(); - match array.data_type() { + let needle = Scalar::new(scalar_from.to_array_of_size(1)?); + match list_array.data_type() { DataType::List(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + let list = list_array.as_list::(); + general_replace_with_scalar::(list, &needle, scalar_to, max_replacements) } DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) - } - DataType::Null => Ok(new_null_array(array.data_type(), 1)), - array_type => { - exec_err!("array_replace_n does not support type '{array_type}'.") + let list = list_array.as_list::(); + general_replace_with_scalar::(list, &needle, scalar_to, max_replacements) } + DataType::Null => Ok(new_null_array(list_array.data_type(), 1)), + array_type => exec_err!("array_replace does not support type '{array_type}'."), } } -fn array_replace_all_inner(args: &[ArrayRef]) -> Result { - let [array, from, to] = take_function_args("array_replace_all", args)?; - - // replace all occurrences (up to "i64::MAX") - let arr_n = vec![i64::MAX; array.len()]; +fn array_replace_internal( + array: &ArrayRef, + from: &ArrayRef, + to: &ArrayRef, + arr_n: &[Option], +) -> Result { match array.data_type() { DataType::List(_) => { let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + general_replace::(list_array, from, to, arr_n) } DataType::LargeList(_) => { let list_array = array.as_list::(); - general_replace::(list_array, from, to, &arr_n) + general_replace::(list_array, from, to, arr_n) } DataType::Null => Ok(new_null_array(array.data_type(), 1)), - array_type => { - exec_err!("array_replace_all does not support type '{array_type}'.") - } + array_type => exec_err!("array_replace does not support type '{array_type}'."), + } +} + +fn array_replace_n_inner( + array: &ArrayRef, + from: &ArrayRef, + to: &ArrayRef, + max: &ArrayRef, +) -> Result { + let arr_n = as_int64_array(max)?.iter().collect::>(); + array_replace_internal(array, from, to, &arr_n) +} + +#[cfg(test)] +mod tests { + use super::{ArrayReplaceN, array_replace_n_inner}; + use arrow::array::{ArrayRef, AsArray, Int32Array, Int64Array, ListArray}; + use arrow::buffer::{NullBuffer, ScalarBuffer}; + use arrow::datatypes::{DataType, Field, Int32Type}; + use datafusion_common::{Result, ScalarValue, config::ConfigOptions}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use std::sync::Arc; + + #[test] + fn test_array_replace_n_null_max_returns_null() -> Result<()> { + let array: ArrayRef = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(2)]), + ])); + let from: ArrayRef = Arc::new(Int32Array::from(vec![2, 2])); + let to: ArrayRef = Arc::new(Int32Array::from(vec![9, 9])); + let max: ArrayRef = Arc::new(Int64Array::new( + ScalarBuffer::from(vec![1, 1]), + Some(NullBuffer::from(vec![true, false])), + )); + + let result = array_replace_n_inner(&array, &from, &to, &max)?; + let expected = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(9), Some(3)]), + None, + ]); + + assert_eq!(result.as_list::(), &expected); + + Ok(()) + } + + #[test] + fn test_array_replace_n_scalar_null_max_returns_null() -> Result<()> { + let array: ArrayRef = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(2)]), + ])); + let array_field = Arc::new(Field::new("array", array.data_type().clone(), true)); + + let result = ArrayReplaceN::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::clone(&array)), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(9))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + arg_fields: vec![ + Arc::clone(&array_field), + Arc::new(Field::new("from", DataType::Int32, false)), + Arc::new(Field::new("to", DataType::Int32, false)), + Arc::new(Field::new("max", DataType::Int64, true)), + ], + number_rows: array.len(), + return_field: Arc::clone(&array_field), + config_options: Arc::new(ConfigOptions::default()), + })?; + + let result = result.into_array(array.len())?; + let expected = ListArray::from_iter_primitive::(vec![ + Option::>>::None, + Option::>>::None, + ]); + + assert_eq!(result.as_list::(), &expected); + + Ok(()) } } diff --git a/datafusion/functions-nested/src/resize.rs b/datafusion/functions-nested/src/resize.rs index c76f7970d2064..d11064bf7efd6 100644 --- a/datafusion/functions-nested/src/resize.rs +++ b/datafusion/functions-nested/src/resize.rs @@ -19,8 +19,8 @@ use crate::utils::make_scalar_function; use arrow::array::{ - new_null_array, Array, ArrayRef, Capacities, GenericListArray, Int64Array, - MutableArrayData, NullBufferBuilder, OffsetSizeTrait, + Array, ArrayRef, Capacities, GenericListArray, Int64Array, MutableArrayData, + NullBufferBuilder, OffsetSizeTrait, new_null_array, }; use arrow::buffer::OffsetBuffer; use arrow::datatypes::DataType; @@ -31,13 +31,12 @@ use arrow::datatypes::{ }; use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; use datafusion_common::utils::ListCoercion; -use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, - ScalarUDFImpl, Signature, TypeSignature, Volatility, + ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; make_udf_expr_and_func!( @@ -111,10 +110,6 @@ impl ArrayResize { } impl ScalarUDFImpl for ArrayResize { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_resize" } @@ -136,10 +131,7 @@ impl ScalarUDFImpl for ArrayResize { } } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_resize_inner)(&args.args) } @@ -168,7 +160,7 @@ fn array_resize_inner(arg: &[ArrayRef]) -> Result { return exec_err!( "array_resize does not support type '{:?}'.", array.data_type() - ) + ); } }; return Ok(new_null_array(&return_type, array.len())); @@ -206,28 +198,117 @@ fn general_list_resize>( let values = array.values(); let original_data = values.to_data(); - // create default element array - let default_element = if let Some(default_element) = default_element { - default_element + // Track the largest per-row growth so the uniform-fill fast path can + // materialize one reusable fill buffer of the required size. + let mut max_extra: usize = 0; + let mut output_values_len: usize = 0; + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + if array.is_null(row_index) || count_array.is_null(row_index) { + continue; + } + let target_count = count_array.value(row_index).to_usize().ok_or_else(|| { + internal_datafusion_err!("array_resize: failed to convert size to usize") + })?; + output_values_len = + output_values_len.checked_add(target_count).ok_or_else(|| { + internal_datafusion_err!("array_resize: output size overflow") + })?; + let current_len = (offset_window[1] - offset_window[0]).to_usize().unwrap(); + if target_count > current_len { + max_extra = max_extra.max(target_count - current_len); + } + } + + // The fast path is valid when at least one row grows and every row would + // use the same fill value. + let use_bulk_fill = max_extra > 0 + && match &default_element { + None => true, + Some(fill_array) => { + let len = fill_array.len(); + let null_count = fill_array.logical_null_count(); + + len <= 1 + || null_count == len + || (null_count == 0 && { + let first = fill_array.slice(0, 1); + (1..len) + .all(|i| fill_array.slice(i, 1).as_ref() == first.as_ref()) + }) + } + }; + + if use_bulk_fill { + // Fast path: materialize one reusable fill buffer for all grown rows. + let fill_scalar = match &default_element { + None => ScalarValue::try_from(&data_type)?, + Some(fill_array) if fill_array.logical_null_count() == fill_array.len() => { + ScalarValue::try_from(&data_type)? + } + Some(fill_array) => ScalarValue::try_from_array(fill_array.as_ref(), 0)?, + }; + let fill_values = fill_scalar.to_array_of_size(max_extra)?; + let default_value_data = fill_values.to_data(); + build_resized_list( + array, + count_array, + field, + &original_data, + &default_value_data, + output_values_len, + |mutable, _, extra_count| mutable.extend(1, 0, extra_count), + ) } else { - let null_scalar = ScalarValue::try_from(&data_type)?; - null_scalar.to_array_of_size(original_data.len())? - }; - let default_value_data = default_element.to_data(); + // Slow path: rows may need different fill values, so append from the + // corresponding slot in the input fill array for each grown element. + let fill_values = match default_element { + Some(fill_values) => fill_values, + None => { + let null_scalar = ScalarValue::try_from(&data_type)?; + null_scalar.to_array_of_size(original_data.len())? + } + }; + let default_value_data = fill_values.to_data(); + build_resized_list( + array, + count_array, + field, + &original_data, + &default_value_data, + output_values_len, + |mutable, row_index, extra_count| { + for _ in 0..extra_count { + mutable.extend(1, row_index, row_index + 1); + } + }, + ) + } +} - // create a mutable array to store the original data - let capacity = Capacities::Array(original_data.len() + default_value_data.len()); +fn build_resized_list( + array: &GenericListArray, + count_array: &Int64Array, + field: &FieldRef, + original_data: &arrow::array::ArrayData, + default_value_data: &arrow::array::ArrayData, + output_values_len: usize, + mut append_fill_values: F, +) -> Result +where + O: OffsetSizeTrait + TryInto, + F: FnMut(&mut MutableArrayData, usize, usize), +{ + let capacity = Capacities::Array(output_values_len); let mut offsets = vec![O::usize_as(0)]; let mut mutable = MutableArrayData::with_capacities( - vec![&original_data, &default_value_data], + vec![original_data, default_value_data], false, capacity, ); - let mut null_builder = NullBufferBuilder::new(array.len()); for (row_index, offset_window) in array.offsets().windows(2).enumerate() { - if array.is_null(row_index) { + if array.is_null(row_index) || count_array.is_null(row_index) { null_builder.append_null(); offsets.push(offsets[row_index]); continue; @@ -240,21 +321,13 @@ fn general_list_resize>( let count = O::usize_as(count); let start = offset_window[0]; if start + count > offset_window[1] { - let extra_count = - (start + count - offset_window[1]).try_into().map_err(|_| { - internal_datafusion_err!( - "array_resize: failed to convert size to i64" - ) - })?; + let extra_count = (start + count - offset_window[1]).to_usize().unwrap(); let end = offset_window[1]; - mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap()); - // append default element - for _ in 0..extra_count { - mutable.extend(1, row_index, row_index + 1); - } + mutable.extend(0, start.to_usize().unwrap(), end.to_usize().unwrap()); + append_fill_values(&mut mutable, row_index, extra_count); } else { let end = start + count; - mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap()); + mutable.extend(0, start.to_usize().unwrap(), end.to_usize().unwrap()); }; offsets.push(offsets[row_index] + count); } @@ -268,3 +341,36 @@ fn general_list_resize>( null_builder.finish(), )?)) } + +#[cfg(test)] +mod tests { + use super::array_resize_inner; + use arrow::array::{ArrayRef, AsArray, Int64Array, ListArray}; + use arrow::buffer::{NullBuffer, ScalarBuffer}; + use arrow::datatypes::Int32Type; + use datafusion_common::Result; + use std::sync::Arc; + + #[test] + fn test_array_resize_null_size_returns_null() -> Result<()> { + let array: ArrayRef = + Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + ])); + let size: ArrayRef = Arc::new(Int64Array::new( + ScalarBuffer::from(vec![2, 1]), + Some(NullBuffer::from(vec![true, false])), + )); + + let result = array_resize_inner(&[array, size])?; + let expected = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + None, + ]); + + assert_eq!(result.as_list::(), &expected); + + Ok(()) + } +} diff --git a/datafusion/functions-nested/src/reverse.rs b/datafusion/functions-nested/src/reverse.rs index df873ade798d3..587d0dd29f306 100644 --- a/datafusion/functions-nested/src/reverse.rs +++ b/datafusion/functions-nested/src/reverse.rs @@ -32,13 +32,13 @@ use datafusion_common::cast::{ as_fixed_size_list_array, as_large_list_array, as_large_list_view_array, as_list_array, as_list_view_array, }; -use datafusion_common::{exec_err, utils::take_function_args, Result}; +use datafusion_common::{Result, exec_err, utils::take_function_args}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; use itertools::Itertools; -use std::any::Any; use std::sync::Arc; make_udf_expr_and_func!( @@ -88,10 +88,6 @@ impl ArrayReverse { } impl ScalarUDFImpl for ArrayReverse { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_reverse" } @@ -104,10 +100,7 @@ impl ScalarUDFImpl for ArrayReverse { Ok(arg_types[0].clone()) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_reverse_inner)(&args.args) } diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index 4350bfdc5a9bc..2214d3d35bb7b 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -19,26 +19,25 @@ use crate::utils::make_scalar_function; use arrow::array::{ - new_null_array, Array, ArrayRef, GenericListArray, LargeListArray, ListArray, - OffsetSizeTrait, + Array, ArrayRef, GenericListArray, OffsetSizeTrait, UInt32Array, UInt64Array, + new_empty_array, new_null_array, }; -use arrow::buffer::OffsetBuffer; -use arrow::compute; +use arrow::buffer::{NullBuffer, OffsetBuffer}; +use arrow::compute::{concat, take}; use arrow::datatypes::DataType::{LargeList, List, Null}; use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::row::{RowConverter, SortField}; use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::utils::ListCoercion; +use datafusion_common::utils::{ListCoercion, normalize_float_zero}; use datafusion_common::{ - assert_eq_or_internal_err, exec_err, internal_err, utils::take_function_args, Result, + Result, assert_eq_or_internal_err, exec_err, internal_err, utils::take_function_args, }; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; -use itertools::Itertools; -use std::any::Any; -use std::collections::HashSet; +use hashbrown::HashSet; use std::fmt::{Display, Formatter}; use std::sync::Arc; @@ -69,7 +68,7 @@ make_udf_expr_and_func!( #[user_doc( doc_section(label = "Array Functions"), - description = "Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates.", + description = "Returns an array of elements that are present in both arrays (all elements from both arrays) without duplicates.", syntax_example = "array_union(array1, array2)", sql_example = r#"```sql > select array_union([1, 2, 3, 4], [5, 6, 3, 4]); @@ -120,10 +119,6 @@ impl ArrayUnion { } impl ScalarUDFImpl for ArrayUnion { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_union" } @@ -136,16 +131,12 @@ impl ScalarUDFImpl for ArrayUnion { let [array1, array2] = take_function_args(self.name(), arg_types)?; match (array1, array2) { (Null, Null) => Ok(DataType::new_list(Null, true)), - (Null, dt) => Ok(dt.clone()), - (dt, Null) => Ok(dt.clone()), + (Null, dt) | (dt, Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_union_inner)(&args.args) } @@ -186,11 +177,17 @@ impl ScalarUDFImpl for ArrayUnion { ) )] #[derive(Debug, PartialEq, Eq, Hash)] -pub(super) struct ArrayIntersect { +pub struct ArrayIntersect { signature: Signature, aliases: Vec, } +impl Default for ArrayIntersect { + fn default() -> Self { + Self::new() + } +} + impl ArrayIntersect { pub fn new() -> Self { Self { @@ -205,10 +202,6 @@ impl ArrayIntersect { } impl ScalarUDFImpl for ArrayIntersect { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_intersect" } @@ -221,16 +214,12 @@ impl ScalarUDFImpl for ArrayIntersect { let [array1, array2] = take_function_args(self.name(), arg_types)?; match (array1, array2) { (Null, Null) => Ok(DataType::new_list(Null, true)), - (Null, dt) => Ok(dt.clone()), - (dt, Null) => Ok(dt.clone()), + (Null, dt) | (dt, Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_intersect_inner)(&args.args) } @@ -261,7 +250,7 @@ impl ScalarUDFImpl for ArrayIntersect { ) )] #[derive(Debug, PartialEq, Eq, Hash)] -pub(super) struct ArrayDistinct { +pub struct ArrayDistinct { signature: Signature, aliases: Vec, } @@ -275,11 +264,13 @@ impl ArrayDistinct { } } -impl ScalarUDFImpl for ArrayDistinct { - fn as_any(&self) -> &dyn Any { - self +impl Default for ArrayDistinct { + fn default() -> Self { + Self::new() } +} +impl ScalarUDFImpl for ArrayDistinct { fn name(&self) -> &str { "array_distinct" } @@ -292,10 +283,7 @@ impl ScalarUDFImpl for ArrayDistinct { Ok(arg_types[0].clone()) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_distinct_inner)(&args.args) } @@ -361,66 +349,165 @@ fn generic_set_lists( "{set_op:?} is not implemented for '{l:?}' and '{r:?}'" ); - let mut offsets = vec![OffsetSize::usize_as(0)]; - let mut new_arrays = vec![]; let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; - for (first_arr, second_arr) in l.iter().zip(r.iter()) { - let l_values = if let Some(first_arr) = first_arr { - converter.convert_columns(&[first_arr])? - } else { - converter.convert_columns(&[])? - }; - let r_values = if let Some(second_arr) = second_arr { - converter.convert_columns(&[second_arr])? - } else { - converter.convert_columns(&[])? - }; + // Normalize -0.0 → +0.0 so RowConverter (which uses IEEE 754 totalOrder + // and treats ±0 as distinct) groups them together. Use the normalized + // arrays for both row conversion and the final output values. + let l_values_norm = normalize_float_zero(l.values()); + let r_values_norm = normalize_float_zero(r.values()); + + // Only convert the visible portion of the values array. For sliced + // ListArrays, values() returns the full underlying array but only + // elements between the first and last offset are referenced. + let l_first = l.offsets()[0].as_usize(); + let l_len = l.offsets()[l.len()].as_usize() - l_first; + let l_values = l_values_norm.slice(l_first, l_len); + let rows_l = converter.convert_columns(&[Arc::clone(&l_values)])?; + + let r_first = r.offsets()[0].as_usize(); + let r_len = r.offsets()[r.len()].as_usize() - r_first; + let r_values = r_values_norm.slice(r_first, r_len); + let rows_r = converter.convert_columns(&[Arc::clone(&r_values)])?; + + // Indices from the row converter are 0-based in the per-side slice; + // concatenating those same slices lets indices map directly into the + // combined values array. + let combined_values = concat(&[l_values.as_ref(), r_values.as_ref()])?; + let r_offset = l_len; + + match set_op { + SetOp::Union => generic_set_loop::( + l, + r, + &rows_l, + &rows_r, + field, + &combined_values, + r_offset, + ), + SetOp::Intersect => generic_set_loop::( + l, + r, + &rows_l, + &rows_r, + field, + &combined_values, + r_offset, + ), + } +} - let l_iter = l_values.iter().sorted().dedup(); - let values_set: HashSet<_> = l_iter.clone().collect(); - let mut rows = if set_op == SetOp::Union { - l_iter.collect() - } else { - vec![] - }; - - for r_val in r_values.iter().sorted().dedup() { - match set_op { - SetOp::Union => { - if !values_set.contains(&r_val) { - rows.push(r_val); - } +/// Inner loop for set operations, parameterized by const generic to +/// avoid branching inside the hot loop. +fn generic_set_loop( + l: &GenericListArray, + r: &GenericListArray, + rows_l: &arrow::row::Rows, + rows_r: &arrow::row::Rows, + field: Arc, + combined_values: &ArrayRef, + r_offset: usize, +) -> Result { + let l_offsets = l.value_offsets(); + let r_offsets = r.value_offsets(); + let l_first = l.offsets()[0].as_usize(); + let r_first = r.offsets()[0].as_usize(); + + let mut result_offsets = Vec::with_capacity(l.len() + 1); + result_offsets.push(OffsetSize::usize_as(0)); + let initial_capacity = if IS_UNION { + // Union can include all elements from both sides + rows_l.num_rows() + } else { + // Intersect result is bounded by the smaller side + rows_l.num_rows().min(rows_r.num_rows()) + }; + + let mut indices: Vec = Vec::with_capacity(initial_capacity); + + // Reuse hash sets across iterations + let mut seen = HashSet::new(); + let mut lookup_set = HashSet::new(); + for i in 0..l.len() { + let last_offset = *result_offsets.last().unwrap(); + + if l.is_null(i) || r.is_null(i) { + result_offsets.push(last_offset); + continue; + } + + let l_start = l_offsets[i].as_usize() - l_first; + let l_end = l_offsets[i + 1].as_usize() - l_first; + let r_start = r_offsets[i].as_usize() - r_first; + let r_end = r_offsets[i + 1].as_usize() - r_first; + + seen.clear(); + + if IS_UNION { + for idx in l_start..l_end { + let row = rows_l.row(idx); + if seen.insert(row) { + indices.push(idx); } - SetOp::Intersect => { - if values_set.contains(&r_val) { - rows.push(r_val); - } + } + for idx in r_start..r_end { + let row = rows_r.row(idx); + if seen.insert(row) { + indices.push(idx + r_offset); } } - } - - let last_offset = match offsets.last() { - Some(offset) => *offset, - None => return internal_err!("offsets should not be empty"), - }; - - offsets.push(last_offset + OffsetSize::usize_as(rows.len())); - let arrays = converter.convert_rows(rows)?; - let array = match arrays.first() { - Some(array) => Arc::clone(array), - None => { - return internal_err!("{set_op}: failed to get array from rows"); + } else { + let l_len = l_end - l_start; + let r_len = r_end - r_start; + + // Select shorter side for lookup, longer side for probing. + // Track the probe side's offset into the combined values array. + let (lookup_rows, lookup_range, probe_rows, probe_range, probe_offset) = + if l_len < r_len { + (rows_l, l_start..l_end, rows_r, r_start..r_end, r_offset) + } else { + (rows_r, r_start..r_end, rows_l, l_start..l_end, 0) + }; + lookup_set.clear(); + lookup_set.reserve(lookup_range.len()); + + // Build lookup table + for idx in lookup_range { + lookup_set.insert(lookup_rows.row(idx)); } - }; - new_arrays.push(array); - } + // Probe and emit distinct intersected rows + for idx in probe_range { + let row = probe_rows.row(idx); + if lookup_set.contains(&row) && seen.insert(row) { + indices.push(idx + probe_offset); + } + } + } + result_offsets.push(last_offset + OffsetSize::usize_as(seen.len())); + } + + // Gather distinct values by index from the combined values array. + // Use UInt64Array for LargeList to support values arrays exceeding u32::MAX. + let final_values = if indices.is_empty() { + new_empty_array(&l.value_type()) + } else if OffsetSize::IS_LARGE { + let indices = + UInt64Array::from(indices.into_iter().map(|i| i as u64).collect::>()); + take(combined_values.as_ref(), &indices, None)? + } else { + let indices = + UInt32Array::from(indices.into_iter().map(|i| i as u32).collect::>()); + take(combined_values.as_ref(), &indices, None)? + }; - let offsets = OffsetBuffer::new(offsets.into()); - let new_arrays_ref: Vec<_> = new_arrays.iter().map(|v| v.as_ref()).collect(); - let values = compute::concat(&new_arrays_ref)?; - let arr = GenericListArray::::try_new(field, offsets, values, None)?; + let arr = GenericListArray::::try_new( + field, + OffsetBuffer::new(result_offsets.into()), + final_values, + NullBuffer::union(l.nulls(), r.nulls()), + )?; Ok(Arc::new(arr)) } @@ -429,59 +516,13 @@ fn general_set_op( array2: &ArrayRef, set_op: SetOp, ) -> Result { - fn empty_array(data_type: &DataType, len: usize, large: bool) -> Result { - let field = Arc::new(Field::new_list_field(data_type.clone(), true)); - let values = new_null_array(data_type, len); - if large { - Ok(Arc::new(LargeListArray::try_new( - field, - OffsetBuffer::new_zeroed(len), - values, - None, - )?)) - } else { - Ok(Arc::new(ListArray::try_new( - field, - OffsetBuffer::new_zeroed(len), - values, - None, - )?)) - } - } - + let len = array1.len(); match (array1.data_type(), array2.data_type()) { - (Null, Null) => Ok(Arc::new(ListArray::new_null( - Arc::new(Field::new_list_field(Null, true)), - array1.len(), - ))), - (Null, List(field)) => { - if set_op == SetOp::Intersect { - return empty_array(field.data_type(), array1.len(), false); - } - let array = as_list_array(&array2)?; - general_array_distinct::(array, field) - } - (List(field), Null) => { - if set_op == SetOp::Intersect { - return empty_array(field.data_type(), array1.len(), false); - } - let array = as_list_array(&array1)?; - general_array_distinct::(array, field) - } - (Null, LargeList(field)) => { - if set_op == SetOp::Intersect { - return empty_array(field.data_type(), array1.len(), true); - } - let array = as_large_list_array(&array2)?; - general_array_distinct::(array, field) - } - (LargeList(field), Null) => { - if set_op == SetOp::Intersect { - return empty_array(field.data_type(), array1.len(), true); - } - let array = as_large_list_array(&array1)?; - general_array_distinct::(array, field) - } + (Null, Null) => Ok(new_null_array(&DataType::new_list(Null, true), len)), + (Null, dt @ List(_)) + | (Null, dt @ LargeList(_)) + | (dt @ List(_), Null) + | (dt @ LargeList(_), Null) => Ok(new_null_array(dt, len)), (List(field), List(_)) => { let array1 = as_list_array(&array1)?; let array2 = as_list_array(&array2)?; @@ -517,42 +558,72 @@ fn general_array_distinct( if array.is_empty() { return Ok(Arc::new(array.clone()) as ArrayRef); } + let value_offsets = array.value_offsets(); let dt = array.value_type(); - let mut offsets = Vec::with_capacity(array.len()); + let mut offsets = Vec::with_capacity(array.len() + 1); offsets.push(OffsetSize::usize_as(0)); - let mut new_arrays = Vec::with_capacity(array.len()); - let converter = RowConverter::new(vec![SortField::new(dt)])?; - // distinct for each list in ListArray - for arr in array.iter() { - let last_offset: OffsetSize = offsets.last().copied().unwrap(); - let Some(arr) = arr else { - // Add same offset for null + + let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; + + // Normalize -0.0 → +0.0 so RowConverter (which uses IEEE 754 totalOrder + // and treats ±0 as distinct) groups them together, and so the output + // carries the canonical sign. + let values_norm = normalize_float_zero(array.values()); + + // Only convert the visible portion of the values array. For sliced + // ListArrays, values() returns the full underlying array but only + // elements between the first and last offset are referenced. + let first_offset = value_offsets[0].as_usize(); + let visible_len = value_offsets[array.len()].as_usize() - first_offset; + let rows = + converter.convert_columns(&[values_norm.slice(first_offset, visible_len)])?; + + let mut indices: Vec = Vec::with_capacity(rows.num_rows()); + let mut seen = HashSet::new(); + for i in 0..array.len() { + let last_offset = *offsets.last().unwrap(); + + // Null list entries produce no output; just carry forward the offset. + if array.is_null(i) { offsets.push(last_offset); continue; - }; - let values = converter.convert_columns(&[arr])?; - // sort elements in list and remove duplicates - let rows = values.iter().sorted().dedup().collect::>(); - offsets.push(last_offset + OffsetSize::usize_as(rows.len())); - let arrays = converter.convert_rows(rows)?; - let array = match arrays.first() { - Some(array) => Arc::clone(array), - None => { - return internal_err!("array_distinct: failed to get array from rows") + } + + let start = value_offsets[i].as_usize() - first_offset; + let end = value_offsets[i + 1].as_usize() - first_offset; + seen.clear(); + seen.reserve(end - start); + + // Walk the sub-array and keep only the first occurrence of each value. + for idx in start..end { + let row = rows.row(idx); + if seen.insert(row) { + indices.push(idx + first_offset); } - }; - new_arrays.push(array); - } - if new_arrays.is_empty() { - return Ok(Arc::new(array.clone()) as ArrayRef); - } - let offsets = OffsetBuffer::new(offsets.into()); - let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); - let values = compute::concat(&new_arrays_ref)?; + } + offsets.push(last_offset + OffsetSize::usize_as(seen.len())); + } + + // Gather distinct values in a single pass, using the computed `indices`. + // Indices are absolute positions in the (normalized) values array, so we + // can take directly from the full values. + // Use UInt64Array for LargeList to support values arrays exceeding u32::MAX. + let final_values = if indices.is_empty() { + new_empty_array(&dt) + } else if OffsetSize::IS_LARGE { + let indices = + UInt64Array::from(indices.into_iter().map(|i| i as u64).collect::>()); + take(values_norm.as_ref(), &indices, None)? + } else { + let indices = + UInt32Array::from(indices.into_iter().map(|i| i as u32).collect::>()); + take(values_norm.as_ref(), &indices, None)? + }; + Ok(Arc::new(GenericListArray::::try_new( Arc::clone(field), - offsets, - values, + OffsetBuffer::new(offsets.into()), + final_values, // Keep the list nulls array.nulls().cloned(), )?)) @@ -563,18 +634,136 @@ mod tests { use std::sync::Arc; use arrow::{ - array::{Int32Array, ListArray}, + array::{Array, AsArray, Int32Array, ListArray}, buffer::OffsetBuffer, - datatypes::{DataType, Field}, + datatypes::{DataType, Field, Int32Type}, }; - use datafusion_common::{config::ConfigOptions, DataFusionError}; - use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; + use datafusion_common::{DataFusionError, Result, config::ConfigOptions}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + + use crate::set_ops::{ArrayDistinct, ArrayIntersect, ArrayUnion, array_distinct_udf}; + + /// Build two sliced ListArrays and return them along with the shared list + /// field. + /// + /// l: [[1,2], [3,4], [5,6], [7,8]] → slice(1,2) → [[3,4], [5,6]] + /// r: [[1,3], [3,5], [5,7], [7,1]] → slice(1,2) → [[3,5], [5,7]] + fn make_sliced_pair() -> (ListArray, ListArray, Arc) { + let l = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(3), Some(4)]), + Some(vec![Some(5), Some(6)]), + Some(vec![Some(7), Some(8)]), + ]); + let r = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(3)]), + Some(vec![Some(3), Some(5)]), + Some(vec![Some(5), Some(7)]), + Some(vec![Some(7), Some(1)]), + ]); + let field = Arc::new(Field::new("item", l.data_type().clone(), true)); + (l.slice(1, 2), r.slice(1, 2), field) + } + + fn collect_i32_list(list: &ListArray) -> Vec> { + (0..list.len()) + .map(|i| { + let arr = list.value(i); + arr.as_any() + .downcast_ref::() + .unwrap() + .values() + .to_vec() + }) + .collect() + } + + #[test] + fn test_array_union_sliced_lists() -> Result<()> { + let (l, r, field) = make_sliced_pair(); + + let result = ArrayUnion::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(l)), + ColumnarValue::Array(Arc::new(r)), + ], + arg_fields: vec![Arc::clone(&field), Arc::clone(&field)], + number_rows: 2, + return_field: Arc::clone(&field), + config_options: Arc::new(ConfigOptions::default()), + })?; + + let output = result.into_array(2)?; + let output = output.as_list::(); + let rows = collect_i32_list(output); + + // Row 0: union([3,4], [3,5]) = [3,4,5] + assert_eq!(rows[0], vec![3, 4, 5]); + // Row 1: union([5,6], [5,7]) = [5,6,7] + assert_eq!(rows[1], vec![5, 6, 7]); + Ok(()) + } + + #[test] + fn test_array_intersect_sliced_lists() -> Result<()> { + let (l, r, field) = make_sliced_pair(); + + let result = ArrayIntersect::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::new(l)), + ColumnarValue::Array(Arc::new(r)), + ], + arg_fields: vec![Arc::clone(&field), Arc::clone(&field)], + number_rows: 2, + return_field: Arc::clone(&field), + config_options: Arc::new(ConfigOptions::default()), + })?; + + let output = result.into_array(2)?; + let output = output.as_list::(); + let rows = collect_i32_list(output); + + // Row 0: intersect([3,4], [3,5]) = [3] + assert_eq!(rows[0], vec![3]); + // Row 1: intersect([5,6], [5,7]) = [5] + assert_eq!(rows[1], vec![5]); + Ok(()) + } - use crate::set_ops::array_distinct_udf; + #[test] + fn test_array_distinct_sliced_list() -> Result<()> { + // [[1,1], [3,3,4], [5,5,6], [7,7]] → slice(1,2) → [[3,3,4], [5,5,6]] + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(1)]), + Some(vec![Some(3), Some(3), Some(4)]), + Some(vec![Some(5), Some(5), Some(6)]), + Some(vec![Some(7), Some(7)]), + ]); + let sliced = list.slice(1, 2); + let field = Arc::new(Field::new("item", sliced.data_type().clone(), true)); + + let result = ArrayDistinct::new().invoke_with_args(ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::new(sliced))], + arg_fields: vec![Arc::clone(&field)], + number_rows: 2, + return_field: field, + config_options: Arc::new(ConfigOptions::default()), + })?; + + let output = result.into_array(2)?; + let output = output.as_list::(); + let rows = collect_i32_list(output); + + // Row 0: distinct([3,3,4]) = [3,4] + assert_eq!(rows[0], vec![3, 4]); + // Row 1: distinct([5,5,6]) = [5,6] + assert_eq!(rows[1], vec![5, 6]); + Ok(()) + } #[test] - fn test_array_distinct_inner_nullability_result_type_match_return_type( - ) -> Result<(), DataFusionError> { + fn test_array_distinct_inner_nullability_result_type_match_return_type() + -> Result<(), DataFusionError> { let udf = array_distinct_udf(); for inner_nullable in [true, false] { diff --git a/datafusion/functions-nested/src/sort.rs b/datafusion/functions-nested/src/sort.rs index 8cfc8a297b7b7..0a34cce6b965f 100644 --- a/datafusion/functions-nested/src/sort.rs +++ b/datafusion/functions-nested/src/sort.rs @@ -18,22 +18,23 @@ //! [`ScalarUDFImpl`] definitions for array_sort function. use crate::utils::make_scalar_function; +use arrow::array::BooleanBufferBuilder; use arrow::array::{ - new_null_array, Array, ArrayRef, GenericListArray, NullBufferBuilder, OffsetSizeTrait, + Array, ArrayRef, ArrowPrimitiveType, GenericListArray, OffsetSizeTrait, + PrimitiveArray, UInt32Array, UInt64Array, new_empty_array, new_null_array, }; -use arrow::buffer::OffsetBuffer; -use arrow::compute::SortColumn; -use arrow::datatypes::{DataType, FieldRef}; -use arrow::{compute, compute::SortOptions}; +use arrow::buffer::{NullBuffer, OffsetBuffer}; +use arrow::datatypes::{ArrowNativeTypeOp, DataType, FieldRef}; +use arrow::row::{RowConverter, SortField}; +use arrow::{compute, compute::SortOptions, downcast_primitive_array}; use datafusion_common::cast::{as_large_list_array, as_list_array, as_string_array}; use datafusion_common::utils::ListCoercion; -use datafusion_common::{exec_err, plan_err, Result}; +use datafusion_common::{Result, exec_err, internal_datafusion_err}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation, - ScalarUDFImpl, Signature, TypeSignature, Volatility, + ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; make_udf_expr_and_func!( @@ -69,11 +70,11 @@ make_udf_expr_and_func!( ), argument( name = "desc", - description = "Whether to sort in descending order(`ASC` or `DESC`)." + description = "Whether to sort in ascending (`ASC`) or descending (`DESC`) order. The default is `ASC`." ), argument( name = "nulls_first", - description = "Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`)." + description = "Whether to sort nulls first (`NULLS FIRST`) or last (`NULLS LAST`). The default is `NULLS FIRST`." ) )] #[derive(Debug, PartialEq, Eq, Hash)] @@ -121,10 +122,6 @@ impl ArraySort { } impl ScalarUDFImpl for ArraySort { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_sort" } @@ -134,24 +131,10 @@ impl ScalarUDFImpl for ArraySort { } fn return_type(&self, arg_types: &[DataType]) -> Result { - match &arg_types[0] { - DataType::Null => Ok(DataType::Null), - DataType::List(field) => { - Ok(DataType::new_list(field.data_type().clone(), true)) - } - DataType::LargeList(field) => { - Ok(DataType::new_large_list(field.data_type().clone(), true)) - } - arg_type => { - plan_err!("{} does not support type {arg_type}", self.name()) - } - } + Ok(arg_types[0].clone()) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_sort_inner)(&args.args) } @@ -177,25 +160,20 @@ fn array_sort_inner(args: &[ArrayRef]) -> Result { return Ok(new_null_array(args[0].data_type(), args[0].len())); } - let sort_options = match args.len() { - 1 => None, - 2 => { - let sort = as_string_array(&args[1])?.value(0); - Some(SortOptions { - descending: order_desc(sort)?, - nulls_first: true, - }) - } - 3 => { - let sort = as_string_array(&args[1])?.value(0); - let nulls_first = as_string_array(&args[2])?.value(0); - Some(SortOptions { - descending: order_desc(sort)?, - nulls_first: order_nulls_first(nulls_first)?, - }) - } - // We guard at the top - _ => unreachable!(), + let sort_options = if args.len() >= 2 { + let order = as_string_array(&args[1])?.value(0); + let descending = order_desc(order)?; + let nulls_first = if args.len() >= 3 { + order_nulls_first(as_string_array(&args[2])?.value(0))? + } else { + true + }; + Some(SortOptions { + descending, + nulls_first, + }) + } else { + None }; match args[0].data_type() { @@ -206,11 +184,11 @@ fn array_sort_inner(args: &[ArrayRef]) -> Result { } DataType::List(field) => { let array = as_list_array(&args[0])?; - array_sort_generic(array, field, sort_options) + array_sort_generic(array, Arc::clone(field), sort_options) } DataType::LargeList(field) => { let array = as_large_list_array(&args[0])?; - array_sort_generic(array, field, sort_options) + array_sort_generic(array, Arc::clone(field), sort_options) } // Signature should prevent this arm ever occurring _ => exec_err!("array_sort expects list for first argument"), @@ -219,62 +197,286 @@ fn array_sort_inner(args: &[ArrayRef]) -> Result { fn array_sort_generic( list_array: &GenericListArray, - field: &FieldRef, + field: FieldRef, + sort_options: Option, +) -> Result { + let values = list_array.values(); + + if values.data_type().is_primitive() { + array_sort_primitive(list_array, field, sort_options) + } else { + array_sort_non_primitive(list_array, field, sort_options) + } +} + +/// Sort each row of a primitive-typed ListArray using a custom in-place sort +/// kernel. +fn array_sort_primitive( + list_array: &GenericListArray, + field: FieldRef, sort_options: Option, ) -> Result { + let values = list_array.values().as_ref(); + downcast_primitive_array! { + values => sort_primitive_list(values, list_array, field, sort_options), + _ => exec_err!("array_sort: unsupported primitive type") + } +} + +fn sort_primitive_list( + prim_values: &PrimitiveArray, + list_array: &GenericListArray, + field: FieldRef, + sort_options: Option, +) -> Result +where + T::Native: ArrowNativeTypeOp, +{ + if prim_values.null_count() > 0 { + sort_list_with_nulls(prim_values, list_array, field, sort_options) + } else { + sort_list_no_nulls(prim_values, list_array, field, sort_options) + } +} + +/// Fast path for primitive values with no element-level nulls. Copies all +/// values into a single `Vec` and sorts each row's slice in-place. +fn sort_list_no_nulls( + prim_values: &PrimitiveArray, + list_array: &GenericListArray, + field: FieldRef, + sort_options: Option, +) -> Result +where + T::Native: ArrowNativeTypeOp, +{ let row_count = list_array.len(); + let offsets = list_array.offsets(); + let values_start = offsets[0].as_usize(); + let values_end = offsets[row_count].as_usize(); + + let descending = sort_options.is_some_and(|o| o.descending); + + // Copy all values into a mutable buffer + let mut values: Vec = + prim_values.values()[values_start..values_end].to_vec(); - let mut array_lengths = vec![]; - let mut arrays = vec![]; - let mut valid = NullBufferBuilder::new(row_count); - for i in 0..row_count { - if list_array.is_null(i) { - array_lengths.push(0); - valid.append_null(); + for (row_index, window) in offsets.windows(2).enumerate() { + if list_array.is_null(row_index) { + continue; + } + let start = window[0].as_usize() - values_start; + let end = window[1].as_usize() - values_start; + let slice = &mut values[start..end]; + if descending { + slice.sort_unstable_by(|a, b| b.compare(*a)); } else { - let arr_ref = list_array.value(i); - - // arrow sort kernel does not support Structs, so use - // lexsort_to_indices instead: - // https://github.com/apache/arrow-rs/issues/6911#issuecomment-2562928843 - let sorted_array = match arr_ref.data_type() { - DataType::Struct(_) => { - let sort_columns: Vec = vec![SortColumn { - values: Arc::clone(&arr_ref), - options: sort_options, - }]; - let indices = compute::lexsort_to_indices(&sort_columns, None)?; - compute::take(arr_ref.as_ref(), &indices, None)? - } - _ => { - let arr_ref = arr_ref.as_ref(); - compute::sort(arr_ref, sort_options)? - } - }; - array_lengths.push(sorted_array.len()); - arrays.push(sorted_array); - valid.append_non_null(); + slice.sort_unstable_by(|a, b| a.compare(*b)); } } - let buffer = valid.finish(); + let new_offsets = rebase_offsets(offsets); + let sorted_values = Arc::new( + PrimitiveArray::::new(values.into(), None) + .with_data_type(prim_values.data_type().clone()), + ); + + Ok(Arc::new(GenericListArray::::try_new( + field, + new_offsets, + sorted_values, + list_array.nulls().cloned(), + )?)) +} + +/// Slow path for primitive values with element-level nulls. +fn sort_list_with_nulls( + prim_values: &PrimitiveArray, + list_array: &GenericListArray, + field: FieldRef, + sort_options: Option, +) -> Result +where + T::Native: ArrowNativeTypeOp, +{ + let row_count = list_array.len(); + let offsets = list_array.offsets(); + let values_start = offsets[0].as_usize(); + let values_end = offsets[row_count].as_usize(); + let total_values = values_end - values_start; - let elements = arrays - .iter() - .map(|a| a.as_ref()) - .collect::>(); + let descending = sort_options.is_some_and(|o| o.descending); + let nulls_first = sort_options.is_none_or(|o| o.nulls_first); - let list_arr = if elements.is_empty() { - GenericListArray::::new_null(Arc::clone(field), row_count) - } else { - GenericListArray::::new( - Arc::clone(field), - OffsetBuffer::from_lengths(array_lengths), - Arc::new(compute::concat(elements.as_slice())?), - buffer, + let mut out_values: Vec = vec![T::Native::default(); total_values]; + let mut validity = BooleanBufferBuilder::new(total_values); + + let src_nulls = prim_values.nulls().ok_or_else(|| { + internal_datafusion_err!( + "sort_list_with_nulls called but values have no null buffer" ) + })?; + let src_values = prim_values.values(); + + for (row_index, window) in offsets.windows(2).enumerate() { + let start = window[0].as_usize(); + let end = window[1].as_usize(); + let row_len = end - start; + let out_start = start - values_start; + + if list_array.is_null(row_index) || row_len == 0 { + validity.append_n(row_len, false); + continue; + } + + let null_count = src_nulls.slice(start, row_len).null_count(); + let valid_count = row_len - null_count; + + // Compact valid values directly into the target region of the output + // buffer: after nulls (if nulls_first) or at the start (if nulls_last). + let valid_offset = if nulls_first { null_count } else { 0 }; + let mut write_pos = out_start + valid_offset; + for i in start..end { + if src_nulls.is_valid(i) { + out_values[write_pos] = src_values[i]; + write_pos += 1; + } + } + + let valid_slice = &mut out_values + [out_start + valid_offset..out_start + valid_offset + valid_count]; + if descending { + valid_slice.sort_unstable_by(|a, b| b.compare(*a)); + } else { + valid_slice.sort_unstable_by(|a, b| a.compare(*b)); + } + + // Build validity bits + if nulls_first { + validity.append_n(null_count, false); + validity.append_n(valid_count, true); + } else { + validity.append_n(valid_count, true); + validity.append_n(null_count, false); + } + } + + let new_offsets = rebase_offsets(offsets); + + let null_buffer = NullBuffer::from(validity.finish()); + let sorted_values = Arc::new( + PrimitiveArray::::new(out_values.into(), Some(null_buffer)) + .with_data_type(prim_values.data_type().clone()), + ); + + Ok(Arc::new(GenericListArray::::try_new( + field, + new_offsets, + sorted_values, + list_array.nulls().cloned(), + )?)) +} + +/// Sort a non-pritive-typed ListArray by converting all rows at once using +/// `RowConverter`, and then sort row indices by comparing encoded bytes (sort +/// direction and null ordering are baked into the encoding), and materialize +/// the result with a single `take()`. +fn array_sort_non_primitive( + list_array: &GenericListArray, + field: FieldRef, + sort_options: Option, +) -> Result { + let row_count = list_array.len(); + let values = list_array.values(); + let offsets = list_array.offsets(); + let values_start = offsets[0].as_usize(); + let total_values = offsets[row_count].as_usize() - values_start; + + let converter = RowConverter::new(vec![SortField::new_with_options( + values.data_type().clone(), + sort_options.unwrap_or_default(), + )])?; + let values_sliced = values.slice(values_start, total_values); + let rows = converter.convert_columns(&[Arc::clone(&values_sliced)])?; + + let mut indices: Vec = Vec::with_capacity(total_values); + let mut new_offsets = Vec::with_capacity(row_count + 1); + new_offsets.push(OffsetSize::usize_as(0)); + + let mut sort_scratch: Vec = Vec::new(); + + for (row_index, window) in offsets.windows(2).enumerate() { + let start = window[0]; + let end = window[1]; + + if list_array.is_null(row_index) { + new_offsets.push(new_offsets[row_index]); + continue; + } + + let len = (end - start).as_usize(); + let local_start = start.as_usize() - values_start; + + if len <= 1 { + indices.extend((local_start..local_start + len).map(OffsetSize::usize_as)); + } else { + sort_scratch.clear(); + sort_scratch.extend(local_start..local_start + len); + sort_scratch.sort_unstable_by(|&a, &b| rows.row(a).cmp(&rows.row(b))); + indices.extend(sort_scratch.iter().map(|&i| OffsetSize::usize_as(i))); + } + + new_offsets.push(new_offsets[row_index] + (end - start)); + } + + let sorted_values = if indices.is_empty() { + new_empty_array(values.data_type()) + } else { + take_by_indices(&values_sliced, indices)? }; - Ok(Arc::new(list_arr)) + + Ok(Arc::new(GenericListArray::::try_new( + field, + OffsetBuffer::::new(new_offsets.into()), + sorted_values, + list_array.nulls().cloned(), + )?)) +} + +/// Select elements from `values` at the given `indices` using `compute::take`. +/// We consume `indices` in order to avoid an intermediate copy. +fn take_by_indices( + values: &ArrayRef, + indices: Vec, +) -> Result { + let len = indices.len(); + let buffer = arrow::buffer::Buffer::from_vec(indices); + let indices_array: ArrayRef = if OffsetSize::IS_LARGE { + Arc::new(UInt64Array::new( + arrow::buffer::ScalarBuffer::new(buffer, 0, len), + None, + )) + } else { + Arc::new(UInt32Array::new( + arrow::buffer::ScalarBuffer::new(buffer, 0, len), + None, + )) + }; + Ok(compute::take(values.as_ref(), &indices_array, None)?) +} + +/// Rebase offsets so they start at 0. For non-sliced ListArrays (the common +/// case) offsets already start at 0 and we can clone the Arc-backed buffer +/// cheaply instead of allocating a new Vec. +fn rebase_offsets( + offsets: &OffsetBuffer, +) -> OffsetBuffer { + if offsets[0].as_usize() == 0 { + offsets.clone() + } else { + let rebased: Vec = offsets.iter().map(|o| *o - offsets[0]).collect(); + OffsetBuffer::new(rebased.into()) + } } fn order_desc(modifier: &str) -> Result { diff --git a/datafusion/functions-nested/src/string.rs b/datafusion/functions-nested/src/string.rs index e19025cf673e0..b76736672cffa 100644 --- a/datafusion/functions-nested/src/string.rs +++ b/datafusion/functions-nested/src/string.rs @@ -19,104 +19,41 @@ use arrow::array::{ Array, ArrayRef, BooleanArray, Float32Array, Float64Array, GenericListArray, - Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, ListBuilder, - OffsetSizeTrait, StringArray, StringBuilder, UInt16Array, UInt32Array, UInt64Array, - UInt8Array, + Int8Array, Int16Array, Int32Array, Int64Array, LargeStringArray, ListBuilder, + OffsetSizeTrait, StringArray, StringBuilder, UInt8Array, UInt16Array, UInt32Array, + UInt64Array, }; use arrow::datatypes::{DataType, Field}; use datafusion_common::utils::ListCoercion; -use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result, ScalarValue, not_impl_err}; -use std::any::Any; +use std::fmt::{self, Write}; use crate::utils::make_scalar_function; use arrow::array::{ + StringArrayType, StringViewArray, builder::{ArrayBuilder, LargeStringBuilder, StringViewBuilder}, cast::AsArray, - GenericStringArray, StringArrayType, StringViewArray, }; -use arrow::compute::cast; +use arrow::compute::{can_cast_types, cast}; use arrow::datatypes::DataType::{ Dictionary, FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8, Utf8View, }; use datafusion_common::cast::{ as_fixed_size_list_array, as_large_list_array, as_list_array, }; -use datafusion_common::exec_err; use datafusion_common::types::logical_string; +use datafusion_common::{exec_datafusion_err, exec_err}; use datafusion_expr::{ ArrayFunctionArgument, ArrayFunctionSignature, Coercion, ColumnarValue, - Documentation, ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, - Volatility, + Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, }; use datafusion_functions::downcast_arg; use datafusion_macros::user_doc; use std::sync::Arc; -macro_rules! call_array_function { - ($DATATYPE:expr, false) => { - match $DATATYPE { - DataType::Utf8 => array_function!(StringArray), - DataType::Utf8View => array_function!(StringViewArray), - DataType::LargeUtf8 => array_function!(LargeStringArray), - DataType::Boolean => array_function!(BooleanArray), - DataType::Float32 => array_function!(Float32Array), - DataType::Float64 => array_function!(Float64Array), - DataType::Int8 => array_function!(Int8Array), - DataType::Int16 => array_function!(Int16Array), - DataType::Int32 => array_function!(Int32Array), - DataType::Int64 => array_function!(Int64Array), - DataType::UInt8 => array_function!(UInt8Array), - DataType::UInt16 => array_function!(UInt16Array), - DataType::UInt32 => array_function!(UInt32Array), - DataType::UInt64 => array_function!(UInt64Array), - dt => not_impl_err!("Unsupported data type in array_to_string: {dt}"), - } - }; - ($DATATYPE:expr, $INCLUDE_LIST:expr) => {{ - match $DATATYPE { - DataType::List(_) => array_function!(ListArray), - DataType::Utf8 => array_function!(StringArray), - DataType::Utf8View => array_function!(StringViewArray), - DataType::LargeUtf8 => array_function!(LargeStringArray), - DataType::Boolean => array_function!(BooleanArray), - DataType::Float32 => array_function!(Float32Array), - DataType::Float64 => array_function!(Float64Array), - DataType::Int8 => array_function!(Int8Array), - DataType::Int16 => array_function!(Int16Array), - DataType::Int32 => array_function!(Int32Array), - DataType::Int64 => array_function!(Int64Array), - DataType::UInt8 => array_function!(UInt8Array), - DataType::UInt16 => array_function!(UInt16Array), - DataType::UInt32 => array_function!(UInt32Array), - DataType::UInt64 => array_function!(UInt64Array), - dt => not_impl_err!("Unsupported data type in array_to_string: {dt}"), - } - }}; -} - -macro_rules! to_string { - ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); - for x in arr { - match x { - Some(x) => { - $ARG.push_str(&x.to_string()); - $ARG.push_str($DELIMITER); - } - None => { - if $WITH_NULL_STRING { - $ARG.push_str($NULL_STRING); - $ARG.push_str($DELIMITER); - } - } - } - } - Ok($ARG) - }}; -} - // Create static instances of ScalarUDFs for each function make_udf_expr_and_func!( ArrayToString, @@ -145,7 +82,7 @@ make_udf_expr_and_func!( argument(name = "delimiter", description = "Array element separator."), argument( name = "null_string", - description = "Optional. String to replace null values in the array. If not provided, nulls will be handled by default behavior." + description = "Optional. String to use for null values in the output. If not provided, nulls will be omitted." ) )] #[derive(Debug, PartialEq, Eq, Hash)] @@ -193,10 +130,6 @@ impl ArrayToString { } impl ScalarUDFImpl for ArrayToString { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array_to_string" } @@ -209,10 +142,7 @@ impl ScalarUDFImpl for ArrayToString { Ok(Utf8) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(array_to_string_inner)(&args.args) } @@ -259,11 +189,17 @@ make_udf_expr_and_func!( ) )] #[derive(Debug, PartialEq, Eq, Hash)] -pub(super) struct StringToArray { +pub struct StringToArray { signature: Signature, aliases: Vec, } +impl Default for StringToArray { + fn default() -> Self { + Self::new() + } +} + impl StringToArray { pub fn new() -> Self { Self { @@ -287,10 +223,6 @@ impl StringToArray { } impl ScalarUDFImpl for StringToArray { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "string_to_array" } @@ -306,17 +238,72 @@ impl ScalarUDFImpl for StringToArray { )))) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - let args = &args.args; - match args[0].data_type() { - Utf8 | Utf8View => make_scalar_function(string_to_array_inner::)(args), - LargeUtf8 => make_scalar_function(string_to_array_inner::)(args), + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + + let delimiter_is_scalar = matches!(&args[1], ColumnarValue::Scalar(_)); + let null_str_is_scalar = args + .get(2) + .is_none_or(|a| matches!(a, ColumnarValue::Scalar(_))); + + if !(delimiter_is_scalar && null_str_is_scalar) { + return make_scalar_function(string_to_array_fallback)(&args); + } + + // Delimiter and null_str (if given) are scalar, so use the fast path + let delimiter = match &args[1] { + ColumnarValue::Scalar(s) => s.try_as_str().ok_or_else(|| { + exec_datafusion_err!( + "unsupported type for string_to_array delimiter: {:?}", + args[1].data_type() + ) + })?, + _ => unreachable!("delimiter must be scalar in this branch"), + }; + let null_value = match args.get(2) { + Some(ColumnarValue::Scalar(s)) => s.try_as_str().ok_or_else(|| { + exec_datafusion_err!( + "unsupported type for string_to_array null_str: {:?}", + args[2].data_type() + ) + })?, + _ => None, + }; + + let (all_scalar, string_array) = match &args[0] { + ColumnarValue::Array(a) => (false, Arc::clone(a)), + ColumnarValue::Scalar(s) => (true, s.to_array_of_size(1)?), + }; + + let result = match string_array.data_type() { + Utf8 => { + let arr = string_array.as_string::(); + let builder = + StringBuilder::with_capacity(arr.len(), arr.get_buffer_memory_size()); + string_to_array_scalar_args(&arr, delimiter, null_value, builder) + } + Utf8View => { + let arr = string_array.as_string_view(); + let builder = StringViewBuilder::with_capacity(arr.len()); + string_to_array_scalar_args(&arr, delimiter, null_value, builder) + } + LargeUtf8 => { + let arr = string_array.as_string::(); + let builder = LargeStringBuilder::with_capacity( + arr.len(), + arr.get_buffer_memory_size(), + ); + string_to_array_scalar_args(&arr, delimiter, null_value, builder) + } other => { exec_err!("unsupported type for string_to_array function as {other:?}") } + }?; + + if all_scalar { + ScalarValue::try_from_array(&result, 0).map(ColumnarValue::Scalar) + } else { + Ok(ColumnarValue::Array(result)) } } @@ -329,428 +316,468 @@ impl ScalarUDFImpl for StringToArray { } } -fn array_to_string_inner(args: &[ArrayRef]) -> Result { - if args.len() < 2 || args.len() > 3 { - return exec_err!("array_to_string expects two or three arguments"); +/// Appends `value` to the string builder, or NULL if it matches `null_value`. +#[inline(always)] +fn append_part( + builder: &mut impl StringArrayBuilderType, + value: &str, + null_value: Option<&str>, +) { + if null_value == Some(value) { + builder.append_null(); + } else { + builder.append_value(value); } +} - let arr = &args[0]; - - let delimiters: Vec> = match args[1].data_type() { - Utf8 => args[1].as_string::().iter().collect(), - Utf8View => args[1].as_string_view().iter().collect(), - LargeUtf8 => args[1].as_string::().iter().collect(), - other => return exec_err!("unsupported type for second argument to array_to_string function as {other:?}") - }; +/// Optimized `string_to_array` implementation for the common case where +/// delimiter and null_value are scalar values. +fn string_to_array_scalar_args<'a, StringArrType, StringBuilderType>( + string_array: &StringArrType, + delimiter: Option<&str>, + null_value: Option<&str>, + string_builder: StringBuilderType, +) -> Result +where + StringArrType: StringArrayType<'a>, + StringBuilderType: StringArrayBuilderType, +{ + let mut list_builder = ListBuilder::new(string_builder); - let mut null_string = String::from(""); - let mut with_null_string = false; - if args.len() == 3 { - null_string = match args[2].data_type() { - Utf8 => args[2].as_string::().value(0).to_string(), - Utf8View => args[2].as_string_view().value(0).to_string(), - LargeUtf8 => args[2].as_string::().value(0).to_string(), - other => return exec_err!("unsupported type for second argument to array_to_string function as {other:?}") - }; - with_null_string = true; - } - - /// Creates a single string from single element of a ListArray (which is - /// itself another Array) - fn compute_array_to_string<'a>( - arg: &'a mut String, - arr: &ArrayRef, - delimiter: String, - null_string: String, - with_null_string: bool, - ) -> Result<&'a mut String> { - match arr.data_type() { - List(..) => { - let list_array = as_list_array(&arr)?; - for i in 0..list_array.len() { - if !list_array.is_null(i) { - compute_array_to_string( - arg, - &list_array.value(i), - delimiter.clone(), - null_string.clone(), - with_null_string, - )?; - } else if with_null_string { - arg.push_str(&null_string); - arg.push_str(&delimiter); - } + match delimiter { + Some("") => { + // Empty delimiter: each non-empty string becomes a single-element list. + // Empty strings produce an empty array (PostgreSQL compat). + for i in 0..string_array.len() { + if string_array.is_null(i) { + list_builder.append(false); + continue; } - - Ok(arg) - } - FixedSizeList(..) => { - let list_array = as_fixed_size_list_array(&arr)?; - - for i in 0..list_array.len() { - if !list_array.is_null(i) { - compute_array_to_string( - arg, - &list_array.value(i), - delimiter.clone(), - null_string.clone(), - with_null_string, - )?; - } else if with_null_string { - arg.push_str(&null_string); - arg.push_str(&delimiter); - } + let string = string_array.value(i); + if !string.is_empty() { + append_part(list_builder.values(), string, null_value); } - - Ok(arg) + list_builder.append(true); } - LargeList(..) => { - let list_array = as_large_list_array(&arr)?; - for i in 0..list_array.len() { - if !list_array.is_null(i) { - compute_array_to_string( - arg, - &list_array.value(i), - delimiter.clone(), - null_string.clone(), - with_null_string, - )?; - } else if with_null_string { - arg.push_str(&null_string); - arg.push_str(&delimiter); - } + } + Some(delimiter) => { + // Rather than using `str::split`, do the split ourselves using + // `memmem::Finder`. This allows pre-compiling the delimiter search + // pattern once and reusing it for all rows. + let finder = memchr::memmem::Finder::new(delimiter.as_bytes()); + let delim_len = delimiter.len(); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + list_builder.append(false); + continue; } - - Ok(arg) - } - Dictionary(_key_type, value_type) => { - // Call cast to unwrap the dictionary. This could be optimized if we wanted - // to accept the overhead of extra code - let values = cast(&arr, value_type.as_ref()).map_err(|e| { - DataFusionError::from(e).context( - "Casting dictionary to values in compute_array_to_string", - ) - })?; - compute_array_to_string( - arg, - &values, - delimiter, - null_string, - with_null_string, - ) - } - Null => Ok(arg), - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - to_string!( - arg, - arr, - &delimiter, - &null_string, - with_null_string, - $ARRAY_TYPE - ) - }; + let string = string_array.value(i); + if !string.is_empty() { + let bytes = string.as_bytes(); + let mut start = 0; + for pos in finder.find_iter(bytes) { + append_part( + list_builder.values(), + &string[start..pos], + null_value, + ); + start = pos + delim_len; + } + // Trailing part after last delimiter (or entire string if no + // delimiter was found). + append_part(list_builder.values(), &string[start..], null_value); } - call_array_function!(data_type, false) + list_builder.append(true); } } - } - - fn generate_string_array( - list_arr: &GenericListArray, - delimiters: &[Option<&str>], - null_string: &str, - with_null_string: bool, - ) -> Result { - let mut res: Vec> = Vec::new(); - for (arr, &delimiter) in list_arr.iter().zip(delimiters.iter()) { - if let (Some(arr), Some(delimiter)) = (arr, delimiter) { - let mut arg = String::from(""); - let s = compute_array_to_string( - &mut arg, - &arr, - delimiter.to_string(), - null_string.to_string(), - with_null_string, - )? - .clone(); - - if let Some(s) = s.strip_suffix(delimiter) { - res.push(Some(s.to_string())); - } else { - res.push(Some(s)); + None => { + // NULL delimiter: split into individual characters. + for i in 0..string_array.len() { + if string_array.is_null(i) { + list_builder.append(false); + continue; + } + let string = string_array.value(i); + for (pos, c) in string.char_indices() { + append_part( + list_builder.values(), + &string[pos..pos + c.len_utf8()], + null_value, + ); } - } else { - res.push(None); + list_builder.append(true); } } - - Ok(StringArray::from(res)) } - let string_arr = match arr.data_type() { - List(_) => { - let list_array = as_list_array(&arr)?; - generate_string_array::( - list_array, - &delimiters, - &null_string, - with_null_string, - )? - } - LargeList(_) => { - let list_array = as_large_list_array(&arr)?; - generate_string_array::( - list_array, - &delimiters, - &null_string, - with_null_string, - )? - } - // Signature guards against this arm - _ => return exec_err!("array_to_string expects list as first argument"), - }; - - Ok(Arc::new(string_arr)) + Ok(Arc::new(list_builder.finish()) as ArrayRef) } -/// String_to_array SQL function -/// Splits string at occurrences of delimiter and returns an array of parts -/// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]' -fn string_to_array_inner(args: &[ArrayRef]) -> Result { - if args.len() < 2 || args.len() > 3 { - return exec_err!("string_to_array expects two or three arguments"); - } +/// Fallback path for `string_to_array` when delimiter and/or null_value +/// are array columns rather than scalars. +fn string_to_array_fallback(args: &[ArrayRef]) -> Result { + let null_value_array = args.get(2); match args[0].data_type() { Utf8 => { - let string_array = args[0].as_string::(); - let builder = StringBuilder::with_capacity(string_array.len(), string_array.get_buffer_memory_size()); - string_to_array_inner_2::<&GenericStringArray, StringBuilder>(args, &string_array, builder) + let arr = args[0].as_string::(); + let builder = + StringBuilder::with_capacity(arr.len(), arr.get_buffer_memory_size()); + string_to_array_column_args(&arr, &args[1], null_value_array, builder) } Utf8View => { - let string_array = args[0].as_string_view(); - let builder = StringViewBuilder::with_capacity(string_array.len()); - string_to_array_inner_2::<&StringViewArray, StringViewBuilder>(args, &string_array, builder) + let arr = args[0].as_string_view(); + let builder = StringViewBuilder::with_capacity(arr.len()); + string_to_array_column_args(&arr, &args[1], null_value_array, builder) } LargeUtf8 => { - let string_array = args[0].as_string::(); - let builder = LargeStringBuilder::with_capacity(string_array.len(), string_array.get_buffer_memory_size()); - string_to_array_inner_2::<&GenericStringArray, LargeStringBuilder>(args, &string_array, builder) + let arr = args[0].as_string::(); + let builder = LargeStringBuilder::with_capacity( + arr.len(), + arr.get_buffer_memory_size(), + ); + string_to_array_column_args(&arr, &args[1], null_value_array, builder) } - other => exec_err!("unsupported type for first argument to string_to_array function as {other:?}") + other => exec_err!("unsupported type for string_to_array function as {other:?}"), } } -fn string_to_array_inner_2<'a, StringArrType, StringBuilderType>( - args: &'a [ArrayRef], +fn string_to_array_column_args<'a, StringArrType, StringBuilderType>( string_array: &StringArrType, + delimiter_array: &ArrayRef, + null_value_array: Option<&ArrayRef>, string_builder: StringBuilderType, ) -> Result where StringArrType: StringArrayType<'a>, StringBuilderType: StringArrayBuilderType, { - match args[1].data_type() { - Utf8 => { - let delimiter_array = args[1].as_string::(); - if args.len() == 2 { - string_to_array_impl::< - StringArrType, - &GenericStringArray, - &StringViewArray, - StringBuilderType, - >(string_array, &delimiter_array, None, string_builder) - } else { - string_to_array_inner_3::, - StringBuilderType>(args, string_array, &delimiter_array, string_builder) - } + let mut list_builder = ListBuilder::new(string_builder); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + list_builder.append(false); + continue; } - Utf8View => { - let delimiter_array = args[1].as_string_view(); - - if args.len() == 2 { - string_to_array_impl::< - StringArrType, - &StringViewArray, - &StringViewArray, - StringBuilderType, - >(string_array, &delimiter_array, None, string_builder) - } else { - string_to_array_inner_3::(args, string_array, &delimiter_array, string_builder) + + let string = string_array.value(i); + let delimiter = get_str_value(delimiter_array, i); + let null_value = null_value_array.and_then(|arr| get_str_value(arr, i)); + + match delimiter { + Some("") => { + if !string.is_empty() { + append_part(list_builder.values(), string, null_value); + } } - } - LargeUtf8 => { - let delimiter_array = args[1].as_string::(); - if args.len() == 2 { - string_to_array_impl::< - StringArrType, - &GenericStringArray, - &StringViewArray, - StringBuilderType, - >(string_array, &delimiter_array, None, string_builder) - } else { - string_to_array_inner_3::, - StringBuilderType>(args, string_array, &delimiter_array, string_builder) + Some(delimiter) => { + if !string.is_empty() { + for part in string.split(delimiter) { + append_part(list_builder.values(), part, null_value); + } + } + } + None => { + for (pos, c) in string.char_indices() { + append_part( + list_builder.values(), + &string[pos..pos + c.len_utf8()], + null_value, + ); + } } } - other => exec_err!("unsupported type for second argument to string_to_array function as {other:?}") + + list_builder.append(true); } + + Ok(Arc::new(list_builder.finish()) as ArrayRef) } -fn string_to_array_inner_3<'a, StringArrType, DelimiterArrType, StringBuilderType>( - args: &'a [ArrayRef], - string_array: &StringArrType, - delimiter_array: &DelimiterArrType, - string_builder: StringBuilderType, -) -> Result -where - StringArrType: StringArrayType<'a>, - DelimiterArrType: StringArrayType<'a>, - StringBuilderType: StringArrayBuilderType, -{ - match args[2].data_type() { - Utf8 => { - let null_type_array = Some(args[2].as_string::()); - string_to_array_impl::< - StringArrType, - DelimiterArrType, - &GenericStringArray, - StringBuilderType, - >( - string_array, - delimiter_array, - null_type_array, - string_builder, - ) +/// Returns the string value at index `i` from a string array of any type. +fn get_str_value(array: &ArrayRef, i: usize) -> Option<&str> { + if array.is_null(i) { + return None; + } + match array.data_type() { + Utf8 => Some(array.as_string::().value(i)), + LargeUtf8 => Some(array.as_string::().value(i)), + Utf8View => Some(array.as_string_view().value(i)), + other => { + debug_assert!(false, "unexpected type in get_str_value: {other:?}"); + None } - Utf8View => { - let null_type_array = Some(args[2].as_string_view()); - string_to_array_impl::< - StringArrType, - DelimiterArrType, - &StringViewArray, - StringBuilderType, - >( - string_array, - delimiter_array, - null_type_array, - string_builder, - ) + } +} + +fn array_to_string_inner(args: &[ArrayRef]) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!("array_to_string expects two or three arguments"); + } + + let arr = &args[0]; + + let delimiters: Vec> = match args[1].data_type() { + Utf8 => args[1].as_string::().iter().collect(), + Utf8View => args[1].as_string_view().iter().collect(), + LargeUtf8 => args[1].as_string::().iter().collect(), + other => { + return exec_err!( + "unsupported type for second argument to array_to_string function as {other:?}" + ); } - LargeUtf8 => { - let null_type_array = Some(args[2].as_string::()); - string_to_array_impl::< - StringArrType, - DelimiterArrType, - &GenericStringArray, - StringBuilderType, - >( - string_array, - delimiter_array, - null_type_array, - string_builder, - ) + }; + + let null_strings: Vec> = if args.len() == 3 { + match args[2].data_type() { + Utf8 => args[2].as_string::().iter().collect(), + Utf8View => args[2].as_string_view().iter().collect(), + LargeUtf8 => args[2].as_string::().iter().collect(), + other => { + return exec_err!( + "unsupported type for third argument to array_to_string function as {other:?}" + ); + } } - other => { - exec_err!("unsupported type for string_to_array function as {other:?}") + } else { + // If `null_strings` is not specified, we treat it as equivalent to + // explicitly passing a NULL value for `null_strings` in every row. + vec![None; args[0].len()] + }; + + let string_arr = match arr.data_type() { + List(_) => { + let list_array = as_list_array(&arr)?; + generate_string_array::(list_array, &delimiters, &null_strings)? } - } + LargeList(_) => { + let list_array = as_large_list_array(&arr)?; + generate_string_array::(list_array, &delimiters, &null_strings)? + } + // Signature guards against this arm + _ => return exec_err!("array_to_string expects list as first argument"), + }; + + Ok(Arc::new(string_arr)) } -fn string_to_array_impl< - 'a, - StringArrType, - DelimiterArrType, - NullValueArrType, - StringBuilderType, ->( - string_array: &StringArrType, - delimiter_array: &DelimiterArrType, - null_value_array: Option, - string_builder: StringBuilderType, -) -> Result -where - StringArrType: StringArrayType<'a>, - DelimiterArrType: StringArrayType<'a>, - NullValueArrType: StringArrayType<'a>, - StringBuilderType: StringArrayBuilderType, -{ - let mut list_builder = ListBuilder::new(string_builder); +fn generate_string_array( + list_arr: &GenericListArray, + delimiters: &[Option<&str>], + null_strings: &[Option<&str>], +) -> Result { + let mut builder = StringBuilder::with_capacity(list_arr.len(), 0); + + for ((arr, &delimiter), &null_string) in list_arr + .iter() + .zip(delimiters.iter()) + .zip(null_strings.iter()) + { + let (Some(arr), Some(delimiter)) = (arr, delimiter) else { + builder.append_null(); + continue; + }; - match null_value_array { - None => { - string_array.iter().zip(delimiter_array.iter()).for_each( - |(string, delimiter)| { - match (string, delimiter) { - (Some(string), Some("")) => { - list_builder.values().append_value(string); - list_builder.append(true); - } - (Some(string), Some(delimiter)) => { - string.split(delimiter).for_each(|s| { - list_builder.values().append_value(s); - }); - list_builder.append(true); - } - (Some(string), None) => { - string.chars().map(|c| c.to_string()).for_each(|c| { - list_builder.values().append_value(c.as_str()); - }); - list_builder.append(true); - } - _ => list_builder.append(false), // null value + let mut first = true; + compute_array_to_string(&mut builder, &arr, delimiter, null_string, &mut first)?; + builder.append_value(""); + } + + Ok(builder.finish()) +} + +fn compute_array_to_string( + w: &mut impl Write, + arr: &ArrayRef, + delimiter: &str, + null_string: Option<&str>, + first: &mut bool, +) -> Result<()> { + // Handle lists by recursing on each list element. + macro_rules! handle_list { + ($list_array:expr) => { + for i in 0..$list_array.len() { + if !$list_array.is_null(i) { + compute_array_to_string( + w, + &$list_array.value(i), + delimiter, + null_string, + first, + )?; + } else if let Some(ns) = null_string { + if *first { + *first = false; + } else { + w.write_str(delimiter)?; } - }, - ) + w.write_str(ns)?; + } + } + }; + } + + match arr.data_type() { + List(..) => { + let list_array = as_list_array(arr)?; + handle_list!(list_array); + Ok(()) } - Some(null_value_array) => string_array - .iter() - .zip(delimiter_array.iter()) - .zip(null_value_array.iter()) - .for_each(|((string, delimiter), null_value)| { - match (string, delimiter) { - (Some(string), Some("")) => { - if Some(string) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(string); - } - list_builder.append(true); - } - (Some(string), Some(delimiter)) => { - string.split(delimiter).for_each(|s| { - if Some(s) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(s); - } - }); - list_builder.append(true); - } - (Some(string), None) => { - string.chars().map(|c| c.to_string()).for_each(|c| { - if Some(c.as_str()) == null_value { - list_builder.values().append_null(); + FixedSizeList(..) => { + let list_array = as_fixed_size_list_array(arr)?; + handle_list!(list_array); + Ok(()) + } + LargeList(..) => { + let list_array = as_large_list_array(arr)?; + handle_list!(list_array); + Ok(()) + } + Dictionary(_key_type, value_type) => { + // Call cast to unwrap the dictionary. This could be optimized if we wanted + // to accept the overhead of extra code + let values = cast(arr, value_type.as_ref()).map_err(|e| { + DataFusionError::from(e) + .context("Casting dictionary to values in compute_array_to_string") + })?; + compute_array_to_string(w, &values, delimiter, null_string, first) + } + Null => Ok(()), + data_type => { + macro_rules! str_leaf { + ($ARRAY_TYPE:ident) => { + write_leaf_to_string( + w, + downcast_arg!(arr, $ARRAY_TYPE), + delimiter, + null_string, + first, + |w, x: &str| w.write_str(x), + )? + }; + } + macro_rules! bool_leaf { + ($ARRAY_TYPE:ident) => { + write_leaf_to_string( + w, + downcast_arg!(arr, $ARRAY_TYPE), + delimiter, + null_string, + first, + |w, x: bool| { + if x { + w.write_str("true") } else { - list_builder.values().append_value(c.as_str()); + w.write_str("false") } - }); - list_builder.append(true); - } - _ => list_builder.append(false), // null value + }, + )? + }; + } + macro_rules! int_leaf { + ($ARRAY_TYPE:ident) => { + write_leaf_to_string( + w, + downcast_arg!(arr, $ARRAY_TYPE), + delimiter, + null_string, + first, + |w, x| { + let mut itoa_buf = itoa::Buffer::new(); + w.write_str(itoa_buf.format(x)) + }, + )? + }; + } + macro_rules! float_leaf { + ($ARRAY_TYPE:ident) => { + write_leaf_to_string( + w, + downcast_arg!(arr, $ARRAY_TYPE), + delimiter, + null_string, + first, + // TODO: Consider switching to a more efficient + // floating point display library (e.g., ryu). This + // might result in some differences in the output + // format, however. + |w, x| write!(w, "{}", x), + )? + }; + } + match data_type { + Utf8 => str_leaf!(StringArray), + Utf8View => str_leaf!(StringViewArray), + LargeUtf8 => str_leaf!(LargeStringArray), + DataType::Boolean => bool_leaf!(BooleanArray), + DataType::Float32 => float_leaf!(Float32Array), + DataType::Float64 => float_leaf!(Float64Array), + DataType::Int8 => int_leaf!(Int8Array), + DataType::Int16 => int_leaf!(Int16Array), + DataType::Int32 => int_leaf!(Int32Array), + DataType::Int64 => int_leaf!(Int64Array), + DataType::UInt8 => int_leaf!(UInt8Array), + DataType::UInt16 => int_leaf!(UInt16Array), + DataType::UInt32 => int_leaf!(UInt32Array), + DataType::UInt64 => int_leaf!(UInt64Array), + data_type if can_cast_types(data_type, &Utf8) => { + let str_arr = cast(arr, &Utf8).map_err(|e| { + DataFusionError::from(e) + .context("Casting to string in array_to_string") + })?; + return compute_array_to_string( + w, + &str_arr, + delimiter, + null_string, + first, + ); } - }), - }; + data_type => { + return not_impl_err!( + "Unsupported data type in array_to_string: {data_type}" + ); + } + } + Ok(()) + } + } +} + +/// Appends the string representation of each element in a leaf (non-list) +/// array to `w`, separated by `delimiter`. Null elements are rendered +/// using `null_string` if provided, or skipped otherwise. The `append` +/// closure controls how each non-null element is written. +fn write_leaf_to_string<'a, W: Write, A, T>( + w: &mut W, + arr: &'a A, + delimiter: &str, + null_string: Option<&str>, + first: &mut bool, + append: impl Fn(&mut W, T) -> fmt::Result, +) -> Result<()> +where + &'a A: IntoIterator>, +{ + for x in arr { + // Skip nulls when no null_string is provided + if x.is_none() && null_string.is_none() { + continue; + } + + if *first { + *first = false; + } else { + w.write_str(delimiter)?; + } - let list_array = list_builder.finish(); - Ok(Arc::new(list_array) as ArrayRef) + match x { + Some(x) => append(w, x)?, + None => w.write_str(null_string.unwrap())?, + } + } + Ok(()) } trait StringArrayBuilderType: ArrayBuilder { diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index 464301b6ffcf0..1b2bf428ff2d8 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -22,13 +22,15 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Field, Fields}; use arrow::array::{ - Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, UInt32Array, + Array, ArrayRef, BooleanArray, Float64Array, GenericListArray, NullBufferBuilder, + OffsetBufferBuilder, OffsetSizeTrait, Scalar, }; -use arrow::buffer::OffsetBuffer; +use arrow::buffer::{NullBuffer, OffsetBuffer}; use datafusion_common::cast::{ - as_fixed_size_list_array, as_large_list_array, as_list_array, + as_fixed_size_list_array, as_float64_array, as_generic_list_array, + as_large_list_array, as_large_list_view_array, as_list_array, as_list_view_array, }; -use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err}; use datafusion_expr::ColumnarValue; use itertools::Itertools as _; @@ -161,8 +163,7 @@ pub(crate) fn compare_element_to_list( ); } - let indices = UInt32Array::from(vec![row_index as u32]); - let element_array_row = arrow::compute::take(element_array, &indices, None)?; + let element_array_row = element_array.slice(row_index, 1); // Compute all positions in list_row_array (that is itself an // array) that are equal to `from_array_row` @@ -244,6 +245,14 @@ pub(crate) fn compute_array_dims( value = as_large_list_array(&value)?.value(0); res.push(Some(value.len() as u64)); } + DataType::ListView(_) => { + value = as_list_view_array(&value)?.value(0); + res.push(Some(value.len() as u64)); + } + DataType::LargeListView(_) => { + value = as_large_list_view_array(&value)?.value(0); + res.push(Some(value.len() as u64)); + } DataType::FixedSizeList(..) => { value = as_fixed_size_list_array(&value)?.value(0); res.push(Some(value.len() as u64)); @@ -260,7 +269,7 @@ pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> { match field_data_type { DataType::Struct(fields) => Ok(fields), _ => { - internal_err!("Expected a Struct type, got {:?}", field_data_type) + internal_err!("Expected a Struct type, got {}", field_data_type) } } } @@ -268,6 +277,141 @@ pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> { } } +/// Shared `coerce_types` impl for array-math UDFs whose kernels expect +/// `List` / `LargeList` (e.g. `array_add`, `cosine_distance`, +/// `inner_product`, `array_normalize`). +/// +/// Each input must be `Null`, `List`, `LargeList`, or `FixedSizeList`; otherwise +/// returns a plan error naming `name`. `FixedSizeList` is widened to `List`, +/// `Null` is coerced to a list of `Float64`, and if any input is `LargeList` +/// the rest are widened to `LargeList` so the runtime sees a homogeneous pair. +pub(crate) fn coerce_array_math_arg_types( + name: &str, + arg_types: &[DataType], +) -> Result> { + use DataType::{FixedSizeList, LargeList, List, Null}; + use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; + + let coercion = Some(&ListCoercion::FixedSizedListToList); + + for arg_type in arg_types { + if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + return plan_err!("{name} does not support type {arg_type}"); + } + } + + // If any input is `LargeList`, both sides must be widened to `LargeList` + // so the runtime dispatch in `inner_product_inner` sees a homogeneous + // pair. Follows the pattern in `ArrayConcat::coerce_types`. + let any_large_list = arg_types.iter().any(|t| matches!(t, LargeList(_))); + + let coerced = arg_types + .iter() + .map(|arg_type| { + if matches!(arg_type, Null) { + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + return if any_large_list { + LargeList(field) + } else { + List(field) + }; + } + let coerced = + coerced_type_with_base_type_only(arg_type, &DataType::Float64, coercion); + match coerced { + List(field) if any_large_list => LargeList(field), + other => other, + } + }) + .collect(); + + Ok(coerced) +} + +/// Element-wise binary operation kernel for two `Float64` lists of equal per-row +/// length. The caller is responsible for type-dispatching on `O` (`i32` for +/// `List`, `i64` for `LargeList`). +/// +/// Semantics: +/// - whole-row NULL on either side → NULL output row, length 0 +/// - per-element NULL on either side → NULL at that output position +/// - per-row length mismatch → exec error tagged with `op_name` +/// +/// `op_name` flows into the error message; `op` is the per-element scalar op +/// (e.g. `|a, b| a + b` for `array_add`, `|a, b| a - b` for `array_subtract`). +pub(crate) fn array_math_binary_op( + op_name: &str, + lhs: &ArrayRef, + rhs: &ArrayRef, + op: F, +) -> Result +where + O: OffsetSizeTrait, + F: Fn(f64, f64) -> f64, +{ + let lhs = as_generic_list_array::(lhs)?; + let rhs = as_generic_list_array::(rhs)?; + + let lhs_values = as_float64_array(lhs.values())?; + let rhs_values = as_float64_array(rhs.values())?; + let lhs_offsets = lhs.value_offsets(); + let rhs_offsets = rhs.value_offsets(); + + let row_nulls = NullBuffer::union(lhs.nulls(), rhs.nulls()); + + let mut out_values: Vec = Vec::with_capacity(lhs_values.len()); + let mut out_inner_nulls = NullBufferBuilder::new(lhs_values.len()); + let mut out_offsets = OffsetBufferBuilder::::new(lhs.len()); + + for row in 0..lhs.len() { + if row_nulls.as_ref().is_some_and(|nb| nb.is_null(row)) { + out_offsets.push_length(0); + continue; + } + + let start1 = lhs_offsets[row].as_usize(); + let len1 = lhs.value_length(row).as_usize(); + let start2 = rhs_offsets[row].as_usize(); + let len2 = rhs.value_length(row).as_usize(); + + if len1 != len2 { + return exec_err!( + "{op_name} requires both list inputs to have the same length per row, got {len1} and {len2} at row {row}" + ); + } + + let l_slice = lhs_values.slice(start1, len1); + let r_slice = rhs_values.slice(start2, len2); + + let l_vals = l_slice.values(); + let r_vals = r_slice.values(); + + for i in 0..len1 { + out_values.push(op(l_vals[i], r_vals[i])); + } + + match NullBuffer::union(l_slice.nulls(), r_slice.nulls()) { + Some(nb) => out_inner_nulls.append_buffer(&nb), + None => out_inner_nulls.append_n_non_nulls(len1), + } + + out_offsets.push_length(len1); + } + + let values_array = Arc::new(Float64Array::new( + out_values.into(), + out_inner_nulls.finish(), + )); + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + + Ok(Arc::new(GenericListArray::::try_new( + field, + out_offsets.finish(), + values_array, + row_nulls, + )?)) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index ad52a551a7c17..4eca16961fa8c 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -40,7 +40,7 @@ workspace = true [features] crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] # enable datetime functions -datetime_expressions = [] +datetime_expressions = ["chrono-tz"] # Enable encoding by default so the doctests work. In general don't automatically enable all packages. default = [ "datetime_expressions", @@ -59,7 +59,7 @@ regex_expressions = ["regex"] # enable string functions string_expressions = ["uuid"] # enable unicode functions -unicode_expressions = ["unicode-segmentation"] +unicode_expressions = [] [lib] name = "datafusion_functions" @@ -71,22 +71,24 @@ base64 = { version = "0.22", optional = true } blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.8", optional = true } chrono = { workspace = true } +chrono-tz = { version = "0.10.4", optional = true } datafusion-common = { workspace = true } datafusion-doc = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true } datafusion-macros = { workspace = true } -hex = { version = "0.4", optional = true } +datafusion-physical-expr-common = { workspace = true } +hex = { workspace = true, optional = true } itertools = { workspace = true } log = { workspace = true } -md-5 = { version = "^0.10.0", optional = true } +md-5 = { version = "^0.11.0", optional = true } +memchr = { workspace = true } num-traits = { workspace = true } rand = { workspace = true } regex = { workspace = true, optional = true } -sha2 = { version = "^0.10.9", optional = true } -unicode-segmentation = { version = "^1.7.1", optional = true } -uuid = { version = "1.18", features = ["v4"], optional = true } +sha2 = { workspace = true, optional = true } +uuid = { workspace = true, features = ["v4"], optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } @@ -106,6 +108,11 @@ harness = false name = "concat" required-features = ["string_expressions"] +[[bench]] +harness = false +name = "concat_ws" +required-features = ["string_expressions"] + [[bench]] harness = false name = "to_timestamp" @@ -126,6 +133,16 @@ harness = false name = "gcd" required-features = ["math_expressions"] +[[bench]] +harness = false +name = "lcm" +required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "nanvl" +required-features = ["math_expressions"] + [[bench]] harness = false name = "uuid" @@ -170,6 +187,16 @@ harness = false name = "to_char" required-features = ["datetime_expressions"] +[[bench]] +harness = false +name = "to_local_time" +required-features = ["datetime_expressions"] + +[[bench]] +harness = false +name = "to_time" +required-features = ["datetime_expressions"] + [[bench]] harness = false name = "isnan" @@ -180,6 +207,16 @@ harness = false name = "signum" required-features = ["math_expressions"] +[[bench]] +harness = false +name = "atan2" +required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "power" +required-features = ["math_expressions"] + [[bench]] harness = false name = "substr_index" @@ -187,7 +224,7 @@ required-features = ["unicode_expressions"] [[bench]] harness = false -name = "ltrim" +name = "trim" required-features = ["string_expressions"] [[bench]] @@ -210,6 +247,15 @@ harness = false name = "repeat" required-features = ["string_expressions"] +[[bench]] +harness = false +name = "replace" +required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "overlay" + [[bench]] harness = false name = "random" @@ -254,3 +300,63 @@ required-features = ["unicode_expressions"] harness = false name = "find_in_set" required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "contains" +required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "starts_with" +required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "ends_with" +required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "regexp_count" +required-features = ["regex_expressions"] + +[[bench]] +harness = false +name = "crypto" +required-features = ["crypto_expressions"] + +[[bench]] +harness = false +name = "translate" +required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "levenshtein" +required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "split_part" +required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "left_right" +required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "factorial" +required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "floor_ceil" +required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "round" +required-features = ["math_expressions"] diff --git a/datafusion/functions/benches/ascii.rs b/datafusion/functions/benches/ascii.rs index 03d25e9c3d4fe..a2424ed352afc 100644 --- a/datafusion/functions/benches/ascii.rs +++ b/datafusion/functions/benches/ascii.rs @@ -15,19 +15,47 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; mod helper; use arrow::datatypes::{DataType, Field}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; -use datafusion_expr::ScalarFunctionArgs; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use helper::gen_string_array; use std::hint::black_box; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let ascii = datafusion_functions::string::ascii(); + let config_options = Arc::new(ConfigOptions::default()); + + // Scalar benchmarks (outside loop) + c.bench_function("ascii/scalar_utf8", |b| { + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "hello".to_string(), + )))], + arg_fields: vec![Field::new("a", DataType::Utf8, false).into()], + number_rows: 1, + return_field: Field::new("f", DataType::Int32, true).into(), + config_options: Arc::clone(&config_options), + }; + b.iter(|| black_box(ascii.invoke_with_args(args.clone()).unwrap())) + }); + + c.bench_function("ascii/scalar_utf8view", |b| { + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "hello".to_string(), + )))], + arg_fields: vec![Field::new("a", DataType::Utf8View, false).into()], + number_rows: 1, + return_field: Field::new("f", DataType::Int32, true).into(), + config_options: Arc::clone(&config_options), + }; + b.iter(|| black_box(ascii.invoke_with_args(args.clone()).unwrap())) + }); // All benches are single batch run with 8192 rows const N_ROWS: usize = 8192; diff --git a/datafusion/functions/benches/atan2.rs b/datafusion/functions/benches/atan2.rs new file mode 100644 index 0000000000000..f1c9756a0cc08 --- /dev/null +++ b/datafusion/functions/benches/atan2.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::datatypes::{DataType, Field, Float32Type, Float64Type}; +use arrow::util::bench_util::create_primitive_array; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::math::atan2; +use std::hint::black_box; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let atan2_fn = atan2(); + let config_options = Arc::new(ConfigOptions::default()); + + for size in [1024, 4096, 8192] { + let y_f32 = Arc::new(create_primitive_array::(size, 0.2)); + let x_f32 = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(y_f32), ColumnarValue::Array(x_f32)]; + let f32_arg_fields = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field_f32 = Field::new("f", DataType::Float32, true).into(); + + c.bench_function(&format!("atan2 f32 array: {size}"), |b| { + b.iter(|| { + black_box( + atan2_fn + .invoke_with_args(ScalarFunctionArgs { + args: f32_args.clone(), + arg_fields: f32_arg_fields.clone(), + number_rows: size, + return_field: Arc::clone(&return_field_f32), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let y_f64 = Arc::new(create_primitive_array::(size, 0.2)); + let x_f64 = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(y_f64), ColumnarValue::Array(x_f64)]; + let f64_arg_fields = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let return_field_f64 = Field::new("f", DataType::Float64, true).into(); + + c.bench_function(&format!("atan2 f64 array: {size}"), |b| { + b.iter(|| { + black_box( + atan2_fn + .invoke_with_args(ScalarFunctionArgs { + args: f64_args.clone(), + arg_fields: f64_arg_fields.clone(), + number_rows: size, + return_field: Arc::clone(&return_field_f64), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + } + + let scalar_f32_args = vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0))), + ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), + ]; + let scalar_f32_arg_fields = vec![ + Field::new("a", DataType::Float32, false).into(), + Field::new("b", DataType::Float32, false).into(), + ]; + let return_field_f32 = Field::new("f", DataType::Float32, false).into(); + + c.bench_function("atan2 f32 scalar", |b| { + b.iter(|| { + black_box( + atan2_fn + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f32), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f64_args = vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), + ]; + let scalar_f64_arg_fields = vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("b", DataType::Float64, false).into(), + ]; + let return_field_f64 = Field::new("f", DataType::Float64, false).into(); + + c.bench_function("atan2 f64 scalar", |b| { + b.iter(|| { + black_box( + atan2_fn + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_f64_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f64), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index 4a1a63d62765f..4927627ec2f05 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -15,10 +15,8 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; use datafusion_expr::ScalarFunctionArgs; use helper::gen_string_array; diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index 8356cf7c31726..a702dc161ae06 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -15,10 +15,9 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::{array::PrimitiveArray, datatypes::Int64Type}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string::chr; use rand::{Rng, SeedableRng}; @@ -35,11 +34,32 @@ pub fn seedable_rng() -> StdRng { } fn criterion_benchmark(c: &mut Criterion) { - let cot_fn = chr(); + let chr_fn = chr(); + let config_options = Arc::new(ConfigOptions::default()); + + // Scalar benchmarks + c.bench_function("chr/scalar", |b| { + let args = vec![ColumnarValue::Scalar(ScalarValue::Int64(Some(65)))]; + let arg_fields = vec![Field::new("arg_0", DataType::Int64, true).into()]; + b.iter(|| { + black_box( + chr_fn + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + let size = 1024; let input: PrimitiveArray = { let null_density = 0.2; - let mut rng = StdRng::seed_from_u64(42); + let mut rng = seedable_rng(); (0..size) .map(|_| { if rng.random::() < null_density { @@ -57,12 +77,11 @@ fn criterion_benchmark(c: &mut Criterion) { .enumerate() .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); - c.bench_function("chr", |b| { + c.bench_function("chr/array", |b| { b.iter(|| { black_box( - cot_fn + chr_fn .invoke_with_args(ScalarFunctionArgs { args: args.clone(), arg_fields: arg_fields.clone(), diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 09200139a244b..0fb910800e3bc 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -17,16 +17,18 @@ use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; -use arrow::util::bench_util::create_string_array_with_len; -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use datafusion_common::config::ConfigOptions; +use arrow::util::bench_util::{create_string_array_with_len, create_string_view_array}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string::concat; +use rand::Rng; +use rand::distr::Alphanumeric; use std::hint::black_box; use std::sync::Arc; -fn create_args(size: usize, str_len: usize) -> Vec { +fn create_array_args(size: usize, str_len: usize) -> Vec { let array = Arc::new(create_string_array_with_len::(size, 0.2, str_len)); let scalar = ScalarValue::Utf8(Some(", ".to_string())); vec![ @@ -36,9 +38,37 @@ fn create_args(size: usize, str_len: usize) -> Vec { ] } +fn create_array_args_view(size: usize) -> Vec { + let array = Arc::new(create_string_view_array(size, 0.2)); + let scalar = ScalarValue::Utf8(Some(", ".to_string())); + vec![ + ColumnarValue::Array(Arc::clone(&array) as ArrayRef), + ColumnarValue::Scalar(scalar), + ColumnarValue::Array(array), + ] +} + +fn generate_random_string(str_len: usize) -> String { + rand::rng() + .sample_iter(&Alphanumeric) + .take(str_len) + .map(char::from) + .collect() +} + +fn create_scalar_args(count: usize, str_len: usize) -> Vec { + std::iter::repeat_with(|| { + let s = generate_random_string(str_len); + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) + }) + .take(count) + .collect() +} + fn criterion_benchmark(c: &mut Criterion) { + // Benchmark for array concat for size in [1024, 4096, 8192] { - let args = create_args(size, 32); + let args = create_array_args(size, 32); let arg_fields = args .iter() .enumerate() @@ -67,6 +97,70 @@ fn criterion_benchmark(c: &mut Criterion) { }); group.finish(); } + + // Benchmark for StringViewArray concat + for size in [1024, 4096, 8192] { + let args = create_array_args_view(size); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + // Use Utf8View for array args + let dt = if matches!(arg, ColumnarValue::Array(_)) { + DataType::Utf8View + } else { + DataType::Utf8 // scalar remains Utf8 + }; + Field::new(format!("arg_{idx}"), dt, true).into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + let mut group = c.benchmark_group("concat function"); + group.bench_function(BenchmarkId::new("concat_view", size), |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box( + concat() + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8View, true) + .into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + group.finish(); + } + + // Benchmark for scalar concat + let scalar_args = create_scalar_args(10, 100); + let scalar_arg_fields = scalar_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let mut group = c.benchmark_group("concat function"); + group.bench_function(BenchmarkId::new("concat", "scalar"), |b| { + b.iter(|| { + let args_cloned = scalar_args.clone(); + black_box( + concat() + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: scalar_arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/concat_ws.rs b/datafusion/functions/benches/concat_ws.rs new file mode 100644 index 0000000000000..97d6d96411d73 --- /dev/null +++ b/datafusion/functions/benches/concat_ws.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field}; +use arrow::util::bench_util::create_string_array_with_len; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::string::concat_ws; +use rand::Rng; +use rand::distr::Alphanumeric; +use std::hint::black_box; +use std::sync::Arc; + +fn create_array_args(size: usize, str_len: usize) -> Vec { + let array = Arc::new(create_string_array_with_len::(size, 0.2, str_len)); + let scalar = ScalarValue::Utf8(Some(", ".to_string())); + vec![ + ColumnarValue::Scalar(scalar), + ColumnarValue::Array(Arc::clone(&array) as ArrayRef), + ColumnarValue::Array(array), + ] +} + +fn generate_random_string(str_len: usize) -> String { + rand::rng() + .sample_iter(&Alphanumeric) + .take(str_len) + .map(char::from) + .collect() +} + +fn create_scalar_args(count: usize, str_len: usize) -> Vec { + let mut args = Vec::with_capacity(count + 1); + + args.push(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + ",".to_string(), + )))); + + for _ in 0..count { + let s = generate_random_string(str_len); + args.push(ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))); + } + args +} + +fn criterion_benchmark(c: &mut Criterion) { + // Benchmark for array concat_ws + for size in [1024, 4096, 8192] { + let args = create_array_args(size, 32); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + let mut group = c.benchmark_group("concat_ws function"); + group.bench_function(BenchmarkId::new("concat_ws", size), |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box( + concat_ws() + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + group.finish(); + } + + // Benchmark for scalar concat_ws + let scalar_args = create_scalar_args(10, 100); + let scalar_arg_fields = scalar_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let mut group = c.benchmark_group("concat_ws function"); + group.bench_function(BenchmarkId::new("concat_ws", "scalar"), |b| { + b.iter(|| { + let args_cloned = scalar_args.clone(); + black_box( + concat_ws() + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: scalar_arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/contains.rs b/datafusion/functions/benches/contains.rs new file mode 100644 index 0000000000000..6c39f45e14fa6 --- /dev/null +++ b/datafusion/functions/benches/contains.rs @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{StringArray, StringViewArray}; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use rand::distr::Alphanumeric; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +/// Generate a StringArray/StringViewArray with random ASCII strings +fn gen_string_array( + n_rows: usize, + str_len: usize, + is_string_view: bool, +) -> ColumnarValue { + let mut rng = StdRng::seed_from_u64(42); + let strings: Vec> = (0..n_rows) + .map(|_| { + let s: String = (&mut rng) + .sample_iter(&Alphanumeric) + .take(str_len) + .map(char::from) + .collect(); + Some(s) + }) + .collect(); + + if is_string_view { + ColumnarValue::Array(Arc::new(StringViewArray::from(strings))) + } else { + ColumnarValue::Array(Arc::new(StringArray::from(strings))) + } +} + +/// Generate a scalar search string +fn gen_scalar_search(search_str: &str, is_string_view: bool) -> ColumnarValue { + if is_string_view { + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(search_str.to_string()))) + } else { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(search_str.to_string()))) + } +} + +/// Generate an array of search strings (same string repeated) +fn gen_array_search( + search_str: &str, + n_rows: usize, + is_string_view: bool, +) -> ColumnarValue { + let strings: Vec> = + (0..n_rows).map(|_| Some(search_str.to_string())).collect(); + + if is_string_view { + ColumnarValue::Array(Arc::new(StringViewArray::from(strings))) + } else { + ColumnarValue::Array(Arc::new(StringArray::from(strings))) + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let contains = datafusion_functions::string::contains(); + let n_rows = 8192; + let str_len = 128; + let search_str = "xyz"; // A pattern that likely won't be found + + // Benchmark: StringArray with scalar search (the optimized path) + let str_array = gen_string_array(n_rows, str_len, false); + let scalar_search = gen_scalar_search(search_str, false); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ]; + let return_field = Field::new("f", DataType::Boolean, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function("contains_StringArray_scalar_search", |b| { + b.iter(|| { + black_box(contains.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), scalar_search.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringArray with array search (for comparison) + let array_search = gen_array_search(search_str, n_rows, false); + c.bench_function("contains_StringArray_array_search", |b| { + b.iter(|| { + black_box(contains.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), array_search.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringViewArray with scalar search (the optimized path) + let str_view_array = gen_string_array(n_rows, str_len, true); + let scalar_search_view = gen_scalar_search(search_str, true); + let arg_fields_view = vec![ + Field::new("a", DataType::Utf8View, true).into(), + Field::new("b", DataType::Utf8View, true).into(), + ]; + + c.bench_function("contains_StringViewArray_scalar_search", |b| { + b.iter(|| { + black_box(contains.invoke_with_args(ScalarFunctionArgs { + args: vec![str_view_array.clone(), scalar_search_view.clone()], + arg_fields: arg_fields_view.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringViewArray with array search (for comparison) + let array_search_view = gen_array_search(search_str, n_rows, true); + c.bench_function("contains_StringViewArray_array_search", |b| { + b.iter(|| { + black_box(contains.invoke_with_args(ScalarFunctionArgs { + args: vec![str_view_array.clone(), array_search_view.clone()], + arg_fields: arg_fields_view.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark different string lengths with scalar search + for str_len in [8, 32, 128, 512] { + let str_array = gen_string_array(n_rows, str_len, true); + let scalar_search = gen_scalar_search(search_str, true); + let arg_fields = vec![ + Field::new("a", DataType::Utf8View, true).into(), + Field::new("b", DataType::Utf8View, true).into(), + ]; + + c.bench_function( + &format!("contains_StringViewArray_scalar_strlen_{str_len}"), + |b| { + b.iter(|| { + black_box(contains.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), scalar_search.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index 97f21ccd6d55e..16c3fba2175fe 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -15,23 +15,25 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::cot; use std::hint::black_box; use arrow::datatypes::{DataType, Field}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let cot_fn = cot(); + let config_options = Arc::new(ConfigOptions::default()); + + // Array benchmarks - run for different sizes for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; @@ -42,7 +44,6 @@ fn criterion_benchmark(c: &mut Criterion) { Field::new(format!("arg_{idx}"), arg.data_type(), true).into() }) .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); c.bench_function(&format!("cot f32 array: {size}"), |b| { b.iter(|| { @@ -59,6 +60,7 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; let arg_fields = f64_args @@ -86,6 +88,47 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); } + + // Scalar benchmarks - run only once since size doesn't affect scalar performance + let scalar_f32_args = vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0)))]; + let scalar_f32_arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let return_field_f32 = Field::new("f", DataType::Float32, false).into(); + + c.bench_function("cot f32 scalar", |b| { + b.iter(|| { + black_box( + cot_fn + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f32), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f64_args = vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0)))]; + let scalar_f64_arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let return_field_f64 = Field::new("f", DataType::Float64, false).into(); + + c.bench_function("cot f64 scalar", |b| { + b.iter(|| { + black_box( + cot_fn + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_f64_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f64), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/crypto.rs b/datafusion/functions/benches/crypto.rs new file mode 100644 index 0000000000000..9a86efbff9ed8 --- /dev/null +++ b/datafusion/functions/benches/crypto.rs @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field}; +use arrow::util::bench_util::create_string_array_with_len; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::ScalarFunctionArgs; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_functions::crypto; +use std::hint::black_box; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let crypto = vec![ + crypto::md5(), + crypto::sha224(), + crypto::sha256(), + crypto::sha384(), + crypto::sha512(), + ]; + let config_options = Arc::new(ConfigOptions::default()); + + for func in crypto { + let size = 1024; + let arr_args = vec![ColumnarValue::Array(Arc::new( + create_string_array_with_len::(size, 0.2, 32), + ))]; + c.bench_function(&format!("{}_array", func.name()), |b| { + b.iter(|| { + let args_cloned = arr_args.clone(); + black_box(func.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Utf8, true).into()], + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); + + let scalar_args = vec![ColumnarValue::Scalar("test_string".into())]; + c.bench_function(&format!("{}_scalar", func.name()), |b| { + b.iter(|| { + let args_cloned = scalar_args.clone(); + black_box(func.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Utf8, true).into()], + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index 74390491d538c..28dee96987261 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -15,20 +15,18 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; use arrow::array::{Array, ArrayRef, TimestampSecondArray}; use arrow::datatypes::Field; -use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion_common::config::ConfigOptions; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::datetime::date_bin; -use rand::rngs::ThreadRng; use rand::Rng; +use rand::rngs::ThreadRng; fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { let mut seconds = vec![]; diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index 498a3e63ef290..0668a1cc5085c 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -15,20 +15,18 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; use arrow::array::{Array, ArrayRef, TimestampSecondArray}; use arrow::datatypes::Field; -use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion_common::config::ConfigOptions; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::ScalarValue; -use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs}; use datafusion_functions::datetime::date_trunc; -use rand::rngs::ThreadRng; use rand::Rng; +use rand::rngs::ThreadRng; fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { let mut seconds = vec![]; @@ -57,10 +55,13 @@ fn criterion_benchmark(c: &mut Criterion) { }) .collect::>(); - let return_type = udf - .return_type(&args.iter().map(|arg| arg.data_type()).collect::>()) + let scalar_arguments = vec![None; arg_fields.len()]; + let return_field = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &scalar_arguments, + }) .unwrap(); - let return_field = Arc::new(Field::new("f", return_type, true)); let config_options = Arc::new(ConfigOptions::default()); b.iter(|| { diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index 98faee91e1911..0b8f0c5c51a58 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -15,12 +15,10 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::Array; use arrow::datatypes::{DataType, Field}; -use arrow::util::bench_util::create_string_array_with_len; -use criterion::{criterion_group, criterion_main, Criterion}; +use arrow::util::bench_util::create_binary_array; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::encoding; @@ -32,20 +30,22 @@ fn criterion_benchmark(c: &mut Criterion) { let config_options = Arc::new(ConfigOptions::default()); for size in [1024, 4096, 8192] { - let str_array = Arc::new(create_string_array_with_len::(size, 0.2, 32)); + let bin_array = Arc::new(create_binary_array::(size, 0.2)); c.bench_function(&format!("base64_decode/{size}"), |b| { let method = ColumnarValue::Scalar("base64".into()); let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], + args: vec![ColumnarValue::Array(bin_array.clone()), method.clone()], arg_fields: vec![ - Field::new("a", str_array.data_type().to_owned(), true).into(), + Field::new("a", bin_array.data_type().to_owned(), true).into(), Field::new("b", method.data_type().to_owned(), true).into(), ], number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), }) + .unwrap() + .cast_to(&DataType::Binary, None) .unwrap(); let arg_fields = vec![ @@ -61,7 +61,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: arg_fields.clone(), number_rows: size, - return_field: Field::new("f", DataType::Utf8, true).into(), + return_field: Field::new("f", DataType::Binary, true).into(), config_options: Arc::clone(&config_options), }) .unwrap(), @@ -72,24 +72,26 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function(&format!("hex_decode/{size}"), |b| { let method = ColumnarValue::Scalar("hex".into()); let arg_fields = vec![ - Field::new("a", str_array.data_type().to_owned(), true).into(), + Field::new("a", bin_array.data_type().to_owned(), true).into(), Field::new("b", method.data_type().to_owned(), true).into(), ]; let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], + args: vec![ColumnarValue::Array(bin_array.clone()), method.clone()], arg_fields, number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), }) + .unwrap() + .cast_to(&DataType::Binary, None) .unwrap(); let arg_fields = vec![ Field::new("a", encoded.data_type().to_owned(), true).into(), Field::new("b", method.data_type().to_owned(), true).into(), ]; - let return_field = Field::new("f", DataType::Utf8, true).into(); + let return_field = Field::new("f", DataType::Binary, true).into(); let args = vec![encoded, method]; b.iter(|| { diff --git a/datafusion/functions/benches/ends_with.rs b/datafusion/functions/benches/ends_with.rs new file mode 100644 index 0000000000000..474e8a1555cf2 --- /dev/null +++ b/datafusion/functions/benches/ends_with.rs @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{StringArray, StringViewArray}; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use rand::distr::Alphanumeric; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +/// Generate a StringArray/StringViewArray with random ASCII strings +fn gen_string_array( + n_rows: usize, + str_len: usize, + is_string_view: bool, +) -> ColumnarValue { + let mut rng = StdRng::seed_from_u64(42); + let strings: Vec> = (0..n_rows) + .map(|_| { + let s: String = (&mut rng) + .sample_iter(&Alphanumeric) + .take(str_len) + .map(char::from) + .collect(); + Some(s) + }) + .collect(); + + if is_string_view { + ColumnarValue::Array(Arc::new(StringViewArray::from(strings))) + } else { + ColumnarValue::Array(Arc::new(StringArray::from(strings))) + } +} + +/// Generate a scalar suffix string +fn gen_scalar_suffix(suffix_str: &str, is_string_view: bool) -> ColumnarValue { + if is_string_view { + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(suffix_str.to_string()))) + } else { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(suffix_str.to_string()))) + } +} + +/// Generate an array of suffix strings (same string repeated) +fn gen_array_suffix( + suffix_str: &str, + n_rows: usize, + is_string_view: bool, +) -> ColumnarValue { + let strings: Vec> = + (0..n_rows).map(|_| Some(suffix_str.to_string())).collect(); + + if is_string_view { + ColumnarValue::Array(Arc::new(StringViewArray::from(strings))) + } else { + ColumnarValue::Array(Arc::new(StringArray::from(strings))) + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let ends_with = datafusion_functions::string::ends_with(); + let n_rows = 8192; + let str_len = 128; + let suffix_str = "xyz"; // A pattern that likely won't match + + // Benchmark: StringArray with scalar suffix (the optimized path) + let str_array = gen_string_array(n_rows, str_len, false); + let scalar_suffix = gen_scalar_suffix(suffix_str, false); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ]; + let return_field = Field::new("f", DataType::Boolean, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function("ends_with_StringArray_scalar_suffix", |b| { + b.iter(|| { + black_box(ends_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), scalar_suffix.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringArray with array suffix (for comparison) + let array_suffix = gen_array_suffix(suffix_str, n_rows, false); + c.bench_function("ends_with_StringArray_array_suffix", |b| { + b.iter(|| { + black_box(ends_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), array_suffix.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringViewArray with scalar suffix (the optimized path) + let str_view_array = gen_string_array(n_rows, str_len, true); + let scalar_suffix_view = gen_scalar_suffix(suffix_str, true); + let arg_fields_view = vec![ + Field::new("a", DataType::Utf8View, true).into(), + Field::new("b", DataType::Utf8View, true).into(), + ]; + + c.bench_function("ends_with_StringViewArray_scalar_suffix", |b| { + b.iter(|| { + black_box(ends_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_view_array.clone(), scalar_suffix_view.clone()], + arg_fields: arg_fields_view.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringViewArray with array suffix (for comparison) + let array_suffix_view = gen_array_suffix(suffix_str, n_rows, true); + c.bench_function("ends_with_StringViewArray_array_suffix", |b| { + b.iter(|| { + black_box(ends_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_view_array.clone(), array_suffix_view.clone()], + arg_fields: arg_fields_view.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark different string lengths with scalar suffix + for str_len in [8, 32, 128, 512] { + let str_array = gen_string_array(n_rows, str_len, true); + let scalar_suffix = gen_scalar_suffix(suffix_str, true); + let arg_fields = vec![ + Field::new("a", DataType::Utf8View, true).into(), + Field::new("b", DataType::Utf8View, true).into(), + ]; + + c.bench_function( + &format!("ends_with_StringViewArray_scalar_strlen_{str_len}"), + |b| { + b.iter(|| { + black_box(ends_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), scalar_suffix.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/factorial.rs b/datafusion/functions/benches/factorial.rs new file mode 100644 index 0000000000000..c441b50c288c3 --- /dev/null +++ b/datafusion/functions/benches/factorial.rs @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Int64Array; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::ScalarFunctionArgs; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_functions::math::factorial; +use std::hint::black_box; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let factorial = factorial(); + let config_options = Arc::new(ConfigOptions::default()); + + let arr_args = vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter( + (0..1024).map(|i| Some(i % 21)), + )))]; + c.bench_function(&format!("{}_array", factorial.name()), |b| { + b.iter(|| { + let args_cloned = arr_args.clone(); + black_box(factorial.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Utf8, true).into()], + number_rows: arr_args.len(), + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); + + let scalar_args = vec![ColumnarValue::Scalar(ScalarValue::Int64(Some(20)))]; + c.bench_function(&format!("{}_scalar", factorial.name()), |b| { + b.iter(|| { + let args_cloned = scalar_args.clone(); + black_box(factorial.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Utf8, true).into()], + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index a928f5655806c..9ee20ecd14fdf 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -15,16 +15,14 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; -use criterion::{criterion_group, criterion_main, Criterion, SamplingMode}; -use datafusion_common::config::ConfigOptions; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use rand::distr::Alphanumeric; use rand::prelude::StdRng; diff --git a/datafusion/functions/benches/floor_ceil.rs b/datafusion/functions/benches/floor_ceil.rs new file mode 100644 index 0000000000000..dc095e0152c4d --- /dev/null +++ b/datafusion/functions/benches/floor_ceil.rs @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Float64Type}; +use arrow::util::bench_util::create_primitive_array; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::math::{ceil, floor}; +use std::hint::black_box; +use std::sync::Arc; +use std::time::Duration; + +fn criterion_benchmark(c: &mut Criterion) { + let floor_fn = floor(); + let ceil_fn = ceil(); + let config_options = Arc::new(ConfigOptions::default()); + + for size in [1024, 4096, 8192] { + let mut group = c.benchmark_group(format!("floor_ceil size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // Float64 array benchmark + let f64_array = Arc::new(create_primitive_array::(size, 0.1)); + let batch_len = f64_array.len(); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + + group.bench_function("floor_f64_array", |b| { + b.iter(|| { + let args_cloned = f64_args.clone(); + black_box( + floor_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, true).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + group.bench_function("ceil_f64_array", |b| { + b.iter(|| { + let args_cloned = f64_args.clone(); + black_box( + ceil_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, true).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + // Scalar benchmark (the optimization we added) + let scalar_args = vec![ColumnarValue::Scalar(ScalarValue::Float64(Some( + std::f64::consts::PI, + )))]; + + group.bench_function("floor_f64_scalar", |b| { + b.iter(|| { + let args_cloned = scalar_args.clone(); + black_box( + floor_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false) + .into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + group.bench_function("ceil_f64_scalar", |b| { + b.iter(|| { + let args_cloned = scalar_args.clone(); + black_box( + ceil_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false) + .into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index 19e196d9a3eab..3c72a46e6643d 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -15,16 +15,14 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::Field; use arrow::{ array::{ArrayRef, Int64Array}, datatypes::DataType, }; -use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion_common::config::ConfigOptions; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::gcd; use rand::Rng; diff --git a/datafusion/functions/benches/helper.rs b/datafusion/functions/benches/helper.rs index a2b110ae4d63b..d6d6afd48f2ca 100644 --- a/datafusion/functions/benches/helper.rs +++ b/datafusion/functions/benches/helper.rs @@ -18,7 +18,7 @@ use arrow::array::{StringArray, StringViewArray}; use datafusion_expr::ColumnarValue; use rand::distr::Alphanumeric; -use rand::{rngs::StdRng, Rng, SeedableRng}; +use rand::{Rng, SeedableRng, rngs::StdRng}; use std::sync::Arc; /// gen_arr(4096, 128, 0.1, 0.1, true) will generate a StringViewArray with diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index 50aee8dbb9161..b5e653e4136a3 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -15,19 +15,19 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - -use arrow::array::OffsetSizeTrait; +use arrow::array::{ArrayRef, OffsetSizeTrait, StringArray, StringViewBuilder}; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode; use std::hint::black_box; use std::sync::Arc; +use std::time::Duration; fn create_args( size: usize, @@ -47,62 +47,161 @@ fn create_args( } } +/// Create a Utf8 array where every value contains non-ASCII Unicode text. +fn create_unicode_utf8_args(size: usize) -> Vec { + let array = Arc::new(StringArray::from_iter_values(std::iter::repeat_n( + "ñAnDÚ ÁrBOL ОлЕГ ÍslENsku", + size, + ))) as ArrayRef; + vec![ColumnarValue::Array(array)] +} + +/// Create a Utf8View array where every value contains non-ASCII Unicode text. +fn create_unicode_utf8view_args(size: usize) -> Vec { + let mut builder = StringViewBuilder::with_capacity(size); + for _ in 0..size { + builder.append_value("ñAnDÚ ÁrBOL ОлЕГ ÍslENsku"); + } + let array = Arc::new(builder.finish()) as ArrayRef; + vec![ColumnarValue::Array(array)] +} + fn criterion_benchmark(c: &mut Criterion) { let initcap = unicode::initcap(); - for size in [1024, 4096] { - let args = create_args::(size, 8, true); - let arg_fields = args - .iter() - .enumerate() - .map(|(idx, arg)| { - Field::new(format!("arg_{idx}"), arg.data_type(), true).into() - }) - .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); + let config_options = Arc::new(ConfigOptions::default()); + + // Array benchmarks: vary both row count and string length + for size in [1024, 4096, 8192] { + for str_len in [16, 128] { + let mut group = + c.benchmark_group(format!("initcap size={size} str_len={str_len}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // Utf8 + let array_args = create_args::(size, str_len, false); + let array_arg_fields = vec![Field::new("arg_0", DataType::Utf8, true).into()]; - c.bench_function( - format!("initcap string view shorter than 12 [size={size}]").as_str(), - |b| { + group.bench_function("array_utf8", |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), + args: array_args.clone(), + arg_fields: array_arg_fields.clone(), number_rows: size, - return_field: Field::new("f", DataType::Utf8View, true).into(), + return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), })) }) - }, - ); + }); + + // Utf8View + let array_view_args = create_args::(size, str_len, true); + let array_view_arg_fields = + vec![Field::new("arg_0", DataType::Utf8View, true).into()]; - let args = create_args::(size, 16, true); - c.bench_function( - format!("initcap string view longer than 12 [size={size}]").as_str(), - |b| { + group.bench_function("array_utf8view", |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), + args: array_view_args.clone(), + arg_fields: array_view_arg_fields.clone(), number_rows: size, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), })) }) - }, - ); + }); - let args = create_args::(size, 16, false); - c.bench_function(format!("initcap string [size={size}]").as_str(), |b| { + group.finish(); + } + } + + // Unicode array benchmarks + for size in [1024, 4096, 8192] { + let mut group = c.benchmark_group(format!("initcap unicode size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + let unicode_args = create_unicode_utf8_args(size); + let unicode_arg_fields = vec![Field::new("arg_0", DataType::Utf8, true).into()]; + + group.bench_function("array_utf8", |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), + args: unicode_args.clone(), + arg_fields: unicode_arg_fields.clone(), number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), })) }) }); + + let unicode_view_args = create_unicode_utf8view_args(size); + let unicode_view_arg_fields = + vec![Field::new("arg_0", DataType::Utf8View, true).into()]; + + group.bench_function("array_utf8view", |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: unicode_view_args.clone(), + arg_fields: unicode_view_arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); + + group.finish(); + } + + // Scalar benchmarks: independent of array size, run once + { + let mut group = c.benchmark_group("initcap scalar"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // Utf8 + let scalar_args = vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "hello world test string".to_string(), + )))]; + let scalar_arg_fields = vec![Field::new("arg_0", DataType::Utf8, false).into()]; + + group.bench_function("scalar_utf8", |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: scalar_args.clone(), + arg_fields: scalar_arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, false).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Utf8View + let scalar_view_args = vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "hello world test string".to_string(), + )))]; + let scalar_view_arg_fields = + vec![Field::new("arg_0", DataType::Utf8View, false).into()]; + + group.bench_function("scalar_utf8view", |b| { + b.iter(|| { + black_box(initcap.invoke_with_args(ScalarFunctionArgs { + args: scalar_view_args.clone(), + arg_fields: scalar_view_arg_fields.clone(), + number_rows: 1, + return_field: Field::new("f", DataType::Utf8View, false).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); + + group.finish(); } } diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index 4a90d45d66223..e353b9d27a0a1 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -15,14 +15,12 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::isnan; diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index 961cba7200ce0..c6d0aed4c615c 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -15,14 +15,13 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::iszero; @@ -31,6 +30,8 @@ use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let iszero = iszero(); + let config_options = Arc::new(ConfigOptions::default()); + for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f32_array.len(); @@ -43,7 +44,6 @@ fn criterion_benchmark(c: &mut Criterion) { }) .collect::>(); let return_field = Arc::new(Field::new("f", DataType::Boolean, true)); - let config_options = Arc::new(ConfigOptions::default()); c.bench_function(&format!("iszero f32 array: {size}"), |b| { b.iter(|| { @@ -60,6 +60,7 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f64_array.len(); let f64_args = vec![ColumnarValue::Array(f64_array)]; @@ -88,6 +89,46 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); } + + // Scalar benchmarks - run once since size doesn't affect scalar performance + let scalar_f32_args = vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0)))]; + let scalar_f32_arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let return_field_scalar = Arc::new(Field::new("f", DataType::Boolean, false)); + + c.bench_function("iszero f32 scalar", |b| { + b.iter(|| { + black_box( + iszero + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_scalar), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f64_args = vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0)))]; + let scalar_f64_arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + + c.bench_function("iszero f64 scalar", |b| { + b.iter(|| { + black_box( + iszero + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_f64_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_scalar), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/lcm.rs b/datafusion/functions/benches/lcm.rs new file mode 100644 index 0000000000000..247c0ec749d15 --- /dev/null +++ b/datafusion/functions/benches/lcm.rs @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::Field; +use arrow::{ + array::{ArrayRef, Int64Array}, + datatypes::DataType, +}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::math::lcm; +use rand::Rng; +use std::hint::black_box; +use std::sync::Arc; + +fn generate_i64_array(n_rows: usize) -> ArrayRef { + let mut rng = rand::rng(); + let values = (0..n_rows) + .map(|_| rng.random_range(0..1000)) + .collect::>(); + Arc::new(Int64Array::from(values)) as ArrayRef +} + +fn criterion_benchmark(c: &mut Criterion) { + let n_rows = 100000; + let array_a = ColumnarValue::Array(generate_i64_array(n_rows)); + let array_b = ColumnarValue::Array(generate_i64_array(n_rows)); + let udf = lcm(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function("lcm both array", |b| { + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: vec![array_a.clone(), array_b.clone()], + arg_fields: vec![ + Field::new("a", array_a.data_type(), true).into(), + Field::new("b", array_b.data_type(), true).into(), + ], + number_rows: n_rows, + return_field: Field::new("f", DataType::Int64, true).into(), + config_options: Arc::clone(&config_options), + }) + .expect("lcm should work on valid values"), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/left_right.rs b/datafusion/functions/benches/left_right.rs new file mode 100644 index 0000000000000..8d5865acb845e --- /dev/null +++ b/datafusion/functions/benches/left_right.rs @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::hint::black_box; +use std::ops::Range; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array}; +use arrow::datatypes::{DataType, Field}; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::unicode::{left, right}; + +const BATCH_SIZE: usize = 8192; + +fn create_args( + str_len: usize, + n_range: Range, + is_string_view: bool, +) -> Vec { + let string_arg = if is_string_view { + ColumnarValue::Array(Arc::new(create_string_view_array_with_len( + BATCH_SIZE, 0.1, str_len, true, + ))) + } else { + ColumnarValue::Array(Arc::new(create_string_array_with_len::( + BATCH_SIZE, 0.1, str_len, + ))) + }; + + let n_span = (n_range.end - n_range.start) as usize; + let n_values: Vec = (0..BATCH_SIZE) + .map(|i| n_range.start + (i % n_span) as i64) + .collect(); + let n_array = Arc::new(Int64Array::from(n_values)); + + vec![ + string_arg, + ColumnarValue::Array(Arc::clone(&n_array) as ArrayRef), + ] +} + +fn criterion_benchmark(c: &mut Criterion) { + // Short results (1-10 chars) produce inline StringView entries (≤12 bytes). + // Long results (20-29 chars) produce out-of-line entries. + let cases = [ + ("short_result", 32, 1..11_i64), + ("long_result", 32, 20..30_i64), + ]; + + for function in [left(), right()] { + let mut group = c.benchmark_group(function.name().to_string()); + + for is_string_view in [false, true] { + let array_type = if is_string_view { + "string_view" + } else { + "string" + }; + + for (case_name, str_len, n_range) in &cases { + let bench_name = format!("{array_type} {case_name}"); + let args = create_args(*str_len, n_range.clone(), is_string_view); + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect(); + let config_options = Arc::new(ConfigOptions::default()); + let return_field = Field::new("f", DataType::Utf8View, true).into(); + + group.bench_function(&bench_name, |b| { + b.iter(|| { + black_box( + function + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: BATCH_SIZE, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .expect("should work"), + ) + }) + }); + } + } + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/levenshtein.rs b/datafusion/functions/benches/levenshtein.rs new file mode 100644 index 0000000000000..08733b245ffb4 --- /dev/null +++ b/datafusion/functions/benches/levenshtein.rs @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::OffsetSizeTrait; +use arrow::datatypes::{DataType, Field}; +use arrow::util::bench_util::create_string_array_with_len; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::DataFusionError; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::string; +use std::hint::black_box; +use std::sync::Arc; +use std::time::Duration; + +fn create_args(size: usize, str_len: usize) -> Vec { + let string1_array = Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + let string2_array = Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string1_array), + ColumnarValue::Array(string2_array), + ] +} + +fn invoke_levenshtein_with_args( + args: Vec, + number_rows: usize, +) -> Result { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + string::levenshtein().invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Int32, true).into(), + config_options: Arc::clone(&config_options), + }) +} + +fn criterion_benchmark(c: &mut Criterion) { + for size in [1024, 4096] { + let mut group = c.benchmark_group(format!("levenshtein size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + for str_len in [8, 32] { + let args = create_args::(size, str_len); + group.bench_function( + format!("levenshtein_string [size={size}, str_len={str_len}]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(invoke_levenshtein_with_args(args_cloned, size)) + }) + }, + ); + } + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 6a5178b87fdce..2764491c69c71 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -15,14 +15,12 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - -use arrow::array::{ArrayRef, StringArray, StringViewBuilder}; +use arrow::array::{Array, ArrayRef, StringArray, StringViewBuilder}; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; @@ -197,6 +195,43 @@ fn criterion_benchmark(c: &mut Criterion) { ); } + { + let parent_size = 65536; + let slice_len = 128; + let str_len = 32; + let parent = Arc::new(create_string_array_with_len::( + parent_size, + 0.2, + str_len, + )) as ArrayRef; + let offset = (parent_size - slice_len) / 2; + let sliced = parent.slice(offset, slice_len); + let args = vec![ColumnarValue::Array(sliced)]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + c.bench_function( + &format!("lower_sliced_ascii: parent={parent_size}, slice={slice_len}, str_len={str_len}"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(lower.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: slice_len, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + } + let sizes = [4096, 8192]; let str_lens = [10, 64, 128]; let mixes = [true, false]; diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index 15a895468db93..1c7b61ec60497 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -15,24 +15,22 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; use arrow::array::{Array, ArrayRef, Int32Array}; use arrow::datatypes::{DataType, Field}; -use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion_common::config::ConfigOptions; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::datetime::make_date; -use rand::rngs::ThreadRng; use rand::Rng; +use rand::rngs::ThreadRng; fn years(rng: &mut ThreadRng) -> Int32Array { let mut years = vec![]; - for _ in 0..1000 { + for _ in 0..8192 { years.push(rng.random_range(1900..2050)); } @@ -41,7 +39,7 @@ fn years(rng: &mut ThreadRng) -> Int32Array { fn months(rng: &mut ThreadRng) -> Int32Array { let mut months = vec![]; - for _ in 0..1000 { + for _ in 0..8192 { months.push(rng.random_range(1..13)); } @@ -50,14 +48,14 @@ fn months(rng: &mut ThreadRng) -> Int32Array { fn days(rng: &mut ThreadRng) -> Int32Array { let mut days = vec![]; - for _ in 0..1000 { + for _ in 0..8192 { days.push(rng.random_range(1..29)); } Int32Array::from(days) } fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("make_date_col_col_col_1000", |b| { + c.bench_function("make_date_col_col_col_8192", |b| { let mut rng = rand::rng(); let years_array = Arc::new(years(&mut rng)) as ArrayRef; let batch_len = years_array.len(); @@ -87,7 +85,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("make_date_scalar_col_col_1000", |b| { + c.bench_function("make_date_scalar_col_col_8192", |b| { let mut rng = rand::rng(); let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); let months_arr = Arc::new(months(&mut rng)) as ArrayRef; @@ -117,7 +115,7 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("make_date_scalar_scalar_col_1000", |b| { + c.bench_function("make_date_scalar_scalar_col_8192", |b| { let mut rng = rand::rng(); let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); let month = ColumnarValue::Scalar(ScalarValue::Int32(Some(11))); diff --git a/datafusion/functions/benches/nanvl.rs b/datafusion/functions/benches/nanvl.rs new file mode 100644 index 0000000000000..206eebd81eb81 --- /dev/null +++ b/datafusion/functions/benches/nanvl.rs @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{ArrayRef, Float32Array, Float64Array}; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::math::nanvl; +use std::hint::black_box; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let nanvl_fn = nanvl(); + let config_options = Arc::new(ConfigOptions::default()); + + // Scalar benchmarks + c.bench_function("nanvl/scalar_f64", |b| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(f64::NAN))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0))), + ], + arg_fields: vec![ + Field::new("a", DataType::Float64, true).into(), + Field::new("b", DataType::Float64, true).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::clone(&config_options), + }; + + b.iter(|| black_box(nanvl_fn.invoke_with_args(args.clone()).unwrap())) + }); + + c.bench_function("nanvl/scalar_f32", |b| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(f32::NAN))), + ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0))), + ], + arg_fields: vec![ + Field::new("a", DataType::Float32, true).into(), + Field::new("b", DataType::Float32, true).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::clone(&config_options), + }; + + b.iter(|| black_box(nanvl_fn.invoke_with_args(args.clone()).unwrap())) + }); + + // Array benchmarks + for size in [1024, 4096, 8192] { + let a64: ArrayRef = Arc::new(Float64Array::from(vec![f64::NAN; size])); + let b64: ArrayRef = Arc::new(Float64Array::from(vec![1.0; size])); + c.bench_function(&format!("nanvl/array_f64/{size}"), |bench| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::clone(&a64)), + ColumnarValue::Array(Arc::clone(&b64)), + ], + arg_fields: vec![ + Field::new("a", DataType::Float64, true).into(), + Field::new("b", DataType::Float64, true).into(), + ], + number_rows: size, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::clone(&config_options), + }; + bench.iter(|| black_box(nanvl_fn.invoke_with_args(args.clone()).unwrap())) + }); + + let a32: ArrayRef = Arc::new(Float32Array::from(vec![f32::NAN; size])); + let b32: ArrayRef = Arc::new(Float32Array::from(vec![1.0; size])); + c.bench_function(&format!("nanvl/array_f32/{size}"), |bench| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(Arc::clone(&a32)), + ColumnarValue::Array(Arc::clone(&b32)), + ], + arg_fields: vec![ + Field::new("a", DataType::Float32, true).into(), + Field::new("b", DataType::Float32, true).into(), + ], + number_rows: size, + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::clone(&config_options), + }; + bench.iter(|| black_box(nanvl_fn.invoke_with_args(args.clone()).unwrap())) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index d649697cc5188..f9f063c52d0d4 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -15,13 +15,11 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; -use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion_common::config::ConfigOptions; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::core::nullif; use std::hint::black_box; diff --git a/datafusion/functions/benches/overlay.rs b/datafusion/functions/benches/overlay.rs new file mode 100644 index 0000000000000..0b7fff5989d1f --- /dev/null +++ b/datafusion/functions/benches/overlay.rs @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod helper; + +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF}; +use helper::gen_string_array; +use std::hint::black_box; +use std::sync::Arc; + +#[expect(clippy::too_many_arguments)] +fn bench_overlay( + c: &mut Criterion, + name: &str, + overlay: &ScalarUDF, + n_rows: usize, + null_density: f32, + utf8_density: f32, + is_string_view: bool, + with_for: bool, +) { + const STR_LEN: usize = 128; + + let mut args = + gen_string_array(n_rows, STR_LEN, null_density, utf8_density, is_string_view); + // The substring scalar's type must match the string column's type (the + // function dispatches per-type without coercion). + let substr = "DataFusion".to_string(); + let substr_scalar = if is_string_view { + ScalarValue::Utf8View(Some(substr)) + } else { + ScalarValue::Utf8(Some(substr)) + }; + args.push(ColumnarValue::Scalar(substr_scalar)); + args.push(ColumnarValue::Scalar(ScalarValue::Int64(Some(32)))); + if with_for { + args.push(ColumnarValue::Scalar(ScalarValue::Int64(Some(8)))); + } + + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let return_field = Arc::new(Field::new("f", DataType::Utf8, true)); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(name, |b| { + b.iter(|| { + black_box( + overlay + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + const N_ROWS: usize = 8192; + const MIXED_UTF8: f32 = 0.5; + let overlay = datafusion_functions::core::overlay(); + + // Null-density variants on StringArray (mixed ASCII/UTF-8, 4-arg form). + bench_overlay( + c, + "overlay_StringArray_low_nulls", + &overlay, + N_ROWS, + 0.1, + MIXED_UTF8, + false, + true, + ); + bench_overlay( + c, + "overlay_StringArray_high_nulls", + &overlay, + N_ROWS, + 0.9, + MIXED_UTF8, + false, + true, + ); + bench_overlay( + c, + "overlay_StringArray_no_nulls", + &overlay, + N_ROWS, + 0.0, + MIXED_UTF8, + false, + true, + ); + + // Content variants on StringArray (no nulls, 4-arg form). Pair against + // `overlay_StringArray_no_nulls` to isolate the impact of UTF-8 density. + bench_overlay( + c, + "overlay_StringArray_ascii", + &overlay, + N_ROWS, + 0.0, + 0.0, + false, + true, + ); + bench_overlay( + c, + "overlay_StringArray_all_utf8", + &overlay, + N_ROWS, + 0.0, + 1.0, + false, + true, + ); + + // 3-arg form (no FOR clause), where the replace length is derived from + // the substring per row. + bench_overlay( + c, + "overlay_StringArray_no_for", + &overlay, + N_ROWS, + 0.0, + MIXED_UTF8, + false, + false, + ); + + // StringViewArray counterparts. + bench_overlay( + c, + "overlay_StringViewArray_low_nulls", + &overlay, + N_ROWS, + 0.1, + MIXED_UTF8, + true, + true, + ); + bench_overlay( + c, + "overlay_StringViewArray_ascii", + &overlay, + N_ROWS, + 0.0, + 0.0, + true, + true, + ); + bench_overlay( + c, + "overlay_StringViewArray_all_utf8", + &overlay, + N_ROWS, + 0.0, + 1.0, + true, + true, + ); + bench_overlay( + c, + "overlay_StringViewArray_no_for", + &overlay, + N_ROWS, + 0.0, + MIXED_UTF8, + true, + false, + ); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index f92a69bbf4f92..c71d5a7161a66 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -15,20 +15,69 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, ArrowPrimitiveType, OffsetSizeTrait, PrimitiveArray}; +use arrow::array::{ + ArrowPrimitiveType, GenericStringBuilder, OffsetSizeTrait, PrimitiveArray, + StringViewBuilder, +}; use arrow::datatypes::{DataType, Field, Int64Type}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; -use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; -use datafusion_functions::unicode::{lpad, rpad}; -use rand::distr::{Distribution, Uniform}; +use datafusion_functions::unicode; use rand::Rng; +use rand::distr::{Distribution, Uniform}; use std::hint::black_box; use std::sync::Arc; +use std::time::Duration; + +const UNICODE_STRINGS: &[&str] = &[ + "Ñandú", + "Íslensku", + "Þjóðarinnar", + "Ελληνική", + "Иванович", + "データフュージョン", + "José García", + "Ölçü bïrïmï", + "Ÿéšṱëṟḏàÿ", + "Ährenstraße", +]; + +fn create_unicode_string_array( + size: usize, + null_density: f32, +) -> arrow::array::GenericStringArray { + let mut rng = rand::rng(); + let mut builder = GenericStringBuilder::::new(); + for i in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + builder.append_value(UNICODE_STRINGS[i % UNICODE_STRINGS.len()]); + } + } + builder.finish() +} + +fn create_unicode_string_view_array( + size: usize, + null_density: f32, +) -> arrow::array::StringViewArray { + let mut rng = rand::rng(); + let mut builder = StringViewBuilder::with_capacity(size); + for i in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + builder.append_value(UNICODE_STRINGS[i % UNICODE_STRINGS.len()]); + } + } + builder.finish() +} struct Filter { dist: Dist, @@ -67,103 +116,642 @@ where .collect() } -fn create_args( +/// Create args for pad benchmark with Unicode strings +fn create_unicode_pad_args( size: usize, - str_len: usize, - force_view_types: bool, + target_len: usize, + use_string_view: bool, ) -> Vec { - let length_array = Arc::new(create_primitive_array::(size, 0.0, str_len)); - - if !force_view_types { - let string_array = - Arc::new(create_string_array_with_len::(size, 0.1, str_len)); - let fill_array = Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + let length_array = + Arc::new(create_primitive_array::(size, 0.0, target_len)); + if use_string_view { + let string_array = create_unicode_string_view_array(size, 0.1); + let fill_array = create_unicode_string_view_array(size, 0.1); vec![ - ColumnarValue::Array(string_array), - ColumnarValue::Array(Arc::clone(&length_array) as ArrayRef), - ColumnarValue::Array(fill_array), + ColumnarValue::Array(Arc::new(string_array)), + ColumnarValue::Array(length_array), + ColumnarValue::Array(Arc::new(fill_array)), ] } else { - let string_array = - Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); - let fill_array = - Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); - + let string_array = create_unicode_string_array::(size, 0.1); + let fill_array = create_unicode_string_array::(size, 0.1); vec![ - ColumnarValue::Array(string_array), - ColumnarValue::Array(Arc::clone(&length_array) as ArrayRef), - ColumnarValue::Array(fill_array), + ColumnarValue::Array(Arc::new(string_array)), + ColumnarValue::Array(length_array), + ColumnarValue::Array(Arc::new(fill_array)), ] } } -fn invoke_pad_with_args( - args: Vec, - number_rows: usize, - left_pad: bool, -) -> Result { - let arg_fields = args - .iter() - .enumerate() - .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) - .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); - - let scalar_args = ScalarFunctionArgs { - args: args.clone(), - arg_fields, - number_rows, - return_field: Field::new("f", DataType::Utf8, true).into(), - config_options: Arc::clone(&config_options), - }; +/// Create args for pad benchmark +fn create_pad_args( + size: usize, + str_len: usize, + target_len: usize, + use_string_view: bool, +) -> Vec { + let length_array = + Arc::new(create_primitive_array::(size, 0.0, target_len)); + + if use_string_view { + let string_array = create_string_view_array_with_len(size, 0.1, str_len, false); + let fill_array = create_string_view_array_with_len(size, 0.1, str_len, false); + vec![ + ColumnarValue::Array(Arc::new(string_array)), + ColumnarValue::Array(length_array), + ColumnarValue::Array(Arc::new(fill_array)), + ] + } else { + let string_array = create_string_array_with_len::(size, 0.1, str_len); + let fill_array = create_string_array_with_len::(size, 0.1, str_len); + vec![ + ColumnarValue::Array(Arc::new(string_array)), + ColumnarValue::Array(length_array), + ColumnarValue::Array(Arc::new(fill_array)), + ] + } +} - if left_pad { - lpad().invoke_with_args(scalar_args) +/// Create args for pad benchmark with scalar length and fill (common pattern: +/// `lpad(column, 20, '0')`). +fn create_scalar_pad_args( + size: usize, + str_len: usize, + target_len: i64, + fill: &str, + use_string_view: bool, +) -> Vec { + if use_string_view { + let string_array = create_string_view_array_with_len(size, 0.1, str_len, false); + vec![ + ColumnarValue::Array(Arc::new(string_array)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(target_len))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(fill.to_string()))), + ] } else { - rpad().invoke_with_args(scalar_args) + let string_array = create_string_array_with_len::(size, 0.1, str_len); + vec![ + ColumnarValue::Array(Arc::new(string_array)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(target_len))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(fill.to_string()))), + ] } } fn criterion_benchmark(c: &mut Criterion) { - for size in [1024, 2048] { - let mut group = c.benchmark_group("lpad function"); + for size in [1024, 4096] { + let mut group = c.benchmark_group(format!("lpad size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // Utf8 type + let args = create_pad_args::(size, 5, 20, false); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + group.bench_function( + format!("lpad utf8 [size={size}, str_len=5, target=20]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::lpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // StringView type + let args = create_pad_args::(size, 5, 20, true); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!("lpad stringview [size={size}, str_len=5, target=20]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::lpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // Utf8 type with longer strings + let args = create_pad_args::(size, 20, 50, false); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!("lpad utf8 [size={size}, str_len=20, target=50]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::lpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // StringView type with longer strings + let args = create_pad_args::(size, 20, 50, true); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!("lpad stringview [size={size}, str_len=20, target=50]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::lpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // Utf8 type with Unicode strings + let args = create_unicode_pad_args(size, 20, false); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); - let args = create_args::(size, 32, false); + group.bench_function( + format!("lpad utf8 unicode [size={size}, target=20]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::lpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); - group.bench_function(BenchmarkId::new("utf8 type", size), |b| { - b.iter(|| black_box(invoke_pad_with_args(args.clone(), size, true).unwrap())) - }); + // StringView type with Unicode strings + let args = create_unicode_pad_args(size, 20, true); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); - let args = create_args::(size, 32, false); - group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { - b.iter(|| black_box(invoke_pad_with_args(args.clone(), size, true).unwrap())) - }); + group.bench_function( + format!("lpad stringview unicode [size={size}, target=20]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::lpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); - let args = create_args::(size, 32, true); - group.bench_function(BenchmarkId::new("stringview type", size), |b| { - b.iter(|| black_box(invoke_pad_with_args(args.clone(), size, true).unwrap())) - }); + // --- Scalar length + fill benchmarks --- + + // Utf8 with scalar length and fill (3-arg) + let args = create_scalar_pad_args::(size, 5, 20, "x", false); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!("lpad utf8 scalar [size={size}, str_len=5, target=20, fill='x']"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::lpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // StringView with scalar length and fill (3-arg) + let args = create_scalar_pad_args::(size, 5, 20, "x", true); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!( + "lpad stringview scalar [size={size}, str_len=5, target=20, fill='x']" + ), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::lpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // Utf8 with scalar length and unicode fill + let args = create_scalar_pad_args::(size, 5, 20, "é", false); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!( + "lpad utf8 scalar unicode [size={size}, str_len=5, target=20, fill='é']" + ), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::lpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // Utf8 with scalar truncation (str_len > target) and unicode fill + let args = create_scalar_pad_args::(size, 20, 5, "é", false); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!( + "lpad utf8 scalar truncate [size={size}, str_len=20, target=5, fill='é']" + ), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::lpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); group.finish(); + } + + for size in [1024, 4096] { + let mut group = c.benchmark_group(format!("rpad size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // Utf8 type + let args = create_pad_args::(size, 5, 20, false); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + group.bench_function( + format!("rpad utf8 [size={size}, str_len=5, target=20]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::rpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // StringView type + let args = create_pad_args::(size, 5, 20, true); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!("rpad stringview [size={size}, str_len=5, target=20]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::rpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // Utf8 type with longer strings + let args = create_pad_args::(size, 20, 50, false); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!("rpad utf8 [size={size}, str_len=20, target=50]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::rpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // StringView type with longer strings + let args = create_pad_args::(size, 20, 50, true); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!("rpad stringview [size={size}, str_len=20, target=50]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::rpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // Utf8 type with Unicode strings + let args = create_unicode_pad_args(size, 20, false); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!("rpad utf8 unicode [size={size}, target=20]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::rpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // StringView type with Unicode strings + let args = create_unicode_pad_args(size, 20, true); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!("rpad stringview unicode [size={size}, target=20]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::rpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // --- Scalar length + fill benchmarks --- + + // Utf8 with scalar length and fill (3-arg) + let args = create_scalar_pad_args::(size, 5, 20, "x", false); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!("rpad utf8 scalar [size={size}, str_len=5, target=20, fill='x']"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::rpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + + // StringView with scalar length and fill (3-arg) + let args = create_scalar_pad_args::(size, 5, 20, "x", true); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + + group.bench_function( + format!( + "rpad stringview scalar [size={size}, str_len=5, target=20, fill='x']" + ), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::rpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); - let mut group = c.benchmark_group("rpad function"); + // Utf8 with scalar length and unicode fill + let args = create_scalar_pad_args::(size, 5, 20, "é", false); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); - let args = create_args::(size, 32, false); - group.bench_function(BenchmarkId::new("utf8 type", size), |b| { - b.iter(|| black_box(invoke_pad_with_args(args.clone(), size, false).unwrap())) - }); + group.bench_function( + format!( + "rpad utf8 scalar unicode [size={size}, str_len=5, target=20, fill='é']" + ), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::rpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); - let args = create_args::(size, 32, false); - group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { - b.iter(|| black_box(invoke_pad_with_args(args.clone(), size, false).unwrap())) - }); + // Utf8 with scalar truncation (str_len > target) and unicode fill + let args = create_scalar_pad_args::(size, 20, 5, "é", false); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); - // rpad for stringview type - let args = create_args::(size, 32, true); - group.bench_function(BenchmarkId::new("stringview type", size), |b| { - b.iter(|| black_box(invoke_pad_with_args(args.clone(), size, false).unwrap())) - }); + group.bench_function( + format!( + "rpad utf8 scalar truncate [size={size}, str_len=20, target=5, fill='é']" + ), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(unicode::rpad().invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); group.finish(); } diff --git a/datafusion/functions/benches/power.rs b/datafusion/functions/benches/power.rs new file mode 100644 index 0000000000000..5336e42ebe59b --- /dev/null +++ b/datafusion/functions/benches/power.rs @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Microbenchmark for `power(decimal_array, int_*)`. +//! +//! Covers both array- and scalar-shaped integer exponents on a Decimal +//! base. Both shapes are dispatched to the native per-row decimal kernel; +//! the bench guards against any future change that routes either shape +//! through a Float64 round-trip, which is measurably slower than the +//! decimal kernel for the cases the kernel can handle. + +extern crate criterion; + +use arrow::array::{Decimal128Array, Int64Array}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF}; +use datafusion_functions::math::power; +use std::hint::black_box; +use std::sync::Arc; + +fn make_decimal_array(size: usize, precision: u8, scale: i8) -> Decimal128Array { + // Use a fixed unscaled value (250) so the bench is independent of `scale`. + // The four-arm dispatch in `power` only cares about the Decimal variant + // and the exponent's shape, not the numeric value. + let arr = Decimal128Array::from(vec![250i128; size]); + arr.with_precision_and_scale(precision, scale).unwrap() +} + +fn make_int_array(size: usize, value: i64) -> Int64Array { + Int64Array::from(vec![value; size]) +} + +fn run_power( + power_fn: &ScalarUDF, + args: &[ColumnarValue], + arg_fields: &[FieldRef], + return_field: &FieldRef, + config_options: &Arc, + num_rows: usize, +) { + black_box( + power_fn + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + arg_fields: arg_fields.to_vec(), + number_rows: num_rows, + return_field: Arc::clone(return_field), + config_options: Arc::clone(config_options), + }) + .unwrap(), + ); +} + +fn criterion_benchmark(c: &mut Criterion) { + let power_fn = power(); + let config_options = Arc::new(ConfigOptions::default()); + let precision: u8 = 20; + let scale: i8 = 2; + let decimal_ty = DataType::Decimal128(precision, scale); + + // Exponents are bounded by what the native decimal kernel can handle + // without overflowing the i128 intermediate; see + // + let exponents = [2i64, 4, 8]; + + for size in [1024usize, 8192] { + let base_arr = Arc::new(make_decimal_array(size, precision, scale)); + let base_field: FieldRef = Field::new("base", decimal_ty.clone(), true).into(); + let exp_field: FieldRef = Field::new("exp", DataType::Int64, true).into(); + let return_field: FieldRef = Field::new("r", decimal_ty.clone(), true).into(); + let arg_fields = vec![base_field, exp_field]; + + for &exp in &exponents { + let exp_arr = Arc::new(make_int_array(size, exp)); + let array_args = vec![ + ColumnarValue::Array(base_arr.clone()), + ColumnarValue::Array(exp_arr), + ]; + c.bench_function( + &format!( + "power decimal({precision},{scale}) array x int array, exp={exp}, n={size}" + ), + |b| { + b.iter(|| { + run_power( + &power_fn, + &array_args, + &arg_fields, + &return_field, + &config_options, + size, + ) + }) + }, + ); + + let scalar_args = vec![ + ColumnarValue::Array(base_arr.clone()), + ColumnarValue::Scalar(ScalarValue::Int64(Some(exp))), + ]; + c.bench_function( + &format!( + "power decimal({precision},{scale}) array x int scalar, exp={exp}, n={size}" + ), + |b| { + b.iter(|| { + run_power( + &power_fn, + &scalar_args, + &arg_fields, + &return_field, + &config_options, + size, + ) + }) + }, + ); + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 88efb2d1b5b93..71ded120eb515 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -15,10 +15,8 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; use datafusion_functions::math::random::RandomFunc; diff --git a/datafusion/functions/benches/regexp_count.rs b/datafusion/functions/benches/regexp_count.rs new file mode 100644 index 0000000000000..bce76c05585b9 --- /dev/null +++ b/datafusion/functions/benches/regexp_count.rs @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Int64Array; +use arrow::array::OffsetSizeTrait; +use arrow::datatypes::{DataType, Field}; +use arrow::util::bench_util::create_string_array_with_len; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::regex; +use std::hint::black_box; +use std::sync::Arc; +use std::time::Duration; + +fn create_args( + size: usize, + str_len: usize, + with_start: bool, +) -> Vec { + let string_array = Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + // Use a simple pattern that matches common characters + let pattern = ColumnarValue::Scalar(ScalarValue::Utf8(Some("a".to_string()))); + + if with_start { + // Test with start position (this is where the optimization matters) + let start_array = Arc::new(Int64Array::from( + (0..size).map(|i| (i % 10 + 1) as i64).collect::>(), + )); + vec![ + ColumnarValue::Array(string_array), + pattern, + ColumnarValue::Array(start_array), + ] + } else { + vec![ColumnarValue::Array(string_array), pattern] + } +} + +fn invoke_regexp_count_with_args( + args: Vec, + number_rows: usize, +) -> Result { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + regex::regexp_count().invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Int64, true).into(), + config_options: Arc::clone(&config_options), + }) +} + +fn criterion_benchmark(c: &mut Criterion) { + for size in [1024, 4096] { + let mut group = c.benchmark_group(format!("regexp_count size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // Test without start position (no optimization impact) + for str_len in [32, 128] { + let args = create_args::(size, str_len, false); + group.bench_function( + format!("regexp_count_no_start [size={size}, str_len={str_len}]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(invoke_regexp_count_with_args(args_cloned, size)) + }) + }, + ); + } + + // Test with start position (optimization should help here) + for str_len in [32, 128] { + let args = create_args::(size, str_len, true); + group.bench_function( + format!("regexp_count_with_start [size={size}, str_len={str_len}]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(invoke_regexp_count_with_args(args_cloned, size)) + }) + }, + ); + } + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index a415330245bf5..a46b548236d08 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -15,25 +15,27 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; +use std::hint::black_box; +use std::iter; +use std::sync::Arc; use arrow::array::builder::StringBuilder; use arrow::array::{ArrayRef, AsArray, Int64Array, StringArray, StringViewArray}; use arrow::compute::cast; -use arrow::datatypes::DataType; -use criterion::{criterion_group, criterion_main, Criterion}; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use datafusion_functions::regex::regexpcount::regexp_count_func; use datafusion_functions::regex::regexpinstr::regexp_instr_func; -use datafusion_functions::regex::regexplike::regexp_like; +use datafusion_functions::regex::regexplike::{RegexpLikeFunc, regexp_like}; use datafusion_functions::regex::regexpmatch::regexp_match; use datafusion_functions::regex::regexpreplace::regexp_replace; +use rand::Rng; use rand::distr::Alphanumeric; use rand::prelude::IndexedRandom; use rand::rngs::ThreadRng; -use rand::Rng; -use std::hint::black_box; -use std::iter; -use std::sync::Arc; fn data(rng: &mut ThreadRng) -> StringArray { let mut data: Vec = vec![]; for _ in 0..1000 { @@ -107,6 +109,8 @@ fn subexp(rng: &mut ThreadRng) -> Int64Array { } fn criterion_benchmark(c: &mut Criterion) { + let regexp_like_func = RegexpLikeFunc::new(); + let config_options = Arc::new(ConfigOptions::default()); c.bench_function("regexp_count_1000 string", |b| { let mut rng = rand::rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; @@ -221,6 +225,32 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + let scalar_args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("foobarbequebaz".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("(bar)(beque)".to_string()))), + ]; + let scalar_arg_fields = vec![ + Field::new("arg_0", DataType::Utf8, false).into(), + Field::new("arg_1", DataType::Utf8, false).into(), + ]; + let return_field = Field::new("f", DataType::Boolean, true).into(); + + c.bench_function("regexp_like scalar utf8", |b| { + b.iter(|| { + black_box( + regexp_like_func + .invoke_with_args(ScalarFunctionArgs { + args: scalar_args.clone(), + arg_fields: scalar_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .expect("regexp_like scalar should work on valid values"), + ) + }) + }); + c.bench_function("regexp_match_1000", |b| { let mut rng = rand::rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 80ffa8ee38f1a..354812c0d2ea2 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -15,16 +15,15 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; -use criterion::{criterion_group, criterion_main, Criterion, SamplingMode}; -use datafusion_common::config::ConfigOptions; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; use datafusion_common::DataFusionError; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::hint::black_box; @@ -80,6 +79,44 @@ fn invoke_repeat_with_args( } fn criterion_benchmark(c: &mut Criterion) { + let repeat_fn = string::repeat(); + let config_options = Arc::new(ConfigOptions::default()); + + // Scalar benchmarks (outside loop) + c.bench_function("repeat/scalar_utf8", |b| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("hello".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ], + arg_fields: vec![ + Field::new("a", DataType::Utf8, false).into(), + Field::new("b", DataType::Int64, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }; + b.iter(|| black_box(repeat_fn.invoke_with_args(args.clone()).unwrap())) + }); + + c.bench_function("repeat/scalar_utf8view", |b| { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("hello".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ], + arg_fields: vec![ + Field::new("a", DataType::Utf8View, false).into(), + Field::new("b", DataType::Int64, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }; + b.iter(|| black_box(repeat_fn.invoke_with_args(args.clone()).unwrap())) + }); + for size in [1024, 4096] { // REPEAT 3 TIMES let repeat_times = 3; diff --git a/datafusion/functions/benches/replace.rs b/datafusion/functions/benches/replace.rs new file mode 100644 index 0000000000000..7ad198995a028 --- /dev/null +++ b/datafusion/functions/benches/replace.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{GenericStringArray, OffsetSizeTrait, StringViewArray}; +use arrow::datatypes::{DataType, Field}; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::DataFusionError; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::string; +use std::hint::black_box; +use std::sync::Arc; +use std::time::Duration; + +/// Build a string array, dropping the null buffer when `null_density == 0.0` +fn make_string_array( + size: usize, + null_density: f32, + str_len: usize, +) -> GenericStringArray { + let arr = create_string_array_with_len::(size, null_density, str_len); + if null_density == 0.0 { + let (offsets, values, _) = arr.into_parts(); + GenericStringArray::::new(offsets, values, None) + } else { + arr + } +} + +fn make_string_view_array( + size: usize, + null_density: f32, + str_len: usize, +) -> StringViewArray { + let arr = create_string_view_array_with_len(size, null_density, str_len, false); + if null_density == 0.0 { + let (views, buffers, _) = arr.into_parts(); + StringViewArray::new(views, buffers, None) + } else { + arr + } +} + +fn create_args( + size: usize, + str_len: usize, + force_view_types: bool, + from_len: usize, + to_len: usize, + null_density: f32, +) -> Vec { + // Apply `null_density` only to the string column; `from` and `to` are + // typically not NULL in real-world workloads. + if force_view_types { + let string_array = Arc::new(make_string_view_array(size, null_density, str_len)); + let from_array = Arc::new(make_string_view_array(size, 0.0, from_len)); + let to_array = Arc::new(make_string_view_array(size, 0.0, to_len)); + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(from_array), + ColumnarValue::Array(to_array), + ] + } else { + let string_array = Arc::new(make_string_array::(size, null_density, str_len)); + let from_array = Arc::new(make_string_array::(size, 0.0, from_len)); + let to_array = Arc::new(make_string_array::(size, 0.0, to_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(from_array), + ColumnarValue::Array(to_array), + ] + } +} + +fn invoke_replace_with_args( + args: Vec, + number_rows: usize, +) -> Result { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + string::replace().invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) +} + +fn criterion_benchmark(c: &mut Criterion) { + for size in [1024, 4096] { + let mut group = c.benchmark_group(format!("replace size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + for &nulls in &[0.0_f32, 0.2] { + for &str_len in &[32_usize, 128] { + // ASCII single character replacement (fast path) + let args = create_args::(size, str_len, false, 1, 1, nulls); + group.bench_function( + format!( + "replace_string_ascii_single [size={size}, str_len={str_len}, nulls={nulls}]" + ), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(invoke_replace_with_args(args_cloned, size)) + }) + }, + ); + + // Multi-character strings (general path) + let args = create_args::(size, str_len, true, 3, 5, nulls); + group.bench_function( + format!( + "replace_string_view [size={size}, str_len={str_len}, nulls={nulls}]" + ), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(invoke_replace_with_args(args_cloned, size)) + }) + }, + ); + + let args = create_args::(size, str_len, false, 3, 5, nulls); + group.bench_function( + format!( + "replace_string [size={size}, str_len={str_len}, nulls={nulls}]" + ), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(invoke_replace_with_args(args_cloned, size)) + }) + }, + ); + } + } + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index b1eca654fb254..f2e2898bbfe43 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -15,11 +15,10 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; mod helper; use arrow::datatypes::{DataType, Field}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; use datafusion_expr::ScalarFunctionArgs; use helper::gen_string_array; diff --git a/datafusion/functions/benches/round.rs b/datafusion/functions/benches/round.rs new file mode 100644 index 0000000000000..7010aa3507dbc --- /dev/null +++ b/datafusion/functions/benches/round.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Float32Type, Float64Type}; +use arrow::util::bench_util::create_primitive_array; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::math::round; +use std::hint::black_box; +use std::sync::Arc; +use std::time::Duration; + +fn criterion_benchmark(c: &mut Criterion) { + let round_fn = round(); + let config_options = Arc::new(ConfigOptions::default()); + + for size in [1024, 4096, 8192] { + let mut group = c.benchmark_group(format!("round size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // Float64 array benchmark + let f64_array = Arc::new(create_primitive_array::(size, 0.1)); + let batch_len = f64_array.len(); + let f64_args = vec![ + ColumnarValue::Array(f64_array), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + + group.bench_function("round_f64_array", |b| { + b.iter(|| { + let args_cloned = f64_args.clone(); + black_box( + round_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, true).into(), + Field::new("b", DataType::Int32, false).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + // Float32 array benchmark + let f32_array = Arc::new(create_primitive_array::(size, 0.1)); + let f32_args = vec![ + ColumnarValue::Array(f32_array), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + + group.bench_function("round_f32_array", |b| { + b.iter(|| { + let args_cloned = f32_args.clone(); + black_box( + round_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float32, true).into(), + Field::new("b", DataType::Int32, false).into(), + ], + number_rows: batch_len, + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + // Scalar benchmark (the optimization we added) + let scalar_f64_args = vec![ + ColumnarValue::Scalar(ScalarValue::Float64(Some(std::f64::consts::PI))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + + group.bench_function("round_f64_scalar", |b| { + b.iter(|| { + let args_cloned = scalar_f64_args.clone(); + black_box( + round_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float64, false).into(), + Field::new("b", DataType::Int32, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false) + .into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f32_args = vec![ + ColumnarValue::Scalar(ScalarValue::Float32(Some(std::f32::consts::PI))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ]; + + group.bench_function("round_f32_scalar", |b| { + b.iter(|| { + let args_cloned = scalar_f32_args.clone(); + black_box( + round_fn + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![ + Field::new("a", DataType::Float32, false).into(), + Field::new("b", DataType::Int32, false).into(), + ], + number_rows: 1, + return_field: Field::new("f", DataType::Float32, false) + .into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 24b8861e4d28c..e98d1b2c22ea2 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -15,14 +15,13 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::DataType; use arrow::{ datatypes::{Field, Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::signum; @@ -88,6 +87,51 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + // Scalar benchmarks (the optimization we added) + let scalar_f32_args = + vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(-42.5)))]; + let scalar_f32_arg_fields = + vec![Field::new("a", DataType::Float32, false).into()]; + let return_field_f32 = Field::new("f", DataType::Float32, false).into(); + + c.bench_function(&format!("signum f32 scalar: {size}"), |b| { + b.iter(|| { + black_box( + signum + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f32), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f64_args = + vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(-42.5)))]; + let scalar_f64_arg_fields = + vec![Field::new("a", DataType::Float64, false).into()]; + let return_field_f64 = Field::new("f", DataType::Float64, false).into(); + + c.bench_function(&format!("signum f64 scalar: {size}"), |b| { + b.iter(|| { + black_box( + signum + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_f64_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&return_field_f64), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); } } diff --git a/datafusion/functions/benches/split_part.rs b/datafusion/functions/benches/split_part.rs new file mode 100644 index 0000000000000..d578339368768 --- /dev/null +++ b/datafusion/functions/benches/split_part.rs @@ -0,0 +1,272 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int64Array, StringArray, StringViewArray}; +use arrow::datatypes::{DataType, Field}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF}; +use datafusion_functions::string::split_part; +use rand::distr::Alphanumeric; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +const N_ROWS: usize = 8192; + +/// Creates an array of strings with `num_parts` random alphanumeric segments +/// of `part_len` bytes each, joined by `delimiter`. +fn gen_string_array( + n_rows: usize, + num_parts: usize, + part_len: usize, + delimiter: &str, + use_string_view: bool, +) -> ColumnarValue { + let mut rng = StdRng::seed_from_u64(42); + + let mut strings: Vec = Vec::with_capacity(n_rows); + for _ in 0..n_rows { + let mut parts: Vec = Vec::with_capacity(num_parts); + for _ in 0..num_parts { + let part: String = (&mut rng) + .sample_iter(&Alphanumeric) + .take(part_len) + .map(char::from) + .collect(); + parts.push(part); + } + strings.push(parts.join(delimiter)); + } + + if use_string_view { + let string_array: StringViewArray = strings.into_iter().map(Some).collect(); + ColumnarValue::Array(Arc::new(string_array) as ArrayRef) + } else { + let string_array: StringArray = strings.into_iter().map(Some).collect(); + ColumnarValue::Array(Arc::new(string_array) as ArrayRef) + } +} + +#[expect(clippy::too_many_arguments)] +fn bench_split_part( + group: &mut criterion::BenchmarkGroup<'_, criterion::measurement::WallTime>, + func: &ScalarUDF, + config_options: &Arc, + name: &str, + tag: &str, + strings: ColumnarValue, + delimiter: ColumnarValue, + position: ColumnarValue, +) { + let args = vec![strings, delimiter, position]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect(); + let return_type = match args[0].data_type() { + DataType::Utf8View => DataType::Utf8View, + _ => DataType::Utf8, + }; + let return_field = Field::new("f", return_type, true).into(); + + group.bench_function(BenchmarkId::new(name, tag), |b| { + b.iter(|| { + black_box( + func.invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(config_options), + }) + .expect("split_part should work"), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let split_part_func = split_part(); + let config_options = Arc::new(ConfigOptions::default()); + let mut group = c.benchmark_group("split_part"); + + // ── Scalar delimiter and position ──────────────── + + // Utf8, single-char delimiter, scalar args + { + let strings = gen_string_array(N_ROWS, 10, 8, ".", false); + let delimiter = ColumnarValue::Scalar(ScalarValue::Utf8(Some(".".into()))); + let position = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); + bench_split_part( + &mut group, + &split_part_func, + &config_options, + "scalar_utf8_single_char", + "pos_first", + strings, + delimiter, + position, + ); + } + + { + let strings = gen_string_array(N_ROWS, 10, 8, ".", false); + let delimiter = ColumnarValue::Scalar(ScalarValue::Utf8(Some(".".into()))); + let position = ColumnarValue::Scalar(ScalarValue::Int64(Some(5))); + bench_split_part( + &mut group, + &split_part_func, + &config_options, + "scalar_utf8_single_char", + "pos_middle", + strings, + delimiter, + position, + ); + } + + { + let strings = gen_string_array(N_ROWS, 10, 8, ".", false); + let delimiter = ColumnarValue::Scalar(ScalarValue::Utf8(Some(".".into()))); + let position = ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))); + bench_split_part( + &mut group, + &split_part_func, + &config_options, + "scalar_utf8_single_char", + "pos_negative", + strings, + delimiter, + position, + ); + } + + // Utf8, multi-char delimiter, scalar args + { + let strings = gen_string_array(N_ROWS, 10, 8, "~@~", false); + let delimiter = ColumnarValue::Scalar(ScalarValue::Utf8(Some("~@~".into()))); + let position = ColumnarValue::Scalar(ScalarValue::Int64(Some(5))); + bench_split_part( + &mut group, + &split_part_func, + &config_options, + "scalar_utf8_multi_char", + "pos_middle", + strings, + delimiter, + position, + ); + } + + // Utf8, long strings, scalar args + { + let strings = gen_string_array(N_ROWS, 50, 16, ".", false); + let delimiter = ColumnarValue::Scalar(ScalarValue::Utf8(Some(".".into()))); + let position = ColumnarValue::Scalar(ScalarValue::Int64(Some(25))); + bench_split_part( + &mut group, + &split_part_func, + &config_options, + "scalar_utf8_long_strings", + "pos_middle", + strings, + delimiter, + position, + ); + } + + // Utf8View, long parts, scalar args + { + let strings = gen_string_array(N_ROWS, 10, 32, ".", true); + let delimiter = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(".".into()))); + let position = ColumnarValue::Scalar(ScalarValue::Int64(Some(5))); + bench_split_part( + &mut group, + &split_part_func, + &config_options, + "scalar_utf8view_long_parts", + "pos_middle", + strings, + delimiter, + position, + ); + } + + // Utf8View, very long parts (256 bytes), position 1 + { + let strings = gen_string_array(N_ROWS, 5, 256, ".", true); + let delimiter = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(".".into()))); + let position = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); + bench_split_part( + &mut group, + &split_part_func, + &config_options, + "scalar_utf8view_very_long_parts", + "pos_first", + strings, + delimiter, + position, + ); + } + + // ── Array delimiter and position ───────────────── + + // Utf8, single-char delimiter, array args + { + let strings = gen_string_array(N_ROWS, 10, 8, ".", false); + let delimiters: StringArray = vec![Some("."); N_ROWS].into_iter().collect(); + let delimiter = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef); + let positions = ColumnarValue::Array(Arc::new(Int64Array::from(vec![5; N_ROWS]))); + bench_split_part( + &mut group, + &split_part_func, + &config_options, + "array_utf8_single_char", + "pos_middle", + strings, + delimiter, + positions, + ); + } + + // Utf8, multi-char delimiter, array args + { + let strings = gen_string_array(N_ROWS, 10, 8, "~@~", false); + let delimiters: StringArray = vec![Some("~@~"); N_ROWS].into_iter().collect(); + let delimiter = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef); + let positions = ColumnarValue::Array(Arc::new(Int64Array::from(vec![5; N_ROWS]))); + bench_split_part( + &mut group, + &split_part_func, + &config_options, + "array_utf8_multi_char", + "pos_middle", + strings, + delimiter, + positions, + ); + } + + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/starts_with.rs b/datafusion/functions/benches/starts_with.rs new file mode 100644 index 0000000000000..17483f0da7a07 --- /dev/null +++ b/datafusion/functions/benches/starts_with.rs @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{StringArray, StringViewArray}; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use rand::distr::Alphanumeric; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +/// Generate a StringArray/StringViewArray with random ASCII strings +fn gen_string_array( + n_rows: usize, + str_len: usize, + is_string_view: bool, +) -> ColumnarValue { + let mut rng = StdRng::seed_from_u64(42); + let strings: Vec> = (0..n_rows) + .map(|_| { + let s: String = (&mut rng) + .sample_iter(&Alphanumeric) + .take(str_len) + .map(char::from) + .collect(); + Some(s) + }) + .collect(); + + if is_string_view { + ColumnarValue::Array(Arc::new(StringViewArray::from(strings))) + } else { + ColumnarValue::Array(Arc::new(StringArray::from(strings))) + } +} + +/// Generate a scalar prefix string +fn gen_scalar_prefix(prefix_str: &str, is_string_view: bool) -> ColumnarValue { + if is_string_view { + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(prefix_str.to_string()))) + } else { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(prefix_str.to_string()))) + } +} + +/// Generate an array of prefix strings (same string repeated) +fn gen_array_prefix( + prefix_str: &str, + n_rows: usize, + is_string_view: bool, +) -> ColumnarValue { + let strings: Vec> = + (0..n_rows).map(|_| Some(prefix_str.to_string())).collect(); + + if is_string_view { + ColumnarValue::Array(Arc::new(StringViewArray::from(strings))) + } else { + ColumnarValue::Array(Arc::new(StringArray::from(strings))) + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let starts_with = datafusion_functions::string::starts_with(); + let n_rows = 8192; + let str_len = 128; + let prefix_str = "xyz"; // A pattern that likely won't match + + // Benchmark: StringArray with scalar prefix (the optimized path) + let str_array = gen_string_array(n_rows, str_len, false); + let scalar_prefix = gen_scalar_prefix(prefix_str, false); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ]; + let return_field = Field::new("f", DataType::Boolean, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function("starts_with_StringArray_scalar_prefix", |b| { + b.iter(|| { + black_box(starts_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), scalar_prefix.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringArray with array prefix (for comparison) + let array_prefix = gen_array_prefix(prefix_str, n_rows, false); + c.bench_function("starts_with_StringArray_array_prefix", |b| { + b.iter(|| { + black_box(starts_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), array_prefix.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringViewArray with scalar prefix (the optimized path) + let str_view_array = gen_string_array(n_rows, str_len, true); + let scalar_prefix_view = gen_scalar_prefix(prefix_str, true); + let arg_fields_view = vec![ + Field::new("a", DataType::Utf8View, true).into(), + Field::new("b", DataType::Utf8View, true).into(), + ]; + + c.bench_function("starts_with_StringViewArray_scalar_prefix", |b| { + b.iter(|| { + black_box(starts_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_view_array.clone(), scalar_prefix_view.clone()], + arg_fields: arg_fields_view.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark: StringViewArray with array prefix (for comparison) + let array_prefix_view = gen_array_prefix(prefix_str, n_rows, true); + c.bench_function("starts_with_StringViewArray_array_prefix", |b| { + b.iter(|| { + black_box(starts_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_view_array.clone(), array_prefix_view.clone()], + arg_fields: arg_fields_view.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); + + // Benchmark different string lengths with scalar prefix + for str_len in [8, 32, 128, 512] { + let str_array = gen_string_array(n_rows, str_len, true); + let scalar_prefix = gen_scalar_prefix(prefix_str, true); + let arg_fields = vec![ + Field::new("a", DataType::Utf8View, true).into(), + Field::new("b", DataType::Utf8View, true).into(), + ]; + + c.bench_function( + &format!("starts_with_StringViewArray_scalar_strlen_{str_len}"), + |b| { + b.iter(|| { + black_box(starts_with.invoke_with_args(ScalarFunctionArgs { + args: vec![str_array.clone(), scalar_prefix.clone()], + arg_fields: arg_fields.clone(), + number_rows: n_rows, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index 18a99e44bf487..549186dbab14d 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -15,180 +15,219 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use rand::distr::Alphanumeric; use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; use std::hint::black_box; -use std::str::Chars; use std::sync::Arc; -/// gen_arr(4096, 128, 0.1, 0.1, true) will generate a StringViewArray with -/// 4096 rows, each row containing a string with 128 random characters. -/// around 10% of the rows are null, around 10% of the rows are non-ASCII. -fn gen_string_array( - n_rows: usize, +#[rustfmt::skip] +const UTF8_CORPUS: &[char] = &[ + // Cyrillic (2 bytes each) + 'А', 'Б', 'В', 'Г', 'Д', 'Е', 'Ж', 'З', 'И', 'К', 'Л', 'М', 'Н', 'О', 'П', 'Р', 'С', + 'Т', 'У', 'Ф', 'Х', 'Ц', 'Ч', 'Ш', 'Щ', 'Э', 'Ю', 'Я', + // CJK (3 bytes each) + '数', '据', '融', '合', '查', '询', '引', '擎', '优', '化', '执', '行', '计', '划', + '表', '达', + // Emoji (4 bytes each) + '📊', '🔥', '🚀', '⚡', '🎯', '💡', '🔧', '📈', +]; +const N_ROWS: usize = 8192; + +/// Returns a random string of `len` characters. If `ascii` is true, the string +/// is ASCII-only; otherwise it is drawn from `UTF8_CORPUS`. +fn random_string(rng: &mut StdRng, len: usize, ascii: bool) -> String { + if ascii { + let value: Vec = rng.sample_iter(&Alphanumeric).take(len).collect(); + String::from_utf8(value).unwrap() + } else { + (0..len) + .map(|_| UTF8_CORPUS[rng.random_range(0..UTF8_CORPUS.len())]) + .collect() + } +} + +/// Wraps `strings` into either a `StringArray` or `StringViewArray`. +fn to_columnar_value( + strings: Vec>, + is_string_view: bool, +) -> ColumnarValue { + if is_string_view { + let arr: StringViewArray = strings.into_iter().collect(); + ColumnarValue::Array(Arc::new(arr)) + } else { + let arr: StringArray = strings.into_iter().collect(); + ColumnarValue::Array(Arc::new(arr)) + } +} + +/// Returns haystack and needle, where both are arrays. Each needle is a +/// contiguous substring of its corresponding haystack. Around `null_density` +/// fraction of rows are null and `utf8_density` fraction contain non-ASCII +/// characters. +fn make_array_needle_args( + rng: &mut StdRng, str_len_chars: usize, null_density: f32, utf8_density: f32, - is_string_view: bool, // false -> StringArray, true -> StringViewArray + is_string_view: bool, ) -> Vec { - let mut rng = StdRng::seed_from_u64(42); - let rng_ref = &mut rng; - - let utf8 = "DatafusionДатаФусион数据融合📊🔥"; // includes utf8 encoding with 1~4 bytes - let corpus_char_count = utf8.chars().count(); - - let mut output_string_vec: Vec> = Vec::with_capacity(n_rows); - let mut output_sub_string_vec: Vec> = Vec::with_capacity(n_rows); - for _ in 0..n_rows { - let rand_num = rng_ref.random::(); // [0.0, 1.0) - if rand_num < null_density { - output_sub_string_vec.push(None); - output_string_vec.push(None); - } else if rand_num < null_density + utf8_density { - // Generate random UTF8 string - let mut generated_string = String::with_capacity(str_len_chars); - for _ in 0..str_len_chars { - let idx = rng_ref.random_range(0..corpus_char_count); - let char = utf8.chars().nth(idx).unwrap(); - generated_string.push(char); - } - output_sub_string_vec.push(Some(random_substring(generated_string.chars()))); - output_string_vec.push(Some(generated_string)); + let mut haystacks: Vec> = Vec::with_capacity(N_ROWS); + let mut needles: Vec> = Vec::with_capacity(N_ROWS); + for _ in 0..N_ROWS { + let r = rng.random::(); + if r < null_density { + haystacks.push(None); + needles.push(None); } else { - // Generate random ASCII-only string - let value = rng_ref + let ascii = r >= null_density + utf8_density; + let s = random_string(rng, str_len_chars, ascii); + needles.push(Some(random_substring(rng, &s))); + haystacks.push(Some(s)); + } + } + + vec![ + to_columnar_value(haystacks, is_string_view), + to_columnar_value(needles, is_string_view), + ] +} + +/// Returns haystack array with a fixed scalar needle inserted into each row. +/// Around `null_density` fraction of rows are null and `utf8_density` fraction +/// contain non-ASCII characters. The needle must be ASCII. +fn make_scalar_needle_args( + rng: &mut StdRng, + str_len_chars: usize, + needle: &str, + null_density: f32, + utf8_density: f32, + is_string_view: bool, +) -> Vec { + let needle_len = needle.len(); + assert!( + str_len_chars >= needle_len, + "str_len_chars must be >= needle length" + ); + + let mut haystacks: Vec> = Vec::with_capacity(N_ROWS); + for _ in 0..N_ROWS { + let r = rng.random::(); + if r < null_density { + haystacks.push(None); + } else if r >= null_density + utf8_density { + let mut value: Vec = (&mut *rng) .sample_iter(&Alphanumeric) .take(str_len_chars) .collect(); - let value = String::from_utf8(value).unwrap(); - output_sub_string_vec.push(Some(random_substring(value.chars()))); - output_string_vec.push(Some(value)); + let pos = rng.random_range(0..=str_len_chars - needle_len); + value[pos..pos + needle_len].copy_from_slice(needle.as_bytes()); + haystacks.push(Some(String::from_utf8(value).unwrap())); + } else { + let mut s = random_string(rng, str_len_chars, false); + let char_positions: Vec = s.char_indices().map(|(i, _)| i).collect(); + let insert_pos = if char_positions.len() > 1 { + char_positions[rng.random_range(0..char_positions.len())] + } else { + 0 + }; + s.insert_str(insert_pos, needle); + haystacks.push(Some(s)); } } - if is_string_view { - let string_view_array: StringViewArray = output_string_vec.into_iter().collect(); - let sub_string_view_array: StringViewArray = - output_sub_string_vec.into_iter().collect(); - vec![ - ColumnarValue::Array(Arc::new(string_view_array)), - ColumnarValue::Array(Arc::new(sub_string_view_array)), - ] - } else { - let string_array: StringArray = output_string_vec.clone().into_iter().collect(); - let sub_string_array: StringArray = output_sub_string_vec.into_iter().collect(); - vec![ - ColumnarValue::Array(Arc::new(string_array)), - ColumnarValue::Array(Arc::new(sub_string_array)), - ] - } + let needle_cv = ColumnarValue::Scalar(ScalarValue::Utf8(Some(needle.to_string()))); + vec![to_columnar_value(haystacks, is_string_view), needle_cv] } -fn random_substring(chars: Chars) -> String { - // get the substring of a random length from the input string by byte unit - let mut rng = StdRng::seed_from_u64(44); - let count = chars.clone().count(); +/// Extracts a random contiguous substring from `s`. +fn random_substring(rng: &mut StdRng, s: &str) -> String { + let count = s.chars().count(); + + assert!(count > 0, "random_substring requires a non-empty string"); + if count == 1 { + return s.to_string(); + } + let start = rng.random_range(0..count - 1); let end = rng.random_range(start + 1..count); - chars - .enumerate() - .filter(|(i, _)| *i >= start && *i < end) - .map(|(_, c)| c) - .collect() + s.chars().skip(start).take(end - start).collect() +} + +fn bench_strpos( + c: &mut Criterion, + name: &str, + args: &[ColumnarValue], + strpos: &datafusion_expr::ScalarUDF, +) { + let arg_fields = vec![Field::new("a", args[0].data_type(), true).into()]; + let return_field: Arc = Field::new("f", DataType::Int32, true).into(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(name, |b| { + b.iter(|| { + black_box(strpos.invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + arg_fields: arg_fields.clone(), + number_rows: N_ROWS, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + })) + }) + }); } fn criterion_benchmark(c: &mut Criterion) { - // All benches are single batch run with 8192 rows let strpos = datafusion_functions::unicode::strpos(); + let mut rng = StdRng::seed_from_u64(42); - let n_rows = 8192; for str_len in [8, 32, 128, 4096] { - // StringArray ASCII only - let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false); - let arg_fields = - vec![Field::new("a", args_string_ascii[0].data_type(), true).into()]; - let return_field = Field::new("f", DataType::Int32, true).into(); - let config_options = Arc::new(ConfigOptions::default()); - - c.bench_function( - &format!("strpos_StringArray_ascii_str_len_{str_len}"), - |b| { - b.iter(|| { - black_box(strpos.invoke_with_args(ScalarFunctionArgs { - args: args_string_ascii.clone(), - arg_fields: arg_fields.clone(), - number_rows: n_rows, - return_field: Arc::clone(&return_field), - config_options: Arc::clone(&config_options), - })) - }) - }, - ); - - // StringArray UTF8 - let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false); - let arg_fields = - vec![Field::new("a", args_string_utf8[0].data_type(), true).into()]; - let return_field = Field::new("f", DataType::Int32, true).into(); - c.bench_function(&format!("strpos_StringArray_utf8_str_len_{str_len}"), |b| { - b.iter(|| { - black_box(strpos.invoke_with_args(ScalarFunctionArgs { - args: args_string_utf8.clone(), - arg_fields: arg_fields.clone(), - number_rows: n_rows, - return_field: Arc::clone(&return_field), - config_options: Arc::clone(&config_options), - })) - }) - }); - - // StringViewArray ASCII only - let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true); - let arg_fields = - vec![Field::new("a", args_string_view_ascii[0].data_type(), true).into()]; - let return_field = Field::new("f", DataType::Int32, true).into(); - c.bench_function( - &format!("strpos_StringViewArray_ascii_str_len_{str_len}"), - |b| { - b.iter(|| { - black_box(strpos.invoke_with_args(ScalarFunctionArgs { - args: args_string_view_ascii.clone(), - arg_fields: arg_fields.clone(), - number_rows: n_rows, - return_field: Arc::clone(&return_field), - config_options: Arc::clone(&config_options), - })) - }) - }, - ); - - // StringViewArray UTF8 - let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true); - let arg_fields = - vec![Field::new("a", args_string_view_utf8[0].data_type(), true).into()]; - let return_field = Field::new("f", DataType::Int32, true).into(); - c.bench_function( - &format!("strpos_StringViewArray_utf8_str_len_{str_len}"), - |b| { - b.iter(|| { - black_box(strpos.invoke_with_args(ScalarFunctionArgs { - args: args_string_view_utf8.clone(), - arg_fields: arg_fields.clone(), - number_rows: n_rows, - return_field: Arc::clone(&return_field), - config_options: Arc::clone(&config_options), - })) - }) - }, - ); + // Array needle benchmarks + for (label, utf8_density, is_view) in [ + ("StringArray_ascii", 0.0, false), + ("StringArray_utf8", 0.5, false), + ("StringViewArray_ascii", 0.0, true), + ("StringViewArray_utf8", 0.5, true), + ] { + let args = + make_array_needle_args(&mut rng, str_len, 0.1, utf8_density, is_view); + bench_strpos( + c, + &format!("strpos_{label}_str_len_{str_len}"), + &args, + strpos.as_ref(), + ); + } + + // Scalar needle benchmarks + let needle = "xyz"; + for (label, utf8_density, is_view) in [ + ("StringArray_scalar_needle_ascii", 0.0, false), + ("StringArray_scalar_needle_utf8", 0.5, false), + ("StringViewArray_scalar_needle_ascii", 0.0, true), + ("StringViewArray_scalar_needle_utf8", 0.5, true), + ] { + let args = make_scalar_needle_args( + &mut rng, + str_len, + needle, + 0.1, + utf8_density, + is_view, + ); + bench_strpos( + c, + &format!("strpos_{label}_str_len_{str_len}"), + &args, + strpos.as_ref(), + ); + } } } diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index 771413458c1fb..3939fd100e459 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -15,55 +15,48 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; -use criterion::{criterion_group, criterion_main, Criterion, SamplingMode}; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; -use datafusion_common::DataFusionError; +use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode; use std::hint::black_box; use std::sync::Arc; +fn make_i64_arg(value: i64, size: usize, as_scalar: bool) -> ColumnarValue { + if as_scalar { + ColumnarValue::Scalar(ScalarValue::from(value)) + } else { + ColumnarValue::Array(Arc::new(Int64Array::from(vec![value; size]))) + } +} + fn create_args_without_count( size: usize, str_len: usize, start_half_way: bool, force_view_types: bool, + scalar_start: bool, ) -> Vec { - let start_array = Arc::new(Int64Array::from( - (0..size) - .map(|_| { - if start_half_way { - (str_len / 2) as i64 - } else { - 1i64 - } - }) - .collect::>(), - )); - - if force_view_types { - let string_array = - Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); - vec![ - ColumnarValue::Array(string_array), - ColumnarValue::Array(start_array), - ] + let start_val = if start_half_way { + (str_len / 2) as i64 } else { - let string_array = - Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + 1i64 + }; + let start = make_i64_arg(start_val, size, scalar_start); - vec![ - ColumnarValue::Array(string_array), - ColumnarValue::Array(Arc::clone(&start_array) as ArrayRef), - ] - } + let string_array: ArrayRef = if force_view_types { + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)) + } else { + Arc::new(create_string_array_with_len::(size, 0.1, str_len)) + }; + + vec![ColumnarValue::Array(string_array), start] } fn create_args_with_count( @@ -71,34 +64,22 @@ fn create_args_with_count( str_len: usize, count_max: usize, force_view_types: bool, + scalar_args: bool, ) -> Vec { - let start_array = - Arc::new(Int64Array::from((0..size).map(|_| 1).collect::>())); let count = count_max.min(str_len) as i64; - let count_array = Arc::new(Int64Array::from( - (0..size).map(|_| count).collect::>(), - )); - - if force_view_types { - let string_array = - Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); - vec![ - ColumnarValue::Array(string_array), - ColumnarValue::Array(start_array), - ColumnarValue::Array(count_array), - ] + let start = make_i64_arg(1i64, size, scalar_args); + let count = make_i64_arg(count, size, scalar_args); + + let string_array: ArrayRef = if force_view_types { + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)) } else { - let string_array = - Arc::new(create_string_array_with_len::(size, 0.1, str_len)); - - vec![ - ColumnarValue::Array(string_array), - ColumnarValue::Array(Arc::clone(&start_array) as ArrayRef), - ColumnarValue::Array(Arc::clone(&count_array) as ArrayRef), - ] - } + Arc::new(create_string_array_with_len::(size, 0.1, str_len)) + }; + + vec![ColumnarValue::Array(string_array), start, count] } +#[expect(clippy::needless_pass_by_value)] fn invoke_substr_with_args( args: Vec, number_rows: usize, @@ -123,22 +104,22 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096] { // string_len = 12, substring_len=6 (see `create_args_without_count`) let len = 12; - let mut group = c.benchmark_group("SHORTER THAN 12"); + let mut group = c.benchmark_group("substr, no count, short strings"); group.sampling_mode(SamplingMode::Flat); group.sample_size(10); - let args = create_args_without_count::(size, len, true, true); + let args = create_args_without_count::(size, len, true, true, false); group.bench_function( format!("substr_string_view [size={size}, strlen={len}]"), |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); - let args = create_args_without_count::(size, len, false, false); + let args = create_args_without_count::(size, len, false, false, false); group.bench_function(format!("substr_string [size={size}, strlen={len}]"), |b| { b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))) }); - let args = create_args_without_count::(size, len, true, false); + let args = create_args_without_count::(size, len, true, false, false); group.bench_function( format!("substr_large_string [size={size}, strlen={len}]"), |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), @@ -149,23 +130,23 @@ fn criterion_benchmark(c: &mut Criterion) { // string_len = 128, start=1, count=64, substring_len=64 let len = 128; let count = 64; - let mut group = c.benchmark_group("LONGER THAN 12"); + let mut group = c.benchmark_group("substr, with count, long strings"); group.sampling_mode(SamplingMode::Flat); group.sample_size(10); - let args = create_args_with_count::(size, len, count, true); + let args = create_args_with_count::(size, len, count, true, false); group.bench_function( format!("substr_string_view [size={size}, count={count}, strlen={len}]",), |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); - let args = create_args_with_count::(size, len, count, false); + let args = create_args_with_count::(size, len, count, false, false); group.bench_function( format!("substr_string [size={size}, count={count}, strlen={len}]",), |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); - let args = create_args_with_count::(size, len, count, false); + let args = create_args_with_count::(size, len, count, false, false); group.bench_function( format!("substr_large_string [size={size}, count={count}, strlen={len}]",), |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), @@ -176,29 +157,136 @@ fn criterion_benchmark(c: &mut Criterion) { // string_len = 128, start=1, count=6, substring_len=6 let len = 128; let count = 6; - let mut group = c.benchmark_group("SRC_LEN > 12, SUB_LEN < 12"); + let mut group = c.benchmark_group("substr, short count, long strings"); group.sampling_mode(SamplingMode::Flat); group.sample_size(10); - let args = create_args_with_count::(size, len, count, true); + let args = create_args_with_count::(size, len, count, true, false); group.bench_function( format!("substr_string_view [size={size}, count={count}, strlen={len}]",), |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); - let args = create_args_with_count::(size, len, count, false); + let args = create_args_with_count::(size, len, count, false, false); group.bench_function( format!("substr_string [size={size}, count={count}, strlen={len}]",), |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); - let args = create_args_with_count::(size, len, count, false); + let args = create_args_with_count::(size, len, count, false, false); group.bench_function( format!("substr_large_string [size={size}, count={count}, strlen={len}]",), |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); group.finish(); + + // Scalar start, no count, short strings + let len = 12; + let mut group = + c.benchmark_group("substr, scalar start, no count, short strings"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_without_count::(size, len, true, true, true); + group.bench_function( + format!("substr_string_view [size={size}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_without_count::(size, len, false, false, true); + group.bench_function(format!("substr_string [size={size}, strlen={len}]"), |b| { + b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))) + }); + + group.finish(); + + // Scalar start, no count, long strings, start near middle + let len = 128; + let mut group = c.benchmark_group("substr, scalar start, no count, long strings"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_without_count::(size, len, true, true, true); + group.bench_function( + format!("substr_string_view [size={size}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_without_count::(size, len, false, false, true); + group.bench_function(format!("substr_string [size={size}, strlen={len}]"), |b| { + b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))) + }); + + group.finish(); + + // Scalar start=1, no count, long strings + let len = 128; + let mut group = + c.benchmark_group("substr, scalar start=1, no count, long strings"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_without_count::(size, len, false, true, true); + group.bench_function( + format!("substr_string_view [size={size}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_without_count::(size, len, false, false, true); + group.bench_function(format!("substr_string [size={size}, strlen={len}]"), |b| { + b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))) + }); + + group.finish(); + + // Scalar start and count, short strings + let len = 12; + let count = 6; + let mut group = c.benchmark_group("substr, scalar args, short strings"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_with_count::(size, len, count, true, true); + group.bench_function( + format!("substr_string_view [size={size}, count={count}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_with_count::(size, len, count, false, true); + group.bench_function( + format!("substr_string [size={size}, count={count}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + group.finish(); + + // Scalar start and count, long strings + let len = 128; + let count = 64; + let mut group = c.benchmark_group("substr, scalar args, long strings"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_with_count::(size, len, count, true, true); + group.bench_function( + format!("substr_string_view [size={size}, count={count}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_with_count::(size, len, count, false, true); + group.bench_function( + format!("substr_string [size={size}, count={count}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_with_count::(size, len, count, false, true); + group.bench_function( + format!("substr_large_string [size={size}, count={count}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + group.finish(); } } diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index d0941d9baedda..a0c3784dbeee5 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -15,20 +15,23 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; -use arrow::array::{ArrayRef, Int64Array, StringArray}; +use arrow::array::{ArrayRef, Int64Array, StringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; -use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion_common::config::ConfigOptions; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::{ScalarValue, config::ConfigOptions}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode::substr_index; +use rand::Rng; +use rand::SeedableRng; use rand::distr::{Alphanumeric, Uniform}; use rand::prelude::Distribution; -use rand::Rng; +use rand::rngs::StdRng; + +const ARRAY_DATA_SEED: u64 = 0x5EED_AAAA; +const SCALAR_DATA_SEED: u64 = 0x5EED_BBBB; struct Filter { dist: Dist, @@ -50,71 +53,207 @@ where } } -fn data() -> (StringArray, StringArray, Int64Array) { - let dist = Filter { - dist: Uniform::new(-4, 5), +#[derive(Clone, Copy)] +enum StringRep { + Utf8, + Utf8View, +} + +impl StringRep { + fn name(self) -> &'static str { + match self { + Self::Utf8 => "utf8", + Self::Utf8View => "utf8view", + } + } + + fn data_type(self) -> DataType { + match self { + Self::Utf8 => DataType::Utf8, + Self::Utf8View => DataType::Utf8View, + } + } + + fn array(self, values: &[String]) -> ArrayRef { + match self { + Self::Utf8 => Arc::new(StringArray::from(values.to_vec())) as ArrayRef, + Self::Utf8View => Arc::new( + values + .iter() + .map(|value| Some(value.as_str())) + .collect::(), + ) as ArrayRef, + } + } + + fn scalar(self, value: &str) -> ScalarValue { + match self { + Self::Utf8 => ScalarValue::Utf8(Some(value.to_string())), + Self::Utf8View => ScalarValue::Utf8View(Some(value.to_string())), + } + } +} + +fn random_token(rng: &mut R, len: usize) -> String { + rng.sample_iter(&Alphanumeric) + .take(len) + .map(char::from) + .collect() +} + +fn array_data( + batch_size: usize, + single_char_delimiter: bool, +) -> (Vec, Vec, Vec) { + let count_dist = Filter { + dist: Uniform::new(-4, 5).expect("valid count distribution"), test: |x: &i64| x != &0, }; - let mut rng = rand::rng(); - let mut strings: Vec = vec![]; - let mut delimiters: Vec = vec![]; - let mut counts: Vec = vec![]; + let mut rng = StdRng::seed_from_u64(ARRAY_DATA_SEED); + let mut strings = Vec::with_capacity(batch_size); + let mut delimiters = Vec::with_capacity(batch_size); + let mut counts = Vec::with_capacity(batch_size); - for _ in 0..1000 { + for _ in 0..batch_size { let length = rng.random_range(20..50); - let text: String = (&mut rng) - .sample_iter(&Alphanumeric) - .take(length) - .map(char::from) - .collect(); - let char = rng.random_range(0..text.len()); - let delimiter = &text.chars().nth(char).unwrap(); - let count = rng.sample(dist.dist.unwrap()); - - strings.push(text); - delimiters.push(delimiter.to_string()); - counts.push(count); + let base = random_token(&mut rng, length); + + let (string_value, delimiter) = if single_char_delimiter { + let char_idx = rng.random_range(0..base.chars().count()); + let delimiter = base.chars().nth(char_idx).unwrap().to_string(); + (base, delimiter) + } else { + let long_delimiters = ["|||", "***", "&&&", "###", "@@@", "$$$"]; + let delimiter = + long_delimiters[rng.random_range(0..long_delimiters.len())].to_string(); + + let delimiter_count = rng.random_range(1..4); + let mut result = String::new(); + for i in 0..delimiter_count { + result.push_str(&base); + if i < delimiter_count - 1 { + result.push_str(&delimiter); + } + } + (result, delimiter) + }; + + strings.push(string_value); + delimiters.push(delimiter); + counts.push(count_dist.sample(&mut rng)); } - ( - StringArray::from(strings), - StringArray::from(delimiters), - Int64Array::from(counts), - ) + (strings, delimiters, counts) } -fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("substr_index_array_array_1000", |b| { - let (strings, delimiters, counts) = data(); - let batch_len = counts.len(); - let strings = ColumnarValue::Array(Arc::new(strings) as ArrayRef); - let delimiters = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef); - let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef); - - let args = vec![strings, delimiters, counts]; - let arg_fields = args - .iter() - .enumerate() - .map(|(idx, arg)| { - Field::new(format!("arg_{idx}"), arg.data_type(), true).into() - }) - .collect::>(); - let config_options = Arc::new(ConfigOptions::default()); - - b.iter(|| { - black_box( - substr_index() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: arg_fields.clone(), - number_rows: batch_len, - return_field: Field::new("f", DataType::Utf8, true).into(), - config_options: Arc::clone(&config_options), - }) - .expect("substr_index should work on valid values"), - ) +fn scalar_data(batch_size: usize, delimiter: &str) -> Vec { + let mut rng = StdRng::seed_from_u64(SCALAR_DATA_SEED); + let mut strings = Vec::with_capacity(batch_size); + + for _ in 0..batch_size { + let left_len = rng.random_range(12..24); + let middle_len = rng.random_range(12..24); + let right_len = rng.random_range(12..24); + let left = random_token(&mut rng, left_len); + let middle = random_token(&mut rng, middle_len); + let right = random_token(&mut rng, right_len); + strings.push(format!("{left}{delimiter}{middle}{delimiter}{right}")); + } + + strings +} + +fn run_benchmark( + b: &mut criterion::Bencher, + args: &[ColumnarValue], + return_type: &DataType, + number_rows: usize, +) { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type().clone(), true).into() }) - }); + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + b.iter(|| { + black_box( + substr_index() + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + arg_fields: arg_fields.clone(), + number_rows, + return_field: Field::new("f", return_type.clone(), true).into(), + config_options: Arc::clone(&config_options), + }) + .expect("substr_index should work on valid values"), + ) + }) +} + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("substr_index"); + let batch_sizes = [100, 1000, 10_000]; + + for batch_size in batch_sizes { + for rep in [StringRep::Utf8, StringRep::Utf8View] { + let rep_name = rep.name(); + + group.bench_function( + format!("substr_index_{rep_name}_{batch_size}_array_single_delimiter"), + |b| { + let (strings, delimiters, counts) = array_data(batch_size, true); + let args = vec![ + ColumnarValue::Array(rep.array(&strings)), + ColumnarValue::Array(rep.array(&delimiters)), + ColumnarValue::Array( + Arc::new(Int64Array::from(counts)) as ArrayRef + ), + ]; + run_benchmark(b, &args, &rep.data_type(), batch_size); + }, + ); + + group.bench_function( + format!("substr_index_{rep_name}_{batch_size}_array_long_delimiter"), + |b| { + let (strings, delimiters, counts) = array_data(batch_size, false); + let args = vec![ + ColumnarValue::Array(rep.array(&strings)), + ColumnarValue::Array(rep.array(&delimiters)), + ColumnarValue::Array( + Arc::new(Int64Array::from(counts)) as ArrayRef + ), + ]; + run_benchmark(b, &args, &rep.data_type(), batch_size); + }, + ); + + for (name, delimiter, count) in [ + ("single_delimiter_pos", ".", 1_i64), + ("single_delimiter_neg", ".", -1_i64), + ("long_delimiter_pos", "|||", 1_i64), + ("long_delimiter_neg", "|||", -1_i64), + ] { + group.bench_function( + format!("substr_index_{rep_name}_{batch_size}_scalar_{name}"), + |b| { + let strings = scalar_data(batch_size, delimiter); + let args = vec![ + ColumnarValue::Array(rep.array(&strings)), + ColumnarValue::Scalar(rep.scalar(delimiter)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(count))), + ]; + run_benchmark(b, &args, &rep.data_type(), batch_size); + }, + ); + } + } + } + + group.finish(); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index 945508aec7405..350a55a37135c 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -15,24 +15,21 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; -use arrow::array::{ArrayRef, Date32Array, StringArray}; +use arrow::array::{ArrayRef, Date32Array, Date64Array, StringArray}; use arrow::datatypes::{DataType, Field}; -use chrono::prelude::*; use chrono::TimeDelta; -use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion_common::config::ConfigOptions; +use chrono::prelude::*; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::ScalarValue; -use datafusion_common::ScalarValue::TimestampNanosecond; +use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::datetime::to_char; +use rand::Rng; use rand::prelude::IndexedRandom; use rand::rngs::ThreadRng; -use rand::Rng; fn pick_date_in_range( rng: &mut ThreadRng, @@ -65,6 +62,26 @@ fn generate_date32_array(rng: &mut ThreadRng) -> Date32Array { Date32Array::from(data) } +fn generate_date64_array(rng: &mut ThreadRng) -> Date64Array { + let start_date = "1970-01-01" + .parse::() + .expect("Date should parse"); + let end_date = "2050-12-31" + .parse::() + .expect("Date should parse"); + let mut data: Vec = Vec::with_capacity(1000); + for _ in 0..1000 { + let date = pick_date_in_range(rng, start_date, end_date); + let millis = date + .and_hms_opt(0, 0, 0) + .unwrap() + .and_utc() + .timestamp_millis(); + data.push(millis); + } + Date64Array::from(data) +} + const DATE_PATTERNS: [&str; 5] = ["%Y:%m:%d", "%d-%m-%Y", "%d%m%Y", "%Y%m%d", "%Y...%m...%d"]; @@ -157,7 +174,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("to_char_array_datetime_patterns_1000", |b| { let mut rng = rand::rng(); - let data_arr = generate_date32_array(&mut rng); + let data_arr = generate_date64_array(&mut rng); let batch_len = data_arr.len(); let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); let patterns = ColumnarValue::Array(Arc::new(generate_datetime_pattern_array( @@ -184,7 +201,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("to_char_array_mixed_patterns_1000", |b| { let mut rng = rand::rng(); - let data_arr = generate_date32_array(&mut rng); + let data_arr = generate_date64_array(&mut rng); let batch_len = data_arr.len(); let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); let patterns = ColumnarValue::Array(Arc::new(generate_mixed_pattern_array( @@ -237,7 +254,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("to_char_scalar_datetime_pattern_1000", |b| { let mut rng = rand::rng(); - let data_arr = generate_date32_array(&mut rng); + let data_arr = generate_date64_array(&mut rng); let batch_len = data_arr.len(); let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); let patterns = ColumnarValue::Scalar(ScalarValue::Utf8(Some( @@ -262,30 +279,58 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("to_char_scalar_1000", |b| { + // These bellow 02 benchmarks use Date32 data with format strings that contain + // time specifiers (%H, %M, %S, ...). Arrow's Date32 formatter cannot + // handle time specifiers and falls back to a Date64 cast. + + // Covers full fallback (every row triggers the cast) + c.bench_function("to_char_array_date32_datetime_patterns_1000", |b| { let mut rng = rand::rng(); - let timestamp = "2026-07-08T09:10:11" - .parse::() - .unwrap() - .with_nanosecond(56789) - .unwrap() - .and_utc() - .timestamp_nanos_opt() - .unwrap(); - let data = ColumnarValue::Scalar(TimestampNanosecond(Some(timestamp), None)); - let pattern = - ColumnarValue::Scalar(ScalarValue::Utf8(Some(pick_date_pattern(&mut rng)))); + let data_arr = generate_date32_array(&mut rng); + let batch_len = data_arr.len(); + let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); + let patterns = ColumnarValue::Array(Arc::new(generate_datetime_pattern_array( + &mut rng, + )) as ArrayRef); b.iter(|| { black_box( to_char() .invoke_with_args(ScalarFunctionArgs { - args: vec![data.clone(), pattern.clone()], + args: vec![data.clone(), patterns.clone()], arg_fields: vec![ Field::new("a", data.data_type(), true).into(), - Field::new("b", pattern.data_type(), true).into(), + Field::new("b", patterns.data_type(), true).into(), ], - number_rows: 1, + number_rows: batch_len, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .expect("to_char should work on valid values"), + ) + }) + }); + + // Covers partial fallback (roughly half the rows trigger it) + c.bench_function("to_char_array_date32_mixed_patterns_1000", |b| { + let mut rng = rand::rng(); + let data_arr = generate_date32_array(&mut rng); + let batch_len = data_arr.len(); + let data = ColumnarValue::Array(Arc::new(data_arr) as ArrayRef); + let patterns = ColumnarValue::Array(Arc::new(generate_mixed_pattern_array( + &mut rng, + )) as ArrayRef); + + b.iter(|| { + black_box( + to_char() + .invoke_with_args(ScalarFunctionArgs { + args: vec![data.clone(), patterns.clone()], + arg_fields: vec![ + Field::new("a", data.data_type(), true).into(), + Field::new("b", patterns.data_type(), true).into(), + ], + number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), }) diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index a75ed9258791e..33f8d9c49e8eb 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -15,33 +15,31 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - +use arrow::array::Int64Array; use arrow::datatypes::{DataType, Field, Int32Type, Int64Type}; use arrow::util::bench_util::create_primitive_array; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::hint::black_box; use std::sync::Arc; +use std::time::Duration; fn criterion_benchmark(c: &mut Criterion) { let hex = string::to_hex(); - let size = 1024; - let i32_array = Arc::new(create_primitive_array::(size, 0.2)); - let batch_len = i32_array.len(); - let i32_args = vec![ColumnarValue::Array(i32_array)]; let config_options = Arc::new(ConfigOptions::default()); - c.bench_function(&format!("to_hex i32 array: {size}"), |b| { + c.bench_function("to_hex/scalar_i32", |b| { + let args = vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(2147483647)))]; + let arg_fields = vec![Field::new("a", DataType::Int32, true).into()]; b.iter(|| { - let args_cloned = i32_args.clone(); black_box( hex.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - arg_fields: vec![Field::new("a", DataType::Int32, false).into()], - number_rows: batch_len, + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), }) @@ -49,17 +47,18 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); - let i64_array = Arc::new(create_primitive_array::(size, 0.2)); - let batch_len = i64_array.len(); - let i64_args = vec![ColumnarValue::Array(i64_array)]; - c.bench_function(&format!("to_hex i64 array: {size}"), |b| { + + c.bench_function("to_hex/scalar_i64", |b| { + let args = vec![ColumnarValue::Scalar(ScalarValue::Int64(Some( + 9223372036854775807, + )))]; + let arg_fields = vec![Field::new("a", DataType::Int64, true).into()]; b.iter(|| { - let args_cloned = i64_args.clone(); black_box( hex.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - arg_fields: vec![Field::new("a", DataType::Int64, false).into()], - number_rows: batch_len, + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), }) @@ -67,6 +66,88 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + for size in [1024, 4096, 8192] { + let mut group = c.benchmark_group(format!("to_hex size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + // i32 array with random values + let i32_array = Arc::new(create_primitive_array::(size, 0.1)); + let batch_len = i32_array.len(); + let i32_args = vec![ColumnarValue::Array(i32_array)]; + + group.bench_function("i32_random", |b| { + b.iter(|| { + let args_cloned = i32_args.clone(); + black_box( + hex.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Int32, true).into()], + number_rows: batch_len, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + // i64 array with random values (produces longer hex strings) + let i64_array = Arc::new(create_primitive_array::(size, 0.1)); + let batch_len = i64_array.len(); + let i64_args = vec![ColumnarValue::Array(i64_array)]; + + group.bench_function("i64_random", |b| { + b.iter(|| { + let args_cloned = i64_args.clone(); + black_box( + hex.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Int64, true).into()], + number_rows: batch_len, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + // i64 array with large values (max length hex strings) + let i64_large_array = Arc::new(Int64Array::from( + (0..size) + .map(|i| { + if i % 10 == 0 { + None + } else { + Some(i64::MAX - i as i64) + } + }) + .collect::>(), + )); + let batch_len = i64_large_array.len(); + let i64_large_args = vec![ColumnarValue::Array(i64_large_array)]; + + group.bench_function("i64_large_values", |b| { + b.iter(|| { + let args_cloned = i64_large_args.clone(); + black_box( + hex.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Int64, true).into()], + number_rows: batch_len, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + group.finish(); + } } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/to_local_time.rs b/datafusion/functions/benches/to_local_time.rs new file mode 100644 index 0000000000000..42d1e271980e8 --- /dev/null +++ b/datafusion/functions/benches/to_local_time.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::hint::black_box; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, TimestampNanosecondArray}; +use arrow::datatypes::Field; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::datetime::to_local_time; +use rand::Rng; +use rand::rngs::ThreadRng; + +fn timestamps(rng: &mut ThreadRng) -> TimestampNanosecondArray { + let nanos: Vec = (0..100_000) + .map(|_| rng.random_range(0..1_000_000_000_000_000_000i64)) + .collect(); + TimestampNanosecondArray::from(nanos).with_timezone("America/New_York") +} + +fn timestamps_with_nulls(rng: &mut ThreadRng) -> TimestampNanosecondArray { + let values: Vec> = (0..100_000) + .map(|_| { + if rng.random_range(0..10u32) == 0 { + None + } else { + Some(rng.random_range(0..1_000_000_000_000_000_000i64)) + } + }) + .collect(); + TimestampNanosecondArray::from(values).with_timezone("America/New_York") +} + +fn bench_to_local_time(c: &mut Criterion, name: &str, array: ArrayRef) { + let batch_len = array.len(); + let input = ColumnarValue::Array(array); + let udf = to_local_time(); + let return_type = udf.return_type(&[input.data_type()]).unwrap(); + let return_field = Arc::new(Field::new("f", return_type, true)); + let arg_fields = vec![Field::new("a", input.data_type(), true).into()]; + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(name, |b| { + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: vec![input.clone()], + arg_fields: arg_fields.clone(), + number_rows: batch_len, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .expect("to_local_time should work on valid values"), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let mut rng = rand::rng(); + bench_to_local_time( + c, + "to_local_time_no_nulls_100k", + Arc::new(timestamps(&mut rng)), + ); + bench_to_local_time( + c, + "to_local_time_10pct_nulls_100k", + Arc::new(timestamps_with_nulls(&mut rng)), + ); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/to_time.rs b/datafusion/functions/benches/to_time.rs new file mode 100644 index 0000000000000..6b3aa192415a3 --- /dev/null +++ b/datafusion/functions/benches/to_time.rs @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::hint::black_box; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, StringArray}; +use arrow::datatypes::Field; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::datetime::to_time; +use rand::Rng; +use rand::rngs::ThreadRng; + +fn random_time_string(rng: &mut ThreadRng) -> String { + format!( + "{:02}:{:02}:{:02}.{:06}", + rng.random_range(0..24u32), + rng.random_range(0..60u32), + rng.random_range(0..60u32), + rng.random_range(0..1_000_000u32), + ) +} + +fn time_strings(rng: &mut ThreadRng) -> StringArray { + let strings: Vec = (0..100_000).map(|_| random_time_string(rng)).collect(); + StringArray::from(strings) +} + +fn time_strings_with_nulls(rng: &mut ThreadRng) -> StringArray { + let values: Vec> = (0..100_000) + .map(|_| { + if rng.random_range(0..10u32) == 0 { + None + } else { + Some(random_time_string(rng)) + } + }) + .collect(); + StringArray::from(values) +} + +fn bench_to_time(c: &mut Criterion, name: &str, array: ArrayRef) { + let batch_len = array.len(); + let input = ColumnarValue::Array(array); + let udf = to_time(); + let return_type = udf.return_type(&[input.data_type()]).unwrap(); + let return_field = Arc::new(Field::new("f", return_type, true)); + let arg_fields = vec![Field::new("a", input.data_type(), true).into()]; + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(name, |b| { + b.iter(|| { + black_box( + udf.invoke_with_args(ScalarFunctionArgs { + args: vec![input.clone()], + arg_fields: arg_fields.clone(), + number_rows: batch_len, + return_field: Arc::clone(&return_field), + config_options: Arc::clone(&config_options), + }) + .expect("to_time should work on valid values"), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let mut rng = rand::rng(); + bench_to_time(c, "to_time_no_nulls_100k", Arc::new(time_strings(&mut rng))); + bench_to_time( + c, + "to_time_10pct_nulls_100k", + Arc::new(time_strings_with_nulls(&mut rng)), + ); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index a8f5c5816d4da..90ea145d5d2c0 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use std::hint::black_box; use std::sync::Arc; @@ -24,7 +22,7 @@ use arrow::array::builder::StringBuilder; use arrow::array::{Array, ArrayRef, StringArray}; use arrow::compute::cast; use arrow::datatypes::{DataType, Field, TimeUnit}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::datetime::to_timestamp; @@ -114,16 +112,21 @@ fn criterion_benchmark(c: &mut Criterion) { Field::new("f", DataType::Timestamp(TimeUnit::Nanosecond, None), true).into(); let arg_field = Field::new("a", DataType::Utf8, false).into(); let arg_fields = vec![arg_field]; - let config_options = Arc::new(ConfigOptions::default()); + let mut options = ConfigOptions::default(); + options.execution.time_zone = Some("UTC".into()); + let config_options = Arc::new(options); + + let to_timestamp_udf = to_timestamp(config_options.as_ref()); c.bench_function("to_timestamp_no_formats_utf8", |b| { + let to_timestamp_udf = Arc::clone(&to_timestamp_udf); let arr_data = data(); let batch_len = arr_data.len(); let string_array = ColumnarValue::Array(Arc::new(arr_data) as ArrayRef); b.iter(|| { black_box( - to_timestamp() + to_timestamp_udf .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], arg_fields: arg_fields.clone(), @@ -137,13 +140,14 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("to_timestamp_no_formats_largeutf8", |b| { + let to_timestamp_udf = Arc::clone(&to_timestamp_udf); let data = cast(&data(), &DataType::LargeUtf8).unwrap(); let batch_len = data.len(); let string_array = ColumnarValue::Array(Arc::new(data) as ArrayRef); b.iter(|| { black_box( - to_timestamp() + to_timestamp_udf .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], arg_fields: arg_fields.clone(), @@ -157,13 +161,14 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("to_timestamp_no_formats_utf8view", |b| { + let to_timestamp_udf = Arc::clone(&to_timestamp_udf); let data = cast(&data(), &DataType::Utf8View).unwrap(); let batch_len = data.len(); let string_array = ColumnarValue::Array(Arc::new(data) as ArrayRef); b.iter(|| { black_box( - to_timestamp() + to_timestamp_udf .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], arg_fields: arg_fields.clone(), @@ -177,6 +182,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("to_timestamp_with_formats_utf8", |b| { + let to_timestamp_udf = Arc::clone(&to_timestamp_udf); let (inputs, format1, format2, format3) = data_with_formats(); let batch_len = inputs.len(); @@ -196,7 +202,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box( - to_timestamp() + to_timestamp_udf .invoke_with_args(ScalarFunctionArgs { args: args.clone(), arg_fields: arg_fields.clone(), @@ -210,6 +216,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("to_timestamp_with_formats_largeutf8", |b| { + let to_timestamp_udf = Arc::clone(&to_timestamp_udf); let (inputs, format1, format2, format3) = data_with_formats(); let batch_len = inputs.len(); @@ -237,7 +244,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box( - to_timestamp() + to_timestamp_udf .invoke_with_args(ScalarFunctionArgs { args: args.clone(), arg_fields: arg_fields.clone(), @@ -251,6 +258,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("to_timestamp_with_formats_utf8view", |b| { + let to_timestamp_udf = Arc::clone(&to_timestamp_udf); let (inputs, format1, format2, format3) = data_with_formats(); let batch_len = inputs.len(); @@ -279,7 +287,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box( - to_timestamp() + to_timestamp_udf .invoke_with_args(ScalarFunctionArgs { args: args.clone(), arg_fields: arg_fields.clone(), diff --git a/datafusion/functions/benches/translate.rs b/datafusion/functions/benches/translate.rs new file mode 100644 index 0000000000000..adde7b4bd763d --- /dev/null +++ b/datafusion/functions/benches/translate.rs @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::{DataType, Field}; +use arrow::util::bench_util::create_string_array_with_len; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_functions::unicode; +use rand::SeedableRng; +use rand::prelude::IndexedRandom; +use rand::rngs::StdRng; +use std::hint::black_box; +use std::sync::Arc; +use std::time::Duration; + +// Mix of 2-byte (Greek) and 3-byte (CJK/Hangul) UTF-8 to exercise +// variable-width char paths in translate. +const NON_ASCII_ALPHABET: &[char] = &[ + 'α', 'β', 'γ', 'δ', 'ε', 'ζ', 'η', 'θ', 'ι', 'κ', 'λ', 'μ', 'ν', 'ξ', 'ο', 'π', 'ρ', + 'σ', 'τ', 'υ', 'φ', 'χ', 'ψ', 'ω', '日', '本', '語', '中', '文', '한', '국', '어', +]; + +fn create_non_ascii_string_array( + size: usize, + char_count: usize, + seed: u64, +) -> GenericStringArray { + let mut rng = StdRng::seed_from_u64(seed); + (0..size) + .map(|_| { + Some( + (0..char_count) + .map(|_| *NON_ASCII_ALPHABET.choose(&mut rng).unwrap()) + .collect::(), + ) + }) + .collect() +} + +fn create_args_array_from_to( + size: usize, + str_len: usize, +) -> Vec { + let string_array = Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + let from_array = Arc::new(create_string_array_with_len::(size, 0.1, 3)); + let to_array = Arc::new(create_string_array_with_len::(size, 0.1, 2)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(from_array), + ColumnarValue::Array(to_array), + ] +} + +fn create_args_array_from_to_non_ascii( + size: usize, + str_len: usize, +) -> Vec { + let string_array = Arc::new(create_non_ascii_string_array::( + size, + str_len, + 0xA110_AAAA, + )); + let from_array = Arc::new(create_non_ascii_string_array::(size, 3, 0xA110_BBBB)); + let to_array = Arc::new(create_non_ascii_string_array::(size, 2, 0xA110_CCCC)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(from_array), + ColumnarValue::Array(to_array), + ] +} + +fn create_args_scalar_from_to( + size: usize, + str_len: usize, +) -> Vec { + let string_array = Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Scalar(ScalarValue::from("aeiou")), + ColumnarValue::Scalar(ScalarValue::from("AEIOU")), + ] +} + +fn invoke_translate_with_args( + args: Vec, + number_rows: usize, +) -> Result { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + unicode::translate().invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + }) +} + +fn criterion_benchmark(c: &mut Criterion) { + for size in [1024, 4096] { + let mut group = c.benchmark_group(format!("translate size={size}")); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + for str_len in [8, 32, 128, 1024] { + let args = create_args_array_from_to::(size, str_len); + group.bench_function(format!("array_from_to [str_len={str_len}]"), |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(invoke_translate_with_args(args_cloned, size)) + }) + }); + + let args = create_args_array_from_to_non_ascii::(size, str_len); + group.bench_function( + format!("array_from_to_non_ascii [str_len={str_len}]"), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(invoke_translate_with_args(args_cloned, size)) + }) + }, + ); + + let args = create_args_scalar_from_to::(size, str_len); + group.bench_function(format!("scalar_from_to [str_len={str_len}]"), |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(invoke_translate_with_args(args_cloned, size)) + }) + }); + } + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/trim.rs b/datafusion/functions/benches/trim.rs new file mode 100644 index 0000000000000..21d99592d1820 --- /dev/null +++ b/datafusion/functions/benches/trim.rs @@ -0,0 +1,435 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray}; +use arrow::datatypes::{DataType, Field}; +use criterion::{ + BenchmarkGroup, Criterion, SamplingMode, criterion_group, criterion_main, + measurement::Measurement, +}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF}; +use datafusion_functions::string; +use rand::{Rng, SeedableRng, distr::Alphanumeric, rngs::StdRng}; +use std::hint::black_box; +use std::{fmt, sync::Arc}; + +#[derive(Clone, Copy)] +pub enum StringArrayType { + Utf8View, + Utf8, + LargeUtf8, +} + +impl fmt::Display for StringArrayType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + StringArrayType::Utf8View => f.write_str("string_view"), + StringArrayType::Utf8 => f.write_str("string"), + StringArrayType::LargeUtf8 => f.write_str("large_string"), + } + } +} + +#[derive(Clone, Copy)] +pub enum TrimType { + Ltrim, + Rtrim, + Btrim, +} + +impl fmt::Display for TrimType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TrimType::Ltrim => f.write_str("ltrim"), + TrimType::Rtrim => f.write_str("rtrim"), + TrimType::Btrim => f.write_str("btrim"), + } + } +} + +/// Returns an array of strings with trim characters positioned according to trim type, +/// and `characters` as a ScalarValue. +/// +/// For ltrim: trim characters are at the start (prefix) +/// For rtrim: trim characters are at the end (suffix) +/// For btrim: trim characters are at both start and end +fn create_string_array_and_characters( + size: usize, + characters: &str, + trimmed: &str, + remaining_len: usize, + string_array_type: StringArrayType, + trim_type: TrimType, +) -> (ArrayRef, ScalarValue) { + let rng = &mut StdRng::seed_from_u64(42); + + // Create `size` rows: + // - 10% rows will be `None` + // - Other 90% will be strings with `remaining_len` content length + let string_iter = (0..size).map(|_| { + if rng.random::() < 0.1 { + None + } else { + let content: String = rng + .sample_iter(&Alphanumeric) + .take(remaining_len) + .map(char::from) + .collect(); + + let value = match trim_type { + TrimType::Ltrim => format!("{trimmed}{content}"), + TrimType::Rtrim => format!("{content}{trimmed}"), + TrimType::Btrim => format!("{trimmed}{content}{trimmed}"), + }; + Some(value) + } + }); + + // Build the target `string array` and `characters` according to `string_array_type` + match string_array_type { + StringArrayType::Utf8View => ( + Arc::new(string_iter.collect::()), + ScalarValue::Utf8View(Some(characters.to_string())), + ), + StringArrayType::Utf8 => ( + Arc::new(string_iter.collect::()), + ScalarValue::Utf8(Some(characters.to_string())), + ), + StringArrayType::LargeUtf8 => ( + Arc::new(string_iter.collect::()), + ScalarValue::LargeUtf8(Some(characters.to_string())), + ), + } +} + +/// Create args for the trim benchmark +fn create_args( + size: usize, + characters: &str, + trimmed: &str, + remaining_len: usize, + string_array_type: StringArrayType, + trim_type: TrimType, +) -> Vec { + let (string_array, pattern) = create_string_array_and_characters( + size, + characters, + trimmed, + remaining_len, + string_array_type, + trim_type, + ); + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Scalar(pattern), + ] +} + +/// Create args for trim benchmark where space characters are being trimmed +fn create_space_trim_args( + size: usize, + pad_len: usize, + remaining_len: usize, + string_array_type: StringArrayType, + trim_type: TrimType, +) -> Vec { + let rng = &mut StdRng::seed_from_u64(42); + let spaces = " ".repeat(pad_len); + + let string_iter = (0..size).map(|_| { + if rng.random::() < 0.1 { + None + } else { + let content: String = rng + .sample_iter(&Alphanumeric) + .take(remaining_len) + .map(char::from) + .collect(); + + let value = match trim_type { + TrimType::Ltrim => format!("{spaces}{content}"), + TrimType::Rtrim => format!("{content}{spaces}"), + TrimType::Btrim => format!("{spaces}{content}{spaces}"), + }; + Some(value) + } + }); + + let string_array: ArrayRef = match string_array_type { + StringArrayType::Utf8View => Arc::new(string_iter.collect::()), + StringArrayType::Utf8 => Arc::new(string_iter.collect::()), + StringArrayType::LargeUtf8 => Arc::new(string_iter.collect::()), + }; + + vec![ColumnarValue::Array(string_array)] +} + +#[expect(clippy::too_many_arguments)] +fn run_with_string_type( + group: &mut BenchmarkGroup<'_, M>, + trim_func: &ScalarUDF, + trim_type: TrimType, + size: usize, + total_len: usize, + characters: &str, + trimmed: &str, + remaining_len: usize, + string_type: StringArrayType, +) { + let args = create_args( + size, + characters, + trimmed, + remaining_len, + string_type, + trim_type, + ); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + group.bench_function( + format!( + "{trim_type} {string_type} [size={size}, len={total_len}, remaining={remaining_len}]", + ), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(trim_func.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); +} + +#[expect(clippy::too_many_arguments)] +fn run_trim_benchmark( + c: &mut Criterion, + group_name: &str, + trim_func: &ScalarUDF, + trim_type: TrimType, + string_types: &[StringArrayType], + size: usize, + total_len: usize, + characters: &str, + trimmed: &str, + remaining_len: usize, +) { + let mut group = c.benchmark_group(group_name); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + for string_type in string_types { + run_with_string_type( + &mut group, + trim_func, + trim_type, + size, + total_len, + characters, + trimmed, + remaining_len, + *string_type, + ); + } + + group.finish(); +} + +#[expect(clippy::too_many_arguments)] +fn run_space_trim_benchmark( + c: &mut Criterion, + group_name: &str, + trim_func: &ScalarUDF, + trim_type: TrimType, + string_types: &[StringArrayType], + size: usize, + pad_len: usize, + remaining_len: usize, +) { + let mut group = c.benchmark_group(group_name); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let total_len = match trim_type { + TrimType::Btrim => 2 * pad_len + remaining_len, + _ => pad_len + remaining_len, + }; + + for string_type in string_types { + let args = + create_space_trim_args(size, pad_len, remaining_len, *string_type, trim_type); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true).into() + }) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + group.bench_function( + format!( + "{trim_type} {string_type} [size={size}, len={total_len}, pad={pad_len}]", + ), + |b| { + b.iter(|| { + let args_cloned = args.clone(); + black_box(trim_func.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }, + ); + } + + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let ltrim = string::ltrim(); + let rtrim = string::rtrim(); + let btrim = string::btrim(); + + let characters = ",!()"; + + let string_types = [ + StringArrayType::Utf8View, + StringArrayType::Utf8, + StringArrayType::LargeUtf8, + ]; + + let trim_funcs = [ + (<rim, TrimType::Ltrim), + (&rtrim, TrimType::Rtrim), + (&btrim, TrimType::Btrim), + ]; + + for size in [4096] { + for (trim_func, trim_type) in &trim_funcs { + // Scenario 1: Short strings (len <= 12, inline in StringView) + // trimmed_len=4, remaining_len=8 + let total_len = 12; + let trimmed = characters; + let remaining_len = total_len - trimmed.len(); + run_trim_benchmark( + c, + "short strings (len <= 12)", + trim_func, + *trim_type, + &string_types, + size, + total_len, + characters, + trimmed, + remaining_len, + ); + + // Scenario 2: Long strings, short trim (len > 12, output > 12) + // trimmed_len=4, remaining_len=60 + let total_len = 64; + let trimmed = characters; + let remaining_len = total_len - trimmed.len(); + run_trim_benchmark( + c, + "long strings, short trim", + trim_func, + *trim_type, + &string_types, + size, + total_len, + characters, + trimmed, + remaining_len, + ); + + // Scenario 3: Long strings, long trim (len > 12, output <= 12) + // trimmed_len=56, remaining_len=8 + let total_len = 64; + let trimmed = characters.repeat(14); + let remaining_len = total_len - trimmed.len(); + run_trim_benchmark( + c, + "long strings, long trim", + trim_func, + *trim_type, + &string_types, + size, + total_len, + characters, + &trimmed, + remaining_len, + ); + + // Scenario 4: Trim spaces, short strings (len <= 12) + // pad_len=4, remaining_len=8 + run_space_trim_benchmark( + c, + "trim spaces, short strings (len <= 12)", + trim_func, + *trim_type, + &string_types, + size, + 4, + 8, + ); + + // Scenario 5: Trim spaces, long strings (len > 12) + // pad_len=4, remaining_len=60 + run_space_trim_benchmark( + c, + "trim spaces, long strings", + trim_func, + *trim_type, + &string_types, + size, + 4, + 60, + ); + + // Scenario 6: Trim spaces, long strings, heavy padding + // pad_len=56, remaining_len=8 + run_space_trim_benchmark( + c, + "trim spaces, heavy padding", + trim_func, + *trim_type, + &string_types, + size, + 56, + 8, + ); + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index 6e225e0e7038b..ffbedcb142c71 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -15,13 +15,11 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::{ datatypes::{Field, Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::trunc; use std::hint::black_box; @@ -32,12 +30,13 @@ use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let trunc = trunc(); + let config_options = Arc::new(ConfigOptions::default()); + for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; let return_field = Field::new("f", DataType::Float32, true).into(); - let config_options = Arc::new(ConfigOptions::default()); c.bench_function(&format!("trunc f32 array: {size}"), |b| { b.iter(|| { @@ -74,6 +73,51 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); } + + // Scalar benchmarks - to measure optimized performance + let scalar_f64_args = vec![ColumnarValue::Scalar( + datafusion_common::ScalarValue::Float64(Some(std::f64::consts::PI)), + )]; + let scalar_arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let scalar_return_field = Field::new("f", DataType::Float64, false).into(); + + c.bench_function("trunc f64 scalar", |b| { + b.iter(|| { + black_box( + trunc + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f64_args.clone(), + arg_fields: scalar_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&scalar_return_field), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); + + let scalar_f32_args = vec![ColumnarValue::Scalar( + datafusion_common::ScalarValue::Float32(Some(std::f32::consts::PI)), + )]; + let scalar_f32_arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let scalar_f32_return_field = Field::new("f", DataType::Float32, false).into(); + + c.bench_function("trunc f32 scalar", |b| { + b.iter(|| { + black_box( + trunc + .invoke_with_args(ScalarFunctionArgs { + args: scalar_f32_args.clone(), + arg_fields: scalar_f32_arg_fields.clone(), + number_rows: 1, + return_field: Arc::clone(&scalar_f32_return_field), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index 7328b32574a4a..3f6fa36b18c13 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -15,11 +15,9 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index 1368e2f2af5d1..629fb950dd9ff 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -15,10 +15,8 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; use datafusion_expr::ScalarFunctionArgs; use datafusion_functions::string; diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index c4e58601cd106..0b67883c17c87 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -20,17 +20,15 @@ use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::error::ArrowError; use datafusion_common::{ - arrow_datafusion_err, exec_err, internal_err, Result, ScalarValue, + Result, ScalarValue, arrow_datafusion_err, datatype::DataTypeExt, + exec_datafusion_err, exec_err, internal_err, types::logical_string, + utils::take_function_args, }; -use datafusion_common::{ - exec_datafusion_err, utils::take_function_args, DataFusionError, -}; -use std::any::Any; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, - ScalarUDFImpl, Signature, Volatility, + Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; @@ -104,16 +102,18 @@ impl Default for ArrowCastFunc { impl ArrowCastFunc { pub fn new() -> Self { Self { - signature: Signature::any(2, Volatility::Immutable), + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Any), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ], + Volatility::Immutable, + ), } } } impl ScalarUDFImpl for ArrowCastFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "arrow_cast" } @@ -154,24 +154,21 @@ impl ScalarUDFImpl for ArrowCastFunc { fn simplify( &self, - mut args: Vec, - info: &dyn SimplifyInfo, + args: Vec, + info: &SimplifyContext, ) -> Result { // convert this into a real cast - let target_type = data_type_from_args(&args)?; - // remove second (type) argument - args.pop().unwrap(); - let arg = args.pop().unwrap(); - - let source_type = info.get_data_type(&arg)?; + let [source_arg, type_arg] = take_function_args(self.name(), args)?; + let target_type = data_type_from_type_arg(self.name(), &type_arg)?; + let source_type = info.get_data_type(&source_arg)?; let new_expr = if source_type == target_type { // the argument's data type is already the correct type - arg + source_arg } else { // Use an actual cast to get the correct type Expr::Cast(datafusion_expr::Cast { - expr: Box::new(arg), - data_type: target_type, + expr: Box::new(source_arg), + field: target_type.into_nullable_field_ref(), }) }; // return the newly written argument to DataFusion @@ -183,13 +180,11 @@ impl ScalarUDFImpl for ArrowCastFunc { } } -/// Returns the requested type from the arguments -fn data_type_from_args(args: &[Expr]) -> Result { - let [_, type_arg] = take_function_args("arrow_cast", args)?; - +/// Returns the requested type from the type argument +pub(crate) fn data_type_from_type_arg(name: &str, type_arg: &Expr) -> Result { let Expr::Literal(ScalarValue::Utf8(Some(val)), _) = type_arg else { return exec_err!( - "arrow_cast requires its second argument to be a constant string, got {:?}", + "{name} requires its second argument to be a constant string, got {:?}", type_arg ); }; diff --git a/datafusion/functions/src/core/arrow_field.rs b/datafusion/functions/src/core/arrow_field.rs new file mode 100644 index 0000000000000..dce7cff42ba80 --- /dev/null +++ b/datafusion/functions/src/core/arrow_field.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, BooleanArray, MapBuilder, StringArray, StringBuilder, StructArray, +}; +use arrow::datatypes::{DataType, Field, Fields}; +use datafusion_common::{Result, ScalarValue, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +#[user_doc( + doc_section(label = "Other Functions"), + description = "Returns a struct containing the Arrow field information of the expression, including name, data type, nullability, and metadata.", + syntax_example = "arrow_field(expression)", + sql_example = r#"```sql +> select arrow_field(1); ++-------------------------------------------------------------+ +| arrow_field(Int64(1)) | ++-------------------------------------------------------------+ +| {name: lit, data_type: Int64, nullable: false, metadata: {}} | ++-------------------------------------------------------------+ + +> select arrow_field(1)['data_type']; ++-----------------------------------+ +| arrow_field(Int64(1))[data_type] | ++-----------------------------------+ +| Int64 | ++-----------------------------------+ +```"#, + argument( + name = "expression", + description = "Expression to evaluate. The expression can be a constant, column, or function, and any combination of operators." + ) +)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct ArrowFieldFunc { + signature: Signature, +} + +impl Default for ArrowFieldFunc { + fn default() -> Self { + Self::new() + } +} + +impl ArrowFieldFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } + + fn return_struct_type() -> DataType { + DataType::Struct(Fields::from(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("data_type", DataType::Utf8, false), + Field::new("nullable", DataType::Boolean, false), + Field::new( + "metadata", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("keys", DataType::Utf8, false), + Field::new("values", DataType::Utf8, true), + ])), + false, + )), + false, + ), + false, + ), + ])) + } +} + +impl ScalarUDFImpl for ArrowFieldFunc { + fn name(&self) -> &str { + "arrow_field" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Self::return_struct_type()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let return_type = args.return_type().clone(); + let [field] = take_function_args(self.name(), args.arg_fields)?; + + // Build the name array + let name_array = + Arc::new(StringArray::from(vec![field.name().as_str()])) as Arc; + + // Build the data_type array + let data_type_str = field.data_type().to_string(); + let data_type_array = + Arc::new(StringArray::from(vec![data_type_str.as_str()])) as Arc; + + // Build the nullable array + let nullable_array = + Arc::new(BooleanArray::from(vec![field.is_nullable()])) as Arc; + + // Build the metadata map array (same pattern as arrow_metadata.rs) + let metadata = field.metadata(); + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), StringBuilder::new()); + + let mut entries: Vec<_> = metadata.iter().collect(); + entries.sort_by_key(|(k, _)| *k); + + for (k, v) in entries { + map_builder.keys().append_value(k); + map_builder.values().append_value(v); + } + map_builder.append(true)?; + + let metadata_array = Arc::new(map_builder.finish()) as Arc; + + // Build the struct + let DataType::Struct(fields) = return_type else { + unreachable!() + }; + + let struct_array = StructArray::new( + fields, + vec![name_array, data_type_array, nullable_array, metadata_array], + None, + ); + + Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &struct_array, + 0, + )?)) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/functions/src/core/arrow_metadata.rs b/datafusion/functions/src/core/arrow_metadata.rs new file mode 100644 index 0000000000000..a80f66f396731 --- /dev/null +++ b/datafusion/functions/src/core/arrow_metadata.rs @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{MapBuilder, StringBuilder}; +use arrow::datatypes::{DataType, Field, Fields}; +use datafusion_common::types::logical_string; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, TypeSignatureClass, Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +#[user_doc( + doc_section(label = "Other Functions"), + description = "Returns the metadata of the input expression. If a key is provided, returns the value for that key. If no key is provided, returns a Map of all metadata.", + syntax_example = "arrow_metadata(expression[, key])", + sql_example = r#"```sql +> select arrow_metadata(col) from table; ++----------------------------+ +| arrow_metadata(table.col) | ++----------------------------+ +| {k: v} | ++----------------------------+ +> select arrow_metadata(col, 'k') from table; ++-------------------------------+ +| arrow_metadata(table.col, 'k')| ++-------------------------------+ +| v | ++-------------------------------+ +```"#, + argument( + name = "expression", + description = "The expression to retrieve metadata from. Can be a column or other expression." + ), + argument( + name = "key", + description = "Optional. The specific metadata key to retrieve." + ) +)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ArrowMetadataFunc { + signature: Signature, +} + +impl ArrowMetadataFunc { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![Coercion::new_exact( + TypeSignatureClass::Any, + )]), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Any), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ]), + ], + Volatility::Immutable, + ), + } + } +} + +impl Default for ArrowMetadataFunc { + fn default() -> Self { + Self::new() + } +} + +impl ScalarUDFImpl for ArrowMetadataFunc { + fn name(&self) -> &str { + "arrow_metadata" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() == 2 { + Ok(DataType::Utf8) + } else if arg_types.len() == 1 { + Ok(DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("keys", DataType::Utf8, false), + Field::new("values", DataType::Utf8, true), + ])), + false, + )), + false, + )) + } else { + internal_err!("arrow_metadata requires 1 or 2 arguments") + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let metadata = args.arg_fields[0].metadata(); + + if args.args.len() == 2 { + let key = match &args.args[1] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(key))) => key, + _ => { + return exec_err!( + "Second argument to arrow_metadata must be a string literal key" + ); + } + }; + let value = metadata.get(key).cloned(); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(value))) + } else if args.args.len() == 1 { + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), StringBuilder::new()); + + let mut entries: Vec<_> = metadata.iter().collect(); + entries.sort_by_key(|(k, _)| *k); + + for (k, v) in entries { + map_builder.keys().append_value(k); + map_builder.values().append_value(v); + } + map_builder.append(true)?; + + let map_array = map_builder.finish(); + + Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &map_array, 0, + )?)) + } else { + internal_err!("arrow_metadata requires 1 or 2 arguments") + } + } +} diff --git a/datafusion/functions/src/core/arrow_try_cast.rs b/datafusion/functions/src/core/arrow_try_cast.rs new file mode 100644 index 0000000000000..d27b29ba5736d --- /dev/null +++ b/datafusion/functions/src/core/arrow_try_cast.rs @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ArrowTryCastFunc`]: Implementation of the `arrow_try_cast` + +use arrow::datatypes::{DataType, Field, FieldRef}; +use arrow::error::ArrowError; +use datafusion_common::{ + Result, arrow_datafusion_err, datatype::DataTypeExt, exec_datafusion_err, exec_err, + internal_err, types::logical_string, utils::take_function_args, +}; + +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; +use datafusion_macros::user_doc; + +use super::arrow_cast::data_type_from_type_arg; + +/// Like [`arrow_cast`](super::arrow_cast::ArrowCastFunc) but returns NULL on cast failure instead of erroring. +/// +/// This is implemented by simplifying `arrow_try_cast(expr, 'Type')` into +/// `Expr::TryCast` during optimization. +#[user_doc( + doc_section(label = "Other Functions"), + description = "Casts a value to a specific Arrow data type, returning NULL if the cast fails.", + syntax_example = "arrow_try_cast(expression, datatype)", + sql_example = r#"```sql +> select arrow_try_cast('123', 'Int64') as a, + arrow_try_cast('not_a_number', 'Int64') as b; + ++-----+------+ +| a | b | ++-----+------+ +| 123 | NULL | ++-----+------+ +```"#, + argument( + name = "expression", + description = "Expression to cast. The expression can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "datatype", + description = "[Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`]" + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrowTryCastFunc { + signature: Signature, +} + +impl Default for ArrowTryCastFunc { + fn default() -> Self { + Self::new() + } +} + +impl ArrowTryCastFunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Any), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for ArrowTryCastFunc { + fn name(&self) -> &str { + "arrow_try_cast" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be called instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + // TryCast can always return NULL (on cast failure), so always nullable + let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?; + + type_arg + .and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty())) + .map_or_else( + || { + exec_err!( + "{} requires its second argument to be a non-empty constant string", + self.name() + ) + }, + |casted_type| match casted_type.parse::() { + Ok(data_type) => { + Ok(Field::new(self.name(), data_type, true).into()) + } + Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")), + Err(e) => Err(arrow_datafusion_err!(e)), + }, + ) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("arrow_try_cast should have been simplified to try_cast") + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [source_arg, type_arg] = take_function_args(self.name(), args)?; + let target_type = data_type_from_type_arg(self.name(), &type_arg)?; + + let source_type = info.get_data_type(&source_arg)?; + let new_expr = if source_type == target_type { + source_arg + } else { + Expr::TryCast(datafusion_expr::TryCast { + expr: Box::new(source_arg), + field: target_type.into_nullable_field_ref(), + }) + }; + Ok(ExprSimplifyResult::Simplified(new_expr)) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/functions/src/core/arrowtypeof.rs b/datafusion/functions/src/core/arrowtypeof.rs index f178890f93704..d25db584f6ea0 100644 --- a/datafusion/functions/src/core/arrowtypeof.rs +++ b/datafusion/functions/src/core/arrowtypeof.rs @@ -16,11 +16,10 @@ // under the License. use arrow::datatypes::DataType; -use datafusion_common::{utils::take_function_args, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, utils::take_function_args}; use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; -use std::any::Any; #[user_doc( doc_section(label = "Other Functions"), @@ -60,9 +59,6 @@ impl ArrowTypeOfFunc { } impl ScalarUDFImpl for ArrowTypeOfFunc { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "arrow_typeof" } diff --git a/datafusion/functions/src/core/cast_to_type.rs b/datafusion/functions/src/core/cast_to_type.rs new file mode 100644 index 0000000000000..abc7d440e04ba --- /dev/null +++ b/datafusion/functions/src/core/cast_to_type.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`CastToTypeFunc`]: Implementation of the `cast_to_type` function + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{Result, internal_err, utils::take_function_args}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; +use datafusion_macros::user_doc; + +/// Casts the first argument to the data type of the second argument. +/// +/// Only the type of the second argument is used; its value is ignored. +/// This is useful in macros or generic SQL where you need to preserve +/// or match types dynamically. +/// +/// For example: +/// ```sql +/// select cast_to_type('42', NULL::INTEGER); +/// ``` +#[user_doc( + doc_section(label = "Other Functions"), + description = "Casts the first argument to the data type of the second argument. Only the type of the second argument is used; its value is ignored.", + syntax_example = "cast_to_type(expression, reference)", + sql_example = r#"```sql +> select cast_to_type('42', NULL::INTEGER) as a; ++----+ +| a | ++----+ +| 42 | ++----+ + +> select cast_to_type(1 + 2, NULL::DOUBLE) as b; ++-----+ +| b | ++-----+ +| 3.0 | ++-----+ +```"#, + argument( + name = "expression", + description = "The expression to cast. It can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "reference", + description = "Reference expression whose data type determines the target cast type. The value is ignored." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct CastToTypeFunc { + signature: Signature, +} + +impl Default for CastToTypeFunc { + fn default() -> Self { + Self::new() + } +} + +impl CastToTypeFunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Any), + Coercion::new_exact(TypeSignatureClass::Any), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for CastToTypeFunc { + fn name(&self) -> &str { + "cast_to_type" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be called instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let [source_field, reference_field] = + take_function_args(self.name(), args.arg_fields)?; + let target_type = reference_field.data_type().clone(); + // Nullability is inherited only from the first argument (the value + // being cast). The second argument is used solely for its type, so + // its own nullability is irrelevant. The one exception is when the + // target type is Null – that type is inherently nullable. + let nullable = source_field.is_nullable() || target_type == DataType::Null; + Ok(Field::new(self.name(), target_type, nullable).into()) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("cast_to_type should have been simplified to cast") + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [source_arg, type_arg] = take_function_args(self.name(), args)?; + let target_type = info.get_data_type(&type_arg)?; + let source_type = info.get_data_type(&source_arg)?; + let new_expr = if source_type == target_type { + // the argument's data type is already the correct type + source_arg + } else { + let nullable = info.nullable(&source_arg)? || target_type == DataType::Null; + // Use an actual cast to get the correct type + Expr::Cast(datafusion_expr::Cast { + expr: Box::new(source_arg), + field: Field::new("", target_type, nullable).into(), + }) + }; + Ok(ExprSimplifyResult::Simplified(new_expr)) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index aab1f445d5590..9cf3536443e6c 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -16,17 +16,16 @@ // under the License. use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion_common::{exec_err, internal_err, plan_err, Result}; +use datafusion_common::{Result, exec_err, internal_err, plan_err}; use datafusion_expr::binary::try_type_union_resolution; use datafusion_expr::conditional_expressions::CaseBuilder; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use itertools::Itertools; -use std::any::Any; #[user_doc( doc_section(label = "Conditional Functions"), @@ -65,10 +64,6 @@ impl CoalesceFunc { } impl ScalarUDFImpl for CoalesceFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "coalesce" } @@ -97,7 +92,7 @@ impl ScalarUDFImpl for CoalesceFunc { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + _info: &SimplifyContext, ) -> Result { if args.is_empty() { return plan_err!("coalesce must have at least one argument"); diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 1194d11ba5a2d..93a4cddef453e 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -15,64 +15,78 @@ // specific language governing permissions and limitations // under the License. +use std::sync::{Arc, OnceLock}; + use arrow::array::{ - make_array, make_comparator, Array, BooleanArray, Capacities, MutableArrayData, - Scalar, + Array, BooleanArray, Capacities, MutableArrayData, Scalar, cast::AsArray, make_array, + make_comparator, }; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, FieldRef}; use arrow_buffer::NullBuffer; + use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ - exec_err, internal_err, plan_datafusion_err, utils::take_function_args, Result, - ScalarValue, + Result, ScalarValue, exec_datafusion_err, exec_err, internal_err, plan_datafusion_err, }; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, Expr, ExpressionPlacement, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; -use std::any::Any; -use std::sync::Arc; + +use super::named_struct::NamedStructFunc; +use super::r#struct::StructFunc; #[user_doc( doc_section(label = "Other Functions"), description = r#"Returns a field within a map or a struct with the given key. + Supports nested field access by providing multiple field names. Note: most users invoke `get_field` indirectly via field access syntax such as `my_struct_col['field_name']` which results in a call to - `get_field(my_struct_col, 'field_name')`."#, - syntax_example = "get_field(expression1, expression2)", + `get_field(my_struct_col, 'field_name')`. + Nested access like `my_struct['a']['b']` is optimized to a single call: + `get_field(my_struct, 'a', 'b')`."#, + syntax_example = "get_field(expression, field_name[, field_name2, ...])", sql_example = r#"```sql -> create table t (idx varchar, v varchar) as values ('data','fusion'), ('apache', 'arrow'); -> select struct(idx, v) from t as c; -+-------------------------+ -| struct(c.idx,c.v) | -+-------------------------+ -| {c0: data, c1: fusion} | -| {c0: apache, c1: arrow} | -+-------------------------+ -> select get_field((select struct(idx, v) from t), 'c0'); -+-----------------------+ -| struct(t.idx,t.v)[c0] | -+-----------------------+ -| data | -| apache | -+-----------------------+ -> select get_field((select struct(idx, v) from t), 'c1'); -+-----------------------+ -| struct(t.idx,t.v)[c1] | -+-----------------------+ -| fusion | -| arrow | -+-----------------------+ +> -- Access a field from a struct column +> create table test( struct_col) as values + ({name: 'Alice', age: 30}), + ({name: 'Bob', age: 25}); +> select struct_col from test; ++-----------------------------+ +| struct_col | ++-----------------------------+ +| {name: Alice, age: 30} | +| {name: Bob, age: 25} | ++-----------------------------+ +> select struct_col['name'] as name from test; ++-------+ +| name | ++-------+ +| Alice | +| Bob | ++-------+ + +> -- Nested field access with multiple arguments +> create table test(struct_col) as values + ({outer: {inner_val: 42}}); +> select struct_col['outer']['inner_val'] as result from test; ++--------+ +| result | ++--------+ +| 42 | ++--------+ ```"#, argument( - name = "expression1", - description = "The map or struct to retrieve a field for." + name = "expression", + description = "The map or struct to retrieve a field from." ), argument( - name = "expression2", - description = "The field name in the map or struct to retrieve data for. Must evaluate to a string." + name = "field_name", + description = "The field name(s) to access, in order for nested access. Must evaluate to strings." ) )] #[derive(Debug, PartialEq, Eq, Hash)] @@ -86,43 +100,328 @@ impl Default for GetFieldFunc { } } +/// Process a map array by finding matching keys and extracting corresponding values. +/// +/// This function handles both simple (scalar) and nested key types by using +/// appropriate comparison strategies. +fn process_map_array( + array: &dyn Array, + key_array: Arc, +) -> Result { + let map_array = as_map_array(array)?; + let keys = if key_array.data_type().is_nested() { + let comparator = make_comparator( + map_array.keys().as_ref(), + key_array.as_ref(), + SortOptions::default(), + )?; + let len = map_array.keys().len().min(key_array.len()); + let values = (0..len).map(|i| comparator(i, i).is_eq()).collect(); + let nulls = NullBuffer::union(map_array.keys().nulls(), key_array.nulls()); + BooleanArray::new(values, nulls) + } else { + let be_compared = Scalar::new(key_array); + arrow::compute::kernels::cmp::eq(&be_compared, map_array.keys())? + }; + + let original_data = map_array.entries().column(1).to_data(); + let capacity = Capacities::Array(original_data.len()); + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, capacity); + + for entry in 0..map_array.len() { + let start = map_array.value_offsets()[entry] as usize; + let end = map_array.value_offsets()[entry + 1] as usize; + + let maybe_matched = keys + .slice(start, end - start) + .iter() + .enumerate() + .find(|(_, t)| t.unwrap()); + + if maybe_matched.is_none() { + mutable.extend_nulls(1); + continue; + } + let (match_offset, _) = maybe_matched.unwrap(); + mutable.extend(0, start + match_offset, start + match_offset + 1); + } + + let data = mutable.freeze(); + let data = make_array(data); + Ok(ColumnarValue::Array(data)) +} + +/// Process a map array with a nested key type by iterating through entries +/// and using a comparator for key matching. +/// +/// This specialized version is used when the key type is nested (e.g., struct, list). +fn process_map_with_nested_key( + array: &dyn Array, + key_array: &dyn Array, +) -> Result { + let map_array = as_map_array(array)?; + + let comparator = + make_comparator(map_array.keys().as_ref(), key_array, SortOptions::default())?; + + let original_data = map_array.entries().column(1).to_data(); + let capacity = Capacities::Array(original_data.len()); + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, capacity); + + for entry in 0..map_array.len() { + let start = map_array.value_offsets()[entry] as usize; + let end = map_array.value_offsets()[entry + 1] as usize; + + let mut found_match = false; + for i in start..end { + if comparator(i, 0).is_eq() { + mutable.extend(0, i, i + 1); + found_match = true; + break; + } + } + + if !found_match { + mutable.extend_nulls(1); + } + } + + let data = mutable.freeze(); + let data = make_array(data); + Ok(ColumnarValue::Array(data)) +} + +/// Extract a single field from a struct or map array +fn extract_single_field(base: ColumnarValue, name: ScalarValue) -> Result { + let arrays = ColumnarValue::values_to_arrays(&[base])?; + let array = Arc::clone(&arrays[0]); + + let string_value = name.try_as_str().flatten().map(|s| s.to_string()); + + match (array.data_type(), name, string_value) { + // Dictionary-encoded struct: extract the field from the dictionary's + // values (the deduplicated struct array) and rebuild a dictionary with + // the same keys. This preserves dictionary encoding without expanding. + (DataType::Dictionary(_, value_type), _, Some(field_name)) + if matches!(value_type.as_ref(), DataType::Struct(_)) => + { + let dict = array.as_any_dictionary(); + let values_struct = dict.values().as_struct(); + let field_col = + values_struct.column_by_name(&field_name).ok_or_else(|| { + exec_datafusion_err!( + "Field {field_name} not found in dictionary struct" + ) + })?; + Ok(ColumnarValue::Array( + dict.with_values(Arc::clone(field_col)), + )) + } + (DataType::Map(_, _), ScalarValue::List(arr), _) => { + let key_array: Arc = arr; + process_map_array(&array, key_array) + } + (DataType::Map(_, _), ScalarValue::Struct(arr), _) => { + process_map_array(&array, arr as Arc) + } + (DataType::Map(_, _), other, _) => { + let data_type = other.data_type(); + if data_type.is_nested() { + process_map_with_nested_key(&array, &other.to_array()?) + } else { + process_map_array(&array, other.to_array()?) + } + } + (DataType::Struct(_), _, Some(k)) => { + let as_struct_array = as_struct_array(&array)?; + match as_struct_array.column_by_name(&k) { + None => exec_err!("Field {k} not found in struct"), + Some(col) => Ok(ColumnarValue::Array(Arc::clone(col))), + } + } + (DataType::Struct(_), name, _) => exec_err!( + "get_field is only possible on struct with utf8 indexes. \ + Received with {name:?} index" + ), + (DataType::Null, _, _) => Ok(ColumnarValue::Scalar(ScalarValue::Null)), + (dt, name, _) => exec_err!( + "get_field is only possible on maps or structs. Received {dt} with {name:?} index" + ), + } +} + +/// The shared `get_field` UDF, reused whenever simplification needs to build a +/// fresh `get_field` node (e.g. re-wrapping the remaining access path). +fn get_field_udf() -> Arc { + static GET_FIELD_UDF: OnceLock> = OnceLock::new(); + Arc::clone( + GET_FIELD_UDF + .get_or_init(|| Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::new()))), + ) +} + +/// Try to simplify a `get_field` call whose base is an inline struct +/// constructor by resolving the field access at plan time. +/// +/// Handles both struct constructors: +/// * `named_struct('a', x, 'b', y)` — fields are looked up by name. +/// * `struct(x, y)` — fields are positional and named `c0`, `c1`, ... +/// +/// For example: +/// * `get_field(named_struct('min', a, 'max', b), 'max')` => `b` +/// * `get_field(struct(a, b), 'c1')` => `b` +/// +/// `args` is the (already flattened) argument list of the `get_field` call: +/// `[base, field_name, rest_of_path...]`. When extra path elements remain +/// after resolving the first one (`get_field(named_struct('s', inner), 's', 'k')`), +/// the resolved value is re-wrapped in a `get_field` call for the remaining +/// path so the simplifier can recurse into it on the next pass. +/// +/// Returns `None` — leaving the expression untouched — whenever the rewrite +/// cannot be proven safe, e.g. a non-literal field name, a `named_struct` +/// with a non-literal field name (which might shadow the requested field at +/// runtime), or a field the constructor does not produce. +/// +/// Replacing the access with the selected field expression drops the +/// expressions for the other (unaccessed) fields, so they are no longer +/// evaluated — e.g. `get_field(named_struct('a', 1/0, 'b', x), 'b')` becomes +/// `x` and the `1/0` is never evaluated. This is intentional and matches the +/// optimizer's contract for immutable expressions: a simplification may drop +/// sub-expressions whose value is not observed. +fn simplify_get_field_over_struct_constructor(args: &[Expr]) -> Option { + let [base, field_name, rest @ ..] = args else { + return None; + }; + + // The accessed field name must be a non-empty string literal. + let Expr::Literal(field_name, _) = field_name else { + return None; + }; + let field_name = field_name + .try_as_str() + .flatten() + .filter(|s| !s.is_empty())?; + + let Expr::ScalarFunction(ScalarFunction { + func, + args: ctor_args, + }) = base + else { + return None; + }; + + let value = if func.inner().is::() { + // named_struct(name1, value1, name2, value2, ...) + if !ctor_args.len().is_multiple_of(2) { + return None; + } + let mut matched = None; + for pair in ctor_args.chunks_exact(2) { + // Every name must be a literal string: a non-literal name appearing + // *before* the first match could evaluate to `field_name` at runtime + // and become the real first match (Arrow's `column_by_name` returns + // the first match), so we cannot resolve the access. + // + // We conservatively bail on *any* non-literal name. Once a literal + // match has been found, a later non-literal name is in fact harmless + // — it can never precede the first match — so bailing there is a + // deliberate approximation we accept to keep this check simple, not a + // correctness requirement. + let Expr::Literal(name, _) = &pair[0] else { + return None; + }; + let name = name.try_as_str().flatten()?; + // `column_by_name` resolves to the first match, so do the same. + if matched.is_none() && name == field_name { + matched = Some(&pair[1]); + } + } + matched?.clone() + } else if func.inner().is::() { + // struct(value0, value1, ...) produces fields named c0, c1, ... + let index: usize = field_name.strip_prefix('c')?.parse().ok()?; + // Reject non-canonical spellings (e.g. "c01") that name no real field. + if format!("c{index}") != field_name { + return None; + } + ctor_args.get(index)?.clone() + } else { + return None; + }; + + if rest.is_empty() { + return Some(value); + } + + // Remaining path elements: re-wrap as get_field(value, rest...) and let + // the simplifier resolve the rest on a subsequent pass. + let mut new_args = Vec::with_capacity(rest.len() + 1); + new_args.push(value); + new_args.extend_from_slice(rest); + Some(Expr::ScalarFunction(ScalarFunction::new_udf( + get_field_udf(), + new_args, + ))) +} + impl GetFieldFunc { pub fn new() -> Self { Self { - signature: Signature::any(2, Volatility::Immutable), + signature: Signature::user_defined(Volatility::Immutable), } } } // get_field(struct_array, field_name) impl ScalarUDFImpl for GetFieldFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "get_field" } fn display_name(&self, args: &[Expr]) -> Result { - let [base, field_name] = take_function_args(self.name(), args)?; + if args.len() < 2 { + return exec_err!( + "get_field requires at least 2 arguments, got {}", + args.len() + ); + } - let name = match field_name { - Expr::Literal(name, _) => name.to_string(), - other => other.schema_name().to_string(), - }; + let base = &args[0]; + let field_names: Vec = args[1..] + .iter() + .map(|f| match f { + Expr::Literal(name, _) => name.to_string(), + other => other.schema_name().to_string(), + }) + .collect(); - Ok(format!("{base}[{name}]")) + Ok(format!("{}[{}]", base, field_names.join("]["))) } fn schema_name(&self, args: &[Expr]) -> Result { - let [base, field_name] = take_function_args(self.name(), args)?; - let name = match field_name { - Expr::Literal(name, _) => name.to_string(), - other => other.schema_name().to_string(), - }; + if args.len() < 2 { + return exec_err!( + "get_field requires at least 2 arguments, got {}", + args.len() + ); + } + + let base = &args[0]; + let field_names: Vec = args[1..] + .iter() + .map(|f| match f { + Expr::Literal(name, _) => name.to_string(), + other => other.schema_name().to_string(), + }) + .collect(); - Ok(format!("{}[{}]", base.schema_name(), name)) + Ok(format!( + "{}[{}]", + base.schema_name(), + field_names.join("][") + )) } fn signature(&self) -> &Signature { @@ -134,193 +433,699 @@ impl ScalarUDFImpl for GetFieldFunc { } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - // Length check handled in the signature - debug_assert_eq!(args.scalar_arguments.len(), 2); - - match (&args.arg_fields[0].data_type(), args.scalar_arguments[1].as_ref()) { - (DataType::Map(fields, _), _) => { - match fields.data_type() { - DataType::Struct(fields) if fields.len() == 2 => { - // Arrow's MapArray is essentially a ListArray of structs with two columns. They are - // often named "key", and "value", but we don't require any specific naming here; - // instead, we assume that the second column is the "value" column both here and in - // execution. - let value_field = fields.get(1).expect("fields should have exactly two members"); - - Ok(value_field.as_ref().clone().with_nullable(true).into()) - }, - _ => exec_err!("Map fields must contain a Struct with exactly 2 fields"), + // Validate minimum 2 arguments: base expression + at least one field name + if args.scalar_arguments.len() < 2 { + return exec_err!( + "get_field requires at least 2 arguments, got {}", + args.scalar_arguments.len() + ); + } + + let mut current_field = Arc::clone(&args.arg_fields[0]); + + // Iterate through each field name (starting from index 1) + for (i, sv) in args.scalar_arguments.iter().enumerate().skip(1) { + match current_field.data_type() { + DataType::Map(map_field, _) => { + match map_field.data_type() { + DataType::Struct(fields) if fields.len() == 2 => { + // Arrow's MapArray is essentially a ListArray of structs with two columns. They are + // often named "key", and "value", but we don't require any specific naming here; + // instead, we assume that the second column is the "value" column both here and in + // execution. + let value_field = fields + .get(1) + .expect("fields should have exactly two members"); + + current_field = Arc::new( + value_field.as_ref().clone().with_nullable(true), + ); + } + _ => { + return exec_err!( + "Map fields must contain a Struct with exactly 2 fields" + ); + } + } + } + // Dictionary-encoded struct: resolve the child field from + // the underlying struct, then wrap the result back in the + // same Dictionary type so the promised type matches execution. + DataType::Dictionary(key_type, value_type) + if matches!(value_type.as_ref(), DataType::Struct(_)) => + { + let DataType::Struct(fields) = value_type.as_ref() else { + unreachable!() + }; + let field_name = sv + .as_ref() + .and_then(|sv| { + sv.try_as_str().flatten().filter(|s| !s.is_empty()) + }) + .ok_or_else(|| { + exec_datafusion_err!("Field name must be a non-empty string") + })?; + + let child_field = fields + .iter() + .find(|f| f.name() == field_name) + .ok_or_else(|| { + plan_datafusion_err!("Field {field_name} not found in struct") + })?; + + let dict_type = DataType::Dictionary( + key_type.clone(), + Box::new(child_field.data_type().clone()), + ); + let mut new_field = + child_field.as_ref().clone().with_data_type(dict_type); + if current_field.is_nullable() { + new_field = new_field.with_nullable(true); + } + current_field = Arc::new(new_field); + } + DataType::Struct(fields) => { + let field_name = sv + .as_ref() + .and_then(|sv| { + sv.try_as_str().flatten().filter(|s| !s.is_empty()) + }) + .ok_or_else(|| { + datafusion_common::DataFusionError::Execution( + "Field name must be a non-empty string".to_string(), + ) + })?; + + let child_field = fields + .iter() + .find(|f| f.name() == field_name) + .ok_or_else(|| { + plan_datafusion_err!("Field {field_name} not found in struct") + })?; + + let mut new_field = child_field.as_ref().clone(); + + // If the parent is nullable, then getting the child must be nullable + if current_field.is_nullable() { + new_field = new_field.with_nullable(true); + } + current_field = Arc::new(new_field); + } + DataType::Null => { + return Ok(Field::new(self.name(), DataType::Null, true).into()); + } + other => { + return exec_err!( + "Cannot access field at argument {}: type {} is not Struct, Map, or Null", + i, + other + ); } } - (DataType::Struct(fields),sv) => { - sv.and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty())) - .map_or_else( - || exec_err!("Field name must be a non-empty string"), - |field_name| { - fields.iter().find(|f| f.name() == field_name) - .ok_or(plan_datafusion_err!("Field {field_name} not found in struct")) - .map(|f| { - let mut child_field = f.as_ref().clone(); - - // If the parent is nullable, then getting the child must be nullable, - // so potentially override the return value - - if args.arg_fields[0].is_nullable() { - child_field = child_field.with_nullable(true); - } - Arc::new(child_field) - }) - }) - }, - (DataType::Null, _) => Ok(Field::new(self.name(), DataType::Null, true).into()), - (other, _) => exec_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), } + + Ok(current_field) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let [base, field_name] = take_function_args(self.name(), args.args)?; + if args.args.len() < 2 { + return exec_err!( + "get_field requires at least 2 arguments, got {}", + args.args.len() + ); + } + + let mut current = args.args[0].clone(); - if base.data_type().is_null() { + // Early exit for null base + if current.data_type().is_null() { return Ok(ColumnarValue::Scalar(ScalarValue::Null)); } - let arrays = - ColumnarValue::values_to_arrays(&[base.clone(), field_name.clone()])?; - let array = Arc::clone(&arrays[0]); - let name = match field_name { - ColumnarValue::Scalar(name) => name, - _ => { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); + // Iterate through each field name + for field_name in args.args.iter().skip(1) { + let field_name_scalar = match field_name { + ColumnarValue::Scalar(name) => name.clone(), + _ => { + return exec_err!( + "get_field function requires all field_name arguments to be scalars" + ); + } + }; + + current = extract_single_field(current, field_name_scalar)?; + + // Early exit if we hit null + if current.data_type().is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Null)); } - }; + } - fn process_map_array( - array: &dyn Array, - key_array: Arc, - ) -> Result { - let map_array = as_map_array(array)?; - let keys = if key_array.data_type().is_nested() { - let comparator = make_comparator( - map_array.keys().as_ref(), - key_array.as_ref(), - SortOptions::default(), - )?; - let len = map_array.keys().len().min(key_array.len()); - let values = (0..len).map(|i| comparator(i, i).is_eq()).collect(); - let nulls = - NullBuffer::union(map_array.keys().nulls(), key_array.nulls()); - BooleanArray::new(values, nulls) - } else { - let be_compared = Scalar::new(key_array); - arrow::compute::kernels::cmp::eq(&be_compared, map_array.keys())? - }; + Ok(current) + } - let original_data = map_array.entries().column(1).to_data(); - let capacity = Capacities::Array(original_data.len()); - let mut mutable = - MutableArrayData::with_capacities(vec![&original_data], true, capacity); + fn simplify( + &self, + args: Vec, + _info: &datafusion_expr::simplify::SimplifyContext, + ) -> Result { + // Need at least 2 args (base + field) + if args.len() < 2 { + return Ok(ExprSimplifyResult::Original(args)); + } - for entry in 0..map_array.len() { - let start = map_array.value_offsets()[entry] as usize; - let end = map_array.value_offsets()[entry + 1] as usize; + // Flatten all nested get_field calls in a single pass + // Pattern: get_field(get_field(get_field(base, a), b), c) => get_field(base, a, b, c) + // + // `path_args_stack` collects each level's field-name arguments, + // outermost first; it is reversed below to restore access order. + let mut path_args_stack = vec![&args[1..]]; + let mut current_expr = &args[0]; - let maybe_matched = keys - .slice(start, end - start) - .iter() - .enumerate() - .find(|(_, t)| t.unwrap()); + // Walk down the chain of nested get_field calls + let base_expr = loop { + if let Expr::ScalarFunction(ScalarFunction { + func, + args: inner_args, + }) = current_expr + && func.inner().is::() + { + // Store this level's path arguments (all except the first, which is base/nested call) + path_args_stack.push(&inner_args[1..]); - if maybe_matched.is_none() { - mutable.extend_nulls(1); - continue; - } - let (match_offset, _) = maybe_matched.unwrap(); - mutable.extend(0, start + match_offset, start + match_offset + 1); + // Move to the next level down + current_expr = &inner_args[0]; + continue; } + // Not a get_field call, this is the base expression + break current_expr; + }; + + // Whether any nested get_field calls were collapsed above. + let did_flatten = path_args_stack.len() > 1; - let data = mutable.freeze(); - let data = make_array(data); - Ok(ColumnarValue::Array(data)) + // Build merged args: [base, ...all path args in access order]. + // The stack holds path slices outermost-first, so iterate in reverse. + let mut merged_args = vec![base_expr.clone()]; + for path_slice in path_args_stack.iter().rev() { + merged_args.extend_from_slice(path_slice); } - fn process_map_with_nested_key( - array: &dyn Array, - key_array: &dyn Array, - ) -> Result { - let map_array = as_map_array(array)?; - - let comparator = make_comparator( - map_array.keys().as_ref(), - key_array, - SortOptions::default(), - )?; - - let original_data = map_array.entries().column(1).to_data(); - let capacity = Capacities::Array(original_data.len()); - let mut mutable = - MutableArrayData::with_capacities(vec![&original_data], true, capacity); - - for entry in 0..map_array.len() { - let start = map_array.value_offsets()[entry] as usize; - let end = map_array.value_offsets()[entry + 1] as usize; - - let mut found_match = false; - for i in start..end { - if comparator(i, 0).is_eq() { - mutable.extend(0, i, i + 1); - found_match = true; - break; - } - } + // Resolve field accesses against an inline struct constructor: + // get_field(named_struct('min', a, 'max', b), 'max') => b + if let Some(simplified) = simplify_get_field_over_struct_constructor(&merged_args) + { + return Ok(ExprSimplifyResult::Simplified(simplified)); + } - if !found_match { - mutable.extend_nulls(1); - } - } + if did_flatten { + return Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( + ScalarFunction::new_udf(get_field_udf(), merged_args), + ))); + } + + Ok(ExprSimplifyResult::Original(args)) + } - let data = mutable.freeze(); - let data = make_array(data); - Ok(ColumnarValue::Array(data)) + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() < 2 { + return exec_err!( + "get_field requires at least 2 arguments, got {}", + arg_types.len() + ); } + // Accept types as-is, validation happens in return_field_from_args + Ok(arg_types.to_vec()) + } - match (array.data_type(), name) { - (DataType::Map(_, _), ScalarValue::List(arr)) => { - let key_array: Arc = arr; - process_map_array(&array, key_array) - } - (DataType::Map(_, _), ScalarValue::Struct(arr)) => { - process_map_array(&array, arr as Arc) - } - (DataType::Map(_, _), other) => { - let data_type = other.data_type(); - if data_type.is_nested() { - process_map_with_nested_key(&array, &other.to_array()?) - } else { - process_map_array(&array, other.to_array()?) - } + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } + + fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement { + // get_field can be pushed to leaves if: + // 1. The base (first arg) is a column or already placeable at leaves + // 2. All field keys (remaining args) are literals + if args.is_empty() { + return ExpressionPlacement::KeepInPlace; + } + + let base_placement = args[0]; + let base_is_pushable = matches!( + base_placement, + ExpressionPlacement::Column | ExpressionPlacement::MoveTowardsLeafNodes + ); + + let all_keys_are_literals = args + .iter() + .skip(1) + .all(|p| *p == ExpressionPlacement::Literal); + + if base_is_pushable && all_keys_are_literals { + ExpressionPlacement::MoveTowardsLeafNodes + } else { + ExpressionPlacement::KeepInPlace + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ArrayRef, Int32Array, StructArray}; + use arrow::datatypes::Fields; + + #[test] + fn test_get_field_utf8view_key() -> Result<()> { + // Create a struct array with fields "a" and "b" + let a_values = Int32Array::from(vec![Some(1), Some(2), Some(3)]); + let b_values = Int32Array::from(vec![Some(10), Some(20), Some(30)]); + + let fields: Fields = vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ] + .into(); + + let struct_array = StructArray::new( + fields, + vec![ + Arc::new(a_values) as ArrayRef, + Arc::new(b_values) as ArrayRef, + ], + None, + ); + + let base = ColumnarValue::Array(Arc::new(struct_array)); + + // Use Utf8View key to access field "a" + let key = ScalarValue::Utf8View(Some("a".to_string())); + + let result = extract_single_field(base, key)?; + + let result_array = result.into_array(3)?; + let expected = Int32Array::from(vec![Some(1), Some(2), Some(3)]); + + assert_eq!(result_array.as_ref(), &expected as &dyn Array); + + Ok(()) + } + + #[test] + fn test_get_field_dict_encoded_struct() -> Result<()> { + use arrow::array::{DictionaryArray, StringArray, UInt32Array}; + use arrow::datatypes::UInt32Type; + + let names = Arc::new(StringArray::from(vec!["main", "foo", "bar"])) as ArrayRef; + let ids = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + + let struct_fields: Fields = vec![ + Field::new("name", DataType::Utf8, false), + Field::new("id", DataType::Int32, false), + ] + .into(); + + let values_struct = + Arc::new(StructArray::new(struct_fields, vec![names, ids], None)) as ArrayRef; + + let keys = UInt32Array::from(vec![0u32, 1, 2, 0, 1]); + let dict = DictionaryArray::::try_new(keys, values_struct)?; + + let base = ColumnarValue::Array(Arc::new(dict)); + let key = ScalarValue::Utf8(Some("name".to_string())); + + let result = extract_single_field(base, key)?; + let result_array = result.into_array(5)?; + + assert!( + matches!(result_array.data_type(), DataType::Dictionary(_, _)), + "expected dictionary output, got {:?}", + result_array.data_type() + ); + + let result_dict = result_array + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(result_dict.values().len(), 3); + assert_eq!(result_dict.len(), 5); + + let resolved = arrow::compute::cast(&result_array, &DataType::Utf8)?; + let string_arr = resolved.as_any().downcast_ref::().unwrap(); + assert_eq!(string_arr.value(0), "main"); + assert_eq!(string_arr.value(1), "foo"); + assert_eq!(string_arr.value(2), "bar"); + assert_eq!(string_arr.value(3), "main"); + assert_eq!(string_arr.value(4), "foo"); + + Ok(()) + } + + #[test] + fn test_get_field_nested_dict_struct() -> Result<()> { + use arrow::array::{DictionaryArray, StringArray, UInt32Array}; + use arrow::datatypes::UInt32Type; + + let func_names = Arc::new(StringArray::from(vec!["main", "foo"])) as ArrayRef; + let func_files = Arc::new(StringArray::from(vec!["main.c", "foo.c"])) as ArrayRef; + let func_fields: Fields = vec![ + Field::new("name", DataType::Utf8, false), + Field::new("file", DataType::Utf8, false), + ] + .into(); + let func_struct = Arc::new(StructArray::new( + func_fields.clone(), + vec![func_names, func_files], + None, + )) as ArrayRef; + let func_dict = Arc::new(DictionaryArray::::try_new( + UInt32Array::from(vec![0u32, 1, 0]), + func_struct, + )?) as ArrayRef; + + let line_nums = Arc::new(Int32Array::from(vec![10, 20, 30])) as ArrayRef; + let line_fields: Fields = vec![ + Field::new("num", DataType::Int32, false), + Field::new( + "function", + DataType::Dictionary( + Box::new(DataType::UInt32), + Box::new(DataType::Struct(func_fields)), + ), + false, + ), + ] + .into(); + let line_struct = StructArray::new(line_fields, vec![line_nums, func_dict], None); + + let base = ColumnarValue::Array(Arc::new(line_struct)); + + let func_result = + extract_single_field(base, ScalarValue::Utf8(Some("function".to_string())))?; + + let func_array = func_result.into_array(3)?; + assert!( + matches!(func_array.data_type(), DataType::Dictionary(_, _)), + "expected dictionary for function, got {:?}", + func_array.data_type() + ); + + let name_result = extract_single_field( + ColumnarValue::Array(func_array), + ScalarValue::Utf8(Some("name".to_string())), + )?; + let name_array = name_result.into_array(3)?; + + assert!( + matches!(name_array.data_type(), DataType::Dictionary(_, _)), + "expected dictionary for name, got {:?}", + name_array.data_type() + ); + + let name_dict = name_array + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(name_dict.values().len(), 2); + assert_eq!(name_dict.len(), 3); + + let resolved = arrow::compute::cast(&name_array, &DataType::Utf8)?; + let strings = resolved.as_any().downcast_ref::().unwrap(); + assert_eq!(strings.value(0), "main"); + assert_eq!(strings.value(1), "foo"); + assert_eq!(strings.value(2), "main"); + + Ok(()) + } + + #[test] + fn test_placement_literal_key() { + let func = GetFieldFunc::new(); + + // get_field(col, 'literal') -> leaf-pushable (static field access) + let args = vec![ExpressionPlacement::Column, ExpressionPlacement::Literal]; + assert_eq!( + func.placement(&args), + ExpressionPlacement::MoveTowardsLeafNodes + ); + + // get_field(col, 'a', 'b') -> leaf-pushable (nested static field access) + let args = vec![ + ExpressionPlacement::Column, + ExpressionPlacement::Literal, + ExpressionPlacement::Literal, + ]; + assert_eq!( + func.placement(&args), + ExpressionPlacement::MoveTowardsLeafNodes + ); + + // get_field(get_field(col, 'a'), 'b') represented as MoveTowardsLeafNodes for base + let args = vec![ + ExpressionPlacement::MoveTowardsLeafNodes, + ExpressionPlacement::Literal, + ]; + assert_eq!( + func.placement(&args), + ExpressionPlacement::MoveTowardsLeafNodes + ); + } + + #[test] + fn test_placement_column_key() { + let func = GetFieldFunc::new(); + + // get_field(col, other_col) -> NOT leaf-pushable (dynamic per-row lookup) + let args = vec![ExpressionPlacement::Column, ExpressionPlacement::Column]; + assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace); + + // get_field(col, 'a', other_col) -> NOT leaf-pushable (dynamic nested lookup) + let args = vec![ + ExpressionPlacement::Column, + ExpressionPlacement::Literal, + ExpressionPlacement::Column, + ]; + assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace); + } + + #[test] + fn test_placement_root() { + let func = GetFieldFunc::new(); + + // get_field(root_expr, 'literal') -> NOT leaf-pushable + let args = vec![ + ExpressionPlacement::KeepInPlace, + ExpressionPlacement::Literal, + ]; + assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace); + + // get_field(col, root_expr) -> NOT leaf-pushable + let args = vec![ + ExpressionPlacement::Column, + ExpressionPlacement::KeepInPlace, + ]; + assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace); + } + + #[test] + fn test_placement_edge_cases() { + let func = GetFieldFunc::new(); + + // Empty args -> NOT leaf-pushable + assert_eq!(func.placement(&[]), ExpressionPlacement::KeepInPlace); + + // Just base, no key -> MoveTowardsLeafNodes (not a valid call but should handle gracefully) + let args = vec![ExpressionPlacement::Column]; + assert_eq!( + func.placement(&args), + ExpressionPlacement::MoveTowardsLeafNodes + ); + + // Literal base with literal key -> NOT leaf-pushable (would be constant-folded) + let args = vec![ExpressionPlacement::Literal, ExpressionPlacement::Literal]; + assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace); + } + + // --- get_field over struct constructor simplification -------------------- + + use datafusion_common::Column; + use datafusion_expr::simplify::SimplifyContext; + + /// A non-empty string literal expression. + fn lit_str(s: &str) -> Expr { + Expr::Literal(ScalarValue::Utf8(Some(s.to_string())), None) + } + + /// A column reference expression. + fn col(name: &str) -> Expr { + Expr::Column(Column::from_name(name)) + } + + fn scalar_fn(udf: ScalarUDF, args: Vec) -> Expr { + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), args)) + } + + /// `named_struct(name1, value1, name2, value2, ...)`. + fn named_struct(pairs: Vec<(&str, Expr)>) -> Expr { + let args = pairs + .into_iter() + .flat_map(|(name, value)| [lit_str(name), value]) + .collect(); + scalar_fn(ScalarUDF::new_from_impl(NamedStructFunc::new()), args) + } + + /// `struct(value0, value1, ...)`. + fn struct_fn(values: Vec) -> Expr { + scalar_fn(ScalarUDF::new_from_impl(StructFunc::new()), values) + } + + /// `get_field(args...)`. + fn get_field(args: Vec) -> Expr { + scalar_fn(ScalarUDF::new_from_impl(GetFieldFunc::new()), args) + } + + /// Run `GetFieldFunc::simplify` once and return the rewritten expression, + /// panicking if the input was left unchanged. + fn simplified(args: Vec) -> Expr { + match GetFieldFunc::new() + .simplify(args, &SimplifyContext::default()) + .unwrap() + { + ExprSimplifyResult::Simplified(expr) => expr, + ExprSimplifyResult::Original(args) => { + panic!("expected the expression to be simplified, got {args:?}") } - (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { - let as_struct_array = as_struct_array(&array)?; - match as_struct_array.column_by_name(&k) { - None => exec_err!("get indexed field {k} not found in struct"), - Some(col) => Ok(ColumnarValue::Array(Arc::clone(col))), - } + } + } + + /// Assert that `GetFieldFunc::simplify` leaves the arguments unchanged. + fn assert_not_simplified(args: Vec) { + match GetFieldFunc::new() + .simplify(args.clone(), &SimplifyContext::default()) + .unwrap() + { + ExprSimplifyResult::Original(unchanged) => assert_eq!(unchanged, args), + ExprSimplifyResult::Simplified(expr) => { + panic!("expected no simplification, got {expr:?}") } - (DataType::Struct(_), name) => exec_err!( - "get_field is only possible on struct with utf8 indexes. \ - Received with {name:?} index" - ), - (DataType::Null, _) => Ok(ColumnarValue::Scalar(ScalarValue::Null)), - (dt, name) => exec_err!( - "get_field is only possible on maps with utf8 indexes or struct \ - with utf8 indexes. Received {dt} with {name:?} index" - ), } } - fn documentation(&self) -> Option<&Documentation> { - self.doc() + #[test] + fn simplify_get_field_named_struct_returns_matching_value() { + // get_field(named_struct('min', a, 'max', b), 'max') => b + let args = vec![ + named_struct(vec![("min", col("a")), ("max", col("b"))]), + lit_str("max"), + ]; + assert_eq!(simplified(args), col("b")); + } + + #[test] + fn simplify_get_field_named_struct_first_field() { + // get_field(named_struct('min', a, 'max', b), 'min') => a + let args = vec![ + named_struct(vec![("min", col("a")), ("max", col("b"))]), + lit_str("min"), + ]; + assert_eq!(simplified(args), col("a")); + } + + #[test] + fn simplify_get_field_named_struct_duplicate_names_picks_first() { + // Arrow's `column_by_name` resolves to the first match; mirror that. + let args = vec![ + named_struct(vec![("k", col("a")), ("k", col("b"))]), + lit_str("k"), + ]; + assert_eq!(simplified(args), col("a")); + } + + #[test] + fn simplify_get_field_struct_positional() { + // get_field(struct(a, b), 'c1') => b + let args = vec![struct_fn(vec![col("a"), col("b")]), lit_str("c1")]; + assert_eq!(simplified(args), col("b")); + } + + #[test] + fn simplify_get_field_nested_named_struct() { + // get_field(named_struct('s', named_struct('k', x)), 's', 'k') + // => get_field(named_struct('k', x), 'k') (first pass) + // => x (second pass) + let args = vec![ + named_struct(vec![("s", named_struct(vec![("k", col("x"))]))]), + lit_str("s"), + lit_str("k"), + ]; + let first_pass = simplified(args); + let Expr::ScalarFunction(ScalarFunction { args, .. }) = first_pass else { + panic!("expected a get_field call after the first pass") + }; + assert_eq!(simplified(args), col("x")); + } + + #[test] + fn simplify_get_field_flattens_then_resolves_named_struct() { + // get_field(get_field(named_struct('s', named_struct('k', x)), 's'), 'k') + // flattens to get_field(named_struct(...), 's', 'k') and resolves 's'. + let args = vec![ + get_field(vec![ + named_struct(vec![("s", named_struct(vec![("k", col("x"))]))]), + lit_str("s"), + ]), + lit_str("k"), + ]; + let expected = get_field(vec![named_struct(vec![("k", col("x"))]), lit_str("k")]); + assert_eq!(simplified(args), expected); + } + + #[test] + fn simplify_get_field_dynamic_field_name_left_alone() { + // A non-literal field name cannot be resolved at plan time. + let args = vec![named_struct(vec![("a", col("x"))]), col("field_name")]; + assert_not_simplified(args); + } + + #[test] + fn simplify_get_field_null_field_name_left_alone() { + // A NULL string literal field name resolves to no field, so the + // `try_as_str().flatten()` guard must leave the expression untouched. + let null_field_name = Expr::Literal(ScalarValue::Utf8(None), None); + let args = vec![named_struct(vec![("a", col("x"))]), null_field_name]; + assert_not_simplified(args); + } + + #[test] + fn simplify_get_field_dynamic_struct_name_left_alone() { + // A non-literal name inside named_struct could shadow the requested + // field at runtime, so the rewrite must bail out entirely. + let named_struct_with_dynamic_name = scalar_fn( + ScalarUDF::new_from_impl(NamedStructFunc::new()), + vec![col("dynamic_name"), col("x")], + ); + let args = vec![named_struct_with_dynamic_name, lit_str("a")]; + assert_not_simplified(args); + } + + #[test] + fn simplify_get_field_missing_field_left_alone() { + // The named_struct does not produce field 'missing'. + let args = vec![named_struct(vec![("a", col("x"))]), lit_str("missing")]; + assert_not_simplified(args); + } + + #[test] + fn simplify_get_field_non_canonical_struct_field_left_alone() { + // 'c01' is not a real field name produced by `struct(...)`. + let args = vec![struct_fn(vec![col("a"), col("b")]), lit_str("c01")]; + assert_not_simplified(args); + } + + #[test] + fn simplify_get_field_column_base_left_alone() { + // A plain column base is not a struct constructor. + let args = vec![col("s"), lit_str("a")]; + assert_not_simplified(args); } } diff --git a/datafusion/functions/src/core/greatest.rs b/datafusion/functions/src/core/greatest.rs index 95fd8e64d7274..64eaefb9b887d 100644 --- a/datafusion/functions/src/core/greatest.rs +++ b/datafusion/functions/src/core/greatest.rs @@ -16,17 +16,16 @@ // under the License. use crate::core::greatest_least_utils::GreatestLeastOperator; -use arrow::array::{make_comparator, Array, BooleanArray}; +use arrow::array::{Array, BooleanArray, make_comparator}; use arrow::buffer::BooleanBuffer; -use arrow::compute::kernels::cmp; use arrow::compute::SortOptions; +use arrow::compute::kernels::cmp; use arrow::datatypes::DataType; -use datafusion_common::{assert_eq_or_internal_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err}; use datafusion_doc::Documentation; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; -use std::any::Any; const SORT_OPTIONS: SortOptions = SortOptions { // We want greatest first @@ -90,11 +89,7 @@ impl GreatestLeastOperator for GreatestFunc { SORT_OPTIONS, )?; - if cmp(0, 0).is_ge() { - Ok(lhs) - } else { - Ok(rhs) - } + if cmp(0, 0).is_ge() { Ok(lhs) } else { Ok(rhs) } } /// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array @@ -127,10 +122,6 @@ impl GreatestLeastOperator for GreatestFunc { } impl ScalarUDFImpl for GreatestFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "greatest" } diff --git a/datafusion/functions/src/core/greatest_least_utils.rs b/datafusion/functions/src/core/greatest_least_utils.rs index 78c6864d8c9f5..2714a01832175 100644 --- a/datafusion/functions/src/core/greatest_least_utils.rs +++ b/datafusion/functions/src/core/greatest_least_utils.rs @@ -18,9 +18,9 @@ use arrow::array::{Array, ArrayRef, BooleanArray}; use arrow::compute::kernels::zip::zip; use arrow::datatypes::DataType; -use datafusion_common::{assert_or_internal_err, plan_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, assert_or_internal_err, plan_err}; use datafusion_expr_common::columnar_value::ColumnarValue; -use datafusion_expr_common::type_coercion::binary::type_union_resolution; +use datafusion_expr_common::type_coercion::binary::comparison_coercion; use std::sync::Arc; pub(super) trait GreatestLeastOperator { @@ -120,13 +120,17 @@ pub(super) fn find_coerced_type( data_types: &[DataType], ) -> Result { if data_types.is_empty() { - plan_err!( + return plan_err!( "{} was called without any arguments. It requires at least 1.", Op::NAME - ) - } else if let Some(coerced_type) = type_union_resolution(data_types) { - Ok(coerced_type) - } else { - plan_err!("Cannot find a common type for arguments") + ); + } + let mut coerced = data_types[0].clone(); + for dt in &data_types[1..] { + let Some(next) = comparison_coercion(&coerced, dt) else { + return plan_err!("Cannot find a common type for arguments to {}", Op::NAME); + }; + coerced = next; } + Ok(coerced) } diff --git a/datafusion/functions/src/core/least.rs b/datafusion/functions/src/core/least.rs index 602cd4169a3fd..8b84aa49ab82a 100644 --- a/datafusion/functions/src/core/least.rs +++ b/datafusion/functions/src/core/least.rs @@ -16,17 +16,16 @@ // under the License. use crate::core::greatest_least_utils::GreatestLeastOperator; -use arrow::array::{make_comparator, Array, BooleanArray}; +use arrow::array::{Array, BooleanArray, make_comparator}; use arrow::buffer::BooleanBuffer; -use arrow::compute::kernels::cmp; use arrow::compute::SortOptions; +use arrow::compute::kernels::cmp; use arrow::datatypes::DataType; -use datafusion_common::{assert_eq_or_internal_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err}; use datafusion_doc::Documentation; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; -use std::any::Any; const SORT_OPTIONS: SortOptions = SortOptions { // Having the smallest result first @@ -103,11 +102,7 @@ impl GreatestLeastOperator for LeastFunc { SORT_OPTIONS, )?; - if cmp(0, 0).is_le() { - Ok(lhs) - } else { - Ok(rhs) - } + if cmp(0, 0).is_le() { Ok(lhs) } else { Ok(rhs) } } /// Return boolean array where `arr[i] = lhs[i] <= rhs[i]` for all i, where `arr` is the result array @@ -140,10 +135,6 @@ impl GreatestLeastOperator for LeastFunc { } impl ScalarUDFImpl for LeastFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "least" } diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 0c569d300656f..5657f9d88810c 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -21,7 +21,11 @@ use datafusion_expr::ScalarUDF; use std::sync::Arc; pub mod arrow_cast; +pub mod arrow_field; +pub mod arrow_metadata; +pub mod arrow_try_cast; pub mod arrowtypeof; +pub mod cast_to_type; pub mod coalesce; pub mod expr_ext; pub mod getfield; @@ -35,12 +39,17 @@ pub mod nvl2; pub mod overlay; pub mod planner; pub mod r#struct; +pub mod try_cast_to_type; pub mod union_extract; pub mod union_tag; pub mod version; +pub mod with_metadata; // create UDFs make_udf_function!(arrow_cast::ArrowCastFunc, arrow_cast); +make_udf_function!(arrow_try_cast::ArrowTryCastFunc, arrow_try_cast); +make_udf_function!(cast_to_type::CastToTypeFunc, cast_to_type); +make_udf_function!(try_cast_to_type::TryCastToTypeFunc, try_cast_to_type); make_udf_function!(nullif::NullIfFunc, nullif); make_udf_function!(nvl::NVLFunc, nvl); make_udf_function!(nvl2::NVL2Func, nvl2); @@ -55,6 +64,9 @@ make_udf_function!(least::LeastFunc, least); make_udf_function!(union_extract::UnionExtractFun, union_extract); make_udf_function!(union_tag::UnionTagFunc, union_tag); make_udf_function!(version::VersionFunc, version); +make_udf_function!(arrow_metadata::ArrowMetadataFunc, arrow_metadata); +make_udf_function!(with_metadata::WithMetadataFunc, with_metadata); +make_udf_function!(arrow_field::ArrowFieldFunc, arrow_field); pub mod expr_fn { use datafusion_expr::{Expr, Literal}; @@ -65,7 +77,19 @@ pub mod expr_fn { arg1 arg2 ),( arrow_cast, - "Returns value2 if value1 is NULL; otherwise it returns value1", + "Casts a value to a specific Arrow data type", + arg1 arg2 + ),( + arrow_try_cast, + "Casts a value to a specific Arrow data type, returning NULL if the cast fails", + arg1 arg2 + ),( + cast_to_type, + "Casts the first argument to the data type of the second argument", + arg1 arg2 + ),( + try_cast_to_type, + "Casts the first argument to the data type of the second argument, returning NULL on failure", arg1 arg2 ),( nvl, @@ -83,6 +107,18 @@ pub mod expr_fn { arrow_typeof, "Returns the Arrow type of the input expression.", arg1 + ),( + arrow_field, + "Returns the Arrow field info (name, data_type, nullable, metadata) of the input expression.", + arg1 + ),( + arrow_metadata, + "Returns the metadata of the input expression", + args, + ),( + with_metadata, + "Attaches Arrow field metadata (key/value pairs) to the input expression", + args, ),( r#struct, "Returns a struct with the given arguments", @@ -115,6 +151,13 @@ pub mod expr_fn { super::get_field().call(vec![arg1, arg2.lit()]) } + #[doc = "Returns the value of nested fields by traversing multiple field names"] + pub fn get_field_path(base: Expr, field_names: Vec) -> Expr { + let mut args = vec![base]; + args.extend(field_names); + super::get_field().call(args) + } + #[doc = "Returns the value of the field with the given name from the union when it's selected, or NULL otherwise"] #[expect(clippy::needless_pass_by_value)] pub fn union_extract(arg1: Expr, arg2: impl Literal) -> Expr { @@ -127,6 +170,12 @@ pub fn functions() -> Vec> { vec![ nullif(), arrow_cast(), + arrow_field(), + arrow_try_cast(), + cast_to_type(), + try_cast_to_type(), + arrow_metadata(), + with_metadata(), nvl(), nvl2(), overlay(), diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 1da5148474f8c..71c48ce89e26e 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -15,20 +15,23 @@ // specific language governing permissions and limitations // under the License. +use super::getfield::GetFieldFunc; use arrow::array::StructArray; use arrow::datatypes::{DataType, Field, FieldRef, Fields}; -use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, + StructFieldMapping, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; #[user_doc( doc_section(label = "Struct Functions"), - description = "Returns an Arrow struct using the specified name and input expressions pairs.", + description = "Returns an Arrow struct using the specified name and input expressions pairs. +For information on comparing and ordering struct values (including `NULL` handling), +see [Comparison and Ordering](struct_coercion.md#comparison-and-ordering).", syntax_example = "named_struct(expression1_name, expression1_input[, ..., expression_n_name, expression_n_input])", sql_example = r#" For example, this query converts two columns `a` and `b` to a single column with @@ -78,10 +81,6 @@ impl NamedStructFunc { } impl ScalarUDFImpl for NamedStructFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "named_struct" } @@ -138,7 +137,7 @@ impl ScalarUDFImpl for NamedStructFunc { let return_fields = names .into_iter() - .zip(types.into_iter()) + .zip(types) .map(|(name, data_type)| Ok(Field::new(name, data_type.to_owned(), true))) .collect::>>()?; @@ -177,4 +176,31 @@ impl ScalarUDFImpl for NamedStructFunc { fn documentation(&self) -> Option<&Documentation> { self.doc() } + + fn struct_field_mapping( + &self, + literal_args: &[Option], + ) -> Option { + if literal_args.is_empty() || !literal_args.len().is_multiple_of(2) { + return None; + } + + let mut fields = Vec::with_capacity(literal_args.len() / 2); + for (i, chunk) in literal_args.chunks(2).enumerate() { + match chunk { + [Some(ScalarValue::Utf8(Some(name))), _] => { + fields.push(( + vec![ScalarValue::Utf8(Some(name.clone()))], + i * 2 + 1, // index of the value argument + )); + } + _ => return None, + } + } + + Some(StructFieldMapping { + field_accessor: Arc::new(ScalarUDF::from(GetFieldFunc::new())), + fields, + }) + } } diff --git a/datafusion/functions/src/core/nullif.rs b/datafusion/functions/src/core/nullif.rs index 69d86360cb3cb..f58ae857d4791 100644 --- a/datafusion/functions/src/core/nullif.rs +++ b/datafusion/functions/src/core/nullif.rs @@ -18,12 +18,11 @@ use arrow::datatypes::DataType; use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs}; -use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::nullif::nullif; -use datafusion_common::{utils::take_function_args, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, utils::take_function_args}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; -use std::any::Any; +use datafusion_physical_expr_common::datum::compare_with_eq; #[user_doc( doc_section(label = "Conditional Functions"), @@ -86,9 +85,6 @@ impl NullIfFunc { } impl ScalarUDFImpl for NullIfFunc { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "nullif" } @@ -115,25 +111,29 @@ impl ScalarUDFImpl for NullIfFunc { /// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed. fn nullif_func(args: &[ColumnarValue]) -> Result { let [lhs, rhs] = take_function_args("nullif", args)?; + let is_nested = lhs.data_type().is_nested(); match (lhs, rhs) { (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => { let rhs = rhs.to_scalar()?; - let array = nullif(lhs, &eq(&lhs, &rhs)?)?; + let eq_array = compare_with_eq(lhs, &rhs, is_nested)?; + let array = nullif(lhs, &eq_array)?; Ok(ColumnarValue::Array(array)) } (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => { - let array = nullif(lhs, &eq(&lhs, &rhs)?)?; + let eq_array = compare_with_eq(lhs, rhs, is_nested)?; + let array = nullif(lhs, &eq_array)?; Ok(ColumnarValue::Array(array)) } (ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => { let lhs_s = lhs.to_scalar()?; let lhs_a = lhs.to_array_of_size(rhs.len())?; + let eq_array = compare_with_eq(&lhs_s, rhs, is_nested)?; let array = nullif( // nullif in arrow-select does not support Datum, so we need to convert to array lhs_a.as_ref(), - &eq(&lhs_s, &rhs)?, + &eq_array, )?; Ok(ColumnarValue::Array(array)) } @@ -152,7 +152,12 @@ fn nullif_func(args: &[ColumnarValue]) -> Result { mod tests { use std::sync::Arc; - use arrow::array::*; + use arrow::{ + array::*, + buffer::NullBuffer, + datatypes::{Field, Fields, Int64Type}, + }; + use datafusion_common::DataFusionError; use super::*; @@ -255,6 +260,104 @@ mod tests { Ok(()) } + #[test] + fn nullif_struct() -> Result<()> { + let fields = Fields::from(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Utf8, true), + ]); + + let lhs_a = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); + let lhs_b = Arc::new(StringArray::from(vec![Some("1"), Some("2"), None])); + let lhs_nulls = Some(NullBuffer::from(vec![true, true, false])); + let lhs = ColumnarValue::Array(Arc::new(StructArray::new( + fields.clone(), + vec![lhs_a, lhs_b], + lhs_nulls, + ))); + + let rhs_a = Arc::new(Int64Array::from(vec![Some(1), Some(9), None])); + let rhs_b = Arc::new(StringArray::from(vec![Some("1"), Some("2"), None])); + let rhs_nulls = Some(NullBuffer::from(vec![true, true, false])); + let rhs = ColumnarValue::Array(Arc::new(StructArray::new( + fields.clone(), + vec![rhs_a, rhs_b], + rhs_nulls, + ))); + + let result = nullif_func(&[lhs, rhs])?; + let result = result.into_array(0).expect("Failed to convert to array"); + + let expected_arrays = vec![ + Arc::new(Int64Array::from(vec![None, Some(2), None])) as ArrayRef, + Arc::new(StringArray::from(vec![None, Some("2"), None])) as ArrayRef, + ]; + let expected_nulls = NullBuffer::from(vec![false, true, false]); + + let expected = Arc::new(StructArray::try_new( + fields, + expected_arrays, + Some(expected_nulls), + )?) as ArrayRef; + + assert_eq!(expected.as_ref(), result.as_ref()); + + Ok(()) + } + + #[test] + fn nullif_list() -> Result<()> { + let lhs = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(3)]), + Some(vec![]), + Some(vec![Some(5), Some(6), Some(7)]), + None, + ])); + let lhs = ColumnarValue::Array(lhs); + + let rhs = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + ])); + let rhs = ColumnarValue::Scalar(ScalarValue::List(rhs)); + + let result = nullif_func(&[lhs, rhs])?; + let result = result.into_array(0).expect("Failed to convert to array"); + + let expected = Arc::new(ListArray::from_iter_primitive::(vec![ + None, + Some(vec![Some(3)]), + Some(vec![]), + Some(vec![Some(5), Some(6), Some(7)]), + None, + ])) as ArrayRef; + + assert_eq!(expected.as_ref(), result.as_ref()); + + Ok(()) + } + + #[test] + fn nullif_compare_nested_to_unnested() -> Result<()> { + let lhs = Arc::new(ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2)]), + Some(vec![Some(3)]), + Some(vec![]), + Some(vec![Some(5), Some(6), Some(7)]), + None, + ])); + let lhs = ColumnarValue::Array(lhs); + + let rhs = Arc::new(Int64Array::from(vec![Some(1), Some(3), None, None, None])); + let rhs = ColumnarValue::Array(rhs); + + let result = nullif_func(&[lhs, rhs]); + + assert!(matches!(result, Err(DataFusionError::ArrowError(_, _)))); + + Ok(()) + } + #[test] fn nullif_literal_first() -> Result<()> { let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), Some(4)]); diff --git a/datafusion/functions/src/core/nvl.rs b/datafusion/functions/src/core/nvl.rs index 0b9968a88fc95..3b73dd0165143 100644 --- a/datafusion/functions/src/core/nvl.rs +++ b/datafusion/functions/src/core/nvl.rs @@ -18,7 +18,7 @@ use crate::core::coalesce::CoalesceFunc; use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::Result; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -101,10 +101,6 @@ impl NVLFunc { } impl ScalarUDFImpl for NVLFunc { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "nvl" } @@ -124,7 +120,7 @@ impl ScalarUDFImpl for NVLFunc { fn simplify( &self, args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { self.coalesce.simplify(args, info) } diff --git a/datafusion/functions/src/core/nvl2.rs b/datafusion/functions/src/core/nvl2.rs index 45cb6760d062d..d68296d9b862b 100644 --- a/datafusion/functions/src/core/nvl2.rs +++ b/datafusion/functions/src/core/nvl2.rs @@ -16,13 +16,13 @@ // under the License. use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion_common::{internal_err, utils::take_function_args, Result}; +use datafusion_common::{Result, internal_err, utils::take_function_args}; use datafusion_expr::{ - conditional_expressions::CaseBuilder, - simplify::{ExprSimplifyResult, SimplifyInfo}, - type_coercion::binary::comparison_coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + conditional_expressions::CaseBuilder, + simplify::{ExprSimplifyResult, SimplifyContext}, + type_coercion::binary::type_union_coercion, }; use datafusion_macros::user_doc; @@ -78,10 +78,6 @@ impl NVL2Func { } impl ScalarUDFImpl for NVL2Func { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "nvl2" } @@ -108,7 +104,7 @@ impl ScalarUDFImpl for NVL2Func { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + _info: &SimplifyContext, ) -> Result { let [test, if_non_null, if_null] = take_function_args(self.name(), args)?; @@ -133,11 +129,9 @@ impl ScalarUDFImpl for NVL2Func { [if_non_null, if_null] .iter() .try_fold(tested.clone(), |acc, x| { - // The coerced types found by `comparison_coercion` are not guaranteed to be - // coercible for the arguments. `comparison_coercion` returns more loose - // types that can be coerced to both `acc` and `x` for comparison purpose. - // See `maybe_data_types` for the actual coercion. - let coerced_type = comparison_coercion(&acc, x); + // `type_union_coercion` finds a loose common type; the actual + // coercion is done by `maybe_data_types`. + let coerced_type = type_union_coercion(&acc, x); if let Some(coerced_type) = coerced_type { Ok(coerced_type) } else { diff --git a/datafusion/functions/src/core/overlay.rs b/datafusion/functions/src/core/overlay.rs index 0b3bb2ce7413c..c1f3353a8f413 100644 --- a/datafusion/functions/src/core/overlay.rs +++ b/datafusion/functions/src/core/overlay.rs @@ -15,17 +15,21 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::sync::Arc; - -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ + Array, ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait, StringArrayType, + StringViewArray, +}; +use arrow::buffer::NullBuffer; use arrow::datatypes::DataType; +use crate::strings::{ + BulkNullStringArrayBuilder, GenericStringArrayBuilder, StringWriter, +}; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{ as_generic_string_array, as_int64_array, as_string_view_array, }; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; @@ -84,10 +88,6 @@ impl OverlayFunc { } impl ScalarUDFImpl for OverlayFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "overlay" } @@ -117,145 +117,215 @@ impl ScalarUDFImpl for OverlayFunc { } } -macro_rules! process_overlay { - // For the three-argument case - ($string_array:expr, $characters_array:expr, $pos_num:expr) => {{ - $string_array - .iter() - .zip($characters_array.iter()) - .zip($pos_num.iter()) - .map(|((string, characters), start_pos)| { - match (string, characters, start_pos) { - (Some(string), Some(characters), Some(start_pos)) => { - let string_len = string.chars().count(); - let characters_len = characters.chars().count(); - let replace_len = characters_len as i64; - let mut res = - String::with_capacity(string_len.max(characters_len)); +/// Computes the byte ranges of `string` to keep around the replaced span: the +/// prefix is `string[..prefix_end]` and the suffix is `string[suffix_start..]`. +/// +/// `start_pos` is a 1-based character position; the caller must ensure it is +/// `>= 1`. `replace_len` is the number of characters of `string` to replace, +/// and may be negative (in which case `suffix_start <= prefix_end` and the +/// result re-emits part of the original string). +/// +/// Matches PostgreSQL semantics for codepoint indices past the end of +/// `string`: `prefix_end` and `suffix_start` clamp to `string.len()`. +fn overlay_bounds(string: &str, start_pos: i64, replace_len: i64) -> (usize, usize) { + let start_char_idx = start_pos - 1; + let end_char_idx = start_char_idx.saturating_add(replace_len); - //as sql replace index start from 1 while string index start from 0 - if start_pos > 1 && start_pos - 1 < string_len as i64 { - let start = (start_pos - 1) as usize; - res.push_str(&string[..start]); - } - res.push_str(characters); - // if start + replace_len - 1 >= string_length, just to string end - if start_pos + replace_len - 1 < string_len as i64 { - let end = (start_pos + replace_len - 1) as usize; - res.push_str(&string[end..]); - } - Ok(Some(res)) - } - _ => Ok(None), - } - }) - .collect::>>() - }}; + if string.is_ascii() { + // ASCII fast path: byte index == codepoint index. + let len = string.len() as i64; + let prefix_end = start_char_idx.clamp(0, len) as usize; + let suffix_start = end_char_idx.clamp(0, len) as usize; + return (prefix_end, suffix_start); + } - // For the four-argument case - ($string_array:expr, $characters_array:expr, $pos_num:expr, $len_num:expr) => {{ - $string_array - .iter() - .zip($characters_array.iter()) - .zip($pos_num.iter()) - .zip($len_num.iter()) - .map(|(((string, characters), start_pos), len)| { - match (string, characters, start_pos, len) { - (Some(string), Some(characters), Some(start_pos), Some(len)) => { - let string_len = string.chars().count(); - let characters_len = characters.chars().count(); - let replace_len = len.min(string_len as i64); - let mut res = - String::with_capacity(string_len.max(characters_len)); + let prefix_target = usize::try_from(start_char_idx).unwrap_or(usize::MAX); + let suffix_target = usize::try_from(end_char_idx.max(0)).unwrap_or(usize::MAX); + let target_max = prefix_target.max(suffix_target); - //as sql replace index start from 1 while string index start from 0 - if start_pos > 1 && start_pos - 1 < string_len as i64 { - let start = (start_pos - 1) as usize; - res.push_str(&string[..start]); - } - res.push_str(characters); - // if start + replace_len - 1 >= string_length, just to string end - if start_pos + replace_len - 1 < string_len as i64 { - let end = (start_pos + replace_len - 1) as usize; - res.push_str(&string[end..]); - } - Ok(Some(res)) - } - _ => Ok(None), - } - }) - .collect::>>() - }}; + // Single forward pass over codepoint boundaries records both targets. + // Either target falls through to `string.len()` if past the codepoint + // count. + let mut prefix_byte = string.len(); + let mut suffix_byte = string.len(); + for (count, (byte_idx, _)) in string.char_indices().enumerate() { + if count == prefix_target { + prefix_byte = byte_idx; + } + if count == suffix_target { + suffix_byte = byte_idx; + } + if count == target_max { + break; + } + } + (prefix_byte, suffix_byte) } -/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) -/// Replaces a substring of string1 with string2 starting at the integer bit -/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas -/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead -fn overlay(args: &[ArrayRef]) -> Result { - let use_string_view = args[0].data_type() == &DataType::Utf8View; - if use_string_view { - string_view_overlay::(args) - } else { - string_overlay::(args) +/// Appends the overlay result for one non-null row into `builder`. +#[inline] +fn apply_overlay( + string: &str, + characters: &str, + start_pos: i64, + replace_len: i64, + builder: &mut B, +) -> Result<()> { + if start_pos < 1 { + return exec_err!("overlay start position must be at least 1: {start_pos}"); } + let (prefix_end, suffix_start) = overlay_bounds(string, start_pos, replace_len); + builder.append_with(|w| { + w.write_str(&string[..prefix_end]); + w.write_str(characters); + w.write_str(&string[suffix_start..]); + }); + Ok(()) } -fn string_overlay(args: &[ArrayRef]) -> Result { - match args.len() { - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - let pos_num = as_int64_array(&args[2])?; +#[inline] +fn char_count(characters: &str) -> i64 { + if characters.is_ascii() { + characters.len() as i64 + } else { + characters.chars().count() as i64 + } +} - let result = process_overlay!(string_array, characters_array, pos_num)?; - Ok(Arc::new(result) as ArrayRef) - } - 4 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - let pos_num = as_int64_array(&args[2])?; - let len_num = as_int64_array(&args[3])?; +/// `OVERLAY(string PLACING substring FROM start [FOR count])` +/// +/// Replaces a region of `string` with `substring`, starting at the 1-based +/// character position `start`. If `count` is supplied, that many characters +/// of `string` are replaced; otherwise `count` defaults to the character +/// length of `substring`. +/// +/// ```text +/// overlay('Txxxxas' placing 'hom' from 2 for 4) → 'Thomas' +/// overlay('Txxxxas' placing 'hom' from 2) → 'Thomxas' +/// ``` +fn overlay(args: &[ArrayRef]) -> Result { + if !matches!(args.len(), 3 | 4) { + return exec_err!( + "overlay was called with {} arguments. It requires 3 or 4.", + args.len() + ); + } + let pos_array = as_int64_array(&args[2])?; + let len_array = if args.len() == 4 { + Some(as_int64_array(&args[3])?) + } else { + None + }; - let result = - process_overlay!(string_array, characters_array, pos_num, len_num)?; - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!("overlay was called with {other} arguments. It requires 3 or 4.") - } + if args[0].data_type() == &DataType::Utf8View { + let string_array = as_string_view_array(&args[0])?; + let characters_array = as_string_view_array(&args[1])?; + let data_capacity = visible_view_bytes(string_array) + .saturating_add(visible_view_bytes(characters_array)); + let builder = GenericStringArrayBuilder::::with_capacity( + string_array.len(), + data_capacity, + ); + overlay_inner( + string_array, + characters_array, + pos_array, + len_array, + builder, + ) + } else { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let data_capacity = visible_offset_bytes(string_array) + .saturating_add(visible_offset_bytes(characters_array)); + let builder = GenericStringArrayBuilder::::with_capacity( + string_array.len(), + data_capacity, + ); + overlay_inner( + string_array, + characters_array, + pos_array, + len_array, + builder, + ) } } -fn string_view_overlay(args: &[ArrayRef]) -> Result { - match args.len() { - 3 => { - let string_array = as_string_view_array(&args[0])?; - let characters_array = as_string_view_array(&args[1])?; - let pos_num = as_int64_array(&args[2])?; - - let result = process_overlay!(string_array, characters_array, pos_num)?; - Ok(Arc::new(result) as ArrayRef) - } - 4 => { - let string_array = as_string_view_array(&args[0])?; - let characters_array = as_string_view_array(&args[1])?; - let pos_num = as_int64_array(&args[2])?; - let len_num = as_int64_array(&args[3])?; +/// Drives the per-row OVERLAY computation. A null in any input array +/// produces a null output. +fn overlay_inner<'a, V, B>( + string_array: V, + characters_array: V, + pos_array: &Int64Array, + len_array: Option<&Int64Array>, + mut builder: B, +) -> Result +where + V: StringArrayType<'a, Item = &'a str> + Copy, + B: BulkNullStringArrayBuilder, +{ + let len = string_array.len(); + let nulls = NullBuffer::union_many([ + string_array.nulls(), + characters_array.nulls(), + pos_array.nulls(), + len_array.and_then(|a| a.nulls()), + ]); - let result = - process_overlay!(string_array, characters_array, pos_num, len_num)?; - Ok(Arc::new(result) as ArrayRef) + if let Some(nulls_ref) = nulls.as_ref() { + for i in 0..len { + if nulls_ref.is_null(i) { + builder.append_placeholder(); + continue; + } + // SAFETY: `i < len`, and null bitmap check implies not-null + let string = unsafe { string_array.value_unchecked(i) }; + let characters = unsafe { characters_array.value_unchecked(i) }; + let start_pos = unsafe { pos_array.value_unchecked(i) }; + let replace_len = match len_array { + Some(arr) => unsafe { arr.value_unchecked(i) }, + None => char_count(characters), + }; + apply_overlay(string, characters, start_pos, replace_len, &mut builder)?; } - other => { - exec_err!("overlay was called with {other} arguments. It requires 3 or 4.") + } else { + for i in 0..len { + // SAFETY: `i < len`, and no null bitmap means no nulls + let string = unsafe { string_array.value_unchecked(i) }; + let characters = unsafe { characters_array.value_unchecked(i) }; + let start_pos = unsafe { pos_array.value_unchecked(i) }; + let replace_len = match len_array { + Some(arr) => unsafe { arr.value_unchecked(i) }, + None => char_count(characters), + }; + apply_overlay(string, characters, start_pos, replace_len, &mut builder)?; } } + builder.finish(nulls) +} + +/// Bytes referenced by the visible window of `array`, computed from the +/// per-view lengths. +fn visible_view_bytes(array: &StringViewArray) -> usize { + array.lengths().map(|l| l as usize).sum() +} + +/// Bytes referenced by the visible window of `array`, derived from the offset +/// buffer. +fn visible_offset_bytes(array: &GenericStringArray) -> usize { + let offsets = array.value_offsets(); + // `value_offsets()` always has `array.len() + 1` entries (≥1). + let first = offsets.first().copied().unwrap_or_default(); + let last = offsets.last().copied().unwrap_or_default(); + last.as_usize() - first.as_usize() } #[cfg(test)] mod tests { - use arrow::array::{Int64Array, StringArray}; + use std::sync::Arc; + + use arrow::array::StringArray; use super::*; @@ -270,7 +340,9 @@ mod tests { let res = overlay::(&[string, replace_string, start, end]).unwrap(); let result = as_generic_string_array::(&res).unwrap(); - let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]); + // First row: start=4 is past the end of "123" (len 3). PostgreSQL + // takes the whole string as prefix and appends the replacement. + let expected = StringArray::from(vec!["123abc", "qwertyasdfg", "ijkz", "Thomas"]); assert_eq!(&expected, result); Ok(()) diff --git a/datafusion/functions/src/core/planner.rs b/datafusion/functions/src/core/planner.rs index 227e401156173..22d9b5c4284b2 100644 --- a/datafusion/functions/src/core/planner.rs +++ b/datafusion/functions/src/core/planner.rs @@ -20,7 +20,7 @@ use datafusion_common::Result; use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawDictionaryExpr}; -use datafusion_expr::{lit, Expr}; +use datafusion_expr::{Expr, lit}; use super::named_struct; @@ -34,7 +34,7 @@ impl ExprPlanner for CoreFunctionPlanner { _schema: &DFSchema, ) -> Result> { let mut args = vec![]; - for (k, v) in expr.keys.into_iter().zip(expr.values.into_iter()) { + for (k, v) in expr.keys.into_iter().zip(expr.values) { args.push(k); args.push(v); } diff --git a/datafusion/functions/src/core/struct.rs b/datafusion/functions/src/core/struct.rs index 32c7af80e397f..2697cb46b09f0 100644 --- a/datafusion/functions/src/core/struct.rs +++ b/datafusion/functions/src/core/struct.rs @@ -17,18 +17,19 @@ use arrow::array::StructArray; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_common::{Result, exec_err, internal_err}; use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; #[user_doc( doc_section(label = "Struct Functions"), description = "Returns an Arrow struct using the specified input expressions optionally named. Fields in the returned struct use the optional name or the `cN` naming convention. -For example: `c0`, `c1`, `c2`, etc.", +For example: `c0`, `c1`, `c2`, etc. +For information on comparing and ordering struct values (including `NULL` handling), +see [Comparison and Ordering](struct_coercion.md#comparison-and-ordering).", syntax_example = "struct(expression1[, ..., expression_n])", sql_example = r#"For example, this query converts two columns `a` and `b` to a single column with a struct type of fields `field_a` and `c1`: @@ -86,9 +87,6 @@ impl StructFunc { } impl ScalarUDFImpl for StructFunc { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "struct" } diff --git a/datafusion/functions/src/core/try_cast_to_type.rs b/datafusion/functions/src/core/try_cast_to_type.rs new file mode 100644 index 0000000000000..4c5af4cc6d228 --- /dev/null +++ b/datafusion/functions/src/core/try_cast_to_type.rs @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`TryCastToTypeFunc`]: Implementation of the `try_cast_to_type` function + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{ + Result, datatype::DataTypeExt, internal_err, utils::take_function_args, +}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; +use datafusion_macros::user_doc; + +/// Like [`cast_to_type`](super::cast_to_type::CastToTypeFunc) but returns NULL +/// on cast failure instead of erroring. +/// +/// This is implemented by simplifying `try_cast_to_type(expr, ref)` into +/// `Expr::TryCast` during optimization. +#[user_doc( + doc_section(label = "Other Functions"), + description = "Casts the first argument to the data type of the second argument, returning NULL if the cast fails. Only the type of the second argument is used; its value is ignored.", + syntax_example = "try_cast_to_type(expression, reference)", + sql_example = r#"```sql +> select try_cast_to_type('123', NULL::INTEGER) as a, + try_cast_to_type('not_a_number', NULL::INTEGER) as b; + ++-----+------+ +| a | b | ++-----+------+ +| 123 | NULL | ++-----+------+ +```"#, + argument( + name = "expression", + description = "The expression to cast. It can be a constant, column, or function, and any combination of operators." + ), + argument( + name = "reference", + description = "Reference expression whose data type determines the target cast type. The value is ignored." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct TryCastToTypeFunc { + signature: Signature, +} + +impl Default for TryCastToTypeFunc { + fn default() -> Self { + Self::new() + } +} + +impl TryCastToTypeFunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Any), + Coercion::new_exact(TypeSignatureClass::Any), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TryCastToTypeFunc { + fn name(&self) -> &str { + "try_cast_to_type" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be called instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + // TryCast can always return NULL (on cast failure), so always nullable + let [_, reference_field] = take_function_args(self.name(), args.arg_fields)?; + let target_type = reference_field.data_type().clone(); + Ok(Field::new(self.name(), target_type, true).into()) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("try_cast_to_type should have been simplified to try_cast") + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [source_arg, type_arg] = take_function_args(self.name(), args)?; + let target_type = info.get_data_type(&type_arg)?; + let source_type = info.get_data_type(&source_arg)?; + let new_expr = if source_type == target_type { + source_arg + } else { + Expr::TryCast(datafusion_expr::TryCast { + expr: Box::new(source_arg), + field: target_type.into_nullable_field_ref(), + }) + }; + Ok(ExprSimplifyResult::Simplified(new_expr)) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index a71e2e87388d5..9c0d42edf7fde 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -20,7 +20,7 @@ use arrow::datatypes::{DataType, Field, FieldRef, UnionFields}; use datafusion_common::cast::as_union_array; use datafusion_common::utils::take_function_args; use datafusion_common::{ - exec_datafusion_err, exec_err, internal_err, Result, ScalarValue, + Result, ScalarValue, exec_datafusion_err, exec_err, internal_err, }; use datafusion_doc::Documentation; use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs}; @@ -69,10 +69,6 @@ impl UnionExtractFun { } impl ScalarUDFImpl for UnionExtractFun { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "union_extract" } @@ -117,9 +113,16 @@ impl ScalarUDFImpl for UnionExtractFun { let [array, target_name] = take_function_args("union_extract", args.args)?; let target_name = match target_name { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => Ok(target_name), - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!("union_extract second argument must be a non-null string literal, got a null instead"), - _ => exec_err!("union_extract second argument must be a non-null string literal, got {} instead", target_name.data_type()), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => { + Ok(target_name) + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!( + "union_extract second argument must be a non-null string literal, got a null instead" + ), + _ => exec_err!( + "union_extract second argument must be a non-null string literal, got {} instead", + target_name.data_type() + ), }?; match array { @@ -182,13 +185,14 @@ mod tests { fn test_scalar_value() -> Result<()> { let fun = UnionExtractFun::new(); - let fields = UnionFields::new( + let fields = UnionFields::try_new( vec![1, 3], vec![ Field::new("str", DataType::Utf8, false), Field::new("int", DataType::Int32, false), ], - ); + ) + .unwrap(); let args = vec![ ColumnarValue::Scalar(ScalarValue::Union( diff --git a/datafusion/functions/src/core/union_tag.rs b/datafusion/functions/src/core/union_tag.rs index aeadb8292ba1e..9a349a4b9a8eb 100644 --- a/datafusion/functions/src/core/union_tag.rs +++ b/datafusion/functions/src/core/union_tag.rs @@ -18,7 +18,7 @@ use arrow::array::{Array, AsArray, DictionaryArray, Int8Array, StringArray}; use arrow::datatypes::DataType; use datafusion_common::utils::take_function_args; -use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, exec_datafusion_err, exec_err}; use datafusion_doc::Documentation; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; @@ -63,10 +63,6 @@ impl UnionTagFunc { } impl ScalarUDFImpl for UnionTagFunc { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "union_tag" } @@ -143,7 +139,7 @@ impl ScalarUDFImpl for UnionTagFunc { args.return_field.data_type(), )?)), }, - v => exec_err!("union_tag only support unions, got {:?}", v.data_type()), + v => exec_err!("union_tag only support unions, got {}", v.data_type()), } } @@ -156,8 +152,8 @@ impl ScalarUDFImpl for UnionTagFunc { mod tests { use super::UnionTagFunc; use arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; - use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; + use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 006da4b132ad3..1e8cc8683ab5b 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -18,13 +18,12 @@ //! [`VersionFunc`]: Implementation of the `version` function. use arrow::datatypes::DataType; -use datafusion_common::{utils::take_function_args, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, utils::take_function_args}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; #[user_doc( doc_section(label = "Other Functions"), @@ -59,10 +58,6 @@ impl VersionFunc { } impl ScalarUDFImpl for VersionFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "version" } @@ -99,6 +94,7 @@ mod test { use super::*; use arrow::datatypes::Field; use datafusion_common::config::ConfigOptions; + use datafusion_expr::ScalarFunctionArgs; use datafusion_expr::ScalarUDF; use std::sync::Arc; diff --git a/datafusion/functions/src/core/with_metadata.rs b/datafusion/functions/src/core/with_metadata.rs new file mode 100644 index 0000000000000..481ed713ed7ad --- /dev/null +++ b/datafusion/functions/src/core/with_metadata.rs @@ -0,0 +1,335 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{Result, exec_err, internal_err}; +use datafusion_expr::{ + ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, Volatility, +}; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Other Functions"), + description = "Attaches Arrow field metadata (key/value pairs) to the input expression. Keys must be non-empty constant strings and values must be constant strings (empty values are allowed). Existing metadata on the input field is preserved; new keys overwrite on collision. This is the inverse of `arrow_metadata`.", + syntax_example = "with_metadata(expression, key1, value1[, key2, value2, ...])", + sql_example = r#"```sql +> select arrow_metadata(with_metadata(column1, 'unit', 'ms'), 'unit') from (values (1)); ++---------------------------------------------------------------+ +| arrow_metadata(with_metadata(column1,Utf8("unit"),Utf8("ms")),Utf8("unit")) | ++---------------------------------------------------------------+ +| ms | ++---------------------------------------------------------------+ +> select arrow_metadata(with_metadata(column1, 'unit', 'ms', 'source', 'sensor')) from (values (1)); ++--------------------------+ +| {source: sensor, unit: ms} | ++--------------------------+ +```"#, + argument( + name = "expression", + description = "The expression whose output Arrow field should be annotated. Values flow through unchanged." + ), + argument( + name = "key", + description = "Metadata key. Must be a non-empty constant string literal." + ), + argument( + name = "value", + description = "Metadata value. Must be a constant string literal (may be empty)." + ) +)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct WithMetadataFunc { + signature: Signature, +} + +impl Default for WithMetadataFunc { + fn default() -> Self { + Self::new() + } +} + +impl WithMetadataFunc { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for WithMetadataFunc { + fn name(&self) -> &str { + "with_metadata" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!( + "with_metadata: return_type called instead of return_field_from_args" + ) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + // Require at least the value expression plus one (key, value) pair, + // and an odd total (1 + 2*N). + if args.arg_fields.len() < 3 { + return exec_err!( + "with_metadata requires the input expression plus at least one (key, value) pair (minimum 3 arguments), got {}", + args.arg_fields.len() + ); + } + if args.arg_fields.len().is_multiple_of(2) { + return exec_err!( + "with_metadata requires an odd number of arguments (expression followed by key/value pairs), got {}", + args.arg_fields.len() + ); + } + + let input_field = &args.arg_fields[0]; + let mut metadata = input_field.metadata().clone(); + + // Keys are at indices 1, 3, 5, ...; values at 2, 4, 6, ... + for pair_idx in 0..((args.scalar_arguments.len() - 1) / 2) { + let key_idx = 1 + pair_idx * 2; + let value_idx = key_idx + 1; + + let key = args.scalar_arguments[key_idx] + .and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty())) + .ok_or_else(|| { + datafusion_common::DataFusionError::Execution(format!( + "with_metadata requires argument {key_idx} (key) to be a non-empty constant string" + )) + })?; + + let value = args.scalar_arguments[value_idx] + .and_then(|sv| sv.try_as_str().flatten()) + .ok_or_else(|| { + datafusion_common::DataFusionError::Execution(format!( + "with_metadata requires argument {value_idx} (value) to be a constant string" + )) + })?; + + metadata.insert(key.to_string(), value.to_string()); + } + + // Preserve the input field's name, data type, and nullability; only the + // metadata changes. This makes `with_metadata(col, ...)` a true + // pass-through annotation from a schema perspective. + let field = Field::new( + input_field.name(), + input_field.data_type().clone(), + input_field.is_nullable(), + ) + .with_metadata(metadata); + + Ok(field.into()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // Pure value pass-through. The metadata was attached to the return + // field during planning and flows through record batch schemas; the + // physical operator does not need to rebuild arrays. + Ok(args.args[0].clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::Field; + use datafusion_common::ScalarValue; + use std::sync::Arc; + + fn field(name: &str, dt: DataType, nullable: bool) -> FieldRef { + Arc::new(Field::new(name, dt, nullable)) + } + + fn str_lit(s: &str) -> ScalarValue { + ScalarValue::Utf8(Some(s.to_string())) + } + + #[test] + fn attaches_single_key() { + let udf = WithMetadataFunc::new(); + let input = field("my_col", DataType::Int32, true); + let k = str_lit("unit"); + let v = str_lit("ms"); + let fields = [ + Arc::clone(&input), + field("", DataType::Utf8, false), + field("", DataType::Utf8, false), + ]; + let scalars = [None, Some(&k), Some(&v)]; + let ret = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &scalars, + }) + .unwrap(); + assert_eq!(ret.name(), "my_col"); + assert_eq!(ret.data_type(), &DataType::Int32); + assert!(ret.is_nullable()); + assert_eq!(ret.metadata().get("unit").map(String::as_str), Some("ms")); + } + + #[test] + fn merges_existing_metadata_and_overwrites_on_collision() { + let udf = WithMetadataFunc::new(); + let mut existing = Field::new("x", DataType::Float64, false); + existing.set_metadata( + [ + ("keep".to_string(), "yes".to_string()), + ("unit".to_string(), "old".to_string()), + ] + .into_iter() + .collect(), + ); + let input: FieldRef = Arc::new(existing); + let k = str_lit("unit"); + let v = str_lit("new"); + let fields = [ + Arc::clone(&input), + field("", DataType::Utf8, false), + field("", DataType::Utf8, false), + ]; + let scalars = [None, Some(&k), Some(&v)]; + let ret = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &scalars, + }) + .unwrap(); + assert_eq!(ret.name(), "x"); + assert!(!ret.is_nullable()); + assert_eq!(ret.metadata().get("keep").map(String::as_str), Some("yes")); + assert_eq!(ret.metadata().get("unit").map(String::as_str), Some("new")); + } + + #[test] + fn multiple_pairs() { + let udf = WithMetadataFunc::new(); + let input = field("c", DataType::Utf8, true); + let k1 = str_lit("a"); + let v1 = str_lit("1"); + let k2 = str_lit("b"); + let v2 = str_lit("2"); + let fields = [ + Arc::clone(&input), + field("", DataType::Utf8, false), + field("", DataType::Utf8, false), + field("", DataType::Utf8, false), + field("", DataType::Utf8, false), + ]; + let scalars = [None, Some(&k1), Some(&v1), Some(&k2), Some(&v2)]; + let ret = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &scalars, + }) + .unwrap(); + assert_eq!(ret.metadata().get("a").map(String::as_str), Some("1")); + assert_eq!(ret.metadata().get("b").map(String::as_str), Some("2")); + } + + #[test] + fn rejects_even_arity() { + let udf = WithMetadataFunc::new(); + let input = field("c", DataType::Int32, true); + let a = str_lit("a"); + let b = str_lit("b"); + let c = str_lit("c"); + // 4 args total: input + 3 literals (odd key count) + let fields = [ + Arc::clone(&input), + field("", DataType::Utf8, false), + field("", DataType::Utf8, false), + field("", DataType::Utf8, false), + ]; + let scalars = [None, Some(&a), Some(&b), Some(&c)]; + let err = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &scalars, + }) + .unwrap_err(); + assert!(err.to_string().contains("odd number")); + } + + #[test] + fn rejects_too_few_args() { + let udf = WithMetadataFunc::new(); + let input = field("c", DataType::Int32, true); + let k = str_lit("a"); + let fields = [Arc::clone(&input), field("", DataType::Utf8, false)]; + let scalars = [None, Some(&k)]; + let err = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &scalars, + }) + .unwrap_err(); + assert!(err.to_string().contains("at least one")); + } + + #[test] + fn allows_empty_value() { + let udf = WithMetadataFunc::new(); + let input = field("c", DataType::Int32, true); + let k = str_lit("unit"); + let v = str_lit(""); + let fields = [ + Arc::clone(&input), + field("", DataType::Utf8, false), + field("", DataType::Utf8, false), + ]; + let scalars = [None, Some(&k), Some(&v)]; + let ret = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &scalars, + }) + .unwrap(); + assert_eq!(ret.metadata().get("unit").map(String::as_str), Some("")); + } + + #[test] + fn rejects_non_literal_key() { + let udf = WithMetadataFunc::new(); + let input = field("c", DataType::Int32, true); + let v = str_lit("v"); + let fields = [ + Arc::clone(&input), + field("", DataType::Utf8, true), + field("", DataType::Utf8, false), + ]; + let scalars = [None, None, Some(&v)]; + let err = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &scalars, + }) + .unwrap_err(); + assert!(err.to_string().contains("non-empty constant string")); + } +} diff --git a/datafusion/functions/src/crypto/basic.rs b/datafusion/functions/src/crypto/basic.rs index f1b6c71763cf3..e848daaed1cbf 100644 --- a/datafusion/functions/src/crypto/basic.rs +++ b/datafusion/functions/src/crypto/basic.rs @@ -17,86 +17,22 @@ //! "crypto" DataFusion functions -use arrow::array::{ - Array, ArrayRef, BinaryArray, BinaryArrayType, BinaryViewArray, GenericBinaryArray, - OffsetSizeTrait, -}; -use arrow::array::{AsArray, GenericStringArray, StringViewArray}; +use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, BinaryArrayType}; use arrow::datatypes::DataType; -use blake2::{Blake2b512, Blake2s256, Digest}; +use blake2::{Blake2b512, Blake2s256}; use blake3::Hasher as Blake3; -use datafusion_common::cast::as_binary_array; use arrow::compute::StringArrayType; -use datafusion_common::{ - exec_err, internal_err, plan_err, utils::take_function_args, DataFusionError, Result, - ScalarValue, -}; +use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err, plan_err}; use datafusion_expr::ColumnarValue; use md5::Md5; use sha2::{Sha224, Sha256, Sha384, Sha512}; -use std::fmt::{self, Write}; +use std::fmt; use std::str::FromStr; use std::sync::Arc; -macro_rules! define_digest_function { - ($NAME: ident, $METHOD: ident, $DOC: expr) => { - #[doc = $DOC] - pub fn $NAME(args: &[ColumnarValue]) -> Result { - let [data] = take_function_args(&DigestAlgorithm::$METHOD.to_string(), args)?; - digest_process(data, DigestAlgorithm::$METHOD) - } - }; -} -define_digest_function!( - sha224, - Sha224, - "computes sha224 hash digest of the given input" -); -define_digest_function!( - sha256, - Sha256, - "computes sha256 hash digest of the given input" -); -define_digest_function!( - sha384, - Sha384, - "computes sha384 hash digest of the given input" -); -define_digest_function!( - sha512, - Sha512, - "computes sha512 hash digest of the given input" -); -define_digest_function!( - blake2b, - Blake2b, - "computes blake2b hash digest of the given input" -); -define_digest_function!( - blake2s, - Blake2s, - "computes blake2s hash digest of the given input" -); -define_digest_function!( - blake3, - Blake3, - "computes blake3 hash digest of the given input" -); - -macro_rules! digest_to_scalar { - ($METHOD: ident, $INPUT:expr) => {{ - ScalarValue::Binary($INPUT.as_ref().map(|v| { - let mut digest = $METHOD::default(); - digest.update(v); - #[allow(deprecated)] - digest.finalize().as_slice().to_vec() - })) - }}; -} - -#[derive(Debug, Copy, Clone)] -pub enum DigestAlgorithm { +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub(crate) enum DigestAlgorithm { Md5, Sha224, Sha256, @@ -107,23 +43,6 @@ pub enum DigestAlgorithm { Blake3, } -/// Digest computes a binary hash of the given data, accepts Utf8 or LargeUtf8 and returns a [`ColumnarValue`]. -/// Second argument is the algorithm to use. -/// Standard algorithms are md5, sha1, sha224, sha256, sha384 and sha512. -pub fn digest(args: &[ColumnarValue]) -> Result { - let [data, digest_algorithm] = take_function_args("digest", args)?; - let digest_algorithm = match digest_algorithm { - ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { - Some(Some(method)) => method.parse::(), - _ => exec_err!("Unsupported data type {scalar:?} for function digest"), - }, - ColumnarValue::Array(_) => { - internal_err!("Digest using dynamically decided method is not yet supported") - } - }?; - digest_process(data, digest_algorithm) -} - impl FromStr for DigestAlgorithm { type Err = DataFusionError; fn from_str(name: &str) -> Result { @@ -165,84 +84,35 @@ impl fmt::Display for DigestAlgorithm { } } -/// computes md5 hash digest of the given input -pub fn md5(args: &[ColumnarValue]) -> Result { - let [data] = take_function_args("md5", args)?; - let value = digest_process(data, DigestAlgorithm::Md5)?; - - // md5 requires special handling because of its unique utf8view return type - Ok(match value { - ColumnarValue::Array(array) => { - let binary_array = as_binary_array(&array)?; - let string_array: StringViewArray = binary_array - .iter() - .map(|opt| opt.map(hex_encode::<_>)) - .collect(); - ColumnarValue::Array(Arc::new(string_array)) - } - ColumnarValue::Scalar(ScalarValue::Binary(opt)) => { - ColumnarValue::Scalar(ScalarValue::Utf8View(opt.map(hex_encode::<_>))) - } - _ => return exec_err!("Impossibly got invalid results from digest"), - }) -} - -/// this function exists so that we do not need to pull in the crate hex. it is only used by md5 -/// function below -#[inline] -fn hex_encode>(data: T) -> String { - let mut s = String::with_capacity(data.as_ref().len() * 2); - for b in data.as_ref() { - // Writing to a string never errors, so we can unwrap here. - write!(&mut s, "{b:02x}").unwrap(); - } - s -} -pub fn utf8_or_binary_to_binary_type( - arg_type: &DataType, - name: &str, -) -> Result { - Ok(match arg_type { - DataType::Utf8View - | DataType::LargeUtf8 - | DataType::Utf8 - | DataType::Binary - | DataType::BinaryView - | DataType::LargeBinary => DataType::Binary, - DataType::Null => DataType::Null, - _ => { - return plan_err!( - "The {name:?} function can only accept strings or binary arrays." - ); - } - }) -} macro_rules! digest_to_array { - ($METHOD:ident, $INPUT:expr) => {{ + ($MODULE:ident, $METHOD:ident, $INPUT:expr) => {{ + use $MODULE::Digest; let binary_array: BinaryArray = $INPUT .iter() - .map(|x| { - x.map(|x| { - let mut digest = $METHOD::default(); - digest.update(x); - digest.finalize() - }) - }) + .map(|x| x.map(|x| $METHOD::digest(x))) .collect(); Arc::new(binary_array) }}; } + +macro_rules! digest_to_scalar { + ($MODULE: ident, $METHOD: ident, $INPUT:expr) => {{ + use $MODULE::Digest; + ScalarValue::Binary($INPUT.map(|v| $METHOD::digest(v).as_slice().to_vec())) + }}; +} + impl DigestAlgorithm { /// digest an optional string to its hash value, null values are returned as is - pub fn digest_scalar(self, value: Option<&[u8]>) -> ColumnarValue { + fn digest_scalar(self, value: Option<&[u8]>) -> ColumnarValue { ColumnarValue::Scalar(match self { - Self::Md5 => digest_to_scalar!(Md5, value), - Self::Sha224 => digest_to_scalar!(Sha224, value), - Self::Sha256 => digest_to_scalar!(Sha256, value), - Self::Sha384 => digest_to_scalar!(Sha384, value), - Self::Sha512 => digest_to_scalar!(Sha512, value), - Self::Blake2b => digest_to_scalar!(Blake2b512, value), - Self::Blake2s => digest_to_scalar!(Blake2s256, value), + Self::Md5 => digest_to_scalar!(md5, Md5, value), + Self::Sha224 => digest_to_scalar!(sha2, Sha224, value), + Self::Sha256 => digest_to_scalar!(sha2, Sha256, value), + Self::Sha384 => digest_to_scalar!(sha2, Sha384, value), + Self::Sha512 => digest_to_scalar!(sha2, Sha512, value), + Self::Blake2b => digest_to_scalar!(blake2, Blake2b512, value), + Self::Blake2s => digest_to_scalar!(blake2, Blake2s256, value), Self::Blake3 => ScalarValue::Binary(value.map(|v| { let mut digest = Blake3::default(); digest.update(v); @@ -251,49 +121,7 @@ impl DigestAlgorithm { }) } - /// digest a binary array to their hash values - pub fn digest_binary_array(self, value: &dyn Array) -> Result - where - T: OffsetSizeTrait, - { - let array = match value.data_type() { - DataType::Binary | DataType::LargeBinary => { - let v = value.as_binary::(); - self.digest_binary_array_impl::<&GenericBinaryArray>(&v) - } - DataType::BinaryView => { - let v = value.as_binary_view(); - self.digest_binary_array_impl::<&BinaryViewArray>(&v) - } - other => { - return exec_err!("unsupported type for digest_utf_array: {other:?}") - } - }; - Ok(ColumnarValue::Array(array)) - } - - /// digest a string array to their hash values - pub fn digest_utf8_array(self, value: &dyn Array) -> Result - where - T: OffsetSizeTrait, - { - let array = match value.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => { - let v = value.as_string::(); - self.digest_utf8_array_impl::<&GenericStringArray>(&v) - } - DataType::Utf8View => { - let v = value.as_string_view(); - self.digest_utf8_array_impl::<&StringViewArray>(&v) - } - other => { - return exec_err!("unsupported type for digest_utf_array: {other:?}") - } - }; - Ok(ColumnarValue::Array(array)) - } - - pub fn digest_utf8_array_impl<'a, StringArrType>( + fn digest_utf8_array_impl<'a, StringArrType>( self, input_value: &StringArrType, ) -> ArrayRef @@ -301,13 +129,13 @@ impl DigestAlgorithm { StringArrType: StringArrayType<'a>, { match self { - Self::Md5 => digest_to_array!(Md5, input_value), - Self::Sha224 => digest_to_array!(Sha224, input_value), - Self::Sha256 => digest_to_array!(Sha256, input_value), - Self::Sha384 => digest_to_array!(Sha384, input_value), - Self::Sha512 => digest_to_array!(Sha512, input_value), - Self::Blake2b => digest_to_array!(Blake2b512, input_value), - Self::Blake2s => digest_to_array!(Blake2s256, input_value), + Self::Md5 => digest_to_array!(md5, Md5, input_value), + Self::Sha224 => digest_to_array!(sha2, Sha224, input_value), + Self::Sha256 => digest_to_array!(sha2, Sha256, input_value), + Self::Sha384 => digest_to_array!(sha2, Sha384, input_value), + Self::Sha512 => digest_to_array!(sha2, Sha512, input_value), + Self::Blake2b => digest_to_array!(blake2, Blake2b512, input_value), + Self::Blake2s => digest_to_array!(blake2, Blake2s256, input_value), Self::Blake3 => { let binary_array: BinaryArray = input_value .iter() @@ -324,7 +152,7 @@ impl DigestAlgorithm { } } - pub fn digest_binary_array_impl<'a, BinaryArrType>( + fn digest_binary_array_impl<'a, BinaryArrType>( self, input_value: &BinaryArrType, ) -> ArrayRef @@ -332,13 +160,13 @@ impl DigestAlgorithm { BinaryArrType: BinaryArrayType<'a>, { match self { - Self::Md5 => digest_to_array!(Md5, input_value), - Self::Sha224 => digest_to_array!(Sha224, input_value), - Self::Sha256 => digest_to_array!(Sha256, input_value), - Self::Sha384 => digest_to_array!(Sha384, input_value), - Self::Sha512 => digest_to_array!(Sha512, input_value), - Self::Blake2b => digest_to_array!(Blake2b512, input_value), - Self::Blake2s => digest_to_array!(Blake2s256, input_value), + Self::Md5 => digest_to_array!(md5, Md5, input_value), + Self::Sha224 => digest_to_array!(sha2, Sha224, input_value), + Self::Sha256 => digest_to_array!(sha2, Sha256, input_value), + Self::Sha384 => digest_to_array!(sha2, Sha384, input_value), + Self::Sha512 => digest_to_array!(sha2, Sha512, input_value), + Self::Blake2b => digest_to_array!(blake2, Blake2b512, input_value), + Self::Blake2s => digest_to_array!(blake2, Blake2s256, input_value), Self::Blake3 => { let binary_array: BinaryArray = input_value .iter() @@ -355,26 +183,40 @@ impl DigestAlgorithm { } } } -pub fn digest_process( + +pub(crate) fn digest_process( value: &ColumnarValue, digest_algorithm: DigestAlgorithm, ) -> Result { match value { - ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8View => digest_algorithm.digest_utf8_array::(a.as_ref()), - DataType::Utf8 => digest_algorithm.digest_utf8_array::(a.as_ref()), - DataType::LargeUtf8 => digest_algorithm.digest_utf8_array::(a.as_ref()), - DataType::Binary => digest_algorithm.digest_binary_array::(a.as_ref()), - DataType::LargeBinary => { - digest_algorithm.digest_binary_array::(a.as_ref()) - } - DataType::BinaryView => { - digest_algorithm.digest_binary_array::(a.as_ref()) - } - other => exec_err!( - "Unsupported data type {other:?} for function {digest_algorithm}" - ), - }, + ColumnarValue::Array(a) => { + let output = match a.data_type() { + DataType::Utf8View => { + digest_algorithm.digest_utf8_array_impl(&a.as_string_view()) + } + DataType::Utf8 => { + digest_algorithm.digest_utf8_array_impl(&a.as_string::()) + } + DataType::LargeUtf8 => { + digest_algorithm.digest_utf8_array_impl(&a.as_string::()) + } + DataType::Binary => { + digest_algorithm.digest_binary_array_impl(&a.as_binary::()) + } + DataType::LargeBinary => { + digest_algorithm.digest_binary_array_impl(&a.as_binary::()) + } + DataType::BinaryView => { + digest_algorithm.digest_binary_array_impl(&a.as_binary_view()) + } + other => { + return exec_err!( + "Unsupported data type {other:?} for function {digest_algorithm}" + ); + } + }; + Ok(ColumnarValue::Array(output)) + } ColumnarValue::Scalar(scalar) => { match scalar { ScalarValue::Utf8View(a) diff --git a/datafusion/functions/src/crypto/digest.rs b/datafusion/functions/src/crypto/digest.rs index a4999f72f8d56..84b2c99b00087 100644 --- a/datafusion/functions/src/crypto/digest.rs +++ b/datafusion/functions/src/crypto/digest.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -//! "crypto" DataFusion functions -use super::basic::{digest, utf8_or_binary_to_binary_type}; +use crate::crypto::basic::{DigestAlgorithm, digest_process}; + use arrow::datatypes::DataType; use datafusion_common::{ + Result, exec_err, not_impl_err, types::{logical_binary, logical_string}, - Result, + utils::take_function_args, }; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -28,7 +29,6 @@ use datafusion_expr::{ }; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; -use std::any::Any; #[user_doc( doc_section(label = "Hashing Functions"), @@ -36,16 +36,16 @@ use std::any::Any; syntax_example = "digest(expression, algorithm)", sql_example = r#"```sql > select digest('foo', 'sha256'); -+------------------------------------------+ -| digest(Utf8("foo"), Utf8("sha256")) | -+------------------------------------------+ -| | -+------------------------------------------+ ++------------------------------------------------------------------+ +| digest(Utf8("foo"),Utf8("sha256")) | ++------------------------------------------------------------------+ +| 2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae | ++------------------------------------------------------------------+ ```"#, standard_argument(name = "expression", prefix = "String"), argument( name = "algorithm", - description = "String expression specifying algorithm to use. Must be one of: + description = "String expression specifying algorithm to use. Must be one of: - md5 - sha224 - sha256 @@ -60,6 +60,7 @@ use std::any::Any; pub struct DigestFunc { signature: Signature, } + impl Default for DigestFunc { fn default() -> Self { Self::new() @@ -85,11 +86,8 @@ impl DigestFunc { } } } -impl ScalarUDFImpl for DigestFunc { - fn as_any(&self) -> &dyn Any { - self - } +impl ScalarUDFImpl for DigestFunc { fn name(&self) -> &str { "digest" } @@ -98,14 +96,35 @@ impl ScalarUDFImpl for DigestFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_or_binary_to_binary_type(&arg_types[0], self.name()) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Binary) } + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - digest(&args.args) + let [data, digest_algorithm] = take_function_args(self.name(), &args.args)?; + digest(data, digest_algorithm) } fn documentation(&self) -> Option<&Documentation> { self.doc() } } + +/// Compute binary hash of the given `data` (String or Binary array), according +/// to the specified `digest_algorithm`. See [`DigestAlgorithm`] for supported +/// algorithms. +fn digest( + data: &ColumnarValue, + digest_algorithm: &ColumnarValue, +) -> Result { + let digest_algorithm = match digest_algorithm { + ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { + Some(Some(method)) => method.parse::(), + _ => exec_err!("Unsupported data type {scalar:?} for function digest"), + }, + ColumnarValue::Array(_) => { + not_impl_err!("Digest using dynamically decided method is not yet supported") + } + }?; + digest_process(data, digest_algorithm) +} diff --git a/datafusion/functions/src/crypto/md5.rs b/datafusion/functions/src/crypto/md5.rs index 88859fdee34a7..178aebf0fbd41 100644 --- a/datafusion/functions/src/crypto/md5.rs +++ b/datafusion/functions/src/crypto/md5.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. -//! "crypto" DataFusion functions -use crate::crypto::basic::md5; -use arrow::datatypes::DataType; +use arrow::{array::StringViewArray, datatypes::DataType}; use datafusion_common::{ - plan_err, - types::{logical_binary, logical_string, NativeType}, - Result, + Result, ScalarValue, + cast::as_binary_array, + internal_err, + types::{logical_binary, logical_string}, + utils::take_function_args, }; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -29,7 +29,9 @@ use datafusion_expr::{ }; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; -use std::any::Any; +use std::sync::Arc; + +use crate::crypto::basic::{DigestAlgorithm, digest_process}; #[user_doc( doc_section(label = "Hashing Functions"), @@ -37,11 +39,11 @@ use std::any::Any; syntax_example = "md5(expression)", sql_example = r#"```sql > select md5('foo'); -+-------------------------------------+ -| md5(Utf8("foo")) | -+-------------------------------------+ -| | -+-------------------------------------+ ++----------------------------------+ +| md5(Utf8("foo")) | ++----------------------------------+ +| acbd18db4cc2f85cedef654fccc4a4d8 | ++----------------------------------+ ```"#, standard_argument(name = "expression", prefix = "String") )] @@ -49,6 +51,7 @@ use std::any::Any; pub struct Md5Func { signature: Signature, } + impl Default for Md5Func { fn default() -> Self { Self::new() @@ -60,15 +63,11 @@ impl Md5Func { Self { signature: Signature::one_of( vec![ - TypeSignature::Coercible(vec![Coercion::new_implicit( - TypeSignatureClass::Native(logical_binary()), - vec![TypeSignatureClass::Native(logical_string())], - NativeType::String, + TypeSignature::Coercible(vec![Coercion::new_exact( + TypeSignatureClass::Native(logical_string()), )]), - TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignature::Coercible(vec![Coercion::new_exact( TypeSignatureClass::Native(logical_binary()), - vec![TypeSignatureClass::Native(logical_binary())], - NativeType::Binary, )]), ], Volatility::Immutable, @@ -76,11 +75,8 @@ impl Md5Func { } } } -impl ScalarUDFImpl for Md5Func { - fn as_any(&self) -> &dyn Any { - self - } +impl ScalarUDFImpl for Md5Func { fn name(&self) -> &str { "md5" } @@ -89,30 +85,10 @@ impl ScalarUDFImpl for Md5Func { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match &arg_types[0] { - LargeUtf8 | LargeBinary => Utf8View, - Utf8View | Utf8 | Binary | BinaryView => Utf8View, - Null => Null, - Dictionary(_, t) => match **t { - LargeUtf8 | LargeBinary => Utf8View, - Utf8 | Binary | BinaryView => Utf8View, - Null => Null, - _ => { - return plan_err!( - "the md5 can only accept strings but got {:?}", - **t - ); - } - }, - other => { - return plan_err!( - "The md5 function can only accept strings. Got {other}" - ); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8View) } + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { md5(&args.args) } @@ -121,3 +97,38 @@ impl ScalarUDFImpl for Md5Func { self.doc() } } + +/// Hex encoding lookup table for fast byte-to-hex conversion +const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef"; + +/// Fast hex encoding using a lookup table instead of format strings. +/// This is significantly faster than using `write!("{:02x}")` for each byte. +#[inline] +fn hex_encode(data: impl AsRef<[u8]>) -> String { + let bytes = data.as_ref(); + let mut s = String::with_capacity(bytes.len() * 2); + for &b in bytes { + s.push(HEX_CHARS_LOWER[(b >> 4) as usize] as char); + s.push(HEX_CHARS_LOWER[(b & 0x0f) as usize] as char); + } + s +} + +fn md5(args: &[ColumnarValue]) -> Result { + let [data] = take_function_args("md5", args)?; + let value = digest_process(data, DigestAlgorithm::Md5)?; + + // md5 requires special handling because of its unique utf8view return type + Ok(match value { + ColumnarValue::Array(array) => { + let binary_array = as_binary_array(&array)?; + let string_array: StringViewArray = + binary_array.iter().map(|opt| opt.map(hex_encode)).collect(); + ColumnarValue::Array(Arc::new(string_array)) + } + ColumnarValue::Scalar(ScalarValue::Binary(opt)) => { + ColumnarValue::Scalar(ScalarValue::Utf8View(opt.map(hex_encode))) + } + _ => return internal_err!("Impossibly got invalid results from digest"), + }) +} diff --git a/datafusion/functions/src/crypto/mod.rs b/datafusion/functions/src/crypto/mod.rs index 62ea3c2e27371..fd15db44c795d 100644 --- a/datafusion/functions/src/crypto/mod.rs +++ b/datafusion/functions/src/crypto/mod.rs @@ -23,16 +23,13 @@ use std::sync::Arc; pub mod basic; pub mod digest; pub mod md5; -pub mod sha224; -pub mod sha256; -pub mod sha384; -pub mod sha512; +pub mod sha; make_udf_function!(digest::DigestFunc, digest); make_udf_function!(md5::Md5Func, md5); -make_udf_function!(sha224::SHA224Func, sha224); -make_udf_function!(sha256::SHA256Func, sha256); -make_udf_function!(sha384::SHA384Func, sha384); -make_udf_function!(sha512::SHA512Func, sha512); +make_udf_function!(sha::SHAFunc, sha224, sha::SHAFunc::sha224); +make_udf_function!(sha::SHAFunc, sha256, sha::SHAFunc::sha256); +make_udf_function!(sha::SHAFunc, sha384, sha::SHAFunc::sha384); +make_udf_function!(sha::SHAFunc, sha512, sha::SHAFunc::sha512); pub mod expr_fn { export_functions!(( diff --git a/datafusion/functions/src/crypto/sha.rs b/datafusion/functions/src/crypto/sha.rs new file mode 100644 index 0000000000000..65153fa117eda --- /dev/null +++ b/datafusion/functions/src/crypto/sha.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::crypto::basic::{DigestAlgorithm, digest_process}; + +use arrow::datatypes::DataType; +use datafusion_common::{ + Result, + types::{logical_binary, logical_string}, + utils::take_function_args, +}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, +}; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Hashing Functions"), + description = "Computes the SHA-224 hash of a binary string.", + syntax_example = "sha224(expression)", + sql_example = r#"```sql +> select sha224('foo'); ++----------------------------------------------------------+ +| sha224(Utf8("foo")) | ++----------------------------------------------------------+ +| 0808f64e60d58979fcb676c96ec938270dea42445aeefcd3a4e6f8db | ++----------------------------------------------------------+ +```"#, + standard_argument(name = "expression", prefix = "String") +)] +struct SHA224Doc; + +#[user_doc( + doc_section(label = "Hashing Functions"), + description = "Computes the SHA-256 hash of a binary string.", + syntax_example = "sha256(expression)", + sql_example = r#"```sql +> select sha256('foo'); ++------------------------------------------------------------------+ +| sha256(Utf8("foo")) | ++------------------------------------------------------------------+ +| 2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7ae | ++------------------------------------------------------------------+ +```"#, + standard_argument(name = "expression", prefix = "String") +)] +struct SHA256Doc; + +#[user_doc( + doc_section(label = "Hashing Functions"), + description = "Computes the SHA-384 hash of a binary string.", + syntax_example = "sha384(expression)", + sql_example = r#"```sql +> select sha384('foo'); ++--------------------------------------------------------------------------------------------------+ +| sha384(Utf8("foo")) | ++--------------------------------------------------------------------------------------------------+ +| 98c11ffdfdd540676b1a137cb1a22b2a70350c9a44171d6b1180c6be5cbb2ee3f79d532c8a1dd9ef2e8e08e752a3babb | ++--------------------------------------------------------------------------------------------------+ +```"#, + standard_argument(name = "expression", prefix = "String") +)] +struct SHA384Doc; + +#[user_doc( + doc_section(label = "Hashing Functions"), + description = "Computes the SHA-512 hash of a binary string.", + syntax_example = "sha512(expression)", + sql_example = r#"```sql +> select sha512('foo'); ++----------------------------------------------------------------------------------------------------------------------------------+ +| sha512(Utf8("foo")) | ++----------------------------------------------------------------------------------------------------------------------------------+ +| f7fbba6e0636f890e56fbbf3283e524c6fa3204ae298382d624741d0dc6638326e282c41be5e4254d8820772c5518a2c5a8c0c7f7eda19594a7eb539453e1ed7 | ++----------------------------------------------------------------------------------------------------------------------------------+ +```"#, + standard_argument(name = "expression", prefix = "String") +)] +struct SHA512Doc; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SHAFunc { + signature: Signature, + name: &'static str, + algorithm: DigestAlgorithm, +} + +impl SHAFunc { + pub fn sha224() -> Self { + Self::new("sha224", DigestAlgorithm::Sha224) + } + + pub fn sha256() -> Self { + Self::new("sha256", DigestAlgorithm::Sha256) + } + + pub fn sha384() -> Self { + Self::new("sha384", DigestAlgorithm::Sha384) + } + + pub fn sha512() -> Self { + Self::new("sha512", DigestAlgorithm::Sha512) + } + + fn new(name: &'static str, algorithm: DigestAlgorithm) -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![Coercion::new_exact( + TypeSignatureClass::Native(logical_string()), + )]), + TypeSignature::Coercible(vec![Coercion::new_exact( + TypeSignatureClass::Native(logical_binary()), + )]), + ], + Volatility::Immutable, + ), + name, + algorithm, + } + } +} + +impl ScalarUDFImpl for SHAFunc { + fn name(&self) -> &str { + self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Binary) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [data] = take_function_args(self.name(), args.args)?; + digest_process(&data, self.algorithm) + } + + fn documentation(&self) -> Option<&Documentation> { + match self.algorithm { + DigestAlgorithm::Sha224 => SHA224Doc {}.doc(), + DigestAlgorithm::Sha256 => SHA256Doc {}.doc(), + DigestAlgorithm::Sha384 => SHA384Doc {}.doc(), + DigestAlgorithm::Sha512 => SHA512Doc {}.doc(), + DigestAlgorithm::Md5 + | DigestAlgorithm::Blake2s + | DigestAlgorithm::Blake2b + | DigestAlgorithm::Blake3 => unreachable!(), + } + } +} diff --git a/datafusion/functions/src/datetime/common.rs b/datafusion/functions/src/datetime/common.rs index a0daab66c7af3..2db64beafa9b7 100644 --- a/datafusion/functions/src/datetime/common.rs +++ b/datafusion/functions/src/datetime/common.rs @@ -15,31 +15,57 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; +use arrow::array::timezone::Tz; use arrow::array::{ Array, ArrowPrimitiveType, AsArray, GenericStringArray, PrimitiveArray, StringArrayType, StringViewArray, }; -use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; -use arrow::datatypes::DataType; -use chrono::format::{parse, Parsed, StrftimeItems}; +use arrow::compute::DecimalCast; +use arrow::compute::kernels::cast_utils::string_to_datetime; +use arrow::datatypes::{DataType, TimeUnit}; +use arrow_buffer::ArrowNativeType; use chrono::LocalResult::Single; +use chrono::format::{Parsed, StrftimeItems, parse}; use chrono::{DateTime, TimeZone, Utc}; - use datafusion_common::cast::as_generic_string_array; use datafusion_common::{ - exec_datafusion_err, exec_err, unwrap_or_internal_err, DataFusionError, Result, - ScalarType, ScalarValue, + DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err, + internal_datafusion_err, unwrap_or_internal_err, }; use datafusion_expr::ColumnarValue; /// Error message if nanosecond conversion request beyond supported interval const ERR_NANOSECONDS_NOT_SUPPORTED: &str = "The dates that can be represented as nanoseconds have to be between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804"; -/// Calls string_to_timestamp_nanos and converts the error type -pub(crate) fn string_to_timestamp_nanos_shim(s: &str) -> Result { - string_to_timestamp_nanos(s).map_err(|e| e.into()) +static UTC: LazyLock = LazyLock::new(|| "UTC".parse().expect("UTC is always valid")); + +/// Converts a string representation of a date‑time into a timestamp expressed in +/// nanoseconds since the Unix epoch. +/// +/// This helper is a thin wrapper around the more general `string_to_datetime` +/// function. It accepts an optional `timezone` which, if `None`, defaults to +/// Coordinated Universal Time (UTC). The string `s` must contain a valid +/// date‑time format that can be parsed by the underlying chrono parser. +/// +/// # Return Value +/// +/// * `Ok(i64)` – The number of nanoseconds since `1970‑01‑01T00:00:00Z`. +/// * `Err(DataFusionError)` – If the string cannot be parsed, the parsed +/// value is out of range (between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804) +/// or the parsed value does not correspond to an unambiguous time. +pub(crate) fn string_to_timestamp_nanos_with_timezone( + timezone: &Option, + s: &str, +) -> Result { + let tz = timezone.as_ref().unwrap_or(&UTC); + let dt = string_to_datetime(tz, s)?; + let parsed = dt + .timestamp_nanos_opt() + .ok_or_else(|| exec_datafusion_err!("{ERR_NANOSECONDS_NOT_SUPPORTED}"))?; + + Ok(parsed) } /// Checks that all the arguments from the second are of type [Utf8], [LargeUtf8] or [Utf8View] @@ -69,13 +95,12 @@ pub(crate) fn validate_data_types(args: &[ColumnarValue], name: &str) -> Result< /// Accepts a string and parses it using the [`chrono::format::strftime`] specifiers /// relative to the provided `timezone` /// -/// [IANA timezones] are only supported if the `arrow-array/chrono-tz` feature is enabled -/// -/// * `2023-01-01 040506 America/Los_Angeles` -/// /// If a timestamp is ambiguous, for example as a result of daylight-savings time, an error /// will be returned /// +/// Note that parsing [IANA timezones] is not supported yet in chrono - +/// and this implementation only supports named timezones at the end of the string preceded by a space. +/// /// [`chrono::format::strftime`]: https://docs.rs/chrono/latest/chrono/format/strftime/index.html /// [IANA timezones]: https://www.iana.org/time-zones pub(crate) fn string_to_datetime_formatted( @@ -89,11 +114,55 @@ pub(crate) fn string_to_datetime_formatted( ) }; + let mut datetime_str = s; + let mut format = format; + + // Manually handle the most common case of a named timezone at the end of the timestamp. + // Note that %+ handles 'Z' at the end of the string without a space. This code doesn't + // handle named timezones with no preceding space since that would require writing a + // custom parser (or switching to Jiff) + let tz: Option = if format.trim_end().ends_with(" %Z") { + // grab the string after the last space as the named timezone + if let Some((dt_str, timezone_name)) = datetime_str.trim_end().rsplit_once(' ') { + datetime_str = dt_str; + + // attempt to parse the timezone name + let result: Result = + timezone_name.parse(); + let Ok(tz) = result else { + return Err(err(&result.unwrap_err().to_string())); + }; + + // successfully parsed the timezone name, remove the ' %Z' from the format + format = &format[..format.len() - 3]; + + Some(tz) + } else { + None + } + } else if format.contains("%Z") { + return Err(err( + "'%Z' is only supported at the end of the format string preceded by a space", + )); + } else { + None + }; + let mut parsed = Parsed::new(); - parse(&mut parsed, s, StrftimeItems::new(format)).map_err(|e| err(&e.to_string()))?; + parse(&mut parsed, datetime_str, StrftimeItems::new(format)) + .map_err(|e| err(&e.to_string()))?; - // attempt to parse the string assuming it has a timezone - let dt = parsed.to_datetime(); + let dt = match tz { + Some(tz) => { + // A timezone was manually parsed out, convert it to a fixed offset + match parsed.to_datetime_with_timezone(&tz) { + Ok(dt) => Ok(dt.fixed_offset()), + Err(e) => Err(e), + } + } + // default to parse the string assuming it has a timezone + None => parsed.to_datetime(), + }; if let Err(e) = &dt { // no timezone or other failure, try without a timezone @@ -115,7 +184,7 @@ pub(crate) fn string_to_datetime_formatted( } /// Accepts a string with a `chrono` format and converts it to a -/// nanosecond precision timestamp. +/// nanosecond precision timestamp relative to the provided `timezone`. /// /// See [`chrono::format::strftime`] for the full set of supported formats. /// @@ -141,19 +210,21 @@ pub(crate) fn string_to_datetime_formatted( /// /// [`chrono::format::strftime`]: https://docs.rs/chrono/latest/chrono/format/strftime/index.html #[inline] -pub(crate) fn string_to_timestamp_nanos_formatted( +pub(crate) fn string_to_timestamp_nanos_formatted_with_timezone( + timezone: &Option, s: &str, format: &str, ) -> Result { - string_to_datetime_formatted(&Utc, s, format)? - .naive_utc() - .and_utc() + let dt = string_to_datetime_formatted(timezone.as_ref().unwrap_or(&UTC), s, format)?; + let parsed = dt .timestamp_nanos_opt() - .ok_or_else(|| exec_datafusion_err!("{ERR_NANOSECONDS_NOT_SUPPORTED}")) + .ok_or_else(|| exec_datafusion_err!("{ERR_NANOSECONDS_NOT_SUPPORTED}"))?; + + Ok(parsed) } /// Accepts a string with a `chrono` format and converts it to a -/// millisecond precision timestamp. +/// millisecond precision timestamp relative to the provided `timezone`. /// /// See [`chrono::format::strftime`] for the full set of supported formats. /// @@ -176,14 +247,14 @@ pub(crate) fn string_to_timestamp_millis_formatted(s: &str, format: &str) -> Res .timestamp_millis()) } -pub(crate) fn handle( +pub(crate) fn handle( args: &[ColumnarValue], op: F, name: &str, + dt: &DataType, ) -> Result where O: ArrowPrimitiveType, - S: ScalarType, F: Fn(&str) -> Result, { match &args[0] { @@ -210,8 +281,13 @@ where }, ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { Some(a) => { - let result = a.as_ref().map(|x| op(x)).transpose()?; - Ok(ColumnarValue::Scalar(S::scalar(result))) + let result = a + .as_ref() + .map(|x| op(x)) + .transpose()? + .and_then(|v| v.to_i64()); + let s = scalar_value(dt, result)?; + Ok(ColumnarValue::Scalar(s)) } _ => exec_err!("Unsupported data type {scalar:?} for function {name}"), }, @@ -221,15 +297,15 @@ where // Given a function that maps a `&str`, `&str` to an arrow native type, // returns a `ColumnarValue` where the function is applied to either a `ArrayRef` or `ScalarValue` // depending on the `args`'s variant. -pub(crate) fn handle_multiple( +pub(crate) fn handle_multiple( args: &[ColumnarValue], op: F, op2: M, name: &str, + dt: &DataType, ) -> Result where O: ArrowPrimitiveType, - S: ScalarType, F: Fn(&str, &str) -> Result, M: Fn(O::Native) -> O::Native, { @@ -243,14 +319,24 @@ where DataType::Utf8View | DataType::LargeUtf8 | DataType::Utf8 => { // all good } - other => return exec_err!("Unsupported data type {other:?} for function {name}, arg # {pos}"), + other => { + return exec_err!( + "Unsupported data type {other:?} for function {name}, arg # {pos}" + ); + } }, ColumnarValue::Scalar(arg) => { match arg.data_type() { - DataType::Utf8View| DataType::LargeUtf8 | DataType::Utf8 => { + DataType::Utf8View + | DataType::LargeUtf8 + | DataType::Utf8 => { // all good } - other => return exec_err!("Unsupported data type {other:?} for function {name}, arg # {pos}"), + other => { + return exec_err!( + "Unsupported data type {other:?} for function {name}, arg # {pos}" + ); + } } } } @@ -280,15 +366,17 @@ where | ScalarValue::Utf8(x), ) = v else { - return exec_err!("Unsupported data type {v:?} for function {name}, arg # {pos}"); + return exec_err!( + "Unsupported data type {v:?} for function {name}, arg # {pos}" + ); }; if let Some(s) = x { match op(a, s.as_str()) { Ok(r) => { - ret = Some(Ok(ColumnarValue::Scalar(S::scalar(Some( - op2(r), - ))))); + let result = op2(r).to_i64(); + let s = scalar_value(dt, result)?; + ret = Some(Ok(ColumnarValue::Scalar(s))); break; } Err(e) => ret = Some(Err(e)), @@ -442,3 +530,16 @@ where // first map is the iterator, second is for the `Option<_>` array.iter().map(|x| x.map(&op).transpose()).collect() } + +fn scalar_value(dt: &DataType, r: Option) -> Result { + match dt { + DataType::Date32 => Ok(ScalarValue::Date32(r.and_then(|v| v.to_i32()))), + DataType::Timestamp(u, tz) => match u { + TimeUnit::Second => Ok(ScalarValue::TimestampSecond(r, tz.clone())), + TimeUnit::Millisecond => Ok(ScalarValue::TimestampMillisecond(r, tz.clone())), + TimeUnit::Microsecond => Ok(ScalarValue::TimestampMicrosecond(r, tz.clone())), + TimeUnit::Nanosecond => Ok(ScalarValue::TimestampNanosecond(r, tz.clone())), + }, + t => Err(internal_datafusion_err!("Unsupported data type: {t:?}")), + } +} diff --git a/datafusion/functions/src/datetime/current_date.rs b/datafusion/functions/src/datetime/current_date.rs index da690b4e6be18..d07a3b1caf13b 100644 --- a/datafusion/functions/src/datetime/current_date.rs +++ b/datafusion/functions/src/datetime/current_date.rs @@ -15,17 +15,16 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::array::timezone::Tz; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Date32; use chrono::{Datelike, NaiveDate, TimeZone}; -use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_common::{Result, ScalarValue, internal_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -38,7 +37,24 @@ The `current_date()` return value is determined at query time and will return th "#, syntax_example = r#"current_date() (optional) SET datafusion.execution.time_zone = '+00:00'; - SELECT current_date();"# + SELECT current_date();"#, + sql_example = r#"```sql +> SELECT current_date(); ++----------------+ +| current_date() | ++----------------+ +| 2024-12-23 | ++----------------+ + +-- The current date is based on the session time zone (UTC by default) +> SET datafusion.execution.time_zone = 'Asia/Tokyo'; +> SELECT current_date(); ++----------------+ +| current_date() | ++----------------+ +| 2024-12-24 | ++----------------+ +```"# )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct CurrentDateFunc { @@ -68,10 +84,6 @@ impl CurrentDateFunc { /// wherever it appears within a single statement. This value is /// chosen during planning time. impl ScalarUDFImpl for CurrentDateFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "current_date" } @@ -84,10 +96,7 @@ impl ScalarUDFImpl for CurrentDateFunc { Ok(Date32) } - fn invoke_with_args( - &self, - _args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { internal_err!( "invoke should not be called on a simplified current_date() function" ) @@ -99,23 +108,20 @@ impl ScalarUDFImpl for CurrentDateFunc { fn simplify( &self, - _args: Vec, - info: &dyn SimplifyInfo, + args: Vec, + info: &SimplifyContext, ) -> Result { - let now_ts = info.execution_props().query_execution_start_time; + let Some(now_ts) = info.query_execution_start_time() else { + return Ok(ExprSimplifyResult::Original(args)); + }; // Get timezone from config and convert to local time let days = info - .execution_props() .config_options() - .and_then(|config| { - config - .execution - .time_zone - .as_ref() - .map(|tz| tz.parse::().ok()) - }) - .flatten() + .execution + .time_zone + .as_ref() + .and_then(|tz| tz.parse::().ok()) .map_or_else( || datetime_to_days(&now_ts), |tz| { diff --git a/datafusion/functions/src/datetime/current_time.rs b/datafusion/functions/src/datetime/current_time.rs index 9f3456b8777f0..92f4ae5e66f02 100644 --- a/datafusion/functions/src/datetime/current_time.rs +++ b/datafusion/functions/src/datetime/current_time.rs @@ -21,13 +21,13 @@ use arrow::datatypes::DataType::Time64; use arrow::datatypes::TimeUnit::Nanosecond; use chrono::TimeZone; use chrono::Timelike; -use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_common::{Result, ScalarValue, internal_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; #[user_doc( doc_section(label = "Time and Date Functions"), @@ -40,7 +40,24 @@ The session time zone can be set using the statement 'SET datafusion.execution.t "#, syntax_example = r#"current_time() (optional) SET datafusion.execution.time_zone = '+00:00'; - SELECT current_time();"# + SELECT current_time();"#, + sql_example = r#"```sql +> SELECT current_time(); ++--------------------+ +| current_time() | ++--------------------+ +| 06:30:00.123456789 | ++--------------------+ + +-- The current time is based on the session time zone (UTC by default) +> SET datafusion.execution.time_zone = 'Asia/Tokyo'; +> SELECT current_time(); ++--------------------+ +| current_time() | ++--------------------+ +| 15:30:00.123456789 | ++--------------------+ +```"# )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct CurrentTimeFunc { @@ -68,10 +85,6 @@ impl CurrentTimeFunc { /// wherever it appears within a single statement. This value is /// chosen during planning time. impl ScalarUDFImpl for CurrentTimeFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "current_time" } @@ -84,10 +97,7 @@ impl ScalarUDFImpl for CurrentTimeFunc { Ok(Time64(Nanosecond)) } - fn invoke_with_args( - &self, - _args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { internal_err!( "invoke should not be called on a simplified current_time() function" ) @@ -95,23 +105,20 @@ impl ScalarUDFImpl for CurrentTimeFunc { fn simplify( &self, - _args: Vec, - info: &dyn SimplifyInfo, + args: Vec, + info: &SimplifyContext, ) -> Result { - let now_ts = info.execution_props().query_execution_start_time; + let Some(now_ts) = info.query_execution_start_time() else { + return Ok(ExprSimplifyResult::Original(args)); + }; // Try to get timezone from config and convert to local time let nano = info - .execution_props() .config_options() - .and_then(|config| { - config - .execution - .time_zone - .as_ref() - .map(|tz| tz.parse::().ok()) - }) - .flatten() + .execution + .time_zone + .as_ref() + .and_then(|tz| tz.parse::().ok()) .map_or_else( || datetime_to_time_nanos(&now_ts), |tz| { @@ -143,46 +150,24 @@ fn datetime_to_time_nanos(dt: &chrono::DateTime) -> Option Result { - Ok(false) - } - - fn nullable(&self, _expr: &Expr) -> Result { - Ok(true) - } - - fn execution_props(&self) -> &ExecutionProps { - &self.execution_props - } - - fn get_data_type(&self, _expr: &Expr) -> Result { - Ok(Time64(Nanosecond)) - } - } - - fn set_session_timezone_env(tz: &str, start_time: DateTime) -> MockSimplifyInfo { - let mut config = datafusion_common::config::ConfigOptions::default(); + fn set_session_timezone_env(tz: &str, start_time: DateTime) -> SimplifyContext { + let mut config = ConfigOptions::default(); config.execution.time_zone = if tz.is_empty() { None } else { Some(tz.to_string()) }; - let mut execution_props = - ExecutionProps::new().with_query_execution_start_time(start_time); - execution_props.config_options = Some(Arc::new(config)); - MockSimplifyInfo { execution_props } + let schema = Arc::new(DFSchema::empty()); + SimplifyContext::builder() + .with_schema(schema) + .with_config_options(Arc::new(config)) + .with_query_execution_start_time(Some(start_time)) + .build() } #[test] @@ -225,6 +210,9 @@ mod tests { // 10 hours in nanoseconds let expected_offset = 10i64 * 3600 * 1_000_000_000; - assert_eq!(difference, expected_offset, "Expected 10-hour offset difference in nanoseconds between UTC+05:00 and UTC-05:00"); + assert_eq!( + difference, expected_offset, + "Expected 10-hour offset difference in nanoseconds between UTC+05:00 and UTC-05:00" + ); } } diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 92af123dbafac..38b491e42bcbd 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::temporal_conversions::NANOSECONDS; @@ -24,18 +23,25 @@ use arrow::array::types::{ TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use arrow::array::{ArrayRef, PrimitiveArray}; -use arrow::datatypes::DataType::{Null, Timestamp, Utf8}; +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; +use arrow::datatypes::DataType::{Time32, Time64, Timestamp}; use arrow::datatypes::IntervalUnit::{DayTime, MonthDayNano}; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; -use arrow::datatypes::{DataType, TimeUnit}; - +use arrow::datatypes::{ + DataType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimeUnit, +}; +use arrow::error::ArrowError; +use arrow::temporal_conversions::NANOSECONDS_IN_DAY; use datafusion_common::cast::as_primitive_array; -use datafusion_common::{exec_err, not_impl_err, plan_err, Result, ScalarValue}; -use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_common::{ + DataFusionError, Result, ScalarValue, exec_err, not_impl_err, plan_err, +}; use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TIMEZONE_WILDCARD, Volatility, }; use datafusion_macros::user_doc; @@ -71,6 +77,17 @@ FROM VALUES ('2023-01-01T18:18:18Z'), ('2023-01-03T19:00:03Z') t(time); | 2023-01-03T03:00:00 | +---------------------+ 2 row(s) fetched. + +-- Bin the time into 15 minute intervals starting at 1 min +> SELECT date_bin(interval '15 minutes', time, TIME '00:01:00') as bin +FROM VALUES (TIME '02:18:18'), (TIME '19:00:03') t(time); ++----------+ +| bin | ++----------+ +| 02:16:00 | +| 18:46:00 | ++----------+ +2 row(s) fetched. ```"#, argument(name = "interval", description = "Bin interval."), argument( @@ -109,7 +126,7 @@ impl Default for DateBinFunc { impl DateBinFunc { pub fn new() -> Self { let base_sig = |array_type: TimeUnit| { - vec![ + let mut v = vec![ Exact(vec![ DataType::Interval(MonthDayNano), Timestamp(array_type, None), @@ -146,7 +163,44 @@ impl DateBinFunc { DataType::Interval(DayTime), Timestamp(array_type, Some(TIMEZONE_WILDCARD.into())), ]), - ] + ]; + + match array_type { + Second | Millisecond => { + v.append(&mut vec![ + Exact(vec![ + DataType::Interval(MonthDayNano), + Time32(array_type), + Time32(array_type), + ]), + Exact(vec![DataType::Interval(MonthDayNano), Time32(array_type)]), + Exact(vec![ + DataType::Interval(DayTime), + Time32(array_type), + Time32(array_type), + ]), + Exact(vec![DataType::Interval(DayTime), Time32(array_type)]), + ]); + } + Microsecond | Nanosecond => { + v.append(&mut vec![ + Exact(vec![ + DataType::Interval(DayTime), + Time64(array_type), + Time64(array_type), + ]), + Exact(vec![DataType::Interval(DayTime), Time64(array_type)]), + Exact(vec![ + DataType::Interval(MonthDayNano), + Time64(array_type), + Time64(array_type), + ]), + Exact(vec![DataType::Interval(MonthDayNano), Time64(array_type)]), + ]); + } + } + + v }; let full_sig = [Nanosecond, Microsecond, Millisecond, Second] @@ -162,10 +216,6 @@ impl DateBinFunc { } impl ScalarUDFImpl for DateBinFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "date_bin" } @@ -176,28 +226,39 @@ impl ScalarUDFImpl for DateBinFunc { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[1] { - Timestamp(Nanosecond, None) | Utf8 | Null => Ok(Timestamp(Nanosecond, None)), - Timestamp(Nanosecond, tz_opt) => Ok(Timestamp(Nanosecond, tz_opt.clone())), - Timestamp(Microsecond, tz_opt) => Ok(Timestamp(Microsecond, tz_opt.clone())), - Timestamp(Millisecond, tz_opt) => Ok(Timestamp(Millisecond, tz_opt.clone())), - Timestamp(Second, tz_opt) => Ok(Timestamp(Second, tz_opt.clone())), + Timestamp(tu, tz_opt) => Ok(Timestamp(*tu, tz_opt.clone())), + Time32(tu) => Ok(Time32(*tu)), + Time64(tu) => Ok(Time64(*tu)), _ => plan_err!( - "The date_bin function can only accept timestamp as the second arg." + "The date_bin function can only accept timestamp or time as the second arg." ), } } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = &args.args; if args.len() == 2 { - // Default to unix EPOCH - let origin = ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - Some(0), - Some("+00:00".into()), - )); + let origin = match args[1].data_type() { + Time32(Second) => { + ColumnarValue::Scalar(ScalarValue::Time32Second(Some(0))) + } + Time32(Millisecond) => { + ColumnarValue::Scalar(ScalarValue::Time32Millisecond(Some(0))) + } + Time64(Microsecond) => { + ColumnarValue::Scalar(ScalarValue::Time64Microsecond(Some(0))) + } + Time64(Nanosecond) => { + ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(Some(0))) + } + _ => { + // Default to unix EPOCH + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(0), + Some("+00:00".into()), + )) + } + }; date_bin_impl(&args[0], &args[1], &origin) } else if args.len() == 3 { date_bin_impl(&args[0], &args[1], &args[2]) @@ -227,6 +288,18 @@ impl ScalarUDFImpl for DateBinFunc { } } +const NANOS_PER_MICRO: i64 = 1_000; +const NANOS_PER_MILLI: i64 = 1_000_000; +const NANOS_PER_SEC: i64 = NANOSECONDS; +/// Function type for binning timestamps into intervals +/// +/// Arguments: +/// * `stride` - Interval width (nanoseconds for time-based, months for month-based) +/// * `source` - Timestamp to bin (nanoseconds since epoch) +/// * `origin` - Origin timestamp (nanoseconds since epoch) +/// +/// Returns: Binned timestamp in nanoseconds, or error if out of range +type BinFunction = fn(i64, i64, i64) -> Result; enum Interval { Nanoseconds(i64), Months(i64), @@ -241,7 +314,7 @@ impl Interval { /// `source` is the timestamp being binned /// /// `origin` is the time, in nanoseconds, where windows are measured from - fn bin_fn(&self) -> (i64, fn(i64, i64, i64) -> i64) { + fn bin_fn(&self) -> (i64, BinFunction) { match self { Interval::Nanoseconds(nanos) => (*nanos, date_bin_nanos_interval), Interval::Months(months) => (*months, date_bin_months_interval), @@ -250,32 +323,55 @@ impl Interval { } // return time in nanoseconds that the source timestamp falls into based on the stride and origin -fn date_bin_nanos_interval(stride_nanos: i64, source: i64, origin: i64) -> i64 { - let time_diff = source - origin; +fn date_bin_nanos_interval(stride_nanos: i64, source: i64, origin: i64) -> Result { + let time_diff = source.checked_sub(origin).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "date_bin source timestamp {source} - origin {origin} overflows i64" + )) + })?; // distance from origin to bin - let time_delta = compute_distance(time_diff, stride_nanos); + let time_delta = compute_distance(time_diff, stride_nanos)?; - origin + time_delta + origin.checked_add(time_delta).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "date_bin origin {origin} + delta {time_delta} overflows i64" + )) + .into() + }) } // distance from origin to bin -fn compute_distance(time_diff: i64, stride: i64) -> i64 { - let time_delta = time_diff - (time_diff % stride); +fn compute_distance(time_diff: i64, stride: i64) -> Result { + let remainder = time_diff.checked_rem(stride).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "date_bin compute_distance time_diff {time_diff} % stride {stride} overflows i64" + )) + })?; + let time_delta = time_diff.checked_sub(remainder).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "date_bin compute_distance time_diff {time_diff} - remainder {remainder} overflows i64" + )) + })?; if time_diff < 0 && stride > 1 && time_delta != time_diff { // The origin is later than the source timestamp, round down to the previous bin - time_delta - stride + time_delta.checked_sub(stride).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "date_bin compute_distance time_delta {time_delta} - stride {stride} overflows i64" + )) + .into() + }) } else { - time_delta + Ok(time_delta) } } // return time in nanoseconds that the source timestamp falls into based on the stride and origin -fn date_bin_months_interval(stride_months: i64, source: i64, origin: i64) -> i64 { +fn date_bin_months_interval(stride_months: i64, source: i64, origin: i64) -> Result { // convert source and origin to DateTime - let source_date = to_utc_date_time(source); - let origin_date = to_utc_date_time(origin); + let source_date = to_utc_date_time(source)?; + let origin_date = to_utc_date_time(origin)?; // calculate the number of months between the source and origin let month_diff = (source_date.year() - origin_date.year()) * 12 @@ -283,12 +379,20 @@ fn date_bin_months_interval(stride_months: i64, source: i64, origin: i64) -> i64 - origin_date.month() as i32; // distance from origin to bin - let month_delta = compute_distance(month_diff as i64, stride_months); + let month_delta = compute_distance(month_diff as i64, stride_months)?; let mut bin_time = if month_delta < 0 { - origin_date - Months::new(month_delta.unsigned_abs() as u32) + match origin_date + .checked_sub_months(Months::new(month_delta.unsigned_abs() as u32)) + { + Some(dt) => dt, + None => return exec_err!("DATE_BIN month subtraction out of range"), + } } else { - origin_date + Months::new(month_delta as u32) + match origin_date.checked_add_months(Months::new(month_delta as u32)) { + Some(dt) => dt, + None => return exec_err!("DATE_BIN month addition out of range"), + } }; // If origin is not midnight of first date of the month, the bin_time may be larger than the source @@ -296,19 +400,32 @@ fn date_bin_months_interval(stride_months: i64, source: i64, origin: i64) -> i64 if bin_time > source_date { let month_delta = month_delta - stride_months; bin_time = if month_delta < 0 { - origin_date - Months::new(month_delta.unsigned_abs() as u32) + match origin_date + .checked_sub_months(Months::new(month_delta.unsigned_abs() as u32)) + { + Some(dt) => dt, + None => return exec_err!("DATE_BIN month subtraction out of range"), + } } else { - origin_date + Months::new(month_delta as u32) + match origin_date.checked_add_months(Months::new(month_delta as u32)) { + Some(dt) => dt, + None => return exec_err!("DATE_BIN month addition out of range"), + } }; } - - bin_time.timestamp_nanos_opt().unwrap() + match bin_time.timestamp_nanos_opt() { + Some(nanos) => Ok(nanos), + None => exec_err!("DATE_BIN result timestamp out of range"), + } } -fn to_utc_date_time(nanos: i64) -> DateTime { - let secs = nanos / 1_000_000_000; - let nsec = (nanos % 1_000_000_000) as u32; - DateTime::from_timestamp(secs, nsec).unwrap() +fn to_utc_date_time(nanos: i64) -> Result> { + let secs = nanos / NANOS_PER_SEC; + let nsec = (nanos % NANOS_PER_SEC) as u32; + match DateTime::from_timestamp(secs, nsec) { + Some(dt) => Ok(dt), + None => exec_err!("Invalid timestamp value"), + } } // Supported intervals: @@ -323,6 +440,12 @@ fn date_bin_impl( origin: &ColumnarValue, ) -> Result { let stride = match stride { + ColumnarValue::Scalar(s) if s.is_null() => { + // NULL stride -> NULL result (standard SQL NULL propagation) + return Ok(ColumnarValue::Scalar(ScalarValue::try_from( + array.data_type(), + )?)); + } ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(v))) => { let (days, ms) = IntervalDayTimeType::to_parts(*v); let nanos = (TimeDelta::try_days(days as i64).unwrap() @@ -365,23 +488,105 @@ fn date_bin_impl( } ColumnarValue::Array(_) => { return not_impl_err!( - "DATE_BIN only supports literal values for the stride argument, not arrays" - ); + "DATE_BIN only supports literal values for the stride argument, not arrays" + ); } }; - let origin = match origin { - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(v), _)) => *v, + let (origin, is_time) = match origin { + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(v), _)) => { + (*v, false) + } + ColumnarValue::Scalar(ScalarValue::Time32Millisecond(Some(v))) => { + match stride { + Interval::Months(m) => { + if m > 0 { + return exec_err!( + "DATE_BIN stride for TIME input must be less than 1 day" + ); + } + } + Interval::Nanoseconds(ns) => { + if ns >= NANOSECONDS_IN_DAY { + return exec_err!( + "DATE_BIN stride for TIME input must be less than 1 day" + ); + } + } + } + + (*v as i64 * NANOS_PER_MILLI, true) + } + ColumnarValue::Scalar(ScalarValue::Time32Second(Some(v))) => { + match stride { + Interval::Months(m) => { + if m > 0 { + return exec_err!( + "DATE_BIN stride for TIME input must be less than 1 day" + ); + } + } + Interval::Nanoseconds(ns) => { + if ns >= NANOSECONDS_IN_DAY { + return exec_err!( + "DATE_BIN stride for TIME input must be less than 1 day" + ); + } + } + } + + (*v as i64 * NANOS_PER_SEC, true) + } + ColumnarValue::Scalar(ScalarValue::Time64Microsecond(Some(v))) => { + match stride { + Interval::Months(m) => { + if m > 0 { + return exec_err!( + "DATE_BIN stride for TIME input must be less than 1 day" + ); + } + } + Interval::Nanoseconds(ns) => { + if ns >= NANOSECONDS_IN_DAY { + return exec_err!( + "DATE_BIN stride for TIME input must be less than 1 day" + ); + } + } + } + + (*v * NANOS_PER_MICRO, true) + } + ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(Some(v))) => { + match stride { + Interval::Months(m) => { + if m > 0 { + return exec_err!( + "DATE_BIN stride for TIME input must be less than 1 day" + ); + } + } + Interval::Nanoseconds(ns) => { + if ns >= NANOSECONDS_IN_DAY { + return exec_err!( + "DATE_BIN stride for TIME input must be less than 1 day" + ); + } + } + } + + (*v, true) + } ColumnarValue::Scalar(v) => { return exec_err!( - "DATE_BIN expects origin argument to be a TIMESTAMP with nanosecond precision but got {}", + "DATE_BIN expects origin argument to be a TIMESTAMP with nanosecond precision or a TIME but got {}", v.data_type() ); } ColumnarValue::Array(_) => { return not_impl_err!( - "DATE_BIN only supports literal values for the origin argument, not arrays" - ); + "DATE_BIN only supports literal values for the origin argument, not arrays" + ); } }; @@ -392,59 +597,153 @@ fn date_bin_impl( return exec_err!("DATE_BIN stride must be non-zero"); } - fn stride_map_fn( - origin: i64, - stride: i64, - stride_fn: fn(i64, i64, i64) -> i64, - ) -> impl Fn(i64) -> i64 { - let scale = match T::UNIT { + fn timestamp_scale() -> i64 { + match T::UNIT { Nanosecond => 1, - Microsecond => NANOSECONDS / 1_000_000, - Millisecond => NANOSECONDS / 1_000, + Microsecond => NANOS_PER_MICRO, + Millisecond => NANOS_PER_MILLI, Second => NANOSECONDS, - }; - move |x: i64| stride_fn(stride, x * scale, origin) / scale + } + } + + fn timestamp_scale_overflow_error(x: i64) -> DataFusionError { + DataFusionError::Execution(format!( + "DATE_BIN source timestamp {x} cannot be represented in nanoseconds" + )) } Ok(match array { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { - let apply_stride_fn = - stride_map_fn::(origin, stride, stride_fn); + let scale = timestamp_scale::(); ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - v.map(apply_stride_fn), + match *v { + Some(val) => { + let scaled = val + .checked_mul(scale) + .ok_or_else(|| timestamp_scale_overflow_error(val))?; + match stride_fn(stride, scaled, origin) { + Ok(result) => Some(result / scale), + Err(_) => None, + } + } + None => None, + }, tz_opt.clone(), )) } ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => { - let apply_stride_fn = - stride_map_fn::(origin, stride, stride_fn); + let scale = timestamp_scale::(); ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( - v.map(apply_stride_fn), + match *v { + Some(val) => { + let scaled = val + .checked_mul(scale) + .ok_or_else(|| timestamp_scale_overflow_error(val))?; + match stride_fn(stride, scaled, origin) { + Ok(result) => Some(result / scale), + Err(_) => None, + } + } + None => None, + }, tz_opt.clone(), )) } ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => { - let apply_stride_fn = - stride_map_fn::(origin, stride, stride_fn); + let scale = timestamp_scale::(); ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( - v.map(apply_stride_fn), + match *v { + Some(val) => { + let scaled = val + .checked_mul(scale) + .ok_or_else(|| timestamp_scale_overflow_error(val))?; + match stride_fn(stride, scaled, origin) { + Ok(result) => Some(result / scale), + Err(_) => None, + } + } + None => None, + }, tz_opt.clone(), )) } ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { - let apply_stride_fn = - stride_map_fn::(origin, stride, stride_fn); + let scale = timestamp_scale::(); ColumnarValue::Scalar(ScalarValue::TimestampSecond( - v.map(apply_stride_fn), + match *v { + Some(val) => { + let scaled = val + .checked_mul(scale) + .ok_or_else(|| timestamp_scale_overflow_error(val))?; + match stride_fn(stride, scaled, origin) { + Ok(result) => Some(result / scale), + Err(_) => None, + } + } + None => None, + }, tz_opt.clone(), )) } - + ColumnarValue::Scalar(ScalarValue::Time32Millisecond(v)) => { + if !is_time { + return exec_err!("DATE_BIN with Time32 source requires Time32 origin"); + } + let result = v.and_then(|x| { + match stride_fn(stride, x as i64 * NANOS_PER_MILLI, origin) { + Ok(binned_nanos) => { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + Some((nanos / NANOS_PER_MILLI) as i32) + } + Err(_) => None, + } + }); + ColumnarValue::Scalar(ScalarValue::Time32Millisecond(result)) + } + ColumnarValue::Scalar(ScalarValue::Time32Second(v)) => { + if !is_time { + return exec_err!("DATE_BIN with Time32 source requires Time32 origin"); + } + let result = v.and_then(|x| { + match stride_fn(stride, x as i64 * NANOS_PER_SEC, origin) { + Ok(binned_nanos) => { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + Some((nanos / NANOS_PER_SEC) as i32) + } + Err(_) => None, + } + }); + ColumnarValue::Scalar(ScalarValue::Time32Second(result)) + } + ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(v)) => { + if !is_time { + return exec_err!("DATE_BIN with Time64 source requires Time64 origin"); + } + let result = v.and_then(|x| match stride_fn(stride, x, origin) { + Ok(binned_nanos) => Some(binned_nanos % (NANOSECONDS_IN_DAY)), + Err(_) => None, + }); + ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(result)) + } + ColumnarValue::Scalar(ScalarValue::Time64Microsecond(v)) => { + if !is_time { + return exec_err!("DATE_BIN with Time64 source requires Time64 origin"); + } + let result = + v.and_then(|x| match stride_fn(stride, x * NANOS_PER_MICRO, origin) { + Ok(binned_nanos) => { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + Some(nanos / NANOS_PER_MICRO) + } + Err(_) => None, + }); + ColumnarValue::Scalar(ScalarValue::Time64Microsecond(result)) + } ColumnarValue::Array(array) => { fn transform_array_with_stride( origin: i64, stride: i64, - stride_fn: fn(i64, i64, i64) -> i64, + stride_fn: BinFunction, array: &ArrayRef, tz_opt: &Option>, ) -> Result @@ -452,11 +751,26 @@ fn date_bin_impl( T: ArrowTimestampType, { let array = as_primitive_array::(array)?; - let apply_stride_fn = stride_map_fn::(origin, stride, stride_fn); - let array: PrimitiveArray = array - .unary(apply_stride_fn) - .with_timezone_opt(tz_opt.clone()); + let scale = timestamp_scale::(); + let values = array + .iter() + .map(|val| match val { + Some(val) => { + let scaled = val + .checked_mul(scale) + .ok_or_else(|| timestamp_scale_overflow_error(val))?; + Ok(stride_fn(stride, scaled, origin) + .ok() + .map(|binned| binned / scale)) + } + None => Ok(None), + }) + .collect::>>()?; + + let result = PrimitiveArray::::from_iter(values); + + let array = result.with_timezone_opt(tz_opt.clone()); Ok(ColumnarValue::Array(Arc::new(array))) } @@ -481,9 +795,78 @@ fn date_bin_impl( origin, stride, stride_fn, array, tz_opt, )? } + Time32(Millisecond) => { + if !is_time { + return exec_err!( + "DATE_BIN with Time32 source requires Time32 origin" + ); + } + let array = array.as_primitive::(); + let result: PrimitiveArray = + array.try_unary(|x| { + stride_fn(stride, x as i64 * NANOS_PER_MILLI, origin) + .map(|binned_nanos| { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + (nanos / NANOS_PER_MILLI) as i32 + }) + .map_err(|e| ArrowError::ComputeError(e.to_string())) + })?; + ColumnarValue::Array(Arc::new(result)) + } + Time32(Second) => { + if !is_time { + return exec_err!( + "DATE_BIN with Time32 source requires Time32 origin" + ); + } + let array = array.as_primitive::(); + let result: PrimitiveArray = + array.try_unary(|x| { + stride_fn(stride, x as i64 * NANOS_PER_SEC, origin) + .map(|binned_nanos| { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + (nanos / NANOS_PER_SEC) as i32 + }) + .map_err(|e| ArrowError::ComputeError(e.to_string())) + })?; + ColumnarValue::Array(Arc::new(result)) + } + Time64(Microsecond) => { + if !is_time { + return exec_err!( + "DATE_BIN with Time64 source requires Time64 origin" + ); + } + let array = array.as_primitive::(); + let result: PrimitiveArray = + array.try_unary(|x| { + stride_fn(stride, x * NANOS_PER_MICRO, origin) + .map(|binned_nanos| { + let nanos = binned_nanos % (NANOSECONDS_IN_DAY); + nanos / NANOS_PER_MICRO + }) + .map_err(|e| ArrowError::ComputeError(e.to_string())) + })?; + ColumnarValue::Array(Arc::new(result)) + } + Time64(Nanosecond) => { + if !is_time { + return exec_err!( + "DATE_BIN with Time64 source requires Time64 origin" + ); + } + let array = array.as_primitive::(); + let result: PrimitiveArray = + array.try_unary(|x| { + stride_fn(stride, x, origin) + .map(|binned_nanos| binned_nanos % (NANOSECONDS_IN_DAY)) + .map_err(|e| ArrowError::ComputeError(e.to_string())) + })?; + ColumnarValue::Array(Arc::new(result)) + } _ => { return exec_err!( - "DATE_BIN expects source argument to be a TIMESTAMP but got {}", + "DATE_BIN expects source argument to be a TIMESTAMP or TIME but got {}", array.data_type() ); } @@ -491,7 +874,7 @@ fn date_bin_impl( } _ => { return exec_err!( - "DATE_BIN expects source argument to be a TIMESTAMP scalar or array" + "DATE_BIN expects source argument to be a TIMESTAMP or TIME scalar or array" ); } }) @@ -501,7 +884,7 @@ fn date_bin_impl( mod tests { use std::sync::Arc; - use crate::datetime::date_bin::{date_bin_nanos_interval, DateBinFunc}; + use crate::datetime::date_bin::{DateBinFunc, date_bin_nanos_interval}; use arrow::array::types::TimestampNanosecondType; use arrow::array::{Array, IntervalDayTimeArray, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; @@ -509,7 +892,7 @@ mod tests { use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion_common::{DataFusionError, ScalarValue}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use chrono::TimeDelta; use datafusion_common::config::ConfigOptions; @@ -524,7 +907,7 @@ mod tests { .map(|arg| Field::new("a", arg.data_type(), true).into()) .collect::>(); - let args = datafusion_expr::ScalarFunctionArgs { + let args = ScalarFunctionArgs { args, arg_fields, number_rows, @@ -687,7 +1070,7 @@ mod tests { let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), - "Execution error: DATE_BIN expects origin argument to be a TIMESTAMP with nanosecond precision but got Timestamp(µs)" + "Execution error: DATE_BIN expects origin argument to be a TIMESTAMP with nanosecond precision or a TIME but got Timestamp(µs)" ); args = vec![ @@ -935,7 +1318,7 @@ mod tests { let origin1 = string_to_timestamp_nanos(origin).unwrap(); let expected1 = string_to_timestamp_nanos(expected).unwrap(); - let result = date_bin_nanos_interval(stride1, source1, origin1); + let result = date_bin_nanos_interval(stride1, source1, origin1).unwrap(); assert_eq!(result, expected1, "{source} = {expected}"); }) } @@ -963,8 +1346,103 @@ mod tests { let source1 = string_to_timestamp_nanos(source).unwrap(); let expected1 = string_to_timestamp_nanos(expected).unwrap(); - let result = date_bin_nanos_interval(stride1, source1, 0); + let result = date_bin_nanos_interval(stride1, source1, 0).unwrap(); assert_eq!(result, expected1, "{source} = {expected}"); }) } + + #[test] + fn test_date_bin_out_of_range() { + let return_field = &Arc::new(Field::new( + "f", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + )); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1637426858, 0, 0)), + ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(1040292460), + None, + )), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(string_to_timestamp_nanos("1984-01-07 00:00:00").unwrap()), + None, + )), + ]; + + let result = invoke_date_bin_with_args(args, 1, return_field); + assert!(result.is_ok()); + if let ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(val, _)) = + result.unwrap() + { + assert!(val.is_none(), "Expected None for out of range operation"); + } + let args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1637426858, 0, 0)), + ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(-1040292460), + None, + )), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(string_to_timestamp_nanos("1984-01-07 00:00:00").unwrap()), + None, + )), + ]; + + let result = invoke_date_bin_with_args(args, 1, return_field); + assert!(result.is_ok()); + if let ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(val, _)) = + result.unwrap() + { + assert!(val.is_none(), "Expected None for out of range operation"); + } + } + + #[test] + fn test_date_bin_compute_distance_i64_min() { + // Regression for #22215: date_bin_nanos_interval on a source near i64::MIN + // previously panicked inside compute_distance with "attempt to subtract with overflow". + // Now it must return a normal Err that the scalar pipeline maps to NULL. + let result = date_bin_nanos_interval(3, i64::MIN, 0); + assert!( + result.is_err(), + "expected Err for source=i64::MIN, got {result:?}" + ); + + let return_field = &Arc::new(Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + )); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(0, 0, 3)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(i64::MIN), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(0), None)), + ]; + let result = invoke_date_bin_with_args(args, 1, return_field); + assert!(result.is_ok(), "expected Ok with NULL, got {result:?}"); + if let ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(val, _)) = + result.unwrap() + { + assert!( + val.is_none(), + "Expected None for compute_distance overflow, got {val:?}" + ); + } else { + panic!("Expected TimestampNanosecond scalar"); + } + } + + #[test] + fn test_date_bin_compute_distance_rem_overflow() { + // Regression for #22215: `time_diff % stride` panics with "attempt to + // calculate the remainder with overflow" when `time_diff == i64::MIN` + // and `stride == -1`. Now it must return a normal Err that the scalar + // pipeline maps to NULL. + let result = date_bin_nanos_interval(-1, i64::MIN, 0); + assert!( + result.is_err(), + "expected Err for time_diff=i64::MIN, stride=-1, got {result:?}" + ); + } } diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index aa23a5028dd81..3c405d388bcab 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -15,23 +15,30 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::str::FromStr; use std::sync::Arc; -use arrow::array::{Array, ArrayRef, Float64Array, Int32Array}; +use arrow::array::timezone::Tz; +use arrow::array::{Array, ArrayRef, Float64Array, Int32Array, Int64Array}; use arrow::compute::kernels::cast_utils::IntervalUnit; -use arrow::compute::{binary, date_part, DatePart}; +use arrow::compute::{DatePart, binary, date_part}; use arrow::datatypes::DataType::{ Date32, Date64, Duration, Interval, Time32, Time64, Timestamp, }; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; -use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; -use datafusion_common::types::{logical_date, NativeType}; +use arrow::datatypes::{ + ArrowTimestampType, DataType, Date32Type, Date64Type, Field, FieldRef, + IntervalUnit as ArrowIntervalUnit, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; +use chrono::{Datelike, NaiveDate}; +use datafusion_common::types::{NativeType, logical_date}; use datafusion_common::{ + Result, ScalarValue, cast::{ - as_date32_array, as_date64_array, as_int32_array, as_time32_millisecond_array, + as_date32_array, as_date64_array, as_int32_array, as_interval_dt_array, + as_interval_mdn_array, as_interval_ym_array, as_time32_millisecond_array, as_time32_second_array, as_time64_microsecond_array, as_time64_nanosecond_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, as_timestamp_second_array, @@ -39,11 +46,12 @@ use datafusion_common::{ exec_err, internal_err, not_impl_err, types::logical_string, utils::take_function_args, - Result, ScalarValue, }; +use datafusion_expr::preimage::PreimageResult; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, - TypeSignature, Volatility, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignature, Volatility, interval_arithmetic, }; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; @@ -56,8 +64,9 @@ use datafusion_macros::user_doc; argument( name = "part", description = r#"Part of the date to return. The following date parts are supported: - + - year + - isoyear (ISO 8601 week-numbering year) - quarter (emits value in inclusive range [1, 4] based on which quartile of the year the date is in) - month - week (week of the year) @@ -70,14 +79,28 @@ use datafusion_macros::user_doc; - nanosecond - dow (day of the week where Sunday is 0) - doy (day of the year) - - epoch (seconds since Unix epoch) - - isodow (day of the week where Monday is 0) + - epoch (seconds since Unix epoch for timestamps/dates, total seconds for intervals) + - isodow (ISO 8601 day of the week where Monday is 1 and Sunday is 7) "# ), argument( name = "expression", description = "Time expression to operate on. Can be a constant, column, or function." - ) + ), + sql_example = r#"```sql +> SELECT date_part('year', '2024-05-01T00:00:00'); ++-----------------------------------------------------+ +| date_part(Utf8("year"),Utf8("2024-05-01T00:00:00")) | ++-----------------------------------------------------+ +| 2024 | ++-----------------------------------------------------+ +> SELECT extract(day FROM timestamp '2024-05-01T00:00:00'); ++----------------------------------------------------+ +| date_part(Utf8("DAY"),Utf8("2024-05-01T00:00:00")) | ++----------------------------------------------------+ +| 1 | ++----------------------------------------------------+ +```"# )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct DatePartFunc { @@ -130,10 +153,6 @@ impl DatePartFunc { } impl ScalarUDFImpl for DatePartFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "date_part" } @@ -148,6 +167,7 @@ impl ScalarUDFImpl for DatePartFunc { fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { let [field, _] = take_function_args(self.name(), args.scalar_arguments)?; + let nullable = args.arg_fields[1].is_nullable(); field .and_then(|sv| { @@ -156,9 +176,12 @@ impl ScalarUDFImpl for DatePartFunc { .filter(|s| !s.is_empty()) .map(|part| { if is_epoch(part) { - Field::new(self.name(), DataType::Float64, true) + Field::new(self.name(), DataType::Float64, nullable) + } else if is_nanosecond(part) { + // See notes on [seconds_ns] for rationale + Field::new(self.name(), DataType::Int64, nullable) } else { - Field::new(self.name(), DataType::Int32, true) + Field::new(self.name(), DataType::Int32, nullable) } }) }) @@ -169,10 +192,7 @@ impl ScalarUDFImpl for DatePartFunc { ) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = args.args; let [part, array] = take_function_args(self.name(), args)?; @@ -208,17 +228,29 @@ impl ScalarUDFImpl for DatePartFunc { IntervalUnit::Second => seconds_as_i32(array.as_ref(), Second)?, IntervalUnit::Millisecond => seconds_as_i32(array.as_ref(), Millisecond)?, IntervalUnit::Microsecond => seconds_as_i32(array.as_ref(), Microsecond)?, - IntervalUnit::Nanosecond => seconds_as_i32(array.as_ref(), Nanosecond)?, + IntervalUnit::Nanosecond => seconds_ns(array.as_ref())?, // century and decade are not supported by `DatePart`, although they are supported in postgres _ => return exec_err!("Date part '{part}' not supported"), } } else { // special cases that can be extracted (in postgres) but are not interval units match part_trim.to_lowercase().as_str() { + "isoyear" => date_part(array.as_ref(), DatePart::YearISO)?, "qtr" | "quarter" => date_part(array.as_ref(), DatePart::Quarter)?, "doy" => date_part(array.as_ref(), DatePart::DayOfYear)?, "dow" => date_part(array.as_ref(), DatePart::DayOfWeekSunday0)?, - "isodow" => date_part(array.as_ref(), DatePart::DayOfWeekMonday0)?, + "isodow" => { + // Postgres `isodow` is 1..=7 with Mon=1. Arrow's + // `DayOfWeekMonday0` returns 0..=6 with Mon=0; shift by + // +1 to match Postgres. TODO: switch to a future + // `DatePart::DayOfWeekMonday1` upstream variant once it + // exists, so this kernel-then-add becomes a single call. + let zero_based = + date_part(array.as_ref(), DatePart::DayOfWeekMonday0)?; + let int_arr = as_int32_array(&zero_based)?; + let one_based: Int32Array = int_arr.unary(|v| v + 1); + Arc::new(one_based) as ArrayRef + } "epoch" => epoch(array.as_ref())?, _ => return exec_err!("Date part '{part}' not supported"), } @@ -231,6 +263,71 @@ impl ScalarUDFImpl for DatePartFunc { }) } + // Only casting the year is supported since pruning other IntervalUnit is not possible + // date_part(col, YEAR) = 2024 => col >= '2024-01-01' and col < '2025-01-01' + // But for anything less than YEAR simplifying is not possible without specifying the bigger interval + // date_part(col, MONTH) = 1 => col = '2023-01-01' or col = '2024-01-01' or ... or col = '3000-01-01' + fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + info: &SimplifyContext, + ) -> Result { + let [part, col_expr] = take_function_args(self.name(), args)?; + + // Get the interval unit from the part argument + let interval_unit = part + .as_literal() + .and_then(|sv| sv.try_as_str().flatten()) + .map(part_normalization) + .and_then(|s| IntervalUnit::from_str(s).ok()); + + // only support extracting year + match interval_unit { + Some(IntervalUnit::Year) => (), + _ => return Ok(PreimageResult::None), + } + + // Check if the argument is a literal (e.g. date_part(YEAR, col) = 2024) + let Some(argument_literal) = lit_expr.as_literal() else { + return Ok(PreimageResult::None); + }; + + // Extract i32 year from Scalar value + let year = match argument_literal { + ScalarValue::Int32(Some(y)) => *y, + _ => return Ok(PreimageResult::None), + }; + + // Can only extract year from Date32/64 and Timestamp column + let target_type = match info.get_data_type(col_expr)? { + Date32 | Date64 | Timestamp(_, _) => &info.get_data_type(col_expr)?, + _ => return Ok(PreimageResult::None), + }; + + // Compute the Interval bounds + let Some(start_time) = NaiveDate::from_ymd_opt(year, 1, 1) else { + return Ok(PreimageResult::None); + }; + let Some(end_time) = start_time.with_year(year + 1) else { + return Ok(PreimageResult::None); + }; + + // Convert to ScalarValues + let (Some(lower), Some(upper)) = ( + date_to_scalar(start_time, target_type), + date_to_scalar(end_time, target_type), + ) else { + return Ok(PreimageResult::None); + }; + let interval = Box::new(interval_arithmetic::Interval::try_new(lower, upper)?); + + Ok(PreimageResult::Range { + expr: col_expr.clone(), + interval, + }) + } + fn aliases(&self) -> &[String] { &self.aliases } @@ -245,6 +342,53 @@ fn is_epoch(part: &str) -> bool { matches!(part.to_lowercase().as_str(), "epoch") } +fn is_nanosecond(part: &str) -> bool { + IntervalUnit::from_str(part_normalization(part)) + .map(|p| matches!(p, IntervalUnit::Nanosecond)) + .unwrap_or(false) +} + +fn date_to_scalar(date: NaiveDate, target_type: &DataType) -> Option { + Some(match target_type { + Date32 => ScalarValue::Date32(Some(Date32Type::from_naive_date(date))), + Date64 => ScalarValue::Date64(Some(Date64Type::from_naive_date(date))), + + Timestamp(unit, tz_opt) => { + let naive_midnight = date.and_hms_opt(0, 0, 0)?; + let tz: Option = tz_opt.clone().and_then(|s| s.parse().ok()); + + match unit { + Second => ScalarValue::TimestampSecond( + TimestampSecondType::from_naive_datetime(naive_midnight, tz.as_ref()), + tz_opt.clone(), + ), + Millisecond => ScalarValue::TimestampMillisecond( + TimestampMillisecondType::from_naive_datetime( + naive_midnight, + tz.as_ref(), + ), + tz_opt.clone(), + ), + Microsecond => ScalarValue::TimestampMicrosecond( + TimestampMicrosecondType::from_naive_datetime( + naive_midnight, + tz.as_ref(), + ), + tz_opt.clone(), + ), + Nanosecond => ScalarValue::TimestampNanosecond( + TimestampNanosecondType::from_naive_datetime( + naive_midnight, + tz.as_ref(), + ), + tz_opt.clone(), + ), + } + } + _ => return None, + }) +} + // Try to remove quote if exist, if the quote is invalid, return original string and let the downstream function handle the error fn part_normalization(part: &str) -> &str { part.strip_prefix(|c| c == '\'' || c == '\"') @@ -349,6 +493,11 @@ fn seconds(array: &dyn Array, unit: TimeUnit) -> Result { fn epoch(array: &dyn Array) -> Result { const SECONDS_IN_A_DAY: f64 = 86400_f64; + // Note: Month-to-second conversion uses 30 days as an approximation. + // This matches PostgreSQL's behavior for interval epoch extraction, + // but does not represent exact calendar months (which vary 28-31 days). + // See: https://doxygen.postgresql.org/datatype_2timestamp_8h.html + const DAYS_PER_MONTH: f64 = 30_f64; let f: Float64Array = match array.data_type() { Timestamp(Second, _) => as_timestamp_second_array(array)?.unary(|x| x as f64), @@ -373,8 +522,56 @@ fn epoch(array: &dyn Array) -> Result { Time64(Nanosecond) => { as_time64_nanosecond_array(array)?.unary(|x| x as f64 / 1_000_000_000_f64) } - Interval(_) | Duration(_) => return seconds(array, Second), + Interval(ArrowIntervalUnit::YearMonth) => as_interval_ym_array(array)? + .unary(|x| x as f64 * DAYS_PER_MONTH * SECONDS_IN_A_DAY), + Interval(ArrowIntervalUnit::DayTime) => as_interval_dt_array(array)?.unary(|x| { + x.days as f64 * SECONDS_IN_A_DAY + x.milliseconds as f64 / 1_000_f64 + }), + Interval(ArrowIntervalUnit::MonthDayNano) => { + as_interval_mdn_array(array)?.unary(|x| { + x.months as f64 * DAYS_PER_MONTH * SECONDS_IN_A_DAY + + x.days as f64 * SECONDS_IN_A_DAY + + x.nanoseconds as f64 / 1_000_000_000_f64 + }) + } + Duration(_) => return seconds(array, Second), d => return exec_err!("Cannot convert {d:?} to epoch"), }; Ok(Arc::new(f)) } + +/// Invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the +/// result to a total number of nanoseconds as an Int64 array. +/// +/// This returns an Int64 rather than Int32 because there 1 billion +/// `nanosecond`s in each second, so representing up to 60 seconds as +/// nanoseconds can be values up to 60 billion, which does not fit in Int32. +fn seconds_ns(array: &dyn Array) -> Result { + let secs = date_part(array, DatePart::Second)?; + // This assumes array is primitive and not a dictionary + let secs = as_int32_array(secs.as_ref())?; + let subsecs = date_part(array, DatePart::Nanosecond)?; + let subsecs = as_int32_array(subsecs.as_ref())?; + + // Special case where there are no nulls. + if subsecs.null_count() == 0 { + let r: Int64Array = binary(secs, subsecs, |secs, subsecs| { + (secs as i64) * 1_000_000_000 + (subsecs as i64) + })?; + Ok(Arc::new(r)) + } else { + // Nulls in secs are preserved, nulls in subsecs are treated as zero to account for the case + // where the number of nanoseconds overflows. + let r: Int64Array = secs + .iter() + .zip(subsecs) + .map(|(secs, subsecs)| { + secs.map(|secs| { + let subsecs = subsecs.unwrap_or(0); + (secs as i64) * 1_000_000_000 + (subsecs as i64) + }) + }) + .collect(); + Ok(Arc::new(r)) + } +} diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index c8376cf84415f..a4b244405cc22 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -15,32 +15,37 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; +use std::fmt; use std::num::NonZeroI64; use std::ops::{Add, Sub}; use std::str::FromStr; use std::sync::Arc; use arrow::array::temporal_conversions::{ - as_datetime_with_timezone, timestamp_ns_to_datetime, + MICROSECONDS, MILLISECONDS, NANOSECONDS, as_datetime_with_timezone, + timestamp_ns_to_datetime, }; use arrow::array::timezone::Tz; use arrow::array::types::{ - ArrowTimestampType, TimestampMicrosecondType, TimestampMillisecondType, + ArrowTimestampType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; use arrow::array::{Array, ArrayRef, PrimitiveArray}; -use arrow::datatypes::DataType::{self, Null, Timestamp, Utf8, Utf8View}; +use arrow::datatypes::DataType::{self, Time32, Time64, Timestamp}; use arrow::datatypes::TimeUnit::{self, Microsecond, Millisecond, Nanosecond, Second}; +use arrow::datatypes::{Field, FieldRef}; use datafusion_common::cast::as_primitive_array; +use datafusion_common::types::{NativeType, logical_date, logical_string}; use datafusion_common::{ - exec_datafusion_err, exec_err, plan_err, DataFusionError, Result, ScalarValue, + DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err, internal_err, }; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignature, Volatility, }; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; use chrono::{ @@ -116,16 +121,48 @@ impl DateTruncGranularity { fn is_fine_granularity_utc(&self) -> bool { self.is_fine_granularity() || matches!(self, Self::Hour | Self::Day) } + + /// Returns true if this granularity is valid for Time types + /// Time types don't have date components, so day/week/month/quarter/year are not valid + fn valid_for_time(&self) -> bool { + matches!( + self, + Self::Hour + | Self::Minute + | Self::Second + | Self::Millisecond + | Self::Microsecond + ) + } +} + +impl fmt::Display for DateTruncGranularity { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let value = match self { + Self::Microsecond => "microsecond", + Self::Millisecond => "millisecond", + Self::Second => "second", + Self::Minute => "minute", + Self::Hour => "hour", + Self::Day => "day", + Self::Week => "week", + Self::Month => "month", + Self::Quarter => "quarter", + Self::Year => "year", + }; + f.write_str(value) + } } #[user_doc( doc_section(label = "Time and Date Functions"), - description = "Truncates a timestamp value to a specified precision.", + description = "Truncates a timestamp or time value to a specified precision.", syntax_example = "date_trunc(precision, expression)", argument( name = "precision", description = r#"Time precision to truncate to. The following precisions are supported: + For Timestamp types: - year / YEAR - quarter / QUARTER - month / MONTH @@ -136,12 +173,33 @@ impl DateTruncGranularity { - second / SECOND - millisecond / MILLISECOND - microsecond / MICROSECOND + + For Time types (hour, minute, second, millisecond, microsecond only): + - hour / HOUR + - minute / MINUTE + - second / SECOND + - millisecond / MILLISECOND + - microsecond / MICROSECOND "# ), argument( name = "expression", - description = "Time expression to operate on. Can be a constant, column, or function." - ) + description = "Timestamp or time expression to operate on. Can be a constant, column, or function." + ), + sql_example = r#"```sql +> SELECT date_trunc('month', '2024-05-15T10:30:00'); ++-----------------------------------------------+ +| date_trunc(Utf8("month"),Utf8("2024-05-15T10:30:00")) | ++-----------------------------------------------+ +| 2024-05-01T00:00:00 | ++-----------------------------------------------+ +> SELECT date_trunc('hour', '2024-05-15T10:30:00'); ++----------------------------------------------+ +| date_trunc(Utf8("hour"),Utf8("2024-05-15T10:30:00")) | ++----------------------------------------------+ +| 2024-05-15T10:00:00 | ++----------------------------------------------+ +```"# )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct DateTruncFunc { @@ -160,45 +218,21 @@ impl DateTruncFunc { Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8, Timestamp(Nanosecond, None)]), - Exact(vec![Utf8View, Timestamp(Nanosecond, None)]), - Exact(vec![ - Utf8, - Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![ - Utf8View, - Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![Utf8, Timestamp(Microsecond, None)]), - Exact(vec![Utf8View, Timestamp(Microsecond, None)]), - Exact(vec![ - Utf8, - Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Timestamp, + // Allow implicit cast from string and date to timestamp for backward compatibility + vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Native(logical_date()), + ], + NativeType::Timestamp(Nanosecond, None), + ), ]), - Exact(vec![ - Utf8View, - Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![Utf8, Timestamp(Millisecond, None)]), - Exact(vec![Utf8View, Timestamp(Millisecond, None)]), - Exact(vec![ - Utf8, - Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![ - Utf8View, - Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![Utf8, Timestamp(Second, None)]), - Exact(vec![Utf8View, Timestamp(Second, None)]), - Exact(vec![ - Utf8, - Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), - ]), - Exact(vec![ - Utf8View, - Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Time), ]), ], Volatility::Immutable, @@ -209,10 +243,6 @@ impl DateTruncFunc { } impl ScalarUDFImpl for DateTruncFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "date_trunc" } @@ -221,25 +251,25 @@ impl ScalarUDFImpl for DateTruncFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - match &arg_types[1] { - Timestamp(Nanosecond, None) | Utf8 | DataType::Date32 | Null => { - Ok(Timestamp(Nanosecond, None)) - } - Timestamp(Nanosecond, tz_opt) => Ok(Timestamp(Nanosecond, tz_opt.clone())), - Timestamp(Microsecond, tz_opt) => Ok(Timestamp(Microsecond, tz_opt.clone())), - Timestamp(Millisecond, tz_opt) => Ok(Timestamp(Millisecond, tz_opt.clone())), - Timestamp(Second, tz_opt) => Ok(Timestamp(Second, tz_opt.clone())), - _ => plan_err!( - "The date_trunc function can only accept timestamp as the second arg." - ), - } + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be called instead") } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let field = &args.arg_fields[1]; + let return_type = if field.data_type().is_null() { + Timestamp(Nanosecond, None) + } else { + field.data_type().clone() + }; + Ok(Arc::new(Field::new( + self.name(), + return_type, + field.is_nullable(), + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = args.args; let (granularity, array) = (&args[0], &args[1]); @@ -248,6 +278,9 @@ impl ScalarUDFImpl for DateTruncFunc { { v.to_lowercase() } else if let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v))) = granularity + { + v.to_lowercase() + } else if let ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v))) = granularity { v.to_lowercase() } else { @@ -256,6 +289,15 @@ impl ScalarUDFImpl for DateTruncFunc { let granularity = DateTruncGranularity::from_str(&granularity_str)?; + // Check upfront if granularity is valid for Time types + let is_time_type = matches!(array.data_type(), Time64(_) | Time32(_)); + if is_time_type && !granularity.valid_for_time() { + return exec_err!( + "date_trunc does not support '{}' granularity for Time types. Valid values are: hour, minute, second, millisecond, microsecond", + granularity_str + ); + } + fn process_array( array: &dyn Array, granularity: DateTruncGranularity, @@ -303,6 +345,10 @@ impl ScalarUDFImpl for DateTruncFunc { } Ok(match array { + ColumnarValue::Scalar(ScalarValue::Null) => { + // NULL input returns NULL timestamp + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(None, None)) + } ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { process_scalar::(v, granularity, tz_opt)? } @@ -315,38 +361,77 @@ impl ScalarUDFImpl for DateTruncFunc { ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { process_scalar::(v, granularity, tz_opt)? } + ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(v)) => { + let truncated = v.map(|val| truncate_time_nanos(val, granularity)); + ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(truncated)) + } + ColumnarValue::Scalar(ScalarValue::Time64Microsecond(v)) => { + let truncated = v.map(|val| truncate_time_micros(val, granularity)); + ColumnarValue::Scalar(ScalarValue::Time64Microsecond(truncated)) + } + ColumnarValue::Scalar(ScalarValue::Time32Millisecond(v)) => { + let truncated = v.map(|val| truncate_time_millis(val, granularity)); + ColumnarValue::Scalar(ScalarValue::Time32Millisecond(truncated)) + } + ColumnarValue::Scalar(ScalarValue::Time32Second(v)) => { + let truncated = v.map(|val| truncate_time_secs(val, granularity)); + ColumnarValue::Scalar(ScalarValue::Time32Second(truncated)) + } ColumnarValue::Array(array) => { let array_type = array.data_type(); - if let Timestamp(unit, tz_opt) = array_type { - match unit { - Second => process_array::( - array, - granularity, - tz_opt, - )?, - Millisecond => process_array::( - array, - granularity, - tz_opt, - )?, - Microsecond => process_array::( - array, - granularity, - tz_opt, - )?, - Nanosecond => process_array::( - array, - granularity, - tz_opt, - )?, + match array_type { + Timestamp(Second, tz_opt) => { + process_array::(array, granularity, tz_opt)? + } + Timestamp(Millisecond, tz_opt) => process_array::< + TimestampMillisecondType, + >( + array, granularity, tz_opt + )?, + Timestamp(Microsecond, tz_opt) => process_array::< + TimestampMicrosecondType, + >( + array, granularity, tz_opt + )?, + Timestamp(Nanosecond, tz_opt) => process_array::< + TimestampNanosecondType, + >( + array, granularity, tz_opt + )?, + Time64(Nanosecond) => { + let arr = as_primitive_array::(array)?; + let result: PrimitiveArray = + arr.unary(|v| truncate_time_nanos(v, granularity)); + ColumnarValue::Array(Arc::new(result)) + } + Time64(Microsecond) => { + let arr = as_primitive_array::(array)?; + let result: PrimitiveArray = + arr.unary(|v| truncate_time_micros(v, granularity)); + ColumnarValue::Array(Arc::new(result)) + } + Time32(Millisecond) => { + let arr = as_primitive_array::(array)?; + let result: PrimitiveArray = + arr.unary(|v| truncate_time_millis(v, granularity)); + ColumnarValue::Array(Arc::new(result)) + } + Time32(Second) => { + let arr = as_primitive_array::(array)?; + let result: PrimitiveArray = + arr.unary(|v| truncate_time_secs(v, granularity)); + ColumnarValue::Array(Arc::new(result)) + } + _ => { + return exec_err!( + "second argument of `date_trunc` is an unsupported array type: {array_type}" + ); } - } else { - return exec_err!("second argument of `date_trunc` is an unsupported array type: {array_type}"); } } _ => { return exec_err!( - "second argument of `date_trunc` must be timestamp scalar or array" + "second argument of `date_trunc` must be timestamp, time scalar or array" ); } }) @@ -372,6 +457,76 @@ impl ScalarUDFImpl for DateTruncFunc { } } +const NANOS_PER_MICROSECOND: i64 = NANOSECONDS / MICROSECONDS; +const NANOS_PER_MILLISECOND: i64 = NANOSECONDS / MILLISECONDS; +const NANOS_PER_SECOND: i64 = NANOSECONDS; +const NANOS_PER_MINUTE: i64 = 60 * NANOS_PER_SECOND; +const NANOS_PER_HOUR: i64 = 60 * NANOS_PER_MINUTE; + +const MICROS_PER_MILLISECOND: i64 = MICROSECONDS / MILLISECONDS; +const MICROS_PER_SECOND: i64 = MICROSECONDS; +const MICROS_PER_MINUTE: i64 = 60 * MICROS_PER_SECOND; +const MICROS_PER_HOUR: i64 = 60 * MICROS_PER_MINUTE; + +const MILLIS_PER_SECOND: i32 = MILLISECONDS as i32; +const MILLIS_PER_MINUTE: i32 = 60 * MILLIS_PER_SECOND; +const MILLIS_PER_HOUR: i32 = 60 * MILLIS_PER_MINUTE; + +const SECS_PER_MINUTE: i32 = 60; +const SECS_PER_HOUR: i32 = 60 * SECS_PER_MINUTE; + +/// Truncate time in nanoseconds to the specified granularity +fn truncate_time_nanos(value: i64, granularity: DateTruncGranularity) -> i64 { + match granularity { + DateTruncGranularity::Hour => value - (value % NANOS_PER_HOUR), + DateTruncGranularity::Minute => value - (value % NANOS_PER_MINUTE), + DateTruncGranularity::Second => value - (value % NANOS_PER_SECOND), + DateTruncGranularity::Millisecond => value - (value % NANOS_PER_MILLISECOND), + DateTruncGranularity::Microsecond => value - (value % NANOS_PER_MICROSECOND), + // Other granularities are not valid for time - should be caught earlier + _ => value, + } +} + +/// Truncate time in microseconds to the specified granularity +fn truncate_time_micros(value: i64, granularity: DateTruncGranularity) -> i64 { + match granularity { + DateTruncGranularity::Hour => value - (value % MICROS_PER_HOUR), + DateTruncGranularity::Minute => value - (value % MICROS_PER_MINUTE), + DateTruncGranularity::Second => value - (value % MICROS_PER_SECOND), + DateTruncGranularity::Millisecond => value - (value % MICROS_PER_MILLISECOND), + DateTruncGranularity::Microsecond => value, // Already at microsecond precision + // Other granularities are not valid for time + _ => value, + } +} + +/// Truncate time in milliseconds to the specified granularity +fn truncate_time_millis(value: i32, granularity: DateTruncGranularity) -> i32 { + match granularity { + DateTruncGranularity::Hour => value - (value % MILLIS_PER_HOUR), + DateTruncGranularity::Minute => value - (value % MILLIS_PER_MINUTE), + DateTruncGranularity::Second => value - (value % MILLIS_PER_SECOND), + DateTruncGranularity::Millisecond => value, // Already at millisecond precision + DateTruncGranularity::Microsecond => value, // Can't truncate to finer precision + // Other granularities are not valid for time + _ => value, + } +} + +/// Truncate time in seconds to the specified granularity +fn truncate_time_secs(value: i32, granularity: DateTruncGranularity) -> i32 { + match granularity { + DateTruncGranularity::Hour => value - (value % SECS_PER_HOUR), + DateTruncGranularity::Minute => value - (value % SECS_PER_MINUTE), + DateTruncGranularity::Second => value, // Already at second precision + DateTruncGranularity::Millisecond => value, // Can't truncate to finer precision + DateTruncGranularity::Microsecond => value, // Can't truncate to finer precision + // Other granularities are not valid for time + _ => value, + } +} + fn _date_trunc_coarse( granularity: DateTruncGranularity, value: Option, @@ -493,6 +648,7 @@ fn date_trunc_coarse( value: i64, tz: Option, ) -> Result { + let input = value; let value = match tz { Some(tz) => { // Use chrono DateTime to clear the various fields because need to clear per timezone, @@ -509,8 +665,11 @@ fn date_trunc_coarse( } }?; - // `with_x(0)` are infallible because `0` are always a valid - Ok(value.unwrap()) + value.ok_or_else(|| { + exec_datafusion_err!( + "Timestamp {input} out of range after truncating to {granularity}" + ) + }) } /// Fast path for fine granularities (hour and smaller) that can be handled @@ -582,7 +741,13 @@ fn general_date_trunc( }; // convert to nanoseconds - let nano = date_trunc_coarse(granularity, scale * value, tz)?; + let nano = date_trunc_coarse( + granularity, + value + .checked_mul(scale) + .ok_or_else(|| exec_datafusion_err!("Timestamp {value} out of range"))?, + tz, + )?; let result = match tu { Second => match granularity { @@ -629,7 +794,7 @@ mod tests { use std::sync::Arc; use crate::datetime::date_trunc::{ - date_trunc_coarse, DateTruncFunc, DateTruncGranularity, + DateTruncFunc, DateTruncGranularity, date_trunc_coarse, }; use arrow::array::cast::as_primitive_array; @@ -637,9 +802,9 @@ mod tests { use arrow::array::{Array, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; use arrow::datatypes::{DataType, Field, TimeUnit}; - use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_common::config::ConfigOptions; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; #[test] fn date_trunc_test() { @@ -737,6 +902,19 @@ mod tests { }); } + #[test] + fn date_trunc_out_of_range_lower_bound_returns_error() { + let timestamp = string_to_timestamp_nanos("1677-09-22T00:00:00Z").unwrap(); + let err = date_trunc_coarse(DateTruncGranularity::Year, timestamp, None) + .unwrap_err() + .to_string(); + + assert!( + err.contains("out of range after truncating to year"), + "{err}" + ); + } + #[test] fn test_date_trunc_timezones() { let cases = [ @@ -881,7 +1059,7 @@ mod tests { Field::new("a", DataType::Utf8, false).into(), Field::new("b", input.data_type().clone(), false).into(), ]; - let args = datafusion_expr::ScalarFunctionArgs { + let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::from("day")), ColumnarValue::Array(Arc::new(input)), @@ -1069,7 +1247,7 @@ mod tests { Field::new("a", DataType::Utf8, false).into(), Field::new("b", input.data_type().clone(), false).into(), ]; - let args = datafusion_expr::ScalarFunctionArgs { + let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::from("hour")), ColumnarValue::Array(Arc::new(input)), @@ -1237,7 +1415,7 @@ mod tests { Field::new("a", DataType::Utf8, false).into(), Field::new("b", input.data_type().clone(), false).into(), ]; - let args = datafusion_expr::ScalarFunctionArgs { + let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::from(*granularity)), ColumnarValue::Array(Arc::new(input)), diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 5d6adfb6f119a..4787c75b610b6 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::datatypes::DataType::{Int64, Timestamp, Utf8}; use arrow::datatypes::TimeUnit::Second; use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, Volatility, }; use datafusion_macros::user_doc; @@ -69,10 +69,6 @@ impl FromUnixtimeFunc { } impl ScalarUDFImpl for FromUnixtimeFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "from_unixtime" } @@ -118,10 +114,7 @@ impl ScalarUDFImpl for FromUnixtimeFunc { internal_err!("call return_field_from_args instead") } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = args.args; let len = args.len(); if len != 1 && len != 2 { @@ -164,16 +157,16 @@ mod test { use crate::datetime::from_unixtime::FromUnixtimeFunc; use arrow::datatypes::TimeUnit::Second; use arrow::datatypes::{DataType, Field}; - use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; use datafusion_common::ScalarValue::Int64; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_common::config::ConfigOptions; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; #[test] fn test_without_timezone() { let arg_field = Arc::new(Field::new("a", DataType::Int64, true)); - let args = datafusion_expr::ScalarFunctionArgs { + let args = ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(Int64(Some(1729900800)))], arg_fields: vec![arg_field], number_rows: 1, @@ -196,7 +189,7 @@ mod test { Field::new("a", DataType::Int64, true).into(), Field::new("a", DataType::Utf8, true).into(), ]; - let args = datafusion_expr::ScalarFunctionArgs { + let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(Int64(Some(1729900800))), ColumnarValue::Scalar(ScalarValue::Utf8(Some( diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index 0fe5d156a8383..dc1328742f24e 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -15,20 +15,21 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::builder::PrimitiveBuilder; use arrow::array::cast::AsArray; use arrow::array::types::{Date32Type, Int32Type}; -use arrow::array::PrimitiveArray; +use arrow::array::{Array, PrimitiveArray}; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::{Date32, Int32, Int64, UInt32, UInt64, Utf8, Utf8View}; +use arrow::datatypes::DataType::Date32; use chrono::prelude::*; -use datafusion_common::{exec_err, utils::take_function_args, Result, ScalarValue}; +use datafusion_common::types::{NativeType, logical_int32, logical_string}; +use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; @@ -51,7 +52,7 @@ use datafusion_macros::user_doc; +-----------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/date_time.rs) "#, argument( name = "year", @@ -79,21 +80,21 @@ impl Default for MakeDateFunc { impl MakeDateFunc { pub fn new() -> Self { + let int = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![ + TypeSignatureClass::Integer, + TypeSignatureClass::Native(logical_string()), + ], + NativeType::Int32, + ); Self { - signature: Signature::uniform( - 3, - vec![Int32, Int64, UInt32, UInt64, Utf8, Utf8View], - Volatility::Immutable, - ), + signature: Signature::coercible(vec![int; 3], Volatility::Immutable), } } } impl ScalarUDFImpl for MakeDateFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "make_date" } @@ -106,91 +107,60 @@ impl ScalarUDFImpl for MakeDateFunc { Ok(Date32) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - // first, identify if any of the arguments is an Array. If yes, store its `len`, - // as any scalar will need to be converted to an array of len `len`. - let args = args.args; - let len = args - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - - let [years, months, days] = take_function_args(self.name(), args)?; - - if matches!(years, ColumnarValue::Scalar(ScalarValue::Null)) - || matches!(months, ColumnarValue::Scalar(ScalarValue::Null)) - || matches!(days, ColumnarValue::Scalar(ScalarValue::Null)) - { - return Ok(ColumnarValue::Scalar(ScalarValue::Null)); - } - - let years = years.cast_to(&Int32, None)?; - let months = months.cast_to(&Int32, None)?; - let days = days.cast_to(&Int32, None)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [years, months, days] = take_function_args(self.name(), args.args)?; - let scalar_value_fn = |col: &ColumnarValue| -> Result { - let ColumnarValue::Scalar(s) = col else { - return exec_err!("Expected scalar value"); - }; - let ScalarValue::Int32(Some(i)) = s else { - return exec_err!("Unable to parse date from null/empty value"); - }; - Ok(*i) - }; - - let value = if let Some(array_size) = len { - let to_primitive_array_fn = - |col: &ColumnarValue| -> PrimitiveArray { - match col { - ColumnarValue::Array(a) => { - a.as_primitive::().to_owned() - } - _ => { - let v = scalar_value_fn(col).unwrap(); - PrimitiveArray::::from_value(v, array_size) - } + match (years, months, days) { + (ColumnarValue::Scalar(y), _, _) if y.is_null() => { + Ok(ColumnarValue::Scalar(ScalarValue::Date32(None))) + } + (_, ColumnarValue::Scalar(m), _) if m.is_null() => { + Ok(ColumnarValue::Scalar(ScalarValue::Date32(None))) + } + (_, _, ColumnarValue::Scalar(d)) if d.is_null() => { + Ok(ColumnarValue::Scalar(ScalarValue::Date32(None))) + } + ( + ColumnarValue::Scalar(ScalarValue::Int32(Some(years))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(months))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(days))), + ) => { + let mut value = 0; + make_date_inner(years, months, days, |days: i32| value = days)?; + Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some(value)))) + } + (years, months, days) => { + let len = args.number_rows; + let years = years.into_array(len)?; + let months = months.into_array(len)?; + let days = days.into_array(len)?; + + let years = years.as_primitive::(); + let months = months.as_primitive::(); + let days = days.as_primitive::(); + + let mut builder: PrimitiveBuilder = + PrimitiveArray::builder(len); + + for i in 0..len { + // match postgresql behaviour which returns null for any null input + if years.is_null(i) || months.is_null(i) || days.is_null(i) { + builder.append_null(); + } else { + make_date_inner( + years.value(i), + months.value(i), + days.value(i), + |days: i32| builder.append_value(days), + )?; } - }; + } - let years = to_primitive_array_fn(&years); - let months = to_primitive_array_fn(&months); - let days = to_primitive_array_fn(&days); - - let mut builder: PrimitiveBuilder = - PrimitiveArray::builder(array_size); - for i in 0..array_size { - make_date_inner( - years.value(i), - months.value(i), - days.value(i), - |days: i32| builder.append_value(days), - )?; + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) } - - let arr = builder.finish(); - - ColumnarValue::Array(Arc::new(arr)) - } else { - // For scalar only columns the operation is faster without using the PrimitiveArray. - // Also, keep the output as scalar since all inputs are scalar. - let mut value = 0; - make_date_inner( - scalar_value_fn(&years)?, - scalar_value_fn(&months)?, - scalar_value_fn(&days)?, - |days: i32| value = days, - )?; - - ColumnarValue::Scalar(ScalarValue::Date32(Some(value))) - }; - - Ok(value) + } } + fn documentation(&self) -> Option<&Documentation> { self.doc() } @@ -204,11 +174,13 @@ fn make_date_inner( day: i32, mut date_consumer_fn: F, ) -> Result<()> { - let Ok(m) = u32::try_from(month) else { - return exec_err!("Month value '{month:?}' is out of range"); + let m = match month { + 1..=12 => month as u32, + _ => return exec_err!("Month value '{month:?}' is out of range"), }; - let Ok(d) = u32::try_from(day) else { - return exec_err!("Day value '{day:?}' is out of range"); + let d = match day { + 1..=31 => day as u32, + _ => return exec_err!("Day value '{day:?}' is out of range"), }; if let Some(date) = NaiveDate::from_ymd_opt(year, m, d) { @@ -225,180 +197,3 @@ fn make_date_inner( exec_err!("Unable to parse date from {year}, {month}, {day}") } } - -#[cfg(test)] -mod tests { - use crate::datetime::make_date::MakeDateFunc; - use arrow::array::{Array, Date32Array, Int32Array, Int64Array, UInt32Array}; - use arrow::datatypes::{DataType, Field}; - use datafusion_common::config::ConfigOptions; - use datafusion_common::{DataFusionError, ScalarValue}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use std::sync::Arc; - - fn invoke_make_date_with_args( - args: Vec, - number_rows: usize, - ) -> Result { - let arg_fields = args - .iter() - .map(|arg| Field::new("a", arg.data_type(), true).into()) - .collect::>(); - let args = datafusion_expr::ScalarFunctionArgs { - args, - arg_fields, - number_rows, - return_field: Field::new("f", DataType::Date32, true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - MakeDateFunc::new().invoke_with_args(args) - } - - #[test] - fn test_make_date() { - let res = invoke_make_date_with_args( - vec![ - ColumnarValue::Scalar(ScalarValue::Int32(Some(2024))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), - ], - 1, - ) - .expect("that make_date parsed values without error"); - - if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { - assert_eq!(19736, date.unwrap()); - } else { - panic!("Expected a scalar value") - } - - let res = invoke_make_date_with_args( - vec![ - ColumnarValue::Scalar(ScalarValue::Int64(Some(2024))), - ColumnarValue::Scalar(ScalarValue::UInt64(Some(1))), - ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), - ], - 1, - ) - .expect("that make_date parsed values without error"); - - if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { - assert_eq!(19736, date.unwrap()); - } else { - panic!("Expected a scalar value") - } - - let res = invoke_make_date_with_args( - vec![ - ColumnarValue::Scalar(ScalarValue::Utf8(Some("2024".to_string()))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("1".to_string()))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some("14".to_string()))), - ], - 1, - ) - .expect("that make_date parsed values without error"); - - if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { - assert_eq!(19736, date.unwrap()); - } else { - panic!("Expected a scalar value") - } - - let years = Arc::new((2021..2025).map(Some).collect::()); - let months = Arc::new((1..5).map(Some).collect::()); - let days = Arc::new((11..15).map(Some).collect::()); - let batch_len = years.len(); - let res = invoke_make_date_with_args( - vec![ - ColumnarValue::Array(years), - ColumnarValue::Array(months), - ColumnarValue::Array(days), - ], - batch_len, - ) - .unwrap(); - - if let ColumnarValue::Array(array) = res { - assert_eq!(array.len(), 4); - let mut builder = Date32Array::builder(4); - builder.append_value(18_638); - builder.append_value(19_035); - builder.append_value(19_429); - builder.append_value(19_827); - assert_eq!(&builder.finish() as &dyn Array, array.as_ref()); - } else { - panic!("Expected a columnar array") - } - - // - // Fallible test cases - // - - // invalid number of arguments - let res = invoke_make_date_with_args( - vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], - 1, - ); - assert_eq!( - res.err().unwrap().strip_backtrace(), - "Execution error: make_date function requires 3 arguments, got 1" - ); - - // invalid type - let res = invoke_make_date_with_args( - vec![ - ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - 1, - ); - assert_eq!( - res.err().unwrap().strip_backtrace(), - "Arrow error: Cast error: Casting from Interval(YearMonth) to Int32 not supported" - ); - - // overflow of month - let res = invoke_make_date_with_args( - vec![ - ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), - ColumnarValue::Scalar(ScalarValue::UInt64(Some(u64::MAX))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), - ], - 1, - ); - assert_eq!( - res.err().unwrap().strip_backtrace(), - "Arrow error: Cast error: Can't cast value 18446744073709551615 to type Int32" - ); - - // overflow of day - let res = invoke_make_date_with_args( - vec![ - ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), - ColumnarValue::Scalar(ScalarValue::UInt32(Some(u32::MAX))), - ], - 1, - ); - assert_eq!( - res.err().unwrap().strip_backtrace(), - "Arrow error: Cast error: Can't cast value 4294967295 to type Int32" - ); - } - - #[test] - fn test_make_date_null_param() { - let res = invoke_make_date_with_args( - vec![ - ColumnarValue::Scalar(ScalarValue::Null), - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), - ], - 1, - ) - .expect("that make_date parsed values without error"); - - assert!(matches!(res, ColumnarValue::Scalar(ScalarValue::Null))); - } -} diff --git a/datafusion/functions/src/datetime/make_time.rs b/datafusion/functions/src/datetime/make_time.rs new file mode 100644 index 0000000000000..d9e827ac23cbc --- /dev/null +++ b/datafusion/functions/src/datetime/make_time.rs @@ -0,0 +1,268 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::cast::AsArray; +use arrow::array::types::Int32Type; +use arrow::array::{Array, PrimitiveArray}; +use arrow::buffer::NullBuffer; +use arrow::datatypes::DataType::Time32; +use arrow::datatypes::{DataType, Time32SecondType, TimeUnit}; +use chrono::prelude::*; + +use datafusion_common::types::{NativeType, logical_int32, logical_string}; +use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = "Make a time from hour/minute/second component parts.", + syntax_example = "make_time(hour, minute, second)", + sql_example = r#"```sql +> select make_time(13, 23, 1); ++-------------------------------------------+ +| make_time(Int64(13),Int64(23),Int64(1)) | ++-------------------------------------------+ +| 13:23:01 | ++-------------------------------------------+ +> select make_time('23', '01', '31'); ++-----------------------------------------------+ +| make_time(Utf8("23"),Utf8("01"),Utf8("31")) | ++-----------------------------------------------+ +| 23:01:31 | ++-----------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/date_time.rs) +"#, + argument( + name = "hour", + description = "Hour to use when making the time. Can be a constant, column or function, and any combination of arithmetic operators." + ), + argument( + name = "minute", + description = "Minute to use when making the time. Can be a constant, column or function, and any combination of arithmetic operators." + ), + argument( + name = "second", + description = "Second to use when making the time. Can be a constant, column or function, and any combination of arithmetic operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct MakeTimeFunc { + signature: Signature, +} + +impl Default for MakeTimeFunc { + fn default() -> Self { + Self::new() + } +} + +impl MakeTimeFunc { + pub fn new() -> Self { + let int = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![ + TypeSignatureClass::Integer, + TypeSignatureClass::Native(logical_string()), + ], + NativeType::Int32, + ); + Self { + signature: Signature::coercible(vec![int; 3], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for MakeTimeFunc { + fn name(&self) -> &str { + "make_time" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Time32(TimeUnit::Second)) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [hours, minutes, seconds] = take_function_args(self.name(), args.args)?; + + match (hours, minutes, seconds) { + (ColumnarValue::Scalar(h), _, _) if h.is_null() => { + Ok(ColumnarValue::Scalar(ScalarValue::Time32Second(None))) + } + (_, ColumnarValue::Scalar(m), _) if m.is_null() => { + Ok(ColumnarValue::Scalar(ScalarValue::Time32Second(None))) + } + (_, _, ColumnarValue::Scalar(s)) if s.is_null() => { + Ok(ColumnarValue::Scalar(ScalarValue::Time32Second(None))) + } + ( + ColumnarValue::Scalar(ScalarValue::Int32(Some(hours))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(minutes))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(seconds))), + ) => { + let mut value = 0; + make_time_inner(hours, minutes, seconds, |seconds: i32| value = seconds)?; + Ok(ColumnarValue::Scalar(ScalarValue::Time32Second(Some( + value, + )))) + } + (hours, minutes, seconds) => { + let len = args.number_rows; + let hours = hours.into_array(len)?; + let minutes = minutes.into_array(len)?; + let seconds = seconds.into_array(len)?; + + let hours = hours.as_primitive::(); + let minutes = minutes.as_primitive::(); + let seconds = seconds.as_primitive::(); + + let nulls = NullBuffer::union_many([ + hours.nulls(), + minutes.nulls(), + seconds.nulls(), + ]); + + let mut values = Vec::with_capacity(len); + for i in 0..len { + // Match Postgres behaviour which returns null for any null input + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + values.push(0); + } else { + make_time_inner( + hours.value(i), + minutes.value(i), + seconds.value(i), + |seconds: i32| values.push(seconds), + )?; + } + } + + Ok(ColumnarValue::Array(Arc::new(PrimitiveArray::< + Time32SecondType, + >::new( + values.into(), nulls + )))) + } + } + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +/// Converts the hour/minute/second fields to an `i32` representing the seconds from +/// midnight and invokes `time_consumer_fn` with the value +fn make_time_inner( + hour: i32, + minute: i32, + second: i32, + mut time_consumer_fn: F, +) -> Result<()> { + let h = match hour { + 0..=24 => hour as u32, + _ => return exec_err!("Hour value '{hour:?}' is out of range"), + }; + let m = match minute { + 0..=60 => minute as u32, + _ => return exec_err!("Minute value '{minute:?}' is out of range"), + }; + let s = match second { + 0..=60 => second as u32, + _ => return exec_err!("Second value '{second:?}' is out of range"), + }; + + if let Some(time) = NaiveTime::from_hms_opt(h, m, s) { + time_consumer_fn(time.num_seconds_from_midnight() as i32); + Ok(()) + } else { + exec_err!("Unable to parse time from {hour}, {minute}, {second}") + } +} + +#[cfg(test)] +mod tests { + use crate::datetime::make_time::MakeTimeFunc; + use arrow::array::{Array, Int32Array, Time32SecondArray}; + use arrow::datatypes::TimeUnit::Second; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::DataFusionError; + use datafusion_common::config::ConfigOptions; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use std::sync::Arc; + + fn invoke_make_time_with_args( + args: Vec, + number_rows: usize, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true).into()) + .collect::>(); + let args = ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Time32(Second), true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + MakeTimeFunc::new().invoke_with_args(args) + } + + #[test] + fn test_make_time() { + let hours = Arc::new((4..8).map(Some).collect::()); + let minutes = Arc::new((1..5).map(Some).collect::()); + let seconds = Arc::new((11..15).map(Some).collect::()); + let batch_len = hours.len(); + let res = invoke_make_time_with_args( + vec![ + ColumnarValue::Array(hours), + ColumnarValue::Array(minutes), + ColumnarValue::Array(seconds), + ], + batch_len, + ) + .unwrap(); + + if let ColumnarValue::Array(array) = res { + assert_eq!(array.len(), 4); + + let mut builder = Time32SecondArray::builder(4); + builder.append_value(14_471); + builder.append_value(18_132); + builder.append_value(21_793); + builder.append_value(25_454); + assert_eq!(&builder.finish() as &dyn Array, array.as_ref()); + } else { + panic!("Expected a columnar array") + } + } +} diff --git a/datafusion/functions/src/datetime/mod.rs b/datafusion/functions/src/datetime/mod.rs index d80f14facf822..39b9453295df6 100644 --- a/datafusion/functions/src/datetime/mod.rs +++ b/datafusion/functions/src/datetime/mod.rs @@ -29,11 +29,13 @@ pub mod date_part; pub mod date_trunc; pub mod from_unixtime; pub mod make_date; +pub mod make_time; pub mod now; pub mod planner; pub mod to_char; pub mod to_date; pub mod to_local_time; +pub mod to_time; pub mod to_timestamp; pub mod to_unixtime; @@ -44,16 +46,21 @@ make_udf_function!(date_bin::DateBinFunc, date_bin); make_udf_function!(date_part::DatePartFunc, date_part); make_udf_function!(date_trunc::DateTruncFunc, date_trunc); make_udf_function!(make_date::MakeDateFunc, make_date); +make_udf_function!(make_time::MakeTimeFunc, make_time); make_udf_function!(from_unixtime::FromUnixtimeFunc, from_unixtime); make_udf_function!(to_char::ToCharFunc, to_char); make_udf_function!(to_date::ToDateFunc, to_date); make_udf_function!(to_local_time::ToLocalTimeFunc, to_local_time); +make_udf_function!(to_time::ToTimeFunc, to_time); make_udf_function!(to_unixtime::ToUnixtimeFunc, to_unixtime); -make_udf_function!(to_timestamp::ToTimestampFunc, to_timestamp); -make_udf_function!(to_timestamp::ToTimestampSecondsFunc, to_timestamp_seconds); -make_udf_function!(to_timestamp::ToTimestampMillisFunc, to_timestamp_millis); -make_udf_function!(to_timestamp::ToTimestampMicrosFunc, to_timestamp_micros); -make_udf_function!(to_timestamp::ToTimestampNanosFunc, to_timestamp_nanos); +make_udf_function_with_config!(to_timestamp::ToTimestampFunc, to_timestamp); +make_udf_function_with_config!( + to_timestamp::ToTimestampSecondsFunc, + to_timestamp_seconds +); +make_udf_function_with_config!(to_timestamp::ToTimestampMillisFunc, to_timestamp_millis); +make_udf_function_with_config!(to_timestamp::ToTimestampMicrosFunc, to_timestamp_micros); +make_udf_function_with_config!(to_timestamp::ToTimestampNanosFunc, to_timestamp_nanos); // create UDF with config make_udf_function_with_config!(now::NowFunc, now); @@ -90,6 +97,10 @@ pub mod expr_fn { make_date, "make a date from year, month and day component parts", year month day + ),( + make_time, + "make a time from hour, minute and second component parts", + hour minute second ),( now, "returns the current timestamp in nanoseconds, using the same value for all instances of now() in same statement", @@ -102,28 +113,32 @@ pub mod expr_fn { ), ( to_unixtime, - "converts a string and optional formats to a Unixtime", + "converts a value to seconds since the unix epoch", args, ),( - to_timestamp, - "converts a string and optional formats to a `Timestamp(Nanoseconds, None)`", + to_time, + "converts a string and optional formats to a `Time64(Nanoseconds)`", args, + ),( + to_timestamp, + "converts a string and optional formats to a `Timestamp(Nanoseconds, TimeZone)`", + @config args, ),( to_timestamp_seconds, - "converts a string and optional formats to a `Timestamp(Seconds, None)`", - args, + "converts a string and optional formats to a `Timestamp(Seconds, TimeZone)`", + @config args, ),( to_timestamp_millis, - "converts a string and optional formats to a `Timestamp(Milliseconds, None)`", - args, + "converts a string and optional formats to a `Timestamp(Milliseconds, TimeZone)`", + @config args, ),( to_timestamp_micros, - "converts a string and optional formats to a `Timestamp(Microseconds, None)`", - args, + "converts a string and optional formats to a `Timestamp(Microseconds, TimeZone)`", + @config args, ),( to_timestamp_nanos, - "converts a string and optional formats to a `Timestamp(Nanoseconds, None)`", - args, + "converts a string and optional formats to a `Timestamp(Nanoseconds, TimeZone)`", + @config args, )); /// Returns a string representation of a date, time, timestamp or duration based @@ -259,6 +274,7 @@ pub mod expr_fn { /// Returns all DataFusion functions defined in this package pub fn functions() -> Vec> { use datafusion_common::config::ConfigOptions; + let config = ConfigOptions::default(); vec![ current_date(), current_time(), @@ -267,15 +283,17 @@ pub fn functions() -> Vec> { date_trunc(), from_unixtime(), make_date(), - now(&ConfigOptions::default()), + make_time(), + now(&config), to_char(), to_date(), to_local_time(), + to_time(), to_unixtime(), - to_timestamp(), - to_timestamp_seconds(), - to_timestamp_millis(), - to_timestamp_micros(), - to_timestamp_nanos(), + to_timestamp(&config), + to_timestamp_seconds(&config), + to_timestamp_millis(&config), + to_timestamp_micros(&config), + to_timestamp_nanos(&config), ] } diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index 4723548a45584..82bb1251b2045 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -18,15 +18,14 @@ use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::Nanosecond; use arrow::datatypes::{DataType, Field, FieldRef}; -use std::any::Any; use std::sync::Arc; use datafusion_common::config::ConfigOptions; -use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_common::{Result, ScalarValue, internal_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDF, ScalarUDFImpl, - Signature, Volatility, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; @@ -37,7 +36,24 @@ Returns the current timestamp in the system configured timezone (None by default The `now()` return value is determined at query time and will return the same timestamp, no matter when in the query plan the function executes. "#, - syntax_example = "now()" + syntax_example = "now()", + sql_example = r#"```sql +> SELECT now(); ++----------------------------------+ +| now() | ++----------------------------------+ +| 2024-12-23T06:30:00.123456789 | ++----------------------------------+ + +-- The timezone of the returned timestamp depends on the session time zone +> SET datafusion.execution.time_zone = 'America/New_York'; +> SELECT now(); ++--------------------------------------+ +| now() | ++--------------------------------------+ +| 2024-12-23T01:30:00.123456789-05:00 | ++--------------------------------------+ +```"# )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct NowFunc { @@ -83,10 +99,6 @@ impl NowFunc { /// wherever it appears within a single statement. This value is /// chosen during planning time. impl ScalarUDFImpl for NowFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "now" } @@ -112,25 +124,24 @@ impl ScalarUDFImpl for NowFunc { internal_err!("return_field_from_args should be called instead") } - fn invoke_with_args( - &self, - _args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { internal_err!("invoke should not be called on a simplified now() function") } fn simplify( &self, - _args: Vec, - info: &dyn SimplifyInfo, + args: Vec, + info: &SimplifyContext, ) -> Result { - let now_ts = info - .execution_props() - .query_execution_start_time - .timestamp_nanos_opt(); + let Some(now_ts) = info.query_execution_start_time() else { + return Ok(ExprSimplifyResult::Original(args)); + }; Ok(ExprSimplifyResult::Simplified(Expr::Literal( - ScalarValue::TimestampNanosecond(now_ts, self.timezone.clone()), + ScalarValue::TimestampNanosecond( + now_ts.timestamp_nanos_opt(), + self.timezone.clone(), + ), None, ))) } @@ -148,7 +159,7 @@ impl ScalarUDFImpl for NowFunc { mod tests { use super::*; - #[allow(deprecated)] + #[expect(deprecated)] #[test] fn now_func_default_matches_config() { let default_config = ConfigOptions::default(); diff --git a/datafusion/functions/src/datetime/planner.rs b/datafusion/functions/src/datetime/planner.rs index f4b64c3711e2c..f2b8ef9d1d310 100644 --- a/datafusion/functions/src/datetime/planner.rs +++ b/datafusion/functions/src/datetime/planner.rs @@ -16,9 +16,9 @@ // under the License. //! SQL planning extensions like [`DatetimeFunctionPlanner`] +use datafusion_expr::Expr; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::planner::{ExprPlanner, PlannerResult}; -use datafusion_expr::Expr; #[derive(Default, Debug)] pub struct DatetimeFunctionPlanner; diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index ed8090c9a2399..5accddd07f2b4 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -15,23 +15,23 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; +use arrow::array::builder::StringBuilder; use arrow::array::cast::AsArray; -use arrow::array::{new_null_array, Array, ArrayRef, StringArray}; +use arrow::array::{Array, ArrayRef}; use arrow::compute::cast; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{ Date32, Date64, Duration, Time32, Time64, Timestamp, Utf8, }; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; -use arrow::error::ArrowError; use arrow::util::display::{ArrayFormatter, DurationFormat, FormatOptions}; -use datafusion_common::{exec_err, utils::take_function_args, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TIMEZONE_WILDCARD, Volatility, }; use datafusion_macros::user_doc; @@ -48,7 +48,7 @@ use datafusion_macros::user_doc; +----------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/date_time.rs) "#, argument( name = "expression", @@ -119,10 +119,6 @@ impl ToCharFunc { } impl ScalarUDFImpl for ToCharFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "to_char" } @@ -135,28 +131,22 @@ impl ScalarUDFImpl for ToCharFunc { Ok(Utf8) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = args.args; let [date_time, format] = take_function_args(self.name(), &args)?; match format { - ColumnarValue::Scalar(ScalarValue::Utf8(None)) - | ColumnarValue::Scalar(ScalarValue::Null) => to_char_scalar(date_time, None), - // constant format - ColumnarValue::Scalar(ScalarValue::Utf8(Some(format))) => { - // invoke to_char_scalar with the known string, without converting to array - to_char_scalar(date_time, Some(format)) + ColumnarValue::Scalar(ScalarValue::Null | ScalarValue::Utf8(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) } - ColumnarValue::Array(_) => to_char_array(&args), - _ => { - exec_err!( - "Format for `to_char` must be non-null Utf8, received {:?}", - format.data_type() - ) + ColumnarValue::Scalar(ScalarValue::Utf8(Some(fmt))) => { + to_char_scalar(date_time, fmt) } + ColumnarValue::Array(_) => to_char_array(&args), + _ => exec_err!( + "Format for `to_char` must be non-null Utf8, received {}", + format.data_type() + ), } } @@ -171,11 +161,8 @@ impl ScalarUDFImpl for ToCharFunc { fn build_format_options<'a>( data_type: &DataType, - format: Option<&'a str>, -) -> Result, Result> { - let Some(format) = format else { - return Ok(FormatOptions::new()); - }; + format: &'a str, +) -> Result> { let format_options = match data_type { Date32 => FormatOptions::new() .with_date_format(Some(format)) @@ -194,144 +181,124 @@ fn build_format_options<'a>( }, ), other => { - return Err(exec_err!( + return exec_err!( "to_char only supports date, time, timestamp and duration data types, received {other:?}" - )); + ); } }; Ok(format_options) } -/// Special version when arg\[1] is a scalar -fn to_char_scalar( - expression: &ColumnarValue, - format: Option<&str>, -) -> Result { - // it's possible that the expression is a scalar however because - // of the implementation in arrow-rs we need to convert it to an array +/// Formats `expression` using a constant `format` string. +fn to_char_scalar(expression: &ColumnarValue, format: &str) -> Result { + // ArrayFormatter requires an array, so scalar expressions must be + // converted to a 1-element array first. let data_type = &expression.data_type(); let is_scalar_expression = matches!(&expression, ColumnarValue::Scalar(_)); - let array = expression.clone().into_array(1)?; + let array = expression.to_array(1)?; - if format.is_none() { - return if is_scalar_expression { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) - } else { - Ok(ColumnarValue::Array(new_null_array(&Utf8, array.len()))) - }; - } + let format_options = build_format_options(data_type, format)?; + let formatter = ArrayFormatter::try_new(array.as_ref(), &format_options)?; - let format_options = match build_format_options(data_type, format) { - Ok(value) => value, - Err(value) => return value, - }; + // Pad the preallocated capacity a bit because format specifiers often + // expand the string (e.g., %Y -> "2026") + let fmt_len = format.len() + 10; + let mut builder = StringBuilder::with_capacity(array.len(), array.len() * fmt_len); - let formatter = ArrayFormatter::try_new(array.as_ref(), &format_options)?; - let formatted: Result>, ArrowError> = (0..array.len()) - .map(|i| { - if array.is_null(i) { - Ok(None) - } else { - formatter.value(i).try_to_string().map(Some) - } - }) - .collect(); - - if let Ok(formatted) = formatted { - if is_scalar_expression { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8( - formatted.first().unwrap().clone(), - ))) + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); } else { - Ok(ColumnarValue::Array( - Arc::new(StringArray::from(formatted)) as ArrayRef - )) - } - } else { - // if the data type was a Date32, formatting could have failed because the format string - // contained datetime specifiers, so we'll retry by casting the date array as a timestamp array - if data_type == &Date32 { - return to_char_scalar(&expression.cast_to(&Date64, None)?, format); + // Write directly into the builder's internal buffer, then + // commit the value with append_value(""). + match formatter.value(i).write(&mut builder) { + Ok(()) => builder.append_value(""), + // Arrow's Date32 formatter only handles date specifiers + // (%Y, %m, %d, ...). Format strings with time specifiers + // (%H, %M, %S, ...) cause it to fail. When this happens, + // we retry by casting to Date64, whose datetime formatter + // handles both date and time specifiers (with zero for + // the time components). + Err(_) if data_type == &Date32 => { + return to_char_scalar(&expression.cast_to(&Date64, None)?, format); + } + Err(e) => return Err(e.into()), + } } + } - exec_err!("{}", formatted.unwrap_err()) + let result = builder.finish(); + if is_scalar_expression { + let val = result.is_valid(0).then(|| result.value(0).to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(val))) + } else { + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) } } fn to_char_array(args: &[ColumnarValue]) -> Result { let arrays = ColumnarValue::values_to_arrays(args)?; - let mut results: Vec> = vec![]; + let data_array = &arrays[0]; let format_array = arrays[1].as_string::(); - let data_type = arrays[0].data_type(); - - for idx in 0..arrays[0].len() { - let format = if format_array.is_null(idx) { - None - } else { - Some(format_array.value(idx)) - }; - if format.is_none() { - results.push(None); + let data_type = data_array.data_type(); + + // Arbitrary guess for the length of a typical formatted datetime string + let fmt_len = 30; + let mut builder = + StringBuilder::with_capacity(data_array.len(), data_array.len() * fmt_len); + let mut buffer = String::with_capacity(fmt_len); + + // Lazily computed Date64 cast of the entire array, used when a Date32 + // format string contains time specifiers that the Date32 formatter + // cannot handle. Cast once and reuse for all subsequent rows + let mut date64_array: Option = None; + + for idx in 0..data_array.len() { + if format_array.is_null(idx) || data_array.is_null(idx) { + builder.append_null(); continue; } - let format_options = match build_format_options(data_type, format) { - Ok(value) => value, - Err(value) => return value, - }; - // this isn't ideal but this can't use ValueFormatter as it isn't independent - // from ArrayFormatter - let formatter = ArrayFormatter::try_new(arrays[0].as_ref(), &format_options)?; - let result = formatter.value(idx).try_to_string(); - match result { - Ok(value) => results.push(Some(value)), - Err(e) => { - // if the data type was a Date32, formatting could have failed because the format string - // contained datetime specifiers, so we'll treat this specific date element as a timestamp - if data_type == &Date32 { - let failed_date_value = arrays[0].slice(idx, 1); - - match retry_date_as_timestamp(&failed_date_value, &format_options) { - Ok(value) => { - results.push(Some(value)); - continue; - } - Err(e) => { - return exec_err!("{}", e); - } - } - } - return exec_err!("{}", e); + let format = format_array.value(idx); + let format_options = build_format_options(data_type, format)?; + let formatter = ArrayFormatter::try_new(data_array.as_ref(), &format_options)?; + + buffer.clear(); + + // We'd prefer to write directly to the StringBuilder's internal buffer, + // but the write might fail, and there's no easy way to ensure a partial + // write is removed from the buffer. So instead we write to a temporary + // buffer and `append_value` on success. + match formatter.value(idx).write(&mut buffer) { + Ok(()) => builder.append_value(&buffer), + Err(_) if data_type == &Date32 => { + buffer.clear(); + let date64_ref = match &date64_array { + Some(arr) => arr.as_ref(), + None => { + date64_array = Some(cast(data_array.as_ref(), &Date64)?); + date64_array.as_ref().unwrap().as_ref() + } + }; + let retry_options = build_format_options(&Date64, format)?; + let retry_fmt = ArrayFormatter::try_new(date64_ref, &retry_options)?; + retry_fmt.value(idx).write(&mut buffer)?; + builder.append_value(&buffer); } + Err(e) => return Err(e.into()), } } + let result = builder.finish(); match args[0] { - ColumnarValue::Array(_) => Ok(ColumnarValue::Array(Arc::new(StringArray::from( - results, - )) as ArrayRef)), - ColumnarValue::Scalar(_) => match results.first().unwrap() { - Some(value) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( - value.to_string(), - )))), - None => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), - }, + ColumnarValue::Scalar(_) => { + let val = result.is_valid(0).then(|| result.value(0).to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(val))) + } + ColumnarValue::Array(_) => Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)), } } -fn retry_date_as_timestamp( - array_ref: &ArrayRef, - format_options: &FormatOptions, -) -> Result { - let target_data_type = Date64; - - let date_value = cast(&array_ref, &target_data_type)?; - let formatter = ArrayFormatter::try_new(date_value.as_ref(), format_options)?; - let result = formatter.value(0).try_to_string()?; - - Ok(result) -} - #[cfg(test)] mod tests { use crate::datetime::to_char::ToCharFunc; @@ -343,25 +310,35 @@ mod tests { }; use arrow::datatypes::{DataType, Field, TimeUnit}; use chrono::{NaiveDateTime, Timelike}; - use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_common::config::ConfigOptions; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; #[test] fn test_array_array() { - let array_array_data = vec![( - Arc::new(Date32Array::from(vec![18506, 18507])) as ArrayRef, - StringArray::from(vec!["%Y::%m::%d", "%Y::%m::%d %S::%M::%H %f"]), - StringArray::from(vec!["2020::09::01", "2020::09::02 00::00::00 000000000"]), - )]; + let array_array_data = vec![ + ( + Arc::new(Date32Array::from(vec![18506, 18507])) as ArrayRef, + StringArray::from(vec!["%Y::%m::%d", "%Y::%m::%d %S::%M::%H %f"]), + StringArray::from(vec![ + "2020::09::01", + "2020::09::02 00::00::00 000000000", + ]), + ), + ( + Arc::new(Date32Array::from(vec![18506, 18507])) as ArrayRef, + StringArray::from(vec!["%Y::%m::%d %H:%M:%S", "%d-%m-%Y %H:%M"]), + StringArray::from(vec!["2020::09::01 00:00:00", "02-09-2020 00:00"]), + ), + ]; for (value, format, expected) in array_array_data { let batch_len = value.len(); let value_data_type = value.data_type().clone(); let format_data_type = format.data_type().clone(); - let args = datafusion_expr::ScalarFunctionArgs { + let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), @@ -472,7 +449,7 @@ mod tests { Field::new("a", value.data_type(), false).into(), Field::new("a", format.data_type(), false).into(), ]; - let args = datafusion_expr::ScalarFunctionArgs { + let args = ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(value), ColumnarValue::Scalar(format)], arg_fields, number_rows: 1, @@ -563,7 +540,7 @@ mod tests { Field::new("a", value.data_type(), false).into(), Field::new("a", format.data_type().to_owned(), false).into(), ]; - let args = datafusion_expr::ScalarFunctionArgs { + let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), @@ -727,7 +704,7 @@ mod tests { Field::new("a", value.data_type().clone(), false).into(), Field::new("a", format.data_type(), false).into(), ]; - let args = datafusion_expr::ScalarFunctionArgs { + let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(value as ArrayRef), ColumnarValue::Scalar(format), @@ -755,7 +732,7 @@ mod tests { Field::new("a", value.data_type().clone(), false).into(), Field::new("a", format.data_type().clone(), false).into(), ]; - let args = datafusion_expr::ScalarFunctionArgs { + let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), @@ -783,7 +760,7 @@ mod tests { // invalid number of arguments let arg_field = Field::new("a", DataType::Int32, true).into(); - let args = datafusion_expr::ScalarFunctionArgs { + let args = ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], arg_fields: vec![arg_field], number_rows: 1, @@ -801,7 +778,7 @@ mod tests { Field::new("a", DataType::Utf8, true).into(), Field::new("a", DataType::Timestamp(TimeUnit::Nanosecond, None), true).into(), ]; - let args = datafusion_expr::ScalarFunctionArgs { + let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), @@ -814,7 +791,7 @@ mod tests { let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( result.err().unwrap().strip_backtrace(), - "Execution error: Format for `to_char` must be non-null Utf8, received Timestamp(Nanosecond, None)" + "Execution error: Format for `to_char` must be non-null Utf8, received Timestamp(ns)" ); } } diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index 3840c8d8bbb94..cd75ac6bed3ac 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -16,22 +16,23 @@ // under the License. use crate::datetime::common::*; +use arrow::compute::cast_with_options; use arrow::datatypes::DataType; use arrow::datatypes::DataType::*; use arrow::error::ArrowError::ParseError; use arrow::{array::types::Date32Type, compute::kernels::cast_utils::Parser}; -use datafusion_common::error::DataFusionError; -use datafusion_common::{arrow_err, exec_err, internal_datafusion_err, Result}; +use datafusion_common::format::DEFAULT_CAST_OPTIONS; +use datafusion_common::{Result, arrow_err, exec_err, internal_datafusion_err}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; #[user_doc( doc_section(label = "Time and Date Functions"), description = r"Converts a value to a date (`YYYY-MM-DD`). -Supports strings, integer and double types as input. +Supports strings, numeric and timestamp types as input. Strings are parsed as YYYY-MM-DD (e.g. '2023-07-20') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and doubles are interpreted as days since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding date. @@ -53,7 +54,7 @@ Note: `to_date` returns Date32, which represents its values as the number of day +---------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/date_time.rs) "#, standard_argument(name = "expression", prefix = "String"), argument( @@ -83,7 +84,7 @@ impl ToDateFunc { fn to_date(&self, args: &[ColumnarValue]) -> Result { match args.len() { - 1 => handle::( + 1 => handle::( args, |s| match Date32Type::parse(s) { Some(v) => Ok(v), @@ -93,8 +94,9 @@ impl ToDateFunc { )), }, "to_date", + &Date32, ), - 2.. => handle_multiple::( + 2.. => handle_multiple::( args, |s, format| { string_to_timestamp_millis_formatted(s, format) @@ -107,6 +109,7 @@ impl ToDateFunc { }, |n| n, "to_date", + &Date32, ), 0 => exec_err!("Unsupported 0 argument count for function to_date"), } @@ -114,10 +117,6 @@ impl ToDateFunc { } impl ScalarUDFImpl for ToDateFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "to_date" } @@ -130,10 +129,7 @@ impl ScalarUDFImpl for ToDateFunc { Ok(Date32) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = args.args; if args.is_empty() { return exec_err!("to_date function requires 1 or more arguments, got 0"); @@ -145,9 +141,42 @@ impl ScalarUDFImpl for ToDateFunc { } match args[0].data_type() { - Int32 | Int64 | Null | Float64 | Date32 | Date64 => { + Null | Int32 | Int64 | Date32 | Date64 | Timestamp(_, _) => { args[0].cast_to(&Date32, None) } + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 => { + // Arrow cast doesn't support direct casting of these types to date32 + // as it only supports Int32 and Int64. To work around that limitation, + // use cast_with_options to cast to Int32 and then cast the result of + // that to Date32. + match &args[0] { + ColumnarValue::Array(array) => { + Ok(ColumnarValue::Array(cast_with_options( + &cast_with_options(&array, &Int32, &DEFAULT_CAST_OPTIONS)?, + &Date32, + &DEFAULT_CAST_OPTIONS, + )?)) + } + ColumnarValue::Scalar(scalar) => { + let sv = + scalar.cast_to_with_options(&Int32, &DEFAULT_CAST_OPTIONS)?; + Ok(ColumnarValue::Scalar( + sv.cast_to_with_options(&Date32, &DEFAULT_CAST_OPTIONS)?, + )) + } + } + } + Float16 + | Float32 + | Float64 + | Decimal32(_, _) + | Decimal64(_, _) + | Decimal128(_, _) + | Decimal256(_, _) => { + // The only way this makes sense is to get the Int64 value of the float + // or decimal and then cast that to Date32. + args[0].cast_to(&Int64, None)?.cast_to(&Date32, None) + } Utf8View | LargeUtf8 | Utf8 => self.to_date(&args), other => { exec_err!("Unsupported data type {} for function to_date", other) @@ -168,7 +197,7 @@ mod tests { use arrow::{compute::kernels::cast_utils::Parser, datatypes::Date32Type}; use datafusion_common::config::ConfigOptions; use datafusion_common::{DataFusionError, ScalarValue}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; fn invoke_to_date_with_args( @@ -180,7 +209,7 @@ mod tests { .map(|arg| Field::new("a", arg.data_type(), true).into()) .collect::>(); - let args = datafusion_expr::ScalarFunctionArgs { + let args = ScalarFunctionArgs { args, arg_fields, number_rows, @@ -352,7 +381,11 @@ mod tests { match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { let expected = Date32Type::parse_formatted(tc.date_str, "%Y-%m-%d"); - assert_eq!(date_val, expected, "{}: to_date created wrong value for date '{}' with format string '{}'", tc.name, tc.formatted_date, tc.format_str); + assert_eq!( + date_val, expected, + "{}: to_date created wrong value for date '{}' with format string '{}'", + tc.name, tc.formatted_date, tc.format_str + ); } _ => panic!( "Could not convert '{}' with format string '{}'to Date", @@ -386,7 +419,8 @@ mod tests { builder.append_value(expected.unwrap()); assert_eq!( - &builder.finish() as &dyn Array, a.as_ref(), + &builder.finish() as &dyn Array, + a.as_ref(), "{}: to_date created wrong value for date '{}' with format string '{}'", tc.name, tc.formatted_date, diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 6e0a150b0a35f..5bd1978893d54 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::ops::Add; use std::sync::Arc; use arrow::array::timezone::Tz; -use arrow::array::{Array, ArrayRef, PrimitiveBuilder}; +use arrow::array::{ArrayRef, PrimitiveArray}; use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{ @@ -31,11 +30,12 @@ use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc}; use datafusion_common::cast::as_primitive_array; use datafusion_common::{ - exec_err, internal_datafusion_err, plan_err, utils::take_function_args, Result, - ScalarValue, + Result, ScalarValue, exec_err, internal_datafusion_err, internal_err, + utils::take_function_args, }; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; @@ -111,131 +111,145 @@ impl Default for ToLocalTimeFunc { impl ToLocalTimeFunc { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::coercible( + vec![Coercion::new_exact(TypeSignatureClass::Timestamp)], + Volatility::Immutable, + ), } } +} - fn to_local_time(&self, args: &[ColumnarValue]) -> Result { - let [time_value] = take_function_args(self.name(), args)?; +impl ScalarUDFImpl for ToLocalTimeFunc { + fn name(&self) -> &str { + "to_local_time" + } - let arg_type = time_value.data_type(); - match arg_type { - Timestamp(_, None) => { - // if no timezone specified, just return the input - Ok(time_value.clone()) - } - // If has timezone, adjust the underlying time value. The current time value - // is stored as i64 in UTC, even though the timezone may not be in UTC. Therefore, - // we need to adjust the time value to the local time. See [`adjust_to_local_time`] - // for more details. - // - // Then remove the timezone in return type, i.e. return None - Timestamp(_, Some(timezone)) => { - let tz: Tz = timezone.parse()?; - - match time_value { - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - Some(ts), - Some(_), - )) => { - let adjusted_ts = - adjust_to_local_time::(*ts, tz)?; - Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - Some(adjusted_ts), - None, - ))) - } - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( - Some(ts), - Some(_), - )) => { - let adjusted_ts = - adjust_to_local_time::(*ts, tz)?; - Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( - Some(adjusted_ts), - None, - ))) - } - ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( - Some(ts), - Some(_), - )) => { - let adjusted_ts = - adjust_to_local_time::(*ts, tz)?; - Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( - Some(adjusted_ts), - None, - ))) - } - ColumnarValue::Scalar(ScalarValue::TimestampSecond( - Some(ts), - Some(_), - )) => { - let adjusted_ts = - adjust_to_local_time::(*ts, tz)?; - Ok(ColumnarValue::Scalar(ScalarValue::TimestampSecond( - Some(adjusted_ts), - None, - ))) - } - ColumnarValue::Array(array) => { - fn transform_array( - array: &ArrayRef, - tz: Tz, - ) -> Result { - let mut builder = PrimitiveBuilder::::new(); - - let primitive_array = as_primitive_array::(array)?; - for ts_opt in primitive_array.iter() { - match ts_opt { - None => builder.append_null(), - Some(ts) => { - let adjusted_ts: i64 = - adjust_to_local_time::(ts, tz)?; - builder.append_value(adjusted_ts) - } - } - } - - Ok(ColumnarValue::Array(Arc::new(builder.finish()))) - } - - match array.data_type() { - Timestamp(_, None) => { - // if no timezone specified, just return the input - Ok(time_value.clone()) - } - Timestamp(Nanosecond, Some(_)) => { - transform_array::(array, tz) - } - Timestamp(Microsecond, Some(_)) => { - transform_array::(array, tz) - } - Timestamp(Millisecond, Some(_)) => { - transform_array::(array, tz) - } - Timestamp(Second, Some(_)) => { - transform_array::(array, tz) - } - _ => { - exec_err!("to_local_time function requires timestamp argument in array, got {:?}", array.data_type()) - } - } - } - _ => { - exec_err!( - "to_local_time function requires timestamp argument, got {:?}", - time_value.data_type() - ) - } - } - } - _ => { - exec_err!( - "to_local_time function requires timestamp argument, got {:?}", - arg_type - ) - } + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::Null => Ok(Timestamp(Nanosecond, None)), + Timestamp(timeunit, _) => Ok(Timestamp(*timeunit, None)), + dt => internal_err!( + "The to_local_time function can only accept timestamp as the arg, got {dt}" + ), + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [time_value] = take_function_args(self.name(), &args.args)?; + to_local_time(time_value) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn transform_array( + array: &ArrayRef, + tz: Tz, +) -> Result { + let primitive_array = as_primitive_array::(array)?; + let result: PrimitiveArray = + primitive_array.try_unary(|ts| adjust_to_local_time::(ts, tz))?; + Ok(ColumnarValue::Array(Arc::new(result))) +} + +fn to_local_time(time_value: &ColumnarValue) -> Result { + let arg_type = time_value.data_type(); + + let tz: Tz = match &arg_type { + Timestamp(_, Some(timezone)) => timezone.parse()?, + Timestamp(_, None) => { + // if no timezone specified, just return the input + return Ok(time_value.clone()); + } + DataType::Null => { + return Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + None, None, + ))); + } + dt => { + return internal_err!( + "to_local_time function requires timestamp argument, got {dt}" + ); + } + }; + + // If has timezone, adjust the underlying time value. The current time value + // is stored as i64 in UTC, even though the timezone may not be in UTC. Therefore, + // we need to adjust the time value to the local time. See [`adjust_to_local_time`] + // for more details. + // + // Then remove the timezone in return type, i.e. return None + match time_value { + ColumnarValue::Scalar(ScalarValue::TimestampSecond(None, Some(_))) => Ok( + ColumnarValue::Scalar(ScalarValue::TimestampSecond(None, None)), + ), + ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(None, Some(_))) => Ok( + ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(None, None)), + ), + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(None, Some(_))) => Ok( + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(None, None)), + ), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(None, Some(_))) => Ok( + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(None, None)), + ), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(ts), Some(_))) => { + let adjusted_ts = adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(ts), Some(_))) => { + let adjusted_ts = adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(Some(ts), Some(_))) => { + let adjusted_ts = adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Scalar(ScalarValue::TimestampSecond(Some(ts), Some(_))) => { + let adjusted_ts = adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampSecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Array(array) + if matches!(array.data_type(), Timestamp(Nanosecond, Some(_))) => + { + transform_array::(array, tz) + } + ColumnarValue::Array(array) + if matches!(array.data_type(), Timestamp(Microsecond, Some(_))) => + { + transform_array::(array, tz) + } + ColumnarValue::Array(array) + if matches!(array.data_type(), Timestamp(Millisecond, Some(_))) => + { + transform_array::(array, tz) + } + ColumnarValue::Array(array) + if matches!(array.data_type(), Timestamp(Second, Some(_))) => + { + transform_array::(array, tz) + } + _ => { + internal_err!( + "to_local_time function requires timestamp argument, got {arg_type}" + ) } } } @@ -293,7 +307,7 @@ impl ToLocalTimeFunc { /// ``` /// /// See `test_adjust_to_local_time()` for example -fn adjust_to_local_time(ts: i64, tz: Tz) -> Result { +pub fn adjust_to_local_time(ts: i64, tz: Tz) -> Result { fn convert_timestamp(ts: i64, converter: F) -> Result> where F: Fn(i64) -> MappedLocalTime>, @@ -343,81 +357,19 @@ fn adjust_to_local_time(ts: i64, tz: Tz) -> Result { } } -impl ScalarUDFImpl for ToLocalTimeFunc { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - "to_local_time" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - let [time_value] = take_function_args(self.name(), arg_types)?; - - match time_value { - Timestamp(timeunit, _) => Ok(Timestamp(*timeunit, None)), - _ => exec_err!( - "The to_local_time function can only accept timestamp as the arg, got {:?}", time_value - ) - } - } - - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - let [time_value] = take_function_args(self.name(), args.args)?; - - self.to_local_time(std::slice::from_ref(&time_value)) - } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 1 { - return plan_err!( - "to_local_time function requires 1 argument, got {:?}", - arg_types.len() - ); - } - - let first_arg = arg_types[0].clone(); - match &first_arg { - DataType::Null => Ok(vec![Timestamp(Nanosecond, None)]), - Timestamp(Nanosecond, timezone) => { - Ok(vec![Timestamp(Nanosecond, timezone.clone())]) - } - Timestamp(Microsecond, timezone) => { - Ok(vec![Timestamp(Microsecond, timezone.clone())]) - } - Timestamp(Millisecond, timezone) => { - Ok(vec![Timestamp(Millisecond, timezone.clone())]) - } - Timestamp(Second, timezone) => Ok(vec![Timestamp(Second, timezone.clone())]), - _ => plan_err!("The to_local_time function can only accept Timestamp as the arg got {first_arg}"), - } - } - fn documentation(&self) -> Option<&Documentation> { - self.doc() - } -} - #[cfg(test)] mod tests { use std::sync::Arc; - use arrow::array::{types::TimestampNanosecondType, Array, TimestampNanosecondArray}; + use arrow::array::{Array, TimestampNanosecondArray, types::TimestampNanosecondType}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; use arrow::datatypes::{DataType, Field, TimeUnit}; use chrono::NaiveDateTime; - use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; + use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; - use super::{adjust_to_local_time, ToLocalTimeFunc}; + use super::{ToLocalTimeFunc, adjust_to_local_time}; #[test] fn test_adjust_to_local_time() { diff --git a/datafusion/functions/src/datetime/to_time.rs b/datafusion/functions/src/datetime/to_time.rs new file mode 100644 index 0000000000000..94aa49fbbad2f --- /dev/null +++ b/datafusion/functions/src/datetime/to_time.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::datetime::common::*; +use arrow::array::cast::AsArray; +use arrow::array::temporal_conversions::time_to_time64ns; +use arrow::array::types::Time64NanosecondType; +use arrow::array::{Array, PrimitiveArray, StringArrayType}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::*; +use chrono::NaiveTime; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +/// Default time formats to try when parsing without an explicit format +const DEFAULT_TIME_FORMATS: &[&str] = &[ + "%H:%M:%S%.f", // 12:30:45.123456789 + "%H:%M:%S", // 12:30:45 + "%H:%M", // 12:30 +]; + +#[user_doc( + doc_section(label = "Time and Date Functions"), + description = r"Converts a value to a time (`HH:MM:SS.nnnnnnnnn`). +Supports strings and timestamps as input. +Strings are parsed as `HH:MM:SS`, `HH:MM:SS.nnnnnnnnn`, or `HH:MM` if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. +Timestamps will have the time portion extracted. +Returns the corresponding time. + +Note: `to_time` returns Time64(Nanosecond), which represents the time of day in nanoseconds since midnight.", + syntax_example = "to_time('12:30:45', '%H:%M:%S')", + sql_example = r#"```sql +> select to_time('12:30:45'); ++---------------------------+ +| to_time(Utf8("12:30:45")) | ++---------------------------+ +| 12:30:45 | ++---------------------------+ +> select to_time('12-30-45', '%H-%M-%S'); ++--------------------------------------------+ +| to_time(Utf8("12-30-45"),Utf8("%H-%M-%S")) | ++--------------------------------------------+ +| 12:30:45 | ++--------------------------------------------+ +> select to_time('2024-01-15 14:30:45'::timestamp); ++--------------------------------------------------+ +| to_time(Utf8("2024-01-15 14:30:45")) | ++--------------------------------------------------+ +| 14:30:45 | ++--------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/date_time.rs) +"#, + standard_argument(name = "expression", prefix = "String or Timestamp"), + argument( + name = "format_n", + description = r"Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order + they appear with the first successful one being returned. If none of the formats successfully parse the expression + an error will be returned." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ToTimeFunc { + signature: Signature, +} + +impl Default for ToTimeFunc { + fn default() -> Self { + Self::new() + } +} + +impl ToTimeFunc { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ToTimeFunc { + fn name(&self) -> &str { + "to_time" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Time64(arrow::datatypes::TimeUnit::Nanosecond)) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = args.args; + if args.is_empty() { + return exec_err!("to_time function requires 1 or more arguments, got 0"); + } + + // validate that any args after the first one are Utf8 + if args.len() > 1 { + validate_data_types(&args, "to_time")?; + } + + match args[0].data_type() { + Utf8View | LargeUtf8 | Utf8 => string_to_time(&args), + Null => Ok(ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(None))), + // Support timestamp input by extracting time portion using Arrow cast + Timestamp(_, _) => timestamp_to_time(&args[0]), + other => { + exec_err!("Unsupported data type {} for function to_time", other) + } + } + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +/// Convert string arguments to time (standalone function, not a method on ToTimeFunc) +fn string_to_time(args: &[ColumnarValue]) -> Result { + let formats = collect_formats(args)?; + + match &args[0] { + ColumnarValue::Scalar(ScalarValue::Utf8(s)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(s)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(s)) => { + let result = s + .as_ref() + .map(|s| parse_time_with_formats(s, &formats)) + .transpose()?; + Ok(ColumnarValue::Scalar(ScalarValue::Time64Nanosecond(result))) + } + ColumnarValue::Array(array) => { + let result = match array.data_type() { + Utf8 => parse_time_array(&array.as_string::(), &formats)?, + LargeUtf8 => parse_time_array(&array.as_string::(), &formats)?, + Utf8View => parse_time_array(&array.as_string_view(), &formats)?, + other => return exec_err!("Unsupported type for to_time: {other}"), + }; + Ok(ColumnarValue::Array(Arc::new(result))) + } + other => exec_err!("Unsupported argument for to_time: {other:?}"), + } +} + +/// Collect format strings from arguments, erroring on non-scalar inputs +fn collect_formats(args: &[ColumnarValue]) -> Result> { + if args.len() <= 1 { + return Ok(DEFAULT_TIME_FORMATS.to_vec()); + } + + let mut formats = Vec::with_capacity(args.len() - 1); + for (i, arg) in args[1..].iter().enumerate() { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) + | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) => { + formats.push(s.as_str()); + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) => { + // Skip null format strings + } + ColumnarValue::Array(_) => { + return exec_err!( + "to_time format argument {} must be a scalar, not an array", + i + 2 // argument position (1-indexed, +1 for the first arg) + ); + } + other => { + return exec_err!( + "to_time format argument {} has unsupported type: {:?}", + i + 2, + other.data_type() + ); + } + } + } + Ok(formats) +} + +/// Extract time portion from timestamp using Arrow cast kernel +fn timestamp_to_time(arg: &ColumnarValue) -> Result { + arg.cast_to(&Time64(arrow::datatypes::TimeUnit::Nanosecond), None) +} + +/// Parse time array using the provided formats +fn parse_time_array<'a, A: StringArrayType<'a>>( + array: &A, + formats: &[&str], +) -> Result> { + let mut values = Vec::with_capacity(array.len()); + for i in 0..array.len() { + if array.is_null(i) { + values.push(0); + } else { + values.push(parse_time_with_formats(array.value(i), formats)?); + } + } + Ok(PrimitiveArray::new(values.into(), array.nulls().cloned())) +} + +/// Parse time string using provided formats +fn parse_time_with_formats(s: &str, formats: &[&str]) -> Result { + for format in formats { + if let Ok(time) = NaiveTime::parse_from_str(s, format) { + // Use Arrow's time_to_time64ns function instead of custom implementation + return Ok(time_to_time64ns(time)); + } + } + exec_err!( + "Error parsing '{}' as time. Tried formats: {:?}", + s, + formats + ) +} diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 0a0700097770f..f4507ab250559 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -15,30 +15,45 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use crate::datetime::common::*; -use arrow::array::Float64Array; +use arrow::array::timezone::Tz; +use arrow::array::{ + Array, Decimal128Array, Float16Array, Float32Array, Float64Array, + TimestampNanosecondArray, +}; use arrow::datatypes::DataType::*; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{ - ArrowTimestampType, DataType, TimeUnit, TimestampMicrosecondType, + ArrowTimestampType, DECIMAL128_MAX_PRECISION, DataType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use datafusion_common::format::DEFAULT_CAST_OPTIONS; -use datafusion_common::{exec_err, Result, ScalarType, ScalarValue}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{Result, ScalarType, ScalarValue, exec_datafusion_err, exec_err}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, }; use datafusion_macros::user_doc; #[user_doc( doc_section(label = "Time and Date Functions"), description = r#" -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats] are provided. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. - -Note: `to_timestamp` returns `Timestamp(ns)`. The supported range for integer input is between `-9223372037` and `9223372036`. Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds. +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000`) in the session time zone. Supports strings, +integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the +session time zone, or UTC if no session time zone is set. +Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). + +Note: `to_timestamp` returns `Timestamp(ns, TimeZone)` where the time zone is the session time zone. The supported range +for integer input is between`-9223372037` and `9223372036`. Supported range for string input is between +`1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` +for the input outside of supported bounds. + +The session time zone can be set using the statement `SET TIMEZONE = 'desired time zone'`. +The time zone can be a value like +00:00, 'Europe/London' etc. "#, syntax_example = "to_timestamp(expression[, ..., format_n])", sql_example = r#"```sql @@ -55,7 +70,7 @@ Note: `to_timestamp` returns `Timestamp(ns)`. The supported range for integer in | 2023-05-17T03:59:00.123456789 | +--------------------------------------------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/date_time.rs) "#, argument( name = "expression", @@ -63,17 +78,33 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ), argument( name = "format_n", - description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." + description = r#" +Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. +Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully +parse the expression an error will be returned. Note: parsing of named timezones (e.g. 'America/New_York') using %Z is +only supported at the end of the string preceded by a space. +"# ) )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct ToTimestampFunc { signature: Signature, + timezone: Option>, } #[user_doc( doc_section(label = "Time and Date Functions"), - description = "Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.", + description = r#" +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00`) in the session time zone. Supports strings, +integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the +session time zone, or UTC if no session time zone is set. +Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). + +The session time zone can be set using the statement `SET TIMEZONE = 'desired time zone'`. +The time zone can be a value like +00:00, 'Europe/London' etc. +"#, syntax_example = "to_timestamp_seconds(expression[, ..., format_n])", sql_example = r#"```sql > select to_timestamp_seconds('2023-01-31T09:26:56.123456789-05:00'); @@ -89,7 +120,7 @@ pub struct ToTimestampFunc { | 2023-05-17T03:59:00 | +----------------------------------------------------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/date_time.rs) "#, argument( name = "expression", @@ -97,17 +128,33 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ), argument( name = "format_n", - description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." + description = r#" +Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. +Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully +parse the expression an error will be returned. Note: parsing of named timezones (e.g. 'America/New_York') using %Z is +only supported at the end of the string preceded by a space. +"# ) )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct ToTimestampSecondsFunc { signature: Signature, + timezone: Option>, } #[user_doc( doc_section(label = "Time and Date Functions"), - description = "Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.", + description = r#" +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000`) in the session time zone. Supports strings, +integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the +session time zone, or UTC if no session time zone is set. +Integers, unsigned integers, and doubles are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). + +The session time zone can be set using the statement `SET TIMEZONE = 'desired time zone'`. +The time zone can be a value like +00:00, 'Europe/London' etc. +"#, syntax_example = "to_timestamp_millis(expression[, ..., format_n])", sql_example = r#"```sql > select to_timestamp_millis('2023-01-31T09:26:56.123456789-05:00'); @@ -123,7 +170,7 @@ pub struct ToTimestampSecondsFunc { | 2023-05-17T03:59:00.123 | +---------------------------------------------------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/date_time.rs) "#, argument( name = "expression", @@ -131,17 +178,33 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ), argument( name = "format_n", - description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." + description = r#" +Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. +Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully +parse the expression an error will be returned. Note: parsing of named timezones (e.g. 'America/New_York') using %Z is +only supported at the end of the string preceded by a space. +"# ) )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct ToTimestampMillisFunc { signature: Signature, + timezone: Option>, } #[user_doc( doc_section(label = "Time and Date Functions"), - description = "Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) Returns the corresponding timestamp.", + description = r#" +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000`) in the session time zone. Supports strings, +integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the +session time zone, or UTC if no session time zone is set. +Integers, unsigned integers, and doubles are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`). + +The session time zone can be set using the statement `SET TIMEZONE = 'desired time zone'`. +The time zone can be a value like +00:00, 'Europe/London' etc. +"#, syntax_example = "to_timestamp_micros(expression[, ..., format_n])", sql_example = r#"```sql > select to_timestamp_micros('2023-01-31T09:26:56.123456789-05:00'); @@ -157,7 +220,7 @@ pub struct ToTimestampMillisFunc { | 2023-05-17T03:59:00.123456 | +---------------------------------------------------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/date_time.rs) "#, argument( name = "expression", @@ -165,17 +228,32 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ), argument( name = "format_n", - description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." + description = r#" +Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. +Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully +parse the expression an error will be returned. Note: parsing of named timezones (e.g. 'America/New_York') using %Z is +only supported at the end of the string preceded by a space. +"# ) )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct ToTimestampMicrosFunc { signature: Signature, + timezone: Option>, } #[user_doc( doc_section(label = "Time and Date Functions"), - description = "Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.", + description = r#" +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000`) in the session time zone. Supports strings, +integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Strings that parse without a time zone are treated as if they are in the +session time zone. Integers, unsigned integers, and doubles are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). + +The session time zone can be set using the statement `SET TIMEZONE = 'desired time zone'`. +The time zone can be a value like +00:00, 'Europe/London' etc. +"#, syntax_example = "to_timestamp_nanos(expression[, ..., format_n])", sql_example = r#"```sql > select to_timestamp_nanos('2023-01-31T09:26:56.123456789-05:00'); @@ -191,7 +269,7 @@ pub struct ToTimestampMicrosFunc { | 2023-05-17T03:59:00.123456789 | +---------------------------------------------------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/date_time_functions.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/date_time.rs) "#, argument( name = "expression", @@ -199,81 +277,114 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ), argument( name = "format_n", - description = "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned." + description = r#" +Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. +Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully +parse the expression an error will be returned. Note: parsing of named timezones (e.g. 'America/New_York') using %Z is +only supported at the end of the string preceded by a space. +"# ) )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct ToTimestampNanosFunc { signature: Signature, + timezone: Option>, } -impl Default for ToTimestampFunc { - fn default() -> Self { - Self::new() - } -} - -impl ToTimestampFunc { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), +/// Macro to generate boilerplate constructors and config methods for ToTimestamp* functions. +/// Generates: Default impl, deprecated new(), new_with_config(), and extracts timezone from ConfigOptions. +macro_rules! impl_to_timestamp_constructors { + ($func:ty) => { + impl Default for $func { + fn default() -> Self { + Self::new_with_config(&ConfigOptions::default()) + } } - } -} -impl Default for ToTimestampSecondsFunc { - fn default() -> Self { - Self::new() - } -} + impl $func { + #[deprecated(since = "52.0.0", note = "use `new_with_config` instead")] + /// Deprecated constructor retained for backwards compatibility. + /// + /// Prefer `new_with_config` which allows specifying the + /// timezone via [`ConfigOptions`]. This helper now mirrors the + /// canonical default offset (None) provided by `ConfigOptions::default()`. + pub fn new() -> Self { + Self::new_with_config(&ConfigOptions::default()) + } -impl ToTimestampSecondsFunc { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), + pub fn new_with_config(config: &ConfigOptions) -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + timezone: config + .execution + .time_zone + .as_ref() + .map(|tz| Arc::from(tz.as_str())), + } + } } - } -} - -impl Default for ToTimestampMillisFunc { - fn default() -> Self { - Self::new() - } + }; } -impl ToTimestampMillisFunc { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - } - } -} +impl_to_timestamp_constructors!(ToTimestampFunc); +impl_to_timestamp_constructors!(ToTimestampSecondsFunc); +impl_to_timestamp_constructors!(ToTimestampMillisFunc); +impl_to_timestamp_constructors!(ToTimestampMicrosFunc); +impl_to_timestamp_constructors!(ToTimestampNanosFunc); + +fn decimal_to_nanoseconds(value: i128, scale: i8) -> Result { + let nanos_exponent = 9_i16 - scale as i16; + let power = 10_i128 + .checked_pow(nanos_exponent.unsigned_abs() as u32) + .ok_or_else(|| { + exec_datafusion_err!( + "Decimal value {value} with scale {scale} overflows timestamp nanoseconds" + ) + })?; + + let timestamp_nanos = if nanos_exponent >= 0 { + value.checked_mul(power).ok_or_else(|| { + exec_datafusion_err!( + "Decimal value {value} with scale {scale} overflows timestamp nanoseconds" + ) + })? + } else { + value / power + }; -impl Default for ToTimestampMicrosFunc { - fn default() -> Self { - Self::new() - } + i64::try_from(timestamp_nanos).map_err(|_| { + exec_datafusion_err!( + "Decimal value {value} with scale {scale} overflows timestamp nanoseconds" + ) + }) } -impl ToTimestampMicrosFunc { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), +fn decimal128_to_timestamp_nanos( + arg: &ColumnarValue, + tz: Option>, +) -> Result { + match arg { + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(value), _, scale)) => { + let timestamp_nanos = decimal_to_nanoseconds(*value, *scale)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(timestamp_nanos), + tz, + ))) } - } -} - -impl Default for ToTimestampNanosFunc { - fn default() -> Self { - Self::new() - } -} - -impl ToTimestampNanosFunc { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), + ColumnarValue::Scalar(ScalarValue::Decimal128(None, _, _)) => Ok( + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(None, tz)), + ), + ColumnarValue::Array(arr) => { + let decimal_arr = downcast_arg!(arr, Decimal128Array); + let scale = decimal_arr.scale(); + let result: TimestampNanosecondArray = decimal_arr + .iter() + .map(|v| v.map(|val| decimal_to_nanoseconds(val, scale)).transpose()) + .collect::>()?; + let result = result.with_timezone_opt(tz); + Ok(ColumnarValue::Array(Arc::new(result))) } + _ => exec_err!("Invalid Decimal128 value for to_timestamp"), } } @@ -283,11 +394,16 @@ impl ToTimestampNanosFunc { /// The supported range for integer input is between `-9223372037` and `9223372036`. /// Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. /// Please use `to_timestamp_seconds` for the input outside of supported bounds. -impl ScalarUDFImpl for ToTimestampFunc { - fn as_any(&self) -> &dyn Any { - self - } +/// Macro to generate the with_updated_config method for ToTimestamp* functions. +macro_rules! impl_with_updated_config { + () => { + fn with_updated_config(&self, config: &ConfigOptions) -> Option { + Some(Self::new_with_config(config).into()) + } + }; +} +impl ScalarUDFImpl for ToTimestampFunc { fn name(&self) -> &str { "to_timestamp" } @@ -296,15 +412,15 @@ impl ScalarUDFImpl for ToTimestampFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(return_type_for(&arg_types[0], Nanosecond)) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Timestamp(Nanosecond, self.timezone.clone())) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - let args = args.args; + impl_with_updated_config!(); + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + if args.is_empty() { return exec_err!( "to_timestamp function requires 1 or more arguments, got {}", @@ -317,71 +433,84 @@ impl ScalarUDFImpl for ToTimestampFunc { validate_data_types(&args, "to_timestamp")?; } + let tz = self.timezone.clone(); + match args[0].data_type() { - Int32 | Int64 => args[0] + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => args[0] .cast_to(&Timestamp(Second, None), None)? - .cast_to(&Timestamp(Nanosecond, None), None), - Null | Timestamp(_, None) => { - args[0].cast_to(&Timestamp(Nanosecond, None), None) - } - Float64 => { - let rescaled = arrow::compute::kernels::numeric::mul( - &args[0].to_array(1)?, - &arrow::array::Scalar::new(Float64Array::from(vec![ - 1_000_000_000f64, - ])), - )?; - Ok(ColumnarValue::Array(arrow::compute::cast_with_options( - &rescaled, - &Timestamp(Nanosecond, None), - &DEFAULT_CAST_OPTIONS, - )?)) - } - Timestamp(_, Some(tz)) => { - args[0].cast_to(&Timestamp(Nanosecond, Some(tz)), None) + .cast_to(&Timestamp(Nanosecond, tz), None), + Null | Timestamp(_, _) => args[0].cast_to(&Timestamp(Nanosecond, tz), None), + Float16 => match &args[0] { + ColumnarValue::Scalar(ScalarValue::Float16(value)) => { + let timestamp_nanos = + value.map(|v| (v.to_f64() * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + timestamp_nanos, + tz, + ))) + } + ColumnarValue::Array(arr) => { + let f16_arr = downcast_arg!(arr, Float16Array); + let result: TimestampNanosecondArray = + f16_arr.unary(|x| (x.to_f64() * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Array(Arc::new(result.with_timezone_opt(tz)))) + } + _ => exec_err!("Invalid Float16 value for to_timestamp"), + }, + Float32 => match &args[0] { + ColumnarValue::Scalar(ScalarValue::Float32(value)) => { + let timestamp_nanos = + value.map(|v| (v as f64 * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + timestamp_nanos, + tz, + ))) + } + ColumnarValue::Array(arr) => { + let f32_arr = downcast_arg!(arr, Float32Array); + let result: TimestampNanosecondArray = + f32_arr.unary(|x| (x as f64 * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Array(Arc::new(result.with_timezone_opt(tz)))) + } + _ => exec_err!("Invalid Float32 value for to_timestamp"), + }, + Float64 => match &args[0] { + ColumnarValue::Scalar(ScalarValue::Float64(value)) => { + let timestamp_nanos = value.map(|v| (v * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + timestamp_nanos, + tz, + ))) + } + ColumnarValue::Array(arr) => { + let f64_arr = downcast_arg!(arr, Float64Array); + let result: TimestampNanosecondArray = + f64_arr.unary(|x| (x * 1_000_000_000.0) as i64); + Ok(ColumnarValue::Array(Arc::new(result.with_timezone_opt(tz)))) + } + _ => exec_err!("Invalid Float64 value for to_timestamp"), + }, + Decimal32(_, _) | Decimal64(_, _) | Decimal256(_, _) => { + let arg = + args[0].cast_to(&Decimal128(DECIMAL128_MAX_PRECISION, 9), None)?; + decimal128_to_timestamp_nanos(&arg, tz) } + Decimal128(_, _) => decimal128_to_timestamp_nanos(&args[0], tz), Utf8View | LargeUtf8 | Utf8 => { - to_timestamp_impl::(&args, "to_timestamp") - } - Decimal128(_, _) => { - match &args[0] { - ColumnarValue::Scalar(ScalarValue::Decimal128( - Some(value), - _, - scale, - )) => { - // Convert decimal to seconds and nanoseconds - let scale_factor = 10_i128.pow(*scale as u32); - let seconds = value / scale_factor; - let fraction = value % scale_factor; - - let nanos = (fraction * 1_000_000_000) / scale_factor; - - let timestamp_nanos = seconds * 1_000_000_000 + nanos; - - Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - Some(timestamp_nanos as i64), - None, - ))) - } - _ => exec_err!("Invalid decimal value"), - } + to_timestamp_impl::(&args, "to_timestamp", &tz) } other => { exec_err!("Unsupported data type {other} for function to_timestamp") } } } + fn documentation(&self) -> Option<&Documentation> { self.doc() } } impl ScalarUDFImpl for ToTimestampSecondsFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "to_timestamp_seconds" } @@ -390,15 +519,15 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(return_type_for(&arg_types[0], Second)) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Timestamp(Second, self.timezone.clone())) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - let args = args.args; + impl_with_updated_config!(); + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + if args.is_empty() { return exec_err!( "to_timestamp_seconds function requires 1 or more arguments, got {}", @@ -411,14 +540,31 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { validate_data_types(&args, "to_timestamp")?; } + let tz = self.timezone.clone(); + match args[0].data_type() { - Null | Int32 | Int64 | Timestamp(_, None) | Decimal128(_, _) => { - args[0].cast_to(&Timestamp(Second, None), None) - } - Timestamp(_, Some(tz)) => args[0].cast_to(&Timestamp(Second, Some(tz)), None), - Utf8View | LargeUtf8 | Utf8 => { - to_timestamp_impl::(&args, "to_timestamp_seconds") - } + Null + | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Timestamp(_, _) + | Decimal32(_, _) + | Decimal64(_, _) + | Decimal128(_, _) + | Decimal256(_, _) => args[0].cast_to(&Timestamp(Second, tz), None), + Float16 | Float32 | Float64 => args[0] + .cast_to(&Int64, None)? + .cast_to(&Timestamp(Second, tz), None), + Utf8View | LargeUtf8 | Utf8 => to_timestamp_impl::( + &args, + "to_timestamp_seconds", + &self.timezone, + ), other => { exec_err!( "Unsupported data type {} for function to_timestamp_seconds", @@ -427,16 +573,13 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { } } } + fn documentation(&self) -> Option<&Documentation> { self.doc() } } impl ScalarUDFImpl for ToTimestampMillisFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "to_timestamp_millis" } @@ -445,15 +588,15 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(return_type_for(&arg_types[0], Millisecond)) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Timestamp(Millisecond, self.timezone.clone())) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - let args = args.args; + impl_with_updated_config!(); + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + if args.is_empty() { return exec_err!( "to_timestamp_millis function requires 1 or more arguments, got {}", @@ -467,15 +610,29 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { } match args[0].data_type() { - Null | Int32 | Int64 | Timestamp(_, None) => { - args[0].cast_to(&Timestamp(Millisecond, None), None) - } - Timestamp(_, Some(tz)) => { - args[0].cast_to(&Timestamp(Millisecond, Some(tz)), None) + Null + | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Timestamp(_, _) + | Decimal32(_, _) + | Decimal64(_, _) + | Decimal128(_, _) + | Decimal256(_, _) => { + args[0].cast_to(&Timestamp(Millisecond, self.timezone.clone()), None) } + Float16 | Float32 | Float64 => args[0] + .cast_to(&Int64, None)? + .cast_to(&Timestamp(Millisecond, self.timezone.clone()), None), Utf8View | LargeUtf8 | Utf8 => to_timestamp_impl::( &args, "to_timestamp_millis", + &self.timezone, ), other => { exec_err!( @@ -485,16 +642,13 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { } } } + fn documentation(&self) -> Option<&Documentation> { self.doc() } } impl ScalarUDFImpl for ToTimestampMicrosFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "to_timestamp_micros" } @@ -503,15 +657,15 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(return_type_for(&arg_types[0], Microsecond)) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Timestamp(Microsecond, self.timezone.clone())) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - let args = args.args; + impl_with_updated_config!(); + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + if args.is_empty() { return exec_err!( "to_timestamp_micros function requires 1 or more arguments, got {}", @@ -525,15 +679,29 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { } match args[0].data_type() { - Null | Int32 | Int64 | Timestamp(_, None) => { - args[0].cast_to(&Timestamp(Microsecond, None), None) - } - Timestamp(_, Some(tz)) => { - args[0].cast_to(&Timestamp(Microsecond, Some(tz)), None) + Null + | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Timestamp(_, _) + | Decimal32(_, _) + | Decimal64(_, _) + | Decimal128(_, _) + | Decimal256(_, _) => { + args[0].cast_to(&Timestamp(Microsecond, self.timezone.clone()), None) } + Float16 | Float32 | Float64 => args[0] + .cast_to(&Int64, None)? + .cast_to(&Timestamp(Microsecond, self.timezone.clone()), None), Utf8View | LargeUtf8 | Utf8 => to_timestamp_impl::( &args, "to_timestamp_micros", + &self.timezone, ), other => { exec_err!( @@ -543,16 +711,13 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { } } } + fn documentation(&self) -> Option<&Documentation> { self.doc() } } impl ScalarUDFImpl for ToTimestampNanosFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "to_timestamp_nanos" } @@ -561,15 +726,15 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(return_type_for(&arg_types[0], Nanosecond)) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Timestamp(Nanosecond, self.timezone.clone())) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - let args = args.args; + impl_with_updated_config!(); + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + if args.is_empty() { return exec_err!( "to_timestamp_nanos function requires 1 or more arguments, got {}", @@ -583,15 +748,30 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { } match args[0].data_type() { - Null | Int32 | Int64 | Timestamp(_, None) => { - args[0].cast_to(&Timestamp(Nanosecond, None), None) - } - Timestamp(_, Some(tz)) => { - args[0].cast_to(&Timestamp(Nanosecond, Some(tz)), None) - } - Utf8View | LargeUtf8 | Utf8 => { - to_timestamp_impl::(&args, "to_timestamp_nanos") + Null + | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Timestamp(_, _) + | Decimal32(_, _) + | Decimal64(_, _) + | Decimal128(_, _) + | Decimal256(_, _) => { + args[0].cast_to(&Timestamp(Nanosecond, self.timezone.clone()), None) } + Float16 | Float32 | Float64 => args[0] + .cast_to(&Int64, None)? + .cast_to(&Timestamp(Nanosecond, self.timezone.clone()), None), + Utf8View | LargeUtf8 | Utf8 => to_timestamp_impl::( + &args, + "to_timestamp_nanos", + &self.timezone, + ), other => { exec_err!( "Unsupported data type {} for function to_timestamp_nanos", @@ -600,23 +780,16 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { } } } + fn documentation(&self) -> Option<&Documentation> { self.doc() } } -/// Returns the return type for the to_timestamp_* function, preserving -/// the timezone if it exists. -fn return_type_for(arg: &DataType, unit: TimeUnit) -> DataType { - match arg { - Timestamp(_, Some(tz)) => Timestamp(unit, Some(Arc::clone(tz))), - _ => Timestamp(unit, None), - } -} - fn to_timestamp_impl>( args: &[ColumnarValue], name: &str, + timezone: &Option>, ) -> Result { let factor = match T::UNIT { Second => 1_000_000_000, @@ -625,17 +798,26 @@ fn to_timestamp_impl>( Nanosecond => 1, }; + let tz = match timezone.clone() { + Some(tz) => Some(tz.parse::()?), + None => None, + }; + match args.len() { - 1 => handle::( + 1 => handle::( args, - |s| string_to_timestamp_nanos_shim(s).map(|n| n / factor), + move |s| string_to_timestamp_nanos_with_timezone(&tz, s).map(|n| n / factor), name, + &Timestamp(T::UNIT, timezone.clone()), ), - n if n >= 2 => handle_multiple::( + n if n >= 2 => handle_multiple::( args, - string_to_timestamp_nanos_formatted, + move |s, format| { + string_to_timestamp_nanos_formatted_with_timezone(&tz, s, format) + }, |n| n / factor, name, + &Timestamp(T::UNIT, timezone.clone()), ), _ => exec_err!("Unsupported 0 argument count for function {name}"), } @@ -643,7 +825,6 @@ fn to_timestamp_impl>( #[cfg(test)] mod tests { - use std::sync::Arc; use arrow::array::types::Int64Type; use arrow::array::{ @@ -652,35 +833,109 @@ mod tests { }; use arrow::array::{ArrayRef, Int64Array, StringBuilder}; use arrow::datatypes::{Field, TimeUnit}; - use chrono::Utc; - use datafusion_common::config::ConfigOptions; - use datafusion_common::{assert_contains, DataFusionError, ScalarValue}; + use chrono::{DateTime, FixedOffset, Utc}; + use datafusion_common::{DataFusionError, assert_contains}; use datafusion_expr::ScalarFunctionImplementation; use super::*; fn to_timestamp(args: &[ColumnarValue]) -> Result { - to_timestamp_impl::(args, "to_timestamp") + let timezone: Option> = Some("UTC".into()); + to_timestamp_impl::(args, "to_timestamp", &timezone) } /// to_timestamp_millis SQL function fn to_timestamp_millis(args: &[ColumnarValue]) -> Result { - to_timestamp_impl::(args, "to_timestamp_millis") + let timezone: Option> = Some("UTC".into()); + to_timestamp_impl::( + args, + "to_timestamp_millis", + &timezone, + ) } /// to_timestamp_micros SQL function fn to_timestamp_micros(args: &[ColumnarValue]) -> Result { - to_timestamp_impl::(args, "to_timestamp_micros") + let timezone: Option> = Some("UTC".into()); + to_timestamp_impl::( + args, + "to_timestamp_micros", + &timezone, + ) } /// to_timestamp_nanos SQL function fn to_timestamp_nanos(args: &[ColumnarValue]) -> Result { - to_timestamp_impl::(args, "to_timestamp_nanos") + let timezone: Option> = Some("UTC".into()); + to_timestamp_impl::( + args, + "to_timestamp_nanos", + &timezone, + ) } /// to_timestamp_seconds SQL function fn to_timestamp_seconds(args: &[ColumnarValue]) -> Result { - to_timestamp_impl::(args, "to_timestamp_seconds") + let timezone: Option> = Some("UTC".into()); + to_timestamp_impl::(args, "to_timestamp_seconds", &timezone) + } + + fn udfs_and_timeunit() -> Vec<(Box, TimeUnit)> { + let udfs: Vec<(Box, TimeUnit)> = vec![ + ( + Box::new(ToTimestampFunc::new_with_config(&ConfigOptions::default())), + Nanosecond, + ), + ( + Box::new(ToTimestampSecondsFunc::new_with_config( + &ConfigOptions::default(), + )), + Second, + ), + ( + Box::new(ToTimestampMillisFunc::new_with_config( + &ConfigOptions::default(), + )), + Millisecond, + ), + ( + Box::new(ToTimestampMicrosFunc::new_with_config( + &ConfigOptions::default(), + )), + Microsecond, + ), + ( + Box::new(ToTimestampNanosFunc::new_with_config( + &ConfigOptions::default(), + )), + Nanosecond, + ), + ]; + udfs + } + + fn validate_expected_error( + options: &mut ConfigOptions, + args: ScalarFunctionArgs, + expected_err: &str, + ) { + let udfs = udfs_and_timeunit(); + + for (udf, _) in udfs { + match udf + .with_updated_config(options) + .unwrap() + .invoke_with_args(args.clone()) + { + Ok(_) => panic!("Expected error but got success"), + Err(e) => { + assert!( + e.to_string().contains(expected_err), + "Can not find expected error '{expected_err}'. Actual error '{e}'" + ); + } + } + } } #[test] @@ -710,6 +965,37 @@ mod tests { Ok(()) } + #[test] + fn to_timestamp_decimal128_overflow_returns_error() { + let value = "99999999999999999999999999999999999999" + .parse::() + .unwrap(); + let err = decimal128_to_timestamp_nanos( + &ColumnarValue::Scalar(ScalarValue::Decimal128(Some(value), 38, 0)), + None, + ) + .unwrap_err() + .to_string(); + + assert_contains!(err, "overflows timestamp nanoseconds"); + } + + #[test] + fn to_timestamp_decimal128_array_overflow_returns_error() { + let value = "99999999999999999999999999999999999999" + .parse::() + .unwrap(); + let array = Decimal128Array::from(vec![Some(value)]) + .with_precision_and_scale(38, 0) + .unwrap(); + let err = + decimal128_to_timestamp_nanos(&ColumnarValue::Array(Arc::new(array)), None) + .unwrap_err() + .to_string(); + + assert_contains!(err, "overflows timestamp nanoseconds"); + } + #[test] fn to_timestamp_with_formats_arrays_and_nulls() -> Result<()> { // ensure that arrow array implementation is wired up and handles nulls correctly @@ -751,6 +1037,368 @@ mod tests { Ok(()) } + #[test] + fn to_timestamp_respects_execution_timezone() -> Result<()> { + let udfs = udfs_and_timeunit(); + + let mut options = ConfigOptions::default(); + options.execution.time_zone = Some("-05:00".to_string()); + + let time_zone: Option> = options + .execution + .time_zone + .as_ref() + .map(|tz| Arc::from(tz.as_str())); + + for (udf, time_unit) in udfs { + let field = Field::new("arg", Utf8, true).into(); + + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "2020-09-08T13:42:29".to_string(), + )))], + arg_fields: vec![field], + number_rows: 1, + return_field: Field::new( + "f", + Timestamp(time_unit, Some("-05:00".into())), + true, + ) + .into(), + config_options: Arc::new(options.clone()), + }; + + let result = udf + .with_updated_config(&options.clone()) + .unwrap() + .invoke_with_args(args)?; + let result = match time_unit { + Second => { + let ColumnarValue::Scalar(ScalarValue::TimestampSecond( + Some(value), + tz, + )) = result + else { + panic!("expected scalar timestamp"); + }; + + assert_eq!(tz, time_zone); + + value + } + Millisecond => { + let ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(value), + tz, + )) = result + else { + panic!("expected scalar timestamp"); + }; + + assert_eq!(tz, time_zone); + + value + } + Microsecond => { + let ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( + Some(value), + tz, + )) = result + else { + panic!("expected scalar timestamp"); + }; + + assert_eq!(tz, time_zone); + + value + } + Nanosecond => { + let ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(value), + tz, + )) = result + else { + panic!("expected scalar timestamp"); + }; + + assert_eq!(tz, time_zone); + + value + } + }; + + let scale = match time_unit { + Second => 1_000_000_000, + Millisecond => 1_000_000, + Microsecond => 1_000, + Nanosecond => 1, + }; + + let offset = FixedOffset::west_opt(5 * 3600).unwrap(); + let result = Some( + DateTime::::from_timestamp_nanos(result * scale) + .with_timezone(&offset) + .to_string(), + ); + + assert_eq!(result, Some("2020-09-08 13:42:29 -05:00".to_string())); + } + + Ok(()) + } + + #[test] + fn to_timestamp_formats_respects_execution_timezone() -> Result<()> { + let udfs = udfs_and_timeunit(); + + let mut options = ConfigOptions::default(); + options.execution.time_zone = Some("-05:00".to_string()); + + let time_zone: Option> = options + .execution + .time_zone + .as_ref() + .map(|tz| Arc::from(tz.as_str())); + + let expr_field = Field::new("arg", Utf8, true).into(); + let format_field: Arc = Field::new("fmt", Utf8, true).into(); + + for (udf, time_unit) in udfs { + for (value, format, expected_str) in [ + ( + "2020-09-08 09:42:29 -05:00", + "%Y-%m-%d %H:%M:%S %z", + Some("2020-09-08 09:42:29 -05:00"), + ), + ( + "2020-09-08T13:42:29Z", + "%+", + Some("2020-09-08 08:42:29 -05:00"), + ), + ( + "2020-09-08 13:42:29 UTC", + "%Y-%m-%d %H:%M:%S %Z", + Some("2020-09-08 08:42:29 -05:00"), + ), + ( + "+0000 2024-01-01 12:00:00", + "%z %Y-%m-%d %H:%M:%S", + Some("2024-01-01 07:00:00 -05:00"), + ), + ( + "20200908134229+0100", + "%Y%m%d%H%M%S%z", + Some("2020-09-08 07:42:29 -05:00"), + ), + ( + "2020-09-08+0230 13:42", + "%Y-%m-%d%z %H:%M", + Some("2020-09-08 06:12:00 -05:00"), + ), + ] { + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(value.to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some( + format.to_string(), + ))), + ], + arg_fields: vec![Arc::clone(&expr_field), Arc::clone(&format_field)], + number_rows: 1, + return_field: Field::new( + "f", + Timestamp(time_unit, Some("-05:00".into())), + true, + ) + .into(), + config_options: Arc::new(options.clone()), + }; + let result = udf + .with_updated_config(&options.clone()) + .unwrap() + .invoke_with_args(args)?; + let result = match time_unit { + Second => { + let ColumnarValue::Scalar(ScalarValue::TimestampSecond( + Some(value), + tz, + )) = result + else { + panic!("expected scalar timestamp"); + }; + + assert_eq!(tz, time_zone); + + value + } + Millisecond => { + let ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(value), + tz, + )) = result + else { + panic!("expected scalar timestamp"); + }; + + assert_eq!(tz, time_zone); + + value + } + Microsecond => { + let ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( + Some(value), + tz, + )) = result + else { + panic!("expected scalar timestamp"); + }; + + assert_eq!(tz, time_zone); + + value + } + Nanosecond => { + let ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(value), + tz, + )) = result + else { + panic!("expected scalar timestamp"); + }; + + assert_eq!(tz, time_zone); + + value + } + }; + + let scale = match time_unit { + Second => 1_000_000_000, + Millisecond => 1_000_000, + Microsecond => 1_000, + Nanosecond => 1, + }; + let offset = FixedOffset::west_opt(5 * 3600).unwrap(); + let result = Some( + DateTime::::from_timestamp_nanos(result * scale) + .with_timezone(&offset) + .to_string(), + ); + + assert_eq!(result, expected_str.map(|s| s.to_string())); + } + } + + Ok(()) + } + + #[test] + fn to_timestamp_invalid_execution_timezone_behavior() -> Result<()> { + let field: Arc = Field::new("arg", Utf8, true).into(); + let return_field: Arc = + Field::new("f", Timestamp(Nanosecond, None), true).into(); + + let mut options = ConfigOptions::default(); + options.execution.time_zone = Some("Invalid/Timezone".to_string()); + + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "2020-09-08T13:42:29Z".to_string(), + )))], + arg_fields: vec![Arc::clone(&field)], + number_rows: 1, + return_field: Arc::clone(&return_field), + config_options: Arc::new(options.clone()), + }; + + let expected_err = + "Invalid timezone \"Invalid/Timezone\": failed to parse timezone"; + + validate_expected_error(&mut options, args, expected_err); + + Ok(()) + } + + #[test] + fn to_timestamp_formats_invalid_execution_timezone_behavior() -> Result<()> { + let expr_field: Arc = Field::new("arg", Utf8, true).into(); + let format_field: Arc = Field::new("fmt", Utf8, true).into(); + let return_field: Arc = + Field::new("f", Timestamp(Nanosecond, None), true).into(); + + let mut options = ConfigOptions::default(); + options.execution.time_zone = Some("Invalid/Timezone".to_string()); + + let expected_err = + "Invalid timezone \"Invalid/Timezone\": failed to parse timezone"; + + let make_args = |value: &str, format: &str| ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(value.to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(format.to_string()))), + ], + arg_fields: vec![Arc::clone(&expr_field), Arc::clone(&format_field)], + number_rows: 1, + return_field: Arc::clone(&return_field), + config_options: Arc::new(options.clone()), + }; + + for (value, format, _expected_str) in [ + ( + "2020-09-08 09:42:29 -05:00", + "%Y-%m-%d %H:%M:%S %z", + Some("2020-09-08 09:42:29 -05:00"), + ), + ( + "2020-09-08T13:42:29Z", + "%+", + Some("2020-09-08 08:42:29 -05:00"), + ), + ( + "2020-09-08 13:42:29 +0000", + "%Y-%m-%d %H:%M:%S %z", + Some("2020-09-08 08:42:29 -05:00"), + ), + ( + "+0000 2024-01-01 12:00:00", + "%z %Y-%m-%d %H:%M:%S", + Some("2024-01-01 07:00:00 -05:00"), + ), + ( + "20200908134229+0100", + "%Y%m%d%H%M%S%z", + Some("2020-09-08 07:42:29 -05:00"), + ), + ( + "2020-09-08+0230 13:42", + "%Y-%m-%d%z %H:%M", + Some("2020-09-08 06:12:00 -05:00"), + ), + ] { + let args = make_args(value, format); + validate_expected_error(&mut options.clone(), args, expected_err); + } + + let args = ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "2020-09-08T13:42:29".to_string(), + ))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "%Y-%m-%dT%H:%M:%S".to_string(), + ))), + ], + arg_fields: vec![Arc::clone(&expr_field), Arc::clone(&format_field)], + number_rows: 1, + return_field: Arc::clone(&return_field), + config_options: Arc::new(options.clone()), + }; + + validate_expected_error(&mut options.clone(), args, expected_err); + + Ok(()) + } + #[test] fn to_timestamp_invalid_input_type() -> Result<()> { // pass the wrong type of input array to to_timestamp and test @@ -811,8 +1459,7 @@ mod tests { let string_array = ColumnarValue::Array(Arc::new(date_string_builder.finish()) as ArrayRef); - let expected_err = - "Arrow error: Parser error: Error parsing timestamp from '2020-09-08 - 13:42:29.19085Z': error parsing time"; + let expected_err = "Arrow error: Parser error: Error parsing timestamp from '2020-09-08 - 13:42:29.19085Z': error parsing time"; match to_timestamp(&[string_array]) { Ok(_) => panic!("Expected error but got success"), Err(e) => { @@ -836,8 +1483,7 @@ mod tests { let string_array = ColumnarValue::Array(Arc::new(date_string_builder.finish()) as ArrayRef); - let expected_err = - "Arrow error: Parser error: Invalid timezone \"ZZ\": failed to parse timezone"; + let expected_err = "Arrow error: Parser error: Invalid timezone \"ZZ\": failed to parse timezone"; match to_timestamp(&[string_array]) { Ok(_) => panic!("Expected error but got success"), Err(e) => { @@ -874,8 +1520,7 @@ mod tests { ColumnarValue::Array(Arc::new(format3_builder.finish()) as ArrayRef), ]; - let expected_err = - "Execution error: Error parsing timestamp from '2020-09-08T13:42:29.19085Z' using format '%H:%M:%S': input contains invalid characters"; + let expected_err = "Execution error: Error parsing timestamp from '2020-09-08T13:42:29.19085Z' using format '%H:%M:%S': input contains invalid characters"; match to_timestamp(&string_array) { Ok(_) => panic!("Expected error but got success"), Err(e) => { @@ -923,7 +1568,11 @@ mod tests { } fn parse_timestamp_formatted(s: &str, format: &str) -> Result { - let result = string_to_timestamp_nanos_formatted(s, format); + let result = string_to_timestamp_nanos_formatted_with_timezone( + &Some("UTC".parse()?), + s, + format, + ); if let Err(e) = &result { eprintln!("Error parsing timestamp '{s}' using format '{format}': {e:?}"); } @@ -950,7 +1599,9 @@ mod tests { ]; for (s, f, ctx) in cases { - let expected = format!("Execution error: Error parsing timestamp from '{s}' using format '{f}': {ctx}"); + let expected = format!( + "Execution error: Error parsing timestamp from '{s}' using format '{f}': {ctx}" + ); let actual = string_to_datetime_formatted(&Utc, s, f) .unwrap_err() .strip_backtrace(); @@ -978,7 +1629,9 @@ mod tests { ]; for (s, f, ctx) in cases { - let expected = format!("Execution error: Error parsing timestamp from '{s}' using format '{f}': {ctx}"); + let expected = format!( + "Execution error: Error parsing timestamp from '{s}' using format '{f}': {ctx}" + ); let actual = string_to_datetime_formatted(&Utc, s, f) .unwrap_err() .strip_backtrace(); @@ -987,13 +1640,21 @@ mod tests { } #[test] - fn test_tz() { + fn test_no_tz() { let udfs: Vec> = vec![ - Box::new(ToTimestampFunc::new()), - Box::new(ToTimestampSecondsFunc::new()), - Box::new(ToTimestampMillisFunc::new()), - Box::new(ToTimestampNanosFunc::new()), - Box::new(ToTimestampSecondsFunc::new()), + Box::new(ToTimestampFunc::new_with_config(&ConfigOptions::default())), + Box::new(ToTimestampSecondsFunc::new_with_config( + &ConfigOptions::default(), + )), + Box::new(ToTimestampMillisFunc::new_with_config( + &ConfigOptions::default(), + )), + Box::new(ToTimestampNanosFunc::new_with_config( + &ConfigOptions::default(), + )), + Box::new(ToTimestampSecondsFunc::new_with_config( + &ConfigOptions::default(), + )), ]; let mut nanos_builder = TimestampNanosecondArray::builder(2); @@ -1026,8 +1687,8 @@ mod tests { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); let arg_field = Field::new("arg", array.data_type().clone(), true).into(); - assert!(matches!(rt, Timestamp(_, Some(_)))); - let args = datafusion_expr::ScalarFunctionArgs { + assert!(matches!(rt, Timestamp(_, None))); + let args = ScalarFunctionArgs { args: vec![array.clone()], arg_fields: vec![arg_field], number_rows: 4, @@ -1042,7 +1703,7 @@ mod tests { _ => panic!("Expected a columnar array"), }; let ty = array.data_type(); - assert!(matches!(ty, Timestamp(_, Some(_)))); + assert!(matches!(ty, Timestamp(_, None))); } } @@ -1077,7 +1738,7 @@ mod tests { let rt = udf.return_type(&[array.data_type()]).unwrap(); assert!(matches!(rt, Timestamp(_, None))); let arg_field = Field::new("arg", array.data_type().clone(), true).into(); - let args = datafusion_expr::ScalarFunctionArgs { + let args = ScalarFunctionArgs { args: vec![array.clone()], arg_fields: vec![arg_field], number_rows: 5, @@ -1214,4 +1875,23 @@ mod tests { assert_contains!(actual, expected); } } + + #[test] + fn test_decimal_to_nanoseconds_negative_scale() { + // scale -2: internal value 5 represents 5 * 10^2 = 500 seconds + let nanos = decimal_to_nanoseconds(5, -2).unwrap(); + assert_eq!(nanos, 500_000_000_000); // 500 seconds in nanoseconds + + // scale -1: internal value 10 represents 10 * 10^1 = 100 seconds + let nanos = decimal_to_nanoseconds(10, -1).unwrap(); + assert_eq!(nanos, 100_000_000_000); + + // scale 0: internal value 5 represents 5 seconds + let nanos = decimal_to_nanoseconds(5, 0).unwrap(); + assert_eq!(nanos, 5_000_000_000); + + // scale 3: internal value 1500 represents 1.5 seconds + let nanos = decimal_to_nanoseconds(1500, 3).unwrap(); + assert_eq!(nanos, 1_500_000_000); + } } diff --git a/datafusion/functions/src/datetime/to_unixtime.rs b/datafusion/functions/src/datetime/to_unixtime.rs index 42651cd537162..9fcfd254ca74d 100644 --- a/datafusion/functions/src/datetime/to_unixtime.rs +++ b/datafusion/functions/src/datetime/to_unixtime.rs @@ -18,16 +18,21 @@ use super::to_timestamp::ToTimestampSecondsFunc; use crate::datetime::common::*; use arrow::datatypes::{DataType, TimeUnit}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; #[user_doc( doc_section(label = "Time and Date Functions"), - description = "Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00Z`). Supports strings, dates, timestamps and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided.", + description = r#" +Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00`). +Supports strings, dates, timestamps, integer, unsigned integer, and float types as input. +Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') +if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. +Integers, unsigned integers, and floats are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00`)."#, syntax_example = "to_unixtime(expression[, ..., format_n])", sql_example = r#" ```sql @@ -74,10 +79,6 @@ impl ToUnixtimeFunc { } impl ScalarUDFImpl for ToUnixtimeFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "to_unixtime" } @@ -90,10 +91,7 @@ impl ScalarUDFImpl for ToUnixtimeFunc { Ok(DataType::Int64) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let arg_args = &args.args; if arg_args.is_empty() { return exec_err!("to_unixtime function requires 1 or more arguments, got 0"); @@ -101,22 +99,44 @@ impl ScalarUDFImpl for ToUnixtimeFunc { // validate that any args after the first one are Utf8 if arg_args.len() > 1 { - validate_data_types(arg_args, "to_unixtime")?; + // Format arguments only make sense for string inputs + match arg_args[0].data_type() { + DataType::Utf8View | DataType::LargeUtf8 | DataType::Utf8 => { + validate_data_types(arg_args, "to_unixtime")?; + } + _ => { + return exec_err!( + "to_unixtime function only accepts format arguments with string input, got {} arguments", + arg_args.len() + ); + } + } } match arg_args[0].data_type() { - DataType::Int32 | DataType::Int64 | DataType::Null | DataType::Float64 => { - arg_args[0].cast_to(&DataType::Int64, None) - } + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Null => arg_args[0].cast_to(&DataType::Int64, None), DataType::Date64 | DataType::Date32 => arg_args[0] .cast_to(&DataType::Timestamp(TimeUnit::Second, None), None)? .cast_to(&DataType::Int64, None), DataType::Timestamp(_, tz) => arg_args[0] .cast_to(&DataType::Timestamp(TimeUnit::Second, tz), None)? .cast_to(&DataType::Int64, None), - DataType::Utf8 => ToTimestampSecondsFunc::new() - .invoke_with_args(args)? - .cast_to(&DataType::Int64, None), + DataType::Utf8View | DataType::LargeUtf8 | DataType::Utf8 => { + ToTimestampSecondsFunc::new_with_config(args.config_options.as_ref()) + .invoke_with_args(args)? + .cast_to(&DataType::Int64, None) + } other => { exec_err!("Unsupported data type {} for function to_unixtime", other) } diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index e5314ad220c8f..ad156f735b33b 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -19,29 +19,29 @@ use arrow::{ array::{ - Array, ArrayRef, BinaryArray, GenericByteArray, OffsetSizeTrait, StringArray, + Array, ArrayRef, AsArray, BinaryArrayType, GenericBinaryArray, + GenericStringArray, OffsetSizeTrait, }, - datatypes::{ByteArrayType, DataType}, + datatypes::DataType, }; use arrow_buffer::{Buffer, OffsetBufferBuilder}; use base64::{ - engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig}, Engine as _, + engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig}, }; use datafusion_common::{ - cast::{as_generic_binary_array, as_generic_string_array}, + DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err, internal_err, not_impl_err, plan_err, + types::{NativeType, logical_string}, utils::take_function_args, }; -use datafusion_common::{exec_err, internal_datafusion_err, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{ColumnarValue, Documentation}; -use std::sync::Arc; -use std::{fmt, str::FromStr}; - -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; use datafusion_macros::user_doc; -use std::any::Any; +use std::fmt; +use std::sync::Arc; // Allow padding characters, but don't require them, and don't generate them. const BASE64_ENGINE: GeneralPurpose = GeneralPurpose::new( @@ -51,6 +51,12 @@ const BASE64_ENGINE: GeneralPurpose = GeneralPurpose::new( .with_decode_padding_mode(DecodePaddingMode::Indifferent), ); +// Generate padding characters when encoding +const BASE64_ENGINE_PADDED: GeneralPurpose = GeneralPurpose::new( + &base64::alphabet::STANDARD, + GeneralPurposeConfig::new().with_encode_padding(true), +); + #[user_doc( doc_section(label = "Binary String Functions"), description = "Encode binary data into a textual representation.", @@ -61,7 +67,7 @@ const BASE64_ENGINE: GeneralPurpose = GeneralPurpose::new( ), argument( name = "format", - description = "Supported formats are: `base64`, `hex`" + description = "Supported formats are: `base64`, `base64pad`, `hex`" ), related_udf(name = "decode") )] @@ -79,15 +85,22 @@ impl Default for EncodeFunc { impl EncodeFunc { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Binary, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Binary, + ), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ], + Volatility::Immutable, + ), } } } impl ScalarUDFImpl for EncodeFunc { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "encode" } @@ -97,48 +110,21 @@ impl ScalarUDFImpl for EncodeFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - - Ok(match arg_types[0] { - Utf8 => Utf8, - LargeUtf8 => LargeUtf8, - Utf8View => Utf8, - Binary => Utf8, - LargeBinary => LargeUtf8, - Null => Null, - _ => { - return plan_err!( - "The encode function can only accept Utf8 or Binary or Null." - ); - } - }) - } - - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - encode(&args.args) - } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let [expression, format] = take_function_args(self.name(), arg_types)?; - - if format != &DataType::Utf8 { - return Err(DataFusionError::Plan("2nd argument should be Utf8".into())); + match &arg_types[0] { + DataType::LargeBinary => Ok(DataType::LargeUtf8), + _ => Ok(DataType::Utf8), } + } + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [expression, encoding] = take_function_args("encode", &args.args)?; + let encoding = Encoding::try_from(encoding)?; match expression { - DataType::Utf8 | DataType::Utf8View | DataType::Null => { - Ok(vec![DataType::Utf8; 2]) + _ if expression.data_type().is_null() => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) } - DataType::LargeUtf8 => Ok(vec![DataType::LargeUtf8, DataType::Utf8]), - DataType::Binary => Ok(vec![DataType::Binary, DataType::Utf8]), - DataType::LargeBinary => Ok(vec![DataType::LargeBinary, DataType::Utf8]), - _ => plan_err!( - "1st argument should be Utf8 or Binary or Null, got {:?}", - arg_types[0] - ), + ColumnarValue::Array(array) => encode_array(array, encoding), + ColumnarValue::Scalar(scalar) => encode_scalar(scalar, encoding), } } @@ -172,15 +158,22 @@ impl Default for DecodeFunc { impl DecodeFunc { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Binary, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Binary, + ), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ], + Volatility::Immutable, + ), } } } impl ScalarUDFImpl for DecodeFunc { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "decode" } @@ -190,40 +183,21 @@ impl ScalarUDFImpl for DecodeFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].to_owned()) - } - - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - decode(&args.args) - } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 { - return plan_err!( - "{} expects to get 2 arguments, but got {}", - self.name(), - arg_types.len() - ); - } - - if arg_types[1] != DataType::Utf8 { - return plan_err!("2nd argument should be Utf8"); + match &arg_types[0] { + DataType::LargeBinary => Ok(DataType::LargeBinary), + _ => Ok(DataType::Binary), } + } - match arg_types[0] { - DataType::Utf8 | DataType::Utf8View | DataType::Binary | DataType::Null => { - Ok(vec![DataType::Binary, DataType::Utf8]) - } - DataType::LargeUtf8 | DataType::LargeBinary => { - Ok(vec![DataType::LargeBinary, DataType::Utf8]) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [expression, encoding] = take_function_args("decode", &args.args)?; + let encoding = Encoding::try_from(encoding)?; + match expression { + _ if expression.data_type().is_null() => { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) } - _ => plan_err!( - "1st argument should be Utf8 or Binary or Null, got {:?}", - arg_types[0] - ), + ColumnarValue::Array(array) => decode_array(array, encoding), + ColumnarValue::Scalar(scalar) => decode_scalar(scalar, encoding), } } @@ -232,324 +206,317 @@ impl ScalarUDFImpl for DecodeFunc { } } -#[derive(Debug, Copy, Clone)] -enum Encoding { - Base64, - Hex, -} - -fn encode_process(value: &ColumnarValue, encoding: Encoding) -> Result { +fn encode_scalar(value: &ScalarValue, encoding: Encoding) -> Result { match value { - ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 => encoding.encode_utf8_array::(a.as_ref()), - DataType::LargeUtf8 => encoding.encode_utf8_array::(a.as_ref()), - DataType::Utf8View => encoding.encode_utf8_array::(a.as_ref()), - DataType::Binary => encoding.encode_binary_array::(a.as_ref()), - DataType::LargeBinary => encoding.encode_binary_array::(a.as_ref()), - other => exec_err!( - "Unsupported data type {other:?} for function encode({encoding})" - ), - }, - ColumnarValue::Scalar(scalar) => { - match scalar { - ScalarValue::Utf8(a) => { - Ok(encoding.encode_scalar(a.as_ref().map(|s: &String| s.as_bytes()))) - } - ScalarValue::LargeUtf8(a) => Ok(encoding - .encode_large_scalar(a.as_ref().map(|s: &String| s.as_bytes()))), - ScalarValue::Utf8View(a) => { - Ok(encoding.encode_scalar(a.as_ref().map(|s: &String| s.as_bytes()))) - } - ScalarValue::Binary(a) => Ok( - encoding.encode_scalar(a.as_ref().map(|v: &Vec| v.as_slice())) - ), - ScalarValue::LargeBinary(a) => Ok(encoding - .encode_large_scalar(a.as_ref().map(|v: &Vec| v.as_slice()))), - other => exec_err!( - "Unsupported data type {other:?} for function encode({encoding})" - ), - } + ScalarValue::Binary(maybe_bytes) + | ScalarValue::BinaryView(maybe_bytes) + | ScalarValue::FixedSizeBinary(_, maybe_bytes) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8( + maybe_bytes + .as_ref() + .map(|bytes| encoding.encode_bytes(bytes)), + ))) } + ScalarValue::LargeBinary(maybe_bytes) => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8( + maybe_bytes + .as_ref() + .map(|bytes| encoding.encode_bytes(bytes)), + ))) + } + v => internal_err!("Unexpected value for encode: {v}"), } } -fn decode_process(value: &ColumnarValue, encoding: Encoding) -> Result { - match value { - ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 => encoding.decode_utf8_array::(a.as_ref()), - DataType::LargeUtf8 => encoding.decode_utf8_array::(a.as_ref()), - DataType::Utf8View => encoding.decode_utf8_array::(a.as_ref()), - DataType::Binary => encoding.decode_binary_array::(a.as_ref()), - DataType::LargeBinary => encoding.decode_binary_array::(a.as_ref()), - other => exec_err!( - "Unsupported data type {other:?} for function decode({encoding})" - ), - }, - ColumnarValue::Scalar(scalar) => { - match scalar { - ScalarValue::Utf8(a) => { - encoding.decode_scalar(a.as_ref().map(|s: &String| s.as_bytes())) - } - ScalarValue::LargeUtf8(a) => encoding - .decode_large_scalar(a.as_ref().map(|s: &String| s.as_bytes())), - ScalarValue::Utf8View(a) => { - encoding.decode_scalar(a.as_ref().map(|s: &String| s.as_bytes())) - } - ScalarValue::Binary(a) => { - encoding.decode_scalar(a.as_ref().map(|v: &Vec| v.as_slice())) - } - ScalarValue::LargeBinary(a) => encoding - .decode_large_scalar(a.as_ref().map(|v: &Vec| v.as_slice())), - other => exec_err!( - "Unsupported data type {other:?} for function decode({encoding})" - ), - } +fn encode_array(array: &ArrayRef, encoding: Encoding) -> Result { + let array = match array.data_type() { + DataType::Binary => encoding.encode_array::<_, i32>(&array.as_binary::()), + DataType::BinaryView => encoding.encode_array::<_, i32>(&array.as_binary_view()), + DataType::LargeBinary => { + encoding.encode_array::<_, i64>(&array.as_binary::()) } - } + DataType::FixedSizeBinary(_) => { + encoding.encode_array::<_, i32>(&array.as_fixed_size_binary()) + } + dt => { + internal_err!("Unexpected data type for encode: {dt}") + } + }; + array.map(ColumnarValue::Array) } -fn hex_encode(input: &[u8]) -> String { - hex::encode(input) +fn decode_scalar(value: &ScalarValue, encoding: Encoding) -> Result { + match value { + ScalarValue::Binary(maybe_bytes) + | ScalarValue::BinaryView(maybe_bytes) + | ScalarValue::FixedSizeBinary(_, maybe_bytes) => { + Ok(ColumnarValue::Scalar(ScalarValue::Binary( + maybe_bytes + .as_ref() + .map(|x| encoding.decode_bytes(x)) + .transpose()?, + ))) + } + ScalarValue::LargeBinary(maybe_bytes) => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary( + maybe_bytes + .as_ref() + .map(|x| encoding.decode_bytes(x)) + .transpose()?, + ))) + } + v => internal_err!("Unexpected value for decode: {v}"), + } } -fn base64_encode(input: &[u8]) -> String { - BASE64_ENGINE.encode(input) +/// Estimate how many bytes are actually represented by the array; in case the +/// the array slices it's internal buffer, this returns the byte size of that slice +/// but not the byte size of the entire buffer. +/// +/// This is an estimation only as it can estimate higher if null slots are non-zero +/// sized. +fn estimate_byte_data_size(array: &GenericBinaryArray) -> usize { + let offsets = array.value_offsets(); + // Unwraps are safe as should always have 1 element in offset buffer + let start = *offsets.first().unwrap(); + let end = *offsets.last().unwrap(); + let data_size = end - start; + data_size.as_usize() } -fn hex_decode(input: &[u8], buf: &mut [u8]) -> Result { - // only write input / 2 bytes to buf - let out_len = input.len() / 2; - let buf = &mut buf[..out_len]; - hex::decode_to_slice(input, buf) - .map_err(|e| internal_datafusion_err!("Failed to decode from hex: {e}"))?; - Ok(out_len) +fn decode_array(array: &ArrayRef, encoding: Encoding) -> Result { + let array = match array.data_type() { + DataType::Binary => { + let array = array.as_binary::(); + encoding.decode_array::<_, i32>(&array, estimate_byte_data_size(array)) + } + DataType::BinaryView => { + let array = array.as_binary_view(); + // Don't know if there is a more strict upper bound we can infer + // for view arrays byte data size. + encoding.decode_array::<_, i32>(&array, array.get_buffer_memory_size()) + } + DataType::LargeBinary => { + let array = array.as_binary::(); + encoding.decode_array::<_, i64>(&array, estimate_byte_data_size(array)) + } + DataType::FixedSizeBinary(size) => { + let array = array.as_fixed_size_binary(); + // TODO: could we be more conservative by accounting for nulls? + let estimate = array.len().saturating_mul(*size as usize); + encoding.decode_array::<_, i32>(&array, estimate) + } + dt => { + internal_err!("Unexpected data type for decode: {dt}") + } + }; + array.map(ColumnarValue::Array) } -fn base64_decode(input: &[u8], buf: &mut [u8]) -> Result { - BASE64_ENGINE - .decode_slice(input, buf) - .map_err(|e| internal_datafusion_err!("Failed to decode from base64: {e}")) +#[derive(Debug, Copy, Clone)] +enum Encoding { + Base64, + Base64Padded, + Hex, } -macro_rules! encode_to_array { - ($METHOD: ident, $INPUT:expr) => {{ - let utf8_array: StringArray = $INPUT - .iter() - .map(|x| x.map(|x| $METHOD(x.as_ref()))) - .collect(); - Arc::new(utf8_array) - }}; +impl fmt::Display for Encoding { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let name = match self { + Self::Base64 => "base64", + Self::Base64Padded => "base64pad", + Self::Hex => "hex", + }; + write!(f, "{name}") + } } -fn decode_to_array( - method: F, - input: &GenericByteArray, - conservative_upper_bound_size: usize, -) -> Result -where - F: Fn(&[u8], &mut [u8]) -> Result, -{ - let mut values = vec![0; conservative_upper_bound_size]; - let mut offsets = OffsetBufferBuilder::new(input.len()); - let mut total_bytes_decoded = 0; - for v in input { - if let Some(v) = v { - let cursor = &mut values[total_bytes_decoded..]; - let decoded = method(v.as_ref(), cursor)?; - total_bytes_decoded += decoded; - offsets.push_length(decoded); - } else { - offsets.push_length(0); +impl TryFrom<&ColumnarValue> for Encoding { + type Error = DataFusionError; + + fn try_from(encoding: &ColumnarValue) -> Result { + let encoding = match encoding { + ColumnarValue::Scalar(encoding) => match encoding.try_as_str().flatten() { + Some(encoding) => encoding, + _ => return exec_err!("Encoding must be a non-null string"), + }, + ColumnarValue::Array(_) => { + return not_impl_err!( + "Encoding must be a scalar; array specified encoding is not yet supported" + ); + } + }; + match encoding { + "base64" => Ok(Self::Base64), + "base64pad" => Ok(Self::Base64Padded), + "hex" => Ok(Self::Hex), + _ => { + let options = [Self::Base64, Self::Base64Padded, Self::Hex] + .iter() + .map(|i| i.to_string()) + .collect::>() + .join(", "); + plan_err!( + "There is no built-in encoding named '{encoding}', currently supported encodings are: {options}" + ) + } } } - // We reserved an upper bound size for the values buffer, but we only use the actual size - values.truncate(total_bytes_decoded); - let binary_array = BinaryArray::try_new( - offsets.finish(), - Buffer::from_vec(values), - input.nulls().cloned(), - )?; - Ok(Arc::new(binary_array)) } impl Encoding { - fn encode_scalar(self, value: Option<&[u8]>) -> ColumnarValue { - ColumnarValue::Scalar(match self { - Self::Base64 => ScalarValue::Utf8(value.map(|v| BASE64_ENGINE.encode(v))), - Self::Hex => ScalarValue::Utf8(value.map(hex::encode)), - }) + fn encode_bytes(self, value: &[u8]) -> String { + match self { + Self::Base64 => BASE64_ENGINE.encode(value), + Self::Base64Padded => BASE64_ENGINE_PADDED.encode(value), + Self::Hex => hex::encode(value), + } } - fn encode_large_scalar(self, value: Option<&[u8]>) -> ColumnarValue { - ColumnarValue::Scalar(match self { - Self::Base64 => { - ScalarValue::LargeUtf8(value.map(|v| BASE64_ENGINE.encode(v))) + fn decode_bytes(self, value: &[u8]) -> Result> { + match self { + Self::Base64 | Self::Base64Padded => { + BASE64_ENGINE.decode(value).map_err(|e| { + exec_datafusion_err!("Failed to decode value using {self}: {e}") + }) } - Self::Hex => ScalarValue::LargeUtf8(value.map(hex::encode)), - }) - } - - fn encode_binary_array(self, value: &dyn Array) -> Result - where - T: OffsetSizeTrait, - { - let input_value = as_generic_binary_array::(value)?; - let array: ArrayRef = match self { - Self::Base64 => encode_to_array!(base64_encode, input_value), - Self::Hex => encode_to_array!(hex_encode, input_value), - }; - Ok(ColumnarValue::Array(array)) - } - - fn encode_utf8_array(self, value: &dyn Array) -> Result - where - T: OffsetSizeTrait, - { - let input_value = as_generic_string_array::(value)?; - let array: ArrayRef = match self { - Self::Base64 => encode_to_array!(base64_encode, input_value), - Self::Hex => encode_to_array!(hex_encode, input_value), - }; - Ok(ColumnarValue::Array(array)) - } - - fn decode_scalar(self, value: Option<&[u8]>) -> Result { - let value = match value { - Some(value) => value, - None => return Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))), - }; - - let out = match self { - Self::Base64 => BASE64_ENGINE.decode(value).map_err(|e| { - internal_datafusion_err!("Failed to decode value using base64: {e}") - })?, - Self::Hex => hex::decode(value).map_err(|e| { - internal_datafusion_err!("Failed to decode value using hex: {e}") - })?, - }; - - Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(out)))) - } - - fn decode_large_scalar(self, value: Option<&[u8]>) -> Result { - let value = match value { - Some(value) => value, - None => return Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(None))), - }; - - let out = match self { - Self::Base64 => BASE64_ENGINE.decode(value).map_err(|e| { - internal_datafusion_err!("Failed to decode value using base64: {e}") - })?, Self::Hex => hex::decode(value).map_err(|e| { - internal_datafusion_err!("Failed to decode value using hex: {e}") - })?, - }; - - Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(out)))) + exec_datafusion_err!("Failed to decode value using hex: {e}") + }), + } } - fn decode_binary_array(self, value: &dyn Array) -> Result + // OutputOffset important to ensure Large types output Large arrays + fn encode_array<'a, InputBinaryArray, OutputOffset>( + self, + array: &InputBinaryArray, + ) -> Result where - T: OffsetSizeTrait, + InputBinaryArray: BinaryArrayType<'a>, + OutputOffset: OffsetSizeTrait, { - let input_value = as_generic_binary_array::(value)?; - let array = self.decode_byte_array(input_value)?; - Ok(ColumnarValue::Array(array)) + match self { + Self::Base64 => { + let array: GenericStringArray = array + .iter() + .map(|x| x.map(|x| BASE64_ENGINE.encode(x))) + .collect(); + Ok(Arc::new(array)) + } + Self::Base64Padded => { + let array: GenericStringArray = array + .iter() + .map(|x| x.map(|x| BASE64_ENGINE_PADDED.encode(x))) + .collect(); + Ok(Arc::new(array)) + } + Self::Hex => { + let array: GenericStringArray = + array.iter().map(|x| x.map(hex::encode)).collect(); + Ok(Arc::new(array)) + } + } } - fn decode_utf8_array(self, value: &dyn Array) -> Result + // OutputOffset important to ensure Large types output Large arrays + fn decode_array<'a, InputBinaryArray, OutputOffset>( + self, + value: &InputBinaryArray, + approx_data_size: usize, + ) -> Result where - T: OffsetSizeTrait, + InputBinaryArray: BinaryArrayType<'a>, + OutputOffset: OffsetSizeTrait, { - let input_value = as_generic_string_array::(value)?; - let array = self.decode_byte_array(input_value)?; - Ok(ColumnarValue::Array(array)) - } + fn hex_decode(input: &[u8], buf: &mut [u8]) -> Result { + // only write input / 2 bytes to buf + let out_len = input.len() / 2; + let buf = &mut buf[..out_len]; + hex::decode_to_slice(input, buf) + .map_err(|e| exec_datafusion_err!("Failed to decode from hex: {e}"))?; + Ok(out_len) + } + + fn base64_decode(input: &[u8], buf: &mut [u8]) -> Result { + BASE64_ENGINE + .decode_slice(input, buf) + .map_err(|e| exec_datafusion_err!("Failed to decode from base64: {e}")) + } - fn decode_byte_array( - &self, - input_value: &GenericByteArray, - ) -> Result { match self { - Self::Base64 => { - let upper_bound = - base64::decoded_len_estimate(input_value.values().len()); - decode_to_array(base64_decode, input_value, upper_bound) + Self::Base64 | Self::Base64Padded => { + let upper_bound = base64::decoded_len_estimate(approx_data_size); + delegated_decode::<_, _, OutputOffset>(base64_decode, value, upper_bound) } Self::Hex => { // Calculate the upper bound for decoded byte size // For hex encoding, each pair of hex characters (2 bytes) represents 1 byte when decoded // So the upper bound is half the length of the input values. - let upper_bound = input_value.values().len() / 2; - decode_to_array(hex_decode, input_value, upper_bound) + let upper_bound = approx_data_size / 2; + delegated_decode::<_, _, OutputOffset>(hex_decode, value, upper_bound) } } } } -impl fmt::Display for Encoding { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", format!("{self:?}").to_lowercase()) +fn delegated_decode<'a, DecodeFunction, InputBinaryArray, OutputOffset>( + decode: DecodeFunction, + input: &InputBinaryArray, + conservative_upper_bound_size: usize, +) -> Result +where + DecodeFunction: Fn(&[u8], &mut [u8]) -> Result, + InputBinaryArray: BinaryArrayType<'a>, + OutputOffset: OffsetSizeTrait, +{ + let mut values = vec![0; conservative_upper_bound_size]; + let mut offsets = OffsetBufferBuilder::new(input.len()); + let mut total_bytes_decoded = 0; + for v in input.iter() { + if let Some(v) = v { + let cursor = &mut values[total_bytes_decoded..]; + let decoded = decode(v, cursor)?; + total_bytes_decoded += decoded; + offsets.push_length(decoded); + } else { + offsets.push_length(0); + } } + // We reserved an upper bound size for the values buffer, but we only use the actual size + values.truncate(total_bytes_decoded); + let binary_array = GenericBinaryArray::::try_new( + offsets.finish(), + Buffer::from_vec(values), + input.nulls().cloned(), + )?; + Ok(Arc::new(binary_array)) } -impl FromStr for Encoding { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name { - "base64" => Self::Base64, - "hex" => Self::Hex, - _ => { - let options = [Self::Base64, Self::Hex] - .iter() - .map(|i| i.to_string()) - .collect::>() - .join(", "); - return plan_err!( - "There is no built-in encoding named '{name}', currently supported encodings are: {options}" - ); - } - }) +#[cfg(test)] +mod tests { + use arrow::array::BinaryArray; + use arrow_buffer::OffsetBuffer; + + use super::*; + + #[test] + fn test_estimate_byte_data_size() { + // Offsets starting at 0, but don't count entire data buffer size + let array = BinaryArray::new( + OffsetBuffer::new(vec![0, 5, 10, 15].into()), + vec![0; 100].into(), + None, + ); + let size = estimate_byte_data_size(&array); + assert_eq!(size, 15); + + // Offsets starting at 0, but don't count entire data buffer size + let array = BinaryArray::new( + OffsetBuffer::new(vec![50, 51, 51, 60, 80, 81].into()), + vec![0; 100].into(), + Some(vec![true, false, false, true, true].into()), + ); + let size = estimate_byte_data_size(&array); + assert_eq!(size, 31); } } - -/// Encodes the given data, accepts Binary, LargeBinary, Utf8, Utf8View or LargeUtf8 and returns a [`ColumnarValue`]. -/// Second argument is the encoding to use. -/// Standard encodings are base64 and hex. -fn encode(args: &[ColumnarValue]) -> Result { - let [expression, format] = take_function_args("encode", args)?; - - let encoding = match format { - ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { - Some(Some(method)) => method.parse::(), - _ => not_impl_err!( - "Second argument to encode must be non null constant string: Encode using dynamically decided method is not yet supported. Got {scalar:?}" - ), - }, - ColumnarValue::Array(_) => not_impl_err!( - "Second argument to encode must be a constant: Encode using dynamically decided method is not yet supported" - ), - }?; - encode_process(expression, encoding) -} - -/// Decodes the given data, accepts Binary, LargeBinary, Utf8, Utf8View or LargeUtf8 and returns a [`ColumnarValue`]. -/// Second argument is the encoding to use. -/// Standard encodings are base64 and hex. -fn decode(args: &[ColumnarValue]) -> Result { - let [expression, format] = take_function_args("decode", args)?; - - let encoding = match format { - ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { - Some(Some(method))=> method.parse::(), - _ => not_impl_err!( - "Second argument to decode must be a non null constant string: Decode using dynamically decided method is not yet supported. Got {scalar:?}" - ), - }, - ColumnarValue::Array(_) => not_impl_err!( - "Second argument to decode must be a utf8 constant: Decode using dynamically decided method is not yet supported" - ), - }?; - decode_process(expression, encoding) -} diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 252914026befd..7e753d7f35eb3 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -23,8 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! Function packages for [DataFusion]. @@ -195,7 +193,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { } #[cfg(test)] -#[ctor::ctor] +#[ctor::ctor(unsafe)] fn init() { // Enable RUST_LOG logging configuration for test let _ = env_logger::try_init(); diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 471b0a053ed6f..f196870e97228 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -41,6 +41,17 @@ /// - `Vec` argument (single argument followed by a comma) /// - Variable number of `Expr` arguments (zero or more arguments, must be without commas) /// - Functions that require config (marked with `@config` prefix) +/// +/// Note on configuration construction paths: +/// - The convenience wrappers generated for `@config` functions call the inner +/// constructor with `ConfigOptions::default()`. These wrappers are intended +/// primarily for programmatic `Expr` construction and convenience usage. +/// - When functions are registered in a session, DataFusion will call +/// `with_updated_config()` to create a `ScalarUDF` instance using the session's +/// actual `ConfigOptions`. This also happens when configuration changes at runtime +/// (e.g., via `SET` statements). In short: the macro uses the default config for +/// convenience constructors; the session config is applied when functions are +/// registered or when configuration is updated. #[macro_export] macro_rules! export_functions { ($(($FUNC:ident, $DOC:expr, $($arg:tt)*)),*) => { @@ -59,6 +70,24 @@ macro_rules! export_functions { } }; + // function that requires config and takes a vector argument + (single $FUNC:ident, $DOC:expr, @config $arg:ident,) => { + #[doc = $DOC] + pub fn $FUNC($arg: Vec) -> datafusion_expr::Expr { + use datafusion_common::config::ConfigOptions; + super::$FUNC(&ConfigOptions::default()).call($arg) + } + }; + + // function that requires config and variadic arguments + (single $FUNC:ident, $DOC:expr, @config $($arg:ident)*) => { + #[doc = $DOC] + pub fn $FUNC($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { + use datafusion_common::config::ConfigOptions; + super::$FUNC(&ConfigOptions::default()).call(vec![$($arg),*]) + } + }; + // single vector argument (a single argument followed by a comma) (single $FUNC:ident, $DOC:expr, $arg:ident,) => { #[doc = $DOC] @@ -84,7 +113,6 @@ macro_rules! export_functions { #[macro_export] macro_rules! make_udf_function { ($UDF:ty, $NAME:ident, $CTOR:expr) => { - #[allow(rustdoc::redundant_explicit_links)] #[doc = concat!("Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation of ", stringify!($NAME))] pub fn $NAME() -> std::sync::Arc { // Singleton instance of the function @@ -109,7 +137,6 @@ macro_rules! make_udf_function { #[macro_export] macro_rules! make_udf_function_with_config { ($UDF:ty, $NAME:ident) => { - #[allow(rustdoc::redundant_explicit_links)] #[doc = concat!("Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation of ", stringify!($NAME))] pub fn $NAME(config: &datafusion_common::config::ConfigOptions) -> std::sync::Arc { std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( @@ -168,9 +195,7 @@ macro_rules! downcast_named_arg { /// $ARRAY_TYPE: the type of array to cast the argument to #[macro_export] macro_rules! downcast_arg { - ($ARG:expr, $ARRAY_TYPE:ident) => {{ - $crate::downcast_named_arg!($ARG, "", $ARRAY_TYPE) - }}; + ($ARG:expr, $ARRAY_TYPE:ident) => {{ $crate::downcast_named_arg!($ARG, "", $ARRAY_TYPE) }}; } /// Macro to create a unary math UDF. @@ -185,15 +210,27 @@ macro_rules! downcast_arg { /// $GET_DOC: the function to get the documentation of the UDF macro_rules! make_math_unary_udf { ($UDF:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr, $GET_DOC:expr) => { + make_math_unary_udf!( + $UDF, + $NAME, + $UNARY_FUNC, + $OUTPUT_ORDERING, + $EVALUATE_BOUNDS, + $GET_DOC, + None:: Result<()>> + ); + }; + ($UDF:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr, $GET_DOC:expr, $VALIDATOR:expr) => { $crate::make_udf_function!($NAME::$UDF, $NAME); mod $NAME { - use std::any::Any; + use std::sync::Arc; use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; - use datafusion_common::{exec_err, Result}; + use arrow::error::ArrowError; + use datafusion_common::{Result, exec_err}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ @@ -208,11 +245,10 @@ macro_rules! make_math_unary_udf { impl $UDF { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::uniform( 1, - vec![Float64, Float32], + vec![DataType::Float64, DataType::Float32], Volatility::Immutable, ), } @@ -220,9 +256,6 @@ macro_rules! make_math_unary_udf { } impl ScalarUDFImpl for $UDF { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { stringify!($NAME) } @@ -236,7 +269,6 @@ macro_rules! make_math_unary_udf { match arg_type { DataType::Float32 => Ok(DataType::Float32), - // For other types (possible values float64/null/int), use Float64 _ => Ok(DataType::Float64), } } @@ -258,21 +290,43 @@ macro_rules! make_math_unary_udf { ) -> Result { let args = ColumnarValue::values_to_arrays(&args.args)?; let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new( - args[0] + DataType::Float64 => { + let values = args[0] .as_primitive::() - .unary::<_, Float64Type>(|x: f64| f64::$UNARY_FUNC(x)), - ) as ArrayRef, - DataType::Float32 => Arc::new( - args[0] + .try_unary::<_, Float64Type, _>( + |x: f64| -> std::result::Result { + if let Some(validate) = $VALIDATOR { + validate(x).map_err(|error| { + ArrowError::ComputeError(error.to_string()) + })?; + } + + Ok(f64::$UNARY_FUNC(x)) + }, + )?; + Arc::new(values) as ArrayRef + } + DataType::Float32 => { + let values = args[0] .as_primitive::() - .unary::<_, Float32Type>(|x: f32| f32::$UNARY_FUNC(x)), - ) as ArrayRef, + .try_unary::<_, Float32Type, _>( + |x: f32| -> std::result::Result { + if let Some(validate) = $VALIDATOR { + validate(x as f64).map_err(|error| { + ArrowError::ComputeError(error.to_string()) + })?; + } + + Ok(f32::$UNARY_FUNC(x)) + }, + )?; + Arc::new(values) as ArrayRef + } other => { return exec_err!( "Unsupported data type {other:?} for function {}", self.name() - ) + ); } }; @@ -289,8 +343,12 @@ macro_rules! make_math_unary_udf { /// Macro to create a binary math UDF. /// -/// A binary math function takes two arguments of types Float32 or Float64, -/// applies a binary floating function to the argument, and returns a value of the same type. +/// A binary math function takes two numeric arguments. When both arguments are +/// Float32 the function is evaluated in single precision and returns Float32. +/// Any other combination of numeric (or null) argument types is coerced to +/// Float64 and returns Float64; in particular integers are widened to Float64 +/// rather than Float32 so that values needing more than 24 bits of mantissa are +/// not silently rounded. /// /// $UDF: the name of the UDF struct that implements `ScalarUDFImpl` /// $NAME: the name of the function @@ -302,14 +360,14 @@ macro_rules! make_math_binary_udf { $crate::make_udf_function!($NAME::$UDF, $NAME); mod $NAME { - use std::any::Any; + use std::sync::Arc; use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; - use datafusion_common::{exec_err, Result}; + use datafusion_common::utils::take_function_args; + use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; - use datafusion_expr::TypeSignature; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -322,13 +380,18 @@ macro_rules! make_math_binary_udf { impl $UDF { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Float32, Float32]), - TypeSignature::Exact(vec![Float64, Float64]), - ], + // Float64 is listed first so that integer (and other + // non-float) arguments coerce to Float64 rather than + // Float32; genuine Float32 arguments still match + // exactly and stay in single precision. Coercing + // integers to Float64 matters for correctness: Float32 + // has only a 24-bit mantissa, so widening a large + // integer to Float32 would round it before the function + // is ever applied. + signature: Signature::uniform( + 2, + vec![DataType::Float64, DataType::Float32], Volatility::Immutable, ), } @@ -336,9 +399,6 @@ macro_rules! make_math_binary_udf { } impl ScalarUDFImpl for $UDF { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { stringify!($NAME) } @@ -348,11 +408,8 @@ macro_rules! make_math_binary_udf { } fn return_type(&self, arg_types: &[DataType]) -> Result { - let arg_type = &arg_types[0]; - - match arg_type { - DataType::Float32 => Ok(DataType::Float32), - // For other types (possible values float64/null/int), use Float64 + match (&arg_types[0], &arg_types[1]) { + (DataType::Float32, DataType::Float32) => Ok(DataType::Float32), _ => Ok(DataType::Float64), } } @@ -368,37 +425,76 @@ macro_rules! make_math_binary_udf { &self, args: ScalarFunctionArgs, ) -> Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => { - let y = args[0].as_primitive::(); - let x = args[1].as_primitive::(); - let result = arrow::compute::binary::<_, _, _, Float64Type>( - y, - x, - |y, x| f64::$BINARY_FUNC(y, x), - )?; - Arc::new(result) as _ + let ScalarFunctionArgs { + args, return_field, .. + } = args; + let return_type = return_field.data_type(); + let [y, x] = take_function_args(self.name(), args)?; + + match (y, x) { + ( + ColumnarValue::Scalar(y_scalar), + ColumnarValue::Scalar(x_scalar), + ) => match (&y_scalar, &x_scalar) { + (y, x) if y.is_null() || x.is_null() => { + ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(return_type, None) + } + ( + ScalarValue::Float64(Some(yv)), + ScalarValue::Float64(Some(xv)), + ) => Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some( + f64::$BINARY_FUNC(*yv, *xv), + )))), + ( + ScalarValue::Float32(Some(yv)), + ScalarValue::Float32(Some(xv)), + ) => Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some( + f32::$BINARY_FUNC(*yv, *xv), + )))), + _ => internal_err!( + "Unexpected scalar types for function {}: {:?}, {:?}", + self.name(), + y_scalar.data_type(), + x_scalar.data_type() + ), + }, + (y, x) => { + let args = ColumnarValue::values_to_arrays(&[y, x])?; + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = + arrow::compute::binary::<_, _, _, Float64Type>( + y, + x, + |y, x| f64::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } + DataType::Float32 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = + arrow::compute::binary::<_, _, _, Float32Type>( + y, + x, + |y, x| f32::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } + other => { + return internal_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + Ok(ColumnarValue::Array(arr)) } - DataType::Float32 => { - let y = args[0].as_primitive::(); - let x = args[1].as_primitive::(); - let result = arrow::compute::binary::<_, _, _, Float32Type>( - y, - x, - |y, x| f32::$BINARY_FUNC(y, x), - )?; - Arc::new(result) as _ - } - other => { - return exec_err!( - "Unsupported data type {other:?} for function {}", - self.name() - ) - } - }; - - Ok(ColumnarValue::Array(arr)) + } } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index 35d0f3eccf573..02ac89756d919 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -17,17 +17,16 @@ //! math expressions -use std::any::Any; use std::sync::Arc; use arrow::array::{ - ArrayRef, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array, - Float16Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, + ArrayRef, Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, + Float16Array, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, + Int64Array, }; use arrow::datatypes::DataType; use arrow::error::ArrowError; -use datafusion_common::{not_impl_err, utils::take_function_args, Result}; +use datafusion_common::{Result, not_impl_err, utils::take_function_args}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ @@ -50,6 +49,7 @@ macro_rules! make_abs_function { }}; } +#[macro_export] macro_rules! make_try_abs_function { ($ARRAY_TYPE:ident) => {{ |input: &ArrayRef| { @@ -62,7 +62,8 @@ macro_rules! make_try_abs_function { x )) }) - })?; + }) + .and_then(|v| Ok(v.with_data_type(input.data_type().clone())))?; // maintain decimal's precision and scale Ok(Arc::new(res) as ArrayRef) } }}; @@ -145,10 +146,6 @@ impl AbsFunc { } impl ScalarUDFImpl for AbsFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "abs" } diff --git a/datafusion/functions/src/math/ceil.rs b/datafusion/functions/src/math/ceil.rs new file mode 100644 index 0000000000000..395cb4eae03f5 --- /dev/null +++ b/datafusion/functions/src/math/ceil.rs @@ -0,0 +1,201 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::{ + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float32Type, + Float64Type, +}; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, TypeSignatureClass, Volatility, +}; +use datafusion_macros::user_doc; + +use super::decimal::{apply_decimal_op, ceil_decimal_value}; + +#[user_doc( + doc_section(label = "Math Functions"), + description = "Returns the nearest integer greater than or equal to a number.", + syntax_example = "ceil(numeric_expression)", + standard_argument(name = "numeric_expression", prefix = "Numeric"), + sql_example = r#"```sql +> SELECT ceil(3.14); ++------------+ +| ceil(3.14) | ++------------+ +| 4.0 | ++------------+ +```"# +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct CeilFunc { + signature: Signature, +} + +impl Default for CeilFunc { + fn default() -> Self { + Self::new() + } +} + +impl CeilFunc { + pub fn new() -> Self { + let decimal_sig = Coercion::new_exact(TypeSignatureClass::Decimal); + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![decimal_sig]), + TypeSignature::Uniform(1, vec![DataType::Float64, DataType::Float32]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for CeilFunc { + fn name(&self) -> &str { + "ceil" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::Null => Ok(DataType::Float64), + other => Ok(other.clone()), + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let arg = &args.args[0]; + + // Scalar fast path for float types - avoid array conversion overhead entirely + if let ColumnarValue::Scalar(scalar) = arg { + match scalar { + ScalarValue::Float64(v) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64( + v.map(f64::ceil), + ))); + } + ScalarValue::Float32(v) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float32( + v.map(f32::ceil), + ))); + } + ScalarValue::Null => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); + } + // For decimals: convert to array of size 1, process, then extract scalar + // This ensures we don't expand the array while reusing overflow validation + _ => {} + } + } + + // Track if input was a scalar to convert back at the end + let is_scalar = matches!(arg, ColumnarValue::Scalar(_)); + + // Array path (also handles decimal scalars converted to size-1 arrays) + let value = arg.to_array(args.number_rows)?; + + let result: ArrayRef = match value.data_type() { + DataType::Float64 => Arc::new( + value + .as_primitive::() + .unary::<_, Float64Type>(f64::ceil), + ), + DataType::Float32 => Arc::new( + value + .as_primitive::() + .unary::<_, Float32Type>(f32::ceil), + ), + DataType::Null => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); + } + DataType::Decimal32(precision, scale) => { + apply_decimal_op::( + &value, + *precision, + *scale, + self.name(), + ceil_decimal_value, + )? + } + DataType::Decimal64(precision, scale) => { + apply_decimal_op::( + &value, + *precision, + *scale, + self.name(), + ceil_decimal_value, + )? + } + DataType::Decimal128(precision, scale) => { + apply_decimal_op::( + &value, + *precision, + *scale, + self.name(), + ceil_decimal_value, + )? + } + DataType::Decimal256(precision, scale) => { + apply_decimal_op::( + &value, + *precision, + *scale, + self.name(), + ceil_decimal_value, + )? + } + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + // If input was a scalar, convert result back to scalar + if is_scalar { + ScalarValue::try_from_array(&result, 0).map(ColumnarValue::Scalar) + } else { + Ok(ColumnarValue::Array(result)) + } + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) + } + + fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result { + let data_type = inputs[0].data_type(); + Interval::make_unbounded(&data_type) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/functions/src/math/cot.rs b/datafusion/functions/src/math/cot.rs index 43f2012d073dd..24f0a412e3a8a 100644 --- a/datafusion/functions/src/math/cot.rs +++ b/datafusion/functions/src/math/cot.rs @@ -15,15 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, AsArray}; +use arrow::array::AsArray; use arrow::datatypes::DataType::{Float32, Float64}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use crate::utils::make_scalar_function; -use datafusion_common::{exec_err, Result}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -72,10 +71,6 @@ impl CotFunc { } impl ScalarUDFImpl for CotFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "cot" } @@ -96,24 +91,47 @@ impl ScalarUDFImpl for CotFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(cot, vec![])(&args.args) - } -} + let return_field = args.return_field; + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(return_field.data_type(), None); + } -///cot SQL function -fn cot(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - Float64 => Ok(Arc::new( - args[0] - .as_primitive::() - .unary::<_, Float64Type>(|x: f64| compute_cot64(x)), - ) as ArrayRef), - Float32 => Ok(Arc::new( - args[0] - .as_primitive::() - .unary::<_, Float32Type>(|x: f32| compute_cot32(x)), - ) as ArrayRef), - other => exec_err!("Unsupported data type {other:?} for function cot"), + match scalar { + ScalarValue::Float64(Some(v)) => Ok(ColumnarValue::Scalar( + ScalarValue::Float64(Some(compute_cot64(v))), + )), + ScalarValue::Float32(Some(v)) => Ok(ColumnarValue::Scalar( + ScalarValue::Float32(Some(compute_cot32(v))), + )), + _ => { + internal_err!( + "Unexpected scalar type for cot: {:?}", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => match array.data_type() { + Float64 => Ok(ColumnarValue::Array(Arc::new( + array + .as_primitive::() + .unary::<_, Float64Type>(compute_cot64), + ))), + Float32 => Ok(ColumnarValue::Array(Arc::new( + array + .as_primitive::() + .unary::<_, Float32Type>(compute_cot32), + ))), + other => { + internal_err!("Unexpected data type {other:?} for function cot") + } + }, + } } } @@ -129,54 +147,212 @@ fn compute_cot64(x: f64) -> f64 { #[cfg(test)] mod test { - use crate::math::cot::cot; + use std::sync::Arc; + use arrow::array::{ArrayRef, Float32Array, Float64Array}; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::ScalarValue; use datafusion_common::cast::{as_float32_array, as_float64_array}; - use std::sync::Arc; + use datafusion_common::config::ConfigOptions; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + + use crate::math::cot::CotFunc; #[test] fn test_cot_f32() { - let args: Vec = - vec![Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]))]; - let result = cot(&args).expect("failed to initialize function cot"); - let floats = - as_float32_array(&result).expect("failed to initialize function cot"); - - let expected = Float32Array::from(vec![ - -1.986_460_4, - -0.156_119_96, - -0.501_202_8, - 0.156_119_96, - ]); - - let eps = 1e-6; - assert_eq!(floats.len(), 4); - assert!((floats.value(0) - expected.value(0)).abs() < eps); - assert!((floats.value(1) - expected.value(1)).abs() < eps); - assert!((floats.value(2) - expected.value(2)).abs() < eps); - assert!((floats.value(3) - expected.value(3)).abs() < eps); + let array = Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0])); + let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + arg_fields, + number_rows: array.len(), + return_field: Field::new("f", DataType::Float32, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function cot"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float32Array"); + + let expected = Float32Array::from(vec![ + -1.986_460_4, + -0.156_119_96, + -0.501_202_8, + 0.156_119_96, + ]); + + let eps = 1e-6; + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - expected.value(0)).abs() < eps); + assert!((floats.value(1) - expected.value(1)).abs() < eps); + assert!((floats.value(2) - expected.value(2)).abs() < eps); + assert!((floats.value(3) - expected.value(3)).abs() < eps); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } } #[test] fn test_cot_f64() { - let args: Vec = - vec![Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]))]; - let result = cot(&args).expect("failed to initialize function cot"); - let floats = - as_float64_array(&result).expect("failed to initialize function cot"); - - let expected = Float64Array::from(vec![ - -1.986_458_685_881_4, - -0.156_119_952_161_6, - -0.501_202_783_380_1, - 0.156_119_952_161_6, - ]); - - let eps = 1e-12; - assert_eq!(floats.len(), 4); - assert!((floats.value(0) - expected.value(0)).abs() < eps); - assert!((floats.value(1) - expected.value(1)).abs() < eps); - assert!((floats.value(2) - expected.value(2)).abs() < eps); - assert!((floats.value(3) - expected.value(3)).abs() < eps); + let array = Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0])); + let arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + arg_fields, + number_rows: array.len(), + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("failed to initialize function cot"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + let expected = Float64Array::from(vec![ + -1.986_458_685_881_4, + -0.156_119_952_161_6, + -0.501_202_783_380_1, + 0.156_119_952_161_6, + ]); + + let eps = 1e-12; + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - expected.value(0)).abs() < eps); + assert!((floats.value(1) - expected.value(1)).abs() < eps); + assert!((floats.value(2) - expected.value(2)).abs() < eps); + assert!((floats.value(3) - expected.value(3)).abs() < eps); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_cot_scalar_f64() { + let arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0)))], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("cot scalar should succeed"); + + match result { + ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => { + // cot(1.0) = 1/tan(1.0) ≈ 0.6420926159343306 + let expected = 1.0_f64 / 1.0_f64.tan(); + assert!((v - expected).abs() < 1e-12); + } + _ => panic!("Expected Float64 scalar"), + } + } + + #[test] + fn test_cot_scalar_f32() { + let arg_fields = vec![Field::new("a", DataType::Float32, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0)))], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float32, false).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("cot scalar should succeed"); + + match result { + ColumnarValue::Scalar(ScalarValue::Float32(Some(v))) => { + let expected = 1.0_f32 / 1.0_f32.tan(); + assert!((v - expected).abs() < 1e-6); + } + _ => panic!("Expected Float32 scalar"), + } + } + + #[test] + fn test_cot_scalar_null() { + let arg_fields = vec![Field::new("a", DataType::Float64, true).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Float64(None))], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("cot null should succeed"); + + match result { + ColumnarValue::Scalar(scalar) => { + assert!(scalar.is_null()); + } + _ => panic!("Expected scalar result"), + } + } + + #[test] + fn test_cot_scalar_zero() { + let arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some(0.0)))], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("cot zero should succeed"); + + match result { + ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => { + // cot(0) = 1/tan(0) = infinity + assert!(v.is_infinite()); + } + _ => panic!("Expected Float64 scalar"), + } + } + + #[test] + fn test_cot_scalar_pi() { + let arg_fields = vec![Field::new("a", DataType::Float64, false).into()]; + let args = ScalarFunctionArgs { + args: vec![ColumnarValue::Scalar(ScalarValue::Float64(Some( + std::f64::consts::PI, + )))], + arg_fields, + number_rows: 1, + return_field: Field::new("f", DataType::Float64, false).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = CotFunc::new() + .invoke_with_args(args) + .expect("cot pi should succeed"); + + match result { + ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => { + // cot(PI) = 1/tan(PI) - very large negative number due to floating point + let expected = 1.0_f64 / std::f64::consts::PI.tan(); + assert!((v - expected).abs() < 1e-6); + } + _ => panic!("Expected Float64 scalar"), + } } } diff --git a/datafusion/functions/src/math/decimal.rs b/datafusion/functions/src/math/decimal.rs new file mode 100644 index 0000000000000..abaded4568a93 --- /dev/null +++ b/datafusion/functions/src/math/decimal.rs @@ -0,0 +1,111 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; +use arrow::datatypes::{ArrowNativeTypeOp, DecimalType}; +use arrow::error::ArrowError; +use arrow_buffer::ArrowNativeType; +use datafusion_common::{DataFusionError, Result}; + +pub(super) fn apply_decimal_op( + array: &ArrayRef, + precision: u8, + scale: i8, + fn_name: &str, + op: F, +) -> Result +where + T: DecimalType, + T::Native: ArrowNativeType + ArrowNativeTypeOp, + F: Fn(T::Native, T::Native) -> T::Native, +{ + if scale <= 0 { + return Ok(Arc::clone(array)); + } + + let factor = decimal_scale_factor::(scale, fn_name)?; + let decimal = array.as_primitive::(); + let data_type = array.data_type().clone(); + + let result: PrimitiveArray = decimal.try_unary(|value| { + let new_value = op(value, factor); + T::validate_decimal_precision(new_value, precision, scale).map_err(|_| { + ArrowError::ComputeError(format!("Decimal overflow while applying {fn_name}")) + })?; + Ok::<_, ArrowError>(new_value) + })?; + + let result = result.with_data_type(data_type); + + Ok(Arc::new(result)) +} + +fn decimal_scale_factor(scale: i8, fn_name: &str) -> Result +where + T: DecimalType, + T::Native: ArrowNativeType + ArrowNativeTypeOp, +{ + let base = ::from_usize(10).ok_or_else(|| { + DataFusionError::Execution(format!( + "Cannot get 10_{} from usize: {:?}", + std::any::type_name::(), + 10_usize + )) + })?; + + base.pow_checked(scale as u32).map_err(|_| { + DataFusionError::Execution(format!("Decimal overflow while applying {fn_name}")) + }) +} + +pub(super) fn ceil_decimal_value(value: T, factor: T) -> T +where + T: ArrowNativeTypeOp + std::ops::Rem, +{ + let remainder = value % factor; + + if remainder == T::ZERO { + return value; + } + + if value >= T::ZERO { + let increment = factor.sub_wrapping(remainder); + value.add_wrapping(increment) + } else { + value.sub_wrapping(remainder) + } +} + +pub(super) fn floor_decimal_value(value: T, factor: T) -> T +where + T: ArrowNativeTypeOp + std::ops::Rem, +{ + let remainder = value % factor; + + if remainder == T::ZERO { + return value; + } + + if value >= T::ZERO { + value.sub_wrapping(remainder) + } else { + let adjustment = factor.add_wrapping(remainder); + value.sub_wrapping(adjustment) + } +} diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs index 79f6da94dd0e1..3b4f973f19d62 100644 --- a/datafusion/functions/src/math/factorial.rs +++ b/datafusion/functions/src/math/factorial.rs @@ -15,18 +15,15 @@ // specific language governing permissions and limitations // under the License. -use arrow::{ - array::{ArrayRef, Int64Array}, - error::ArrowError, -}; -use std::any::Any; +use arrow::array::{ArrayRef, AsArray, Int64Array}; use std::sync::Arc; -use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; +use arrow::datatypes::{DataType, Int64Type}; -use crate::utils::make_scalar_function; -use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; +use datafusion_common::{ + Result, ScalarValue, exec_err, internal_err, utils::take_function_args, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -35,7 +32,7 @@ use datafusion_macros::user_doc; #[user_doc( doc_section(label = "Math Functions"), - description = "Factorial. Returns 1 if value is less than 2.", + description = "Factorial of a non-negative integer. Errors if the argument is negative or the result overflows.", syntax_example = "factorial(numeric_expression)", sql_example = r#"```sql > SELECT factorial(5); @@ -67,10 +64,6 @@ impl FactorialFunc { } impl ScalarUDFImpl for FactorialFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "factorial" } @@ -84,7 +77,39 @@ impl ScalarUDFImpl for FactorialFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(factorial, vec![])(&args.args) + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))); + } + + match scalar { + ScalarValue::Int64(Some(v)) => { + let result = compute_factorial(v)?; + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result)))) + } + _ => { + internal_err!( + "Unexpected data type {:?} for function factorial", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => match array.data_type() { + Int64 => { + let result: Int64Array = array + .as_primitive::() + .try_unary(compute_factorial)?; + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } + other => { + internal_err!("Unexpected data type {other:?} for function factorial") + } + }, + } } fn documentation(&self) -> Option<&Documentation> { @@ -92,50 +117,36 @@ impl ScalarUDFImpl for FactorialFunc { } } -/// Factorial SQL function -fn factorial(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - Int64 => { - let arg = downcast_named_arg!((&args[0]), "value", Int64Array); - Ok(arg - .iter() - .map(|a| match a { - Some(a) => (2..=a) - .try_fold(1i64, i64::checked_mul) - .ok_or_else(|| { - arrow_datafusion_err!(ArrowError::ComputeError(format!( - "Overflow happened on FACTORIAL({a})" - ))) - }) - .map(Some), - _ => Ok(None), - }) - .collect::>() - .map(Arc::new)? as ArrayRef) - } - other => exec_err!("Unsupported data type {other:?} for function factorial."), - } -} - -#[cfg(test)] -mod test { - - use datafusion_common::cast::as_int64_array; - - use super::*; - - #[test] - fn test_factorial_i64() { - let args: Vec = vec![ - Arc::new(Int64Array::from(vec![0, 1, 2, 4])), // input - ]; - - let result = factorial(&args).expect("failed to initialize function factorial"); - let ints = - as_int64_array(&result).expect("failed to initialize function factorial"); - - let expected = Int64Array::from(vec![1, 1, 2, 24]); - - assert_eq!(ints, &expected); +const FACTORIALS: [i64; 21] = [ + 1, + 1, + 2, + 6, + 24, + 120, + 720, + 5040, + 40320, + 362880, + 3628800, + 39916800, + 479001600, + 6227020800, + 87178291200, + 1307674368000, + 20922789888000, + 355687428096000, + 6402373705728000, + 121645100408832000, + 2432902008176640000, +]; // if return type changes, this constant needs to be updated accordingly + +fn compute_factorial(n: i64) -> Result { + if n < 0 { + exec_err!("factorial of a negative number is undefined") + } else if n < FACTORIALS.len() as i64 { + Ok(FACTORIALS[n as usize]) + } else { + exec_err!("Overflow happened on FACTORIAL({n})") } } diff --git a/datafusion/functions/src/math/floor.rs b/datafusion/functions/src/math/floor.rs new file mode 100644 index 0000000000000..e02aa141c5b71 --- /dev/null +++ b/datafusion/functions/src/math/floor.rs @@ -0,0 +1,684 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, AsArray}; +use arrow::compute::{DecimalCast, rescale_decimal}; +use arrow::datatypes::{ + ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, + Decimal256Type, DecimalType, Float32Type, Float64Type, +}; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::preimage::PreimageResult; +use datafusion_expr::simplify::SimplifyContext; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignature, TypeSignatureClass, Volatility, +}; +use datafusion_macros::user_doc; +use num_traits::{CheckedAdd, Float, One}; + +use super::decimal::{apply_decimal_op, floor_decimal_value}; + +#[user_doc( + doc_section(label = "Math Functions"), + description = "Returns the nearest integer less than or equal to a number.", + syntax_example = "floor(numeric_expression)", + standard_argument(name = "numeric_expression", prefix = "Numeric"), + sql_example = r#"```sql +> SELECT floor(3.14); ++-------------+ +| floor(3.14) | ++-------------+ +| 3.0 | ++-------------+ +```"# +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct FloorFunc { + signature: Signature, +} + +impl Default for FloorFunc { + fn default() -> Self { + Self::new() + } +} + +impl FloorFunc { + pub fn new() -> Self { + let decimal_sig = Coercion::new_exact(TypeSignatureClass::Decimal); + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![decimal_sig]), + TypeSignature::Uniform(1, vec![DataType::Float64, DataType::Float32]), + ], + Volatility::Immutable, + ), + } + } +} + +// ============ Macro for preimage bounds ============ +/// Generates the code to call the appropriate bounds function and wrap results. +macro_rules! preimage_bounds { + // Float types: call float_preimage_bounds and wrap in ScalarValue + (float: $variant:ident, $value:expr) => { + float_preimage_bounds($value).map(|(lo, hi)| { + ( + ScalarValue::$variant(Some(lo)), + ScalarValue::$variant(Some(hi)), + ) + }) + }; + + // Integer types: call int_preimage_bounds and wrap in ScalarValue + (int: $variant:ident, $value:expr) => { + int_preimage_bounds($value).map(|(lo, hi)| { + ( + ScalarValue::$variant(Some(lo)), + ScalarValue::$variant(Some(hi)), + ) + }) + }; + + // Decimal types: call decimal_preimage_bounds with precision/scale and wrap in ScalarValue + (decimal: $variant:ident, $decimal_type:ty, $value:expr, $precision:expr, $scale:expr) => { + decimal_preimage_bounds::<$decimal_type>($value, $precision, $scale).map( + |(lo, hi)| { + ( + ScalarValue::$variant(Some(lo), $precision, $scale), + ScalarValue::$variant(Some(hi), $precision, $scale), + ) + }, + ) + }; +} + +impl ScalarUDFImpl for FloorFunc { + fn name(&self) -> &str { + "floor" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::Null => Ok(DataType::Float64), + other => Ok(other.clone()), + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let arg = &args.args[0]; + + // Scalar fast path for float types - avoid array conversion overhead entirely + if let ColumnarValue::Scalar(scalar) = arg { + match scalar { + ScalarValue::Float64(v) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64( + v.map(f64::floor), + ))); + } + ScalarValue::Float32(v) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float32( + v.map(f32::floor), + ))); + } + ScalarValue::Null => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); + } + // For decimals: convert to array of size 1, process, then extract scalar + // This ensures we don't expand the array while reusing overflow validation + _ => {} + } + } + + // Track if input was a scalar to convert back at the end + let is_scalar = matches!(arg, ColumnarValue::Scalar(_)); + + // Array path (also handles decimal scalars converted to size-1 arrays) + let value = arg.to_array(args.number_rows)?; + + let result: ArrayRef = match value.data_type() { + DataType::Float64 => Arc::new( + value + .as_primitive::() + .unary::<_, Float64Type>(f64::floor), + ), + DataType::Float32 => Arc::new( + value + .as_primitive::() + .unary::<_, Float32Type>(f32::floor), + ), + DataType::Null => { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); + } + DataType::Decimal32(precision, scale) => { + apply_decimal_op::( + &value, + *precision, + *scale, + self.name(), + floor_decimal_value, + )? + } + DataType::Decimal64(precision, scale) => { + apply_decimal_op::( + &value, + *precision, + *scale, + self.name(), + floor_decimal_value, + )? + } + DataType::Decimal128(precision, scale) => { + apply_decimal_op::( + &value, + *precision, + *scale, + self.name(), + floor_decimal_value, + )? + } + DataType::Decimal256(precision, scale) => { + apply_decimal_op::( + &value, + *precision, + *scale, + self.name(), + floor_decimal_value, + )? + } + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + // If input was a scalar, convert result back to scalar + if is_scalar { + ScalarValue::try_from_array(&result, 0).map(ColumnarValue::Scalar) + } else { + Ok(ColumnarValue::Array(result)) + } + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) + } + + fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result { + let data_type = inputs[0].data_type(); + Interval::make_unbounded(&data_type) + } + + /// Compute the preimage for floor function. + /// + /// For `floor(x) = N`, the preimage is `x >= N AND x < N + 1` + /// because floor(x) = N for all x in [N, N+1). + /// + /// This enables predicate pushdown optimizations, transforming: + /// `floor(col) = 100` into `col >= 100 AND col < 101` + fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + _info: &SimplifyContext, + ) -> Result { + // floor takes exactly one argument and we do not expect to reach here with multiple arguments. + debug_assert!(args.len() == 1, "floor() takes exactly one argument"); + + let arg = args[0].clone(); + + // Extract the literal value being compared to + let Expr::Literal(lit_value, _) = lit_expr else { + return Ok(PreimageResult::None); + }; + + // Compute lower bound (N) and upper bound (N + 1) using helper functions + let Some((lower, upper)) = (match lit_value { + // Floating-point types + ScalarValue::Float64(Some(n)) => preimage_bounds!(float: Float64, *n), + ScalarValue::Float32(Some(n)) => preimage_bounds!(float: Float32, *n), + + // Integer types (not reachable from SQL/SLT: floor() only accepts Float64/Float32/Decimal, + // so the RHS literal is always coerced to one of those before preimage runs; kept for + // programmatic use and unit tests) + ScalarValue::Int8(Some(n)) => preimage_bounds!(int: Int8, *n), + ScalarValue::Int16(Some(n)) => preimage_bounds!(int: Int16, *n), + ScalarValue::Int32(Some(n)) => preimage_bounds!(int: Int32, *n), + ScalarValue::Int64(Some(n)) => preimage_bounds!(int: Int64, *n), + + // Decimal types + // DECIMAL(precision, scale) where precision ≤ 38 -> Decimal128(precision, scale) + // DECIMAL(precision, scale) where precision > 38 -> Decimal256(precision, scale) + // Decimal32 and Decimal64 are unreachable from SQL/SLT. + ScalarValue::Decimal32(Some(n), precision, scale) => { + preimage_bounds!(decimal: Decimal32, Decimal32Type, *n, *precision, *scale) + } + ScalarValue::Decimal64(Some(n), precision, scale) => { + preimage_bounds!(decimal: Decimal64, Decimal64Type, *n, *precision, *scale) + } + ScalarValue::Decimal128(Some(n), precision, scale) => { + preimage_bounds!(decimal: Decimal128, Decimal128Type, *n, *precision, *scale) + } + ScalarValue::Decimal256(Some(n), precision, scale) => { + preimage_bounds!(decimal: Decimal256, Decimal256Type, *n, *precision, *scale) + } + + // Unsupported types + _ => None, + }) else { + return Ok(PreimageResult::None); + }; + + Ok(PreimageResult::Range { + expr: arg, + interval: Box::new(Interval::try_new(lower, upper)?), + }) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +// ============ Helper functions for preimage bounds ============ + +/// Compute preimage bounds for floor function on floating-point types. +/// For floor(x) = n, the preimage is [n, n+1). +/// Returns None if: +/// - The value is non-finite (infinity, NaN) +/// - The value is not an integer (floor always returns integers, so floor(x) = 1.3 has no solution) +/// - Adding 1 would lose precision at extreme values +fn float_preimage_bounds(n: F) -> Option<(F, F)> { + let one = F::one(); + // Check for non-finite values (infinity, NaN) + if !n.is_finite() { + return None; + } + // floor always returns an integer, so if n has a fractional part, there's no solution + if n.fract() != F::zero() { + return None; + } + // Check for precision loss at extreme values + if n + one <= n { + return None; + } + Some((n, n + one)) +} + +/// Compute preimage bounds for floor function on integer types. +/// For floor(x) = n, the preimage is [n, n+1). +/// Returns None if adding 1 would overflow. +fn int_preimage_bounds(n: I) -> Option<(I, I)> { + let upper = n.checked_add(&I::one())?; + Some((n, upper)) +} + +/// Compute preimage bounds for floor function on decimal types. +/// For floor(x) = n, the preimage is [n, n+1). +/// Returns None if: +/// - The value has a fractional part (floor always returns integers) +/// - Adding 1 would overflow +fn decimal_preimage_bounds( + value: D::Native, + precision: u8, + scale: i8, +) -> Option<(D::Native, D::Native)> +where + D::Native: DecimalCast + ArrowNativeTypeOp + std::ops::Rem, +{ + // Use rescale_decimal to compute "1" at target scale (avoids manual pow) + // Convert integer 1 (scale=0) to the target scale + let one_scaled: D::Native = rescale_decimal::( + D::Native::ONE, // value = 1 + 1, // input_precision = 1 + 0, // input_scale = 0 (integer) + precision, // output_precision + scale, // output_scale + )?; + + // floor always returns an integer, so if value has a fractional part, there's no solution + // Check: value % one_scaled != 0 means fractional part exists + if scale > 0 && value % one_scaled != D::Native::ZERO { + return None; + } + + // Compute upper bound using checked addition + // Before preimage stage, the internal i128/i256(value) is validated based on the precision and scale. + // MAX_DECIMAL128_FOR_EACH_PRECISION and MAX_DECIMAL256_FOR_EACH_PRECISION are used to validate the internal i128/i256. + // Any invalid i128/i256 will not reach here. + // Therefore, the add_checked will always succeed if tested via SQL/SLT path. + let upper = value.add_checked(one_scaled).ok()?; + + Some((value, upper)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_buffer::i256; + use datafusion_expr::col; + + /// Helper to test valid preimage cases that should return a Range + fn assert_preimage_range( + input: ScalarValue, + expected_lower: ScalarValue, + expected_upper: ScalarValue, + ) { + let floor_func = FloorFunc::new(); + let args = vec![col("x")]; + let lit_expr = Expr::Literal(input.clone(), None); + let info = SimplifyContext::default(); + + let result = floor_func.preimage(&args, &lit_expr, &info).unwrap(); + + match result { + PreimageResult::Range { expr, interval } => { + assert_eq!(expr, col("x")); + assert_eq!(interval.lower().clone(), expected_lower); + assert_eq!(interval.upper().clone(), expected_upper); + } + PreimageResult::None => { + panic!("Expected Range, got None for input {input:?}") + } + } + } + + /// Helper to test cases that should return None + fn assert_preimage_none(input: ScalarValue) { + let floor_func = FloorFunc::new(); + let args = vec![col("x")]; + let lit_expr = Expr::Literal(input.clone(), None); + let info = SimplifyContext::default(); + + let result = floor_func.preimage(&args, &lit_expr, &info).unwrap(); + assert!( + matches!(result, PreimageResult::None), + "Expected None for input {input:?}" + ); + } + + #[test] + fn test_floor_preimage_valid_cases() { + // Float64 + assert_preimage_range( + ScalarValue::Float64(Some(100.0)), + ScalarValue::Float64(Some(100.0)), + ScalarValue::Float64(Some(101.0)), + ); + // Float32 + assert_preimage_range( + ScalarValue::Float32(Some(50.0)), + ScalarValue::Float32(Some(50.0)), + ScalarValue::Float32(Some(51.0)), + ); + // Int64 + assert_preimage_range( + ScalarValue::Int64(Some(42)), + ScalarValue::Int64(Some(42)), + ScalarValue::Int64(Some(43)), + ); + // Int32 + assert_preimage_range( + ScalarValue::Int32(Some(100)), + ScalarValue::Int32(Some(100)), + ScalarValue::Int32(Some(101)), + ); + // Negative values + assert_preimage_range( + ScalarValue::Float64(Some(-5.0)), + ScalarValue::Float64(Some(-5.0)), + ScalarValue::Float64(Some(-4.0)), + ); + // Zero + assert_preimage_range( + ScalarValue::Float64(Some(0.0)), + ScalarValue::Float64(Some(0.0)), + ScalarValue::Float64(Some(1.0)), + ); + } + + #[test] + fn test_floor_preimage_non_integer_float() { + // floor(x) = 1.3 has NO SOLUTION because floor always returns an integer + // Therefore preimage should return None for non-integer literals + assert_preimage_none(ScalarValue::Float64(Some(1.3))); + assert_preimage_none(ScalarValue::Float64(Some(-2.5))); + assert_preimage_none(ScalarValue::Float32(Some(3.7))); + } + + #[test] + fn test_floor_preimage_integer_overflow() { + // All integer types at MAX value should return None + assert_preimage_none(ScalarValue::Int64(Some(i64::MAX))); + assert_preimage_none(ScalarValue::Int32(Some(i32::MAX))); + assert_preimage_none(ScalarValue::Int16(Some(i16::MAX))); + assert_preimage_none(ScalarValue::Int8(Some(i8::MAX))); + } + + #[test] + fn test_floor_preimage_float_edge_cases() { + // Float64 edge cases + assert_preimage_none(ScalarValue::Float64(Some(f64::INFINITY))); + assert_preimage_none(ScalarValue::Float64(Some(f64::NEG_INFINITY))); + assert_preimage_none(ScalarValue::Float64(Some(f64::NAN))); + assert_preimage_none(ScalarValue::Float64(Some(f64::MAX))); // precision loss + + // Float32 edge cases + assert_preimage_none(ScalarValue::Float32(Some(f32::INFINITY))); + assert_preimage_none(ScalarValue::Float32(Some(f32::NEG_INFINITY))); + assert_preimage_none(ScalarValue::Float32(Some(f32::NAN))); + assert_preimage_none(ScalarValue::Float32(Some(f32::MAX))); // precision loss + } + + #[test] + fn test_floor_preimage_null_values() { + assert_preimage_none(ScalarValue::Float64(None)); + assert_preimage_none(ScalarValue::Float32(None)); + assert_preimage_none(ScalarValue::Int64(None)); + } + + // ============ Decimal32 Tests (mirrors float/int tests) ============ + + #[test] + fn test_floor_preimage_decimal_valid_cases() { + // ===== Decimal32 ===== + // Positive integer decimal: 100.00 (scale=2, so raw=10000) + // floor(x) = 100.00 -> x in [100.00, 101.00) + assert_preimage_range( + ScalarValue::Decimal32(Some(10000), 9, 2), + ScalarValue::Decimal32(Some(10000), 9, 2), // 100.00 + ScalarValue::Decimal32(Some(10100), 9, 2), // 101.00 + ); + + // Smaller positive: 50.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(5000), 9, 2), + ScalarValue::Decimal32(Some(5000), 9, 2), // 50.00 + ScalarValue::Decimal32(Some(5100), 9, 2), // 51.00 + ); + + // Negative integer decimal: -5.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(-500), 9, 2), + ScalarValue::Decimal32(Some(-500), 9, 2), // -5.00 + ScalarValue::Decimal32(Some(-400), 9, 2), // -4.00 + ); + + // Zero: 0.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(0), 9, 2), + ScalarValue::Decimal32(Some(0), 9, 2), // 0.00 + ScalarValue::Decimal32(Some(100), 9, 2), // 1.00 + ); + + // Scale 0 (pure integer): 42 + assert_preimage_range( + ScalarValue::Decimal32(Some(42), 9, 0), + ScalarValue::Decimal32(Some(42), 9, 0), + ScalarValue::Decimal32(Some(43), 9, 0), + ); + + // ===== Decimal64 ===== + assert_preimage_range( + ScalarValue::Decimal64(Some(10000), 18, 2), + ScalarValue::Decimal64(Some(10000), 18, 2), // 100.00 + ScalarValue::Decimal64(Some(10100), 18, 2), // 101.00 + ); + + // Negative + assert_preimage_range( + ScalarValue::Decimal64(Some(-500), 18, 2), + ScalarValue::Decimal64(Some(-500), 18, 2), // -5.00 + ScalarValue::Decimal64(Some(-400), 18, 2), // -4.00 + ); + + // Zero + assert_preimage_range( + ScalarValue::Decimal64(Some(0), 18, 2), + ScalarValue::Decimal64(Some(0), 18, 2), + ScalarValue::Decimal64(Some(100), 18, 2), + ); + + // ===== Decimal128 ===== + assert_preimage_range( + ScalarValue::Decimal128(Some(10000), 38, 2), + ScalarValue::Decimal128(Some(10000), 38, 2), // 100.00 + ScalarValue::Decimal128(Some(10100), 38, 2), // 101.00 + ); + + // Negative + assert_preimage_range( + ScalarValue::Decimal128(Some(-500), 38, 2), + ScalarValue::Decimal128(Some(-500), 38, 2), // -5.00 + ScalarValue::Decimal128(Some(-400), 38, 2), // -4.00 + ); + + // Zero + assert_preimage_range( + ScalarValue::Decimal128(Some(0), 38, 2), + ScalarValue::Decimal128(Some(0), 38, 2), + ScalarValue::Decimal128(Some(100), 38, 2), + ); + + // ===== Decimal256 ===== + assert_preimage_range( + ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), + ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), // 100.00 + ScalarValue::Decimal256(Some(i256::from(10100)), 76, 2), // 101.00 + ); + + // Negative + assert_preimage_range( + ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), + ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), // -5.00 + ScalarValue::Decimal256(Some(i256::from(-400)), 76, 2), // -4.00 + ); + + // Zero + assert_preimage_range( + ScalarValue::Decimal256(Some(i256::ZERO), 76, 2), + ScalarValue::Decimal256(Some(i256::ZERO), 76, 2), + ScalarValue::Decimal256(Some(i256::from(100)), 76, 2), + ); + } + + #[test] + fn test_floor_preimage_decimal_non_integer() { + // floor(x) = 1.30 has NO SOLUTION because floor always returns an integer + // Therefore preimage should return None for non-integer decimals + + // Decimal32 + assert_preimage_none(ScalarValue::Decimal32(Some(130), 9, 2)); // 1.30 + assert_preimage_none(ScalarValue::Decimal32(Some(-250), 9, 2)); // -2.50 + assert_preimage_none(ScalarValue::Decimal32(Some(370), 9, 2)); // 3.70 + assert_preimage_none(ScalarValue::Decimal32(Some(1), 9, 2)); // 0.01 + + // Decimal64 + assert_preimage_none(ScalarValue::Decimal64(Some(130), 18, 2)); // 1.30 + assert_preimage_none(ScalarValue::Decimal64(Some(-250), 18, 2)); // -2.50 + + // Decimal128 + assert_preimage_none(ScalarValue::Decimal128(Some(130), 38, 2)); // 1.30 + assert_preimage_none(ScalarValue::Decimal128(Some(-250), 38, 2)); // -2.50 + + // Decimal256 + assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(130)), 76, 2)); // 1.30 + assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(-250)), 76, 2)); // -2.50 + + // Decimal32: i32::MAX - 50 + // This return None because the value is not an integer, not because it is out of range. + assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX - 50), 10, 2)); + + // Decimal64: i64::MAX - 50 + // This return None because the value is not an integer, not because it is out of range. + assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX - 50), 19, 2)); + } + + #[test] + fn test_floor_preimage_decimal_overflow() { + // Test near MAX where adding scale_factor would overflow + + // Decimal32: i32::MAX + assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX), 10, 0)); + + // Decimal64: i64::MAX + assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX), 19, 0)); + } + + #[test] + fn test_floor_preimage_decimal_edge_cases() { + // ===== Decimal32 ===== + // Large value that doesn't overflow + // Decimal(9,2) max value is 9,999,999.99 (stored as 999,999,999) + // Use a large value that fits Decimal(9,2) and is divisible by 100 + let safe_max_aligned_32 = 999_999_900; // 9,999,999.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2), + ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2), + ScalarValue::Decimal32(Some(safe_max_aligned_32 + 100), 9, 2), + ); + + // Negative edge: use a large negative value that fits Decimal(9,2) + // Decimal(9,2) min value is -9,999,999.99 (stored as -999,999,999) + let min_aligned_32 = -999_999_900; // -9,999,999.00 + assert_preimage_range( + ScalarValue::Decimal32(Some(min_aligned_32), 9, 2), + ScalarValue::Decimal32(Some(min_aligned_32), 9, 2), + ScalarValue::Decimal32(Some(min_aligned_32 + 100), 9, 2), + ); + } + + #[test] + fn test_floor_preimage_decimal_null() { + assert_preimage_none(ScalarValue::Decimal32(None, 9, 2)); + assert_preimage_none(ScalarValue::Decimal64(None, 18, 2)); + assert_preimage_none(ScalarValue::Decimal128(None, 38, 2)); + assert_preimage_none(ScalarValue::Decimal256(None, 76, 2)); + } +} diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs index 0b85e7b54a782..8b92c454d9b4c 100644 --- a/datafusion/functions/src/math/gcd.rs +++ b/datafusion/functions/src/math/gcd.rs @@ -15,15 +15,14 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{new_null_array, ArrayRef, AsArray, Int64Array, PrimitiveArray}; +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::compute::try_binary; use arrow::datatypes::{DataType, Int64Type}; use arrow::error::ArrowError; -use std::any::Any; use std::mem::swap; use std::sync::Arc; -use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -69,10 +68,6 @@ impl GcdFunc { } impl ScalarUDFImpl for GcdFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "gcd" } @@ -94,20 +89,23 @@ impl ScalarUDFImpl for GcdFunc { [ColumnarValue::Array(a), ColumnarValue::Array(b)] => { compute_gcd_for_arrays(&a, &b) } - [ColumnarValue::Scalar(ScalarValue::Int64(a)), ColumnarValue::Scalar(ScalarValue::Int64(b))] => { - match (a, b) { - (Some(a), Some(b)) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( - Some(compute_gcd(a, b)?), - ))), - _ => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))), - } - } - [ColumnarValue::Array(a), ColumnarValue::Scalar(ScalarValue::Int64(b))] => { - compute_gcd_with_scalar(&a, b) - } - [ColumnarValue::Scalar(ScalarValue::Int64(a)), ColumnarValue::Array(b)] => { - compute_gcd_with_scalar(&b, a) - } + [ + ColumnarValue::Scalar(ScalarValue::Int64(a)), + ColumnarValue::Scalar(ScalarValue::Int64(b)), + ] => match (a, b) { + (Some(a), Some(b)) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + Some(compute_gcd(a, b)?), + ))), + _ => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))), + }, + [ + ColumnarValue::Array(a), + ColumnarValue::Scalar(ScalarValue::Int64(b)), + ] => compute_gcd_with_scalar(&a, b), + [ + ColumnarValue::Scalar(ScalarValue::Int64(a)), + ColumnarValue::Array(b), + ] => compute_gcd_with_scalar(&b, a), _ => exec_err!("Unsupported argument types for function gcd"), } } @@ -128,23 +126,25 @@ fn compute_gcd_for_arrays(a: &ArrayRef, b: &ArrayRef) -> Result { } fn compute_gcd_with_scalar(arr: &ArrayRef, scalar: Option) -> Result { + let prim = arr.as_primitive::(); match scalar { + Some(scalar_value) if scalar_value != 0 && scalar_value != i64::MIN => { + // The gcd result divides both inputs' absolute values. When the + // scalar is neither 0 nor i64::MIN, the gcd's absolute value fits + // in i64, so the cast to i64 below cannot overflow. This allows us + // to use `unary` instead of `try_unary`, which allows LLVM to + // vectorize more effectively. + let sv = scalar_value.unsigned_abs(); + let result: PrimitiveArray = + prim.unary(|val| unsigned_gcd(val.unsigned_abs(), sv) as i64); + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) + } Some(scalar_value) => { - let result: Result = arr - .as_primitive::() - .iter() - .map(|val| match val { - Some(val) => Ok(Some(compute_gcd(val, scalar_value)?)), - _ => Ok(None), - }) - .collect(); - - result.map(|arr| ColumnarValue::Array(Arc::new(arr) as ArrayRef)) + let result: PrimitiveArray = + prim.try_unary(|val| compute_gcd(val, scalar_value))?; + Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) } - None => Ok(ColumnarValue::Array(new_null_array( - &DataType::Int64, - arr.len(), - ))), + None => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))), } } @@ -176,7 +176,8 @@ pub fn compute_gcd(x: i64, y: i64) -> Result { let a = x.unsigned_abs(); let b = y.unsigned_abs(); let r = unsigned_gcd(a, b); - // gcd(i64::MIN, i64::MIN) = i64::MIN.unsigned_abs() cannot fit into i64 + // The result can be up to 2^63 (e.g. gcd(i64::MIN, 0) or + // gcd(i64::MIN, i64::MIN)), which does not fit into i64. r.try_into().map_err(|_| { ArrowError::ComputeError(format!("Signed integer overflow in GCD({x}, {y})")) }) diff --git a/datafusion/functions/src/math/iszero.rs b/datafusion/functions/src/math/iszero.rs index f053256a4870a..de6fc669692ee 100644 --- a/datafusion/functions/src/math/iszero.rs +++ b/datafusion/functions/src/math/iszero.rs @@ -15,23 +15,28 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, AsArray, BooleanArray}; -use arrow::datatypes::DataType::{Boolean, Float32, Float64}; -use arrow::datatypes::{DataType, Float32Type, Float64Type}; +use arrow::array::{ArrowNativeTypeOp, AsArray, BooleanArray}; +use arrow::datatypes::DataType::{ + Boolean, Decimal32, Decimal64, Decimal128, Decimal256, Float16, Float32, Float64, + Int8, Int16, Int32, Int64, Null, UInt8, UInt16, UInt32, UInt64, +}; +use arrow::datatypes::{ + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type, + Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, +}; -use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::Exact; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; +use datafusion_expr::{Coercion, TypeSignatureClass}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; -use crate::utils::make_scalar_function; - #[user_doc( doc_section(label = "Math Functions"), description = "Returns true if a given number is +0.0 or -0.0 otherwise returns false.", @@ -59,21 +64,15 @@ impl Default for IsZeroFunc { impl IsZeroFunc { pub fn new() -> Self { - use DataType::*; + // Accept any numeric type (ints, uints, floats, decimals) without implicit casts. + let numeric = Coercion::new_exact(TypeSignatureClass::Numeric); Self { - signature: Signature::one_of( - vec![Exact(vec![Float32]), Exact(vec![Float64])], - Volatility::Immutable, - ), + signature: Signature::coercible(vec![numeric], Volatility::Immutable), } } } impl ScalarUDFImpl for IsZeroFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "iszero" } @@ -87,70 +86,155 @@ impl ScalarUDFImpl for IsZeroFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(iszero, vec![])(&args.args) + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + + match scalar { + ScalarValue::Float64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0.0)))) + } + ScalarValue::Float32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0.0)))) + } + ScalarValue::Float16(Some(v)) => Ok(ColumnarValue::Scalar( + ScalarValue::Boolean(Some(v.is_zero())), + )), + + ScalarValue::Int8(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Int16(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Int32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Int64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::UInt8(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::UInt16(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::UInt32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::UInt64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + + ScalarValue::Decimal32(Some(v), ..) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Decimal64(Some(v), ..) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Decimal128(Some(v), ..) => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(v == 0)))) + } + ScalarValue::Decimal256(Some(v), ..) => Ok(ColumnarValue::Scalar( + ScalarValue::Boolean(Some(v.is_zero())), + )), + + _ => { + internal_err!( + "Unexpected scalar type for iszero: {:?}", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => match array.data_type() { + Null => Ok(ColumnarValue::Array(Arc::new(BooleanArray::new_null( + array.len(), + )))), + + Float64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0.0, + )))), + Float32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0.0, + )))), + Float16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x.is_zero(), + )))), + + Int8 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + Int16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + Int32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + Int64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + UInt8 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + UInt16 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + UInt32 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + UInt64 => Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))), + + Decimal32(_, _) => { + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))) + } + Decimal64(_, _) => { + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))) + } + Decimal128(_, _) => { + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x == 0, + )))) + } + Decimal256(_, _) => { + Ok(ColumnarValue::Array(Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x.is_zero(), + )))) + } + + other => { + internal_err!("Unexpected data type {other:?} for function iszero") + } + }, + } } fn documentation(&self) -> Option<&Documentation> { self.doc() } } - -/// Iszero SQL function -fn iszero(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - Float64 => Ok(Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - |x| x == 0.0, - )) as ArrayRef), - - Float32 => Ok(Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - |x| x == 0.0, - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function iszero"), - } -} - -#[cfg(test)] -mod test { - use std::sync::Arc; - - use arrow::array::{ArrayRef, Float32Array, Float64Array}; - - use datafusion_common::cast::as_boolean_array; - - use crate::math::iszero::iszero; - - #[test] - fn test_iszero_f64() { - let args: Vec = - vec![Arc::new(Float64Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; - - let result = iszero(&args).expect("failed to initialize function iszero"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function iszero"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } - - #[test] - fn test_iszero_f32() { - let args: Vec = - vec![Arc::new(Float32Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; - - let result = iszero(&args).expect("failed to initialize function iszero"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function iszero"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } -} diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs index bfb20dfd5ce41..9398e9f8d6e00 100644 --- a/datafusion/functions/src/math/lcm.rs +++ b/datafusion/functions/src/math/lcm.rs @@ -15,15 +15,16 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, Int64Array}; +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; +use arrow::compute::try_binary; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; +use arrow::datatypes::Int64Type; use arrow::error::ArrowError; -use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -69,10 +70,6 @@ impl LcmFunc { } impl ScalarUDFImpl for LcmFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "lcm" } @@ -96,7 +93,7 @@ impl ScalarUDFImpl for LcmFunc { /// Lcm SQL function fn lcm(args: &[ArrayRef]) -> Result { - let compute_lcm = |x: i64, y: i64| { + let compute_lcm = |x: i64, y: i64| -> Result { if x == 0 || y == 0 { return Ok(0); } @@ -110,55 +107,20 @@ fn lcm(args: &[ArrayRef]) -> Result { .checked_mul(b) .and_then(|v| i64::try_from(v).ok()) .ok_or_else(|| { - arrow_datafusion_err!(ArrowError::ComputeError(format!( + ArrowError::ComputeError(format!( "Signed integer overflow in LCM({x}, {y})" - ))) + )) }) }; match args[0].data_type() { Int64 => { - let arg1 = downcast_named_arg!(&args[0], "x", Int64Array); - let arg2 = downcast_named_arg!(&args[1], "y", Int64Array); - - Ok(arg1 - .iter() - .zip(arg2.iter()) - .map(|(a1, a2)| match (a1, a2) { - (Some(a1), Some(a2)) => Ok(Some(compute_lcm(a1, a2)?)), - _ => Ok(None), - }) - .collect::>() - .map(Arc::new)? as ArrayRef) + let arg1 = args[0].as_primitive::(); + let arg2 = args[1].as_primitive::(); + + let result: PrimitiveArray = try_binary(arg1, arg2, compute_lcm)?; + Ok(Arc::new(result) as ArrayRef) } other => exec_err!("Unsupported data type {other:?} for function lcm"), } } - -#[cfg(test)] -mod test { - use std::sync::Arc; - - use arrow::array::{ArrayRef, Int64Array}; - - use datafusion_common::cast::as_int64_array; - - use crate::math::lcm::lcm; - - #[test] - fn test_lcm_i64() { - let args: Vec = vec![ - Arc::new(Int64Array::from(vec![0, 3, 25, -16])), // x - Arc::new(Int64Array::from(vec![0, -2, 15, 8])), // y - ]; - - let result = lcm(&args).expect("failed to initialize function lcm"); - let ints = as_int64_array(&result).expect("failed to initialize function lcm"); - - assert_eq!(ints.len(), 4); - assert_eq!(ints.value(0), 0); - assert_eq!(ints.value(1), 6); - assert_eq!(ints.value(2), 75); - assert_eq!(ints.value(3), 16); - } -} diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 24000a3876bd2..2ca2ed1b572be 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -17,32 +17,30 @@ //! Math function: `log()`. -use std::any::Any; - use super::power::PowerFunc; -use crate::utils::{calculate_binary_math, decimal128_to_i128}; +use crate::utils::calculate_binary_math; use arrow::array::{Array, ArrayRef}; -use arrow::compute::kernels::cast; use arrow::datatypes::{ - DataType, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type, + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type, + Float32Type, Float64Type, }; use arrow::error::ArrowError; use arrow_buffer::i256; use datafusion_common::types::NativeType; use datafusion_common::{ - exec_err, internal_err, plan_datafusion_err, plan_err, Result, ScalarValue, + Result, ScalarValue, exec_err, internal_err, plan_datafusion_err, plan_err, }; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - lit, Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, - TypeSignature, TypeSignatureClass, + Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, + TypeSignature, TypeSignatureClass, lit, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; -use num_traits::Float; +use num_traits::{Float, ToPrimitive}; #[user_doc( doc_section(label = "Math Functions"), @@ -102,45 +100,92 @@ impl LogFunc { } } -/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base -/// Returns error if base is invalid -fn log_decimal128(value: i128, scale: i8, base: f64) -> Result { - if !base.is_finite() || base.trunc() != base { - return Err(ArrowError::ComputeError(format!( - "Log cannot use non-integer base: {base}" - ))); +/// Checks if the base is valid for the efficient integer logarithm algorithm. +#[inline] +fn is_valid_integer_base(base: f64) -> bool { + base.trunc() == base && base >= 2.0 && base <= u32::MAX as f64 +} + +/// Calculate logarithm for Decimal32 values. +/// For integer bases >= 2 with zero scale, return an exact integer log when the +/// value is a perfect power of the base. Otherwise falls back to f64 computation. +fn log_decimal32(value: i32, scale: i8, base: f64) -> Result { + if scale == 0 + && is_valid_integer_base(base) + && let Ok(unscaled) = u32::try_from(value) + && unscaled > 0 + { + let base_u32 = base as u32; + let int_log = unscaled.ilog(base_u32); + if base_u32.checked_pow(int_log) == Some(unscaled) { + return Ok(int_log as f64); + } } - if (base as u32) < 2 { - return Err(ArrowError::ComputeError(format!( - "Log base must be greater than 1: {base}" - ))); + decimal_to_f64(value, scale).map(|v| v.log(base)) +} + +/// Calculate logarithm for Decimal64 values. +/// For integer bases >= 2 with zero scale, return an exact integer log when the +/// value is a perfect power of the base. Otherwise falls back to f64 computation. +fn log_decimal64(value: i64, scale: i8, base: f64) -> Result { + if scale == 0 + && is_valid_integer_base(base) + && let Ok(unscaled) = u64::try_from(value) + && unscaled > 0 + { + let base_u64 = base as u64; + let int_log = unscaled.ilog(base_u64); + if base_u64.checked_pow(int_log) == Some(unscaled) { + return Ok(int_log as f64); + } } + decimal_to_f64(value, scale).map(|v| v.log(base)) +} - let unscaled_value = decimal128_to_i128(value, scale)?; - if unscaled_value > 0 { - let log_value: u32 = unscaled_value.ilog(base as i128); - Ok(log_value as f64) - } else { - // Reflect f64::log behaviour - Ok(f64::NAN) +/// Calculate logarithm for Decimal128 values. +/// For integer bases >= 2 with zero scale, return an exact integer log when the +/// value is a perfect power of the base. Otherwise falls back to f64 computation. +fn log_decimal128(value: i128, scale: i8, base: f64) -> Result { + if scale == 0 + && is_valid_integer_base(base) + && let Ok(unscaled) = u128::try_from(value) + && unscaled > 0 + { + let base_u128 = base as u128; + let int_log = unscaled.ilog(base_u128); + if base_u128.checked_pow(int_log) == Some(unscaled) { + return Ok(int_log as f64); + } } + decimal_to_f64(value, scale).map(|v| v.log(base)) +} + +/// Convert a scaled decimal value to f64. +#[inline] +fn decimal_to_f64(value: T, scale: i8) -> Result { + let value_f64 = value.to_f64().ok_or_else(|| { + ArrowError::ComputeError("Cannot convert value to f64".to_string()) + })?; + let scale_factor = 10f64.powi(scale as i32); + Ok(value_f64 / scale_factor) } -/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base -/// Returns error if base is invalid or if value is out of bounds of Decimal128 fn log_decimal256(value: i256, scale: i8, base: f64) -> Result { + // Try to convert to i128 for the optimized path match value.to_i128() { - Some(value) => log_decimal128(value, scale, base), - None => Err(ArrowError::NotYetImplemented(format!( - "Log of Decimal256 larger than Decimal128 is not yet supported: {value}" - ))), + Some(v) => log_decimal128(v, scale, base), + None => { + // For very large Decimal256 values, use f64 computation + let value_f64 = value.to_f64().ok_or_else(|| { + ArrowError::ComputeError(format!("Cannot convert {value} to f64")) + })?; + let scale_factor = 10f64.powi(scale as i32); + Ok((value_f64 / scale_factor).log(base)) + } } } impl ScalarUDFImpl for LogFunc { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "log" } @@ -223,15 +268,18 @@ impl ScalarUDFImpl for LogFunc { |value, base| Ok(value.log(base)), )? } - // TODO: native log support for decimal 32 & 64; right now upcast - // to decimal128 to calculate - // https://github.com/apache/datafusion/issues/17555 - DataType::Decimal32(precision, scale) - | DataType::Decimal64(precision, scale) => { - calculate_binary_math::( - &cast(&value, &DataType::Decimal128(*precision, *scale))?, + DataType::Decimal32(_, scale) => { + calculate_binary_math::( + &value, &base, - |value, base| log_decimal128(value, *scale, base), + |value, base| log_decimal32(value, *scale, base), + )? + } + DataType::Decimal64(_, scale) => { + calculate_binary_math::( + &value, + &base, + |value, base| log_decimal64(value, *scale, base), )? } DataType::Decimal128(_, scale) => { @@ -249,7 +297,7 @@ impl ScalarUDFImpl for LogFunc { )? } other => { - return exec_err!("Unsupported data type {other:?} for function log") + return exec_err!("Unsupported data type {other:?} for function log"); } }; @@ -267,7 +315,7 @@ impl ScalarUDFImpl for LogFunc { fn simplify( &self, mut args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { let mut arg_types = args .iter() @@ -289,6 +337,19 @@ impl ScalarUDFImpl for LogFunc { if num_args != 1 && num_args != 2 { return plan_err!("Expected log to have 1 or 2 arguments, got {num_args}"); } + + match arg_types.last().unwrap() { + DataType::Decimal32(_, scale) + | DataType::Decimal64(_, scale) + | DataType::Decimal128(_, scale) + | DataType::Decimal256(_, scale) + if *scale < 0 => + { + return Ok(ExprSimplifyResult::Original(args)); + } + _ => (), + }; + let number = args.pop().unwrap(); let number_datatype = arg_types.pop().unwrap(); // default to base 10 @@ -324,7 +385,7 @@ impl ScalarUDFImpl for LogFunc { _ => { return internal_err!( "Unexpected number of arguments in log::simplify" - ) + ); } }; Ok(ExprSimplifyResult::Original(args)) @@ -336,12 +397,11 @@ impl ScalarUDFImpl for LogFunc { /// Returns true if the function is `PowerFunc` fn is_pow(func: &ScalarUDF) -> bool { - func.inner().as_any().downcast_ref::().is_some() + func.inner().is::() } #[cfg(test)] mod tests { - use std::collections::HashMap; use std::sync::Arc; use super::*; @@ -350,23 +410,16 @@ mod tests { Date32Array, Decimal128Array, Decimal256Array, Float32Array, Float64Array, }; use arrow::compute::SortOptions; - use arrow::datatypes::{Field, DECIMAL256_MAX_PRECISION}; + use arrow::datatypes::{DECIMAL256_MAX_PRECISION, Field}; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::config::ConfigOptions; - use datafusion_common::DFSchema; - use datafusion_expr::execution_props::ExecutionProps; - use datafusion_expr::simplify::SimplifyContext; #[test] fn test_log_decimal_native() { let value = 10_i128.pow(35); - assert_eq!((value as f64).log2(), 116.26748332105768); - assert_eq!( - log_decimal128(value, 0, 2.0).unwrap(), - // TODO: see we're losing our decimal points compared to above - // https://github.com/apache/datafusion/issues/18524 - 116.0 - ); + let expected = (value as f64).log2(); + let actual = log_decimal128(value, 0, 2.0).unwrap(); + assert!((actual - expected).abs() < 1e-10); } #[test] @@ -695,10 +748,7 @@ mod tests { #[test] // Test log() simplification errors fn test_log_simplify_errors() { - let props = ExecutionProps::new(); - let schema = - Arc::new(DFSchema::new_with_metadata(vec![], HashMap::new()).unwrap()); - let context = SimplifyContext::new(&props).with_schema(schema); + let context = SimplifyContext::default(); // Expect 0 args to error let _ = LogFunc::new().simplify(vec![], &context).unwrap_err(); // Expect 3 args to error @@ -710,10 +760,7 @@ mod tests { #[test] // Test that non-simplifiable log() expressions are unchanged after simplification fn test_log_simplify_original() { - let props = ExecutionProps::new(); - let schema = - Arc::new(DFSchema::new_with_metadata(vec![], HashMap::new()).unwrap()); - let context = SimplifyContext::new(&props).with_schema(schema); + let context = SimplifyContext::default(); // One argument with no simplifications let result = LogFunc::new().simplify(vec![lit(2)], &context).unwrap(); let ExprSimplifyResult::Original(args) = result else { @@ -934,7 +981,8 @@ mod tests { assert!((floats.value(1) - 2.0).abs() < 1e-10); assert!((floats.value(2) - 3.0).abs() < 1e-10); assert!((floats.value(3) - 4.0).abs() < 1e-10); - assert!((floats.value(4) - 4.0).abs() < 1e-10); // Integer rounding + let expected = 12600_f64.log(10.0); + assert!((floats.value(4) - expected).abs() < 1e-10); assert!(floats.value(5).is_nan()); } ColumnarValue::Scalar(_) => { @@ -1064,13 +1112,14 @@ mod tests { .expect("failed to convert result to a Float64Array"); assert_eq!(floats.len(), 7); - eprintln!("floats {:?}", &floats); assert!((floats.value(0) - 1.0).abs() < 1e-10); assert!((floats.value(1) - 2.0).abs() < 1e-10); assert!((floats.value(2) - 3.0).abs() < 1e-10); assert!((floats.value(3) - 4.0).abs() < 1e-10); - assert!((floats.value(4) - 4.0).abs() < 1e-10); // Integer rounding for float log - assert!((floats.value(5) - 38.0).abs() < 1e-10); + let expected = 12600_f64.log(10.0); + assert!((floats.value(4) - expected).abs() < 1e-10); + let expected = ((i128::MAX - 1000) as f64).log(10.0); + assert!((floats.value(5) - expected).abs() < 1e-10); assert!(floats.value(6).is_nan()); } ColumnarValue::Scalar(_) => { @@ -1080,7 +1129,8 @@ mod tests { } #[test] - fn test_log_decimal128_wrong_base() { + fn test_log_decimal128_invalid_base() { + // Invalid base (-2.0) should return NaN, matching f64::log behavior let arg_fields = vec![ Field::new("b", DataType::Float64, false).into(), Field::new("x", DataType::Decimal128(38, 0), false).into(), @@ -1095,17 +1145,32 @@ mod tests { return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), }; - let result = LogFunc::new().invoke_with_args(args); - assert!(result.is_err()); - assert_eq!( - "Arrow error: Compute error: Log base must be greater than 1: -2", - result.unwrap_err().to_string().lines().next().unwrap() - ); + let result = LogFunc::new() + .invoke_with_args(args) + .expect("should not error on invalid base"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + assert_eq!(floats.len(), 1); + assert!(floats.value(0).is_nan()); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } } #[test] - fn test_log_decimal256_error() { - let arg_field = Field::new("a", DataType::Decimal256(38, 0), false).into(); + fn test_log_decimal256_large() { + // Large Decimal256 values that don't fit in i128 now use f64 fallback + let arg_field = Field::new( + "a", + DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0), + false, + ) + .into(); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Decimal256Array::from(vec![ @@ -1118,10 +1183,26 @@ mod tests { return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), }; - let result = LogFunc::new().invoke_with_args(args); - assert!(result.is_err()); - assert_eq!(result.unwrap_err().to_string().lines().next().unwrap(), - "Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported: 170141183460469231731687303715884106727" - ); + let result = LogFunc::new() + .invoke_with_args(args) + .expect("should handle large Decimal256 via f64 fallback"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + assert_eq!(floats.len(), 1); + // The f64 fallback may lose some precision for very large numbers, + // but we verify we get a reasonable positive result (not NaN/infinity) + let log_result = floats.value(0); + assert!( + log_result.is_finite() && log_result > 0.0, + "Expected positive finite log result, got {log_result}" + ); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } } } diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 4eb337a30110e..1754ccb43488a 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -18,13 +18,17 @@ //! "math" DataFusion functions use crate::math::monotonicity::*; +use datafusion_common::{Result, exec_err}; use datafusion_expr::ScalarUDF; use std::sync::Arc; pub mod abs; pub mod bounds; +pub mod ceil; pub mod cot; +mod decimal; pub mod factorial; +pub mod floor; pub mod gcd; pub mod iszero; pub mod lcm; @@ -39,6 +43,14 @@ pub mod round; pub mod signum; pub mod trunc; +fn validate_sqrt_input(value: f64) -> Result<()> { + if value < 0.0 { + exec_err!("cannot take square root of a negative number") + } else { + Ok(()) + } +} + // Create UDFs make_udf_function!(abs::AbsFunc, abs); make_math_unary_udf!( @@ -104,14 +116,7 @@ make_math_unary_udf!( super::bounds::unbounded_bounds, super::get_cbrt_doc ); -make_math_unary_udf!( - CeilFunc, - ceil, - ceil, - super::ceil_order, - super::bounds::unbounded_bounds, - super::get_ceil_doc -); +make_udf_function!(ceil::CeilFunc, ceil); make_math_unary_udf!( CosFunc, cos, @@ -146,14 +151,7 @@ make_math_unary_udf!( super::get_exp_doc ); make_udf_function!(factorial::FactorialFunc, factorial); -make_math_unary_udf!( - FloorFunc, - floor, - floor, - super::floor_order, - super::bounds::unbounded_bounds, - super::get_floor_doc -); +make_udf_function!(floor::FloorFunc, floor); make_udf_function!(log::LogFunc, log); make_udf_function!(gcd::GcdFunc, gcd); make_udf_function!(nans::IsNanFunc, isnan); @@ -219,7 +217,8 @@ make_math_unary_udf!( sqrt, super::sqrt_order, super::bounds::sqrt_bounds, - super::get_sqrt_doc + super::get_sqrt_doc, + Some(super::validate_sqrt_input) ); make_math_unary_udf!( TanFunc, diff --git a/datafusion/functions/src/math/monotonicity.rs b/datafusion/functions/src/math/monotonicity.rs index ffb3df4196d20..52449f9c9e0b9 100644 --- a/datafusion/functions/src/math/monotonicity.rs +++ b/datafusion/functions/src/math/monotonicity.rs @@ -17,11 +17,11 @@ use std::sync::LazyLock; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_doc::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::Documentation; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::Documentation; /// Non-increasing on the interval \[−1, 1\], undefined otherwise. pub fn acos_order(input: &[ExprProperties]) -> Result { @@ -262,11 +262,11 @@ Can be a constant, column, or function, and any combination of arithmetic operat ) .with_sql_example(r#"```sql > SELECT atan2(1, 1); -+------------+ -| atan2(1,1) | -+------------+ -| 0.7853982 | -+------------+ ++--------------------+ +| atan2(1,1) | ++--------------------+ +| 0.7853981633974483 | ++--------------------+ ```"#) .build() }); @@ -309,30 +309,6 @@ pub fn ceil_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } -static DOCUMENTATION_CEIL: LazyLock = LazyLock::new(|| { - Documentation::builder( - DOC_SECTION_MATH, - "Returns the nearest integer greater than or equal to a number.", - "ceil(numeric_expression)", - ) - .with_standard_argument("numeric_expression", Some("Numeric")) - .with_sql_example( - r#"```sql - > SELECT ceil(3.14); -+------------+ -| ceil(3.14) | -+------------+ -| 4.0 | -+------------+ -```"#, - ) - .build() -}); - -pub fn get_ceil_doc() -> &'static Documentation { - &DOCUMENTATION_CEIL -} - /// Non-increasing on \[0, π\] and then non-decreasing on \[π, 2π\]. /// This pattern repeats periodically with a period of 2π. // TODO: Implement ordering rule of the ATAN2 function. @@ -467,30 +443,6 @@ pub fn floor_order(input: &[ExprProperties]) -> Result { Ok(input[0].sort_properties) } -static DOCUMENTATION_FLOOR: LazyLock = LazyLock::new(|| { - Documentation::builder( - DOC_SECTION_MATH, - "Returns the nearest integer less than or equal to a number.", - "floor(numeric_expression)", - ) - .with_standard_argument("numeric_expression", Some("Numeric")) - .with_sql_example( - r#"```sql -> SELECT floor(3.14); -+-------------+ -| floor(3.14) | -+-------------+ -| 3.0 | -+-------------+ -```"#, - ) - .build() -}); - -pub fn get_floor_doc() -> &'static Documentation { - &DOCUMENTATION_FLOOR -} - /// Non-decreasing for x ≥ 0, undefined otherwise. pub fn ln_order(input: &[ExprProperties]) -> Result { let arg = &input[0]; @@ -785,7 +737,6 @@ pub fn get_tanh_doc() -> &'static Documentation { #[cfg(test)] mod tests { use arrow::compute::SortOptions; - use datafusion_common::Result; use super::*; diff --git a/datafusion/functions/src/math/nans.rs b/datafusion/functions/src/math/nans.rs index 759b0f5fd50ac..c5ea2fa079a45 100644 --- a/datafusion/functions/src/math/nans.rs +++ b/datafusion/functions/src/math/nans.rs @@ -17,14 +17,22 @@ //! Math function: `isnan()`. -use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, TypeSignature}; - use arrow::array::{ArrayRef, AsArray, BooleanArray}; -use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; +use arrow::datatypes::DataType::{ + Decimal32, Decimal64, Decimal128, Decimal256, Float16, Float32, Float64, Int8, Int16, + Int32, Int64, Null, UInt8, UInt16, UInt32, UInt64, +}; +use arrow::datatypes::{ + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type, + Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, + UInt16Type, UInt32Type, UInt64Type, +}; +use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; #[user_doc( @@ -54,23 +62,15 @@ impl Default for IsNanFunc { impl IsNanFunc { pub fn new() -> Self { - use DataType::*; + // Accept any numeric type (ints, uints, floats, decimals) without implicit casts. + let numeric = Coercion::new_exact(TypeSignatureClass::Numeric); Self { - signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Float32]), - TypeSignature::Exact(vec![Float64]), - ], - Volatility::Immutable, - ), + signature: Signature::coercible(vec![numeric], Volatility::Immutable), } } } impl ScalarUDFImpl for IsNanFunc { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "isnan" } @@ -84,26 +84,123 @@ impl ScalarUDFImpl for IsNanFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - f64::is_nan, - )) as ArrayRef, - - DataType::Float32 => Arc::new(BooleanArray::from_unary( - args[0].as_primitive::(), - f32::is_nan, - )) as ArrayRef, - other => { - return exec_err!( - "Unsupported data type {other:?} for function {}", - self.name() - ) + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + + let result = match scalar { + ScalarValue::Float64(Some(v)) => Some(v.is_nan()), + ScalarValue::Float32(Some(v)) => Some(v.is_nan()), + ScalarValue::Float16(Some(v)) => Some(v.is_nan()), + + // Non-float numeric inputs are never NaN + ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + | ScalarValue::Decimal32(_, _, _) + | ScalarValue::Decimal64(_, _, _) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Decimal256(_, _, _) => Some(false), + + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result))) } - }; - Ok(ColumnarValue::Array(arr)) + ColumnarValue::Array(array) => { + // NOTE: BooleanArray::from_unary preserves nulls. + let arr: ArrayRef = match array.data_type() { + Null => Arc::new(BooleanArray::new_null(array.len())) as ArrayRef, + + Float64 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + f64::is_nan, + )) as ArrayRef, + Float32 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + f32::is_nan, + )) as ArrayRef, + Float16 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |x| x.is_nan(), + )) as ArrayRef, + + // Non-float numeric arrays are never NaN + Decimal32(_, _) => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Decimal64(_, _) => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Decimal128(_, _) => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Decimal256(_, _) => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + + Int8 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Int16 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Int32 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + Int64 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + UInt8 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + UInt16 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + UInt32 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + UInt64 => Arc::new(BooleanArray::from_unary( + array.as_primitive::(), + |_| false, + )) as ArrayRef, + + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + + Ok(ColumnarValue::Array(arr)) + } + } } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs index f0835b4d48a0c..251e98bb72c03 100644 --- a/datafusion/functions/src/math/nanvl.rs +++ b/datafusion/functions/src/math/nanvl.rs @@ -15,15 +15,12 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; -use crate::utils::make_scalar_function; - -use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array}; -use arrow::datatypes::DataType::{Float32, Float64}; -use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{exec_err, DataFusionError, Result}; +use arrow::array::{ArrayRef, AsArray, Float16Array, Float32Array, Float64Array}; +use arrow::datatypes::DataType::{Float16, Float32, Float64}; +use arrow::datatypes::{DataType, Float16Type, Float32Type, Float64Type}; +use datafusion_common::{Result, ScalarValue, exec_err, utils::take_function_args}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -66,10 +63,13 @@ impl Default for NanvlFunc { impl NanvlFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], + vec![ + Exact(vec![Float16, Float16]), + Exact(vec![Float32, Float32]), + Exact(vec![Float64, Float64]), + ], Volatility::Immutable, ), } @@ -77,10 +77,6 @@ impl NanvlFunc { } impl ScalarUDFImpl for NanvlFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "nanvl" } @@ -91,13 +87,31 @@ impl ScalarUDFImpl for NanvlFunc { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { + Float16 => Ok(Float16), Float32 => Ok(Float32), _ => Ok(Float64), } } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(nanvl, vec![])(&args.args) + let [x, y] = take_function_args(self.name(), args.args)?; + + match (x, y) { + (ColumnarValue::Scalar(ScalarValue::Float16(Some(v))), y) if v.is_nan() => { + Ok(y) + } + (ColumnarValue::Scalar(ScalarValue::Float32(Some(v))), y) if v.is_nan() => { + Ok(y) + } + (ColumnarValue::Scalar(ScalarValue::Float64(Some(v))), y) if v.is_nan() => { + Ok(y) + } + (x @ ColumnarValue::Scalar(_), _) => Ok(x), + (x, y) => { + let args = ColumnarValue::values_to_arrays(&[x, y])?; + Ok(ColumnarValue::Array(nanvl(&args)?)) + } + } } fn documentation(&self) -> Option<&Documentation> { @@ -106,37 +120,49 @@ impl ScalarUDFImpl for NanvlFunc { } /// Nanvl SQL function +/// +/// - x is NaN -> output is y (which may itself be NULL) +/// - otherwise -> output is x (which may itself be NULL) fn nanvl(args: &[ArrayRef]) -> Result { match args[0].data_type() { Float64 => { - let compute_nanvl = |x: f64, y: f64| { - if x.is_nan() { - y - } else { - x - } - }; - - let x = args[0].as_primitive() as &Float64Array; - let y = args[1].as_primitive() as &Float64Array; - arrow::compute::binary::<_, _, _, Float64Type>(x, y, compute_nanvl) - .map(|res| Arc::new(res) as _) - .map_err(DataFusionError::from) + let x = args[0].as_primitive::(); + let y = args[1].as_primitive::(); + let result: Float64Array = x + .iter() + .zip(y.iter()) + .map(|(x_value, y_value)| match x_value { + Some(x_value) if x_value.is_nan() => y_value, + _ => x_value, + }) + .collect(); + Ok(Arc::new(result) as ArrayRef) } Float32 => { - let compute_nanvl = |x: f32, y: f32| { - if x.is_nan() { - y - } else { - x - } - }; - - let x = args[0].as_primitive() as &Float32Array; - let y = args[1].as_primitive() as &Float32Array; - arrow::compute::binary::<_, _, _, Float32Type>(x, y, compute_nanvl) - .map(|res| Arc::new(res) as _) - .map_err(DataFusionError::from) + let x = args[0].as_primitive::(); + let y = args[1].as_primitive::(); + let result: Float32Array = x + .iter() + .zip(y.iter()) + .map(|(x_value, y_value)| match x_value { + Some(x_value) if x_value.is_nan() => y_value, + _ => x_value, + }) + .collect(); + Ok(Arc::new(result) as ArrayRef) + } + Float16 => { + let x = args[0].as_primitive::(); + let y = args[1].as_primitive::(); + let result: Float16Array = x + .iter() + .zip(y.iter()) + .map(|(x_value, y_value)| match x_value { + Some(x_value) if x_value.is_nan() => y_value, + _ => x_value, + }) + .collect(); + Ok(Arc::new(result) as ArrayRef) } other => exec_err!("Unsupported data type {other:?} for function nanvl"), } @@ -154,8 +180,8 @@ mod test { #[test] fn test_nanvl_f64() { let args: Vec = vec![ - Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), // y - Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), // x + Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), // x + Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), // y ]; let result = nanvl(&args).expect("failed to initialize function nanvl"); @@ -172,8 +198,8 @@ mod test { #[test] fn test_nanvl_f32() { let args: Vec = vec![ - Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), // y - Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), // x + Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), // x + Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), // y ]; let result = nanvl(&args).expect("failed to initialize function nanvl"); diff --git a/datafusion/functions/src/math/pi.rs b/datafusion/functions/src/math/pi.rs index 92a27932e1649..55707f2c71feb 100644 --- a/datafusion/functions/src/math/pi.rs +++ b/datafusion/functions/src/math/pi.rs @@ -15,11 +15,9 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::datatypes::DataType; use arrow::datatypes::DataType::Float64; -use datafusion_common::{assert_or_internal_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, assert_or_internal_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -52,10 +50,6 @@ impl PiFunc { } impl ScalarUDFImpl for PiFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "pi" } diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 5d6a8bcfdef2d..252a3ea0b31d7 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -16,28 +16,34 @@ // under the License. //! Math function: `power()`. -use std::any::Any; - use super::log::LogFunc; -use crate::utils::{calculate_binary_decimal_math, calculate_binary_math}; +use crate::utils::calculate_binary_math; use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::{ - ArrowNativeTypeOp, DataType, Decimal128Type, Decimal256Type, Decimal32Type, - Decimal64Type, Float64Type, Int64Type, -}; +use arrow::datatypes::{DataType, Float64Type}; use arrow::error::ArrowError; +use datafusion_common::types::{NativeType, logical_float64}; use datafusion_common::utils::take_function_args; -use datafusion_common::{exec_err, plan_datafusion_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::type_coercion::is_decimal; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, + Cast, Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, lit, }; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; +/// Matches PostgreSQL: `power(0::float8, negative)` is undefined (IEEE 754 would yield infinity). +#[inline] +fn float64_power_checked(base: f64, exp: f64) -> Result { + if base == 0.0 && exp < 0.0 { + return Err(ArrowError::ComputeError( + "zero raised to a negative power is undefined".to_string(), + )); + } + Ok(base.powf(exp)) +} + #[user_doc( doc_section(label = "Math Functions"), description = "Returns a base expression raised to the power of an exponent.", @@ -67,92 +73,19 @@ impl Default for PowerFunc { impl PowerFunc { pub fn new() -> Self { + let float = Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ); Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::coercible(vec![float; 2], Volatility::Immutable), aliases: vec![String::from("pow")], } } } -/// Binary function to calculate a math power to integer exponent -/// for scaled integer types. -/// -/// Formula -/// The power for a scaled integer `b` is -/// -/// ```text -/// (b * 10^(-s)) ^ e -/// ``` -/// However, the result should be scaled back from scale 0 to scale `s`, -/// which is done by multiplying by `10^s`. -/// At the end, the formula is: -/// -/// ```text -/// b^e * 10^(-s * e) * 10^s = b^e / 10^(s * (e-1)) -/// ``` -/// Example of 2.5 ^ 4 = 39: -/// 2.5 is represented as 25 with scale 1 -/// The unscaled result is 25^4 = 390625 -/// Scale it back to 1: 390625 / 10^4 = 39 -/// -/// Returns error if base is invalid -fn pow_decimal_int(base: T, scale: i8, exp: i64) -> Result -where - T: From + ArrowNativeTypeOp, -{ - let scale: u32 = scale.try_into().map_err(|_| { - ArrowError::NotYetImplemented(format!( - "Negative scale is not yet supported value: {scale}" - )) - })?; - if exp == 0 { - // Edge case to provide 1 as result (10^s with scale) - let result: T = T::from(10).pow_checked(scale).map_err(|_| { - ArrowError::ArithmeticOverflow(format!( - "Cannot make unscale factor for {scale} and {exp}" - )) - })?; - return Ok(result); - } - let exp: u32 = exp.try_into().map_err(|_| { - ArrowError::ArithmeticOverflow(format!("Unsupported exp value: {exp}")) - })?; - let powered: T = base.pow_checked(exp).map_err(|_| { - ArrowError::ArithmeticOverflow(format!("Cannot raise base {base:?} to exp {exp}")) - })?; - let unscale_factor: T = T::from(10).pow_checked(scale * (exp - 1)).map_err(|_| { - ArrowError::ArithmeticOverflow(format!( - "Cannot make unscale factor for {scale} and {exp}" - )) - })?; - - powered.div_checked(unscale_factor) -} - -/// Binary function to calculate a math power to float exponent -/// for scaled integer types. -/// Returns error if exponent is negative or non-integer, or base invalid -fn pow_decimal_float(base: T, scale: i8, exp: f64) -> Result -where - T: From + ArrowNativeTypeOp, -{ - if !exp.is_finite() || exp.trunc() != exp { - return Err(ArrowError::ComputeError(format!( - "Cannot use non-integer exp: {exp}" - ))); - } - if exp < 0f64 || exp >= u32::MAX as f64 { - return Err(ArrowError::ArithmeticOverflow(format!( - "Unsupported exp value: {exp}" - ))); - } - pow_decimal_int(base, scale, exp as i64) -} - impl ScalarUDFImpl for PowerFunc { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "power" } @@ -162,184 +95,93 @@ impl ScalarUDFImpl for PowerFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + let [_base, _exponent] = take_function_args(self.name(), arg_types)?; + Ok(DataType::Float64) } fn aliases(&self) -> &[String] { &self.aliases } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let [arg1, arg2] = take_function_args(self.name(), arg_types)?; - - fn coerced_type_base(name: &str, data_type: &DataType) -> Result { - match data_type { - DataType::Null => Ok(DataType::Int64), - d if d.is_floating() => Ok(DataType::Float64), - d if d.is_integer() => Ok(DataType::Int64), - d if is_decimal(d) => Ok(d.clone()), - other => { - exec_err!("Unsupported data type {other:?} for {} function", name) - } - } - } - - fn coerced_type_exp(name: &str, data_type: &DataType) -> Result { - match data_type { - DataType::Null => Ok(DataType::Int64), - d if d.is_floating() => Ok(DataType::Float64), - d if d.is_integer() => Ok(DataType::Int64), - d if is_decimal(d) => Ok(DataType::Float64), - other => { - exec_err!("Unsupported data type {other:?} for {} function", name) - } - } - } - - Ok(vec![ - coerced_type_base(self.name(), arg1)?, - coerced_type_exp(self.name(), arg2)?, - ]) - } - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let base = &args.args[0].to_array(args.number_rows)?; - let exponent = &args.args[1]; + let [base, exponent] = take_function_args(self.name(), &args.args)?; + let base = base.to_array(args.number_rows)?; let arr: ArrayRef = match (base.data_type(), exponent.data_type()) { - (DataType::Float64, _) => { + (DataType::Float64, DataType::Float64) => { calculate_binary_math::( &base, exponent, - |b, e| Ok(f64::powf(b, e)), - )? - } - (DataType::Int64, _) => { - calculate_binary_math::( - &base, - exponent, - |b, e| match e.try_into() { - Ok(exp_u32) => b.pow_checked(exp_u32), - Err(_) => Err(ArrowError::ArithmeticOverflow(format!( - "Exponent {e} in integer computation is out of bounds." - ))), - }, + float64_power_checked, )? } - (DataType::Decimal32(precision, scale), DataType::Int64) => - calculate_binary_decimal_math::( - &base, - exponent, - |b, e| pow_decimal_int(b, *scale, e), - *precision, - *scale, - )?, - (DataType::Decimal32(precision, scale), DataType::Float64) => - calculate_binary_decimal_math::( - &base, - exponent, - |b, e| pow_decimal_float(b, *scale, e), - *precision, - *scale, - )?, - (DataType::Decimal64(precision, scale), DataType::Int64) => - calculate_binary_decimal_math::( - &base, - exponent, - |b, e| pow_decimal_int(b, *scale, e), - *precision, - *scale, - )?, - (DataType::Decimal64(precision, scale), DataType::Float64) => - calculate_binary_decimal_math::( - &base, - exponent, - |b, e| pow_decimal_float(b, *scale, e), - *precision, - *scale, - )?, - (DataType::Decimal128(precision, scale), DataType::Int64) => - calculate_binary_decimal_math::( - &base, - exponent, - |b, e| pow_decimal_int(b, *scale, e), - *precision, - *scale, - )?, - (DataType::Decimal128(precision, scale), DataType::Float64) => - calculate_binary_decimal_math::< - Decimal128Type, - Float64Type, - Decimal128Type, - _, - >(&base, exponent, |b, e| - pow_decimal_float(b, *scale, e), - *precision, - *scale, - )?, - (DataType::Decimal256(precision, scale),DataType::Int64) => - calculate_binary_decimal_math::( - &base, - exponent, - |b, e| pow_decimal_int(b, *scale, e), - *precision, - *scale, - )?, - (DataType::Decimal256(precision, scale), DataType::Float64) => - calculate_binary_decimal_math::< - Decimal256Type, - Float64Type, - Decimal256Type, - _, - >(&base, exponent, |b, e| - pow_decimal_float(b, *scale, e) , - *precision, - *scale, - )?, (base_type, exp_type) => { - return exec_err!( - "Unsupported data types for base {base_type:?} and exponent {exp_type:?} for function {}", - self.name() - ) + return internal_err!( + "Unsupported data types for base {base_type:?} and exponent {exp_type:?} for power" + ); } }; Ok(ColumnarValue::Array(arr)) } /// Simplify the `power` function by the relevant rules: - /// 1. Power(a, 0) ===> 0 + /// 1. Power(a, 0) ===> 1 /// 2. Power(a, 1) ===> a /// 3. Power(a, Log(a, b)) ===> b fn simplify( &self, - mut args: Vec, - info: &dyn SimplifyInfo, + args: Vec, + info: &SimplifyContext, ) -> Result { - let exponent = args.pop().ok_or_else(|| { - plan_datafusion_err!("Expected power to have 2 arguments, got 0") - })?; - let base = args.pop().ok_or_else(|| { - plan_datafusion_err!("Expected power to have 2 arguments, got 1") - })?; - + let [base, exponent] = take_function_args("power", args)?; + let base_type = info.get_data_type(&base)?; let exponent_type = info.get_data_type(&exponent)?; + let return_type = + self.return_type(&[base_type.clone(), exponent_type.clone()])?; + + // Null propagation + if base_type.is_null() || exponent_type.is_null() { + return Ok(ExprSimplifyResult::Simplified(lit( + ScalarValue::Null.cast_to(&return_type)? + ))); + } + + // `simplify` runs on the logical expression *before* type coercion, + // so a simplified sub-expression may still carry its original type + // rather than the Float64 that `power` is declared to return. Cast it + // back when needed to preserve the schema the optimizer already + // committed to — e.g. `power(int_col, 1)` simplifies to `int_col`, + // and the `b` in `power(b, log(b, uint_col))` simplifies to `uint_col`, + // both of which must become Float64. + let cast_to_return_type = |expr: Expr, expr_type: &DataType| { + if expr_type == &return_type { + expr + } else { + Expr::Cast(Cast::new(Box::new(expr), return_type.clone())) + } + }; + match exponent { Expr::Literal(value, _) if value == ScalarValue::new_zero(&exponent_type)? => { - Ok(ExprSimplifyResult::Simplified(Expr::Literal( - ScalarValue::new_one(&info.get_data_type(&base)?)?, - None, - ))) + Ok(ExprSimplifyResult::Simplified(lit(ScalarValue::new_one( + &return_type, + )?))) } Expr::Literal(value, _) if value == ScalarValue::new_one(&exponent_type)? => { - Ok(ExprSimplifyResult::Simplified(base)) + Ok(ExprSimplifyResult::Simplified(cast_to_return_type( + base, &base_type, + ))) } Expr::ScalarFunction(ScalarFunction { func, mut args }) if is_log(&func) && args.len() == 2 && base == args[0] => { let b = args.pop().unwrap(); // length checked above - Ok(ExprSimplifyResult::Simplified(b)) + let b_type = info.get_data_type(&b)?; + Ok(ExprSimplifyResult::Simplified(cast_to_return_type( + b, &b_type, + ))) } _ => Ok(ExprSimplifyResult::Original(vec![base, exponent])), } @@ -352,267 +194,20 @@ impl ScalarUDFImpl for PowerFunc { /// Return true if this function call is a call to `Log` fn is_log(func: &ScalarUDF) -> bool { - func.inner().as_any().downcast_ref::().is_some() + func.inner().is::() } #[cfg(test)] mod tests { use super::*; - use arrow::array::{Array, Decimal128Array, Float64Array, Int64Array}; - use arrow::datatypes::{Field, DECIMAL128_MAX_SCALE}; - use arrow_buffer::NullBuffer; - use datafusion_common::cast::{ - as_decimal128_array, as_float64_array, as_int64_array, - }; - use datafusion_common::config::ConfigOptions; - use std::sync::Arc; - - #[cfg(test)] - #[ctor::ctor] - fn init() { - // Enable RUST_LOG logging configuration for test - let _ = env_logger::try_init(); - } - - #[test] - fn test_power_f64() { - let arg_fields = vec![ - Field::new("a", DataType::Float64, true).into(), - Field::new("a", DataType::Float64, true).into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 2.0, 2.0, 3.0, 5.0, - ]))), // base - ColumnarValue::Array(Arc::new(Float64Array::from(vec![ - 3.0, 2.0, 4.0, 4.0, - ]))), // exponent - ], - arg_fields, - number_rows: 4, - return_field: Field::new("f", DataType::Float64, true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = PowerFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let floats = as_float64_array(&arr) - .expect("failed to convert result to a Float64Array"); - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), 8.0); - assert_eq!(floats.value(1), 4.0); - assert_eq!(floats.value(2), 81.0); - assert_eq!(floats.value(3), 625.0); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } - - #[test] - fn test_power_i64() { - let arg_fields = vec![ - Field::new("a", DataType::Int64, true).into(), - Field::new("a", DataType::Int64, true).into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base - ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent - ], - arg_fields, - number_rows: 4, - return_field: Field::new("f", DataType::Int64, true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = PowerFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let ints = as_int64_array(&arr) - .expect("failed to convert result to a Int64Array"); - - assert_eq!(ints.len(), 4); - assert_eq!(ints.value(0), 8); - assert_eq!(ints.value(1), 4); - assert_eq!(ints.value(2), 81); - assert_eq!(ints.value(3), 625); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } #[test] - fn test_power_i128() { - let arg_fields = vec![ - Field::new( - "a", - DataType::Decimal128(DECIMAL128_MAX_SCALE as u8, 0), - true, - ) - .into(), - Field::new("a", DataType::Int64, true).into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new( - Decimal128Array::from(vec![2, 2, 3, 5, 0, 5]) - .with_precision_and_scale(DECIMAL128_MAX_SCALE as u8, 0) - .unwrap(), - )), // base - ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4, 4, 0]))), // exponent - ], - arg_fields, - number_rows: 6, - return_field: Field::new( - "f", - DataType::Decimal128(DECIMAL128_MAX_SCALE as u8, 0), - true, - ) - .into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = PowerFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let ints = as_decimal128_array(&arr) - .expect("failed to convert result to an array"); - - assert_eq!(ints.len(), 6); - assert_eq!(ints.value(0), i128::from(8)); - assert_eq!(ints.value(1), i128::from(4)); - assert_eq!(ints.value(2), i128::from(81)); - assert_eq!(ints.value(3), i128::from(625)); - assert_eq!(ints.value(4), i128::from(0)); - assert_eq!(ints.value(5), i128::from(1)); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } + fn test_float64_power_checked_zero_negative_exp() { + assert_eq!(float64_power_checked(0.0, 1.0).unwrap(), 0.0); + assert_eq!(float64_power_checked(2.0, -1.0).unwrap(), 0.5); + for base in [0.0f64, -0.0] { + assert!(float64_power_checked(base, -1.0).is_err()); + assert!(float64_power_checked(base, -0.5).is_err()); } } - - #[test] - fn test_power_array_null() { - let arg_fields = vec![ - Field::new("a", DataType::Int64, true).into(), - Field::new("a", DataType::Int64, true).into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 2]))), // base - ColumnarValue::Array(Arc::new(Int64Array::from_iter_values_with_nulls( - vec![1, 2, 3], - Some(NullBuffer::from(vec![true, false, true])), - ))), // exponent - ], - arg_fields, - number_rows: 1, - return_field: Field::new("f", DataType::Int64, true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = PowerFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let ints = - as_int64_array(&arr).expect("failed to convert result to an array"); - - assert_eq!(ints.len(), 3); - assert!(!ints.is_null(0)); - assert_eq!(ints.value(0), i64::from(2)); - assert!(ints.is_null(1)); - assert!(!ints.is_null(2)); - assert_eq!(ints.value(2), i64::from(8)); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } - - #[test] - fn test_power_decimal_with_scale() { - // 2.5 ^ 4 = 39 - // 2.5 is 25 in Decimal128(2, 1) by parsing rules - // Signature is Decimal128(2, 1) -> Int64 -> Decimal128(2, 1), therefore - // result is 390 in Decimal128(2, 1) aka 39 in unscaled Decimal128(2, 0) - let arg_fields = vec![ - Field::new( - "a", - DataType::Decimal128(DECIMAL128_MAX_SCALE as u8, 0), - true, - ) - .into(), - Field::new("a", DataType::Int64, true).into(), - ]; - let args = ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Decimal128( - Some(i128::from(25)), - 2, - 1, - )), // base - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), // exponent - ], - arg_fields, - number_rows: 1, - return_field: Field::new("f", DataType::Decimal128(2, 1), true).into(), - config_options: Arc::new(ConfigOptions::default()), - }; - let result = PowerFunc::new() - .invoke_with_args(args) - .expect("failed to initialize function power"); - - match result { - ColumnarValue::Array(arr) => { - let ints = as_decimal128_array(&arr) - .expect("failed to convert result to an array"); - - assert_eq!(ints.len(), 1); - assert_eq!(ints.value(0), i128::from(390)); - // Signature stays the same as input - assert_eq!(*arr.data_type(), DataType::Decimal128(2, 1)); - } - ColumnarValue::Scalar(_) => { - panic!("Expected an array value") - } - } - } - - #[test] - fn test_pow_decimal128_helper() { - // Expression: 2.5 ^ 4 = 39.0625 - assert_eq!(pow_decimal_int(25, 1, 4).unwrap(), i128::from(390)); - assert_eq!(pow_decimal_int(2500, 3, 4).unwrap(), i128::from(39062)); - assert_eq!(pow_decimal_int(25000, 4, 4).unwrap(), i128::from(390625)); - - // Expression: 25 ^ 4 = 390625 - assert_eq!(pow_decimal_int(25, 0, 4).unwrap(), i128::from(390625)); - - // Expressions for edge cases - assert_eq!(pow_decimal_int(25, 1, 1).unwrap(), i128::from(25)); - assert_eq!(pow_decimal_int(25, 0, 1).unwrap(), i128::from(25)); - assert_eq!(pow_decimal_int(25, 0, 0).unwrap(), i128::from(1)); - assert_eq!(pow_decimal_int(25, 1, 0).unwrap(), i128::from(10)); - - assert_eq!( - pow_decimal_int(25, -1, 4).unwrap_err().to_string(), - "Not yet implemented: Negative scale is not yet supported value: -1" - ); - } } diff --git a/datafusion/functions/src/math/random.rs b/datafusion/functions/src/math/random.rs index 4270eff665728..f1833c23c305a 100644 --- a/datafusion/functions/src/math/random.rs +++ b/datafusion/functions/src/math/random.rs @@ -15,15 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::Float64Array; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Float64; -use rand::{rng, Rng}; +use rand::{Rng, rng}; -use datafusion_common::{assert_or_internal_err, Result}; +use datafusion_common::{Result, assert_or_internal_err}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -45,6 +44,7 @@ The random seed is unique to each row."#, #[derive(Debug, PartialEq, Eq, Hash)] pub struct RandomFunc { signature: Signature, + aliases: Vec, } impl Default for RandomFunc { @@ -57,15 +57,12 @@ impl RandomFunc { pub fn new() -> Self { Self { signature: Signature::nullary(Volatility::Volatile), + aliases: vec![String::from("rand")], } } } impl ScalarUDFImpl for RandomFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "random" } @@ -78,6 +75,10 @@ impl ScalarUDFImpl for RandomFunc { Ok(Float64) } + fn aliases(&self) -> &[String] { + &self.aliases + } + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { assert_or_internal_err!( args.args.is_empty(), diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 5f9b1eb6ad58b..78016c0f52f71 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -15,23 +15,138 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::sync::Arc; - -use crate::utils::make_scalar_function; +use crate::utils::{calculate_binary_decimal_math, calculate_binary_math}; -use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; -use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::DataType::{Float32, Float64, Int32}; -use arrow::datatypes::{DataType, Float32Type, Float64Type, Int32Type}; -use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; +use arrow::array::ArrayRef; +use arrow::datatypes::DataType::{ + Decimal32, Decimal64, Decimal128, Decimal256, Float32, Float64, +}; +use arrow::datatypes::{ + ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, + Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type, +}; +use arrow::datatypes::{Field, FieldRef}; +use arrow::error::ArrowError; +use datafusion_common::types::{ + NativeType, logical_float32, logical_float64, logical_int32, +}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - Volatility, + Coercion, ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; +use std::sync::Arc; + +fn output_scale_for_decimal(precision: u8, input_scale: i8, decimal_places: i32) -> i8 { + // `decimal_places` controls the maximum output scale, but scale cannot exceed the input scale. + // + // For negative-scale decimals, allow further scale reduction to match negative `decimal_places` + // (e.g. scale -2 rounded to -3 becomes scale -3). This preserves fixed precision by + // representing the rounded result at a coarser scale. + if input_scale < 0 { + // Decimal scales must be within [-precision, precision] and fit in i8. For negative-scale + // decimals, allow rounding to move the output scale further negative, but cap it at + // `-precision` (beyond that, the rounded result is always 0). + let min_scale = -i32::from(precision); + let new_scale = i32::from(input_scale).min(decimal_places).max(min_scale); + return new_scale as i8; + } + + // The `min` ensures the result is always within i8 range because `input_scale` is i8. + let decimal_places = decimal_places.max(0); + i32::from(input_scale).min(decimal_places) as i8 +} + +fn normalize_decimal_places_for_decimal( + decimal_places: i32, + precision: u8, + scale: i8, +) -> Option { + if decimal_places >= 0 { + return Some(decimal_places); + } + + // For fixed precision decimals, the absolute value is strictly less than 10^(precision - scale). + // If the rounding position is beyond that (abs(decimal_places) > precision - scale), the + // rounded result is always 0, and we can avoid overflow in intermediate 10^n computations. + let max_rounding_pow10 = i64::from(precision) - i64::from(scale); + if max_rounding_pow10 <= 0 { + return None; + } + + let abs_decimal_places = i64::from(decimal_places.unsigned_abs()); + (abs_decimal_places <= max_rounding_pow10).then_some(decimal_places) +} + +fn validate_decimal_precision( + value: T::Native, + precision: u8, + scale: i8, +) -> Result { + T::validate_decimal_precision(value, precision, scale).map_err(|e| { + ArrowError::ComputeError(format!( + "Decimal overflow: rounded value exceeds precision {precision}: {e}" + )) + })?; + Ok(value) +} + +fn calculate_new_precision_scale( + precision: u8, + scale: i8, + decimal_places: Option, +) -> Result { + if let Some(decimal_places) = decimal_places { + let new_scale = output_scale_for_decimal(precision, scale, decimal_places); + + // When rounding an integer decimal (scale == 0) to a negative `decimal_places`, a carry can + // add an extra digit to the integer part (e.g. 99 -> 100 when rounding to -1). This can + // only happen when the rounding position is within the existing precision. + let abs_decimal_places = decimal_places.unsigned_abs(); + let new_precision = if scale == 0 + && decimal_places < 0 + && abs_decimal_places <= u32::from(precision) + { + precision.saturating_add(1).min(T::MAX_PRECISION) + } else { + precision + }; + Ok(T::TYPE_CONSTRUCTOR(new_precision, new_scale)) + } else { + let new_precision = precision.saturating_add(1).min(T::MAX_PRECISION); + Ok(T::TYPE_CONSTRUCTOR(new_precision, scale)) + } +} + +fn decimal_places_from_scalar(scalar: &ScalarValue) -> Result { + let out_of_range = |value: String| { + datafusion_common::DataFusionError::Execution(format!( + "round decimal_places {value} is out of supported i32 range" + )) + }; + match scalar { + ScalarValue::Int8(Some(v)) => Ok(i32::from(*v)), + ScalarValue::Int16(Some(v)) => Ok(i32::from(*v)), + ScalarValue::Int32(Some(v)) => Ok(*v), + ScalarValue::Int64(Some(v)) => { + i32::try_from(*v).map_err(|_| out_of_range(v.to_string())) + } + ScalarValue::UInt8(Some(v)) => Ok(i32::from(*v)), + ScalarValue::UInt16(Some(v)) => Ok(i32::from(*v)), + ScalarValue::UInt32(Some(v)) => { + i32::try_from(*v).map_err(|_| out_of_range(v.to_string())) + } + ScalarValue::UInt64(Some(v)) => { + i32::try_from(*v).map_err(|_| out_of_range(v.to_string())) + } + other => exec_err!( + "Unexpected datatype for decimal_places: {}", + other.data_type() + ), + } +} #[user_doc( doc_section(label = "Math Functions"), @@ -64,14 +179,33 @@ impl Default for RoundFunc { impl RoundFunc { pub fn new() -> Self { - use DataType::*; + let decimal = Coercion::new_exact(TypeSignatureClass::Decimal); + let decimal_places = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ); + let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32())); + let float64 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ); Self { signature: Signature::one_of( vec![ - Exact(vec![Float64, Int64]), - Exact(vec![Float32, Int64]), - Exact(vec![Float64]), - Exact(vec![Float32]), + TypeSignature::Coercible(vec![ + decimal.clone(), + decimal_places.clone(), + ]), + TypeSignature::Coercible(vec![decimal]), + TypeSignature::Coercible(vec![ + float32.clone(), + decimal_places.clone(), + ]), + TypeSignature::Coercible(vec![float32]), + TypeSignature::Coercible(vec![float64.clone(), decimal_places]), + TypeSignature::Coercible(vec![float64]), ], Volatility::Immutable, ), @@ -80,10 +214,6 @@ impl RoundFunc { } impl ScalarUDFImpl for RoundFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "round" } @@ -92,15 +222,218 @@ impl ScalarUDFImpl for RoundFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - match arg_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - } + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let input_field = &args.arg_fields[0]; + let input_type = input_field.data_type(); + + // If decimal_places is a scalar literal, we can incorporate it into the output type + // (scale reduction). Otherwise, keep the input scale as we can't pick a per-row scale. + // + // Note: `scalar_arguments` contains the original literal values (pre-coercion), so + // integer literals may appear as Int64 even though the signature coerces them to Int32. + let decimal_places: Option = match args.scalar_arguments.get(1) { + None => Some(0), // No dp argument means default to 0 + Some(None) => None, // dp is not a literal (e.g. column) + Some(Some(scalar)) if scalar.is_null() => Some(0), // null dp => default to 0 + Some(Some(scalar)) => Some(decimal_places_from_scalar(scalar)?), + }; + + // Calculate return type based on input type + // For decimals: reduce scale to decimal_places (reclaims precision for integer part) + // This matches Spark/DuckDB behavior where ROUND adjusts the scale + // BUT only if dp is a scalar literal - otherwise keep original scale and add + // extra precision to accommodate potential carry-over. + let return_type = + match input_type { + Float32 => Float32, + Decimal32(precision, scale) => calculate_new_precision_scale::< + Decimal32Type, + >( + *precision, *scale, decimal_places + )?, + Decimal64(precision, scale) => calculate_new_precision_scale::< + Decimal64Type, + >( + *precision, *scale, decimal_places + )?, + Decimal128(precision, scale) => calculate_new_precision_scale::< + Decimal128Type, + >( + *precision, *scale, decimal_places + )?, + Decimal256(precision, scale) => calculate_new_precision_scale::< + Decimal256Type, + >( + *precision, *scale, decimal_places + )?, + _ => Float64, + }; + + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new(self.name(), return_type, nullable))) + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("use return_field_from_args instead") } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(round, vec![])(&args.args) + if args.arg_fields.iter().any(|a| a.data_type().is_null()) { + return ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(args.return_type(), None); + } + + let default_decimal_places = ColumnarValue::Scalar(ScalarValue::Int32(Some(0))); + let decimal_places = if args.args.len() == 2 { + &args.args[1] + } else { + &default_decimal_places + }; + + if let (ColumnarValue::Scalar(value_scalar), ColumnarValue::Scalar(dp_scalar)) = + (&args.args[0], decimal_places) + { + if value_scalar.is_null() || dp_scalar.is_null() { + return ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(args.return_type(), None); + } + + let dp = if let ScalarValue::Int32(Some(dp)) = dp_scalar { + *dp + } else { + return internal_err!( + "Unexpected datatype for decimal_places: {}", + dp_scalar.data_type() + ); + }; + + match (value_scalar, args.return_type()) { + (ScalarValue::Float32(Some(v)), _) => { + let rounded = round_float(*v, dp)?; + Ok(ColumnarValue::Scalar(ScalarValue::from(rounded))) + } + (ScalarValue::Float64(Some(v)), _) => { + let rounded = round_float(*v, dp)?; + Ok(ColumnarValue::Scalar(ScalarValue::from(rounded))) + } + ( + ScalarValue::Decimal32(Some(v), in_precision, scale), + Decimal32(out_precision, out_scale), + ) => { + let rounded = + round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?; + let rounded = if *out_precision == Decimal32Type::MAX_PRECISION + && *scale == 0 + && dp < 0 + { + // With scale == 0 and negative dp, rounding can carry into an additional + // digit (e.g. 99 -> 100). If we're already at max precision we can't widen + // the type, so validate and error rather than producing an invalid decimal. + validate_decimal_precision::( + rounded, + *out_precision, + *out_scale, + ) + } else { + Ok(rounded) + }?; + let scalar = + ScalarValue::Decimal32(Some(rounded), *out_precision, *out_scale); + Ok(ColumnarValue::Scalar(scalar)) + } + ( + ScalarValue::Decimal64(Some(v), in_precision, scale), + Decimal64(out_precision, out_scale), + ) => { + let rounded = + round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?; + let rounded = if *out_precision == Decimal64Type::MAX_PRECISION + && *scale == 0 + && dp < 0 + { + // See Decimal32 branch for details. + validate_decimal_precision::( + rounded, + *out_precision, + *out_scale, + ) + } else { + Ok(rounded) + }?; + let scalar = + ScalarValue::Decimal64(Some(rounded), *out_precision, *out_scale); + Ok(ColumnarValue::Scalar(scalar)) + } + ( + ScalarValue::Decimal128(Some(v), in_precision, scale), + Decimal128(out_precision, out_scale), + ) => { + let rounded = + round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?; + let rounded = if *out_precision == Decimal128Type::MAX_PRECISION + && *scale == 0 + && dp < 0 + { + // See Decimal32 branch for details. + validate_decimal_precision::( + rounded, + *out_precision, + *out_scale, + ) + } else { + Ok(rounded) + }?; + let scalar = ScalarValue::Decimal128( + Some(rounded), + *out_precision, + *out_scale, + ); + Ok(ColumnarValue::Scalar(scalar)) + } + ( + ScalarValue::Decimal256(Some(v), in_precision, scale), + Decimal256(out_precision, out_scale), + ) => { + let rounded = + round_decimal_or_zero(*v, *in_precision, *scale, *out_scale, dp)?; + let rounded = if *out_precision == Decimal256Type::MAX_PRECISION + && *scale == 0 + && dp < 0 + { + // See Decimal32 branch for details. + validate_decimal_precision::( + rounded, + *out_precision, + *out_scale, + ) + } else { + Ok(rounded) + }?; + let scalar = ScalarValue::Decimal256( + Some(rounded), + *out_precision, + *out_scale, + ); + Ok(ColumnarValue::Scalar(scalar)) + } + (ScalarValue::Null, _) => ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(args.return_type(), None), + (value_scalar, return_type) => { + internal_err!( + "Unexpected datatype for round(value, decimal_places): value {}, return type {}", + value_scalar.data_type(), + return_type + ) + } + } + } else { + round_columnar( + &args.args[0], + decimal_places, + args.number_rows, + args.return_type(), + ) + } } fn output_ordering(&self, input: &[ExprProperties]) -> Result { @@ -123,107 +456,270 @@ impl ScalarUDFImpl for RoundFunc { } } -/// Round SQL function -fn round(args: &[ArrayRef]) -> Result { - if args.len() != 1 && args.len() != 2 { - return exec_err!( - "round function requires one or two arguments, got {}", - args.len() - ); +fn round_columnar( + value: &ColumnarValue, + decimal_places: &ColumnarValue, + number_rows: usize, + return_type: &DataType, +) -> Result { + let value_array = value.to_array(number_rows)?; + let both_scalars = matches!(value, ColumnarValue::Scalar(_)) + && matches!(decimal_places, ColumnarValue::Scalar(_)); + let decimal_places_is_array = matches!(decimal_places, ColumnarValue::Array(_)); + + let arr: ArrayRef = match (value_array.data_type(), return_type) { + (Float64, _) => { + let result = calculate_binary_math::( + value_array.as_ref(), + decimal_places, + round_float::, + )?; + result as _ + } + (Float32, _) => { + let result = calculate_binary_math::( + value_array.as_ref(), + decimal_places, + round_float::, + )?; + result as _ + } + (Decimal32(input_precision, scale), Decimal32(precision, new_scale)) => { + // reduce scale to reclaim integer precision + let result = calculate_binary_decimal_math::< + Decimal32Type, + Int32Type, + Decimal32Type, + _, + >( + value_array.as_ref(), + decimal_places, + |v, dp| { + let rounded = round_decimal_or_zero( + v, + *input_precision, + *scale, + *new_scale, + dp, + )?; + if *precision == Decimal32Type::MAX_PRECISION + && (decimal_places_is_array || (*scale == 0 && dp < 0)) + { + // If we're already at max precision, we can't widen the result type. For + // dp arrays, or for scale == 0 with negative dp, rounding can overflow the + // fixed-precision type. Validate per-row and return an error instead of + // producing an invalid decimal that Arrow may display incorrectly. + validate_decimal_precision::( + rounded, *precision, *new_scale, + ) + } else { + Ok(rounded) + } + }, + *precision, + *new_scale, + )?; + result as _ + } + (Decimal64(input_precision, scale), Decimal64(precision, new_scale)) => { + let result = calculate_binary_decimal_math::< + Decimal64Type, + Int32Type, + Decimal64Type, + _, + >( + value_array.as_ref(), + decimal_places, + |v, dp| { + let rounded = round_decimal_or_zero( + v, + *input_precision, + *scale, + *new_scale, + dp, + )?; + if *precision == Decimal64Type::MAX_PRECISION + && (decimal_places_is_array || (*scale == 0 && dp < 0)) + { + // See Decimal32 branch for details. + validate_decimal_precision::( + rounded, *precision, *new_scale, + ) + } else { + Ok(rounded) + } + }, + *precision, + *new_scale, + )?; + result as _ + } + (Decimal128(input_precision, scale), Decimal128(precision, new_scale)) => { + let result = calculate_binary_decimal_math::< + Decimal128Type, + Int32Type, + Decimal128Type, + _, + >( + value_array.as_ref(), + decimal_places, + |v, dp| { + let rounded = round_decimal_or_zero( + v, + *input_precision, + *scale, + *new_scale, + dp, + )?; + if *precision == Decimal128Type::MAX_PRECISION + && (decimal_places_is_array || (*scale == 0 && dp < 0)) + { + // See Decimal32 branch for details. + validate_decimal_precision::( + rounded, *precision, *new_scale, + ) + } else { + Ok(rounded) + } + }, + *precision, + *new_scale, + )?; + result as _ + } + (Decimal256(input_precision, scale), Decimal256(precision, new_scale)) => { + let result = calculate_binary_decimal_math::< + Decimal256Type, + Int32Type, + Decimal256Type, + _, + >( + value_array.as_ref(), + decimal_places, + |v, dp| { + let rounded = round_decimal_or_zero( + v, + *input_precision, + *scale, + *new_scale, + dp, + )?; + if *precision == Decimal256Type::MAX_PRECISION + && (decimal_places_is_array || (*scale == 0 && dp < 0)) + { + // See Decimal32 branch for details. + validate_decimal_precision::( + rounded, *precision, *new_scale, + ) + } else { + Ok(rounded) + } + }, + *precision, + *new_scale, + )?; + result as _ + } + (other, _) => exec_err!("Unsupported data type {other:?} for function round")?, + }; + + if both_scalars { + ScalarValue::try_from_array(&arr, 0).map(ColumnarValue::Scalar) + } else { + Ok(ColumnarValue::Array(arr)) } +} - let mut decimal_places = ColumnarValue::Scalar(ScalarValue::Int64(Some(0))); +fn round_float(value: T, decimal_places: i32) -> Result +where + T: num_traits::Float, +{ + let factor = T::from(10_f64.powi(decimal_places)).ok_or_else(|| { + ArrowError::ComputeError(format!( + "Invalid value for decimal places: {decimal_places}" + )) + })?; + Ok((value * factor).round() / factor) +} - if args.len() == 2 { - decimal_places = ColumnarValue::Array(Arc::clone(&args[1])); +fn round_decimal( + value: V, + input_scale: i8, + output_scale: i8, + decimal_places: i32, +) -> Result { + let diff = i64::from(input_scale) - i64::from(decimal_places); + if diff <= 0 { + return Ok(value); } - match args[0].data_type() { - Float64 => match decimal_places { - ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { - let decimal_places: i32 = decimal_places.try_into().map_err(|e| { - exec_datafusion_err!( - "Invalid value for decimal places: {decimal_places}: {e}" - ) - })?; - - let result = args[0] - .as_primitive::() - .unary::<_, Float64Type>(|value: f64| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) - }); - Ok(Arc::new(result) as _) - } - ColumnarValue::Array(decimal_places) => { - let options = CastOptions { - safe: false, // raise error if the cast is not possible - ..Default::default() - }; - let decimal_places = cast_with_options(&decimal_places, &Int32, &options) - .map_err(|e| { - exec_datafusion_err!("Invalid values for decimal places: {e}") - })?; - - let values = args[0].as_primitive::(); - let decimal_places = decimal_places.as_primitive::(); - let result = arrow::compute::binary::<_, _, _, Float64Type>( - values, - decimal_places, - |value, decimal_places| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) - }, - )?; - Ok(Arc::new(result) as _) - } - _ => { - exec_err!("round function requires a scalar or array for decimal_places") - } - }, + debug_assert!(diff <= i64::from(u32::MAX)); + let diff = diff as u32; + + let one = V::ONE; + let two = V::from_usize(2).ok_or_else(|| { + ArrowError::ComputeError("Internal error: could not create constant 2".into()) + })?; + let ten = V::from_usize(10).ok_or_else(|| { + ArrowError::ComputeError("Internal error: could not create constant 10".into()) + })?; + + let factor = ten.pow_checked(diff).map_err(|_| { + ArrowError::ComputeError(format!( + "Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}" + )) + })?; + + let mut quotient = value.div_wrapping(factor); + let remainder = value.mod_wrapping(factor); + + // `factor` is an even number (10^n, n > 0), so `factor / 2` is the tie threshold + let threshold = factor.div_wrapping(two); + if remainder >= threshold { + quotient = quotient.add_checked(one).map_err(|_| { + ArrowError::ComputeError("Overflow while rounding decimal".into()) + })?; + } else if remainder <= threshold.neg_wrapping() { + quotient = quotient.sub_checked(one).map_err(|_| { + ArrowError::ComputeError("Overflow while rounding decimal".into()) + })?; + } - Float32 => match decimal_places { - ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { - let decimal_places: i32 = decimal_places.try_into().map_err(|e| { - exec_datafusion_err!( - "Invalid value for decimal places: {decimal_places}: {e}" - ) - })?; - let result = args[0] - .as_primitive::() - .unary::<_, Float32Type>(|value: f32| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) - }); - Ok(Arc::new(result) as _) - } - ColumnarValue::Array(_) => { - let ColumnarValue::Array(decimal_places) = - decimal_places.cast_to(&Int32, None).map_err(|e| { - exec_datafusion_err!("Invalid values for decimal places: {e}") - })? - else { - panic!("Unexpected result of ColumnarValue::Array.cast") - }; - - let values = args[0].as_primitive::(); - let decimal_places = decimal_places.as_primitive::(); - let result: PrimitiveArray = arrow::compute::binary( - values, - decimal_places, - |value, decimal_places| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) - }, - )?; - Ok(Arc::new(result) as _) - } - _ => { - exec_err!("round function requires a scalar or array for decimal_places") - } - }, + // `quotient` is the rounded value at scale `decimal_places`. Rescale to the desired + // `output_scale` (which is always >= `decimal_places` in cases where diff > 0). + let scale_shift = i64::from(output_scale) - i64::from(decimal_places); + if scale_shift == 0 { + return Ok(quotient); + } - other => exec_err!("Unsupported data type {other:?} for function round"), + debug_assert!(scale_shift > 0); + debug_assert!(scale_shift <= i64::from(u32::MAX)); + let scale_shift = scale_shift as u32; + let shift_factor = ten.pow_checked(scale_shift).map_err(|_| { + ArrowError::ComputeError(format!( + "Overflow while rounding decimal with scale {input_scale} and decimal places {decimal_places}" + )) + })?; + quotient + .mul_checked(shift_factor) + .map_err(|_| ArrowError::ComputeError("Overflow while rounding decimal".into())) +} + +fn round_decimal_or_zero( + value: V, + precision: u8, + input_scale: i8, + output_scale: i8, + decimal_places: i32, +) -> Result { + if let Some(dp) = + normalize_decimal_places_for_decimal(decimal_places, precision, input_scale) + { + round_decimal(value, input_scale, output_scale, dp) + } else { + V::from_usize(0).ok_or_else(|| { + ArrowError::ComputeError("Internal error: could not create constant 0".into()) + }) } } @@ -231,11 +727,33 @@ fn round(args: &[ArrayRef]) -> Result { mod test { use std::sync::Arc; - use crate::math::round::round; - use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; - use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::DataFusionError; + use datafusion_common::ScalarValue; + use datafusion_common::cast::{as_float32_array, as_float64_array}; + use datafusion_expr::ColumnarValue; + + fn round_arrays( + value: ArrayRef, + decimal_places: Option, + ) -> Result { + let number_rows = value.len(); + // NOTE: For decimal inputs, the actual ROUND return type can differ from the + // input type (scale reduction for literal `decimal_places`). These unit tests + // only exercise Float32/Float64 behavior. + let return_type = value.data_type().clone(); + let value = ColumnarValue::Array(value); + let decimal_places = decimal_places + .map(ColumnarValue::Array) + .unwrap_or_else(|| ColumnarValue::Scalar(ScalarValue::Int32(Some(0)))); + + let result = + super::round_columnar(&value, &decimal_places, number_rows, &return_type)?; + match result { + ColumnarValue::Array(array) => Ok(array), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(1), + } + } #[test] fn test_round_f32() { @@ -244,7 +762,8 @@ mod test { Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places ]; - let result = round(&args).expect("failed to initialize function round"); + let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1]))) + .expect("failed to initialize function round"); let floats = as_float32_array(&result).expect("failed to initialize function round"); @@ -262,7 +781,8 @@ mod test { Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places ]; - let result = round(&args).expect("failed to initialize function round"); + let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1]))) + .expect("failed to initialize function round"); let floats = as_float64_array(&result).expect("failed to initialize function round"); @@ -279,7 +799,8 @@ mod test { Arc::new(Float32Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input ]; - let result = round(&args).expect("failed to initialize function round"); + let result = round_arrays(Arc::clone(&args[0]), None) + .expect("failed to initialize function round"); let floats = as_float32_array(&result).expect("failed to initialize function round"); @@ -294,7 +815,8 @@ mod test { Arc::new(Float64Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input ]; - let result = round(&args).expect("failed to initialize function round"); + let result = round_arrays(Arc::clone(&args[0]), None) + .expect("failed to initialize function round"); let floats = as_float64_array(&result).expect("failed to initialize function round"); @@ -310,9 +832,12 @@ mod test { Arc::new(Int64Array::from(vec![2147483648])), // decimal_places ]; - let result = round(&args); + let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1]))); assert!(result.is_err()); - assert!(matches!(result, Err(DataFusionError::Execution(_)))); + assert!(matches!( + result, + Err(DataFusionError::ArrowError(_, _)) | Err(DataFusionError::Execution(_)) + )); } } diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index 2e616fe0fe357..8c8eeacf12394 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, AsArray}; +use arrow::array::AsArray; use arrow::datatypes::DataType::{Float32, Float64}; use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -30,8 +30,6 @@ use datafusion_expr::{ }; use datafusion_macros::user_doc; -use crate::utils::make_scalar_function; - #[user_doc( doc_section(label = "Math Functions"), description = r#"Returns the sign of a number. @@ -73,10 +71,6 @@ impl SignumFunc { } impl ScalarUDFImpl for SignumFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "signum" } @@ -98,7 +92,53 @@ impl ScalarUDFImpl for SignumFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(signum, vec![])(&args.args) + let return_type = args.return_type().clone(); + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return ColumnarValue::Scalar(ScalarValue::Null) + .cast_to(&return_type, None); + } + + match scalar { + ScalarValue::Float64(Some(v)) => { + let result = if v == 0.0 { 0.0 } else { v.signum() }; + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(result)))) + } + ScalarValue::Float32(Some(v)) => { + let result = if v == 0.0 { 0.0 } else { v.signum() }; + Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(result)))) + } + _ => { + internal_err!( + "Unexpected scalar type for signum: {:?}", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => match array.data_type() { + Float64 => Ok(ColumnarValue::Array(Arc::new( + array.as_primitive::().unary::<_, Float64Type>( + |x: f64| { + if x == 0.0 { 0.0 } else { x.signum() } + }, + ), + ))), + Float32 => Ok(ColumnarValue::Array(Arc::new( + array.as_primitive::().unary::<_, Float32Type>( + |x: f32| { + if x == 0.0 { 0.0 } else { x.signum() } + }, + ), + ))), + other => { + internal_err!("Unsupported data type {other:?} for function signum") + } + }, + } } fn documentation(&self) -> Option<&Documentation> { @@ -106,41 +146,6 @@ impl ScalarUDFImpl for SignumFunc { } } -/// signum SQL function -fn signum(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - Float64 => Ok(Arc::new( - args[0] - .as_primitive::() - .unary::<_, Float64Type>( - |x: f64| { - if x == 0_f64 { - 0_f64 - } else { - x.signum() - } - }, - ), - ) as ArrayRef), - - Float32 => Ok(Arc::new( - args[0] - .as_primitive::() - .unary::<_, Float32Type>( - |x: f32| { - if x == 0_f32 { - 0_f32 - } else { - x.signum() - } - }, - ), - ) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function signum"), - } -} - #[cfg(test)] mod test { use std::sync::Arc; diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs index 9d1b4336f6389..991ad0e9c470d 100644 --- a/datafusion/functions/src/math/trunc.rs +++ b/datafusion/functions/src/math/trunc.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use crate::utils::make_scalar_function; @@ -24,9 +23,9 @@ use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; use arrow::datatypes::DataType::{Float32, Float64}; use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type}; use datafusion_common::ScalarValue::Int64; -use datafusion_common::{exec_err, Result}; -use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -90,10 +89,6 @@ impl TruncFunc { } impl ScalarUDFImpl for TruncFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "trunc" } @@ -110,7 +105,50 @@ impl ScalarUDFImpl for TruncFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(trunc, vec![])(&args.args) + // Extract precision from second argument (default 0) + let precision = match args.args.get(1) { + Some(ColumnarValue::Scalar(Int64(Some(p)))) => Some(*p), + Some(ColumnarValue::Scalar(Int64(None))) => None, // null precision + Some(ColumnarValue::Array(_)) => { + // Precision is an array - use array path + return make_scalar_function(trunc, vec![])(&args.args); + } + None => Some(0), // default precision + Some(cv) => { + return exec_err!( + "trunc function requires precision to be Int64, got {:?}", + cv.data_type() + ); + } + }; + + // Scalar fast path using tuple matching for (value, precision) + match (&args.args[0], precision) { + // Null cases + (ColumnarValue::Scalar(sv), _) if sv.is_null() => { + ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None) + } + (_, None) => { + ColumnarValue::Scalar(ScalarValue::Null).cast_to(args.return_type(), None) + } + // Scalar cases + (ColumnarValue::Scalar(ScalarValue::Float64(Some(v))), Some(p)) => Ok( + ColumnarValue::Scalar(ScalarValue::Float64(Some(if p == 0 { + v.trunc() + } else { + compute_truncate64(*v, p) + }))), + ), + (ColumnarValue::Scalar(ScalarValue::Float32(Some(v))), Some(p)) => Ok( + ColumnarValue::Scalar(ScalarValue::Float32(Some(if p == 0 { + v.trunc() + } else { + compute_truncate32(*v, p) + }))), + ), + // Array path for everything else + _ => make_scalar_function(trunc, vec![])(&args.args), + } } fn output_ordering(&self, input: &[ExprProperties]) -> Result { @@ -158,11 +196,7 @@ fn trunc(args: &[ArrayRef]) -> Result { args[0] .as_primitive::() .unary::<_, Float64Type>(|x: f64| { - if x == 0_f64 { - 0_f64 - } else { - x.trunc() - } + if x == 0_f64 { 0_f64 } else { x.trunc() } }), ) as ArrayRef) } @@ -184,11 +218,7 @@ fn trunc(args: &[ArrayRef]) -> Result { args[0] .as_primitive::() .unary::<_, Float32Type>(|x: f32| { - if x == 0_f32 { - 0_f32 - } else { - x.trunc() - } + if x == 0_f32 { 0_f32 } else { x.trunc() } }), ) as ArrayRef) } @@ -210,12 +240,12 @@ fn trunc(args: &[ArrayRef]) -> Result { fn compute_truncate32(x: f32, y: i64) -> f32 { let factor = 10.0_f32.powi(y as i32); - (x * factor).round() / factor + (x * factor).trunc() / factor } fn compute_truncate64(x: f64, y: i64) -> f64 { let factor = 10.0_f64.powi(y as i32); - (x * factor).round() / factor + (x * factor).trunc() / factor } #[cfg(test)] @@ -246,9 +276,9 @@ mod test { assert_eq!(floats.len(), 5); assert_eq!(floats.value(0), 15.0); - assert_eq!(floats.value(1), 1_234.268); + assert_eq!(floats.value(1), 1_234.267); assert_eq!(floats.value(2), 1_233.12); - assert_eq!(floats.value(3), 3.312_98); + assert_eq!(floats.value(3), 3.312_97); assert_eq!(floats.value(4), -21.123_4); } @@ -271,9 +301,9 @@ mod test { assert_eq!(floats.len(), 5); assert_eq!(floats.value(0), 5.0); - assert_eq!(floats.value(1), 234.268); + assert_eq!(floats.value(1), 234.267); assert_eq!(floats.value(2), 123.12); - assert_eq!(floats.value(3), 123.312_98); + assert_eq!(floats.value(3), 123.312_97); assert_eq!(floats.value(4), -321.123_1); } diff --git a/datafusion/functions/src/planner.rs b/datafusion/functions/src/planner.rs index ccd167997003e..9854326945e95 100644 --- a/datafusion/functions/src/planner.rs +++ b/datafusion/functions/src/planner.rs @@ -19,9 +19,9 @@ use datafusion_common::Result; use datafusion_expr::{ + Expr, expr::ScalarFunction, planner::{ExprPlanner, PlannerResult}, - Expr, }; #[deprecated( diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index da4e23f91de7d..75cc5d9514cbd 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -19,8 +19,8 @@ use arrow::error::ArrowError; use regex::Regex; -use std::collections::hash_map::Entry; use std::collections::HashMap; +use std::collections::hash_map::Entry; use std::sync::Arc; pub mod regexpcount; pub mod regexpinstr; diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index ae08ca3e920cf..d970eccc43a54 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -22,10 +22,10 @@ use arrow::datatypes::{ DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View, }; use arrow::error::ArrowError; -use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact, - TypeSignature::Uniform, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature::Exact, TypeSignature::Uniform, Volatility, }; use datafusion_macros::user_doc; use itertools::izip; @@ -92,10 +92,6 @@ impl RegexpCountFunc { } impl ScalarUDFImpl for RegexpCountFunc { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "regexp_count" } @@ -108,10 +104,7 @@ impl ScalarUDFImpl for RegexpCountFunc { Ok(Int64) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = &args.args; let len = args @@ -146,7 +139,9 @@ impl ScalarUDFImpl for RegexpCountFunc { pub fn regexp_count_func(args: &[ArrayRef]) -> Result { let args_len = args.len(); if !(2..=4).contains(&args_len) { - return exec_err!("regexp_count was called with {args_len} arguments. It requires at least 2 and at most 4."); + return exec_err!( + "regexp_count was called with {args_len} arguments. It requires at least 2 and at most 4." + ); } let values = &args[0]; @@ -273,7 +268,10 @@ where S: StringArrayType<'a>, { let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 { - (Some(regex_array.value(0)), true) + ( + (!regex_array.is_null(0)).then(|| regex_array.value(0)), + true, + ) } else { (None, false) }; @@ -305,8 +303,8 @@ where match (is_regex_scalar, is_start_scalar, is_flags_scalar) { (true, true, true) => { let regex = match regex_scalar { - None | Some("") => { - return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + None => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))); } Some(regex) => regex, }; @@ -322,8 +320,8 @@ where } (true, true, false) => { let regex = match regex_scalar { - None | Some("") => { - return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + None => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))); } Some(regex) => regex, }; @@ -351,8 +349,8 @@ where } (true, false, true) => { let regex = match regex_scalar { - None | Some("") => { - return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + None => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))); } Some(regex) => regex, }; @@ -371,8 +369,8 @@ where } (true, false, false) => { let regex = match regex_scalar { - None | Some("") => { - return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + None => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))); } Some(regex) => regex, }; @@ -416,7 +414,7 @@ where .zip(regex_array.iter()) .map(|(value, regex)| { let regex = match regex { - None | Some("") => return Ok(0), + None => return Ok(0), Some(regex) => regex, }; @@ -452,7 +450,7 @@ where izip!(values.iter(), regex_array.iter(), flags_array.iter()) .map(|(value, regex, flags)| { let regex = match regex { - None | Some("") => return Ok(0), + None => return Ok(0), Some(regex) => regex, }; @@ -486,7 +484,7 @@ where izip!(values.iter(), regex_array.iter(), start_array.iter()) .map(|(value, regex, start)| { let regex = match regex { - None | Some("") => return Ok(0), + None => return Ok(0), Some(regex) => regex, }; @@ -536,7 +534,7 @@ where ) .map(|(value, regex, start, flags)| { let regex = match regex { - None | Some("") => return Ok(0), + None => return Ok(0), Some(regex) => regex, }; @@ -556,7 +554,7 @@ fn count_matches( start: Option, ) -> Result { let value = match value { - None | Some("") => return Ok(0), + None => return Ok(0), Some(value) => value, }; @@ -567,8 +565,27 @@ fn count_matches( )); } - let find_slice = value.chars().skip(start as usize - 1).collect::(); - let count = pattern.find_iter(find_slice.as_str()).count(); + let char_len = value.chars().count(); + let start_index = (start as usize).saturating_sub(1); + + if start_index > char_len { + return Ok(0); + } + + // Find the byte offset for the start position (1-based character index) + let byte_offset = if start_index == char_len { + value.len() + } else { + value + .char_indices() + .nth(start_index) + .map(|(idx, _)| idx) + .unwrap_or(value.len()) + }; + + // Use string slicing instead of collecting chars into a new String + let find_slice = &value[byte_offset..]; + let count = pattern.find_iter(find_slice).count(); Ok(count as i64) } else { let count = pattern.find_iter(value).count(); @@ -582,11 +599,11 @@ mod tests { use arrow::array::{GenericStringArray, StringViewArray}; use arrow::datatypes::Field; use datafusion_common::config::ConfigOptions; - use datafusion_expr::ScalarFunctionArgs; #[test] fn test_regexp_count() { test_case_sensitive_regexp_count_scalar(); + test_case_sensitive_regexp_count_empty_pattern_scalar(); test_case_sensitive_regexp_count_scalar_start(); test_case_insensitive_regexp_count_scalar_flags(); test_case_sensitive_regexp_count_start_scalar_complex(); @@ -673,6 +690,57 @@ mod tests { }); } + fn test_case_sensitive_regexp_count_empty_pattern_scalar() { + let values = ["", "abc", "abc"]; + let start_positions = [1, 1, 2]; + let expected: Vec = vec![1, 4, 3]; + + values + .iter() + .zip(start_positions.iter()) + .enumerate() + .for_each(|(pos, (&value, &start))| { + let expected = expected.get(pos).cloned(); + let start_sv = ScalarValue::Int64(Some(start)); + + let re = regexp_count_with_scalar_values(&[ + ScalarValue::Utf8(Some(value.to_string())), + ScalarValue::Utf8(Some("".to_string())), + start_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + let re = regexp_count_with_scalar_values(&[ + ScalarValue::LargeUtf8(Some(value.to_string())), + ScalarValue::LargeUtf8(Some("".to_string())), + start_sv.clone(), + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + let re = regexp_count_with_scalar_values(&[ + ScalarValue::Utf8View(Some(value.to_string())), + ScalarValue::Utf8View(Some("".to_string())), + start_sv, + ]); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + fn test_case_sensitive_regexp_count_scalar_start() { let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; let regex = "abc"; @@ -790,7 +858,7 @@ mod tests { let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcAbc"]); let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); - let expected = Int64Array::from(vec![0, 1, 2, 2, 2]); + let expected = Int64Array::from(vec![1, 1, 2, 2, 2]); let re = regexp_count_func(&[Arc::new(values), Arc::new(regex)]).unwrap(); assert_eq!(re.as_ref(), &expected); @@ -804,7 +872,7 @@ mod tests { let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); let start = Int64Array::from(vec![1, 2, 3, 4, 5]); - let expected = Int64Array::from(vec![0, 0, 1, 1, 0]); + let expected = Int64Array::from(vec![1, 0, 1, 1, 0]); let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)]) .unwrap(); @@ -820,7 +888,7 @@ mod tests { let start = Int64Array::from(vec![1]); let flags = A::from(vec!["", "i", "", "", "i"]); - let expected = Int64Array::from(vec![0, 1, 2, 2, 3]); + let expected = Int64Array::from(vec![1, 1, 2, 2, 3]); let re = regexp_count_func(&[ Arc::new(values), @@ -908,7 +976,7 @@ mod tests { let start = Int64Array::from(vec![1, 2, 3, 4, 5]); let flags = A::from(vec!["", "i", "", "", "i"]); - let expected = Int64Array::from(vec![0, 1, 1, 1, 1]); + let expected = Int64Array::from(vec![1, 1, 1, 1, 1]); let re = regexp_count_func(&[ Arc::new(values), diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs index 011564866584e..d46e4452dbab1 100644 --- a/datafusion/functions/src/regex/regexpinstr.rs +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -23,10 +23,10 @@ use arrow::datatypes::{ DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View, }; use arrow::error::ArrowError; -use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact, - TypeSignature::Uniform, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature::Exact, TypeSignature::Uniform, Volatility, }; use datafusion_macros::user_doc; use itertools::izip; @@ -109,10 +109,6 @@ impl RegexpInstrFunc { } impl ScalarUDFImpl for RegexpInstrFunc { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "regexp_instr" } @@ -125,10 +121,7 @@ impl ScalarUDFImpl for RegexpInstrFunc { Ok(Int64) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = &args.args; let len = args @@ -163,7 +156,9 @@ impl ScalarUDFImpl for RegexpInstrFunc { pub fn regexp_instr_func(args: &[ArrayRef]) -> Result { let args_len = args.len(); if !(2..=6).contains(&args_len) { - return exec_err!("regexp_instr was called with {args_len} arguments. It requires at least 2 and at most 6."); + return exec_err!( + "regexp_instr was called with {args_len} arguments. It requires at least 2 and at most 6." + ); } let values = &args[0]; @@ -286,7 +281,6 @@ fn regexp_instr( } } -#[allow(clippy::too_many_arguments)] fn regexp_instr_inner<'a, S>( values: &S, regex_array: &S, @@ -357,14 +351,14 @@ fn handle_subexp( value: &str, byte_start_offset: usize, ) -> Result, ArrowError> { - if let Some(captures) = pattern.captures(search_slice) { - if let Some(matched) = captures.get(subexpr as usize) { - // Convert byte offset relative to search_slice back to 1-based character offset - // relative to the original `value` string. - let start_char_offset = - value[..byte_start_offset + matched.start()].chars().count() as i64 + 1; - return Ok(Some(start_char_offset)); - } + if let Some(captures) = pattern.captures(search_slice) + && let Some(matched) = captures.get(subexpr as usize) + { + // Convert byte offset relative to search_slice back to 1-based character offset + // relative to the original `value` string. + let start_char_offset = + value[..byte_start_offset + matched.start()].chars().count() as i64 + 1; + return Ok(Some(start_char_offset)); } Ok(Some(0)) // Return 0 if the subexpression was not found } @@ -448,11 +442,9 @@ where #[cfg(test)] mod tests { use super::*; - use arrow::array::Int64Array; use arrow::array::{GenericStringArray, StringViewArray}; use arrow::datatypes::Field; use datafusion_common::config::ConfigOptions; - use datafusion_expr::ScalarFunctionArgs; #[test] fn test_regexp_instr() { test_case_sensitive_regexp_instr_nulls(); diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index d75eb9141c056..56754b13db227 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -17,25 +17,24 @@ //! Regex expressions -use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray}; +use arrow::array::{Array, ArrayRef, AsArray, BooleanArray, GenericStringArray}; use arrow::compute::kernels::regexp; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::types::logical_string; use datafusion_common::{ - arrow_datafusion_err, exec_err, internal_err, plan_err, DataFusionError, Result, - ScalarValue, + Result, ScalarValue, arrow_datafusion_err, exec_err, internal_err, plan_err, }; use datafusion_expr::{ - binary_expr, cast, Coercion, ColumnarValue, Documentation, Expr, ScalarUDFImpl, - Signature, TypeSignature, TypeSignatureClass, Volatility, + Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignature, TypeSignatureClass, Volatility, binary_expr, cast, }; use datafusion_macros::user_doc; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr_common::operator::Operator; use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer; -use std::any::Any; +use regex::Regex; use std::sync::Arc; #[user_doc( @@ -56,7 +55,7 @@ SELECT regexp_like('aBc', '(b|d)', 'i'); | true | +--------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/regexp.rs) "#, standard_argument(name = "str", prefix = "String"), standard_argument(name = "regexp", prefix = "Regular"), @@ -103,10 +102,6 @@ impl RegexpLikeFunc { } impl ScalarUDFImpl for RegexpLikeFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "regexp_like" } @@ -126,40 +121,54 @@ impl ScalarUDFImpl for RegexpLikeFunc { }) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = &args.args; - - let len = args - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - - let is_scalar = len.is_none(); - let inferred_length = len.unwrap_or(1); - let args = args - .iter() - .map(|arg| arg.to_array(inferred_length)) - .collect::>>()?; - - let result = regexp_like(&args); - if is_scalar { - // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); - result.map(ColumnarValue::Scalar) - } else { - result.map(ColumnarValue::Array) + match args.as_slice() { + [ColumnarValue::Scalar(value), ColumnarValue::Scalar(pattern)] => { + let value = scalar_string(value)?; + let pattern = scalar_string(pattern)?; + regexp_like_scalar(value, pattern, None) + } + [ + ColumnarValue::Scalar(value), + ColumnarValue::Scalar(pattern), + ColumnarValue::Scalar(flags), + ] => { + let value = scalar_string(value)?; + let pattern = scalar_string(pattern)?; + let flags = scalar_string(flags)?; + regexp_like_scalar(value, pattern, flags) + } + [ColumnarValue::Array(values), ColumnarValue::Scalar(pattern)] => { + let pattern = scalar_string(pattern)?; + let array = regexp_like_array_scalar(values, pattern, None)?; + Ok(ColumnarValue::Array(array)) + } + [ + ColumnarValue::Array(values), + ColumnarValue::Scalar(pattern), + ColumnarValue::Scalar(flags), + ] => { + let flags = scalar_string(flags)?; + if flags.is_some_and(|flagz| flagz.contains('g')) { + plan_err!("regexp_like() does not support the \"global\" option") + } else { + let pattern = scalar_string(pattern)?; + let array = regexp_like_array_scalar(values, pattern, flags)?; + Ok(ColumnarValue::Array(array)) + } + } + _ => { + let args = ColumnarValue::values_to_arrays(args)?; + regexp_like(&args).map(ColumnarValue::Array) + } } } fn simplify( &self, mut args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { // Try to simplify regexp_like usage to one of the builtin operators since those have // optimized code paths for the case where the regular expression pattern is a scalar. @@ -276,43 +285,125 @@ pub fn regexp_like(args: &[ArrayRef]) -> Result { Utf8 => args[2].as_string::(), LargeUtf8 => { let large_string_array = args[2].as_string::(); - let string_vec: Vec> = (0..large_string_array.len()).map(|i| { - if large_string_array.is_null(i) { - None - } else { - Some(large_string_array.value(i)) - } - }) - .collect(); + let string_vec: Vec> = (0..large_string_array.len()) + .map(|i| { + if large_string_array.is_null(i) { + None + } else { + Some(large_string_array.value(i)) + } + }) + .collect(); &GenericStringArray::::from(string_vec) - }, + } _ => { let string_view_array = args[2].as_string_view(); - let string_vec: Vec> = (0..string_view_array.len()).map(|i| { - if string_view_array.is_null(i) { - None - } else { - Some(string_view_array.value(i).to_string()) - } - }) - .collect(); + let string_vec: Vec> = (0..string_view_array.len()) + .map(|i| { + if string_view_array.is_null(i) { + None + } else { + Some(string_view_array.value(i).to_string()) + } + }) + .collect(); &GenericStringArray::::from(string_vec) - }, + } }; - if flags.iter().any(|s| s == Some("g")) { + if flags + .iter() + .any(|s| s.is_some_and(|flagz| flagz.contains('g'))) + { return plan_err!("regexp_like() does not support the \"global\" option"); } handle_regexp_like(&args[0], &args[1], Some(flags)) - }, + } other => exec_err!( "`regexp_like` was called with {other} arguments. It requires at least 2 and at most 3." ), } } +fn scalar_string(value: &ScalarValue) -> Result> { + match value.try_as_str() { + Some(v) => Ok(v), + None => internal_err!( + "Unsupported data type {:?} for function `regexp_like`", + value.data_type() + ), + } +} + +fn regexp_like_array_scalar( + values: &ArrayRef, + pattern: Option<&str>, + flags: Option<&str>, +) -> Result { + use DataType::*; + + let Some(pattern) = pattern else { + return Ok(Arc::new(BooleanArray::new_null(values.len()))); + }; + let array = match values.data_type() { + Utf8 => { + let array = values.as_string::(); + regexp::regexp_is_match_scalar(array, pattern, flags)? + } + Utf8View => { + let array = values.as_string_view(); + regexp::regexp_is_match_scalar(array, pattern, flags)? + } + LargeUtf8 => { + let array = values.as_string::(); + regexp::regexp_is_match_scalar(array, pattern, flags)? + } + other => { + return internal_err!( + "Unsupported data type {other:?} for function `regexp_like`" + ); + } + }; + + Ok(Arc::new(array)) +} + +fn regexp_like_scalar( + value: Option<&str>, + pattern: Option<&str>, + flags: Option<&str>, +) -> Result { + if flags.is_some_and(|flagz| flagz.contains('g')) { + return plan_err!("regexp_like() does not support the \"global\" option"); + } + + if value.is_none() || pattern.is_none() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + + let value = value.unwrap(); + let pattern = pattern.unwrap(); + let pattern = match flags { + Some(flagz) => format!("(?{flagz}){pattern}"), + None => pattern.to_string(), + }; + + let result = if pattern.is_empty() { + true + } else { + let re = Regex::new(pattern.as_str()).map_err(|e| { + datafusion_common::DataFusionError::Execution(format!( + "Regular expression did not compile: {e:?}" + )) + })?; + re.is_match(value) + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(result)))) +} + fn handle_regexp_like( values: &ArrayRef, patterns: &ArrayRef, @@ -355,7 +446,7 @@ fn handle_regexp_like( .map_err(|e| arrow_datafusion_err!(e))? } (Utf8, LargeUtf8) => { - let value = values.as_string_view(); + let value = values.as_string::(); let pattern = patterns.as_string::(); regexp::regexp_is_match(value, pattern, flags) @@ -385,7 +476,7 @@ fn handle_regexp_like( other => { return internal_err!( "Unsupported data type {other:?} for function `regexp_like`" - ) + ); } }; @@ -398,8 +489,37 @@ mod tests { use arrow::array::StringArray; use arrow::array::{BooleanBuilder, StringViewArray}; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::config::ConfigOptions; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; - use crate::regex::regexplike::regexp_like; + use crate::regex::regexplike::{RegexpLikeFunc, regexp_like}; + + fn invoke_regexp_like(args: Vec) -> Result { + let number_rows = args + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }) + .unwrap_or(1); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Arc::new(Field::new(format!("arg_{idx}"), arg.data_type(), true)) + }) + .collect::>(); + + RegexpLikeFunc::new().invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows, + return_field: Arc::new(Field::new("f", DataType::Boolean, true)), + config_options: Arc::new(ConfigOptions::default()), + }) + } #[test] fn test_case_sensitive_regexp_like_utf8() { @@ -498,4 +618,66 @@ mod tests { "Error during planning: regexp_like() does not support the \"global\" option" ); } + + #[test] + fn test_regexp_like_scalar_invoke() { + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("foobarbequebaz".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("(bar)(beque)".to_string()))), + ]; + let result = invoke_regexp_like(args).unwrap(); + match result { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {} + other => panic!("Unexpected result {other:?}"), + } + } + + #[test] + fn test_regexp_like_array_scalar_invoke() { + let values = Arc::new(StringArray::from(vec!["abc", "xyz"])); + let args = vec![ + ColumnarValue::Array(values), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("^(a)".to_string()))), + ]; + let result = invoke_regexp_like(args).unwrap(); + let mut expected_builder = BooleanBuilder::new(); + expected_builder.append_value(true); + expected_builder.append_value(false); + let expected = expected_builder.finish(); + match result { + ColumnarValue::Array(array) => { + assert_eq!(array.as_ref(), &expected); + } + other => panic!("Unexpected result {other:?}"), + } + } + + #[test] + fn test_regexp_like_scalar_flags_with_global() { + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("abc".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("^(a)".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("ig".to_string()))), + ]; + let err = invoke_regexp_like(args).expect_err("global flag should be rejected"); + assert_eq!( + err.strip_backtrace(), + "Error during planning: regexp_like() does not support the \"global\" option" + ); + } + + #[test] + fn test_regexp_like_array_scalar_flags_with_global() { + let values = Arc::new(StringArray::from(vec!["abc", "xyz"])); + let args = vec![ + ColumnarValue::Array(values), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("^(a)".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("ig".to_string()))), + ]; + let err = invoke_regexp_like(args).expect_err("global flag should be rejected"); + assert_eq!( + err.strip_backtrace(), + "Error during planning: regexp_like() does not support the \"global\" option" + ); + } } diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index ba52822a02f8c..34153d9c8ab96 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -20,14 +20,13 @@ use arrow::array::{Array, ArrayRef, AsArray}; use arrow::compute::kernels::regexp; use arrow::datatypes::DataType; use arrow::datatypes::Field; -use datafusion_common::exec_err; +use datafusion_common::Result; use datafusion_common::ScalarValue; +use datafusion_common::exec_err; use datafusion_common::{arrow_datafusion_err, plan_err}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{ColumnarValue, Documentation, TypeSignature}; +use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs, TypeSignature}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; #[user_doc( @@ -48,7 +47,7 @@ use std::sync::Arc; | [B] | +---------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/regexp.rs) "#, standard_argument(name = "str", prefix = "String"), argument( @@ -100,10 +99,6 @@ impl RegexpMatchFunc { } impl ScalarUDFImpl for RegexpMatchFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "regexp_match" } @@ -119,10 +114,7 @@ impl ScalarUDFImpl for RegexpMatchFunc { }) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = &args.args; let len = args .iter() @@ -155,29 +147,35 @@ impl ScalarUDFImpl for RegexpMatchFunc { pub fn regexp_match(args: &[ArrayRef]) -> Result { match args.len() { - 2 => { - regexp::regexp_match(&args[0], &args[1], None) - .map_err(|e| arrow_datafusion_err!(e)) - } + 2 => regexp::regexp_match(&args[0], &args[1], None) + .map_err(|e| arrow_datafusion_err!(e)), 3 => { match args[2].data_type() { DataType::Utf8View => { if args[2].as_string_view().iter().any(|s| s == Some("g")) { - return plan_err!("regexp_match() does not support the \"global\" option"); + return plan_err!( + "regexp_match() does not support the \"global\" option" + ); } } DataType::Utf8 => { if args[2].as_string::().iter().any(|s| s == Some("g")) { - return plan_err!("regexp_match() does not support the \"global\" option"); + return plan_err!( + "regexp_match() does not support the \"global\" option" + ); } } DataType::LargeUtf8 => { if args[2].as_string::().iter().any(|s| s == Some("g")) { - return plan_err!("regexp_match() does not support the \"global\" option"); + return plan_err!( + "regexp_match() does not support the \"global\" option" + ); } } e => { - return plan_err!("regexp_match was called with unexpected data type {e:?}"); + return plan_err!( + "regexp_match was called with unexpected data type {e:?}" + ); } } @@ -254,6 +252,9 @@ mod tests { regexp_match(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) .expect_err("unsupported flag should have failed"); - assert_eq!(re_err.strip_backtrace(), "Error during planning: regexp_match() does not support the \"global\" option"); + assert_eq!( + re_err.strip_backtrace(), + "Error during planning: regexp_match() does not support the \"global\" option" + ); } } diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index 29da195c7a928..215dd33324375 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -16,30 +16,34 @@ // under the License. //! Regex expressions +use memchr::memchr; + use arrow::array::ArrayDataBuilder; use arrow::array::BufferBuilder; use arrow::array::GenericStringArray; use arrow::array::StringViewBuilder; -use arrow::array::{new_null_array, ArrayIter, AsArray}; use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; use arrow::array::{ArrayAccessor, StringViewArray}; +use arrow::array::{ArrayIter, AsArray, new_null_array}; use arrow::datatypes::DataType; +use datafusion_common::ScalarValue; use datafusion_common::cast::{ as_large_string_array, as_string_array, as_string_view_array, }; use datafusion_common::exec_err; use datafusion_common::plan_err; -use datafusion_common::ScalarValue; use datafusion_common::{ - cast::as_generic_string_array, internal_err, DataFusionError, Result, + DataFusionError, Result, cast::as_generic_string_array, internal_err, }; -use datafusion_expr::function::Hint; use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature; -use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::function::Hint; +use datafusion_expr::{ + Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; use datafusion_macros::user_doc; -use regex::Regex; -use std::any::Any; +use regex::{CaptureLocations, Regex}; +use std::borrow::Cow; use std::collections::HashMap; use std::sync::{Arc, LazyLock}; @@ -61,7 +65,7 @@ SELECT regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i'); | aAbBac | +-------------------------------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/builtin_functions/regexp.rs) "#, standard_argument(name = "str", prefix = "String"), argument( @@ -111,10 +115,6 @@ impl RegexpReplaceFunc { } impl ScalarUDFImpl for RegexpReplaceFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "regexp_replace" } @@ -149,10 +149,7 @@ impl ScalarUDFImpl for RegexpReplaceFunc { }) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = &args.args; let len = args @@ -189,16 +186,114 @@ fn regexp_replace_func(args: &[ColumnarValue]) -> Result { } } -/// replace POSIX capture groups (like \1) with Rust Regex group (like ${1}) +/// replace POSIX capture groups (like \1 or \\1) with Rust Regex group (like ${1}) /// used by regexp_replace +/// Handles both single backslash (\1) and double backslash (\\1) which can occur +/// when SQL strings with escaped backslashes are passed through +/// +/// Note: \0 is converted to ${0}, which in Rust's regex replacement syntax +/// substitutes the entire match. This is consistent with POSIX behavior where +/// \0 (or &) refers to the entire matched string. fn regex_replace_posix_groups(replacement: &str) -> String { static CAPTURE_GROUPS_RE_LOCK: LazyLock = - LazyLock::new(|| Regex::new(r"(\\)(\d*)").unwrap()); + LazyLock::new(|| Regex::new(r"\\{1,2}(\d+)").unwrap()); CAPTURE_GROUPS_RE_LOCK - .replace_all(replacement, "$${$2}") + .replace_all(replacement, "$${$1}") .into_owned() } +struct ShortRegex { + /// Shortened anchored regex used to extract capture group 1 directly. + /// See [`try_build_short_extract_regex`] for details. + short_re: Regex, + /// Reusable capture locations for `short_re` to avoid per-row allocation. + locs: CaptureLocations, +} + +/// Holds the normal compiled regex together with the optional fast path used +/// for `regexp_replace(str, '^...(capture)...*$', '\1')`. +struct OptimizedRegex { + /// Full regex used for the normal replacement path and as a correctness fallback. + re: Regex, + /// Precomputed state for the direct-extraction fast path, when applicable. + short_re: Option, +} + +impl OptimizedRegex { + /// Builds any reusable state needed by the extraction fast path. + /// + /// The fast path is only enabled for single replacements where the pattern + /// and replacement satisfy [`try_build_short_extract_regex`]. + fn new(re: Regex, limit: usize, pattern: &str, replacement: &str) -> Self { + let short_re = if limit == 1 { + try_build_short_extract_regex(pattern, replacement) + } else { + None + }; + + let short_re = short_re.map(|short_re| { + let locs = short_re.capture_locations(); + ShortRegex { short_re, locs } + }); + + Self { re, short_re } + } + + /// Applies the direct-extraction fast path when it preserves the result of + /// `Regex::replacen`; otherwise falls back to the full regex replacement. + fn replacen<'a>( + &mut self, + val: &'a str, + limit: usize, + replacement: &str, + ) -> Cow<'a, str> { + // If this pattern is not eligible for direct extraction, use the full regex. + let Some(ShortRegex { short_re, locs }) = self.short_re.as_mut() else { + return self.re.replacen(val, limit, replacement); + }; + + // If the shortened regex does not match, the original anchored regex would + // also leave the input unchanged. + if short_re.captures_read(locs, val).is_none() { + return Cow::Borrowed(val); + }; + + // `captures_read` succeeded, so the overall shortened match is present. + let match_end = locs.get(0).unwrap().1; + if memchr(b'\n', &val.as_bytes()[match_end..]).is_some() { + // If there is a newline after the match, we can't use the short + // regex since it won't match across lines. Fall back to the full + // regex replacement. + return self.re.replacen(val, limit, replacement); + }; + // The fast path only applies to `${1}` replacements, so the result is + // either capture group 1 or the empty string if that group did not match. + if let Some((start, end)) = locs.get(1) { + Cow::Borrowed(&val[start..end]) + } else { + Cow::Borrowed("") + } + } +} + +/// For anchored patterns like `^...(capture)....*$` where the replacement +/// is `\1`, build a shorter regex (stripping trailing `.*$`) and use +/// `captures_read` with `CaptureLocations` for direct extraction — no +/// `expand()`, no `String` allocation. +/// This pattern appears in ClickBench Q28: which uses a regexp like +/// `^https?://(?:www\.)?([^/]+)/.*$` +fn try_build_short_extract_regex(pattern: &str, replacement: &str) -> Option { + if replacement != "${1}" || !pattern.starts_with('^') || !pattern.ends_with(".*$") { + return None; + } + let short = &pattern[..pattern.len() - 3]; + let re = Regex::new(short).ok()?; + if re.captures_len() != 2 { + return None; + } + Some(re) +} + /// Replaces substring(s) matching a PCRE-like regular expression. /// /// The full list of supported features and syntax can be found at @@ -422,7 +517,7 @@ macro_rules! fetch_string_arg { /// hold a single Regex object for the replace operation. This also speeds /// up the pre-processing time of the replacement string, since it only /// needs to processed once. -fn _regexp_replace_static_pattern_replace( +fn regexp_replace_static_pattern_replace( args: &[ArrayRef], ) -> Result { let array_size = args[0].len(); @@ -434,7 +529,7 @@ fn _regexp_replace_static_pattern_replace( other => { return exec_err!( "regexp_replace was called with {other} arguments. It requires at least 3 and at most 4." - ) + ); } }; @@ -457,6 +552,8 @@ fn _regexp_replace_static_pattern_replace( // with rust ones. let replacement = regex_replace_posix_groups(replacement); + let mut opt_re = OptimizedRegex::new(re, limit, &pattern, &replacement); + let string_array_type = args[0].data_type(); match string_array_type { DataType::Utf8 | DataType::LargeUtf8 => { @@ -475,7 +572,7 @@ fn _regexp_replace_static_pattern_replace( string_array.iter().for_each(|val| { if let Some(val) = val { - let result = re.replacen(val, limit, replacement.as_str()); + let result = opt_re.replacen(val, limit, replacement.as_str()); vals.append_slice(result.as_bytes()); } new_offsets.append(T::from_usize(vals.len()).unwrap()); @@ -496,8 +593,8 @@ fn _regexp_replace_static_pattern_replace( for val in string_view_array.iter() { if let Some(val) = val { - let result = re.replacen(val, limit, replacement.as_str()); - builder.append_value(result); + let result = opt_re.replacen(val, limit, replacement.as_str()); + builder.append_value(result.as_ref()); } else { builder.append_null(); } @@ -576,7 +673,7 @@ fn specialize_regexp_replace( arg.to_array(expansion_len) }) .collect::>>()?; - _regexp_replace_static_pattern_replace::(&args) + regexp_replace_static_pattern_replace::(&args) } // If there are no specialized implementations, we'll fall back to the @@ -659,6 +756,42 @@ mod tests { use super::*; + #[test] + fn test_regex_replace_posix_groups() { + // Test that \1, \2, etc. are replaced with ${1}, ${2}, etc. + assert_eq!(regex_replace_posix_groups(r"\1"), "${1}"); + assert_eq!(regex_replace_posix_groups(r"\12"), "${12}"); + assert_eq!(regex_replace_posix_groups(r"X\1Y"), "X${1}Y"); + assert_eq!(regex_replace_posix_groups(r"\1\2"), "${1}${2}"); + + // Test double backslash (from SQL escaped strings like '\\1') + assert_eq!(regex_replace_posix_groups(r"\\1"), "${1}"); + assert_eq!(regex_replace_posix_groups(r"X\\1Y"), "X${1}Y"); + assert_eq!(regex_replace_posix_groups(r"\\1\\2"), "${1}${2}"); + + // Test 3 or 4 backslashes before digits to document expected behavior + assert_eq!(regex_replace_posix_groups(r"\\\1"), r"\${1}"); + assert_eq!(regex_replace_posix_groups(r"\\\\1"), r"\\${1}"); + assert_eq!(regex_replace_posix_groups(r"\\\1\\\\2"), r"\${1}\\${2}"); + + // Test that a lone backslash is NOT replaced (requires at least one digit) + assert_eq!(regex_replace_posix_groups(r"\"), r"\"); + assert_eq!(regex_replace_posix_groups(r"foo\bar"), r"foo\bar"); + + // Test that backslash followed by non-digit is preserved + assert_eq!(regex_replace_posix_groups(r"\n"), r"\n"); + assert_eq!(regex_replace_posix_groups(r"\t"), r"\t"); + + // Test \0 behavior: \0 is converted to ${0}, which in Rust's regex + // replacement syntax substitutes the entire match. This is consistent + // with POSIX behavior where \0 (or &) refers to the entire matched string. + assert_eq!(regex_replace_posix_groups(r"\0"), "${0}"); + assert_eq!( + regex_replace_posix_groups(r"prefix\0suffix"), + "prefix${0}suffix" + ); + } + macro_rules! static_pattern_regexp_replace { ($name:ident, $T:ty, $O:ty) => { #[test] @@ -674,7 +807,7 @@ mod tests { let replacements = <$T>::from(replacement); let expected = <$T>::from(expected); - let re = _regexp_replace_static_pattern_replace::<$O>(&[ + let re = regexp_replace_static_pattern_replace::<$O>(&[ Arc::new(values), Arc::new(patterns), Arc::new(replacements), @@ -719,7 +852,7 @@ mod tests { let flags = StringArray::from(vec!["i"; 5]); let expected = <$T>::from(expected); - let re = _regexp_replace_static_pattern_replace::<$O>(&[ + let re = regexp_replace_static_pattern_replace::<$O>(&[ Arc::new(values), Arc::new(patterns), Arc::new(replacements), @@ -751,7 +884,7 @@ mod tests { let replacements = StringArray::from(vec!["foo"; 5]); let expected = StringArray::from(vec![None::<&str>; 5]); - let re = _regexp_replace_static_pattern_replace::(&[ + let re = regexp_replace_static_pattern_replace::(&[ Arc::new(values), Arc::new(patterns), Arc::new(replacements), @@ -768,7 +901,7 @@ mod tests { let replacements = StringArray::from(Vec::>::new()); let expected = StringArray::from(Vec::>::new()); - let re = _regexp_replace_static_pattern_replace::(&[ + let re = regexp_replace_static_pattern_replace::(&[ Arc::new(values), Arc::new(patterns), Arc::new(replacements), @@ -786,7 +919,7 @@ mod tests { let flags = StringArray::from(vec![None::<&str>; 5]); let expected = StringArray::from(vec![None::<&str>; 5]); - let re = _regexp_replace_static_pattern_replace::(&[ + let re = regexp_replace_static_pattern_replace::(&[ Arc::new(values), Arc::new(patterns), Arc::new(replacements), @@ -805,7 +938,7 @@ mod tests { let patterns = StringArray::from(vec!["["; 5]); let replacements = StringArray::from(vec!["foo"; 5]); - let re = _regexp_replace_static_pattern_replace::(&[ + let re = regexp_replace_static_pattern_replace::(&[ Arc::new(values), Arc::new(patterns), Arc::new(replacements), @@ -842,7 +975,7 @@ mod tests { Some("c"), ]); - let re = _regexp_replace_static_pattern_replace::(&[ + let re = regexp_replace_static_pattern_replace::(&[ Arc::new(values), Arc::new(patterns), Arc::new(replacements), @@ -870,7 +1003,7 @@ mod tests { let replacements = StringArray::from(vec!["foo"; 1]); let expected = StringArray::from(vec![Some("b"), None, Some("foo"), None, None]); - let re = _regexp_replace_static_pattern_replace::(&[ + let re = regexp_replace_static_pattern_replace::(&[ Arc::new(values), Arc::new(patterns), Arc::new(replacements), diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 8b55f8fb8f7bc..bb5a8d0125a70 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -15,17 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::make_scalar_function; use arrow::array::{ArrayRef, AsArray, Int32Array, StringArrayType}; use arrow::datatypes::DataType; use arrow::error::ArrowError; use datafusion_common::types::logical_string; -use datafusion_common::{internal_err, Result}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::{ColumnarValue, Documentation, TypeSignatureClass}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_expr_common::signature::Coercion; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; #[user_doc( @@ -74,10 +73,6 @@ impl AsciiFunc { } impl ScalarUDFImpl for AsciiFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "ascii" } @@ -91,7 +86,31 @@ impl ScalarUDFImpl for AsciiFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(ascii, vec![])(&args.args) + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))); + } + + match scalar { + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => { + let result = s.chars().next().map_or(0, |c| c as i32); + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result)))) + } + _ => { + internal_err!( + "Unexpected data type {:?} for function ascii", + scalar.data_type() + ) + } + } + } + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(ascii(&[array])?)), + } } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index 1578331e57f89..76d8bb73bba87 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -17,7 +17,6 @@ use arrow::compute::kernels::length::bit_length; use arrow::datatypes::DataType; -use std::any::Any; use crate::utils::utf8_to_int_type; use datafusion_common::types::logical_string; @@ -70,10 +69,6 @@ impl BitLengthFunc { } impl ScalarUDFImpl for BitLengthFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "bit_length" } diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index a7fbdb3c69213..279f444d9ffe7 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -20,17 +20,16 @@ use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; use datafusion_common::types::logical_string; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::function::Hint; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; -/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. +/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, spaces are removed. /// btrim('xyxtrimyyx', 'xyz') = 'trim' fn btrim(args: &[ArrayRef]) -> Result { let use_string_view = args[0].data_type() == &DataType::Utf8View; @@ -40,12 +39,12 @@ fn btrim(args: &[ArrayRef]) -> Result { } else { args.to_owned() }; - general_trim::(&args, TrimType::Both, use_string_view) + general_trim::(&args, use_string_view) } #[user_doc( doc_section(label = "String Functions"), - description = "Trims the specified trim string from the start and end of a string. If no trim string is provided, all whitespace is removed from the start and end of the input string.", + description = "Trims the specified trim string from the start and end of a string. If no trim string is provided, all spaces are removed from the start and end of the input string.", syntax_example = "btrim(str[, trim_str])", sql_example = r#"```sql > select btrim('__datafusion____', '_'); @@ -58,7 +57,7 @@ fn btrim(args: &[ArrayRef]) -> Result { standard_argument(name = "str", prefix = "String"), argument( name = "trim_str", - description = r"String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is whitespace characters._" + description = r"String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is a space._" ), alternative_syntax = "trim(BOTH trim_str FROM str)", alternative_syntax = "trim(trim_str FROM str)", @@ -98,10 +97,6 @@ impl BTrimFunc { } impl ScalarUDFImpl for BTrimFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "btrim" } diff --git a/datafusion/functions/src/string/chr.rs b/datafusion/functions/src/string/chr.rs index 8706c43214ea5..60df0c47cfa14 100644 --- a/datafusion/functions/src/string/chr.rs +++ b/datafusion/functions/src/string/chr.rs @@ -15,56 +15,64 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; -use arrow::array::ArrayRef; -use arrow::array::GenericStringBuilder; +use arrow::array::{Array, ArrayRef, Int64Array}; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use arrow::datatypes::DataType::Utf8; -use crate::utils::make_scalar_function; +use crate::strings::GenericStringArrayBuilder; use datafusion_common::cast::as_int64_array; -use datafusion_common::{exec_err, Result}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; /// Returns the character with the given code. /// chr(65) = 'A' -fn chr(args: &[ArrayRef]) -> Result { - let integer_array = as_int64_array(&args[0])?; - - let mut builder = GenericStringBuilder::::with_capacity( - integer_array.len(), - // 1 byte per character, assuming that is the common case - integer_array.len(), +fn chr_array(integer_array: &Int64Array) -> Result { + let len = integer_array.len(); + let mut builder = GenericStringArrayBuilder::::with_capacity( + len, // 1 byte per character, assuming that is the common case + len, ); let mut buf = [0u8; 4]; + let nulls = integer_array.nulls(); - for integer in integer_array { - match integer { - Some(integer) => { - if let Ok(u) = u32::try_from(integer) { - if let Some(c) = core::char::from_u32(u) { - builder.append_value(c.encode_utf8(&mut buf)); - continue; - } - } - - return exec_err!("invalid Unicode scalar value: {integer}"); + if let Some(n) = nulls { + for i in 0..len { + if n.is_null(i) { + builder.append_placeholder(); + continue; + } + // SAFETY: bounds + null check above. + let integer = unsafe { integer_array.value_unchecked(i) }; + if let Ok(u) = u32::try_from(integer) + && let Some(c) = core::char::from_u32(u) + { + builder.append_value(c.encode_utf8(&mut buf)); + continue; } - None => { - builder.append_null(); + return exec_err!("invalid Unicode scalar value: {integer}"); + } + } else { + for i in 0..len { + // SAFETY: no null buffer means every index is valid. + let integer = unsafe { integer_array.value_unchecked(i) }; + if let Ok(u) = u32::try_from(integer) + && let Some(c) = core::char::from_u32(u) + { + builder.append_value(c.encode_utf8(&mut buf)); + continue; } + return exec_err!("invalid Unicode scalar value: {integer}"); } } - let result = builder.finish(); - - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(builder.finish(nulls.cloned())?) as ArrayRef) } #[user_doc( @@ -102,10 +110,6 @@ impl ChrFunc { } impl ScalarUDFImpl for ChrFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "chr" } @@ -119,7 +123,32 @@ impl ScalarUDFImpl for ChrFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(chr, vec![])(&args.args) + let [arg] = take_function_args(self.name(), args.args)?; + + match arg { + ColumnarValue::Scalar(ScalarValue::Int64(Some(code_point))) => { + if let Ok(u) = u32::try_from(code_point) + && let Some(c) = core::char::from_u32(u) + { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + c.to_string(), + )))) + } else { + exec_err!("invalid Unicode scalar value: {code_point}") + } + } + ColumnarValue::Scalar(ScalarValue::Int64(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + } + ColumnarValue::Array(array) => { + let integer_array = as_int64_array(&array)?; + Ok(ColumnarValue::Array(chr_array(integer_array)?)) + } + other => internal_err!( + "Unexpected data type {:?} for function chr", + other.data_type() + ), + } } fn documentation(&self) -> Option<&Documentation> { @@ -130,13 +159,26 @@ impl ScalarUDFImpl for ChrFunc { #[cfg(test)] mod tests { use super::*; - use arrow::array::{Array, Int64Array, StringArray}; + + use arrow::array::{Array, StringArray}; + use arrow::datatypes::Field; use datafusion_common::assert_contains; + use datafusion_common::config::ConfigOptions; + + fn invoke_chr(arg: ColumnarValue, number_rows: usize) -> Result { + ChrFunc::new().invoke_with_args(ScalarFunctionArgs { + args: vec![arg], + arg_fields: vec![Field::new("a", Int64, true).into()], + number_rows, + return_field: Field::new("f", Utf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + } #[test] fn test_chr_normal() { let input = Arc::new(Int64Array::from(vec![ - Some(0), // null + Some(0), // \u{0000} Some(65), // A Some(66), // B Some(67), // C @@ -149,8 +191,13 @@ mod tests { Some(9), // tab Some(0x10FFFF), // 0x10FFFF, the largest Unicode code point ])); - let result = chr(&[input]).unwrap(); - let string_array = result.as_any().downcast_ref::().unwrap(); + + let result = invoke_chr(ColumnarValue::Array(input), 12).unwrap(); + let ColumnarValue::Array(arr) = result else { + panic!("Expected array"); + }; + let string_array = arr.as_any().downcast_ref::().unwrap(); + let expected = [ "\u{0000}", "A", @@ -167,62 +214,61 @@ mod tests { ]; assert_eq!(string_array.len(), expected.len()); + assert_eq!(string_array.null_count(), 1); + assert!(string_array.is_null(7)); for (i, e) in expected.iter().enumerate() { + if i == 7 { + continue; + } + assert!(!string_array.is_null(i)); assert_eq!(string_array.value(i), *e); } } #[test] fn test_chr_error() { - // invalid Unicode code points (too large) let input = Arc::new(Int64Array::from(vec![i64::MAX])); - let result = chr(&[input]); + let result = invoke_chr(ColumnarValue::Array(input), 1); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), "invalid Unicode scalar value: 9223372036854775807" ); - // invalid Unicode code points (too large) case 2 let input = Arc::new(Int64Array::from(vec![0x10FFFF + 1])); - let result = chr(&[input]); + let result = invoke_chr(ColumnarValue::Array(input), 1); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), "invalid Unicode scalar value: 1114112" ); - // invalid Unicode code points (surrogate code point) - // link: let input = Arc::new(Int64Array::from(vec![0xD800 + 1])); - let result = chr(&[input]); + let result = invoke_chr(ColumnarValue::Array(input), 1); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), "invalid Unicode scalar value: 55297" ); - // negative input - let input = Arc::new(Int64Array::from(vec![i64::MIN + 2i64])); // will be 2 if cast to u32 - let result = chr(&[input]); + let input = Arc::new(Int64Array::from(vec![i64::MIN + 2i64])); + let result = invoke_chr(ColumnarValue::Array(input), 1); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), "invalid Unicode scalar value: -9223372036854775806" ); - // negative input case 2 let input = Arc::new(Int64Array::from(vec![-1])); - let result = chr(&[input]); + let result = invoke_chr(ColumnarValue::Array(input), 1); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), "invalid Unicode scalar value: -1" ); - // one error with valid values after - let input = Arc::new(Int64Array::from(vec![65, -1, 66])); // A, -1, B - let result = chr(&[input]); + let input = Arc::new(Int64Array::from(vec![65, -1, 66])); + let result = invoke_chr(ColumnarValue::Array(input), 3); assert!(result.is_err()); assert_contains!( result.err().unwrap().to_string(), @@ -232,10 +278,36 @@ mod tests { #[test] fn test_chr_empty() { - // empty input array let input = Arc::new(Int64Array::from(Vec::::new())); - let result = chr(&[input]).unwrap(); - let string_array = result.as_any().downcast_ref::().unwrap(); + let result = invoke_chr(ColumnarValue::Array(input), 0).unwrap(); + let ColumnarValue::Array(arr) = result else { + panic!("Expected array"); + }; + let string_array = arr.as_any().downcast_ref::().unwrap(); assert_eq!(string_array.len(), 0); } + + #[test] + fn test_chr_scalar() { + let result = + invoke_chr(ColumnarValue::Scalar(ScalarValue::Int64(Some(65))), 1).unwrap(); + + match result { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + assert_eq!(s, "A"); + } + other => panic!("Unexpected result: {other:?}"), + } + } + + #[test] + fn test_chr_scalar_null() { + let result = + invoke_chr(ColumnarValue::Scalar(ScalarValue::Int64(None)), 1).unwrap(); + + match result { + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} + other => panic!("Unexpected result: {other:?}"), + } + } } diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 34f1b6232d412..6ecd41b0b9a5c 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -17,176 +17,203 @@ //! Common utilities for implementing string functions -use std::fmt::{Display, Formatter}; use std::sync::Arc; -use crate::strings::make_and_append_view; +use crate::strings::{ + GenericStringArrayBuilder, STRING_VIEW_INIT_BLOCK_SIZE, STRING_VIEW_MAX_BLOCK_SIZE, + StringViewArrayBuilder, append_view, +}; use arrow::array::{ - new_null_array, Array, ArrayRef, GenericStringArray, GenericStringBuilder, - NullBufferBuilder, OffsetSizeTrait, StringBuilder, StringViewArray, + Array, ArrayRef, GenericStringArray, NullBufferBuilder, OffsetSizeTrait, + StringViewArray, new_null_array, }; -use arrow::buffer::{Buffer, ScalarBuffer}; +use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer}; use arrow::datatypes::DataType; -use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::Result; -use datafusion_common::{exec_err, ScalarValue}; +use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; +use datafusion_common::{ScalarValue, exec_err}; use datafusion_expr::ColumnarValue; -#[derive(Copy, Clone)] -pub(crate) enum TrimType { - Left, - Right, - Both, +/// Trait for trim operations, allowing compile-time dispatch instead of runtime matching. +/// +/// Each implementation performs its specific trim operation and returns +/// (trimmed_str, start_offset) where start_offset is the byte offset +/// from the beginning of the input string where the trimmed result starts. +pub(crate) trait Trimmer { + fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32); + + /// Optimized trim for a single ASCII byte. + /// Uses byte-level scanning instead of char-level iteration. + fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32); } -impl Display for TrimType { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - TrimType::Left => write!(f, "ltrim"), - TrimType::Right => write!(f, "rtrim"), - TrimType::Both => write!(f, "btrim"), +/// Returns the number of leading bytes matching `byte` +#[inline] +fn leading_bytes(bytes: &[u8], byte: u8) -> usize { + bytes.iter().take_while(|&&b| b == byte).count() +} + +/// Returns the number of trailing bytes matching `byte` +#[inline] +fn trailing_bytes(bytes: &[u8], byte: u8) -> usize { + bytes.iter().rev().take_while(|&&b| b == byte).count() +} + +/// Left trim - removes leading characters +pub(crate) struct TrimLeft; + +impl Trimmer for TrimLeft { + #[inline] + fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) { + if pattern.len() == 1 && pattern[0].is_ascii() { + return Self::trim_ascii_char(input, pattern[0] as u8); } + let trimmed = input.trim_start_matches(pattern); + let offset = (input.len() - trimmed.len()) as u32; + (trimmed, offset) + } + + #[inline] + fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) { + let start = leading_bytes(input.as_bytes(), byte); + (&input[start..], start as u32) + } +} + +/// Right trim - removes trailing characters +pub(crate) struct TrimRight; + +impl Trimmer for TrimRight { + #[inline] + fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) { + if pattern.len() == 1 && pattern[0].is_ascii() { + return Self::trim_ascii_char(input, pattern[0] as u8); + } + let trimmed = input.trim_end_matches(pattern); + (trimmed, 0) + } + + #[inline] + fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) { + let bytes = input.as_bytes(); + let end = bytes.len() - trailing_bytes(bytes, byte); + (&input[..end], 0) + } +} + +/// Both trim - removes both leading and trailing characters +pub(crate) struct TrimBoth; + +impl Trimmer for TrimBoth { + #[inline] + fn trim<'a>(input: &'a str, pattern: &[char]) -> (&'a str, u32) { + if pattern.len() == 1 && pattern[0].is_ascii() { + return Self::trim_ascii_char(input, pattern[0] as u8); + } + let left_trimmed = input.trim_start_matches(pattern); + let offset = (input.len() - left_trimmed.len()) as u32; + let trimmed = left_trimmed.trim_end_matches(pattern); + (trimmed, offset) + } + + #[inline] + fn trim_ascii_char(input: &str, byte: u8) -> (&str, u32) { + let bytes = input.as_bytes(); + let start = leading_bytes(bytes, byte); + let end = bytes.len() - trailing_bytes(&bytes[start..], byte); + (&input[start..end], start as u32) } } -pub(crate) fn general_trim( +pub(crate) fn general_trim( args: &[ArrayRef], - trim_type: TrimType, use_string_view: bool, ) -> Result { - let func = match trim_type { - TrimType::Left => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - let ltrimmed_str = - str::trim_start_matches::<&[char]>(input, pattern.as_ref()); - // `ltrimmed_str` is actually `input`[start_offset..], - // so `start_offset` = len(`input`) - len(`ltrimmed_str`) - let start_offset = input.len() - ltrimmed_str.len(); - - (ltrimmed_str, start_offset as u32) - }, - TrimType::Right => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - let rtrimmed_str = str::trim_end_matches::<&[char]>(input, pattern.as_ref()); - - // `ltrimmed_str` is actually `input`[0..new_len], so `start_offset` is 0 - (rtrimmed_str, 0) - }, - TrimType::Both => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - let ltrimmed_str = - str::trim_start_matches::<&[char]>(input, pattern.as_ref()); - // `btrimmed_str` can be got by rtrim(ltrim(`input`)), - // so its `start_offset` should be same as ltrim situation above - let start_offset = input.len() - ltrimmed_str.len(); - let btrimmed_str = - str::trim_end_matches::<&[char]>(ltrimmed_str, pattern.as_ref()); - - (btrimmed_str, start_offset as u32) - }, - }; - if use_string_view { - string_view_trim(func, args) + string_view_trim::(args) } else { - string_trim::(func, args) + string_trim::(args) } } /// Applies the trim function to the given string view array(s) /// and returns a new string view array with the trimmed values. /// -/// # `trim_func`: The function to apply to each string view. -/// -/// ## Arguments -/// - The original string -/// - the pattern to trim -/// -/// ## Returns -/// - trimmed str (must be a substring of the first argument) -/// - start offset, needed in `string_view_trim` -/// -/// ## Examples -/// -/// For `ltrim`: -/// - `fn(" abc", " ") -> ("abc", 2)` -/// - `fn("abd", " ") -> ("abd", 0)` -/// -/// For `btrim`: -/// - `fn(" abc ", " ") -> ("abc", 2)` -/// - `fn("abd", " ") -> ("abd", 0)` -// removing 'a will cause compiler complaining lifetime of `func` -fn string_view_trim<'a>( - trim_func: fn(&'a str, &'a str) -> (&'a str, u32), - args: &'a [ArrayRef], -) -> Result { +/// Pre-computes the pattern characters once for scalar patterns to avoid +/// repeated allocations per row. +fn string_view_trim(args: &[ArrayRef]) -> Result { let string_view_array = as_string_view_array(&args[0])?; let mut views_buf = Vec::with_capacity(string_view_array.len()); let mut null_builder = NullBufferBuilder::new(string_view_array.len()); match args.len() { 1 => { - let array_iter = string_view_array.iter(); - let views_iter = string_view_array.views().iter(); - for (src_str_opt, raw_view) in array_iter.zip(views_iter) { - trim_and_append_str( - src_str_opt, - Some(" "), - trim_func, - &mut views_buf, - &mut null_builder, - raw_view, - ); + // Trim spaces by default + for (src_str_opt, raw_view) in string_view_array + .iter() + .zip(string_view_array.views().iter()) + { + if let Some(src_str) = src_str_opt { + let (trimmed, offset) = Tr::trim_ascii_char(src_str, b' '); + append_view(&mut views_buf, raw_view, trimmed, offset); + null_builder.append_non_null(); + } else { + null_builder.append_null(); + views_buf.push(0); + } } } 2 => { let characters_array = as_string_view_array(&args[1])?; if characters_array.len() == 1 { - // Only one `trim characters` exist + // Scalar pattern - pre-compute pattern chars once if characters_array.is_null(0) { return Ok(new_null_array( - // The schema is expecting utf8 as null &DataType::Utf8View, string_view_array.len(), )); } - let characters = characters_array.value(0); - let array_iter = string_view_array.iter(); - let views_iter = string_view_array.views().iter(); - for (src_str_opt, raw_view) in array_iter.zip(views_iter) { - trim_and_append_str( + let pattern: Vec = characters_array.value(0).chars().collect(); + for (src_str_opt, raw_view) in string_view_array + .iter() + .zip(string_view_array.views().iter()) + { + trim_and_append_view::( src_str_opt, - Some(characters), - trim_func, + &pattern, &mut views_buf, &mut null_builder, raw_view, ); } } else { - // A specific `trim characters` for a row in the string view array - let characters_iter = characters_array.iter(); - let array_iter = string_view_array.iter(); - let views_iter = string_view_array.views().iter(); - for ((src_str_opt, raw_view), characters_opt) in - array_iter.zip(views_iter).zip(characters_iter) + // Per-row pattern - must compute pattern chars for each row + let mut pattern: Vec = Vec::new(); + for ((src_str_opt, raw_view), characters_opt) in string_view_array + .iter() + .zip(string_view_array.views().iter()) + .zip(characters_array.iter()) { - trim_and_append_str( - src_str_opt, - characters_opt, - trim_func, - &mut views_buf, - &mut null_builder, - raw_view, - ); + if let (Some(src_str), Some(characters)) = + (src_str_opt, characters_opt) + { + pattern.clear(); + pattern.extend(characters.chars()); + let (trimmed, offset) = Tr::trim(src_str, &pattern); + append_view(&mut views_buf, raw_view, trimmed, offset); + null_builder.append_non_null(); + } else { + null_builder.append_null(); + views_buf.push(0); + } } } } other => { return exec_err!( - "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2." + "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2." ); } } @@ -211,33 +238,24 @@ fn string_view_trim<'a>( /// Trims the given string and appends the trimmed string to the views buffer /// and the null buffer. /// -/// Calls `trim_func` on the string value in `original_view`, for non_null -/// values and appends the updated view to the views buffer / null_builder. -/// /// Arguments /// - `src_str_opt`: The original string value (represented by the view) -/// - `trim_characters_opt`: The characters to trim from the string -/// - `trim_func`: The function to apply to the string (see [`string_view_trim`] for details) +/// - `pattern`: Pre-computed character pattern to trim /// - `views_buf`: The buffer to append the updated views to /// - `null_builder`: The buffer to append the null values to /// - `original_view`: The original view value (that contains src_str_opt) -fn trim_and_append_str<'a>( - src_str_opt: Option<&'a str>, - trim_characters_opt: Option<&'a str>, - trim_func: fn(&'a str, &'a str) -> (&'a str, u32), +#[inline] +fn trim_and_append_view( + src_str_opt: Option<&str>, + pattern: &[char], views_buf: &mut Vec, null_builder: &mut NullBufferBuilder, original_view: &u128, ) { - if let (Some(src_str), Some(characters)) = (src_str_opt, trim_characters_opt) { - let (trim_str, start_offset) = trim_func(src_str, characters); - make_and_append_view( - views_buf, - null_builder, - original_view, - trim_str, - start_offset, - ); + if let Some(src_str) = src_str_opt { + let (trimmed, offset) = Tr::trim(src_str, pattern); + append_view(views_buf, original_view, trimmed, offset); + null_builder.append_non_null(); } else { null_builder.append_null(); views_buf.push(0); @@ -247,18 +265,17 @@ fn trim_and_append_str<'a>( /// Applies the trim function to the given string array(s) /// and returns a new string array with the trimmed values. /// -/// See [`string_view_trim`] for details on `func` -fn string_trim<'a, T: OffsetSizeTrait>( - func: fn(&'a str, &'a str) -> (&'a str, u32), - args: &'a [ArrayRef], -) -> Result { +/// Pre-computes the pattern characters once for scalar patterns to avoid +/// repeated allocations per row. +fn string_trim(args: &[ArrayRef]) -> Result { let string_array = as_generic_string_array::(&args[0])?; match args.len() { 1 => { + // Trim spaces by default let result = string_array .iter() - .map(|string| string.map(|string: &str| func(string, " ").0)) + .map(|string| string.map(|s| Tr::trim_ascii_char(s, b' ').0)) .collect::>(); Ok(Arc::new(result) as ArrayRef) @@ -267,6 +284,7 @@ fn string_trim<'a, T: OffsetSizeTrait>( let characters_array = as_generic_string_array::(&args[1])?; if characters_array.len() == 1 { + // Scalar pattern - pre-compute pattern chars once if characters_array.is_null(0) { return Ok(new_null_array( string_array.data_type(), @@ -274,19 +292,25 @@ fn string_trim<'a, T: OffsetSizeTrait>( )); } - let characters = characters_array.value(0); + let pattern: Vec = characters_array.value(0).chars().collect(); let result = string_array .iter() - .map(|item| item.map(|string| func(string, characters).0)) + .map(|item| item.map(|s| Tr::trim(s, &pattern).0)) .collect::>(); return Ok(Arc::new(result) as ArrayRef); } + // Per-row pattern - must compute pattern chars for each row + let mut pattern: Vec = Vec::new(); let result = string_array .iter() .zip(characters_array.iter()) .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => Some(func(string, characters).0), + (Some(s), Some(c)) => { + pattern.clear(); + pattern.extend(c.chars()); + Some(Tr::trim(s, &pattern).0) + } _ => None, }) .collect::>(); @@ -295,131 +319,299 @@ fn string_trim<'a, T: OffsetSizeTrait>( } other => { exec_err!( - "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2." + "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2." ) } } } pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result { - case_conversion(args, |string| string.to_lowercase(), name) + case_conversion(args, true, name) } pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result { - case_conversion(args, |string| string.to_uppercase(), name) + case_conversion(args, false, name) +} + +#[inline] +fn unicode_case(s: &str, lower: bool) -> String { + if lower { + s.to_lowercase() + } else { + s.to_uppercase() + } } -fn case_conversion<'a, F>( - args: &'a [ColumnarValue], - op: F, +fn case_conversion( + args: &[ColumnarValue], + lower: bool, name: &str, -) -> Result -where - F: Fn(&'a str) -> String, -{ +) -> Result { match &args[0] { ColumnarValue::Array(array) => match array.data_type() { - DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::( - array, op, + DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::( + array, lower, )?)), - DataType::LargeUtf8 => Ok(ColumnarValue::Array(case_conversion_array::< - i64, - _, - >(array, op)?)), + DataType::LargeUtf8 => Ok(ColumnarValue::Array( + case_conversion_array::(array, lower)?, + )), DataType::Utf8View => { let string_array = as_string_view_array(array)?; - let mut string_builder = StringBuilder::with_capacity( - string_array.len(), - string_array.get_array_memory_size(), - ); - - for str in string_array.iter() { - if let Some(str) = str { - string_builder.append_value(op(str)); - } else { - string_builder.append_null(); + if string_array.is_ascii() { + return Ok(ColumnarValue::Array(Arc::new( + case_conversion_utf8view_ascii(string_array, lower), + ))); + } + let item_len = string_array.len(); + // Null-preserving: reuse the input null buffer as the output null buffer. + let nulls = string_array.nulls().cloned(); + let mut builder = StringViewArrayBuilder::with_capacity(item_len); + + if let Some(ref n) = nulls { + for i in 0..item_len { + if n.is_null(i) { + builder.append_placeholder(); + } else { + // SAFETY: `n.is_null(i)` was false in the branch above. + let s = unsafe { string_array.value_unchecked(i) }; + builder.append_value(&unicode_case(s, lower)); + } + } + } else { + for i in 0..item_len { + // SAFETY: no null buffer means every index is valid. + let s = unsafe { string_array.value_unchecked(i) }; + builder.append_value(&unicode_case(s, lower)); } } - Ok(ColumnarValue::Array(Arc::new(string_builder.finish()))) + Ok(ColumnarValue::Array(Arc::new(builder.finish(nulls)?))) } other => exec_err!("Unsupported data type {other:?} for function {name}"), }, ColumnarValue::Scalar(scalar) => match scalar { ScalarValue::Utf8(a) => { - let result = a.as_ref().map(|x| op(x)); + let result = a.as_ref().map(|x| unicode_case(x, lower)); Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) } ScalarValue::LargeUtf8(a) => { - let result = a.as_ref().map(|x| op(x)); + let result = a.as_ref().map(|x| unicode_case(x, lower)); Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) } ScalarValue::Utf8View(a) => { - let result = a.as_ref().map(|x| op(x)); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + let result = a.as_ref().map(|x| unicode_case(x, lower)); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(result))) } other => exec_err!("Unsupported data type {other:?} for function {name}"), }, } } -fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result -where - O: OffsetSizeTrait, - F: Fn(&'a str) -> String, -{ +fn case_conversion_array( + array: &ArrayRef, + lower: bool, +) -> Result { const PRE_ALLOC_BYTES: usize = 8; let string_array = as_generic_string_array::(array)?; - let value_data = string_array.value_data(); - - // All values are ASCII. - if value_data.is_ascii() { - return case_conversion_ascii_array::(string_array, op); + if string_array.is_ascii() { + return case_conversion_ascii_array::(string_array, lower); } // Values contain non-ASCII. let item_len = string_array.len(); - let capacity = string_array.value_data().len() + PRE_ALLOC_BYTES; - let mut builder = GenericStringBuilder::::with_capacity(item_len, capacity); + let offsets = string_array.value_offsets(); + let start = offsets.first().unwrap().as_usize(); + let end = offsets.last().unwrap().as_usize(); + let capacity = (end - start) + PRE_ALLOC_BYTES; + // Null-preserving: reuse the input null buffer as the output null buffer. + let nulls = string_array.nulls().cloned(); + let mut builder = GenericStringArrayBuilder::::with_capacity(item_len, capacity); + + if let Some(ref n) = nulls { + for i in 0..item_len { + if n.is_null(i) { + builder.append_placeholder(); + } else { + // SAFETY: `n.is_null(i)` was false in the branch above. + let s = unsafe { string_array.value_unchecked(i) }; + builder.append_value(&unicode_case(s, lower)); + } + } + } else { + for i in 0..item_len { + // SAFETY: no null buffer means every index is valid. + let s = unsafe { string_array.value_unchecked(i) }; + builder.append_value(&unicode_case(s, lower)); + } + } + Ok(Arc::new(builder.finish(nulls)?)) +} - if string_array.null_count() == 0 { - let iter = - (0..item_len).map(|i| Some(op(unsafe { string_array.value_unchecked(i) }))); - builder.extend(iter); +/// Fast path for case conversion on an all-ASCII `StringViewArray`. +fn case_conversion_utf8view_ascii( + array: &StringViewArray, + lower: bool, +) -> StringViewArray { + // Specialize per conversion so the byte call inlines in the hot loops below. + if lower { + case_conversion_utf8view_ascii_inner(array, u8::to_ascii_lowercase) } else { - let iter = string_array.iter().map(|string| string.map(&op)); - builder.extend(iter); + case_conversion_utf8view_ascii_inner(array, u8::to_ascii_uppercase) } - Ok(Arc::new(builder.finish())) } -/// All values of string_array are ASCII, and when converting case, there is no changes in the byte -/// array length. Therefore, the StringArray can be treated as a complete ASCII string for -/// case conversion, and we can reuse the offsets buffer and the nulls buffer. -fn case_conversion_ascii_array<'a, O, F>( - string_array: &'a GenericStringArray, - op: F, -) -> Result -where - O: OffsetSizeTrait, - F: Fn(&'a str) -> String, -{ - let value_data = string_array.value_data(); - // SAFETY: all items stored in value_data satisfy UTF8. - // ref: impl ByteArrayNativeType for str {...} - let str_values = unsafe { std::str::from_utf8_unchecked(value_data) }; - - // conversion - let converted_values = op(str_values); - assert_eq!(converted_values.len(), str_values.len()); - let bytes = converted_values.into_bytes(); - - // build result - let values = Buffer::from_vec(bytes); - let offsets = string_array.offsets().clone(); +/// Walks the views once and produces a new `StringViewArray` with +/// case-converted bytes. Inline strings (<= 12 bytes) are converted in-place; +/// long strings copy-and-convert into output buffers and have their view fields +/// rewritten to address the new bytes. ASCII case conversion preserves is byte +/// length, so no row migrates between the inline and long layouts. +fn case_conversion_utf8view_ascii_inner u8>( + array: &StringViewArray, + convert: F, +) -> StringViewArray { + let item_len = array.len(); + let views = array.views(); + let data_buffers = array.data_buffers(); + let nulls = array.nulls(); + + let mut new_views: Vec = Vec::with_capacity(item_len); + // Long values are packed into `in_progress`; when full it is sealed into + // `completed` and a new, larger block is started — same block-doubling + // scheme as Arrow's `GenericByteViewBuilder`. + let mut in_progress: Vec = Vec::new(); + let mut completed: Vec = Vec::new(); + let mut block_size: u32 = STRING_VIEW_INIT_BLOCK_SIZE; + + for i in 0..item_len { + if nulls.is_some_and(|n| n.is_null(i)) { + // Zero view = empty, no buffer reference; the null buffer is what + // marks the row null, so the view's value is irrelevant. + new_views.push(0); + continue; + } + let view = views[i]; + // Length is the low 32 bits; `as u32` discards the rest of the view. + let len = view as u32 as usize; + if len == 0 { + new_views.push(0); + continue; + } + let mut bytes = view.to_le_bytes(); + if len <= 12 { + // Inline: value is in bytes[4..4+len], no buffer reference. Convert + // in place; nothing else in the view needs to change. + for b in &mut bytes[4..4 + len] { + *b = convert(b); + } + new_views.push(u128::from_le_bytes(bytes)); + } else { + // Long: input view points into shared `data_buffers` we can't + // mutate, so copy-convert into our own buffer and rewrite the + // view's prefix/buffer_index/offset (length is preserved). + + // Ensure the current block has room; otherwise flush and grow. + let required_cap = in_progress.len() + len; + if in_progress.capacity() < required_cap { + if !in_progress.is_empty() { + completed.push(Buffer::from_vec(std::mem::take(&mut in_progress))); + } + if block_size < STRING_VIEW_MAX_BLOCK_SIZE { + block_size = block_size.saturating_mul(2); + } + let to_reserve = len.max(block_size as usize); + #[expect( + clippy::disallowed_methods, + reason = "StringView's block size bounds growth, so reserve cannot overflow capacity arithmetically. This hot loop intentionally avoids the extra `try_reserve` checks. It remains subject to allocator failure/OOM, which must be managed externally." + )] + in_progress.reserve(to_reserve); + } + + // The in-progress block will be sealed at index `completed.len()`, + // and our value starts at the current write position within it. + let buffer_index: u32 = i32::try_from(completed.len()) + .expect("buffer count exceeds i32::MAX") + as u32; + let new_offset: u32 = + i32::try_from(in_progress.len()).expect("offset exceeds i32::MAX") as u32; + + // Source location from the input view: bytes 8..12 are buffer + // index, bytes 12..16 are the offset within it. + let src_buffer_index = + u32::from_le_bytes(bytes[8..12].try_into().unwrap()) as usize; + let src_offset = + u32::from_le_bytes(bytes[12..16].try_into().unwrap()) as usize; + let src = + &data_buffers[src_buffer_index].as_slice()[src_offset..src_offset + len]; + + let prefix_start = in_progress.len(); + in_progress.extend(src.iter().map(&convert)); + + // Rewrite the three long-view fields; bytes[0..4] (length) is + // left untouched. The prefix is read back from the bytes we just + // wrote so the converted value has a single source of truth. + let prefix: [u8; 4] = in_progress[prefix_start..prefix_start + 4] + .try_into() + .unwrap(); + bytes[4..8].copy_from_slice(&prefix); + bytes[8..12].copy_from_slice(&buffer_index.to_le_bytes()); + bytes[12..16].copy_from_slice(&new_offset.to_le_bytes()); + new_views.push(u128::from_le_bytes(bytes)); + } + } + + if !in_progress.is_empty() { + completed.push(Buffer::from_vec(in_progress)); + } + + // SAFETY: each long view's buffer_index addresses a buffer we wrote, and + // its offset addresses bytes within that buffer; prefixes were copied from + // those same bytes; inline views were rewritten from valid inline bytes; + // null/empty rows are zero views with no buffer reference; row count is + // unchanged. + unsafe { + StringViewArray::new_unchecked( + ScalarBuffer::from(new_views), + completed, + array.nulls().cloned(), + ) + } +} + +/// Fast path for case conversion on an all-ASCII string array. ASCII case +/// conversion is byte-length-preserving, so we can convert the entire addressed +/// byte range in one pass over the value buffer and reuse the offsets and nulls +/// buffers — rebasing the offsets when the input is a sliced array. +fn case_conversion_ascii_array( + string_array: &GenericStringArray, + lower: bool, +) -> Result { + let value_offsets = string_array.value_offsets(); + let start = value_offsets.first().unwrap().as_usize(); + let end = value_offsets.last().unwrap().as_usize(); + let relevant = &string_array.value_data()[start..end]; + + let converted: Vec = if lower { + relevant.iter().map(u8::to_ascii_lowercase).collect() + } else { + relevant.iter().map(u8::to_ascii_uppercase).collect() + }; + let values = Buffer::from_vec(converted); + + // Shift offsets from `start`-based to 0-based so they index into `values`. + let offsets = if start == 0 { + string_array.offsets().clone() + } else { + let s = O::usize_as(start); + let rebased: Vec = value_offsets.iter().map(|&o| o - s).collect(); + // SAFETY: subtracting a constant from monotonic offsets preserves + // monotonicity, and `start` is the minimum offset, so no underflow. + unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(rebased)) } + }; + let nulls = string_array.nulls().cloned(); - // SAFETY: offsets and nulls are consistent with the input array. + // SAFETY: offsets are monotonic and in-bounds for `values`; nulls + // (if any) match the slice length. Ok(Arc::new(unsafe { GenericStringArray::::new_unchecked(offsets, values, nulls) })) diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 3b53660463d44..b10db23472c99 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -15,21 +15,23 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{as_largestring_array, Array}; +use arrow::array::{Array, as_largestring_array}; use arrow::datatypes::DataType; use datafusion_expr::sort_properties::ExprProperties; -use std::any::Any; use std::sync::Arc; use crate::string::concat; use crate::strings::{ - ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder, + ColumnarValueRef, ConcatLargeStringBuilder, ConcatStringBuilder, + ConcatStringViewBuilder, +}; +use datafusion_common::cast::{as_binary_array, as_string_array, as_string_view_array}; +use datafusion_common::{ + Result, ScalarValue, exec_datafusion_err, internal_err, plan_err, }; -use datafusion_common::cast::{as_string_array, as_string_view_array}; -use datafusion_common::{internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; @@ -68,18 +70,25 @@ impl ConcatFunc { use DataType::*; Self { signature: Signature::variadic( - vec![Utf8View, Utf8, LargeUtf8], + vec![Utf8View, Utf8, LargeUtf8, Binary], Volatility::Immutable, ), } } } -impl ScalarUDFImpl for ConcatFunc { - fn as_any(&self) -> &dyn Any { - self +fn deduce_return_type(arg_types: &[DataType]) -> DataType { + use DataType::*; + if arg_types.contains(&Utf8View) { + Utf8View + } else if arg_types.contains(&LargeUtf8) { + LargeUtf8 + } else { + Utf8 } +} +impl ScalarUDFImpl for ConcatFunc { fn name(&self) -> &str { "concat" } @@ -88,19 +97,11 @@ impl ScalarUDFImpl for ConcatFunc { &self.signature } + /// Match the return type to the input types to avoid unnecessary casts. On + /// mixed inputs, prefer Utf8View; prefer LargeUtf8 over Utf8 to avoid + /// potential overflow on LargeUtf8 input. fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - let mut dt = &Utf8; - arg_types.iter().for_each(|data_type| { - if data_type == &Utf8View { - dt = data_type; - } - if data_type == &LargeUtf8 && dt != &Utf8View { - dt = data_type; - } - }); - - Ok(dt.to_owned()) + Ok(deduce_return_type(arg_types)) } /// Concatenates the text representations of all the arguments. NULL arguments are ignored. @@ -108,43 +109,38 @@ impl ScalarUDFImpl for ConcatFunc { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, .. } = args; - let mut return_datatype = DataType::Utf8; - args.iter().for_each(|col| { - if col.data_type() == DataType::Utf8View { - return_datatype = col.data_type(); - } - if col.data_type() == DataType::LargeUtf8 - && return_datatype != DataType::Utf8View - { - return_datatype = col.data_type(); - } - }); + let arg_types: Vec = args.iter().map(|c| c.data_type()).collect(); + let return_datatype = deduce_return_type(&arg_types); - let array_len = args - .iter() - .filter_map(|x| match x { - ColumnarValue::Array(array) => Some(array.len()), - _ => None, - }) - .next(); + let array_len = args.iter().find_map(|x| match x { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }); // Scalar if array_len.is_none() { - let mut result = String::new(); - for arg in args { + let mut values: Vec<&[u8]> = Vec::with_capacity(args.len()); + for arg in &args { let ColumnarValue::Scalar(scalar) = arg else { return internal_err!("concat expected scalar value, got {arg:?}"); }; - - match scalar.try_as_str() { - Some(Some(v)) => result.push_str(v), - Some(None) => {} // null literal - None => plan_err!( - "Concat function does not support scalar type {}", - scalar - )?, + if let ScalarValue::Binary(Some(value)) = scalar { + values.push(value); + } else { + match scalar.try_as_str() { + Some(Some(v)) => values.push(v.as_bytes()), + Some(None) => {} // null literal + None => plan_err!( + "Concat function does not support scalar type {}", + scalar + )?, + } } } + let concat_bytes = values.concat(); + let result = std::str::from_utf8(&concat_bytes) + .map_err(|_| exec_datafusion_err!("invalid UTF-8 in binary literal"))? + .to_string(); return match return_datatype { DataType::Utf8View => { @@ -177,6 +173,13 @@ impl ScalarUDFImpl for ConcatFunc { columns.push(ColumnarValueRef::Scalar(s.as_bytes())); } } + ColumnarValue::Scalar(ScalarValue::Binary(maybe_value)) => { + if let Some(b) = maybe_value { + // data_size is a capacity hint, so doesn't matter if it is chars or bytes + data_size += b.len() * len; + columns.push(ColumnarValueRef::Scalar(b.as_slice())); + } + } ColumnarValue::Array(array) => { match array.data_type() { DataType::Utf8 => { @@ -189,7 +192,7 @@ impl ScalarUDFImpl for ConcatFunc { ColumnarValueRef::NonNullableArray(string_array) }; columns.push(column); - }, + } DataType::LargeUtf8 => { let string_array = as_largestring_array(array); @@ -197,23 +200,40 @@ impl ScalarUDFImpl for ConcatFunc { let column = if array.is_nullable() { ColumnarValueRef::NullableLargeStringArray(string_array) } else { - ColumnarValueRef::NonNullableLargeStringArray(string_array) + ColumnarValueRef::NonNullableLargeStringArray( + string_array, + ) }; columns.push(column); - }, + } DataType::Utf8View => { let string_array = as_string_view_array(array)?; - data_size += string_array.len(); + // This is an estimate; in particular, it will + // undercount arrays of short strings (<= 12 bytes). + data_size += string_array.total_buffer_bytes_used(); let column = if array.is_nullable() { ColumnarValueRef::NullableStringViewArray(string_array) } else { ColumnarValueRef::NonNullableStringViewArray(string_array) }; columns.push(column); - }, + } + DataType::Binary => { + let string_array = as_binary_array(array)?; + + data_size += string_array.values().len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableBinaryArray(string_array) + } else { + ColumnarValueRef::NonNullableBinaryArray(string_array) + }; + columns.push(column); + } other => { - return plan_err!("Input was {other} which is not a supported datatype for concat function") + return plan_err!( + "Input was {other} which is not a supported datatype for concat function" + ); } }; } @@ -223,39 +243,39 @@ impl ScalarUDFImpl for ConcatFunc { match return_datatype { DataType::Utf8 => { - let mut builder = StringArrayBuilder::with_capacity(len, data_size); + let mut builder = ConcatStringBuilder::with_capacity(len, data_size); for i in 0..len { columns .iter() .for_each(|column| builder.write::(column, i)); - builder.append_offset(); + builder.append_offset()?; } - let string_array = builder.finish(None); + let string_array = builder.finish(None)?; Ok(ColumnarValue::Array(Arc::new(string_array))) } DataType::Utf8View => { - let mut builder = StringViewArrayBuilder::with_capacity(len, data_size); + let mut builder = ConcatStringViewBuilder::with_capacity(len, data_size); for i in 0..len { columns .iter() .for_each(|column| builder.write::(column, i)); - builder.append_offset(); + builder.append_offset()?; } - let string_array = builder.finish(); + let string_array = builder.finish(None)?; Ok(ColumnarValue::Array(Arc::new(string_array))) } DataType::LargeUtf8 => { - let mut builder = LargeStringArrayBuilder::with_capacity(len, data_size); + let mut builder = ConcatLargeStringBuilder::with_capacity(len, data_size); for i in 0..len { columns .iter() .for_each(|column| builder.write::(column, i)); - builder.append_offset(); + builder.append_offset()?; } - let string_array = builder.finish(None); + let string_array = builder.finish(None)?; Ok(ColumnarValue::Array(Arc::new(string_array))) } _ => unreachable!(), @@ -273,7 +293,7 @@ impl ScalarUDFImpl for ConcatFunc { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + _info: &SimplifyContext, ) -> Result { simplify_concat(args) } @@ -305,9 +325,8 @@ pub(crate) fn simplify_concat(args: Vec) -> Result { for arg in args.clone() { match arg { Expr::Literal(ScalarValue::Utf8(None), _) => {} - Expr::Literal(ScalarValue::LargeUtf8(None), _) => { - } - Expr::Literal(ScalarValue::Utf8View(None), _) => { } + Expr::Literal(ScalarValue::LargeUtf8(None), _) => {} + Expr::Literal(ScalarValue::Utf8View(None), _) => {} // filter out `null` args // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. @@ -325,7 +344,7 @@ pub(crate) fn simplify_concat(args: Vec) -> Result { Expr::Literal(x, _) => { return internal_err!( "The scalar {x} should be casted to string type during the type coercion." - ) + ); } // If the arg is not a literal, we should first push the current `contiguous_scalar` // to the `new_args` (if it is not empty) and reset it to empty string. @@ -334,8 +353,10 @@ pub(crate) fn simplify_concat(args: Vec) -> Result { if !contiguous_scalar.is_empty() { match return_type { DataType::Utf8 => new_args.push(lit(contiguous_scalar)), - DataType::LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))), - DataType::Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))), + DataType::LargeUtf8 => new_args + .push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))), + DataType::Utf8View => new_args + .push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))), _ => unreachable!(), } contiguous_scalar = "".to_string(); @@ -374,11 +395,11 @@ pub(crate) fn simplify_concat(args: Vec) -> Result { mod tests { use super::*; use crate::utils::test::test_function; - use arrow::array::{Array, LargeStringArray, StringViewArray}; + use DataType::*; use arrow::array::{ArrayRef, StringArray}; + use arrow::array::{LargeStringArray, StringViewArray}; use arrow::datatypes::Field; use datafusion_common::config::ConfigOptions; - use DataType::*; #[test] fn test_functions() -> Result<()> { @@ -450,7 +471,33 @@ mod tests { Utf8View, StringViewArray ); - + test_function!( + ConcatFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Binary(Some( + "Café".as_bytes().into() + ))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("Cafécc")), + &str, + Utf8, + StringArray + ); + test_function!( + ConcatFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Binary(Some(Vec::from( + "Café".as_bytes() + )))), + ColumnarValue::Scalar(ScalarValue::Binary(Some("cc".as_bytes().into()))), + ], + Ok(Some("Cafécc")), + &str, + Utf8, + StringArray + ); Ok(()) } diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index cdd30ac8755ab..2c2d4bd42165b 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -15,8 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{as_largestring_array, Array, StringArray}; -use std::any::Any; +use arrow::array::Array; use std::sync::Arc; use arrow::datatypes::DataType; @@ -24,12 +23,17 @@ use arrow::datatypes::DataType; use crate::string::concat; use crate::string::concat::simplify_concat; use crate::string::concat_ws; -use crate::strings::{ColumnarValueRef, StringArrayBuilder}; -use datafusion_common::cast::{as_string_array, as_string_view_array}; -use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; +use crate::strings::{ + ColumnarValueRef, ConcatLargeStringBuilder, ConcatStringBuilder, + ConcatStringViewBuilder, +}; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; @@ -83,10 +87,6 @@ impl ConcatWsFunc { } impl ScalarUDFImpl for ConcatWsFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "concat_ws" } @@ -95,17 +95,27 @@ impl ScalarUDFImpl for ConcatWsFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { + /// Match the return type to the input types to avoid unnecessary casts. On + /// mixed inputs, prefer Utf8View; prefer LargeUtf8 over Utf8 to avoid + /// potential overflow on LargeUtf8 input. + fn return_type(&self, arg_types: &[DataType]) -> Result { use DataType::*; - Ok(Utf8) + if arg_types.contains(&Utf8View) { + Ok(Utf8View) + } else if arg_types.contains(&LargeUtf8) { + Ok(LargeUtf8) + } else { + Ok(Utf8) + } } - /// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored. + /// Concatenates all but the first argument, with separators. The first + /// argument is used as the separator string, and should not be NULL. Other + /// NULL arguments are ignored. /// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22' fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, .. } = args; - // do not accept 0 arguments. if args.len() < 2 { return exec_err!( "concat_ws was called with {} arguments. It requires at least 2.", @@ -113,68 +123,67 @@ impl ScalarUDFImpl for ConcatWsFunc { ); } - let array_len = args - .iter() - .filter_map(|x| match x { - ColumnarValue::Array(array) => Some(array.len()), - _ => None, - }) - .next(); + let return_datatype = if args.iter().any(|c| c.data_type() == DataType::Utf8View) + { + DataType::Utf8View + } else if args.iter().any(|c| c.data_type() == DataType::LargeUtf8) { + DataType::LargeUtf8 + } else { + DataType::Utf8 + }; + + let array_len = args.iter().find_map(|x| match x { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }); // Scalar if array_len.is_none() { let ColumnarValue::Scalar(scalar) = &args[0] else { - // loop above checks for all args being scalar unreachable!() }; let sep = match scalar.try_as_str() { Some(Some(s)) => s, Some(None) => { // null literal string - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + return match return_datatype { + DataType::Utf8View => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) + } + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + }; } None => return internal_err!("Expected string literal, got {scalar:?}"), }; - let mut result = String::new(); - // iterator over Option - let iter = &mut args[1..].iter().map(|arg| { + let mut values = Vec::with_capacity(args.len() - 1); + for arg in &args[1..] { let ColumnarValue::Scalar(scalar) = arg else { - // loop above checks for all args being scalar unreachable!() }; - scalar.try_as_str() - }); - - // append first non null arg - for scalar in iter.by_ref() { - match scalar { - Some(Some(s)) => { - result.push_str(s); - break; - } - Some(None) => {} // null literal string - None => { - return internal_err!("Expected string literal, got {scalar:?}") - } - } - } - // handle subsequent non null args - for scalar in iter.by_ref() { - match scalar { - Some(Some(s)) => { - result.push_str(sep); - result.push_str(s); - } + match scalar.try_as_str() { + Some(Some(v)) => values.push(v), Some(None) => {} // null literal string None => { - return internal_err!("Expected string literal, got {scalar:?}") + return internal_err!("Expected string literal, got {scalar:?}"); } } } + let result = values.join(sep); - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); + return match return_datatype { + DataType::Utf8View => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result)))) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result)))) + } + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))), + }; } // Array @@ -183,23 +192,61 @@ impl ScalarUDFImpl for ConcatWsFunc { // parse sep let sep = match &args[0] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { - data_size += s.len() * len * (args.len() - 2); // estimate - ColumnarValueRef::Scalar(s.as_bytes()) - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null(len)))); - } - ColumnarValue::Array(array) => { - let string_array = as_string_array(array)?; - data_size += string_array.values().len() * (args.len() - 2); // estimate - if array.is_nullable() { - ColumnarValueRef::NullableArray(string_array) - } else { - ColumnarValueRef::NonNullableArray(string_array) + ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { + Some(Some(s)) => { + data_size += s.len() * len * (args.len() - 2); // estimate + ColumnarValueRef::Scalar(s.as_bytes()) } - } - _ => unreachable!("concat ws"), + Some(None) => { + return match return_datatype { + DataType::Utf8View => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) + } + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + }; + } + None => { + return internal_err!("Expected string separator, got {scalar:?}"); + } + }, + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8 => { + let string_array = as_string_array(array)?; + data_size += string_array.values().len() * (args.len() - 2); + if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) + } + } + DataType::LargeUtf8 => { + let string_array = as_large_string_array(array)?; + data_size += string_array.values().len() * (args.len() - 2); + if array.is_nullable() { + ColumnarValueRef::NullableLargeStringArray(string_array) + } else { + ColumnarValueRef::NonNullableLargeStringArray(string_array) + } + } + DataType::Utf8View => { + let string_array = as_string_view_array(array)?; + data_size += + string_array.total_buffer_bytes_used() * (args.len() - 2); + if array.is_nullable() { + ColumnarValueRef::NullableStringViewArray(string_array) + } else { + ColumnarValueRef::NonNullableStringViewArray(string_array) + } + } + other => { + return plan_err!( + "Input was {other} which is not a supported datatype for concat_ws separator" + ); + } + }, }; let mut columns = Vec::with_capacity(args.len() - 1); @@ -225,31 +272,37 @@ impl ScalarUDFImpl for ConcatWsFunc { ColumnarValueRef::NonNullableArray(string_array) }; columns.push(column); - }, + } DataType::LargeUtf8 => { - let string_array = as_largestring_array(array); + let string_array = as_large_string_array(array)?; data_size += string_array.values().len(); let column = if array.is_nullable() { ColumnarValueRef::NullableLargeStringArray(string_array) } else { - ColumnarValueRef::NonNullableLargeStringArray(string_array) + ColumnarValueRef::NonNullableLargeStringArray( + string_array, + ) }; columns.push(column); - }, + } DataType::Utf8View => { let string_array = as_string_view_array(array)?; - data_size += string_array.data_buffers().iter().map(|buf| buf.len()).sum::(); + // This is an estimate; in particular, it will + // undercount arrays of short strings (<= 12 bytes). + data_size += string_array.total_buffer_bytes_used(); let column = if array.is_nullable() { ColumnarValueRef::NullableStringViewArray(string_array) } else { ColumnarValueRef::NonNullableStringViewArray(string_array) }; columns.push(column); - }, + } other => { - return plan_err!("Input was {other} which is not a supported datatype for concat_ws function.") + return plan_err!( + "Input was {other} which is not a supported datatype for concat_ws function." + ); } }; } @@ -257,32 +310,71 @@ impl ScalarUDFImpl for ConcatWsFunc { } } - let mut builder = StringArrayBuilder::with_capacity(len, data_size); - for i in 0..len { - if !sep.is_valid(i) { - builder.append_offset(); - continue; + match return_datatype { + DataType::Utf8View => { + let mut builder = ConcatStringViewBuilder::with_capacity(len, data_size); + for i in 0..len { + if !sep.is_valid(i) { + builder.append_offset()?; + continue; + } + let mut first = true; + for column in &columns { + if column.is_valid(i) { + if !first { + builder.write::(&sep, i); + } + builder.write::(column, i); + first = false; + } + } + builder.append_offset()?; + } + Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())?))) } - - let mut iter = columns.iter(); - for column in iter.by_ref() { - if column.is_valid(i) { - builder.write::(column, i); - break; + DataType::LargeUtf8 => { + let mut builder = ConcatLargeStringBuilder::with_capacity(len, data_size); + for i in 0..len { + if !sep.is_valid(i) { + builder.append_offset()?; + continue; + } + let mut first = true; + for column in &columns { + if column.is_valid(i) { + if !first { + builder.write::(&sep, i); + } + builder.write::(column, i); + first = false; + } + } + builder.append_offset()?; } + Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())?))) } - - for column in iter { - if column.is_valid(i) { - builder.write::(&sep, i); - builder.write::(column, i); + _ => { + let mut builder = ConcatStringBuilder::with_capacity(len, data_size); + for i in 0..len { + if !sep.is_valid(i) { + builder.append_offset()?; + continue; + } + let mut first = true; + for column in &columns { + if column.is_valid(i) { + if !first { + builder.write::(&sep, i); + } + builder.write::(column, i); + first = false; + } + } + builder.append_offset()?; } + Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())?))) } - - builder.append_offset(); } - - Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())))) } /// Simply the `concat_ws` function by @@ -293,7 +385,7 @@ impl ScalarUDFImpl for ConcatWsFunc { fn simplify( &self, args: Vec, - _info: &dyn SimplifyInfo, + _info: &SimplifyContext, ) -> Result { match &args[..] { [delimiter, vals @ ..] => simplify_concat_ws(delimiter, vals), @@ -307,6 +399,21 @@ impl ScalarUDFImpl for ConcatWsFunc { } fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { + // Preserve the delimiter's string type for any new literals produced + // during simplification. + let delimiter_type = match delimiter { + Expr::Literal(v, _) => v.data_type(), + _ => DataType::Utf8, + }; + + let typed_lit = |s: String| -> Expr { + match delimiter_type { + DataType::LargeUtf8 => lit(ScalarValue::LargeUtf8(Some(s))), + DataType::Utf8View => lit(ScalarValue::Utf8View(Some(s))), + _ => lit(s), + } + }; + match delimiter { Expr::Literal( ScalarValue::Utf8(delimiter) @@ -315,8 +422,8 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { match delimiter { - // when the delimiter is an empty string, - // we can use `concat` to replace `concat_ws` + // When the delimiter is the empty string, replace `concat_ws` + // with `concat` Some(delimiter) if delimiter.is_empty() => { match simplify_concat(args.to_vec())? { ExprSimplifyResult::Original(_) => { @@ -332,29 +439,41 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { let mut new_args = Vec::with_capacity(args.len()); - new_args.push(lit(delimiter)); + new_args.push(typed_lit(delimiter.to_string())); let mut contiguous_scalar = None; for arg in args { match arg { // filter out null args - Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None), _) => {} - Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)), _) => { - match contiguous_scalar { - None => contiguous_scalar = Some(v.to_string()), - Some(mut pre) => { - pre += delimiter; - pre += v; - contiguous_scalar = Some(pre) - } + Expr::Literal( + ScalarValue::Utf8(None) + | ScalarValue::LargeUtf8(None) + | ScalarValue::Utf8View(None), + _, + ) => {} + Expr::Literal( + ScalarValue::Utf8(Some(v)) + | ScalarValue::LargeUtf8(Some(v)) + | ScalarValue::Utf8View(Some(v)), + _, + ) => match contiguous_scalar { + None => contiguous_scalar = Some(v.to_string()), + Some(mut pre) => { + pre += delimiter; + pre += v; + contiguous_scalar = Some(pre) } + }, + Expr::Literal(s, _) => { + return internal_err!( + "The scalar {s} should be casted to string type during the type coercion." + ); } - Expr::Literal(s, _) => return internal_err!("The scalar {s} should be casted to string type during the type coercion."), // If the arg is not a literal, we should first push the current `contiguous_scalar` // to the `new_args` and reset it to None. // Then pushing this arg to the `new_args`. arg => { if let Some(val) = contiguous_scalar { - new_args.push(lit(val)); + new_args.push(typed_lit(val)); } new_args.push(arg.clone()); contiguous_scalar = None; @@ -362,7 +481,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result Result Ok(ExprSimplifyResult::Simplified(Expr::Literal( - ScalarValue::Utf8(None), - None, - ))), + // If the delimiter is null, then the value of the whole expression is null. + None => { + let null_scalar = match delimiter_type { + DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), + DataType::Utf8View => ScalarValue::Utf8View(None), + _ => ScalarValue::Utf8(None), + }; + Ok(ExprSimplifyResult::Simplified(Expr::Literal( + null_scalar, + None, + ))) + } } } Expr::Literal(d, _) => internal_err!( @@ -406,12 +532,12 @@ mod tests { use std::sync::Arc; use crate::string::concat_ws::ConcatWsFunc; - use arrow::array::{Array, ArrayRef, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use arrow::datatypes::Field; - use datafusion_common::config::ConfigOptions; use datafusion_common::Result; use datafusion_common::ScalarValue; + use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use crate::utils::test::test_function; @@ -547,4 +673,265 @@ mod tests { Ok(()) } + + #[test] + fn concat_ws_utf8view_scalar_separator() -> Result<()> { + let c0 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string()))); + let c1 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + None, + Some("z"), + ]))); + + let arg_fields = vec![ + Field::new("a", Utf8View, true).into(), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 3, + return_field: Field::new("f", Utf8View, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + let expected = + Arc::new(StringViewArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!("Expected array result"), + } + + Ok(()) + } + + #[test] + fn concat_ws_largeutf8_scalar_separator() -> Result<()> { + let c0 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(",".to_string()))); + let c1 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + None, + Some("z"), + ]))); + + let arg_fields = vec![ + Field::new("a", LargeUtf8, true).into(), + Field::new("a", Utf8, true).into(), + Field::new("a", Utf8, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 3, + return_field: Field::new("f", LargeUtf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + let expected = + Arc::new(LargeStringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!("Expected array result"), + } + + Ok(()) + } + + #[test] + fn concat_ws_utf8view_nullable_separator() -> Result<()> { + let c0 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![ + Some(","), + None, + Some("+"), + ]))); + let c1 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![ + "foo", "bar", "baz", + ]))); + let c2 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![ + Some("x"), + Some("y"), + Some("z"), + ]))); + + let arg_fields = vec![ + Field::new("a", Utf8View, true).into(), + Field::new("a", Utf8View, true).into(), + Field::new("a", Utf8View, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 3, + return_field: Field::new("f", Utf8View, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + let expected = Arc::new(StringViewArray::from(vec![ + Some("foo,x"), + None, + Some("baz+z"), + ])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!("Expected array result"), + } + + Ok(()) + } + + #[test] + fn concat_ws_largeutf8_arrays() -> Result<()> { + let c0 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(",".to_string()))); + let c1 = ColumnarValue::Array(Arc::new(LargeStringArray::from(vec![ + "foo", "bar", "baz", + ]))); + let c2 = ColumnarValue::Array(Arc::new(LargeStringArray::from(vec![ + Some("x"), + None, + Some("z"), + ]))); + + let arg_fields = vec![ + Field::new("a", LargeUtf8, true).into(), + Field::new("a", LargeUtf8, true).into(), + Field::new("a", LargeUtf8, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 3, + return_field: Field::new("f", LargeUtf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + let expected = + Arc::new(LargeStringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!("Expected array result"), + } + + Ok(()) + } + + #[test] + fn concat_ws_utf8view_null_separator() -> Result<()> { + // All-scalar path: null Utf8View separator should return Utf8View(None) + let c0 = ColumnarValue::Scalar(ScalarValue::Utf8View(None)); + let c1 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))); + let c2 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some("bb".to_string()))); + + let arg_fields = vec![ + Field::new("a", Utf8View, true).into(), + Field::new("a", Utf8View, true).into(), + Field::new("a", Utf8View, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 1, + return_field: Field::new("f", Utf8View, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + match result { + ColumnarValue::Scalar(ScalarValue::Utf8View(None)) => {} + other => panic!("Expected Utf8View(None), got {other:?}"), + } + + // Array path: null Utf8View scalar separator with array args + let c0 = ColumnarValue::Scalar(ScalarValue::Utf8View(None)); + let c1 = + ColumnarValue::Array(Arc::new(StringViewArray::from(vec!["foo", "bar"]))); + + let arg_fields = vec![ + Field::new("a", Utf8View, true).into(), + Field::new("a", Utf8View, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1], + arg_fields, + number_rows: 2, + return_field: Field::new("f", Utf8View, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + match result { + ColumnarValue::Scalar(ScalarValue::Utf8View(None)) => {} + other => panic!("Expected Utf8View(None), got {other:?}"), + } + + Ok(()) + } + + #[test] + fn concat_ws_largeutf8_null_separator() -> Result<()> { + // All-scalar path: null LargeUtf8 separator should return LargeUtf8(None) + let c0 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)); + let c1 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("aa".to_string()))); + let c2 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("bb".to_string()))); + + let arg_fields = vec![ + Field::new("a", LargeUtf8, true).into(), + Field::new("a", LargeUtf8, true).into(), + Field::new("a", LargeUtf8, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 1, + return_field: Field::new("f", LargeUtf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + match result { + ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} + other => panic!("Expected LargeUtf8(None), got {other:?}"), + } + + // Array path: null LargeUtf8 scalar separator with array args + let c0 = ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)); + let c1 = + ColumnarValue::Array(Arc::new(LargeStringArray::from(vec!["foo", "bar"]))); + + let arg_fields = vec![ + Field::new("a", LargeUtf8, true).into(), + Field::new("a", LargeUtf8, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1], + arg_fields, + number_rows: 2, + return_field: Field::new("f", LargeUtf8, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + match result { + ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} + other => panic!("Expected LargeUtf8(None), got {other:?}"), + } + + Ok(()) + } } diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 7e50676933c8d..8a75e3ac703ee 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -15,20 +15,18 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::make_scalar_function; -use arrow::array::{Array, ArrayRef, AsArray}; +use arrow::array::{Array, ArrayRef, Scalar}; use arrow::compute::contains as arrow_contains; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View}; use datafusion_common::types::logical_string; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::binary::{binary_to_string_coercion, string_coercion}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; #[user_doc( @@ -72,10 +70,6 @@ impl ContainsFunc { } impl ScalarUDFImpl for ContainsFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "contains" } @@ -89,7 +83,7 @@ impl ScalarUDFImpl for ContainsFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(contains, vec![])(&args.args) + contains(args.args.as_slice()) } fn documentation(&self) -> Option<&Documentation> { @@ -97,43 +91,71 @@ impl ScalarUDFImpl for ContainsFunc { } } +fn to_array(value: &ColumnarValue) -> Result<(ArrayRef, bool)> { + match value { + ColumnarValue::Array(array) => Ok((Arc::clone(array), false)), + ColumnarValue::Scalar(scalar) => Ok((scalar.to_array()?, true)), + } +} + +/// Helper to call arrow_contains with proper Datum handling. +/// When an argument is marked as scalar, we wrap it in `Scalar` to tell arrow's +/// kernel to use the optimized single-value code path instead of iterating. +fn call_arrow_contains( + haystack: &ArrayRef, + haystack_is_scalar: bool, + needle: &ArrayRef, + needle_is_scalar: bool, +) -> Result { + // Arrow's Datum trait is implemented for ArrayRef, Arc, and Scalar + // We pass ArrayRef directly when not scalar, or wrap in Scalar when it is + let result = match (haystack_is_scalar, needle_is_scalar) { + (false, false) => arrow_contains(haystack, needle)?, + (false, true) => arrow_contains(haystack, &Scalar::new(Arc::clone(needle)))?, + (true, false) => arrow_contains(&Scalar::new(Arc::clone(haystack)), needle)?, + (true, true) => arrow_contains( + &Scalar::new(Arc::clone(haystack)), + &Scalar::new(Arc::clone(needle)), + )?, + }; + + // If both inputs were scalar, return a scalar result + if haystack_is_scalar && needle_is_scalar { + let scalar = datafusion_common::ScalarValue::try_from_array(&result, 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } else { + Ok(ColumnarValue::Array(Arc::new(result))) + } +} + /// use `arrow::compute::contains` to do the calculation for contains -fn contains(args: &[ArrayRef]) -> Result { +fn contains(args: &[ColumnarValue]) -> Result { + let (haystack, haystack_is_scalar) = to_array(&args[0])?; + let (needle, needle_is_scalar) = to_array(&args[1])?; + if let Some(coercion_data_type) = - string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| { - binary_to_string_coercion(args[0].data_type(), args[1].data_type()) + string_coercion(haystack.data_type(), needle.data_type()).or_else(|| { + binary_to_string_coercion(haystack.data_type(), needle.data_type()) }) { - let arg0 = if args[0].data_type() == &coercion_data_type { - Arc::clone(&args[0]) + let haystack = if haystack.data_type() == &coercion_data_type { + haystack } else { - arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)? + arrow::compute::kernels::cast::cast(&haystack, &coercion_data_type)? }; - let arg1 = if args[1].data_type() == &coercion_data_type { - Arc::clone(&args[1]) + let needle = if needle.data_type() == &coercion_data_type { + needle } else { - arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)? + arrow::compute::kernels::cast::cast(&needle, &coercion_data_type)? }; match coercion_data_type { - Utf8View => { - let mod_str = arg0.as_string_view(); - let match_str = arg1.as_string_view(); - let res = arrow_contains(mod_str, match_str)?; - Ok(Arc::new(res) as ArrayRef) - } - Utf8 => { - let mod_str = arg0.as_string::(); - let match_str = arg1.as_string::(); - let res = arrow_contains(mod_str, match_str)?; - Ok(Arc::new(res) as ArrayRef) - } - LargeUtf8 => { - let mod_str = arg0.as_string::(); - let match_str = arg1.as_string::(); - let res = arrow_contains(mod_str, match_str)?; - Ok(Arc::new(res) as ArrayRef) - } + Utf8View | Utf8 | LargeUtf8 => call_arrow_contains( + &haystack, + haystack_is_scalar, + &needle, + needle_is_scalar, + ), other => { exec_err!("Unsupported data type {other:?} for function `contains`.") } @@ -153,8 +175,8 @@ mod test { use crate::expr_fn::contains; use arrow::array::{BooleanArray, StringArray}; use arrow::datatypes::{DataType, Field}; - use datafusion_common::config::ConfigOptions; use datafusion_common::ScalarValue; + use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index 6090d9c84d4cd..6b84e260a2d11 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -15,15 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; -use arrow::array::ArrayRef; +use arrow::array::{ArrayRef, Scalar}; +use arrow::compute::kernels::comparison::ends_with as arrow_ends_with; use arrow::datatypes::DataType; -use crate::utils::make_scalar_function; use datafusion_common::types::logical_string; -use datafusion_common::{internal_err, Result}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::binary::{binary_to_string_coercion, string_coercion}; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -78,10 +78,6 @@ impl EndsWithFunc { } impl ScalarUDFImpl for EndsWithFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "ends_with" } @@ -95,12 +91,70 @@ impl ScalarUDFImpl for EndsWithFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - match args.args[0].data_type() { - DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => { - make_scalar_function(ends_with, vec![])(&args.args) + let [str_arg, suffix_arg] = take_function_args(self.name(), &args.args)?; + + // Determine the common type for coercion + let coercion_type = string_coercion( + &str_arg.data_type(), + &suffix_arg.data_type(), + ) + .or_else(|| { + binary_to_string_coercion(&str_arg.data_type(), &suffix_arg.data_type()) + }); + + let Some(coercion_type) = coercion_type else { + return exec_err!( + "Unsupported data types {:?}, {:?} for function `ends_with`.", + str_arg.data_type(), + suffix_arg.data_type() + ); + }; + + // Helper to cast an array if needed + let maybe_cast = |arr: &ArrayRef, target: &DataType| -> Result { + if arr.data_type() == target { + Ok(Arc::clone(arr)) + } else { + Ok(arrow::compute::kernels::cast::cast(arr, target)?) + } + }; + + match (str_arg, suffix_arg) { + // Both scalars - just compute directly + (ColumnarValue::Scalar(str_scalar), ColumnarValue::Scalar(suffix_scalar)) => { + let str_arr = str_scalar.to_array_of_size(1)?; + let suffix_arr = suffix_scalar.to_array_of_size(1)?; + let str_arr = maybe_cast(&str_arr, &coercion_type)?; + let suffix_arr = maybe_cast(&suffix_arr, &coercion_type)?; + let result = arrow_ends_with(&str_arr, &suffix_arr)?; + Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &result, 0, + )?)) + } + // String is array, suffix is scalar - use Scalar wrapper for optimization + (ColumnarValue::Array(str_arr), ColumnarValue::Scalar(suffix_scalar)) => { + let str_arr = maybe_cast(str_arr, &coercion_type)?; + let suffix_arr = suffix_scalar.to_array_of_size(1)?; + let suffix_arr = maybe_cast(&suffix_arr, &coercion_type)?; + let suffix_scalar = Scalar::new(suffix_arr); + let result = arrow_ends_with(&str_arr, &suffix_scalar)?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + // String is scalar, suffix is array - use Scalar wrapper for string + (ColumnarValue::Scalar(str_scalar), ColumnarValue::Array(suffix_arr)) => { + let str_arr = str_scalar.to_array_of_size(1)?; + let str_arr = maybe_cast(&str_arr, &coercion_type)?; + let str_scalar = Scalar::new(str_arr); + let suffix_arr = maybe_cast(suffix_arr, &coercion_type)?; + let result = arrow_ends_with(&str_scalar, &suffix_arr)?; + Ok(ColumnarValue::Array(Arc::new(result))) } - other => { - internal_err!("Unsupported data type {other:?} for function ends_with. Expected Utf8, LargeUtf8 or Utf8View")? + // Both arrays - pass directly + (ColumnarValue::Array(str_arr), ColumnarValue::Array(suffix_arr)) => { + let str_arr = maybe_cast(str_arr, &coercion_type)?; + let suffix_arr = maybe_cast(suffix_arr, &coercion_type)?; + let result = arrow_ends_with(&str_arr, &suffix_arr)?; + Ok(ColumnarValue::Array(Arc::new(result))) } } } @@ -110,47 +164,24 @@ impl ScalarUDFImpl for EndsWithFunc { } } -/// Returns true if string ends with suffix. -/// ends_with('alphabet', 'abet') = 't' -fn ends_with(args: &[ArrayRef]) -> Result { - if let Some(coercion_data_type) = - string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| { - binary_to_string_coercion(args[0].data_type(), args[1].data_type()) - }) - { - let arg0 = if args[0].data_type() == &coercion_data_type { - Arc::clone(&args[0]) - } else { - arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)? - }; - let arg1 = if args[1].data_type() == &coercion_data_type { - Arc::clone(&args[1]) - } else { - arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)? - }; - let result = arrow::compute::kernels::comparison::ends_with(&arg0, &arg1)?; - Ok(Arc::new(result) as ArrayRef) - } else { - internal_err!( - "Unsupported data types for ends_with. Expected Utf8, LargeUtf8 or Utf8View" - ) - } -} - #[cfg(test)] mod tests { - use arrow::array::{Array, BooleanArray}; + use arrow::array::{Array, BooleanArray, StringArray}; use arrow::datatypes::DataType::Boolean; + use arrow::datatypes::{DataType, Field}; + use std::sync::Arc; use datafusion_common::Result; use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_common::config::ConfigOptions; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use crate::string::ends_with::EndsWithFunc; use crate::utils::test::test_function; #[test] - fn test_functions() -> Result<()> { + fn test_scalar_scalar() -> Result<()> { + // Test Scalar + Scalar combinations test_function!( EndsWithFunc::new(), vec![ @@ -196,6 +227,186 @@ mod tests { BooleanArray ); + // Test with LargeUtf8 + test_function!( + EndsWithFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some( + "alphabet".to_string() + ))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("bet".to_string()))), + ], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + + // Test with Utf8View + test_function!( + EndsWithFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "alphabet".to_string() + ))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("bet".to_string()))), + ], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + + Ok(()) + } + + #[test] + fn test_array_scalar() -> Result<()> { + // Test Array + Scalar (the optimized path) + let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("alphabet"), + Some("alphabet"), + Some("beta"), + None, + ]))); + let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("bet".to_string()))); + + let args = vec![array, scalar]; + test_function!( + EndsWithFunc::new(), + args, + Ok(Some(true)), // First element result: "alphabet" ends with "bet" + bool, + Boolean, + BooleanArray + ); + Ok(()) } + + #[test] + fn test_array_scalar_full_result() { + // Test Array + Scalar and verify all results + let func = EndsWithFunc::new(); + let array = Arc::new(StringArray::from(vec![ + Some("alphabet"), + Some("alphabet"), + Some("beta"), + None, + ])); + let args = vec![ + ColumnarValue::Array(array), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("bet".to_string()))), + ]; + + let result = func + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ], + number_rows: 4, + return_field: Field::new("f", Boolean, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(); + + let result_array = result.into_array(4).unwrap(); + let bool_array = result_array + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(bool_array.value(0)); // "alphabet" ends with "bet" + assert!(bool_array.value(1)); // "alphabet" ends with "bet" + assert!(!bool_array.value(2)); // "beta" does not end with "bet" + assert!(bool_array.is_null(3)); // null input -> null output + } + + #[test] + fn test_scalar_array() { + // Test Scalar + Array + let func = EndsWithFunc::new(); + let suffixes = Arc::new(StringArray::from(vec![ + Some("bet"), + Some("alph"), + Some("phabet"), + None, + ])); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("alphabet".to_string()))), + ColumnarValue::Array(suffixes), + ]; + + let result = func + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ], + number_rows: 4, + return_field: Field::new("f", Boolean, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(); + + let result_array = result.into_array(4).unwrap(); + let bool_array = result_array + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(bool_array.value(0)); // "alphabet" ends with "bet" + assert!(!bool_array.value(1)); // "alphabet" does not end with "alph" + assert!(bool_array.value(2)); // "alphabet" ends with "phabet" + assert!(bool_array.is_null(3)); // null suffix -> null output + } + + #[test] + fn test_array_array() { + // Test Array + Array + let func = EndsWithFunc::new(); + let strings = Arc::new(StringArray::from(vec![ + Some("alphabet"), + Some("rust"), + Some("datafusion"), + None, + ])); + let suffixes = Arc::new(StringArray::from(vec![ + Some("bet"), + Some("st"), + Some("hello"), + Some("test"), + ])); + let args = vec![ + ColumnarValue::Array(strings), + ColumnarValue::Array(suffixes), + ]; + + let result = func + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ], + number_rows: 4, + return_field: Field::new("f", Boolean, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(); + + let result_array = result.into_array(4).unwrap(); + let bool_array = result_array + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(bool_array.value(0)); // "alphabet" ends with "bet" + assert!(bool_array.value(1)); // "rust" ends with "st" + assert!(!bool_array.value(2)); // "datafusion" does not end with "hello" + assert!(bool_array.is_null(3)); // null string -> null output + } } diff --git a/datafusion/functions/src/string/levenshtein.rs b/datafusion/functions/src/string/levenshtein.rs index 2f7894df903d6..38fa8fa878de9 100644 --- a/datafusion/functions/src/string/levenshtein.rs +++ b/datafusion/functions/src/string/levenshtein.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; @@ -26,7 +25,7 @@ use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::types::logical_string; use datafusion_common::utils::datafusion_strsim; use datafusion_common::utils::take_function_args; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::type_coercion::binary::{ binary_to_string_coercion, string_coercion, }; @@ -83,10 +82,6 @@ impl LevenshteinFunc { } impl ScalarUDFImpl for LevenshteinFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "levenshtein" } @@ -101,7 +96,9 @@ impl ScalarUDFImpl for LevenshteinFunc { { utf8_to_int_type(&coercion_data_type, "levenshtein") } else { - exec_err!("Unsupported data types for levenshtein. Expected Utf8, LargeUtf8 or Utf8View") + exec_err!( + "Unsupported data types for levenshtein. Expected Utf8, LargeUtf8 or Utf8View" + ) } } @@ -149,12 +146,18 @@ fn levenshtein(args: &[ArrayRef]) -> Result { DataType::Utf8View => { let str1_array = as_string_view_array(&str1)?; let str2_array = as_string_view_array(&str2)?; + + // Reusable buffer to avoid allocating for each row + let mut cache = Vec::new(); + let result = str1_array .iter() .zip(str2_array.iter()) .map(|(string1, string2)| match (string1, string2) { (Some(string1), Some(string2)) => { - Some(datafusion_strsim::levenshtein(string1, string2) as i32) + Some(datafusion_strsim::levenshtein_with_buffer( + string1, string2, &mut cache, + ) as i32) } _ => None, }) @@ -164,12 +167,18 @@ fn levenshtein(args: &[ArrayRef]) -> Result { DataType::Utf8 => { let str1_array = as_generic_string_array::(&str1)?; let str2_array = as_generic_string_array::(&str2)?; + + // Reusable buffer to avoid allocating for each row + let mut cache = Vec::new(); + let result = str1_array .iter() .zip(str2_array.iter()) .map(|(string1, string2)| match (string1, string2) { (Some(string1), Some(string2)) => { - Some(datafusion_strsim::levenshtein(string1, string2) as i32) + Some(datafusion_strsim::levenshtein_with_buffer( + string1, string2, &mut cache, + ) as i32) } _ => None, }) @@ -179,12 +188,18 @@ fn levenshtein(args: &[ArrayRef]) -> Result { DataType::LargeUtf8 => { let str1_array = as_generic_string_array::(&str1)?; let str2_array = as_generic_string_array::(&str2)?; + + // Reusable buffer to avoid allocating for each row + let mut cache = Vec::new(); + let result = str1_array .iter() .zip(str2_array.iter()) .map(|(string1, string2)| match (string1, string2) { (Some(string1), Some(string2)) => { - Some(datafusion_strsim::levenshtein(string1, string2) as i64) + Some(datafusion_strsim::levenshtein_with_buffer( + string1, string2, &mut cache, + ) as i64) } _ => None, }) @@ -198,7 +213,9 @@ fn levenshtein(args: &[ArrayRef]) -> Result { } } } else { - exec_err!("Unsupported data types for levenshtein. Expected Utf8, LargeUtf8 or Utf8View") + exec_err!( + "Unsupported data types for levenshtein. Expected Utf8, LargeUtf8 or Utf8View" + ) } } diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index ee56a6a549857..57cbe1d8779f0 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -16,12 +16,10 @@ // under the License. use arrow::datatypes::DataType; -use std::any::Any; use crate::string::common::to_lower; -use crate::utils::utf8_to_str_type; -use datafusion_common::types::logical_string; use datafusion_common::Result; +use datafusion_common::types::logical_string; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, @@ -69,10 +67,6 @@ impl LowerFunc { } impl ScalarUDFImpl for LowerFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "lower" } @@ -82,7 +76,7 @@ impl ScalarUDFImpl for LowerFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "lower") + Ok(arg_types[0].clone()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -97,28 +91,29 @@ impl ScalarUDFImpl for LowerFunc { #[cfg(test)] mod tests { use super::*; - use arrow::array::{Array, ArrayRef, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, ArrayRef, StringArray, StringViewArray}; use arrow::datatypes::Field; use datafusion_common::config::ConfigOptions; use std::sync::Arc; - fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> { + fn invoke_lower(input: ArrayRef) -> Result { let func = LowerFunc::new(); - let arg_fields = vec![Field::new("a", input.data_type().clone(), true).into()]; - + let data_type = input.data_type().clone(); let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], - arg_fields, - return_field: Field::new("f", Utf8, true).into(), + arg_fields: vec![Field::new("a", data_type.clone(), true).into()], + return_field: Field::new("f", data_type, true).into(), config_options: Arc::new(ConfigOptions::default()), }; - - let result = match func.invoke_with_args(args)? { - ColumnarValue::Array(result) => result, + match func.invoke_with_args(args)? { + ColumnarValue::Array(r) => Ok(r), _ => unreachable!("lower"), - }; + } + } + + fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> { + let result = invoke_lower(input)?; assert_eq!(&expected, &result); Ok(()) } @@ -197,4 +192,182 @@ mod tests { to_lower(input, expected) } + + #[test] + fn lower_utf8view() -> Result<()> { + let input = Arc::new(StringViewArray::from(vec![ + Some("ARROW"), + None, + Some("TSCHÜSS"), + ])) as ArrayRef; + + let expected = Arc::new(StringViewArray::from(vec![ + Some("arrow"), + None, + Some("tschüss"), + ])) as ArrayRef; + + to_lower(input, expected) + } + + #[test] + fn lower_ascii_utf8view() -> Result<()> { + // Mix of inlined (≤12 bytes) and referenced (>12 bytes) strings, plus + // a null and an empty, to exercise the all-ASCII Utf8View fast path. + let input = Arc::new(StringViewArray::from(vec![ + Some("ARROW"), // inlined short + None, + Some("HELLO WORLD 123"), // referenced (15 bytes) + Some(""), + Some("0123456789"), // inlined, no case change + Some("DATAFUSION IS COOL"), // referenced + ])) as ArrayRef; + + let expected = Arc::new(StringViewArray::from(vec![ + Some("arrow"), + None, + Some("hello world 123"), + Some(""), + Some("0123456789"), + Some("datafusion is cool"), + ])) as ArrayRef; + + to_lower(input, expected) + } + + #[test] + fn lower_sliced_ascii_utf8view() -> Result<()> { + // Slice of a parent that contains a non-ASCII string outside the + // slice. The slice is all-ASCII, so the fast path must run and produce + // correct output while the parent's unaddressed non-ASCII bytes are + // irrelevant to the result. + let parent = Arc::new(StringViewArray::from(vec![ + Some("农历新年LONG ENOUGH FOR BUFFER"), + Some("HELLO WORLD 123"), + Some("DATAFUSION ROCKS!"), + Some("ZZZZZZZZZZZZZZZZ"), + ])) as ArrayRef; + let sliced = parent.slice(1, 2); + let result = invoke_lower(sliced)?; + let result_sv = result.as_any().downcast_ref::().unwrap(); + + let expected = StringViewArray::from(vec![ + Some("hello world 123"), + Some("datafusion rocks!"), + ]); + assert_eq!(result_sv, &expected); + // The slice's two long views address 15 + 17 = 32 bytes; the ASCII + // fast path must produce a single packed buffer of exactly that + // size, not one scaled to the parent's data buffer. + assert_eq!(result_sv.data_buffers().len(), 1); + assert_eq!(result_sv.data_buffers()[0].len(), 32); + Ok(()) + } + + #[test] + fn lower_utf8view_inline_only_no_buffers() -> Result<()> { + // An array whose values are all ≤ 12 bytes is fully inline; the ASCII + // fast path should produce no data buffers at all. + let input = Arc::new(StringViewArray::from(vec![ + Some("HELLO"), + None, + Some(""), + Some("0123456789ab"), // 12 bytes — inline boundary + ])) as ArrayRef; + let result = invoke_lower(input)?; + let result_sv = result.as_any().downcast_ref::().unwrap(); + + let expected = StringViewArray::from(vec![ + Some("hello"), + None, + Some(""), + Some("0123456789ab"), + ]); + assert_eq!(result_sv, &expected); + assert_eq!( + result_sv.data_buffers().len(), + 0, + "inline-only Utf8View should produce no data buffers" + ); + Ok(()) + } + + #[test] + fn lower_utf8view_long_packs_tight() -> Result<()> { + // Mix of long and inline values; the long values should be packed into + // a single tight output buffer whose size is exactly the sum of their + // lengths (inline values do not contribute). + let input = Arc::new(StringViewArray::from(vec![ + Some("HELLO WORLD 123"), // 15 bytes (long) + Some("ABC"), // inline + None, + Some("DATAFUSION ROCKS!"), // 17 bytes (long) + Some("ANOTHER LONG STRING"), // 19 bytes (long) + ])) as ArrayRef; + let result = invoke_lower(input)?; + let result_sv = result.as_any().downcast_ref::().unwrap(); + + let expected = StringViewArray::from(vec![ + Some("hello world 123"), + Some("abc"), + None, + Some("datafusion rocks!"), + Some("another long string"), + ]); + assert_eq!(result_sv, &expected); + assert_eq!(result_sv.data_buffers().len(), 1); + assert_eq!(result_sv.data_buffers()[0].len(), 15 + 17 + 19); + Ok(()) + } + + #[test] + fn lower_utf8view_splits_into_multiple_buffers() -> Result<()> { + // Produce enough long-string output to overflow the first data block + // (≈16 KiB after the initial doubling) and confirm the fast path + // splits across buffers rather than packing everything into one and + // risking the i32::MAX offset limit. + const STR_LEN: usize = 500; + const N: usize = 40; // 40 × 500 B = 20 KiB total — crosses the first block. + let value = "X".repeat(STR_LEN); + let inputs: Vec> = (0..N).map(|_| Some(value.clone())).collect(); + let input = Arc::new(StringViewArray::from(inputs.clone())) as ArrayRef; + let result = invoke_lower(input)?; + let result_sv = result.as_any().downcast_ref::().unwrap(); + + let expected_value = "x".repeat(STR_LEN); + let expected: Vec> = + (0..N).map(|_| Some(expected_value.as_str())).collect(); + assert_eq!(result_sv, &StringViewArray::from(expected)); + assert!( + result_sv.data_buffers().len() >= 2, + "expected the output to span more than one data buffer, got {}", + result_sv.data_buffers().len() + ); + // Total bytes across buffers must equal total long-value bytes + // (no row was inlined since each value is > 12 bytes). + let total: usize = result_sv.data_buffers().iter().map(|b| b.len()).sum(); + assert_eq!(total, N * STR_LEN); + Ok(()) + } + + #[test] + fn lower_sliced_utf8() -> Result<()> { + let parent = Arc::new(StringArray::from(vec![ + Some("AAAAAAAA"), + Some("HELLO"), + Some("WORLD"), + Some(""), + Some("ZZZZZZZZ"), + ])) as ArrayRef; + let sliced = parent.slice(1, 3); + let result = invoke_lower(sliced)?; + let result_sa = result.as_any().downcast_ref::().unwrap(); + + let expected = StringArray::from(vec![Some("hello"), Some("world"), Some("")]); + assert_eq!(result_sa, &expected); + // The slice's addressed bytes are "HELLO" + "WORLD" = 10; the ASCII + // fast path must produce a tight output buffer (not the parent's). + assert_eq!(result_sa.value_data().len(), 10); + Ok(()) + } } diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index dc6d30d38188c..e49ffeb0541ff 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -17,13 +17,12 @@ use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; -use std::any::Any; use std::sync::Arc; use crate::string::common::*; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use crate::utils::make_scalar_function; use datafusion_common::types::logical_string; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::function::Hint; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -31,7 +30,7 @@ use datafusion_expr::{ }; use datafusion_macros::user_doc; -/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. +/// Returns the longest string with leading characters removed. If the characters are not specified, spaces are removed. /// ltrim('zzzytest', 'xyz') = 'test' fn ltrim(args: &[ArrayRef]) -> Result { let use_string_view = args[0].data_type() == &DataType::Utf8View; @@ -41,12 +40,12 @@ fn ltrim(args: &[ArrayRef]) -> Result { } else { args.to_owned() }; - general_trim::(&args, TrimType::Left, use_string_view) + general_trim::(&args, use_string_view) } #[user_doc( doc_section(label = "String Functions"), - description = "Trims the specified trim string from the beginning of a string. If no trim string is provided, all whitespace is removed from the start of the input string.", + description = "Trims the specified trim string from the beginning of a string. If no trim string is provided, spaces are removed from the start of the input string.", syntax_example = "ltrim(str[, trim_str])", sql_example = r#"```sql > select ltrim(' datafusion '); @@ -65,7 +64,7 @@ fn ltrim(args: &[ArrayRef]) -> Result { standard_argument(name = "str", prefix = "String"), argument( name = "trim_str", - description = r"String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._" + description = r"String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is a space._" ), alternative_syntax = "trim(LEADING trim_str FROM str)", related_udf(name = "btrim"), @@ -102,10 +101,6 @@ impl LtrimFunc { } impl ScalarUDFImpl for LtrimFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "ltrim" } @@ -115,11 +110,7 @@ impl ScalarUDFImpl for LtrimFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types[0] == DataType::Utf8View { - Ok(DataType::Utf8View) - } else { - utf8_to_str_type(&arg_types[0], "ltrim") - } + Ok(arg_types[0].clone()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index aa8257ef8fc53..ecffb2a6de7af 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -17,7 +17,6 @@ use arrow::compute::kernels::length::length; use arrow::datatypes::DataType; -use std::any::Any; use crate::utils::utf8_to_int_type; use datafusion_common::types::logical_string; @@ -70,10 +69,6 @@ impl OctetLengthFunc { } impl ScalarUDFImpl for OctetLengthFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "octet_length" } @@ -119,7 +114,7 @@ mod tests { use arrow::datatypes::DataType::Int32; use datafusion_common::ScalarValue; - use datafusion_common::{exec_err, Result}; + use datafusion_common::{Result, exec_err}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use crate::string::octet_length::OctetLengthFunc; diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 1fc62d747de52..a53f1e2e4fc42 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -15,19 +15,20 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::sync::Arc; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; -use arrow::array::{ - ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, - OffsetSizeTrait, StringArrayType, StringViewArray, +use crate::strings::{ + BulkNullStringArrayBuilder, GenericStringArrayBuilder, StringViewArrayBuilder, }; +use crate::utils::utf8_to_str_type; +use arrow::array::{Array, ArrayRef, AsArray, Int64Array, StringArrayType}; +use arrow::buffer::NullBuffer; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::cast::as_int64_array; -use datafusion_common::types::{logical_int64, logical_string, NativeType}; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::types::{NativeType, logical_int64, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{ + DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err, internal_err, +}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; @@ -82,10 +83,6 @@ impl RepeatFunc { } impl ScalarUDFImpl for RepeatFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "repeat" } @@ -95,11 +92,69 @@ impl ScalarUDFImpl for RepeatFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types[0] == Utf8View { + return Ok(Utf8View); + } utf8_to_str_type(&arg_types[0], "repeat") } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(repeat, vec![])(&args.args) + let return_type = args.return_field.data_type().clone(); + let [string_arg, count_arg] = take_function_args(self.name(), args.args)?; + + // Early return if either argument is a scalar null + if let ColumnarValue::Scalar(s) = &string_arg + && s.is_null() + { + return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?)); + } + if let ColumnarValue::Scalar(c) = &count_arg + && c.is_null() + { + return Ok(ColumnarValue::Scalar(ScalarValue::try_from(&return_type)?)); + } + + match (&string_arg, &count_arg) { + ( + ColumnarValue::Scalar(string_scalar), + ColumnarValue::Scalar(count_scalar), + ) => { + let count = match count_scalar { + ScalarValue::Int64(Some(n)) => *n, + _ => { + return internal_err!( + "Unexpected data type {:?} for repeat count", + count_scalar.data_type() + ); + } + }; + + let result = match string_scalar { + ScalarValue::Utf8View(Some(s)) => ScalarValue::Utf8View(Some( + compute_repeat(s, count, i32::MAX as usize)?, + )), + ScalarValue::Utf8(Some(s)) => ScalarValue::Utf8(Some( + compute_repeat(s, count, i32::MAX as usize)?, + )), + ScalarValue::LargeUtf8(Some(s)) => ScalarValue::LargeUtf8(Some( + compute_repeat(s, count, i64::MAX as usize)?, + )), + _ => { + return internal_err!( + "Unexpected data type {:?} for function repeat", + string_scalar.data_type() + ); + } + }; + + Ok(ColumnarValue::Scalar(result)) + } + _ => { + let string_array = string_arg.to_array(args.number_rows)?; + let count_array = count_arg.to_array(args.number_rows)?; + Ok(ColumnarValue::Array(repeat(&string_array, &count_array)?)) + } + } } fn documentation(&self) -> Option<&Documentation> { @@ -107,34 +162,82 @@ impl ScalarUDFImpl for RepeatFunc { } } +/// Computes repeat for a single string value with max size check +#[inline] +fn compute_repeat(s: &str, count: i64, max_size: usize) -> Result { + if count <= 0 { + return Ok(String::new()); + } + let result_len = repeat_len(s.len(), count, max_size)?; + debug_assert!(result_len <= max_size); + let count = repeat_count(count, max_size)?; + Ok(s.repeat(count)) +} + +fn repeat_len(string_len: usize, count: i64, max_size: usize) -> Result { + let count = repeat_count(count, max_size)?; + let result_len = string_len.checked_mul(count).ok_or_else(|| { + exec_datafusion_err!( + "string size overflow on repeat, max size is {}, but got {}", + max_size, + usize::MAX + ) + })?; + if result_len > max_size { + return exec_err!( + "string size overflow on repeat, max size is {}, but got {}", + max_size, + result_len + ); + } + Ok(result_len) +} + +fn repeat_count(count: i64, max_size: usize) -> Result { + match usize::try_from(count) { + Ok(count) => Ok(count), + Err(_) => exec_err!( + "string size overflow on repeat, max size is {}, but got {}", + max_size, + usize::MAX + ), + } +} + /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' -fn repeat(args: &[ArrayRef]) -> Result { - let number_array = as_int64_array(&args[1])?; - match args[0].data_type() { +fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result { + let number_array = as_int64_array(count_array)?; + match string_array.data_type() { Utf8View => { - let string_view_array = args[0].as_string_view(); - repeat_impl::( + let string_view_array = string_array.as_string_view(); + let (_, max_item_capacity) = calculate_capacities( &string_view_array, number_array, i32::MAX as usize, - ) + )?; + let builder = StringViewArrayBuilder::with_capacity(string_array.len()); + repeat_impl(&string_view_array, number_array, max_item_capacity, builder) } Utf8 => { - let string_array = args[0].as_string::(); - repeat_impl::>( - &string_array, - number_array, - i32::MAX as usize, - ) + let string_arr = string_array.as_string::(); + let (total_capacity, max_item_capacity) = + calculate_capacities(&string_arr, number_array, i32::MAX as usize)?; + let builder = GenericStringArrayBuilder::::with_capacity( + string_array.len(), + total_capacity, + ); + repeat_impl(&string_arr, number_array, max_item_capacity, builder) } LargeUtf8 => { - let string_array = args[0].as_string::(); - repeat_impl::>( - &string_array, - number_array, - i64::MAX as usize, - ) + let string_arr = string_array.as_string::(); + let (total_capacity, max_item_capacity) = + calculate_capacities(&string_arr, number_array, i64::MAX as usize)?; + let builder = GenericStringArrayBuilder::::with_capacity( + string_array.len(), + total_capacity, + ); + repeat_impl(&string_arr, number_array, max_item_capacity, builder) } other => exec_err!( "Unsupported data type {other:?} for function repeat. \ @@ -143,29 +246,31 @@ fn repeat(args: &[ArrayRef]) -> Result { } } -fn repeat_impl<'a, T, S>( +fn calculate_capacities<'a, S>( string_array: &S, number_array: &Int64Array, max_str_len: usize, -) -> Result +) -> Result<(usize, usize)> where - T: OffsetSizeTrait, S: StringArrayType<'a>, { - let mut total_capacity = 0; + let mut total_capacity = 0usize; + let mut max_item_capacity = 0usize; + string_array.iter().zip(number_array.iter()).try_for_each( |(string, number)| -> Result<(), DataFusionError> { match (string, number) { (Some(string), Some(number)) if number >= 0 => { - let item_capacity = string.len() * number as usize; - if item_capacity > max_str_len { - return exec_err!( - "string size overflow on repeat, max size is {}, but got {}", - max_str_len, - number as usize * string.len() - ); - } - total_capacity += item_capacity; + let item_capacity = repeat_len(string.len(), number, max_str_len)?; + total_capacity = + total_capacity.checked_add(item_capacity).ok_or_else(|| { + exec_datafusion_err!( + "string size overflow on repeat, max size is {}, but got {}", + max_str_len, + usize::MAX + ) + })?; + max_item_capacity = max_item_capacity.max(item_capacity); } _ => (), } @@ -173,33 +278,89 @@ where }, )?; - let mut builder = - GenericStringBuilder::::with_capacity(string_array.len(), total_capacity); + Ok((total_capacity, max_item_capacity)) +} - string_array.iter().zip(number_array.iter()).try_for_each( - |(string, number)| -> Result<(), DataFusionError> { - match (string, number) { - (Some(string), Some(number)) if number >= 0 => { - builder.append_value(string.repeat(number as usize)); - } - (Some(_), Some(_)) => builder.append_value(""), - _ => builder.append_null(), +fn repeat_impl<'a, S, B>( + string_array: &S, + number_array: &Int64Array, + max_item_capacity: usize, + mut builder: B, +) -> Result +where + S: StringArrayType<'a> + 'a, + B: BulkNullStringArrayBuilder, +{ + // Reusable buffer to avoid allocations in string.repeat() + let mut buffer = Vec::::with_capacity(max_item_capacity); + + // Helper function to repeat a string into a buffer using doubling strategy + // count must be > 0 + #[inline] + fn repeat_to_buffer(buffer: &mut Vec, string: &str, count: usize) { + buffer.clear(); + if !string.is_empty() { + let src = string.as_bytes(); + // Initial copy + buffer.extend_from_slice(src); + // Doubling strategy: copy what we have so far until we reach the target + while buffer.len() < src.len() * count { + let copy_len = buffer.len().min(src.len() * count - buffer.len()); + // SAFETY: we're copying valid UTF-8 bytes that we already verified + buffer.extend_from_within(..copy_len); } - Ok(()) - }, - )?; - let array = builder.finish(); + } + } + + // Output is null IFF either input is null + let nulls = NullBuffer::union(string_array.nulls(), number_array.nulls()); + + if let Some(ref n) = nulls { + for i in 0..string_array.len() { + if n.is_null(i) { + builder.append_placeholder(); + continue; + } + // SAFETY: index `i` in both arrays is valid + let string = unsafe { string_array.value_unchecked(i) }; + let count = unsafe { number_array.value_unchecked(i) }; + if count > 0 { + repeat_to_buffer(&mut buffer, string, count as usize); + // SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str + builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) }); + } else { + builder.append_value(""); + } + } + } else { + for i in 0..string_array.len() { + // SAFETY: no nulls, so every index in both arrays is valid + let string = unsafe { string_array.value_unchecked(i) }; + let count = unsafe { number_array.value_unchecked(i) }; + if count > 0 { + repeat_to_buffer(&mut buffer, string, count as usize); + // SAFETY: buffer contains valid UTF-8 since we only copy from a valid &str + builder.append_value(unsafe { std::str::from_utf8_unchecked(&buffer) }); + } else { + builder.append_value(""); + } + } + } - Ok(Arc::new(array) as ArrayRef) + builder.finish(nulls) } #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use std::sync::Arc; + + use arrow::array::{ + Array, ArrayRef, Int64Array, LargeStringArray, StringArray, StringViewArray, + }; + use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::ScalarValue; - use datafusion_common::{exec_err, Result}; + use datafusion_common::{Result, exec_err}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use crate::string::repeat::RepeatFunc; @@ -249,8 +410,8 @@ mod tests { ], Ok(Some("PgPgPgPg")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( RepeatFunc::new(), @@ -260,8 +421,19 @@ mod tests { ], Ok(None), &str, - Utf8, - StringArray + Utf8View, + StringViewArray + ); + test_function!( + RepeatFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("Pg")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + LargeUtf8, + LargeStringArray ); test_function!( RepeatFunc::new(), @@ -271,8 +443,8 @@ mod tests { ], Ok(None), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( RepeatFunc::new(), @@ -292,4 +464,81 @@ mod tests { Ok(()) } + + // Slicing the input arrays produces a NullBuffer with a non-zero offset. + // The tests below use 6-row inputs sliced to (1, 4) so that: + // slot 0 (orig 1): "a" × 3 → "aaa" + // slot 1 (orig 2): "bb" × 2 → "bbbb" + // slot 2 (orig 3): "c" × NULL → NULL (count-side null) + // slot 3 (orig 4): NULL × 1 → NULL (string-side null) + fn sliced_offset_inputs(make_strings: F) -> (ArrayRef, ArrayRef) + where + F: FnOnce(Vec>) -> ArrayRef, + { + let strings = make_strings(vec![ + None, + Some("a"), + Some("bb"), + Some("c"), + None, + Some("d"), + ]); + let counts: ArrayRef = Arc::new(Int64Array::from(vec![ + Some(2), + Some(3), + Some(2), + None, + Some(1), + Some(2), + ])); + (strings.slice(1, 4), counts.slice(1, 4)) + } + + fn assert_sliced_offset_output(result: ArrayRef) + where + for<'a> &'a A: arrow::array::ArrayAccessor, + { + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.len(), 4); + assert_eq!(arrow::array::ArrayAccessor::value(&result, 0), "aaa"); + assert_eq!(arrow::array::ArrayAccessor::value(&result, 1), "bbbb"); + assert!(result.is_null(2)); + assert!(result.is_null(3)); + assert_eq!(result.null_count(), 2); + } + + #[test] + fn test_repeat_sliced_string_with_null_offset() { + let (strings, counts) = sliced_offset_inputs(|v| Arc::new(StringArray::from(v))); + let result = super::repeat(&strings, &counts).unwrap(); + assert_sliced_offset_output::(result); + } + + #[test] + fn test_repeat_string_array_overflow() { + let strings: ArrayRef = Arc::new(StringArray::from(vec![Some("abc")])); + let counts: ArrayRef = Arc::new(Int64Array::from(vec![Some(i64::MAX)])); + + let err = super::repeat(&strings, &counts).unwrap_err().to_string(); + assert!( + err.contains("string size overflow on repeat"), + "unexpected error: {err}" + ); + } + + #[test] + fn test_repeat_sliced_large_string_with_null_offset() { + let (strings, counts) = + sliced_offset_inputs(|v| Arc::new(LargeStringArray::from(v))); + let result = super::repeat(&strings, &counts).unwrap(); + assert_sliced_offset_output::(result); + } + + #[test] + fn test_repeat_sliced_string_view_with_null_offset() { + let (strings, counts) = + sliced_offset_inputs(|v| Arc::new(StringViewArray::from(v))); + let result = super::repeat(&strings, &counts).unwrap(); + assert_sliced_offset_output::(result); + } } diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index f127b452b2d34..28f81769f56db 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -15,16 +15,19 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray}; +use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; +use arrow::buffer::NullBuffer; use arrow::datatypes::DataType; +use crate::strings::{ + BulkNullStringArrayBuilder, GenericStringArrayBuilder, StringWriter, +}; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::types::logical_string; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::type_coercion::binary::{ binary_to_string_coercion, string_coercion, }; @@ -79,10 +82,6 @@ impl ReplaceFunc { } impl ScalarUDFImpl for ReplaceFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "replace" } @@ -101,7 +100,9 @@ impl ScalarUDFImpl for ReplaceFunc { { utf8_to_str_type(&coercion_data_type, "replace") } else { - exec_err!("Unsupported data types for replace. Expected Utf8, LargeUtf8 or Utf8View") + exec_err!( + "Unsupported data types for replace. Expected Utf8, LargeUtf8 or Utf8View" + ) } } @@ -163,17 +164,40 @@ fn replace_view(args: &[ArrayRef]) -> Result { let from_array = as_string_view_array(&args[1])?; let to_array = as_string_view_array(&args[2])?; - let result = string_array - .iter() - .zip(from_array.iter()) - .zip(to_array.iter()) - .map(|((string, from), to)| match (string, from, to) { - (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)), - _ => None, - }) - .collect::(); - - Ok(Arc::new(result) as ArrayRef) + let len = string_array.len(); + let mut builder = GenericStringArrayBuilder::::with_capacity(len, 0); + let nulls = NullBuffer::union_many([ + string_array.nulls(), + from_array.nulls(), + to_array.nulls(), + ]); + + // Hoist the nulls.is_some() check out of the loop. LLVM does not always + // unswitch this loop on its own (the Utf8View body is large enough to + // exceed its cost-benefit threshold). + if let Some(nulls_ref) = nulls.as_ref() { + for i in 0..len { + if nulls_ref.is_null(i) { + builder.append_placeholder(); + continue; + } + // SAFETY: union of input nulls is non-null at i, so each input is too. + let string = unsafe { string_array.value_unchecked(i) }; + let from = unsafe { from_array.value_unchecked(i) }; + let to = unsafe { to_array.value_unchecked(i) }; + apply_replace(&mut builder, string, from, to); + } + } else { + for i in 0..len { + // SAFETY: i < len, and no input has a null buffer. + let string = unsafe { string_array.value_unchecked(i) }; + let from = unsafe { from_array.value_unchecked(i) }; + let to = unsafe { to_array.value_unchecked(i) }; + apply_replace(&mut builder, string, from, to); + } + } + + Ok(Arc::new(builder.finish(nulls)?) as ArrayRef) } /// Replaces all occurrences in string of substring from with substring to. @@ -183,24 +207,90 @@ fn replace(args: &[ArrayRef]) -> Result { let from_array = as_generic_string_array::(&args[1])?; let to_array = as_generic_string_array::(&args[2])?; - let result = string_array - .iter() - .zip(from_array.iter()) - .zip(to_array.iter()) - .map(|((string, from), to)| match (string, from, to) { - (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)), - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) + let len = string_array.len(); + let mut builder = GenericStringArrayBuilder::::with_capacity(len, 0); + let nulls = NullBuffer::union_many([ + string_array.nulls(), + from_array.nulls(), + to_array.nulls(), + ]); + + // Hoist the nulls.is_some() check out of the loop. LLVM unswitches this + // automatically today, but kept explicit so the no-nulls fast path is not + // contingent on the optimizer's cost heuristic. + if let Some(nulls_ref) = nulls.as_ref() { + for i in 0..len { + if nulls_ref.is_null(i) { + builder.append_placeholder(); + continue; + } + // SAFETY: union of input nulls is non-null at i, so each input is too. + let string = unsafe { string_array.value_unchecked(i) }; + let from = unsafe { from_array.value_unchecked(i) }; + let to = unsafe { to_array.value_unchecked(i) }; + apply_replace(&mut builder, string, from, to); + } + } else { + for i in 0..len { + // SAFETY: i < len, and no input has a null buffer. + let string = unsafe { string_array.value_unchecked(i) }; + let from = unsafe { from_array.value_unchecked(i) }; + let to = unsafe { to_array.value_unchecked(i) }; + apply_replace(&mut builder, string, from, to); + } + } + + Ok(Arc::new(builder.finish(nulls)?) as ArrayRef) +} + +#[inline] +fn apply_replace( + builder: &mut B, + string: &str, + from: &str, + to: &str, +) { + // Hot path: single ASCII byte → single ASCII byte. An ASCII byte (< 0x80) + // cannot appear inside a multi-byte UTF-8 sequence, so any multi-byte + // sequences in `string` pass through unchanged and output stays valid + // UTF-8. + if let (&[from_byte], &[to_byte]) = (from.as_bytes(), to.as_bytes()) + && from_byte.is_ascii() + && to_byte.is_ascii() + { + // SAFETY: see the contract above. + unsafe { + builder.append_byte_map(string.as_bytes(), |b| { + if b == from_byte { to_byte } else { b } + }); + } + return; + } + + if from.is_empty() { + // PostgreSQL returns the input unchanged when `from` is empty (#22253). + builder.append_value(string); + return; + } + + builder.append_with(|w| replace_into_writer(w, string, from, to)); +} + +#[inline] +fn replace_into_writer(w: &mut W, string: &str, from: &str, to: &str) { + let mut last_end = 0; + for (start, _part) in string.match_indices(from) { + w.write_str(&string[last_end..start]); + w.write_str(to); + last_end = start + from.len(); + } + w.write_str(&string[last_end..]); } #[cfg(test)] mod tests { use super::*; use crate::utils::test::test_function; - use arrow::array::Array; use arrow::array::LargeStringArray; use arrow::array::StringArray; use arrow::datatypes::DataType::{LargeUtf8, Utf8}; @@ -250,6 +340,19 @@ mod tests { StringArray ); + test_function!( + ReplaceFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("abc")))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("")))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("x")))), + ], + Ok(Some("abc")), + &str, + LargeUtf8, + LargeStringArray + ); + Ok(()) } } diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index be0595f65542a..05ad9e855976d 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -17,13 +17,12 @@ use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; -use std::any::Any; use std::sync::Arc; use crate::string::common::*; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use crate::utils::make_scalar_function; use datafusion_common::types::logical_string; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::function::Hint; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, @@ -31,7 +30,7 @@ use datafusion_expr::{ }; use datafusion_macros::user_doc; -/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. +/// Returns the longest string with trailing characters removed. If the characters are not specified, spaces are removed. /// rtrim('testxxzx', 'xyz') = 'test' fn rtrim(args: &[ArrayRef]) -> Result { let use_string_view = args[0].data_type() == &DataType::Utf8View; @@ -41,12 +40,12 @@ fn rtrim(args: &[ArrayRef]) -> Result { } else { args.to_owned() }; - general_trim::(&args, TrimType::Right, use_string_view) + general_trim::(&args, use_string_view) } #[user_doc( doc_section(label = "String Functions"), - description = "Trims the specified trim string from the end of a string. If no trim string is provided, all whitespace is removed from the end of the input string.", + description = "Trims the specified trim string from the end of a string. If no trim string is provided, all spaces are removed from the end of the input string.", syntax_example = "rtrim(str[, trim_str])", alternative_syntax = "trim(TRAILING trim_str FROM str)", sql_example = r#"```sql @@ -66,7 +65,7 @@ fn rtrim(args: &[ArrayRef]) -> Result { standard_argument(name = "str", prefix = "String"), argument( name = "trim_str", - description = "String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._" + description = "String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is a space._" ), related_udf(name = "btrim"), related_udf(name = "ltrim") @@ -102,10 +101,6 @@ impl RtrimFunc { } impl ScalarUDFImpl for RtrimFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "rtrim" } @@ -115,11 +110,7 @@ impl ScalarUDFImpl for RtrimFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types[0] == DataType::Utf8View { - Ok(DataType::Utf8View) - } else { - utf8_to_str_type(&arg_types[0], "rtrim") - } + Ok(arg_types[0].clone()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index c8b293f29811e..7e382868c4f23 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -15,20 +15,26 @@ // specific language governing permissions and limitations // under the License. +use crate::strings::{ + BulkNullStringArrayBuilder, GenericStringArrayBuilder, StringViewArrayBuilder, +}; use crate::utils::utf8_to_str_type; use arrow::array::{ - ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait, StringArrayType, - StringViewArray, + Array, ArrayRef, AsArray, ByteView, Int64Array, StringArrayType, StringViewArray, + make_view, new_null_array, }; -use arrow::array::{AsArray, GenericStringBuilder}; +use arrow::buffer::{NullBuffer, ScalarBuffer}; use arrow::datatypes::DataType; -use datafusion_common::cast::as_int64_array; use datafusion_common::ScalarValue; -use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; +use datafusion_common::cast::as_int64_array; +use datafusion_common::types::{NativeType, logical_int64, logical_string}; +use datafusion_common::{Result, exec_datafusion_err, exec_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, TypeSignatureClass, Volatility, +}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; -use std::any::Any; +use memchr::memmem; use std::sync::Arc; #[user_doc( @@ -45,7 +51,10 @@ use std::sync::Arc; ```"#, standard_argument(name = "str", prefix = "String"), argument(name = "delimiter", description = "String or character to split on."), - argument(name = "pos", description = "Position of the part to return.") + argument( + name = "pos", + description = "Position of the part to return (counting from 1). Negative values count backward from the end of the string." + ) )] #[derive(Debug, PartialEq, Eq, Hash)] pub struct SplitPartFunc { @@ -60,19 +69,16 @@ impl Default for SplitPartFunc { impl SplitPartFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( + signature: Signature::coercible( vec![ - TypeSignature::Exact(vec![Utf8View, Utf8View, Int64]), - TypeSignature::Exact(vec![Utf8View, Utf8, Int64]), - TypeSignature::Exact(vec![Utf8View, LargeUtf8, Int64]), - TypeSignature::Exact(vec![Utf8, Utf8View, Int64]), - TypeSignature::Exact(vec![Utf8, Utf8, Int64]), - TypeSignature::Exact(vec![LargeUtf8, Utf8View, Int64]), - TypeSignature::Exact(vec![LargeUtf8, Utf8, Int64]), - TypeSignature::Exact(vec![Utf8, LargeUtf8, Int64]), - TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64]), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ), ], Volatility::Immutable, ), @@ -81,10 +87,6 @@ impl SplitPartFunc { } impl ScalarUDFImpl for SplitPartFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "split_part" } @@ -94,12 +96,26 @@ impl ScalarUDFImpl for SplitPartFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "split_part") + if arg_types[0] == DataType::Utf8View { + Ok(DataType::Utf8View) + } else { + utf8_to_str_type(&arg_types[0], "split_part") + } } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, .. } = args; + // Fast path: array string, scalar delimiter and position. + if let ( + ColumnarValue::Array(string_array), + ColumnarValue::Scalar(delim_scalar), + ColumnarValue::Scalar(pos_scalar), + ) = (&args[0], &args[1], &args[2]) + { + return split_part_scalar(string_array, delim_scalar, pos_scalar); + } + // First, determine if any of the arguments is an Array let len = args.iter().find_map(|arg| match arg { ColumnarValue::Array(a) => Some(a.len()), @@ -120,71 +136,66 @@ impl ScalarUDFImpl for SplitPartFunc { // Unpack the ArrayRefs from the arguments let n_array = as_int64_array(&args[2])?; - let result = match (args[0].data_type(), args[1].data_type()) { - (DataType::Utf8View, DataType::Utf8View) => { - split_part_impl::<&StringViewArray, &StringViewArray, i32>( - &args[0].as_string_view(), - &args[1].as_string_view(), - n_array, - ) - } - (DataType::Utf8View, DataType::Utf8) => { - split_part_impl::<&StringViewArray, &GenericStringArray, i32>( - &args[0].as_string_view(), - &args[1].as_string::(), - n_array, - ) - } - (DataType::Utf8View, DataType::LargeUtf8) => { - split_part_impl::<&StringViewArray, &GenericStringArray, i32>( - &args[0].as_string_view(), - &args[1].as_string::(), - n_array, - ) - } - (DataType::Utf8, DataType::Utf8View) => { - split_part_impl::<&GenericStringArray, &StringViewArray, i32>( - &args[0].as_string::(), - &args[1].as_string_view(), - n_array, - ) - } - (DataType::LargeUtf8, DataType::Utf8View) => { - split_part_impl::<&GenericStringArray, &StringViewArray, i64>( - &args[0].as_string::(), - &args[1].as_string_view(), - n_array, - ) - } - (DataType::Utf8, DataType::Utf8) => { - split_part_impl::<&GenericStringArray, &GenericStringArray, i32>( - &args[0].as_string::(), - &args[1].as_string::(), - n_array, - ) - } - (DataType::LargeUtf8, DataType::LargeUtf8) => { - split_part_impl::<&GenericStringArray, &GenericStringArray, i64>( - &args[0].as_string::(), - &args[1].as_string::(), - n_array, - ) - } - (DataType::Utf8, DataType::LargeUtf8) => { - split_part_impl::<&GenericStringArray, &GenericStringArray, i32>( - &args[0].as_string::(), - &args[1].as_string::(), - n_array, + + // Dispatch on delimiter type for a given string array and builder. + macro_rules! split_part_for_delimiter_type { + ($str_arr:expr, $builder:expr) => { + match args[1].data_type() { + DataType::Utf8View => split_part_impl( + $str_arr, + &args[1].as_string_view(), + n_array, + $builder, + ), + DataType::Utf8 => split_part_impl( + $str_arr, + &args[1].as_string::(), + n_array, + $builder, + ), + DataType::LargeUtf8 => split_part_impl( + $str_arr, + &args[1].as_string::(), + n_array, + $builder, + ), + other => { + exec_err!("Unsupported delimiter type {other:?} for split_part") + } + } + }; + } + + let result = match args[0].data_type() { + DataType::Utf8View => split_part_for_delimiter_type!( + &args[0].as_string_view(), + StringViewArrayBuilder::with_capacity(inferred_length) + ), + DataType::Utf8 => { + let str_arr = &args[0].as_string::(); + // Conservative under-estimate for data capacity: split_part + // output is typically much smaller than the input, so avoid + // pre-allocating the full input data size. + split_part_for_delimiter_type!( + str_arr, + GenericStringArrayBuilder::::with_capacity( + inferred_length, + inferred_length, + ) ) } - (DataType::LargeUtf8, DataType::Utf8) => { - split_part_impl::<&GenericStringArray, &GenericStringArray, i64>( - &args[0].as_string::(), - &args[1].as_string::(), - n_array, + DataType::LargeUtf8 => { + let str_arr = &args[0].as_string::(); + // Conservative under-estimate; see Utf8 comment above. + split_part_for_delimiter_type!( + str_arr, + GenericStringArrayBuilder::::with_capacity( + inferred_length, + inferred_length, + ) ) } - _ => exec_err!("Unsupported combination of argument types for split_part"), + other => exec_err!("Unsupported string type {other:?} for split_part"), }; if is_scalar { // If all inputs are scalar, keep the output as scalar @@ -200,57 +211,440 @@ impl ScalarUDFImpl for SplitPartFunc { } } -fn split_part_impl<'a, StringArrType, DelimiterArrType, StringArrayLen>( +/// Finds the `n`th (0-based) split part of `string` by `delimiter`. +#[inline] +fn split_nth<'a>(string: &'a str, delimiter: &str, n: usize) -> Option<&'a str> { + if delimiter.len() == 1 { + // A single-byte UTF-8 string is always ASCII, so we can safely cast + // just the first byte to a character. `str::split(char)` internally + // uses memchr::memchr and is notably faster than `str::split(&str)`, + // even for a single character string. + string.split(delimiter.as_bytes()[0] as char).nth(n) + } else { + string.split(delimiter).nth(n) + } +} + +/// Like `split_nth` but splits from the right (`n` is 0-based from the end). +#[inline] +fn rsplit_nth<'a>(string: &'a str, delimiter: &str, n: usize) -> Option<&'a str> { + if delimiter.len() == 1 { + // A single-byte UTF-8 string is always ASCII, so we can safely cast + // just the first byte to a character. `str::rsplit(char)` internally + // uses memchr::memrchr and is notably faster than `str::rsplit(&str)`, + // even for a single character string. + string.rsplit(delimiter.as_bytes()[0] as char).nth(n) + } else { + string.rsplit(delimiter).nth(n) + } +} + +/// Fast path for `split_part(array, scalar_delimiter, scalar_position)`. +fn split_part_scalar( + string_array: &ArrayRef, + delim_scalar: &ScalarValue, + pos_scalar: &ScalarValue, +) -> Result { + // Empty input array → empty result. + if string_array.is_empty() { + return Ok(ColumnarValue::Array(new_null_array( + string_array.data_type(), + 0, + ))); + } + + let delimiter = delim_scalar.try_as_str().ok_or_else(|| { + exec_datafusion_err!( + "Unsupported delimiter type {:?} for split_part", + delim_scalar.data_type() + ) + })?; + + let position = match pos_scalar { + ScalarValue::Int64(v) => *v, + other => { + return exec_err!( + "Unsupported position type {:?} for split_part", + other.data_type() + ); + } + }; + + // Null delimiter or position → every row is null. + let (Some(delimiter), Some(position)) = (delimiter, position) else { + return Ok(ColumnarValue::Array(new_null_array( + string_array.data_type(), + string_array.len(), + ))); + }; + + if position == 0 { + return exec_err!("field position must not be zero"); + } + + let result = match string_array.data_type() { + DataType::Utf8View => { + split_part_scalar_view(string_array.as_string_view(), delimiter, position) + } + DataType::Utf8 => { + let arr = string_array.as_string::(); + // Conservative under-estimate for data capacity: split_part output + // is typically much smaller than the input, so avoid pre-allocating + // the full input data size. + split_part_scalar_impl( + arr, + delimiter, + position, + GenericStringArrayBuilder::::with_capacity(arr.len(), arr.len()), + ) + } + DataType::LargeUtf8 => { + let arr = string_array.as_string::(); + // Conservative under-estimate; see Utf8 comment above. + split_part_scalar_impl( + arr, + delimiter, + position, + GenericStringArrayBuilder::::with_capacity(arr.len(), arr.len()), + ) + } + other => exec_err!("Unsupported string type {other:?} for split_part"), + }?; + + Ok(ColumnarValue::Array(result)) +} + +/// Inner implementation for the scalar-delimiter, scalar-position fast path. +/// Constructing a `memmem::Finder` is somewhat expensive but it's a win when +/// done once and amortized over the entire batch. +fn split_part_scalar_impl<'a, S, B>( + string_array: S, + delimiter: &str, + position: i64, + builder: B, +) -> Result +where + S: StringArrayType<'a> + Copy, + B: BulkNullStringArrayBuilder, +{ + if delimiter.is_empty() { + // PostgreSQL: empty delimiter treats input as a single field, + // so only position 1 or -1 returns the input string. + return if position == 1 || position == -1 { + map_strings(string_array, builder, Some) + } else { + map_strings(string_array, builder, |_| None) + }; + } + + let delim_bytes = delimiter.as_bytes(); + let delim_len = delimiter.len(); + + if position > 0 { + let idx: usize = (position - 1).try_into().map_err(|_| { + exec_datafusion_err!( + "split_part index {position} exceeds maximum supported value" + ) + })?; + let finder = memmem::Finder::new(delim_bytes); + map_strings(string_array, builder, |s| { + split_nth_finder(s, &finder, delim_len, idx) + }) + } else { + let idx: usize = (position.unsigned_abs() - 1).try_into().map_err(|_| { + exec_datafusion_err!( + "split_part index {position} exceeds minimum supported value" + ) + })?; + let finder_rev = memmem::FinderRev::new(delim_bytes); + map_strings(string_array, builder, |s| { + rsplit_nth_finder(s, &finder_rev, delim_len, idx) + }) + } +} + +/// Applies `f` to each non-null string in `string_array`, appending the +/// result (or `""` when `f` returns `None`) to `builder`. +#[inline] +fn map_strings<'a, S, B, F>(string_array: S, mut builder: B, f: F) -> Result +where + S: StringArrayType<'a> + Copy, + B: BulkNullStringArrayBuilder, + F: Fn(&'a str) -> Option<&'a str>, +{ + let item_len = string_array.len(); + let nulls = string_array.nulls().cloned(); + + if let Some(ref n) = nulls { + for i in 0..item_len { + if n.is_null(i) { + builder.append_placeholder(); + } else { + // SAFETY: `n.is_null(i)` was false in the branch above. + let s = unsafe { string_array.value_unchecked(i) }; + builder.append_value(f(s).unwrap_or("")); + } + } + } else { + for i in 0..item_len { + // SAFETY: no null buffer means every index is valid. + let s = unsafe { string_array.value_unchecked(i) }; + builder.append_value(f(s).unwrap_or("")); + } + } + + builder.finish(nulls) +} + +/// Finds the `n`th (0-based) split part using a pre-built `memmem::Finder`. +#[inline] +fn split_nth_finder<'a>( + string: &'a str, + finder: &memmem::Finder, + delim_len: usize, + n: usize, +) -> Option<&'a str> { + let bytes = string.as_bytes(); + let mut start = 0; + for _ in 0..n { + match finder.find(&bytes[start..]) { + Some(pos) => start += pos + delim_len, + None => return None, + } + } + match finder.find(&bytes[start..]) { + Some(pos) => Some(&string[start..start + pos]), + None => Some(&string[start..]), + } +} + +/// Like `split_nth_finder` but splits from the right (`n` is 0-based from +/// the end). +#[inline] +fn rsplit_nth_finder<'a>( + string: &'a str, + finder: &memmem::FinderRev, + delim_len: usize, + n: usize, +) -> Option<&'a str> { + let bytes = string.as_bytes(); + let mut end = bytes.len(); + for _ in 0..n { + match finder.rfind(&bytes[..end]) { + Some(pos) => end = pos, + None => return None, + } + } + match finder.rfind(&bytes[..end]) { + Some(pos) => Some(&string[pos + delim_len..end]), + None => Some(&string[..end]), + } +} + +/// Zero-copy scalar fast path for `StringViewArray` inputs. +/// +/// Instead of copying substring bytes into a new buffer, constructs +/// `StringView` entries that point back into the original array's data +/// buffers. +fn split_part_scalar_view( + string_view_array: &StringViewArray, + delimiter: &str, + position: i64, +) -> Result { + let len = string_view_array.len(); + let mut views_buf = Vec::with_capacity(len); + let views = string_view_array.views(); + + if delimiter.is_empty() { + // PostgreSQL: empty delimiter treats input as a single field. + let empty_view = make_view(b"", 0, 0); + let return_input = position == 1 || position == -1; + for i in 0..len { + if string_view_array.is_null(i) { + views_buf.push(0); + } else if return_input { + views_buf.push(views[i]); + } else { + views_buf.push(empty_view); + } + } + } else if position > 0 { + let idx: usize = (position - 1).try_into().map_err(|_| { + exec_datafusion_err!( + "split_part index {position} exceeds maximum supported value" + ) + })?; + let finder = memmem::Finder::new(delimiter.as_bytes()); + split_view_loop(string_view_array, views, &mut views_buf, |s| { + split_nth_finder(s, &finder, delimiter.len(), idx) + }); + } else { + let idx: usize = (position.unsigned_abs() - 1).try_into().map_err(|_| { + exec_datafusion_err!( + "split_part index {position} exceeds minimum supported value" + ) + })?; + let finder_rev = memmem::FinderRev::new(delimiter.as_bytes()); + split_view_loop(string_view_array, views, &mut views_buf, |s| { + rsplit_nth_finder(s, &finder_rev, delimiter.len(), idx) + }); + } + + let views_buf = ScalarBuffer::from(views_buf); + + // Nulls pass through unchanged, so we can use the input's null array. + let nulls = string_view_array.nulls().cloned(); + + // Safety: each view is either copied unchanged from the input, or built + // by `substr_view` from a substring that is a contiguous sub-range of the + // original string value stored in the input's data buffers. + unsafe { + Ok(Arc::new(StringViewArray::new_unchecked( + views_buf, + string_view_array.data_buffers().to_vec(), + nulls, + )) as ArrayRef) + } +} + +/// Creates a `StringView` referencing a substring of an existing view's buffer. +/// For substrings ≤ 12 bytes, creates an inline view instead. +#[inline] +fn substr_view(original_view: &u128, substr: &str, start_offset: u32) -> u128 { + if substr.len() > 12 { + let view = ByteView::from(*original_view); + make_view( + substr.as_bytes(), + view.buffer_index, + view.offset + start_offset, + ) + } else { + make_view(substr.as_bytes(), 0, 0) + } +} + +/// Applies `split_fn` to each non-null string and appends the resulting view to +/// `views_buf`. +#[inline(always)] +fn split_view_loop( + string_view_array: &StringViewArray, + views: &[u128], + views_buf: &mut Vec, + split_fn: F, +) where + F: Fn(&str) -> Option<&str>, +{ + let empty_view = make_view(b"", 0, 0); + for (i, raw_view) in views.iter().enumerate() { + if string_view_array.is_null(i) { + views_buf.push(0); + continue; + } + let string = string_view_array.value(i); + match split_fn(string) { + Some(substr) => { + let start_offset = substr.as_ptr() as usize - string.as_ptr() as usize; + views_buf.push(substr_view(raw_view, substr, start_offset as u32)); + } + None => views_buf.push(empty_view), + } + } +} + +fn split_part_impl<'a, StringArrType, DelimiterArrType, B>( string_array: &StringArrType, delimiter_array: &DelimiterArrType, n_array: &Int64Array, + mut builder: B, ) -> Result where StringArrType: StringArrayType<'a>, DelimiterArrType: StringArrayType<'a>, - StringArrayLen: OffsetSizeTrait, + B: BulkNullStringArrayBuilder, { - let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - - string_array - .iter() - .zip(delimiter_array.iter()) - .zip(n_array.iter()) - .try_for_each(|((string, delimiter), n)| -> Result<(), DataFusionError> { - match (string, delimiter, n) { - (Some(string), Some(delimiter), Some(n)) => { - let split_string: Vec<&str> = string.split(delimiter).collect(); - let len = split_string.len(); - - let index = match n.cmp(&0) { - std::cmp::Ordering::Less => len as i64 + n, - std::cmp::Ordering::Equal => { - return exec_err!("field position must not be zero"); - } - std::cmp::Ordering::Greater => n - 1, - } as usize; - - if index < len { - builder.append_value(split_string[index]); - } else { - builder.append_value(""); - } - } - _ => builder.append_null(), + let nulls = NullBuffer::union_many([ + string_array.nulls(), + delimiter_array.nulls(), + n_array.nulls(), + ]); + + if let Some(ref n) = nulls { + for i in 0..string_array.len() { + if n.is_null(i) { + builder.append_placeholder(); + continue; } - Ok(()) - })?; - Ok(Arc::new(builder.finish()) as ArrayRef) + // SAFETY: the union null buffer is valid at `i`, so each input is valid. + let string = unsafe { string_array.value_unchecked(i) }; + let delimiter = unsafe { delimiter_array.value_unchecked(i) }; + let position = unsafe { n_array.value_unchecked(i) }; + append_split_part(string, delimiter, position, &mut builder)?; + } + } else { + for i in 0..string_array.len() { + // SAFETY: no input has a null buffer, so every index is valid. + let string = unsafe { string_array.value_unchecked(i) }; + let delimiter = unsafe { delimiter_array.value_unchecked(i) }; + let position = unsafe { n_array.value_unchecked(i) }; + append_split_part(string, delimiter, position, &mut builder)?; + } + } + + builder.finish(nulls) +} + +#[inline] +fn append_split_part( + string: &str, + delimiter: &str, + n: i64, + builder: &mut B, +) -> Result<()> { + let result = match n.cmp(&0) { + std::cmp::Ordering::Greater => { + let idx: usize = (n - 1).try_into().map_err(|_| { + exec_datafusion_err!( + "split_part index {n} exceeds maximum supported value" + ) + })?; + if delimiter.is_empty() { + // Match PostgreSQL's behavior: empty delimiter treats input + // as a single field, so only position 1 returns data. + (n == 1).then_some(string) + } else { + split_nth(string, delimiter, idx) + } + } + std::cmp::Ordering::Less => { + let idx: usize = (n.unsigned_abs() - 1).try_into().map_err(|_| { + exec_datafusion_err!( + "split_part index {n} exceeds minimum supported value" + ) + })?; + if delimiter.is_empty() { + // Match PostgreSQL's behavior: empty delimiter treats input + // as a single field, so only position -1 returns data. + (n == -1).then_some(string) + } else { + rsplit_nth(string, delimiter, idx) + } + } + std::cmp::Ordering::Equal => { + return exec_err!("field position must not be zero"); + } + }; + builder.append_value(result.unwrap_or("")); + Ok(()) } #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; + use arrow::array::{Array, AsArray, StringArray, StringViewArray}; use arrow::datatypes::DataType::Utf8; use datafusion_common::ScalarValue; - use datafusion_common::{exec_err, Result}; + use datafusion_common::{Result, exec_err}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use crate::string::split_part::SplitPartFunc; @@ -314,6 +708,158 @@ mod tests { Utf8, StringArray ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + "abc~@~def~@~ghi" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MIN))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + // Edge cases with delimiters + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(",")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ], + Ok(Some("a")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(",")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ], + Ok(Some("a,b")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ], + Ok(Some("a,b")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + + // Edge cases with delimiters with negative n + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))), + ], + Ok(Some("a,b")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))), + ], + Ok(Some("a,b")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-2))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + + Ok(()) + } + + #[test] + fn test_split_part_stringview_sliced() -> Result<()> { + use super::split_part_scalar_view; + + let strings: StringViewArray = vec![ + Some("skip_this.value"), + Some("this_is_a_long_prefix.suffix"), + Some("short.val"), + Some("another_long_result.rest"), + None, + ] + .into_iter() + .collect(); + + // Slice off the first element to get a non-zero offset array. + let sliced = strings.slice(1, 4); + let result = split_part_scalar_view(&sliced, ".", 1)?; + let result = result.as_string_view(); + assert_eq!(result.len(), 4); + assert_eq!(result.value(0), "this_is_a_long_prefix"); + assert_eq!(result.value(1), "short"); + assert_eq!(result.value(2), "another_long_result"); + assert!(result.is_null(3)); Ok(()) } diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index c4159cba86f34..c89bd66d72cdf 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -15,50 +15,24 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; -use arrow::array::ArrayRef; +use arrow::array::{ArrayRef, Scalar}; +use arrow::compute::kernels::comparison::starts_with as arrow_starts_with; use arrow::datatypes::DataType; -use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_common::types::logical_string; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; use datafusion_expr::type_coercion::binary::{ binary_to_string_coercion, string_coercion, }; - -use crate::utils::make_scalar_function; -use datafusion_common::types::logical_string; -use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::{ - cast, Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs, - ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, + Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, cast, }; use datafusion_macros::user_doc; -/// Returns true if string starts with prefix. -/// starts_with('alphabet', 'alph') = 't' -fn starts_with(args: &[ArrayRef]) -> Result { - if let Some(coercion_data_type) = - string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| { - binary_to_string_coercion(args[0].data_type(), args[1].data_type()) - }) - { - let arg0 = if args[0].data_type() == &coercion_data_type { - Arc::clone(&args[0]) - } else { - arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)? - }; - let arg1 = if args[1].data_type() == &coercion_data_type { - Arc::clone(&args[1]) - } else { - arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)? - }; - let result = arrow::compute::kernels::comparison::starts_with(&arg0, &arg1)?; - Ok(Arc::new(result) as ArrayRef) - } else { - internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View") - } -} - #[user_doc( doc_section(label = "String Functions"), description = "Tests if a string starts with a substring.", @@ -100,10 +74,6 @@ impl StartsWithFunc { } impl ScalarUDFImpl for StartsWithFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "starts_with" } @@ -117,30 +87,93 @@ impl ScalarUDFImpl for StartsWithFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - match args.args[0].data_type() { - DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => { - make_scalar_function(starts_with, vec![])(&args.args) + let [str_arg, prefix_arg] = take_function_args(self.name(), &args.args)?; + + // Determine the common type for coercion + let coercion_type = string_coercion( + &str_arg.data_type(), + &prefix_arg.data_type(), + ) + .or_else(|| { + binary_to_string_coercion(&str_arg.data_type(), &prefix_arg.data_type()) + }); + + let Some(coercion_type) = coercion_type else { + return exec_err!( + "Unsupported data types {:?}, {:?} for function `starts_with`.", + str_arg.data_type(), + prefix_arg.data_type() + ); + }; + + // Helper to cast an array if needed + let maybe_cast = |arr: &ArrayRef, target: &DataType| -> Result { + if arr.data_type() == target { + Ok(Arc::clone(arr)) + } else { + Ok(arrow::compute::kernels::cast::cast(arr, target)?) + } + }; + + match (str_arg, prefix_arg) { + // Both scalars - just compute directly + (ColumnarValue::Scalar(str_scalar), ColumnarValue::Scalar(prefix_scalar)) => { + let str_arr = str_scalar.to_array_of_size(1)?; + let prefix_arr = prefix_scalar.to_array_of_size(1)?; + let str_arr = maybe_cast(&str_arr, &coercion_type)?; + let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?; + let result = arrow_starts_with(&str_arr, &prefix_arr)?; + Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &result, 0, + )?)) + } + // String is array, prefix is scalar - use Scalar wrapper for optimization + (ColumnarValue::Array(str_arr), ColumnarValue::Scalar(prefix_scalar)) => { + let str_arr = maybe_cast(str_arr, &coercion_type)?; + let prefix_arr = prefix_scalar.to_array_of_size(1)?; + let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?; + let prefix_scalar = Scalar::new(prefix_arr); + let result = arrow_starts_with(&str_arr, &prefix_scalar)?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + // String is scalar, prefix is array - use Scalar wrapper for string + (ColumnarValue::Scalar(str_scalar), ColumnarValue::Array(prefix_arr)) => { + let str_arr = str_scalar.to_array_of_size(1)?; + let str_arr = maybe_cast(&str_arr, &coercion_type)?; + let str_scalar = Scalar::new(str_arr); + let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?; + let result = arrow_starts_with(&str_scalar, &prefix_arr)?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + // Both arrays - pass directly + (ColumnarValue::Array(str_arr), ColumnarValue::Array(prefix_arr)) => { + let str_arr = maybe_cast(str_arr, &coercion_type)?; + let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?; + let result = arrow_starts_with(&str_arr, &prefix_arr)?; + Ok(ColumnarValue::Array(Arc::new(result))) } - _ => internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")?, } } fn simplify( &self, args: Vec, - info: &dyn SimplifyInfo, + info: &SimplifyContext, ) -> Result { if let Expr::Literal(scalar_value, _) = &args[1] { // Convert starts_with(col, 'prefix') to col LIKE 'prefix%' with proper escaping - // Example: starts_with(col, 'ja%') -> col LIKE 'ja\%%' - // 1. 'ja%' (input pattern) - // 2. 'ja\%' (escape special char '%') - // 3. 'ja\%%' (add suffix for starts_with) + // Escapes pattern characters: starts_with(col, 'j\_a%') -> col LIKE 'j\\\_a\%%' + // 1. 'j\_a%' (input pattern) + // 2. 'j\\\_a\%' (escape special chars '%', '_' and '\') + // 3. 'j\\\_a\%%' (add unescaped % suffix for starts_with) let like_expr = match scalar_value { ScalarValue::Utf8(Some(pattern)) | ScalarValue::LargeUtf8(Some(pattern)) | ScalarValue::Utf8View(Some(pattern)) => { - let escaped_pattern = pattern.replace("%", "\\%"); + let escaped_pattern = pattern + .replace("\\", "\\\\") + .replace("%", "\\%") + .replace("_", "\\_"); let like_pattern = format!("{escaped_pattern}%"); Expr::Literal(ScalarValue::Utf8(Some(like_pattern)), None) } @@ -188,16 +221,16 @@ impl ScalarUDFImpl for StartsWithFunc { #[cfg(test)] mod tests { use crate::utils::test::test_function; - use arrow::array::{Array, BooleanArray}; + use arrow::array::{Array, BooleanArray, StringArray}; use arrow::datatypes::DataType::Boolean; - use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use arrow::datatypes::Field; + use datafusion_common::config::ConfigOptions; use super::*; #[test] - fn test_functions() -> Result<()> { - // Generate test cases for starts_with + fn test_scalar_scalar() -> Result<()> { + // Test Scalar + Scalar combinations let test_cases = vec![ (Some("alphabet"), Some("alph"), Some(true)), (Some("alphabet"), Some("bet"), Some(false)), @@ -241,4 +274,154 @@ mod tests { Ok(()) } + + #[test] + fn test_array_scalar() -> Result<()> { + // Test Array + Scalar (the optimized path) + let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("alphabet"), + Some("alphabet"), + Some("beta"), + None, + ]))); + let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string()))); + + let args = vec![array, scalar]; + test_function!( + StartsWithFunc::new(), + args, + Ok(Some(true)), // First element result + bool, + Boolean, + BooleanArray + ); + + Ok(()) + } + + #[test] + fn test_array_scalar_full_result() { + // Test Array + Scalar and verify all results + let func = StartsWithFunc::new(); + let array = Arc::new(StringArray::from(vec![ + Some("alphabet"), + Some("alphabet"), + Some("beta"), + None, + ])); + let args = vec![ + ColumnarValue::Array(array), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string()))), + ]; + + let result = func + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ], + number_rows: 4, + return_field: Field::new("f", Boolean, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(); + + let result_array = result.into_array(4).unwrap(); + let bool_array = result_array + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(bool_array.value(0)); // "alphabet" starts with "alph" + assert!(bool_array.value(1)); // "alphabet" starts with "alph" + assert!(!bool_array.value(2)); // "beta" does not start with "alph" + assert!(bool_array.is_null(3)); // null input -> null output + } + + #[test] + fn test_scalar_array() { + // Test Scalar + Array + let func = StartsWithFunc::new(); + let prefixes = Arc::new(StringArray::from(vec![ + Some("alph"), + Some("bet"), + Some("alpha"), + None, + ])); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("alphabet".to_string()))), + ColumnarValue::Array(prefixes), + ]; + + let result = func + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ], + number_rows: 4, + return_field: Field::new("f", Boolean, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(); + + let result_array = result.into_array(4).unwrap(); + let bool_array = result_array + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(bool_array.value(0)); // "alphabet" starts with "alph" + assert!(!bool_array.value(1)); // "alphabet" does not start with "bet" + assert!(bool_array.value(2)); // "alphabet" starts with "alpha" + assert!(bool_array.is_null(3)); // null prefix -> null output + } + + #[test] + fn test_array_array() { + // Test Array + Array + let func = StartsWithFunc::new(); + let strings = Arc::new(StringArray::from(vec![ + Some("alphabet"), + Some("rust"), + Some("datafusion"), + None, + ])); + let prefixes = Arc::new(StringArray::from(vec![ + Some("alph"), + Some("ru"), + Some("hello"), + Some("test"), + ])); + let args = vec![ + ColumnarValue::Array(strings), + ColumnarValue::Array(prefixes), + ]; + + let result = func + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: vec![ + Field::new("a", DataType::Utf8, true).into(), + Field::new("b", DataType::Utf8, true).into(), + ], + number_rows: 4, + return_field: Field::new("f", Boolean, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(); + + let result_array = result.into_array(4).unwrap(); + let bool_array = result_array + .as_any() + .downcast_ref::() + .unwrap(); + + assert!(bool_array.value(0)); // "alphabet" starts with "alph" + assert!(bool_array.value(1)); // "rust" starts with "ru" + assert!(!bool_array.value(2)); // "datafusion" does not start with "hello" + assert!(bool_array.is_null(3)); // null string -> null output + } } diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index 4000f3bb3be2a..497a0a1206922 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -15,64 +15,166 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::fmt::Write; use std::sync::Arc; -use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, GenericStringBuilder}; -use arrow::datatypes::DataType::{ - Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, Utf8, -}; +use arrow::array::{Array, ArrayRef, StringArray}; +use arrow::buffer::{Buffer, OffsetBuffer}; use arrow::datatypes::{ - ArrowNativeType, ArrowPrimitiveType, DataType, Int16Type, Int32Type, Int64Type, - Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ArrowNativeType, ArrowPrimitiveType, DataType, Int8Type, Int16Type, Int32Type, + Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, }; use datafusion_common::cast::as_primitive_array; -use datafusion_common::Result; -use datafusion_common::{exec_err, plan_err}; - -use datafusion_expr::{ColumnarValue, Documentation}; -use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; -use datafusion_expr_common::signature::TypeSignature::Exact; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; use datafusion_macros::user_doc; +/// Hex lookup table for fast conversion +const HEX_CHARS: &[u8; 16] = b"0123456789abcdef"; + /// Converts the number to its equivalent hexadecimal representation. /// to_hex(2147483647) = '7fffffff' -fn to_hex(args: &[ArrayRef]) -> Result +fn to_hex_array(array: &ArrayRef) -> Result where - T::Native: std::fmt::LowerHex, + T::Native: ToHex, { - let integer_array = as_primitive_array::(&args[0])?; + let integer_array = as_primitive_array::(array)?; + let len = integer_array.len(); - let mut result = GenericStringBuilder::::with_capacity( - integer_array.len(), - // * 8 to convert to bits, / 4 bits per hex char - integer_array.len() * (T::Native::get_byte_width() * 8 / 4), - ); + // Max hex string length: 16 chars for u64/i64 + let max_hex_len = T::Native::get_byte_width() * 2; - for integer in integer_array { - if let Some(value) = integer { - if let Some(value_usize) = value.to_usize() { - write!(result, "{value_usize:x}")?; - } else if let Some(value_isize) = value.to_isize() { - write!(result, "{value_isize:x}")?; - } else { - return exec_err!( - "Unsupported data type {integer:?} for function to_hex" - ); - } - result.append_value(""); - } else { - result.append_null(); - } + // Pre-allocate buffers - avoid the builder API overhead + let mut offsets: Vec = Vec::with_capacity(len + 1); + let mut values: Vec = Vec::with_capacity(len * max_hex_len); + + // Reusable buffer for hex conversion + let mut hex_buffer = [0u8; 16]; + + // Start with offset 0 + offsets.push(0); + + // Process all values directly (including null slots - we write empty strings for nulls) + // The null bitmap will mark which entries are actually null + for value in integer_array.values() { + let hex_len = value.write_hex_to_buffer(&mut hex_buffer); + values.extend_from_slice(&hex_buffer[16 - hex_len..]); + offsets.push(values.len() as i32); } - let result = result.finish(); + // Copy null bitmap from input (nulls pass through unchanged) + let nulls = integer_array.nulls().cloned(); + + // SAFETY: offsets are valid (monotonically increasing, last value equals values.len()) + // and values contains valid UTF-8 (only ASCII hex digits) + let offsets = + unsafe { OffsetBuffer::new_unchecked(Buffer::from_vec(offsets).into()) }; + let result = StringArray::new(offsets, Buffer::from_vec(values), nulls); Ok(Arc::new(result) as ArrayRef) } +#[inline] +fn to_hex_scalar(value: T) -> String { + let mut hex_buffer = [0u8; 16]; + let hex_len = value.write_hex_to_buffer(&mut hex_buffer); + // SAFETY: hex_buffer is ASCII hex digits + unsafe { std::str::from_utf8_unchecked(&hex_buffer[16 - hex_len..]).to_string() } +} + +/// Trait for converting integer types to hexadecimal in a buffer +trait ToHex: ArrowNativeType { + /// Write hex representation to buffer and return the number of hex digits written. + /// The hex digits are written right-aligned in the buffer (starting from position 16 - len). + fn write_hex_to_buffer(self, buffer: &mut [u8; 16]) -> usize; +} + +/// Write unsigned value to hex buffer and return the number of digits written. +/// Digits are written right-aligned in the buffer. +#[inline] +fn write_unsigned_hex_to_buffer(value: u64, buffer: &mut [u8; 16]) -> usize { + if value == 0 { + buffer[15] = b'0'; + return 1; + } + + // Write hex digits from right to left + let mut pos = 16; + let mut v = value; + while v > 0 { + pos -= 1; + buffer[pos] = HEX_CHARS[(v & 0xf) as usize]; + v >>= 4; + } + + 16 - pos +} + +/// Write signed value to hex buffer (two's complement for negative) and return digit count +#[inline] +fn write_signed_hex_to_buffer(value: i64, buffer: &mut [u8; 16]) -> usize { + // For negative values, use two's complement representation (same as casting to u64) + write_unsigned_hex_to_buffer(value as u64, buffer) +} + +impl ToHex for i8 { + #[inline] + fn write_hex_to_buffer(self, buffer: &mut [u8; 16]) -> usize { + write_signed_hex_to_buffer(self as i64, buffer) + } +} + +impl ToHex for i16 { + #[inline] + fn write_hex_to_buffer(self, buffer: &mut [u8; 16]) -> usize { + write_signed_hex_to_buffer(self as i64, buffer) + } +} + +impl ToHex for i32 { + #[inline] + fn write_hex_to_buffer(self, buffer: &mut [u8; 16]) -> usize { + write_signed_hex_to_buffer(self as i64, buffer) + } +} + +impl ToHex for i64 { + #[inline] + fn write_hex_to_buffer(self, buffer: &mut [u8; 16]) -> usize { + write_signed_hex_to_buffer(self, buffer) + } +} + +impl ToHex for u8 { + #[inline] + fn write_hex_to_buffer(self, buffer: &mut [u8; 16]) -> usize { + write_unsigned_hex_to_buffer(self as u64, buffer) + } +} + +impl ToHex for u16 { + #[inline] + fn write_hex_to_buffer(self, buffer: &mut [u8; 16]) -> usize { + write_unsigned_hex_to_buffer(self as u64, buffer) + } +} + +impl ToHex for u32 { + #[inline] + fn write_hex_to_buffer(self, buffer: &mut [u8; 16]) -> usize { + write_unsigned_hex_to_buffer(self as u64, buffer) + } +} + +impl ToHex for u64 { + #[inline] + fn write_hex_to_buffer(self, buffer: &mut [u8; 16]) -> usize { + write_unsigned_hex_to_buffer(self, buffer) + } +} + #[user_doc( doc_section(label = "String Functions"), description = "Converts an integer to a hexadecimal string.", @@ -101,17 +203,8 @@ impl Default for ToHexFunc { impl ToHexFunc { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - Exact(vec![Int8]), - Exact(vec![Int16]), - Exact(vec![Int32]), - Exact(vec![Int64]), - Exact(vec![UInt8]), - Exact(vec![UInt16]), - Exact(vec![UInt32]), - Exact(vec![UInt64]), - ], + signature: Signature::coercible( + vec![Coercion::new_exact(TypeSignatureClass::Integer)], Volatility::Immutable, ), } @@ -119,10 +212,6 @@ impl ToHexFunc { } impl ScalarUDFImpl for ToHexFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "to_hex" } @@ -131,26 +220,76 @@ impl ScalarUDFImpl for ToHexFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(match arg_types[0] { - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => Utf8, - _ => { - return plan_err!("The to_hex function can only accept integers."); - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - match args.args[0].data_type() { - Int64 => make_scalar_function(to_hex::, vec![])(&args.args), - UInt64 => make_scalar_function(to_hex::, vec![])(&args.args), - Int32 => make_scalar_function(to_hex::, vec![])(&args.args), - UInt32 => make_scalar_function(to_hex::, vec![])(&args.args), - Int16 => make_scalar_function(to_hex::, vec![])(&args.args), - UInt16 => make_scalar_function(to_hex::, vec![])(&args.args), - Int8 => make_scalar_function(to_hex::, vec![])(&args.args), - UInt8 => make_scalar_function(to_hex::, vec![])(&args.args), - other => exec_err!("Unsupported data type {other:?} for function to_hex"), + let arg = &args.args[0]; + + match arg { + ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::UInt64(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::Int32(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::UInt32(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::Int16(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::UInt16(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::Int8(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + ColumnarValue::Scalar(ScalarValue::UInt8(Some(v))) => Ok( + ColumnarValue::Scalar(ScalarValue::Utf8(Some(to_hex_scalar(*v)))), + ), + + // NULL scalars + ColumnarValue::Scalar(s) if s.is_null() => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) + } + + ColumnarValue::Array(array) => match array.data_type() { + DataType::Int64 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::UInt64 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::Int32 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::UInt32 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::Int16 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::UInt16 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::Int8 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + DataType::UInt8 => { + Ok(ColumnarValue::Array(to_hex_array::(array)?)) + } + other => exec_err!("Unsupported data type {other:?} for function to_hex"), + }, + + other => internal_err!( + "Unexpected argument type {:?} for function to_hex", + other.data_type() + ), } } @@ -162,8 +301,8 @@ impl ScalarUDFImpl for ToHexFunc { #[cfg(test)] mod tests { use arrow::array::{ - Int16Array, Int32Array, Int64Array, Int8Array, StringArray, UInt16Array, - UInt32Array, UInt64Array, UInt8Array, + Int8Array, Int16Array, Int32Array, Int64Array, StringArray, UInt8Array, + UInt16Array, UInt32Array, UInt64Array, }; use datafusion_common::cast::as_string_array; @@ -189,8 +328,8 @@ mod tests { let expected = $expected; let array = <$array_type>::from(input); - let array_ref = Arc::new(array); - let hex_result = to_hex::<$arrow_type>(&[array_ref])?; + let array_ref: ArrayRef = Arc::new(array); + let hex_result = to_hex_array::<$arrow_type>(&array_ref)?; let hex_array = as_string_array(&hex_result)?; let expected_array = StringArray::from(expected); diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 8bb2ec1d511cd..c0ac90b1bc598 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -16,16 +16,14 @@ // under the License. use crate::string::common::to_upper; -use crate::utils::utf8_to_str_type; use arrow::datatypes::DataType; -use datafusion_common::types::logical_string; use datafusion_common::Result; +use datafusion_common::types::logical_string; use datafusion_expr::{ Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; #[user_doc( doc_section(label = "String Functions"), @@ -68,10 +66,6 @@ impl UpperFunc { } impl ScalarUDFImpl for UpperFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "upper" } @@ -81,7 +75,7 @@ impl ScalarUDFImpl for UpperFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "upper") + Ok(arg_types[0].clone()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -96,28 +90,29 @@ impl ScalarUDFImpl for UpperFunc { #[cfg(test)] mod tests { use super::*; - use arrow::array::{Array, ArrayRef, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, ArrayRef, StringArray, StringViewArray}; use arrow::datatypes::Field; use datafusion_common::config::ConfigOptions; use std::sync::Arc; - fn to_upper(input: ArrayRef, expected: ArrayRef) -> Result<()> { + fn invoke_upper(input: ArrayRef) -> Result { let func = UpperFunc::new(); - - let arg_field = Field::new("a", input.data_type().clone(), true).into(); + let data_type = input.data_type().clone(); let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], - arg_fields: vec![arg_field], - return_field: Field::new("f", Utf8, true).into(), + arg_fields: vec![Field::new("a", data_type.clone(), true).into()], + return_field: Field::new("f", data_type, true).into(), config_options: Arc::new(ConfigOptions::default()), }; - - let result = match func.invoke_with_args(args)? { - ColumnarValue::Array(result) => result, + match func.invoke_with_args(args)? { + ColumnarValue::Array(r) => Ok(r), _ => unreachable!("upper"), - }; + } + } + + fn to_upper(input: ArrayRef, expected: ArrayRef) -> Result<()> { + let result = invoke_upper(input)?; assert_eq!(&expected, &result); Ok(()) } @@ -196,4 +191,182 @@ mod tests { to_upper(input, expected) } + + #[test] + fn upper_utf8view() -> Result<()> { + let input = Arc::new(StringViewArray::from(vec![ + Some("arrow"), + None, + Some("tschüß"), + ])) as ArrayRef; + + let expected = Arc::new(StringViewArray::from(vec![ + Some("ARROW"), + None, + Some("TSCHÜSS"), + ])) as ArrayRef; + + to_upper(input, expected) + } + + #[test] + fn upper_ascii_utf8view() -> Result<()> { + // Mix of inlined (≤12 bytes) and referenced (>12 bytes) strings, plus + // a null and an empty, to exercise the all-ASCII Utf8View fast path. + let input = Arc::new(StringViewArray::from(vec![ + Some("arrow"), // inlined short + None, + Some("hello world 123"), // referenced (15 bytes) + Some(""), + Some("0123456789"), // inlined, no case change + Some("datafusion is cool"), // referenced + ])) as ArrayRef; + + let expected = Arc::new(StringViewArray::from(vec![ + Some("ARROW"), + None, + Some("HELLO WORLD 123"), + Some(""), + Some("0123456789"), + Some("DATAFUSION IS COOL"), + ])) as ArrayRef; + + to_upper(input, expected) + } + + #[test] + fn upper_sliced_ascii_utf8view() -> Result<()> { + // Slice of a parent that contains a non-ASCII string outside the + // slice. The slice is all-ASCII, so the fast path must run and produce + // correct output while the parent's unaddressed non-ASCII bytes are + // irrelevant to the result. + let parent = Arc::new(StringViewArray::from(vec![ + Some("农历新年long enough for buffer"), + Some("hello world 123"), + Some("datafusion rocks!"), + Some("zzzzzzzzzzzzzzzz"), + ])) as ArrayRef; + let sliced = parent.slice(1, 2); + let result = invoke_upper(sliced)?; + let result_sv = result.as_any().downcast_ref::().unwrap(); + + let expected = StringViewArray::from(vec![ + Some("HELLO WORLD 123"), + Some("DATAFUSION ROCKS!"), + ]); + assert_eq!(result_sv, &expected); + // The slice's two long views address 15 + 17 = 32 bytes; the ASCII + // fast path must produce a single packed buffer of exactly that + // size, not one scaled to the parent's data buffer. + assert_eq!(result_sv.data_buffers().len(), 1); + assert_eq!(result_sv.data_buffers()[0].len(), 32); + Ok(()) + } + + #[test] + fn upper_utf8view_inline_only_no_buffers() -> Result<()> { + // An array whose values are all ≤ 12 bytes is fully inline; the ASCII + // fast path should produce no data buffers at all. + let input = Arc::new(StringViewArray::from(vec![ + Some("hello"), + None, + Some(""), + Some("0123456789AB"), // 12 bytes — inline boundary + ])) as ArrayRef; + let result = invoke_upper(input)?; + let result_sv = result.as_any().downcast_ref::().unwrap(); + + let expected = StringViewArray::from(vec![ + Some("HELLO"), + None, + Some(""), + Some("0123456789AB"), + ]); + assert_eq!(result_sv, &expected); + assert_eq!( + result_sv.data_buffers().len(), + 0, + "inline-only Utf8View should produce no data buffers" + ); + Ok(()) + } + + #[test] + fn upper_utf8view_long_packs_tight() -> Result<()> { + // Mix of long and inline values; the long values should be packed into + // a single tight output buffer whose size is exactly the sum of their + // lengths (inline values do not contribute). + let input = Arc::new(StringViewArray::from(vec![ + Some("hello world 123"), // 15 bytes (long) + Some("abc"), // inline + None, + Some("datafusion rocks!"), // 17 bytes (long) + Some("another long string"), // 19 bytes (long) + ])) as ArrayRef; + let result = invoke_upper(input)?; + let result_sv = result.as_any().downcast_ref::().unwrap(); + + let expected = StringViewArray::from(vec![ + Some("HELLO WORLD 123"), + Some("ABC"), + None, + Some("DATAFUSION ROCKS!"), + Some("ANOTHER LONG STRING"), + ]); + assert_eq!(result_sv, &expected); + assert_eq!(result_sv.data_buffers().len(), 1); + assert_eq!(result_sv.data_buffers()[0].len(), 15 + 17 + 19); + Ok(()) + } + + #[test] + fn upper_utf8view_splits_into_multiple_buffers() -> Result<()> { + // Produce enough long-string output to overflow the first data block + // (≈16 KiB after the initial doubling) and confirm the fast path + // splits across buffers rather than packing everything into one and + // risking the i32::MAX offset limit. + const STR_LEN: usize = 500; + const N: usize = 40; // 40 × 500 B = 20 KiB total — crosses the first block. + let value = "x".repeat(STR_LEN); + let inputs: Vec> = (0..N).map(|_| Some(value.clone())).collect(); + let input = Arc::new(StringViewArray::from(inputs.clone())) as ArrayRef; + let result = invoke_upper(input)?; + let result_sv = result.as_any().downcast_ref::().unwrap(); + + let expected_value = "X".repeat(STR_LEN); + let expected: Vec> = + (0..N).map(|_| Some(expected_value.as_str())).collect(); + assert_eq!(result_sv, &StringViewArray::from(expected)); + assert!( + result_sv.data_buffers().len() >= 2, + "expected the output to span more than one data buffer, got {}", + result_sv.data_buffers().len() + ); + // Total bytes across buffers must equal total long-value bytes + // (no row was inlined since each value is > 12 bytes). + let total: usize = result_sv.data_buffers().iter().map(|b| b.len()).sum(); + assert_eq!(total, N * STR_LEN); + Ok(()) + } + + #[test] + fn upper_sliced_utf8() -> Result<()> { + let parent = Arc::new(StringArray::from(vec![ + Some("aaaaaaaa"), + Some("hello"), + Some("world"), + Some(""), + Some("zzzzzzzz"), + ])) as ArrayRef; + let sliced = parent.slice(1, 3); + let result = invoke_upper(sliced)?; + let result_sa = result.as_any().downcast_ref::().unwrap(); + + let expected = StringArray::from(vec![Some("HELLO"), Some("WORLD"), Some("")]); + assert_eq!(result_sa, &expected); + // The slice's addressed bytes are "hello" + "world" = 10; the ASCII + // fast path must produce a tight output buffer (not the parent's). + assert_eq!(result_sa.value_data().len(), 10); + Ok(()) + } } diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs index 96ce9439028c6..0bf680b44a183 100644 --- a/datafusion/functions/src/string/uuid.rs +++ b/datafusion/functions/src/string/uuid.rs @@ -15,23 +15,22 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; -use arrow::array::GenericStringBuilder; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Utf8; use rand::Rng; use uuid::Uuid; -use datafusion_common::{assert_or_internal_err, Result}; +use crate::strings::GenericStringArrayBuilder; +use datafusion_common::{Result, assert_or_internal_err}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( doc_section(label = "String Functions"), - description = "Returns [`UUID v4`](https://en.wikipedia.org/wiki/Universally_unique_identifier#Version_4_(random)) string value which is unique per row.", + description = "Returns [`UUID v4`](https://en.wikipedia.org/wiki/Universally_unique_identifier#Version_4_%28random%29) string value which is unique per row.", syntax_example = "uuid()", sql_example = r#"```sql > select uuid(); @@ -56,16 +55,12 @@ impl Default for UuidFunc { impl UuidFunc { pub fn new() -> Self { Self { - signature: Signature::exact(vec![], Volatility::Volatile), + signature: Signature::nullary(Volatility::Volatile), } } } impl ScalarUDFImpl for UuidFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "uuid" } @@ -92,7 +87,7 @@ impl ScalarUDFImpl for UuidFunc { let mut randoms = vec![0u128; args.number_rows]; rng.fill(&mut randoms[..]); - let mut builder = GenericStringBuilder::::with_capacity( + let mut builder = GenericStringArrayBuilder::::with_capacity( args.number_rows, args.number_rows * 36, ); @@ -106,7 +101,7 @@ impl ScalarUDFImpl for UuidFunc { builder.append_value(fmt.encode_lower(&mut buffer)); } - Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + Ok(ColumnarValue::Array(Arc::new(builder.finish(None)?))) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/strings.rs b/datafusion/functions/src/strings.rs index 108c20e136670..144d567f5be0a 100644 --- a/datafusion/functions/src/strings.rs +++ b/datafusion/functions/src/strings.rs @@ -15,24 +15,38 @@ // specific language governing permissions and limitations // under the License. +use std::marker::PhantomData; use std::mem::size_of; +use std::sync::Arc; + +use datafusion_common::{Result, exec_datafusion_err, internal_err}; use arrow::array::{ - make_view, Array, ArrayAccessor, ArrayDataBuilder, ByteView, LargeStringArray, - NullBufferBuilder, StringArray, StringViewArray, StringViewBuilder, + Array, ArrayAccessor, ArrayDataBuilder, ArrayRef, BinaryArray, ByteView, + GenericStringArray, LargeStringArray, OffsetSizeTrait, StringArray, StringViewArray, + make_view, }; -use arrow::buffer::{MutableBuffer, NullBuffer}; +use arrow::buffer::{Buffer, MutableBuffer, NullBuffer, ScalarBuffer}; use arrow::datatypes::DataType; -/// Optimized version of the StringBuilder in Arrow that: -/// 1. Precalculating the expected length of the result, avoiding reallocations. -/// 2. Avoids creating / incrementally creating a `NullBufferBuilder` -pub struct StringArrayBuilder { +/// Builder used by `concat`/`concat_ws` to assemble a [`StringArray`] one row +/// at a time from multiple input columns. +/// +/// Each row is written via repeated `write` calls (one per input fragment) +/// followed by a single `append_offset` to commit the row. The output null +/// buffer is computed in bulk by the caller and supplied to `finish`, avoiding +/// per-row NULL handling work. +/// +/// For the common "produce one `&str` per row" pattern, prefer +/// `GenericStringArrayBuilder` instead. +pub(crate) struct ConcatStringBuilder { offsets_buffer: MutableBuffer, value_buffer: MutableBuffer, + /// If true, a safety check is required during the `finish` call + tainted: bool, } -impl StringArrayBuilder { +impl ConcatStringBuilder { pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { let capacity = item_capacity .checked_add(1) @@ -45,6 +59,7 @@ impl StringArrayBuilder { Self { offsets_buffer, value_buffer: MutableBuffer::with_capacity(data_capacity), + tainted: false, } } @@ -56,6 +71,7 @@ impl StringArrayBuilder { match column { ColumnarValueRef::Scalar(s) => { self.value_buffer.extend_from_slice(s); + self.tainted = true; } ColumnarValueRef::NullableArray(array) => { if !CHECK_VALID || array.is_valid(i) { @@ -75,6 +91,12 @@ impl StringArrayBuilder { .extend_from_slice(array.value(i).as_bytes()); } } + ColumnarValueRef::NullableBinaryArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer.extend_from_slice(array.value(i)); + } + self.tainted = true; + } ColumnarValueRef::NonNullableArray(array) => { self.value_buffer .extend_from_slice(array.value(i).as_bytes()); @@ -87,31 +109,36 @@ impl StringArrayBuilder { self.value_buffer .extend_from_slice(array.value(i).as_bytes()); } + ColumnarValueRef::NonNullableBinaryArray(array) => { + self.value_buffer.extend_from_slice(array.value(i)); + self.tainted = true; + } } } - pub fn append_offset(&mut self) { + pub fn append_offset(&mut self) -> Result<()> { let next_offset: i32 = self .value_buffer .len() .try_into() - .expect("byte array offset overflow"); + .map_err(|_| exec_datafusion_err!("byte array offset overflow"))?; self.offsets_buffer.push(next_offset); + Ok(()) } /// Finalize the builder into a concrete [`StringArray`]. /// - /// # Panics + /// # Errors /// - /// This method can panic when: + /// Returns an error when: /// /// - the provided `null_buffer` is not the same length as the `offsets_buffer`. - pub fn finish(self, null_buffer: Option) -> StringArray { + pub fn finish(self, null_buffer: Option) -> Result { let row_count = self.offsets_buffer.len() / size_of::() - 1; - if let Some(ref null_buffer) = null_buffer { - assert_eq!( - null_buffer.len(), - row_count, + if let Some(ref null_buffer) = null_buffer + && null_buffer.len() != row_count + { + return internal_err!( "Null buffer and offsets buffer must be the same length" ); } @@ -120,24 +147,45 @@ impl StringArrayBuilder { .add_buffer(self.offsets_buffer.into()) .add_buffer(self.value_buffer.into()) .nulls(null_buffer); - // SAFETY: all data that was appended was valid UTF8 and the values - // and offsets were created correctly - let array_data = unsafe { array_builder.build_unchecked() }; - StringArray::from(array_data) + if self.tainted { + // Raw binary arrays with possible invalid utf-8 were used, + // so let ArrayDataBuilder perform validation + let array_data = array_builder.build()?; + Ok(StringArray::from(array_data)) + } else { + // SAFETY: all data that was appended was valid UTF8 and the values + // and offsets were created correctly + let array_data = unsafe { array_builder.build_unchecked() }; + Ok(StringArray::from(array_data)) + } } } -pub struct StringViewArrayBuilder { - builder: StringViewBuilder, - block: String, +/// Builder used by `concat`/`concat_ws` to assemble a [`StringViewArray`] one +/// row at a time from multiple input columns. +/// +/// Each row is written via repeated `write` calls (one per input +/// fragment) followed by a single `append_offset` to commit the row +/// as a single string view. The output null buffer is supplied by the caller +/// at `finish` time, avoiding per-row NULL handling work. +/// +/// For the common "produce one `&str` per row" pattern, prefer +/// [`StringViewArrayBuilder`] instead. +pub(crate) struct ConcatStringViewBuilder { + views: Vec, + data: Vec, + block: Vec, + /// If true, a safety check is required during the `append_offset` call + tainted: bool, } -impl StringViewArrayBuilder { - pub fn with_capacity(_item_capacity: usize, data_capacity: usize) -> Self { - let builder = StringViewBuilder::with_capacity(data_capacity); +impl ConcatStringViewBuilder { + pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { Self { - builder, - block: String::new(), + views: Vec::with_capacity(item_capacity), + data: Vec::with_capacity(data_capacity), + block: vec![], + tainted: false, } } @@ -148,60 +196,124 @@ impl StringViewArrayBuilder { ) { match column { ColumnarValueRef::Scalar(s) => { - self.block.push_str(std::str::from_utf8(s).unwrap()); + self.block.extend_from_slice(s); + self.tainted = true; } ColumnarValueRef::NullableArray(array) => { if !CHECK_VALID || array.is_valid(i) { - self.block.push_str( - std::str::from_utf8(array.value(i).as_bytes()).unwrap(), - ); + self.block.extend_from_slice(array.value(i).as_bytes()); } } ColumnarValueRef::NullableLargeStringArray(array) => { if !CHECK_VALID || array.is_valid(i) { - self.block.push_str( - std::str::from_utf8(array.value(i).as_bytes()).unwrap(), - ); + self.block.extend_from_slice(array.value(i).as_bytes()); } } ColumnarValueRef::NullableStringViewArray(array) => { if !CHECK_VALID || array.is_valid(i) { - self.block.push_str( - std::str::from_utf8(array.value(i).as_bytes()).unwrap(), - ); + self.block.extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NullableBinaryArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.extend_from_slice(array.value(i)); } + self.tainted = true; } ColumnarValueRef::NonNullableArray(array) => { - self.block - .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + self.block.extend_from_slice(array.value(i).as_bytes()); } ColumnarValueRef::NonNullableLargeStringArray(array) => { - self.block - .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + self.block.extend_from_slice(array.value(i).as_bytes()); } ColumnarValueRef::NonNullableStringViewArray(array) => { - self.block - .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + self.block.extend_from_slice(array.value(i).as_bytes()); + } + ColumnarValueRef::NonNullableBinaryArray(array) => { + self.block.extend_from_slice(array.value(i)); + self.tainted = true; } } } - pub fn append_offset(&mut self) { - self.builder.append_value(&self.block); - self.block = String::new(); + /// Finalizes the current row by converting the accumulated data into a + /// StringView and appending it to the views buffer. + pub fn append_offset(&mut self) -> Result<()> { + if self.tainted { + std::str::from_utf8(&self.block) + .map_err(|_| exec_datafusion_err!("invalid UTF-8 in binary literal"))?; + } + + let v = &self.block; + if v.len() > 12 { + let offset: u32 = self + .data + .len() + .try_into() + .map_err(|_| exec_datafusion_err!("byte array offset overflow"))?; + self.data.extend_from_slice(v); + self.views.push(make_view(v, 0, offset)); + } else { + self.views.push(make_view(v, 0, 0)); + } + + self.block.clear(); + self.tainted = false; + Ok(()) } - pub fn finish(mut self) -> StringViewArray { - self.builder.finish() + /// Finalize the builder into a concrete [`StringViewArray`]. + /// + /// # Errors + /// + /// Returns an error when: + /// + /// - the provided `null_buffer` length does not match the row count. + pub fn finish(self, null_buffer: Option) -> Result { + if let Some(ref nulls) = null_buffer + && nulls.len() != self.views.len() + { + return internal_err!( + "Null buffer length ({}) must match row count ({})", + nulls.len(), + self.views.len() + ); + } + + let buffers: Vec = if self.data.is_empty() { + vec![] + } else { + vec![Buffer::from(self.data)] + }; + + // SAFETY: views were constructed with correct lengths, offsets, and + // prefixes. UTF-8 validity was checked in append_offset() for any row + // where tainted data (e.g., binary literals) was appended. + let array = unsafe { + StringViewArray::new_unchecked( + ScalarBuffer::from(self.views), + buffers, + null_buffer, + ) + }; + Ok(array) } } -pub struct LargeStringArrayBuilder { +/// Builder used by `concat`/`concat_ws` to assemble a [`LargeStringArray`] one +/// row at a time from multiple input columns. See [`ConcatStringBuilder`] for +/// details on the row-composition contract. +/// +/// For the common "produce one `&str` per row" pattern, prefer +/// `GenericStringArrayBuilder` instead. +pub(crate) struct ConcatLargeStringBuilder { offsets_buffer: MutableBuffer, value_buffer: MutableBuffer, + /// If true, a safety check is required during the `finish` call + tainted: bool, } -impl LargeStringArrayBuilder { +impl ConcatLargeStringBuilder { pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { let capacity = item_capacity .checked_add(1) @@ -214,6 +326,7 @@ impl LargeStringArrayBuilder { Self { offsets_buffer, value_buffer: MutableBuffer::with_capacity(data_capacity), + tainted: false, } } @@ -225,6 +338,7 @@ impl LargeStringArrayBuilder { match column { ColumnarValueRef::Scalar(s) => { self.value_buffer.extend_from_slice(s); + self.tainted = true; } ColumnarValueRef::NullableArray(array) => { if !CHECK_VALID || array.is_valid(i) { @@ -244,6 +358,12 @@ impl LargeStringArrayBuilder { .extend_from_slice(array.value(i).as_bytes()); } } + ColumnarValueRef::NullableBinaryArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer.extend_from_slice(array.value(i)); + } + self.tainted = true; + } ColumnarValueRef::NonNullableArray(array) => { self.value_buffer .extend_from_slice(array.value(i).as_bytes()); @@ -256,31 +376,36 @@ impl LargeStringArrayBuilder { self.value_buffer .extend_from_slice(array.value(i).as_bytes()); } + ColumnarValueRef::NonNullableBinaryArray(array) => { + self.value_buffer.extend_from_slice(array.value(i)); + self.tainted = true; + } } } - pub fn append_offset(&mut self) { + pub fn append_offset(&mut self) -> Result<()> { let next_offset: i64 = self .value_buffer .len() .try_into() - .expect("byte array offset overflow"); + .map_err(|_| exec_datafusion_err!("byte array offset overflow"))?; self.offsets_buffer.push(next_offset); + Ok(()) } /// Finalize the builder into a concrete [`LargeStringArray`]. /// - /// # Panics + /// # Errors /// - /// This method can panic when: + /// Returns an error when: /// /// - the provided `null_buffer` is not the same length as the `offsets_buffer`. - pub fn finish(self, null_buffer: Option) -> LargeStringArray { + pub fn finish(self, null_buffer: Option) -> Result { let row_count = self.offsets_buffer.len() / size_of::() - 1; - if let Some(ref null_buffer) = null_buffer { - assert_eq!( - null_buffer.len(), - row_count, + if let Some(ref null_buffer) = null_buffer + && null_buffer.len() != row_count + { + return internal_err!( "Null buffer and offsets buffer must be the same length" ); } @@ -289,14 +414,711 @@ impl LargeStringArrayBuilder { .add_buffer(self.offsets_buffer.into()) .add_buffer(self.value_buffer.into()) .nulls(null_buffer); - // SAFETY: all data that was appended was valid Large UTF8 and the values - // and offsets were created correctly - let array_data = unsafe { array_builder.build_unchecked() }; - LargeStringArray::from(array_data) + if self.tainted { + // Raw binary arrays with possible invalid utf-8 were used, + // so let ArrayDataBuilder perform validation + let array_data = array_builder.build()?; + Ok(LargeStringArray::from(array_data)) + } else { + // SAFETY: all data that was appended was valid Large UTF8 and the values + // and offsets were created correctly + let array_data = unsafe { array_builder.build_unchecked() }; + Ok(LargeStringArray::from(array_data)) + } + } +} + +// ---------------------------------------------------------------------------- +// Bulk-nulls builders +// +// These builders are similar to Arrow's `GenericStringBuilder` and +// `StringViewBuilder` but tuned for string UDFs along two axes: +// +// * Bulk-NULL handling. The NULL bitmap is passed to `finish()` rather than +// maintained per-row. Many string UDFs can compute the bitmap in bulk, +// where this is significantly more efficient. +// * Closure-based row emission. Beyond `append_value(&str)`, the builders +// expose `append_with` (fragments written into the builder via a +// `StringWriter`) and `append_byte_map` (byte-to-byte mapping of an input +// slice), letting UDFs emit a row without first assembling it in a scratch +// `String`. +// ---------------------------------------------------------------------------- + +/// Builder for a [`GenericStringArray`]. Instantiate with `O = i32` for +/// [`StringArray`] (Utf8) or `O = i64` for [`LargeStringArray`] (LargeUtf8). +pub(crate) struct GenericStringArrayBuilder { + offsets_buffer: MutableBuffer, + value_buffer: MutableBuffer, + placeholder_count: usize, + _phantom: PhantomData, +} + +impl GenericStringArrayBuilder { + pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { + let capacity = item_capacity + .checked_add(1) + .map(|i| i.saturating_mul(size_of::())) + .expect("capacity integer overflow"); + + let mut offsets_buffer = MutableBuffer::with_capacity(capacity); + offsets_buffer.push(O::usize_as(0)); + Self { + offsets_buffer, + value_buffer: MutableBuffer::with_capacity(data_capacity), + placeholder_count: 0, + _phantom: PhantomData, + } + } + + /// See [`BulkNullStringArrayBuilder::append_value`]. + /// + /// # Panics + /// + /// Panics if the cumulative byte length exceeds `O::MAX`. + #[inline] + pub fn append_value(&mut self, value: &str) { + self.value_buffer.extend_from_slice(value.as_bytes()); + let next_offset = + O::from_usize(self.value_buffer.len()).expect("byte array offset overflow"); + self.offsets_buffer.push(next_offset); + } + + /// See [`BulkNullStringArrayBuilder::append_placeholder`]. + #[inline] + pub fn append_placeholder(&mut self) { + let next_offset = + O::from_usize(self.value_buffer.len()).expect("byte array offset overflow"); + self.offsets_buffer.push(next_offset); + self.placeholder_count += 1; + } + + /// See [`BulkNullStringArrayBuilder::append_byte_map`]. + /// + /// # Safety + /// + /// The bytes produced by applying `map` to each byte of `src`, in order, + /// must form valid UTF-8. + /// + /// # Panics + /// + /// Panics if the cumulative byte length exceeds `O::MAX`. + #[inline] + pub unsafe fn append_byte_map u8>(&mut self, src: &[u8], mut map: F) { + self.value_buffer.extend(src.iter().map(|&b| map(b))); + let next_offset = + O::from_usize(self.value_buffer.len()).expect("byte array offset overflow"); + self.offsets_buffer.push(next_offset); + } + + /// See [`BulkNullStringArrayBuilder::append_with`]. + /// + /// # Panics + /// + /// Panics if the cumulative byte length exceeds `O::MAX`. + #[inline] + pub fn append_with(&mut self, f: F) + where + F: FnOnce(&mut GenericStringWriter<'_>), + { + let mut writer = GenericStringWriter { + value_buffer: &mut self.value_buffer, + }; + f(&mut writer); + let next_offset = + O::from_usize(self.value_buffer.len()).expect("byte array offset overflow"); + self.offsets_buffer.push(next_offset); + } + + /// Finalize into a [`GenericStringArray`] using the caller-supplied + /// null buffer. + /// + /// # Errors + /// + /// Returns an error when `null_buffer.len()` does not match the number of + /// appended rows. + pub fn finish( + self, + null_buffer: Option, + ) -> Result> { + let row_count = self.offsets_buffer.len() / size_of::() - 1; + if let Some(ref n) = null_buffer + && n.len() != row_count + { + return internal_err!( + "Null buffer length ({}) must match row count ({row_count})", + n.len() + ); + } + let null_count = null_buffer.as_ref().map_or(0, |n| n.null_count()); + debug_assert!( + null_count >= self.placeholder_count, + "{} placeholder rows but null buffer has {null_count} nulls", + self.placeholder_count, + ); + let array_data = ArrayDataBuilder::new(GenericStringArray::::DATA_TYPE) + .len(row_count) + .add_buffer(self.offsets_buffer.into()) + .add_buffer(self.value_buffer.into()) + .nulls(null_buffer); + // SAFETY: every appended value came from a `&str`, so the value + // buffer is valid UTF-8 and offsets are monotonically non-decreasing. + let array_data = unsafe { array_data.build_unchecked() }; + Ok(GenericStringArray::::from(array_data)) + } +} + +/// Starting size for the long-string data block used by `StringView`-style +/// arrays; matches Arrow's `GenericByteViewBuilder` default. +pub(crate) const STRING_VIEW_INIT_BLOCK_SIZE: u32 = 8 * 1024; +/// Maximum size each long-string data block in a `StringView`-style array +/// grows to; matches Arrow's `GenericByteViewBuilder` default. +pub(crate) const STRING_VIEW_MAX_BLOCK_SIZE: u32 = 2 * 1024 * 1024; + +/// Append-only writer handed to closures passed to `append_with`. +pub(crate) trait StringWriter { + fn write_str(&mut self, s: &str); + fn write_char(&mut self, c: char); +} + +/// [`StringWriter`] for [`GenericStringArrayBuilder`]. Writes go straight to +/// the value buffer. +pub(crate) struct GenericStringWriter<'a> { + value_buffer: &'a mut MutableBuffer, +} + +impl StringWriter for GenericStringWriter<'_> { + #[inline(always)] + fn write_str(&mut self, s: &str) { + push_bytes_to_mutable_buffer(self.value_buffer, s.as_bytes()); + } + + #[inline(always)] + fn write_char(&mut self, c: char) { + push_char_to_mutable_buffer(self.value_buffer, c); + } +} + +/// Write `bytes` into `value_buffer`. For repeated small writes, +/// MutableBuffer::extend_from_slice can be slow (memcpy per call), so we extend +/// the buffer here directly and force inlining. +#[inline(always)] +fn push_bytes_to_mutable_buffer(value_buffer: &mut MutableBuffer, bytes: &[u8]) { + let n = bytes.len(); + let old_len = value_buffer.len(); + value_buffer.reserve(n); + + // SAFETY: we reserved `n` bytes; the source and destination do not alias + // because `bytes` was passed in by the caller and `value_buffer` is owned. + unsafe { + let dst = value_buffer.as_mut_ptr().add(old_len); + let src = bytes.as_ptr(); + match n { + 0 => {} + 1 => std::ptr::copy_nonoverlapping(src, dst, 1), + 2 => std::ptr::copy_nonoverlapping(src, dst, 2), + 3 => std::ptr::copy_nonoverlapping(src, dst, 3), + 4 => std::ptr::copy_nonoverlapping(src, dst, 4), + 5 => std::ptr::copy_nonoverlapping(src, dst, 5), + 6 => std::ptr::copy_nonoverlapping(src, dst, 6), + 7 => std::ptr::copy_nonoverlapping(src, dst, 7), + 8 => std::ptr::copy_nonoverlapping(src, dst, 8), + _ => std::ptr::copy_nonoverlapping(src, dst, n), + } + value_buffer.set_len(old_len + n); } } -/// Append a new view to the views buffer with the given substr +#[inline(always)] +fn push_char_to_mutable_buffer(value_buffer: &mut MutableBuffer, c: char) { + let len = c.len_utf8(); + let old_len = value_buffer.len(); + value_buffer.reserve(len); + + // SAFETY: we reserved `len` bytes above, write valid UTF-8 into those + // bytes, then update the initialized length to include them. + unsafe { + let dst = value_buffer.as_mut_ptr().add(old_len); + if len == 1 { + *dst = c as u8; + } else { + c.encode_utf8(std::slice::from_raw_parts_mut(dst, len)); + } + value_buffer.set_len(old_len + len); + } +} + +/// Builder for a [`StringViewArray`]. +/// +/// Short strings (≤ 12 bytes) are inlined into the view itself; long strings +/// are appended into an in-progress data block. When the in-progress block +/// fills up it is flushed into `completed` and a new block — double the size +/// of the last, capped at [`STRING_VIEW_MAX_BLOCK_SIZE`] — is started. +pub(crate) struct StringViewArrayBuilder { + views: Vec, + in_progress: Vec, + completed: Vec, + block_size: u32, + placeholder_count: usize, +} + +impl StringViewArrayBuilder { + pub fn with_capacity(item_capacity: usize) -> Self { + Self { + views: Vec::with_capacity(item_capacity), + in_progress: Vec::new(), + completed: Vec::new(), + block_size: STRING_VIEW_INIT_BLOCK_SIZE, + placeholder_count: 0, + } + } + + /// Doubles the block-size target and returns the new size. + fn next_block_size(&mut self) -> u32 { + if self.block_size < STRING_VIEW_MAX_BLOCK_SIZE { + self.block_size = self.block_size.saturating_mul(2); + } + self.block_size + } + + /// See [`BulkNullStringArrayBuilder::append_value`]. + /// + /// # Panics + /// + /// Panics if the value length, the in-progress buffer offset, or the + /// number of completed buffers exceeds `i32::MAX`. The ByteView spec + /// uses signed 32-bit integers for these fields; exceeding `i32::MAX` + /// would produce an array that does not round-trip through Arrow IPC + /// (see ). + #[inline] + pub fn append_value(&mut self, value: &str) { + let v = value.as_bytes(); + let length: u32 = + i32::try_from(v.len()).expect("value length exceeds i32::MAX") as u32; + if length <= 12 { + self.views.push(make_view(v, 0, 0)); + return; + } + + let required_cap = self.in_progress.len() + length as usize; + if self.in_progress.capacity() < required_cap { + self.flush_in_progress(); + let to_reserve = (length as usize).max(self.next_block_size() as usize); + #[expect( + clippy::disallowed_methods, + reason = "StringView's block size bounds growth, so reserve cannot overflow capacity arithmetically. This hot loop intentionally avoids the extra `try_reserve` checks. It remains subject to allocator failure/OOM, which must be managed externally." + )] + self.in_progress.reserve(to_reserve); + } + + let offset: u32 = i32::try_from(self.in_progress.len()) + .expect("offset exceeds i32::MAX") as u32; + self.in_progress.extend_from_slice(v); + self.views.push(self.make_long_view(length, offset, v)); + } + + /// See [`BulkNullStringArrayBuilder::append_placeholder`]. + #[inline] + pub fn append_placeholder(&mut self) { + // Zero-length inline view — `length` field is 0, no buffer ref. + self.views.push(0); + self.placeholder_count += 1; + } + + /// Ensure the in-progress block has room for `length` more bytes, + /// flushing the current block and starting a new (doubled) one if not. + /// Caller must invoke this only when no bytes of the current row are + /// yet in `in_progress` — flushing mid-row would orphan partial data. + #[inline] + fn ensure_long_capacity(&mut self, length: u32) { + let required_cap = self.in_progress.len() + length as usize; + if self.in_progress.capacity() < required_cap { + self.flush_in_progress(); + let to_reserve = (length as usize).max(self.next_block_size() as usize); + #[expect( + clippy::disallowed_methods, + reason = "StringView's block size bounds growth, so reserve cannot overflow capacity arithmetically. This hot loop intentionally avoids the extra `try_reserve` checks. It remains subject to allocator failure/OOM, which must be managed externally." + )] + self.in_progress.reserve(to_reserve); + } + } + + /// Encode a long-form view referencing `length` bytes already written + /// into the in-progress block at `offset`. `prefix_bytes` is the row's + /// data slice (or any slice starting with the row's first 4 bytes). + /// + /// Built inline rather than going through Arrow's `make_view`: that + /// function is `[inline(never)]` and has to handle short strings, so + /// building the view here ourselves is faster. + #[inline] + fn make_long_view(&self, length: u32, offset: u32, prefix_bytes: &[u8]) -> u128 { + let buffer_index: u32 = i32::try_from(self.completed.len()) + .expect("buffer count exceeds i32::MAX") + as u32; + ByteView { + length, + // length > 12, so prefix_bytes has at least 4 bytes. + prefix: u32::from_le_bytes(prefix_bytes[..4].try_into().unwrap()), + buffer_index, + offset, + } + .into() + } + + /// See [`BulkNullStringArrayBuilder::append_byte_map`]. + /// + /// # Safety + /// + /// The bytes produced by applying `map` to each byte of `src`, in order, + /// must form valid UTF-8. + /// + /// # Panics + /// + /// Panics under the same conditions as [`Self::append_value`]: if + /// `src.len()`, the in-progress buffer offset, or the number of completed + /// buffers exceeds `i32::MAX`. + #[inline] + pub unsafe fn append_byte_map u8>(&mut self, src: &[u8], mut map: F) { + let length: u32 = + i32::try_from(src.len()).expect("value length exceeds i32::MAX") as u32; + if length <= 12 { + let mut bytes = [0u8; 12]; + for (d, &b) in bytes[..src.len()].iter_mut().zip(src) { + *d = map(b); + } + self.views.push(make_view(&bytes[..src.len()], 0, 0)); + return; + } + + self.ensure_long_capacity(length); + + let cursor = self.in_progress.len(); + let offset: u32 = i32::try_from(cursor).expect("offset exceeds i32::MAX") as u32; + self.in_progress.extend(src.iter().map(|&b| map(b))); + self.views + .push(self.make_long_view(length, offset, &self.in_progress[cursor..])); + } + + /// See [`BulkNullStringArrayBuilder::append_with`]. + /// + /// # Panics + /// + /// Panics under the same conditions as [`Self::append_value`]: if the + /// row's byte length, the in-progress buffer offset, or the number of + /// completed buffers exceeds `i32::MAX`. + #[inline] + pub fn append_with(&mut self, f: F) + where + F: FnOnce(&mut StringViewWriter<'_>), + { + let mut writer = StringViewWriter { + inline_buf: [0u8; 12], + inline_len: 0, + spill_cursor: None, + builder: self, + }; + f(&mut writer); + // Destructure to release the borrow on `self` and pull out the + // inline-buffer state by-value. Copy types only; the &mut self is + // dropped here, ending the borrow. + let StringViewWriter { + inline_buf, + inline_len, + spill_cursor, + .. + } = writer; + + match spill_cursor { + None => { + self.views + .push(make_view(&inline_buf[..inline_len as usize], 0, 0)); + } + Some(start) => { + let end = self.in_progress.len(); + let length: u32 = i32::try_from(end - start) + .expect("value length exceeds i32::MAX") + as u32; + let offset: u32 = + i32::try_from(start).expect("offset exceeds i32::MAX") as u32; + self.views.push(self.make_long_view( + length, + offset, + &self.in_progress[start..], + )); + } + } + } + + fn flush_in_progress(&mut self) { + if !self.in_progress.is_empty() { + let block = std::mem::take(&mut self.in_progress); + self.completed.push(Buffer::from_vec(block)); + } + } + + /// Finalize into a [`StringViewArray`] using the caller-supplied null + /// buffer. + /// + /// # Errors + /// + /// Returns an error when `null_buffer.len()` does not match the number of + /// appended rows. + pub fn finish(mut self, null_buffer: Option) -> Result { + if let Some(ref n) = null_buffer + && n.len() != self.views.len() + { + return internal_err!( + "Null buffer length ({}) must match row count ({})", + n.len(), + self.views.len() + ); + } + let null_count = null_buffer.as_ref().map_or(0, |n| n.null_count()); + debug_assert!( + null_count >= self.placeholder_count, + "{} placeholder rows but null buffer has {null_count} nulls", + self.placeholder_count, + ); + self.flush_in_progress(); + // SAFETY: every long-string view references bytes we wrote ourselves + // into `self.completed`, with prefixes derived from those same bytes. + // Inline views were built from valid `&str`. Placeholder views are + // zero-length with no buffer reference. + let array = unsafe { + StringViewArray::new_unchecked( + ScalarBuffer::from(self.views), + self.completed, + null_buffer, + ) + }; + Ok(array) + } +} + +/// [`StringWriter`] for [`StringViewArrayBuilder`]. +/// +/// The writer accumulates the first up-to-12 bytes of a row in a stack +/// buffer; if the row stays inline-sized, it never touches the data block. +/// On the first write that would exceed 12 bytes, the stack buffer is +/// spilled into the builder's in-progress block and subsequent writes go +/// directly there. +pub(crate) struct StringViewWriter<'a> { + inline_buf: [u8; 12], + inline_len: u8, + /// `None` while the row fits inline; becomes `Some(start)` (offset of + /// the row's first byte in `in_progress`) at first spill. + spill_cursor: Option, + builder: &'a mut StringViewArrayBuilder, +} + +impl StringWriter for StringViewWriter<'_> { + #[inline] + fn write_str(&mut self, s: &str) { + let bytes = s.as_bytes(); + if self.spill_cursor.is_some() { + self.builder.in_progress.extend_from_slice(bytes); + return; + } + + let inline_len = self.inline_len as usize; + let new_len = inline_len + bytes.len(); + if new_len <= 12 { + self.inline_buf[inline_len..new_len].copy_from_slice(bytes); + self.inline_len = new_len as u8; + return; + } + + // First spill of this row: `ensure_long_capacity` may flush the + // current block, which is safe because no row-data for this row + // is in it yet — the inline prefix is still in `inline_buf`. + self.builder.ensure_long_capacity(new_len as u32); + let cursor = self.builder.in_progress.len(); + self.builder + .in_progress + .extend_from_slice(&self.inline_buf[..inline_len]); + self.builder.in_progress.extend_from_slice(bytes); + self.spill_cursor = Some(cursor); + } + + #[inline] + fn write_char(&mut self, c: char) { + let len = c.len_utf8(); + if self.spill_cursor.is_some() { + push_char_to_vec(&mut self.builder.in_progress, c); + return; + } + + let inline_len = self.inline_len as usize; + let new_len = inline_len + len; + if new_len <= 12 { + c.encode_utf8(&mut self.inline_buf[inline_len..new_len]); + self.inline_len = new_len as u8; + return; + } + + self.builder.ensure_long_capacity(new_len as u32); + let cursor = self.builder.in_progress.len(); + self.builder + .in_progress + .extend_from_slice(&self.inline_buf[..inline_len]); + push_char_to_vec(&mut self.builder.in_progress, c); + self.spill_cursor = Some(cursor); + } +} + +#[inline] +fn push_char_to_vec(v: &mut Vec, c: char) { + let mut buf = [0u8; 4]; + v.extend_from_slice(c.encode_utf8(&mut buf).as_bytes()); +} + +/// Trait abstracting over the bulk-NULL string array builders. +/// +/// Similar to Arrow's `StringLikeArrayBuilder`, this allows generic dispatch +/// over the three string array types (Utf8, LargeUtf8, Utf8View) when the +/// function body is uniform across them. +/// +/// Three methods append a non-null row; which method to pick depends on how the +/// row is produced: +/// +/// - [`append_value`](Self::append_value) pushes an already-finished `&str`. +/// Use it when the row is forwarded from an existing slice (e.g. an input +/// column) — there is nothing to elide. +/// - [`append_byte_map`](Self::append_byte_map) emits a row whose bytes are a +/// byte-to-byte mapping of an input slice. Output length is known up front +/// and the inner loop is straight-line, so this is the fastest path when the +/// shape fits. +/// - [`append_with`](Self::append_with) emits a row by feeding fragments to a +/// [`StringWriter`]. Use it when the row is computed from multiple sources or +/// when the output length is not known up front. Bytes are written directly +/// into the builder, so it is typically faster than assembling a `String` and +/// calling `append_value(&scratch)`. +/// +/// For a NULL row, call [`append_placeholder`](Self::append_placeholder) to +/// advance the row count without writing into the value buffer; the caller MUST +/// clear the corresponding bit in the null buffer passed to +/// [`finish`](Self::finish). +pub(crate) trait BulkNullStringArrayBuilder { + /// Per-builder concrete writer type, exposed as a GAT so generic callers + /// can use the inherent (non-`dyn`) writer methods without vtable + /// dispatch. + type Writer<'a>: StringWriter + where + Self: 'a; + + /// Append `value` as the next row. + /// + /// # Panics + /// + /// Panics if the resulting array would exceed the per-implementation + /// size limit. See the inherent method on each builder for specifics. + fn append_value(&mut self, value: &str); + + /// Append an empty placeholder row. The corresponding slot MUST be masked + /// as null by the null buffer passed to [`finish`](Self::finish). + fn append_placeholder(&mut self); + + /// Append a row whose bytes are produced by `f` calling write methods on + /// the supplied [`StringWriter`]. + /// + /// The closure can call `write_str` or `write_char` on the supplied + /// `StringWriter` zero or more times. Zero calls produces a row containing + /// the empty string. + /// + /// # Panics + /// + /// See [`append_value`](Self::append_value). + fn append_with(&mut self, f: F) + where + F: for<'a> FnOnce(&mut Self::Writer<'a>); + + /// Append a row whose bytes are produced by mapping each byte of `src` + /// through `map`, in order. Output length equals `src.len()`. + /// + /// Because the output length is known up front and the inner loop is + /// straight-line, this is more efficient than + /// [`append_with`](Self::append_with) for byte-to-byte mappings and + /// autovectorizes well. + /// + /// # Safety + /// + /// The bytes produced by applying `map` to each byte of `src`, in order, + /// must form valid UTF-8. + /// + /// # Panics + /// + /// See [`append_value`](Self::append_value). + unsafe fn append_byte_map u8>(&mut self, src: &[u8], map: F); + + /// Finalize into a concrete array using the caller-supplied null buffer. + /// + /// # Errors + /// + /// Returns an error when `null_buffer.len()` does not match the number + /// of appended rows. + fn finish(self, nulls: Option) -> Result; +} + +impl BulkNullStringArrayBuilder for GenericStringArrayBuilder { + type Writer<'a> = GenericStringWriter<'a>; + + #[inline] + fn append_value(&mut self, value: &str) { + GenericStringArrayBuilder::::append_value(self, value) + } + #[inline] + fn append_placeholder(&mut self) { + GenericStringArrayBuilder::::append_placeholder(self) + } + #[inline] + fn append_with(&mut self, f: F) + where + F: for<'a> FnOnce(&mut Self::Writer<'a>), + { + GenericStringArrayBuilder::::append_with(self, f) + } + #[inline] + unsafe fn append_byte_map u8>(&mut self, src: &[u8], map: F) { + // SAFETY: contract forwarded. + unsafe { GenericStringArrayBuilder::::append_byte_map(self, src, map) } + } + fn finish(self, nulls: Option) -> Result { + Ok(Arc::new(GenericStringArrayBuilder::::finish( + self, nulls, + )?)) + } +} + +impl BulkNullStringArrayBuilder for StringViewArrayBuilder { + type Writer<'a> = StringViewWriter<'a>; + + #[inline] + fn append_value(&mut self, value: &str) { + StringViewArrayBuilder::append_value(self, value) + } + #[inline] + fn append_placeholder(&mut self) { + StringViewArrayBuilder::append_placeholder(self) + } + #[inline] + fn append_with(&mut self, f: F) + where + F: for<'a> FnOnce(&mut Self::Writer<'a>), + { + StringViewArrayBuilder::append_with(self, f) + } + #[inline] + unsafe fn append_byte_map u8>(&mut self, src: &[u8], map: F) { + // SAFETY: contract forwarded. + unsafe { StringViewArrayBuilder::append_byte_map(self, src, map) } + } + fn finish(self, nulls: Option) -> Result { + Ok(Arc::new(StringViewArrayBuilder::finish(self, nulls)?)) + } +} + +/// Append a new view to the views buffer with the given substr. +/// +/// Callers are responsible for their own null tracking. /// /// # Safety /// @@ -305,13 +1127,15 @@ impl LargeStringArrayBuilder { /// /// # Arguments /// - views_buffer: The buffer to append the new view to -/// - null_builder: The buffer to append the null value to /// - original_view: The original view value /// - substr: The substring to append. Must be a valid substring of the original view /// - start_offset: The start offset of the substring in the view -pub fn make_and_append_view( +/// +/// LLVM is apparently overly eager to inline this function into some hot loops, +/// which bloats them and regresses performance, so we disable inlining for now. +#[inline(never)] +pub(crate) fn append_view( views_buffer: &mut Vec, - null_builder: &mut NullBufferBuilder, original_view: &u128, substr: &str, start_offset: u32, @@ -325,15 +1149,13 @@ pub fn make_and_append_view( view.offset + start_offset, ) } else { - // inline value does not need block id or offset make_view(substr.as_bytes(), 0, 0) }; views_buffer.push(sub_view); - null_builder.append_non_null(); } #[derive(Debug)] -pub enum ColumnarValueRef<'a> { +pub(crate) enum ColumnarValueRef<'a> { Scalar(&'a [u8]), NullableArray(&'a StringArray), NonNullableArray(&'a StringArray), @@ -341,6 +1163,8 @@ pub enum ColumnarValueRef<'a> { NonNullableLargeStringArray(&'a LargeStringArray), NullableStringViewArray(&'a StringViewArray), NonNullableStringViewArray(&'a StringViewArray), + NullableBinaryArray(&'a BinaryArray), + NonNullableBinaryArray(&'a BinaryArray), } impl ColumnarValueRef<'_> { @@ -350,10 +1174,12 @@ impl ColumnarValueRef<'_> { Self::Scalar(_) | Self::NonNullableArray(_) | Self::NonNullableLargeStringArray(_) - | Self::NonNullableStringViewArray(_) => true, + | Self::NonNullableStringViewArray(_) + | Self::NonNullableBinaryArray(_) => true, Self::NullableArray(array) => array.is_valid(i), Self::NullableStringViewArray(array) => array.is_valid(i), Self::NullableLargeStringArray(array) => array.is_valid(i), + Self::NullableBinaryArray(array) => array.is_valid(i), } } @@ -363,10 +1189,12 @@ impl ColumnarValueRef<'_> { Self::Scalar(_) | Self::NonNullableArray(_) | Self::NonNullableStringViewArray(_) - | Self::NonNullableLargeStringArray(_) => None, + | Self::NonNullableLargeStringArray(_) + | Self::NonNullableBinaryArray(_) => None, Self::NullableArray(array) => array.nulls().cloned(), Self::NullableStringViewArray(array) => array.nulls().cloned(), Self::NullableLargeStringArray(array) => array.nulls().cloned(), + Self::NullableBinaryArray(array) => array.nulls().cloned(), } } } @@ -375,15 +1203,393 @@ impl ColumnarValueRef<'_> { mod tests { use super::*; + /// Run `scenario` against `builder`, finish with a null buffer derived + /// from `expected` (a bit is set wherever `expected[i].is_some()`), and + /// assert the resulting array equals the corresponding + /// `*Array::from(expected)`. + /// + /// The caller is responsible for driving NULLs in `scenario` — usually + /// by calling `append_placeholder` at each index where `expected[i]` is + /// `None`. + fn run_scenario(mut builder: B, expected: &[Option<&str>], scenario: F) + where + B: BulkNullStringArrayBuilder, + F: FnOnce(&mut B), + { + scenario(&mut builder); + let bits: Vec = expected.iter().map(|x| x.is_some()).collect(); + let nulls = if bits.iter().any(|v| !v) { + Some(NullBuffer::from(bits)) + } else { + None + }; + let array = builder.finish(nulls).unwrap(); + let owned: Vec> = expected.to_vec(); + if let Some(a) = array.as_any().downcast_ref::() { + assert_eq!(a, &StringArray::from(owned)); + } else if let Some(a) = array.as_any().downcast_ref::() { + assert_eq!(a, &LargeStringArray::from(owned)); + } else if let Some(a) = array.as_any().downcast_ref::() { + assert_eq!(a, &StringViewArray::from(owned)); + } else { + panic!("unexpected array type"); + } + } + + /// Run `$scenario` against all three bulk-null builders, asserting each + /// produces an array equivalent to `$expected`. `$scenario` is a closure + /// `|builder| { ... }`; it is duplicated syntactically at each call site + /// so the `BulkNullStringArrayBuilder::Writer` GAT can specialize per + /// builder. + macro_rules! check_on_all_builders { + ($expected:expr, $scenario:expr $(,)?) => {{ + let expected = $expected; + run_scenario( + GenericStringArrayBuilder::::with_capacity(0, 0), + expected, + $scenario, + ); + run_scenario( + GenericStringArrayBuilder::::with_capacity(0, 0), + expected, + $scenario, + ); + run_scenario( + StringViewArrayBuilder::with_capacity(0), + expected, + $scenario, + ); + }}; + } + + fn assert_finish_errs_on_length_mismatch(mut builder: B) + where + B: BulkNullStringArrayBuilder, + { + builder.append_value("a"); + builder.append_value("b"); + let nulls = NullBuffer::from(vec![true, false, true]); + assert!(builder.finish(Some(nulls)).is_err()); + } + #[test] #[should_panic(expected = "capacity integer overflow")] - fn test_overflow_string_array_builder() { - let _builder = StringArrayBuilder::with_capacity(usize::MAX, usize::MAX); + fn test_overflow_concat_string_builder() { + let _builder = ConcatStringBuilder::with_capacity(usize::MAX, usize::MAX); } #[test] #[should_panic(expected = "capacity integer overflow")] - fn test_overflow_large_string_array_builder() { - let _builder = LargeStringArrayBuilder::with_capacity(usize::MAX, usize::MAX); + fn test_overflow_concat_large_string_builder() { + let _builder = ConcatLargeStringBuilder::with_capacity(usize::MAX, usize::MAX); + } + + #[test] + fn bulk_append_value_with_nulls() { + check_on_all_builders!( + &[ + Some("a string longer than twelve bytes"), + None, + Some("short"), + None, + ], + |b| { + b.append_value("a string longer than twelve bytes"); + b.append_placeholder(); + b.append_value("short"); + b.append_placeholder(); + }, + ); + } + + #[test] + fn bulk_empty_builder() { + check_on_all_builders!(&[], |_b| {}); + } + + #[test] + fn bulk_all_placeholders() { + check_on_all_builders!(&[None, None, None], |b| { + b.append_placeholder(); + b.append_placeholder(); + b.append_placeholder(); + }); + } + + #[test] + fn bulk_append_value_no_nulls() { + check_on_all_builders!( + &[ + Some("foo"), + Some(""), + Some("a string longer than twelve bytes") + ], + |b| { + b.append_value("foo"); + b.append_value(""); + b.append_value("a string longer than twelve bytes"); + }, + ); + } + + #[test] + fn bulk_append_with() { + check_on_all_builders!( + &[ + Some("hello"), + None, + Some("hello world"), + Some("a long string of 25 bytes"), + Some(""), + ], + |b| { + b.append_with(|w| w.write_str("hello")); + b.append_placeholder(); + b.append_with(|w| { + w.write_str("hello "); + w.write_str("world"); + }); + b.append_with(|w| w.write_str("a long string of 25 bytes")); + b.append_with(|_w| {}); + }, + ); + } + + #[test] + fn bulk_append_with_chars() { + check_on_all_builders!(&[Some("hé!"), Some("x")], |b| { + b.append_with(|w| { + w.write_char('h'); + w.write_char('é'); + w.write_char('!'); + }); + b.append_with(|w| w.write_char('x')); + }); + } + + #[test] + fn bulk_append_byte_map() { + // SAFETY: ASCII inputs and ASCII outputs in every call. + check_on_all_builders!(&[Some("HELLO"), Some("aXcaX"), Some("")], |b| unsafe { + b.append_byte_map(b"hello", |x| x.to_ascii_uppercase()); + b.append_byte_map(b"abcab", |x| if x == b'b' { b'X' } else { x }); + b.append_byte_map(b"", |x| x); + },); + } + + #[test] + fn bulk_finish_errors_on_null_buffer_length_mismatch() { + assert_finish_errs_on_length_mismatch( + GenericStringArrayBuilder::::with_capacity(2, 4), + ); + assert_finish_errs_on_length_mismatch( + GenericStringArrayBuilder::::with_capacity(2, 4), + ); + assert_finish_errs_on_length_mismatch(StringViewArrayBuilder::with_capacity(2)); + } + + #[test] + #[cfg(debug_assertions)] + #[should_panic(expected = "placeholder rows")] + fn string_array_builder_placeholder_without_null_mask() { + let mut builder = GenericStringArrayBuilder::::with_capacity(2, 4); + builder.append_value("a"); + builder.append_placeholder(); + // Slot 1 is a placeholder but the null buffer doesn't mark it null. + let nulls = NullBuffer::from(vec![true, true]); + let _ = builder.finish(Some(nulls)); + } + + #[test] + #[cfg(debug_assertions)] + #[should_panic(expected = "placeholder rows")] + fn string_array_builder_placeholder_with_none_null_buffer() { + let mut builder = GenericStringArrayBuilder::::with_capacity(1, 4); + builder.append_placeholder(); + let _ = builder.finish(None); + } + + #[test] + #[cfg(debug_assertions)] + #[should_panic(expected = "placeholder rows")] + fn string_view_array_builder_placeholder_without_null_mask() { + let mut builder = StringViewArrayBuilder::with_capacity(2); + builder.append_value("a"); + builder.append_placeholder(); + let nulls = NullBuffer::from(vec![true, true]); + let _ = builder.finish(Some(nulls)); + } + + #[test] + #[cfg(debug_assertions)] + #[should_panic(expected = "placeholder rows")] + fn string_view_array_builder_placeholder_with_none_null_buffer() { + let mut builder = StringViewArrayBuilder::with_capacity(1); + builder.append_placeholder(); + let _ = builder.finish(None); + } + + #[test] + fn string_view_array_builder_append_with_inline() { + // Rows that stay ≤ 12 bytes never touch the data block. + let mut builder = StringViewArrayBuilder::with_capacity(4); + let inputs = ["hello", "world!", "", "0123456789ab"]; + for s in &inputs { + builder.append_with(|w| w.write_str(s)); + } + let array = builder.finish(None).unwrap(); + assert_eq!(array.len(), inputs.len()); + for (i, s) in inputs.iter().enumerate() { + assert_eq!(array.value(i), *s); + } + assert_eq!(array.data_buffers().len(), 0); + } + + #[test] + fn string_view_array_builder_append_byte_map() { + let mut builder = StringViewArrayBuilder::with_capacity(4); + // SAFETY: ASCII inputs and ASCII outputs in every call. + unsafe { + builder.append_byte_map(b"hello", |b| b.to_ascii_uppercase()); + builder.append_byte_map(b"a long string of 25 bytes", |b| { + if b == b' ' { b'_' } else { b } + }); + // 12 bytes — exactly at the inline boundary. + builder.append_byte_map(b"abcdefghijkl", |b| b); + builder.append_byte_map(b"", |b| b); + } + let array = builder.finish(None).unwrap(); + assert_eq!(array.value(0), "HELLO"); + assert_eq!(array.value(1), "a_long_string_of_25_bytes"); + assert_eq!(array.value(2), "abcdefghijkl"); + assert_eq!(array.value(3), ""); + assert_eq!(array.data_buffers().len(), 1); + assert_eq!(array.data_buffers()[0].len(), 25); + } + + #[test] + fn string_view_array_builder_append_with_at_inline_boundary() { + // Building exactly 12 bytes via several writes should still go inline. + let mut builder = StringViewArrayBuilder::with_capacity(2); + builder.append_with(|w| { + w.write_str("hello"); + w.write_str(" world!"); + }); + builder.append_with(|w| { + for _ in 0..6 { + w.write_str("ab"); + } + }); + let array = builder.finish(None).unwrap(); + assert_eq!(array.value(0), "hello world!"); + assert_eq!(array.value(1), "abababababab"); + assert_eq!(array.data_buffers().len(), 0); + } + + #[test] + fn string_view_array_builder_append_with_spill_on_overflow() { + // 12 bytes from one write, +1 byte from another → spill at boundary. + let mut builder = StringViewArrayBuilder::with_capacity(1); + builder.append_with(|w| { + w.write_str("hello world!"); + w.write_str("X"); + }); + let array = builder.finish(None).unwrap(); + assert_eq!(array.value(0), "hello world!X"); + assert_eq!(array.data_buffers().len(), 1); + assert_eq!(array.data_buffers()[0].len(), 13); + } + + #[test] + fn string_view_array_builder_append_with_long_single_write() { + // A single write larger than 12 bytes spills immediately with an + // empty inline_buf prefix. + let mut builder = StringViewArrayBuilder::with_capacity(1); + builder.append_with(|w| w.write_str("a long string of 25 bytes")); + let array = builder.finish(None).unwrap(); + assert_eq!(array.value(0), "a long string of 25 bytes"); + assert_eq!(array.data_buffers().len(), 1); + assert_eq!(array.data_buffers()[0].len(), 25); + } + + #[test] + fn string_view_array_builder_append_with_many_small_writes_spilling() { + // 30 × "ab" (60 bytes total): first 6 fit inline, remainder spills. + let mut builder = StringViewArrayBuilder::with_capacity(1); + builder.append_with(|w| { + for _ in 0..30 { + w.write_str("ab"); + } + }); + let array = builder.finish(None).unwrap(); + assert_eq!(array.value(0), "ab".repeat(30)); + assert_eq!(array.data_buffers().len(), 1); + assert_eq!(array.data_buffers()[0].len(), 60); + } + + #[test] + fn string_view_array_builder_append_with_chars() { + // write_char with multi-byte UTF-8: row 0 stays inline (3 bytes), + // row 1 spills (40 bytes). + let mut builder = StringViewArrayBuilder::with_capacity(2); + builder.append_with(|w| { + w.write_char('é'); + w.write_char('!'); + }); + builder.append_with(|w| { + for _ in 0..10 { + w.write_char('🦀'); + } + }); + let array = builder.finish(None).unwrap(); + assert_eq!(array.value(0), "é!"); + assert_eq!(array.value(1), "🦀".repeat(10)); + } + + #[test] + fn string_view_array_builder_append_with_block_rotation() { + // 40 long rows, 500 bytes each, exceeds the first doubled block + // (~16 KiB). Forces the builder to rotate blocks between rows. + const STR_LEN: usize = 500; + const N: usize = 40; + let s = "x".repeat(STR_LEN); + let mut builder = StringViewArrayBuilder::with_capacity(N); + for _ in 0..N { + builder.append_with(|w| w.write_str(&s)); + } + let array = builder.finish(None).unwrap(); + assert_eq!(array.len(), N); + assert!( + array.data_buffers().len() >= 2, + "expected multiple data buffers, got {}", + array.data_buffers().len() + ); + let total: usize = array.data_buffers().iter().map(|b| b.len()).sum(); + assert_eq!(total, N * STR_LEN); + for i in 0..N { + assert_eq!(array.value(i), s); + } + } + + #[test] + fn string_view_array_builder_flushes_full_blocks() { + // Each value is 300 bytes. The first data block is 2 × STRING_VIEW_INIT_BLOCK_SIZE + // = 16 KiB, so ~50 values saturate it and the rest spill into additional + // blocks. + let value = "x".repeat(300); + let mut builder = StringViewArrayBuilder::with_capacity(100); + for _ in 0..100 { + builder.append_value(&value); + } + let array = builder.finish(None).unwrap(); + assert_eq!(array.len(), 100); + assert!( + array.data_buffers().len() > 1, + "expected multiple data buffers, got {}", + array.data_buffers().len() + ); + for i in 0..100 { + assert_eq!(array.value(i), value); + } } } diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index a97bf9710e4a9..465b15ace1d10 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -23,10 +23,10 @@ use arrow::array::{ use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; use datafusion_common::Result; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; use std::sync::Arc; #[user_doc( @@ -72,10 +72,6 @@ impl CharacterLengthFunc { } impl ScalarUDFImpl for CharacterLengthFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "character_length" } @@ -88,10 +84,7 @@ impl ScalarUDFImpl for CharacterLengthFunc { utf8_to_int_type(&arg_types[0], "character_length") } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(character_length, vec![])(&args.args) } @@ -227,7 +220,9 @@ mod tests { #[cfg(not(feature = "unicode_expressions"))] test_function!( CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé"))))], + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("josé") + )))], internal_err!( "function character_length requires compilation with feature flag: unicode_expressions." ), diff --git a/datafusion/functions/src/unicode/common.rs b/datafusion/functions/src/unicode/common.rs new file mode 100644 index 0000000000000..092f2b8003b1b --- /dev/null +++ b/datafusion/functions/src/unicode/common.rs @@ -0,0 +1,275 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common utilities for implementing unicode functions + +use arrow::array::{ + Array, ArrayRef, ByteView, GenericStringArray, Int64Array, OffsetSizeTrait, + StringViewArray, make_view, +}; +use arrow::datatypes::DataType; +use arrow_buffer::{NullBuffer, ScalarBuffer}; +use datafusion_common::Result; +use datafusion_common::ScalarValue; +use datafusion_common::cast::{ + as_generic_string_array, as_int64_array, as_string_view_array, +}; +use datafusion_common::exec_err; +use datafusion_expr::ColumnarValue; +use std::cmp::Ordering; +use std::ops::Range; +use std::sync::Arc; + +/// If `cv` is a non-null scalar string, return its value. +pub(crate) fn try_as_scalar_str(cv: &ColumnarValue) -> Option<&str> { + match cv { + ColumnarValue::Scalar(s) => s.try_as_str().flatten(), + _ => None, + } +} + +/// If `cv` is a non-null scalar Int64, return its value. +pub(crate) fn try_as_scalar_i64(cv: &ColumnarValue) -> Option { + match cv { + ColumnarValue::Scalar(ScalarValue::Int64(v)) => *v, + _ => None, + } +} + +/// A trait for `left` and `right` byte slicing operations +pub(crate) trait LeftRightSlicer { + fn slice(string: &str, n: i64) -> Range; +} + +pub(crate) struct LeftSlicer {} + +impl LeftRightSlicer for LeftSlicer { + fn slice(string: &str, n: i64) -> Range { + 0..left_right_byte_length(string, n) + } +} + +pub(crate) struct RightSlicer {} + +impl LeftRightSlicer for RightSlicer { + fn slice(string: &str, n: i64) -> Range { + if n == 0 { + // Return nothing for `n=0` + 0..0 + } else if n == i64::MIN { + // Special case for i64::MIN overflow + 0..0 + } else { + left_right_byte_length(string, -n)..string.len() + } + } +} + +/// Returns the byte offset of the `n`th codepoint in `string`, +/// or `string.len()` if the string has fewer than `n` codepoints. +#[inline] +pub(crate) fn byte_offset_of_char(string: &str, n: usize) -> usize { + string + .char_indices() + .nth(n) + .map_or(string.len(), |(i, _)| i) +} + +/// If `string` has more than `n` codepoints, returns the byte offset of +/// the `n`-th codepoint boundary. Otherwise returns the total codepoint count. +#[inline] +pub(crate) fn char_count_or_boundary(string: &str, n: usize) -> StringCharLen { + let mut count = 0; + for (byte_idx, _) in string.char_indices() { + if count == n { + return StringCharLen::ByteOffset(byte_idx); + } + count += 1; + } + StringCharLen::CharCount(count) +} + +/// Result of [`char_count_or_boundary`]. +pub(crate) enum StringCharLen { + /// The string has more than `n` codepoints; contains the byte offset + /// at the `n`-th codepoint boundary. + ByteOffset(usize), + /// The string has `n` or fewer codepoints; contains the exact count. + CharCount(usize), +} + +/// Calculate the byte length of the substring of `n` chars from string `string` +#[inline] +fn left_right_byte_length(string: &str, n: i64) -> usize { + match n.cmp(&0) { + Ordering::Less => string + .char_indices() + .nth_back((n.unsigned_abs().min(usize::MAX as u64) - 1) as usize) + .map(|(index, _)| index) + .unwrap_or(0), + Ordering::Equal => 0, + Ordering::Greater => { + byte_offset_of_char(string, n.unsigned_abs().min(usize::MAX as u64) as usize) + } + } +} + +/// General implementation for `left` and `right` functions +pub(crate) fn general_left_right( + args: &[ArrayRef], +) -> Result { + let n_array = as_int64_array(&args[1])?; + + match args[0].data_type() { + DataType::Utf8 => { + let string_array = as_generic_string_array::(&args[0])?; + general_left_right_array::(string_array, n_array) + } + DataType::LargeUtf8 => { + let string_array = as_generic_string_array::(&args[0])?; + general_left_right_array::(string_array, n_array) + } + DataType::Utf8View => { + let string_view_array = as_string_view_array(&args[0])?; + general_left_right_view::(string_view_array, n_array) + } + _ => exec_err!("Not supported"), + } +} + +/// Returns true if all offsets in the array fit in i32, meaning the values +/// buffer can be referenced by StringView's offset field. +fn values_fit_in_i32(string_array: &GenericStringArray) -> bool { + string_array + .offsets() + .last() + .map(|offset| offset.as_usize() <= i32::MAX as usize) + .unwrap_or(true) +} + +/// `left`/`right` for Utf8/LargeUtf8 input. +/// +/// When offsets fit in i32, produces a zero-copy `StringViewArray` with views +/// pointing into the input values buffer. Otherwise falls back to building a +/// `StringViewArray` by copying. +fn general_left_right_array( + string_array: &GenericStringArray, + n_array: &Int64Array, +) -> Result { + if !values_fit_in_i32(string_array) { + let result = string_array + .iter() + .zip(n_array.iter()) + .map(|(string, n)| match (string, n) { + (Some(string), Some(n)) => Some(&string[F::slice(string, n)]), + _ => None, + }) + .collect::(); + return Ok(Arc::new(result) as ArrayRef); + } + + let len = string_array.len(); + let offsets = string_array.value_offsets(); + let nulls = NullBuffer::union(string_array.nulls(), n_array.nulls()); + + let mut views_buf = Vec::with_capacity(len); + let mut has_out_of_line = false; + + for (i, offset) in offsets.iter().enumerate().take(len) { + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + views_buf.push(0); + continue; + } + + // SAFETY: we just checked validity above + let string = unsafe { string_array.value_unchecked(i) }; + let n = n_array.value(i); + let range = F::slice(string, n); + let result_bytes = &string.as_bytes()[range.clone()]; + if result_bytes.len() > 12 { + has_out_of_line = true; + } + + let buf_offset = offset.as_usize() as u32 + range.start as u32; + views_buf.push(make_view(result_bytes, 0, buf_offset)); + } + + let views = ScalarBuffer::from(views_buf); + let data_buffers = if has_out_of_line { + vec![string_array.values().clone()] + } else { + vec![] + }; + + // SAFETY: + // - Each view is produced by `make_view` with correct bytes and offset + // - Out-of-line views reference buffer index 0, which is the original + // values buffer included in data_buffers when has_out_of_line is true + // - values_fit_in_i32 guarantees all offsets fit in i32 + unsafe { + let array = StringViewArray::new_unchecked(views, data_buffers, nulls); + Ok(Arc::new(array) as ArrayRef) + } +} + +/// `general_left_right` for StringViewArray input. +fn general_left_right_view( + string_view_array: &StringViewArray, + n_array: &Int64Array, +) -> Result { + let views = string_view_array.views(); + let new_nulls = NullBuffer::union(string_view_array.nulls(), n_array.nulls()); + let len = n_array.len(); + let mut has_out_of_line = false; + + let new_views = (0..len) + .map(|idx| { + if new_nulls.as_ref().is_some_and(|n| n.is_null(idx)) { + return 0; + } + + // SAFETY: we just checked validity above + let string: &str = unsafe { string_view_array.value_unchecked(idx) }; + let n = n_array.value(idx); + + let range = F::slice(string, n); + let result_bytes = &string.as_bytes()[range.clone()]; + if result_bytes.len() > 12 { + has_out_of_line = true; + } + + let byte_view = ByteView::from(views[idx]); + let new_offset = byte_view.offset + (range.start as u32); + make_view(result_bytes, byte_view.buffer_index, new_offset) + }) + .collect::>(); + + let views = ScalarBuffer::from(new_views); + let data_buffers = if has_out_of_line { + string_view_array.data_buffers().to_vec() + } else { + vec![] + }; + + // SAFETY: + // - Each view is produced by `make_view` with correct bytes and offset + // - Out-of-line views reuse the original buffer index and adjusted offset + unsafe { + let array = StringViewArray::new_unchecked(views, data_buffers, new_nulls); + Ok(Arc::new(array) as ArrayRef) + } +} diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index e83e3d99a329c..0a83eb3ed61ef 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -15,18 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::{ - new_null_array, ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, - OffsetSizeTrait, PrimitiveArray, + ArrayAccessor, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveArray, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; +use arrow_buffer::NullBuffer; use crate::utils::utf8_to_int_type; use datafusion_common::{ - exec_err, internal_err, utils::take_function_args, Result, ScalarValue, + Result, ScalarValue, exec_err, internal_err, utils::take_function_args, }; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ @@ -81,10 +80,6 @@ impl FindInSetFunc { } impl ScalarUDFImpl for FindInSetFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "find_in_set" } @@ -98,9 +93,8 @@ impl ScalarUDFImpl for FindInSetFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let ScalarFunctionArgs { args, .. } = args; - - let [string, str_list] = take_function_args(self.name(), args)?; + let return_field = args.return_field; + let [string, str_list] = take_function_args(self.name(), args.args)?; match (string, str_list) { // both inputs are scalars @@ -139,9 +133,11 @@ impl ScalarUDFImpl for FindInSetFunc { | ScalarValue::LargeUtf8(str_list_literal), ), ) => { - let result_array = match str_list_literal { + match str_list_literal { // find_in_set(column_a, null) = null - None => new_null_array(str_array.data_type(), str_array.len()), + None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + return_field.data_type(), + )?)), Some(str_list_literal) => { let str_list = str_list_literal.split(',').collect::>(); let result = match str_array.data_type() { @@ -167,13 +163,14 @@ impl ScalarUDFImpl for FindInSetFunc { ) } other => { - exec_err!("Unsupported data type {other:?} for function find_in_set") + exec_err!( + "Unsupported data type {other:?} for function find_in_set" + ) } }; - Arc::new(result?) + Ok(ColumnarValue::Array(Arc::new(result?))) } - }; - Ok(ColumnarValue::Array(result_array)) + } } // `string` is scalar, `str_list` is an array @@ -185,11 +182,11 @@ impl ScalarUDFImpl for FindInSetFunc { ), ColumnarValue::Array(str_list_array), ) => { - let res = match string_literal { + match string_literal { // find_in_set(null, column_b) = null - None => { - new_null_array(str_list_array.data_type(), str_list_array.len()) - } + None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + return_field.data_type(), + )?)), Some(string) => { let result = match str_list_array.data_type() { DataType::Utf8 => { @@ -211,13 +208,14 @@ impl ScalarUDFImpl for FindInSetFunc { ) } other => { - exec_err!("Unsupported data type {other:?} for function find_in_set") + exec_err!( + "Unsupported data type {other:?} for function find_in_set" + ) } }; - Arc::new(result?) + Ok(ColumnarValue::Array(Arc::new(result?))) } - }; - Ok(ColumnarValue::Array(res)) + } } // both inputs are arrays @@ -267,53 +265,55 @@ fn find_in_set_general<'a, T, V>(string_array: V, str_list_array: V) -> Result, + V: ArrayAccessor + Copy, { - let string_iter = ArrayIter::new(string_array); - let str_list_iter = ArrayIter::new(str_list_array); - - let mut builder = PrimitiveArray::::builder(string_iter.len()); - - string_iter - .zip(str_list_iter) - .for_each( - |(string_opt, str_list_opt)| match (string_opt, str_list_opt) { - (Some(string), Some(str_list)) => { - let position = str_list - .split(',') - .position(|s| s == string) - .map_or(0, |idx| idx + 1); - builder.append_value(T::Native::from_usize(position).unwrap()); - } - _ => builder.append_null(), - }, - ); + let len = string_array.len(); + let nulls = NullBuffer::union(string_array.nulls(), str_list_array.nulls()); + let zero = T::Native::from_usize(0).unwrap(); + + let values: Vec = (0..len) + .map(|i| { + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + return zero; + } + let string = string_array.value(i); + let str_list = str_list_array.value(i); + let position = str_list + .split(',') + .position(|s| s == string) + .map_or(0, |idx| idx + 1); + T::Native::from_usize(position).unwrap() + }) + .collect(); - Ok(Arc::new(builder.finish()) as ArrayRef) + Ok(Arc::new(PrimitiveArray::::new(values.into(), nulls)) as ArrayRef) } fn find_in_set_left_literal<'a, T, V>(string: &str, str_list_array: V) -> Result where T: ArrowPrimitiveType, T::Native: OffsetSizeTrait, - V: ArrayAccessor, + V: ArrayAccessor + Copy, { - let mut builder = PrimitiveArray::::builder(str_list_array.len()); - - let str_list_iter = ArrayIter::new(str_list_array); - - str_list_iter.for_each(|str_list_opt| match str_list_opt { - Some(str_list) => { + let len = str_list_array.len(); + let nulls = str_list_array.nulls().cloned(); + let zero = T::Native::from_usize(0).unwrap(); + + let values: Vec = (0..len) + .map(|i| { + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + return zero; + } + let str_list = str_list_array.value(i); let position = str_list .split(',') .position(|s| s == string) .map_or(0, |idx| idx + 1); - builder.append_value(T::Native::from_usize(position).unwrap()); - } - None => builder.append_null(), - }); + T::Native::from_usize(position).unwrap() + }) + .collect(); - Ok(Arc::new(builder.finish()) as ArrayRef) + Ok(Arc::new(PrimitiveArray::::new(values.into(), nulls)) as ArrayRef) } fn find_in_set_right_literal<'a, T, V>( @@ -323,24 +323,27 @@ fn find_in_set_right_literal<'a, T, V>( where T: ArrowPrimitiveType, T::Native: OffsetSizeTrait, - V: ArrayAccessor, + V: ArrayAccessor + Copy, { - let mut builder = PrimitiveArray::::builder(string_array.len()); - - let string_iter = ArrayIter::new(string_array); - - string_iter.for_each(|string_opt| match string_opt { - Some(string) => { + let len = string_array.len(); + let nulls = string_array.nulls().cloned(); + let zero = T::Native::from_usize(0).unwrap(); + + let values: Vec = (0..len) + .map(|i| { + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + return zero; + } + let string = string_array.value(i); let position = str_list .iter() .position(|s| *s == string) .map_or(0, |idx| idx + 1); - builder.append_value(T::Native::from_usize(position).unwrap()); - } - None => builder.append_null(), - }); + T::Native::from_usize(position).unwrap() + }) + .collect(); - Ok(Arc::new(builder.finish()) as ArrayRef) + Ok(Arc::new(PrimitiveArray::::new(values.into(), nulls)) as ArrayRef) } #[cfg(test)] diff --git a/datafusion/functions/src/unicode/initcap.rs b/datafusion/functions/src/unicode/initcap.rs index 62862fbe78980..711b2c49b09f6 100644 --- a/datafusion/functions/src/unicode/initcap.rs +++ b/datafusion/functions/src/unicode/initcap.rs @@ -15,21 +15,20 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; -use arrow::array::{ - Array, ArrayRef, GenericStringBuilder, OffsetSizeTrait, StringViewBuilder, -}; +use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::buffer::{Buffer, OffsetBuffer}; use arrow::datatypes::DataType; +use crate::strings::{GenericStringArrayBuilder, StringViewArrayBuilder}; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::types::logical_string; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::{ - Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass, - Volatility, + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; @@ -75,10 +74,6 @@ impl InitcapFunc { } impl ScalarUDFImpl for InitcapFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "initcap" } @@ -95,10 +90,40 @@ impl ScalarUDFImpl for InitcapFunc { } } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let arg = &args.args[0]; + + // Scalar fast path - handle directly without array conversion + if let ColumnarValue::Scalar(scalar) = arg { + return match scalar { + ScalarValue::Utf8(None) + | ScalarValue::LargeUtf8(None) + | ScalarValue::Utf8View(None) => Ok(arg.clone()), + ScalarValue::Utf8(Some(s)) => { + let mut result = String::new(); + initcap_string(s, &mut result); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))) + } + ScalarValue::LargeUtf8(Some(s)) => { + let mut result = String::new(); + initcap_string(s, &mut result); + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result)))) + } + ScalarValue::Utf8View(Some(s)) => { + let mut result = String::new(); + initcap_string(s, &mut result); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result)))) + } + other => { + exec_err!( + "Unsupported data type {:?} for function `initcap`", + other.data_type() + ) + } + }; + } + + // Array path let args = &args.args; match args[0].data_type() { DataType::Utf8 => make_scalar_function(initcap::, vec![])(args), @@ -115,8 +140,8 @@ impl ScalarUDFImpl for InitcapFunc { } } -/// Converts the first letter of each word to upper case and the rest to lower -/// case. Words are sequences of alphanumeric characters separated by +/// Converts the first letter of each word to uppercase and the rest to +/// lowercase. Words are sequences of alphanumeric characters separated by /// non-alphanumeric characters. /// /// Example: @@ -126,38 +151,125 @@ impl ScalarUDFImpl for InitcapFunc { fn initcap(args: &[ArrayRef]) -> Result { let string_array = as_generic_string_array::(&args[0])?; - let mut builder = GenericStringBuilder::::with_capacity( - string_array.len(), + if string_array.is_ascii() { + return Ok(initcap_ascii_array(string_array)); + } + + let len = string_array.len(); + let mut builder = GenericStringArrayBuilder::::with_capacity( + len, string_array.value_data().len(), ); let mut container = String::new(); - string_array.iter().for_each(|str| match str { - Some(s) => { + let nulls = string_array.nulls().cloned(); + if let Some(ref n) = nulls { + for i in 0..len { + if n.is_null(i) { + builder.append_placeholder(); + } else { + // SAFETY: not null per check above. + let s = unsafe { string_array.value_unchecked(i) }; + initcap_string(s, &mut container); + builder.append_value(&container); + } + } + } else { + for i in 0..len { + // SAFETY: no null buffer means every index is valid. + let s = unsafe { string_array.value_unchecked(i) }; initcap_string(s, &mut container); builder.append_value(&container); } - None => builder.append_null(), - }); + } - Ok(Arc::new(builder.finish()) as ArrayRef) + Ok(Arc::new(builder.finish(nulls)?) as ArrayRef) } -fn initcap_utf8view(args: &[ArrayRef]) -> Result { - let string_view_array = as_string_view_array(&args[0])?; +/// Fast path for `Utf8` or `LargeUtf8` arrays that are ASCII-only. We can use a +/// single pass over the buffer and operate directly on bytes. +fn initcap_ascii_array( + string_array: &GenericStringArray, +) -> ArrayRef { + let offsets = string_array.offsets(); + let src = string_array.value_data(); + let first_offset = offsets.first().unwrap().as_usize(); + let last_offset = offsets.last().unwrap().as_usize(); - let mut builder = StringViewBuilder::with_capacity(string_view_array.len()); + // For sliced arrays, only convert the visible bytes, not the entire input + // buffer. + let mut out = Vec::with_capacity(last_offset - first_offset); + for window in offsets.windows(2) { + let start = window[0].as_usize(); + let end = window[1].as_usize(); + + let mut prev_is_alnum = false; + for &b in &src[start..end] { + let converted = if prev_is_alnum { + b.to_ascii_lowercase() + } else { + b.to_ascii_uppercase() + }; + out.push(converted); + prev_is_alnum = b.is_ascii_alphanumeric(); + } + } + + let values = Buffer::from_vec(out); + let out_offsets = if first_offset == 0 { + offsets.clone() + } else { + // For sliced arrays, we need to rebase the offsets to reflect that the + // output only contains the bytes in the visible slice. + let rebased_offsets = offsets + .iter() + .map(|offset| T::usize_as(offset.as_usize() - first_offset)) + .collect::>(); + OffsetBuffer::::new(rebased_offsets.into()) + }; + + // SAFETY: ASCII case conversion preserves byte length, so the original + // string boundaries are preserved. `out_offsets` is either identical to + // the input offsets or a rebased version relative to the compacted values + // buffer. + Arc::new(unsafe { + GenericStringArray::::new_unchecked( + out_offsets, + values, + string_array.nulls().cloned(), + ) + }) +} + +fn initcap_utf8view(args: &[ArrayRef]) -> Result { + let string_view_array = as_string_view_array(&args[0])?; + let len = string_view_array.len(); + let mut builder = StringViewArrayBuilder::with_capacity(len); let mut container = String::new(); - string_view_array.iter().for_each(|str| match str { - Some(s) => { + + let nulls = string_view_array.nulls().cloned(); + if let Some(ref n) = nulls { + for i in 0..len { + if n.is_null(i) { + builder.append_placeholder(); + } else { + // SAFETY: not null per check above. + let s = unsafe { string_view_array.value_unchecked(i) }; + initcap_string(s, &mut container); + builder.append_value(&container); + } + } + } else { + for i in 0..len { + // SAFETY: no null buffer means every index is valid. + let s = unsafe { string_view_array.value_unchecked(i) }; initcap_string(s, &mut container); builder.append_value(&container); } - None => builder.append_null(), - }); + } - Ok(Arc::new(builder.finish()) as ArrayRef) + Ok(Arc::new(builder.finish(nulls)?) as ArrayRef) } fn initcap_string(input: &str, container: &mut String) { @@ -165,13 +277,16 @@ fn initcap_string(input: &str, container: &mut String) { let mut prev_is_alphanumeric = false; if input.is_ascii() { - for c in input.chars() { + container.reserve(input.len()); + // SAFETY: each byte is ASCII, so the result is valid UTF-8. + let out = unsafe { container.as_mut_vec() }; + for &b in input.as_bytes() { if prev_is_alphanumeric { - container.push(c.to_ascii_lowercase()); + out.push(b.to_ascii_lowercase()); } else { - container.push(c.to_ascii_uppercase()); - }; - prev_is_alphanumeric = c.is_ascii_alphanumeric(); + out.push(b.to_ascii_uppercase()); + } + prev_is_alphanumeric = b.is_ascii_alphanumeric(); } } else { for c in input.chars() { @@ -189,10 +304,11 @@ fn initcap_string(input: &str, container: &mut String) { mod tests { use crate::unicode::initcap::InitcapFunc; use crate::utils::test::test_function; - use arrow::array::{Array, StringArray, StringViewArray}; + use arrow::array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray}; use arrow::datatypes::DataType::{Utf8, Utf8View}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use std::sync::Arc; #[test] fn test_functions() -> Result<()> { @@ -296,4 +412,114 @@ mod tests { Ok(()) } + + #[test] + fn test_initcap_ascii_array() -> Result<()> { + let array = StringArray::from(vec![ + Some("hello world"), + None, + Some("foo-bar_baz/baX"), + Some(""), + Some("123 abc 456DEF"), + Some("ALL CAPS"), + Some("already correct"), + ]); + let args: Vec = vec![Arc::new(array)]; + let result = super::initcap::(&args)?; + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.len(), 7); + assert_eq!(result.value(0), "Hello World"); + assert!(result.is_null(1)); + assert_eq!(result.value(2), "Foo-Bar_Baz/Bax"); + assert_eq!(result.value(3), ""); + assert_eq!(result.value(4), "123 Abc 456def"); + assert_eq!(result.value(5), "All Caps"); + assert_eq!(result.value(6), "Already Correct"); + Ok(()) + } + + #[test] + fn test_initcap_ascii_large_array() -> Result<()> { + let array = LargeStringArray::from(vec![ + Some("hello world"), + None, + Some("foo-bar_baz/baX"), + Some(""), + Some("123 abc 456DEF"), + Some("ALL CAPS"), + Some("already correct"), + ]); + let args: Vec = vec![Arc::new(array)]; + let result = super::initcap::(&args)?; + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.len(), 7); + assert_eq!(result.value(0), "Hello World"); + assert!(result.is_null(1)); + assert_eq!(result.value(2), "Foo-Bar_Baz/Bax"); + assert_eq!(result.value(3), ""); + assert_eq!(result.value(4), "123 Abc 456def"); + assert_eq!(result.value(5), "All Caps"); + assert_eq!(result.value(6), "Already Correct"); + Ok(()) + } + + /// Test that initcap works correctly on a sliced ASCII StringArray. + #[test] + fn test_initcap_sliced_ascii_array() -> Result<()> { + let array = StringArray::from(vec![ + Some("hello world"), + Some("foo bar"), + Some("baz qux"), + ]); + // Slice to get only the last two elements. The resulting array's + // offsets are [11, 18, 25] (non-zero start), but value_data still + // contains the full original buffer. + let sliced = array.slice(1, 2); + let args: Vec = vec![Arc::new(sliced)]; + let result = super::initcap::(&args)?; + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.len(), 2); + assert_eq!(result.value(0), "Foo Bar"); + assert_eq!(result.value(1), "Baz Qux"); + + // The output values buffer should be compact + assert_eq!(*result.offsets().first().unwrap(), 0); + assert_eq!( + result.value_data().len(), + *result.offsets().last().unwrap() as usize + ); + Ok(()) + } + + /// Test that initcap works correctly on a sliced ASCII LargeStringArray. + #[test] + fn test_initcap_sliced_ascii_large_array() -> Result<()> { + let array = LargeStringArray::from(vec![ + Some("hello world"), + Some("foo bar"), + Some("baz qux"), + ]); + // Slice to get only the last two elements. The resulting array's + // offsets are [11, 18, 25] (non-zero start), but value_data still + // contains the full original buffer. + let sliced = array.slice(1, 2); + let args: Vec = vec![Arc::new(sliced)]; + let result = super::initcap::(&args)?; + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.len(), 2); + assert_eq!(result.value(0), "Foo Bar"); + assert_eq!(result.value(1), "Baz Qux"); + + // The output values buffer should be compact + assert_eq!(*result.offsets().first().unwrap(), 0); + assert_eq!( + result.value_data().len(), + *result.offsets().last().unwrap() as usize + ); + Ok(()) + } } diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs index ec7ec456ab8b5..423ab4d5dc54b 100644 --- a/datafusion/functions/src/unicode/left.rs +++ b/datafusion/functions/src/unicode/left.rs @@ -15,25 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::cmp::Ordering; -use std::sync::Arc; - -use arrow::array::{ - Array, ArrayAccessor, ArrayIter, ArrayRef, GenericStringArray, Int64Array, - OffsetSizeTrait, -}; +use crate::unicode::common::{LeftSlicer, general_left_right}; +use crate::utils::make_scalar_function; use arrow::datatypes::DataType; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; -use datafusion_common::cast::{ - as_generic_string_array, as_int64_array, as_string_view_array, -}; -use datafusion_common::exec_err; use datafusion_common::Result; +use datafusion_common::exec_err; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -81,10 +71,6 @@ impl LeftFunc { } impl ScalarUDFImpl for LeftFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "left" } @@ -93,23 +79,24 @@ impl ScalarUDFImpl for LeftFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "left") + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8View) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + /// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. + /// left('abcde', 2) = 'ab' + /// left('abcde', -2) = 'abc' + /// The implementation uses UTF-8 code points as characters + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = &args.args; match args[0].data_type() { - DataType::Utf8 | DataType::Utf8View => { - make_scalar_function(left::, vec![])(args) + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => { + make_scalar_function(general_left_right::, vec![])(args) } - DataType::LargeUtf8 => make_scalar_function(left::, vec![])(args), other => exec_err!( - "Unsupported data type {other:?} for function left,\ - expected Utf8View, Utf8 or LargeUtf8." + "Unsupported data type {other:?} for function {},\ + expected Utf8View, Utf8 or LargeUtf8.", + self.name() ), } } @@ -119,54 +106,10 @@ impl ScalarUDFImpl for LeftFunc { } } -/// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. -/// left('abcde', 2) = 'ab' -/// The implementation uses UTF-8 code points as characters -fn left(args: &[ArrayRef]) -> Result { - let n_array = as_int64_array(&args[1])?; - - if args[0].data_type() == &DataType::Utf8View { - let string_array = as_string_view_array(&args[0])?; - left_impl::(string_array, n_array) - } else { - let string_array = as_generic_string_array::(&args[0])?; - left_impl::(string_array, n_array) - } -} - -fn left_impl<'a, T: OffsetSizeTrait, V: ArrayAccessor>( - string_array: V, - n_array: &Int64Array, -) -> Result { - let iter = ArrayIter::new(string_array); - let result = iter - .zip(n_array.iter()) - .map(|(string, n)| match (string, n) { - (Some(string), Some(n)) => match n.cmp(&0) { - Ordering::Less => { - let len = string.chars().count() as i64; - Some(if n.abs() < len { - string.chars().take((len + n) as usize).collect::() - } else { - "".to_string() - }) - } - Ordering::Equal => Some("".to_string()), - Ordering::Greater => { - Some(string.chars().take(n as usize).collect::()) - } - }, - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, StringViewArray}; + use arrow::datatypes::DataType::Utf8View; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -184,8 +127,8 @@ mod tests { ], Ok(Some("ab")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( LeftFunc::new(), @@ -195,8 +138,8 @@ mod tests { ], Ok(Some("abcde")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( LeftFunc::new(), @@ -206,8 +149,19 @@ mod tests { ], Ok(Some("abc")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray + ); + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(i64::MIN)), + ], + Ok(Some("")), + &str, + Utf8View, + StringViewArray ); test_function!( LeftFunc::new(), @@ -217,8 +171,8 @@ mod tests { ], Ok(Some("")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( LeftFunc::new(), @@ -228,8 +182,8 @@ mod tests { ], Ok(Some("")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( LeftFunc::new(), @@ -239,8 +193,8 @@ mod tests { ], Ok(None), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( LeftFunc::new(), @@ -250,8 +204,8 @@ mod tests { ], Ok(None), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( LeftFunc::new(), @@ -261,8 +215,8 @@ mod tests { ], Ok(Some("joséé")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( LeftFunc::new(), @@ -272,8 +226,8 @@ mod tests { ], Ok(Some("joséé")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); #[cfg(not(feature = "unicode_expressions"))] test_function!( @@ -286,9 +240,77 @@ mod tests { "function left requires compilation with feature flag: unicode_expressions." ), &str, - Utf8, - StringArray + Utf8View, + StringViewArray + ); + + // StringView cases + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("abcde".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("ab")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("abcde".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(200i64)), + ], + Ok(Some("abcde")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(200i64)), + ], + Ok(Some("")), + &str, + Utf8View, + StringViewArray ); + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "joséésoj".to_string() + ))), + ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ], + Ok(Some("joséé")), + &str, + Utf8View, + StringViewArray + ); + + // Unicode indexing case + let input = "joé楽s𐀀so↓j"; + for n in 1..=input.chars().count() { + let expected = input + .chars() + .take(input.chars().count() - n) + .collect::(); + test_function!( + LeftFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from(input)), + ColumnarValue::Scalar(ScalarValue::from(-(n as i64))), + ], + Ok(Some(expected.as_str())), + &str, + Utf8View, + StringViewArray + ); + } Ok(()) } diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index 6940459b177a9..d27bc8633e730 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -15,24 +15,23 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::fmt::Write; use std::sync::Arc; +use DataType::{LargeUtf8, Utf8, Utf8View}; use arrow::array::{ Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, OffsetSizeTrait, StringArrayType, StringViewArray, }; use arrow::datatypes::DataType; -use unicode_segmentation::UnicodeSegmentation; -use DataType::{LargeUtf8, Utf8, Utf8View}; use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::as_int64_array; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -49,7 +48,10 @@ use datafusion_macros::user_doc; +---------------------------------------------+ ```"#, standard_argument(name = "str", prefix = "String"), - argument(name = "n", description = "String length to pad to."), + argument( + name = "n", + description = "String length to pad to. If the input string is longer than this length, it is truncated (on the right)." + ), argument( name = "padding_str", description = "Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._" @@ -93,10 +95,6 @@ impl LPadFunc { } impl ScalarUDFImpl for LPadFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "lpad" } @@ -109,14 +107,67 @@ impl ScalarUDFImpl for LPadFunc { utf8_to_str_type(&arg_types[0], "lpad") } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - let args = &args.args; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { + args, number_rows, .. + } = args; + + const MAX_SCALAR_TARGET_LEN: usize = 16384; + + // If target_len and fill (if specified) are constants, use the scalar + // fast path. + if let Some(target_len) = try_as_scalar_i64(&args[1]) { + let target_len: usize = match usize::try_from(target_len) { + Ok(n) if n <= i32::MAX as usize => n, + Ok(n) => { + return exec_err!( + "lpad requested length {n} too large, maximum allowed length is {}", + i32::MAX + ); + } + Err(_) => 0, // negative → 0 + }; + + let fill_str = if args.len() == 3 { + try_as_scalar_str(&args[2]) + } else { + Some(" ") + }; + + // Skip the fast path for very large `target_len` values to avoid + // consuming too much memory. Such large padding values are uncommon + // in practice. + if target_len <= MAX_SCALAR_TARGET_LEN + && let Some(fill) = fill_str + { + let string_array = args[0].to_array_of_size(number_rows)?; + let result = match string_array.data_type() { + Utf8View => lpad_scalar_args::<_, i32>( + string_array.as_string_view(), + target_len, + fill, + ), + Utf8 => lpad_scalar_args::<_, i32>( + string_array.as_string::(), + target_len, + fill, + ), + LargeUtf8 => lpad_scalar_args::<_, i64>( + string_array.as_string::(), + target_len, + fill, + ), + other => { + exec_err!("Unsupported data type {other:?} for function lpad") + } + }?; + return Ok(ColumnarValue::Array(result)); + } + } + match args[0].data_type() { - Utf8 | Utf8View => make_scalar_function(lpad::, vec![])(args), - LargeUtf8 => make_scalar_function(lpad::, vec![])(args), + Utf8 | Utf8View => make_scalar_function(lpad::, vec![])(&args), + LargeUtf8 => make_scalar_function(lpad::, vec![])(&args), other => exec_err!("Unsupported data type {other:?} for function lpad"), } } @@ -126,8 +177,125 @@ impl ScalarUDFImpl for LPadFunc { } } -/// Extends the string to length 'length' by prepending the characters fill (a space by default). -/// If the string is already longer than length then it is truncated (on the right). +use super::common::{ + StringCharLen, char_count_or_boundary, try_as_scalar_i64, try_as_scalar_str, +}; + +/// Optimized lpad for constant target_len and fill arguments. +fn lpad_scalar_args<'a, V: StringArrayType<'a> + Copy, T: OffsetSizeTrait>( + string_array: V, + target_len: usize, + fill: &str, +) -> Result { + if string_array.is_ascii() && fill.is_ascii() { + lpad_scalar_ascii::(string_array, target_len, fill) + } else { + lpad_scalar_unicode::(string_array, target_len, fill) + } +} + +fn lpad_scalar_ascii<'a, V: StringArrayType<'a> + Copy, T: OffsetSizeTrait>( + string_array: V, + target_len: usize, + fill: &str, +) -> Result { + // With a scalar `target_len` and `fill`, we can precompute a padding + // buffer of `target_len` fill characters repeated cyclically. + let padding_buf = if !fill.is_empty() { + let mut buf = String::with_capacity(target_len); + while buf.len() < target_len { + let remaining = target_len - buf.len(); + if remaining >= fill.len() { + buf.push_str(fill); + } else { + buf.push_str(&fill[..remaining]); + } + } + buf + } else { + String::new() + }; + + // Each output row is exactly `target_len` ASCII bytes (padding + string). + let data_capacity = string_array.len().saturating_mul(target_len); + let mut builder = + GenericStringBuilder::::with_capacity(string_array.len(), data_capacity); + + for maybe_string in string_array.iter() { + match maybe_string { + Some(string) => { + let str_len = string.len(); + if target_len <= str_len { + builder.append_value(&string[..target_len]); + } else if fill.is_empty() { + builder.append_value(string); + } else { + let pad_needed = target_len - str_len; + builder.write_str(&padding_buf[..pad_needed])?; + builder.append_value(string); + } + } + None => builder.append_null(), + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +fn lpad_scalar_unicode<'a, V: StringArrayType<'a> + Copy, T: OffsetSizeTrait>( + string_array: V, + target_len: usize, + fill: &str, +) -> Result { + let fill_chars: Vec = fill.chars().collect(); + + // With a scalar `target_len` and `fill`, we can precompute a padding buffer + // of `target_len` fill characters repeated cyclically. Because Unicode + // characters are variable-width, we build a byte-offset table to map from + // character count to the corresponding byte position in the padding buffer. + let (padding_buf, char_byte_offsets) = if !fill_chars.is_empty() { + let mut buf = String::new(); + let mut offsets = Vec::with_capacity(target_len + 1); + offsets.push(0usize); + for i in 0..target_len { + buf.push(fill_chars[i % fill_chars.len()]); + offsets.push(buf.len()); + } + (buf, offsets) + } else { + (String::new(), vec![0]) + }; + + // Each output row is `target_len` chars; multiply by 4 (max UTF-8 bytes + // per char) for an upper bound in bytes. + let data_capacity = string_array.len().saturating_mul(target_len * 4); + let mut builder = + GenericStringBuilder::::with_capacity(string_array.len(), data_capacity); + + for maybe_string in string_array.iter() { + match maybe_string { + Some(string) => match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { + if !fill_chars.is_empty() { + let pad_chars = target_len - char_count; + let pad_bytes = char_byte_offsets[pad_chars]; + builder.write_str(&padding_buf[..pad_bytes])?; + } + builder.append_value(string); + } + }, + None => builder.append_null(), + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +/// Left-pads `string` to `target_len` using the fill string (default: space). +/// Truncates from the right if `string` is already longer than `target_len`. /// lpad('hi', 5, 'xy') = 'xyxhi' fn lpad(args: &[ArrayRef]) -> Result { if args.len() <= 1 || args.len() > 3 { @@ -160,7 +328,7 @@ fn lpad(args: &[ArrayRef]) -> Result { length_array, &args[2], ), - (_, _) => unreachable!("lpad"), + (len, dt) => unreachable!("lpad: unexpected arg count ({len}) or type ({dt})"), } } @@ -206,37 +374,73 @@ where { let array = if let Some(fill_array) = fill_array { let mut builder: GenericStringBuilder = GenericStringBuilder::new(); + let mut fill_chars_buf = Vec::new(); - for ((string, length), fill) in string_array + for ((string, target_len), fill) in string_array .iter() .zip(length_array.iter()) .zip(fill_array.iter()) { - if let (Some(string), Some(length), Some(fill)) = (string, length, fill) { - if length > i32::MAX as i64 { - return exec_err!("lpad requested length {length} too large"); + if let (Some(string), Some(target_len), Some(fill)) = + (string, target_len, fill) + { + if target_len > i32::MAX as i64 { + return exec_err!( + "lpad requested length {target_len} too large, maximum allowed length is {}", + i32::MAX + ); } - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { + let target_len = if target_len < 0 { + 0 + } else { + target_len as usize + }; + if target_len == 0 { builder.append_value(""); continue; } - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - builder.append_value(graphemes[..length].concat()); - } else if fill_chars.is_empty() { - builder.append_value(string); + if string.is_ascii() && fill.is_ascii() { + // ASCII fast path: byte length == character length. + let str_len = string.len(); + if target_len < str_len { + builder.append_value(&string[..target_len]); + } else if fill.is_empty() { + builder.append_value(string); + } else { + let pad_len = target_len - str_len; + let fill_len = fill.len(); + let full_reps = pad_len / fill_len; + let remainder = pad_len % fill_len; + for _ in 0..full_reps { + builder.write_str(fill)?; + } + if remainder > 0 { + builder.write_str(&fill[..remainder])?; + } + builder.append_value(string); + } } else { - for l in 0..length - graphemes.len() { - let c = *fill_chars.get(l % fill_chars.len()).unwrap(); - builder.write_char(c)?; + fill_chars_buf.clear(); + fill_chars_buf.extend(fill.chars()); + + match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { + if !fill_chars_buf.is_empty() { + for l in 0..target_len - char_count { + let c = *fill_chars_buf + .get(l % fill_chars_buf.len()) + .unwrap(); + builder.write_char(c)?; + } + } + builder.append_value(string); + } } - builder.write_str(string)?; - builder.append_value(""); } } else { builder.append_null(); @@ -247,25 +451,48 @@ where } else { let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - for (string, length) in string_array.iter().zip(length_array.iter()) { - if let (Some(string), Some(length)) = (string, length) { - if length > i32::MAX as i64 { - return exec_err!("lpad requested length {length} too large"); + for (string, target_len) in string_array.iter().zip(length_array.iter()) { + if let (Some(string), Some(target_len)) = (string, target_len) { + if target_len > i32::MAX as i64 { + return exec_err!( + "lpad requested length {target_len} too large, maximum allowed length is {}", + i32::MAX + ); } - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { + let target_len = if target_len < 0 { + 0 + } else { + target_len as usize + }; + if target_len == 0 { builder.append_value(""); continue; } - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - builder.append_value(graphemes[..length].concat()); + if string.is_ascii() { + // ASCII fast path: byte length == character length + let str_len = string.len(); + if target_len < str_len { + builder.append_value(&string[..target_len]); + } else { + for _ in 0..(target_len - str_len) { + builder.write_str(" ")?; + } + builder.append_value(string); + } } else { - builder.write_str(" ".repeat(length - graphemes.len()).as_str())?; - builder.write_str(string)?; - builder.append_value(""); + match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { + for _ in 0..(target_len - char_count) { + builder.write_str(" ")?; + } + builder.append_value(string); + } + } } } else { builder.append_null(); @@ -512,6 +739,17 @@ mod tests { None, Ok(None) ); + test_lpad!( + Some("hello".into()), + ScalarValue::Int64(Some(2i64)), + Ok(Some("he")) + ); + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(6i64)), + Some("xy".into()), + Ok(Some("xyxyhi")) + ); test_lpad!( Some("josé".into()), ScalarValue::Int64(Some(10i64)), @@ -526,9 +764,13 @@ mod tests { ); #[cfg(not(feature = "unicode_expressions"))] - test_lpad!(Some("josé".into()), ScalarValue::Int64(Some(5i64)), internal_err!( + test_lpad!( + Some("josé".into()), + ScalarValue::Int64(Some(5i64)), + internal_err!( "function lpad requires compilation with feature flag: unicode_expressions." - )); + ) + ); Ok(()) } diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index 4a0dd21d749af..7250b3915fb5c 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use datafusion_expr::ScalarUDF; pub mod character_length; +pub mod common; pub mod find_in_set; pub mod initcap; pub mod left; diff --git a/datafusion/functions/src/unicode/planner.rs b/datafusion/functions/src/unicode/planner.rs index e4f29be3d13dc..38c82486416a6 100644 --- a/datafusion/functions/src/unicode/planner.rs +++ b/datafusion/functions/src/unicode/planner.rs @@ -17,9 +17,9 @@ //! SQL planning extensions like [`UnicodeFunctionPlanner`] +use datafusion_expr::Expr; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::planner::{ExprPlanner, PlannerResult}; -use datafusion_expr::Expr; #[derive(Default, Debug)] pub struct UnicodeFunctionPlanner; diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs index 56f6048d6b6e9..813dcb5f504dd 100644 --- a/datafusion/functions/src/unicode/reverse.rs +++ b/datafusion/functions/src/unicode/reverse.rs @@ -15,20 +15,19 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::sync::Arc; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; -use arrow::array::{ - Array, ArrayRef, AsArray, GenericStringBuilder, OffsetSizeTrait, StringArrayType, +use crate::strings::{ + BulkNullStringArrayBuilder, GenericStringArrayBuilder, StringViewArrayBuilder, }; +use crate::utils::make_scalar_function; +use DataType::{LargeUtf8, Utf8, Utf8View}; +use arrow::array::{Array, ArrayRef, AsArray, StringArrayType}; use arrow::datatypes::DataType; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; -use DataType::{LargeUtf8, Utf8, Utf8View}; #[user_doc( doc_section(label = "String Functions"), @@ -69,10 +68,6 @@ impl ReverseFunc { } impl ScalarUDFImpl for ReverseFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "reverse" } @@ -82,17 +77,13 @@ impl ScalarUDFImpl for ReverseFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "reverse") + Ok(arg_types[0].clone()) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = &args.args; match args[0].data_type() { - Utf8 | Utf8View => make_scalar_function(reverse::, vec![])(args), - LargeUtf8 => make_scalar_function(reverse::, vec![])(args), + Utf8 | Utf8View | LargeUtf8 => make_scalar_function(reverse, vec![])(args), other => { exec_err!("Unsupported data type {other:?} for function reverse") } @@ -106,48 +97,89 @@ impl ScalarUDFImpl for ReverseFunc { /// Reverses the order of the characters in the string `reverse('abcde') = 'edcba'`. /// The implementation uses UTF-8 code points as characters -fn reverse(args: &[ArrayRef]) -> Result { - if args[0].data_type() == &Utf8View { - reverse_impl::(&args[0].as_string_view()) - } else { - reverse_impl::(&args[0].as_string::()) +fn reverse(args: &[ArrayRef]) -> Result { + let len = args[0].len(); + + match args[0].data_type() { + LargeUtf8 => reverse_impl( + &args[0].as_string::(), + GenericStringArrayBuilder::::with_capacity(len, 1024), + ), + Utf8 => reverse_impl( + &args[0].as_string::(), + GenericStringArrayBuilder::::with_capacity(len, 1024), + ), + Utf8View => reverse_impl( + &args[0].as_string_view(), + StringViewArrayBuilder::with_capacity(len), + ), + _ => unreachable!( + "Reverse can only be applied to Utf8View, Utf8 and LargeUtf8 types" + ), } } -fn reverse_impl<'a, T: OffsetSizeTrait, V: StringArrayType<'a>>( - string_array: &V, -) -> Result { - let mut builder = GenericStringBuilder::::with_capacity(string_array.len(), 1024); - +fn reverse_impl<'a, StringArrType, B>( + string_array: &StringArrType, + mut array_builder: B, +) -> Result +where + StringArrType: StringArrayType<'a>, + B: BulkNullStringArrayBuilder, +{ + let item_len = string_array.len(); + // Null-preserving: reuse the input null buffer as the output null buffer. + let nulls = string_array.nulls().cloned(); let mut string_buf = String::new(); let mut byte_buf = Vec::::new(); - for string in string_array.iter() { - if let Some(s) = string { - if s.is_ascii() { - // reverse bytes directly since ASCII characters are single bytes - byte_buf.extend(s.as_bytes()); - byte_buf.reverse(); - // SAFETY: Since the original string was ASCII, reversing the bytes still results in valid UTF-8. - let reversed = unsafe { std::str::from_utf8_unchecked(&byte_buf) }; - builder.append_value(reversed); - byte_buf.clear(); + + if let Some(ref n) = nulls { + for i in 0..item_len { + if n.is_null(i) { + array_builder.append_placeholder(); } else { - string_buf.extend(s.chars().rev()); - builder.append_value(&string_buf); - string_buf.clear(); + // SAFETY: `n.is_null(i)` was false in the branch above. + let s = unsafe { string_array.value_unchecked(i) }; + append_reversed(s, &mut array_builder, &mut byte_buf, &mut string_buf); } - } else { - builder.append_null(); + } + } else { + for i in 0..item_len { + // SAFETY: no null buffer means every index is valid. + let s = unsafe { string_array.value_unchecked(i) }; + append_reversed(s, &mut array_builder, &mut byte_buf, &mut string_buf); } } - Ok(Arc::new(builder.finish()) as ArrayRef) + array_builder.finish(nulls) +} + +#[inline] +fn append_reversed( + s: &str, + builder: &mut B, + byte_buf: &mut Vec, + string_buf: &mut String, +) { + if s.is_ascii() { + // reverse bytes directly since ASCII characters are single bytes + byte_buf.extend(s.as_bytes()); + byte_buf.reverse(); + // SAFETY: input was ASCII, so reversed bytes are still valid UTF-8. + let reversed = unsafe { std::str::from_utf8_unchecked(byte_buf) }; + builder.append_value(reversed); + byte_buf.clear(); + } else { + string_buf.extend(s.chars().rev()); + builder.append_value(string_buf); + string_buf.clear(); + } } #[cfg(test)] mod tests { - use arrow::array::{Array, LargeStringArray, StringArray}; - use arrow::datatypes::DataType::{LargeUtf8, Utf8}; + use arrow::array::{Array, LargeStringArray, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -180,8 +212,8 @@ mod tests { vec![ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], $EXPECTED, &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); }; } @@ -202,4 +234,58 @@ mod tests { Ok(()) } + + #[test] + fn test_array_with_nulls() { + use crate::unicode::reverse::reverse; + use arrow::array::ArrayRef; + use std::sync::Arc; + + let input_values = vec![Some("abcd"), None, Some("XYZ"), Some("héllo"), None]; + let expected: Vec> = + vec![Some("dcba"), None, Some("ZYX"), Some("olléh"), None]; + + let cases: Vec<(&str, ArrayRef)> = vec![ + ( + "StringArray", + Arc::new(StringArray::from(input_values.clone())), + ), + ( + "LargeStringArray", + Arc::new(LargeStringArray::from(input_values.clone())), + ), + ( + "StringViewArray", + Arc::new(StringViewArray::from(input_values.clone())), + ), + ]; + + for (label, input) in cases { + let out = reverse(&[input]).unwrap(); + assert_eq!(out.len(), expected.len(), "{label}: length mismatch"); + + let actual: Vec> = match out.data_type() { + Utf8 => out + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect(), + LargeUtf8 => out + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect(), + Utf8View => out + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect(), + other => panic!("{label}: unexpected output type {other:?}"), + }; + assert_eq!(actual, expected, "{label}: value mismatch"); + } + } } diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs index 670586e11b4f7..0ed170fef72d7 100644 --- a/datafusion/functions/src/unicode/right.rs +++ b/datafusion/functions/src/unicode/right.rs @@ -15,25 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::cmp::{max, Ordering}; -use std::sync::Arc; - -use arrow::array::{ - Array, ArrayAccessor, ArrayIter, ArrayRef, GenericStringArray, Int64Array, - OffsetSizeTrait, -}; +use crate::unicode::common::{RightSlicer, general_left_right}; +use crate::utils::make_scalar_function; use arrow::datatypes::DataType; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; -use datafusion_common::cast::{ - as_generic_string_array, as_int64_array, as_string_view_array, -}; -use datafusion_common::exec_err; use datafusion_common::Result; +use datafusion_common::exec_err; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -81,10 +71,6 @@ impl RightFunc { } impl ScalarUDFImpl for RightFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "right" } @@ -93,23 +79,24 @@ impl ScalarUDFImpl for RightFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "right") + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8View) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + /// Returns right n characters in the string, or when n is negative, returns all but first |n| characters. + /// right('abcde', 2) = 'de' + /// right('abcde', -2) = 'cde' + /// The implementation uses UTF-8 code points as characters + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let args = &args.args; match args[0].data_type() { - DataType::Utf8 | DataType::Utf8View => { - make_scalar_function(right::, vec![])(args) + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => { + make_scalar_function(general_left_right::, vec![])(args) } - DataType::LargeUtf8 => make_scalar_function(right::, vec![])(args), other => exec_err!( - "Unsupported data type {other:?} for function right,\ - expected Utf8View, Utf8 or LargeUtf8." + "Unsupported data type {other:?} for function {},\ + expected Utf8View, Utf8 or LargeUtf8.", + self.name() ), } } @@ -119,58 +106,10 @@ impl ScalarUDFImpl for RightFunc { } } -/// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. -/// right('abcde', 2) = 'de' -/// The implementation uses UTF-8 code points as characters -fn right(args: &[ArrayRef]) -> Result { - let n_array = as_int64_array(&args[1])?; - if args[0].data_type() == &DataType::Utf8View { - // string_view_right(args) - let string_array = as_string_view_array(&args[0])?; - right_impl::(&mut string_array.iter(), n_array) - } else { - // string_right::(args) - let string_array = &as_generic_string_array::(&args[0])?; - right_impl::(&mut string_array.iter(), n_array) - } -} - -// Currently the return type can only be Utf8 or LargeUtf8, to reach fully support, we need -// to edit the `get_optimal_return_type` in utils.rs to make the udfs be able to return Utf8View -// See https://github.com/apache/datafusion/issues/11790#issuecomment-2283777166 -fn right_impl<'a, T: OffsetSizeTrait, V: ArrayAccessor>( - string_array_iter: &mut ArrayIter, - n_array: &Int64Array, -) -> Result { - let result = string_array_iter - .zip(n_array.iter()) - .map(|(string, n)| match (string, n) { - (Some(string), Some(n)) => match n.cmp(&0) { - Ordering::Less => Some( - string - .chars() - .skip(n.unsigned_abs() as usize) - .collect::(), - ), - Ordering::Equal => Some("".to_string()), - Ordering::Greater => Some( - string - .chars() - .skip(max(string.chars().count() as i64 - n, 0) as usize) - .collect::(), - ), - }, - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, StringViewArray}; + use arrow::datatypes::DataType::Utf8View; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -188,8 +127,8 @@ mod tests { ], Ok(Some("de")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( RightFunc::new(), @@ -199,8 +138,8 @@ mod tests { ], Ok(Some("abcde")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( RightFunc::new(), @@ -210,8 +149,19 @@ mod tests { ], Ok(Some("cde")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray + ); + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(i64::MIN)), + ], + Ok(Some("")), + &str, + Utf8View, + StringViewArray ); test_function!( RightFunc::new(), @@ -221,8 +171,8 @@ mod tests { ], Ok(Some("")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( RightFunc::new(), @@ -232,8 +182,8 @@ mod tests { ], Ok(Some("")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( RightFunc::new(), @@ -243,8 +193,8 @@ mod tests { ], Ok(None), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( RightFunc::new(), @@ -254,30 +204,30 @@ mod tests { ], Ok(None), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( RightFunc::new(), vec![ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from("joséérend")), ColumnarValue::Scalar(ScalarValue::from(5i64)), ], - Ok(Some("éésoj")), + Ok(Some("érend")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); test_function!( RightFunc::new(), vec![ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from("joséérend")), ColumnarValue::Scalar(ScalarValue::from(-3i64)), ], - Ok(Some("éésoj")), + Ok(Some("éérend")), &str, - Utf8, - StringArray + Utf8View, + StringViewArray ); #[cfg(not(feature = "unicode_expressions"))] test_function!( @@ -290,9 +240,74 @@ mod tests { "function right requires compilation with feature flag: unicode_expressions." ), &str, - Utf8, - StringArray + Utf8View, + StringViewArray + ); + + // StringView cases + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("abcde".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("de")), + &str, + Utf8View, + StringViewArray ); + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("abcde".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(200i64)), + ], + Ok(Some("abcde")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(200i64)), + ], + Ok(Some("")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "joséérend".to_string() + ))), + ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ], + Ok(Some("éérend")), + &str, + Utf8View, + StringViewArray + ); + + // Unicode indexing case + let input = "joé楽s𐀀so↓j"; + for n in 1..=input.chars().count() { + let expected = input.chars().skip(n).collect::(); + test_function!( + RightFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from(input)), + ColumnarValue::Scalar(ScalarValue::from(-(n as i64))), + ], + Ok(Some(expected.as_str())), + &str, + Utf8View, + StringViewArray + ); + } Ok(()) } diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index a7e951051d7cd..b3e14f93526ab 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -15,25 +15,25 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use std::fmt::Write; +use std::sync::Arc; + +use DataType::{LargeUtf8, Utf8, Utf8View}; use arrow::array::{ ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, OffsetSizeTrait, StringArrayType, StringViewArray, }; use arrow::datatypes::DataType; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::as_int64_array; -use datafusion_common::DataFusionError; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; -use std::any::Any; -use std::fmt::Write; -use std::sync::Arc; -use unicode_segmentation::UnicodeSegmentation; -use DataType::{LargeUtf8, Utf8, Utf8View}; #[user_doc( doc_section(label = "String Functions"), @@ -48,7 +48,10 @@ use DataType::{LargeUtf8, Utf8, Utf8View}; +-----------------------------------------------+ ```"#, standard_argument(name = "str", prefix = "String"), - argument(name = "n", description = "String length to pad to."), + argument( + name = "n", + description = "String length to pad to. If the input string is longer than this length, it is truncated." + ), argument( name = "padding_str", description = "String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._" @@ -92,10 +95,6 @@ impl RPadFunc { } impl ScalarUDFImpl for RPadFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "rpad" } @@ -108,36 +107,69 @@ impl ScalarUDFImpl for RPadFunc { utf8_to_str_type(&arg_types[0], "rpad") } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - let args = &args.args; - match ( - args.len(), - args[0].data_type(), - args.get(2).map(|arg| arg.data_type()), - ) { - (2, Utf8 | Utf8View, _) => { - make_scalar_function(rpad::, vec![])(args) - } - (2, LargeUtf8, _) => make_scalar_function(rpad::, vec![])(args), - (3, Utf8 | Utf8View, Some(Utf8 | Utf8View)) => { - make_scalar_function(rpad::, vec![])(args) - } - (3, LargeUtf8, Some(LargeUtf8)) => { - make_scalar_function(rpad::, vec![])(args) - } - (3, Utf8 | Utf8View, Some(LargeUtf8)) => { - make_scalar_function(rpad::, vec![])(args) - } - (3, LargeUtf8, Some(Utf8 | Utf8View)) => { - make_scalar_function(rpad::, vec![])(args) - } - (_, _, _) => { - exec_err!("Unsupported combination of data types for function rpad") + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { + args, number_rows, .. + } = args; + + const MAX_SCALAR_TARGET_LEN: usize = 16384; + + // If target_len and fill (if specified) are constants, use the + // scalar fast path. + if let Some(target_len) = try_as_scalar_i64(&args[1]) { + let target_len: usize = match usize::try_from(target_len) { + Ok(n) if n <= i32::MAX as usize => n, + Ok(n) => { + return exec_err!( + "rpad requested length {n} too large, maximum allowed length is {}", + i32::MAX + ); + } + Err(_) => 0, // negative → 0 + }; + + let fill_str = if args.len() == 3 { + try_as_scalar_str(&args[2]) + } else { + Some(" ") + }; + + // Skip the fast path for very large `target_len` values to avoid + // consuming too much memory. Such large padding values are uncommon + // in practice. + if target_len <= MAX_SCALAR_TARGET_LEN + && let Some(fill) = fill_str + { + let string_array = args[0].to_array_of_size(number_rows)?; + let result = match string_array.data_type() { + Utf8View => rpad_scalar_args::<_, i32>( + string_array.as_string_view(), + target_len, + fill, + ), + Utf8 => rpad_scalar_args::<_, i32>( + string_array.as_string::(), + target_len, + fill, + ), + LargeUtf8 => rpad_scalar_args::<_, i64>( + string_array.as_string::(), + target_len, + fill, + ), + other => { + exec_err!("Unsupported data type {other:?} for function rpad") + } + }?; + return Ok(ColumnarValue::Array(result)); } } + + match args[0].data_type() { + Utf8 | Utf8View => make_scalar_function(rpad::, vec![])(&args), + LargeUtf8 => make_scalar_function(rpad::, vec![])(&args), + other => exec_err!("Unsupported data type {other:?} for function rpad"), + } } fn documentation(&self) -> Option<&Documentation> { @@ -145,154 +177,335 @@ impl ScalarUDFImpl for RPadFunc { } } -fn rpad( - args: &[ArrayRef], +use super::common::{ + StringCharLen, char_count_or_boundary, try_as_scalar_i64, try_as_scalar_str, +}; + +/// Optimized rpad for constant target_len and fill arguments. +fn rpad_scalar_args<'a, V: StringArrayType<'a> + Copy, T: OffsetSizeTrait>( + string_array: V, + target_len: usize, + fill: &str, ) -> Result { - if args.len() < 2 || args.len() > 3 { + if string_array.is_ascii() && fill.is_ascii() { + rpad_scalar_ascii::(string_array, target_len, fill) + } else { + rpad_scalar_unicode::(string_array, target_len, fill) + } +} + +fn rpad_scalar_ascii<'a, V: StringArrayType<'a> + Copy, T: OffsetSizeTrait>( + string_array: V, + target_len: usize, + fill: &str, +) -> Result { + // With a scalar `target_len` and `fill`, we can precompute a padding + // buffer of `target_len` fill characters repeated cyclically. + let padding_buf = if !fill.is_empty() { + let mut buf = String::with_capacity(target_len); + while buf.len() < target_len { + let remaining = target_len - buf.len(); + if remaining >= fill.len() { + buf.push_str(fill); + } else { + buf.push_str(&fill[..remaining]); + } + } + buf + } else { + String::new() + }; + + // Each output row is exactly `target_len` ASCII bytes (string + padding). + let data_capacity = string_array.len().saturating_mul(target_len); + let mut builder = + GenericStringBuilder::::with_capacity(string_array.len(), data_capacity); + + for maybe_string in string_array.iter() { + match maybe_string { + Some(string) => { + let str_len = string.len(); + if target_len <= str_len { + builder.append_value(&string[..target_len]); + } else if fill.is_empty() { + builder.append_value(string); + } else { + let pad_needed = target_len - str_len; + builder.write_str(string)?; + builder.write_str(&padding_buf[..pad_needed])?; + builder.append_value(""); + } + } + None => builder.append_null(), + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +fn rpad_scalar_unicode<'a, V: StringArrayType<'a> + Copy, T: OffsetSizeTrait>( + string_array: V, + target_len: usize, + fill: &str, +) -> Result { + let fill_chars: Vec = fill.chars().collect(); + + // With a scalar `target_len` and `fill`, we can precompute a padding buffer + // of `target_len` fill characters repeated cyclically. Because Unicode + // characters are variable-width, we build a byte-offset table to map from + // character count to the corresponding byte position in the padding buffer. + let (padding_buf, char_byte_offsets) = if !fill_chars.is_empty() { + let mut buf = String::new(); + let mut offsets = Vec::with_capacity(target_len + 1); + offsets.push(0usize); + for i in 0..target_len { + buf.push(fill_chars[i % fill_chars.len()]); + offsets.push(buf.len()); + } + (buf, offsets) + } else { + (String::new(), vec![0]) + }; + + // Each output row is `target_len` chars; multiply by 4 (max UTF-8 bytes + // per char) for an upper bound in bytes. + let data_capacity = string_array.len().saturating_mul(target_len * 4); + let mut builder = + GenericStringBuilder::::with_capacity(string_array.len(), data_capacity); + + for maybe_string in string_array.iter() { + match maybe_string { + Some(string) => match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { + builder.write_str(string)?; + if !fill_chars.is_empty() { + let pad_chars = target_len - char_count; + let pad_bytes = char_byte_offsets[pad_chars]; + builder.write_str(&padding_buf[..pad_bytes])?; + } + builder.append_value(""); + } + }, + None => builder.append_null(), + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +fn rpad(args: &[ArrayRef]) -> Result { + if args.len() <= 1 || args.len() > 3 { return exec_err!( - "rpad was called with {} arguments. It requires 2 or 3 arguments.", + "rpad was called with {} arguments. It requires at least 2 and at most 3.", args.len() ); } let length_array = as_int64_array(&args[1])?; - match ( - args.len(), - args[0].data_type(), - args.get(2).map(|arg| arg.data_type()), - ) { - (2, Utf8View, _) => { - rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>( - &args[0].as_string_view(), - length_array, - None, - ) - } - (3, Utf8View, Some(Utf8View)) => { - rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>( - &args[0].as_string_view(), - length_array, - Some(args[2].as_string_view()), - ) - } - (3, Utf8View, Some(Utf8 | LargeUtf8)) => { - rpad_impl::<&StringViewArray, &GenericStringArray, StringArrayLen>( - &args[0].as_string_view(), - length_array, - Some(args[2].as_string::()), - ) - } - (3, Utf8 | LargeUtf8, Some(Utf8View)) => rpad_impl::< - &GenericStringArray, - &StringViewArray, - StringArrayLen, - >( - &args[0].as_string::(), + + match (args.len(), args[0].data_type()) { + (2, Utf8View) => rpad_impl::<&StringViewArray, &GenericStringArray, T>( + &args[0].as_string_view(), length_array, - Some(args[2].as_string_view()), + None, ), - (_, _, _) => rpad_impl::< - &GenericStringArray, - &GenericStringArray, - StringArrayLen, - >( - &args[0].as_string::(), + (2, Utf8 | LargeUtf8) => rpad_impl::< + &GenericStringArray, + &GenericStringArray, + T, + >(&args[0].as_string::(), length_array, None), + (3, Utf8View) => rpad_with_replace::<&StringViewArray, T>( + &args[0].as_string_view(), length_array, - args.get(2).map(|arg| arg.as_string::()), + &args[2], ), + (3, Utf8 | LargeUtf8) => rpad_with_replace::<&GenericStringArray, T>( + &args[0].as_string::(), + length_array, + &args[2], + ), + (len, dt) => unreachable!("rpad: unexpected arg count ({len}) or type ({dt})"), } } -/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. -/// rpad('hi', 5, 'xy') = 'hixyx' -fn rpad_impl<'a, StringArrType, FillArrType, StringArrayLen>( - string_array: &StringArrType, +fn rpad_with_replace<'a, V, T: OffsetSizeTrait>( + string_array: &V, length_array: &Int64Array, - fill_array: Option, + fill_array: &'a ArrayRef, ) -> Result where - StringArrType: StringArrayType<'a>, - FillArrType: StringArrayType<'a>, - StringArrayLen: OffsetSizeTrait, + V: StringArrayType<'a>, { - let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - - match fill_array { - None => { - string_array.iter().zip(length_array.iter()).try_for_each( - |(string, length)| -> Result<(), DataFusionError> { - match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!( - "rpad requested length {} too large", - length - ); - } - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - builder.append_value(""); - } else { - let graphemes = - string.graphemes(true).collect::>(); - if length < graphemes.len() { - builder.append_value(graphemes[..length].concat()); - } else { - builder.write_str(string)?; - builder.write_str( - &" ".repeat(length - graphemes.len()), - )?; - builder.append_value(""); + match fill_array.data_type() { + Utf8View => rpad_impl::( + string_array, + length_array, + Some(fill_array.as_string_view()), + ), + LargeUtf8 => rpad_impl::, T>( + string_array, + length_array, + Some(fill_array.as_string::()), + ), + Utf8 => rpad_impl::, T>( + string_array, + length_array, + Some(fill_array.as_string::()), + ), + other => { + exec_err!("Unsupported data type {other:?} for function rpad") + } + } +} + +fn rpad_impl<'a, V, V2, T>( + string_array: &V, + length_array: &Int64Array, + fill_array: Option, +) -> Result +where + V: StringArrayType<'a>, + V2: StringArrayType<'a>, + T: OffsetSizeTrait, +{ + let array = if let Some(fill_array) = fill_array { + let mut builder: GenericStringBuilder = GenericStringBuilder::new(); + let mut fill_chars_buf = Vec::new(); + + for ((string, target_len), fill) in string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.iter()) + { + if let (Some(string), Some(target_len), Some(fill)) = + (string, target_len, fill) + { + if target_len > i32::MAX as i64 { + return exec_err!( + "rpad requested length {target_len} too large, maximum allowed length is {}", + i32::MAX + ); + } + + let target_len = if target_len < 0 { + 0 + } else { + target_len as usize + }; + if target_len == 0 { + builder.append_value(""); + continue; + } + + if string.is_ascii() && fill.is_ascii() { + // ASCII fast path: byte length == character length. + let str_len = string.len(); + if target_len < str_len { + builder.append_value(&string[..target_len]); + } else if fill.is_empty() { + builder.append_value(string); + } else { + let pad_len = target_len - str_len; + let fill_len = fill.len(); + let full_reps = pad_len / fill_len; + let remainder = pad_len % fill_len; + builder.write_str(string)?; + for _ in 0..full_reps { + builder.write_str(fill)?; + } + if remainder > 0 { + builder.write_str(&fill[..remainder])?; + } + builder.append_value(""); + } + } else { + fill_chars_buf.clear(); + fill_chars_buf.extend(fill.chars()); + + match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { + builder.write_str(string)?; + if !fill_chars_buf.is_empty() { + for l in 0..target_len - char_count { + let c = *fill_chars_buf + .get(l % fill_chars_buf.len()) + .unwrap(); + builder.write_char(c)?; } } + builder.append_value(""); } - _ => builder.append_null(), } - Ok(()) - }, - )?; + } + } else { + builder.append_null(); + } } - Some(fill_array) => { - string_array - .iter() - .zip(length_array.iter()) - .zip(fill_array.iter()) - .try_for_each( - |((string, length), fill)| -> Result<(), DataFusionError> { - match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!( - "rpad requested length {} too large", - length - ); - } - let length = if length < 0 { 0 } else { length as usize }; - let graphemes = - string.graphemes(true).collect::>(); - - if length < graphemes.len() { - builder.append_value(graphemes[..length].concat()); - } else if fill.is_empty() { - builder.append_value(string); - } else { - builder.write_str(string)?; - fill.chars() - .cycle() - .take(length - graphemes.len()) - .for_each(|ch| builder.write_char(ch).unwrap()); - builder.append_value(""); - } + + builder.finish() + } else { + let mut builder: GenericStringBuilder = GenericStringBuilder::new(); + + for (string, target_len) in string_array.iter().zip(length_array.iter()) { + if let (Some(string), Some(target_len)) = (string, target_len) { + if target_len > i32::MAX as i64 { + return exec_err!( + "rpad requested length {target_len} too large, maximum allowed length is {}", + i32::MAX + ); + } + + let target_len = if target_len < 0 { + 0 + } else { + target_len as usize + }; + if target_len == 0 { + builder.append_value(""); + continue; + } + + if string.is_ascii() { + // ASCII fast path: byte length == character length + let str_len = string.len(); + if target_len < str_len { + builder.append_value(&string[..target_len]); + } else { + builder.write_str(string)?; + for _ in 0..(target_len - str_len) { + builder.write_str(" ")?; + } + builder.append_value(""); + } + } else { + match char_count_or_boundary(string, target_len) { + StringCharLen::ByteOffset(offset) => { + builder.append_value(&string[..offset]); + } + StringCharLen::CharCount(char_count) => { + builder.write_str(string)?; + for _ in 0..(target_len - char_count) { + builder.write_str(" ")?; } - _ => builder.append_null(), + builder.append_value(""); } - Ok(()) - }, - )?; + } + } + } else { + builder.append_null(); + } } - } - Ok(Arc::new(builder.finish()) as ArrayRef) + builder.finish() + }; + + Ok(Arc::new(array) as ArrayRef) } #[cfg(test)] @@ -447,6 +660,29 @@ mod tests { Utf8, StringArray ); + test_function!( + RPadFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("hello")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("he")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(6i64)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(Some("hixyxy")), + &str, + Utf8, + StringArray + ); test_function!( RPadFunc::new(), vec![ diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 294f783ba693b..d361ecdbc1721 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use crate::utils::{make_scalar_function, utf8_to_int_type}; @@ -26,12 +25,13 @@ use arrow::datatypes::{ ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type, }; use datafusion_common::types::logical_string; -use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{ - Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass, - Volatility, + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; +use memchr::{memchr, memmem}; #[user_doc( doc_section(label = "String Functions"), @@ -77,10 +77,6 @@ impl StrposFunc { } impl ScalarUDFImpl for StrposFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "strpos" } @@ -109,10 +105,15 @@ impl ScalarUDFImpl for StrposFunc { ) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // Fast path for array haystack and scalar needle + if let ( + ColumnarValue::Array(haystack_array), + ColumnarValue::Scalar(needle_scalar), + ) = (&args.args[0], &args.args[1]) + { + return strpos_scalar_needle(haystack_array, needle_scalar); + } make_scalar_function(strpos, vec![])(&args.args) } @@ -126,103 +127,121 @@ impl ScalarUDFImpl for StrposFunc { } fn strpos(args: &[ArrayRef]) -> Result { - match (args[0].data_type(), args[1].data_type()) { - (DataType::Utf8, DataType::Utf8) => { - let string_array = args[0].as_string::(); - let substring_array = args[1].as_string::(); - calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array) - } - (DataType::Utf8, DataType::Utf8View) => { - let string_array = args[0].as_string::(); - let substring_array = args[1].as_string_view(); - calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array) - } - (DataType::Utf8, DataType::LargeUtf8) => { - let string_array = args[0].as_string::(); - let substring_array = args[1].as_string::(); - calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array) - } - (DataType::LargeUtf8, DataType::Utf8) => { - let string_array = args[0].as_string::(); - let substring_array = args[1].as_string::(); - calculate_strpos::<_, _, Int64Type>(&string_array, &substring_array) - } - (DataType::LargeUtf8, DataType::Utf8View) => { - let string_array = args[0].as_string::(); - let substring_array = args[1].as_string_view(); - calculate_strpos::<_, _, Int64Type>(&string_array, &substring_array) - } - (DataType::LargeUtf8, DataType::LargeUtf8) => { - let string_array = args[0].as_string::(); - let substring_array = args[1].as_string::(); - calculate_strpos::<_, _, Int64Type>(&string_array, &substring_array) + /// Dispatches the needle array to the correct string type and calls + /// `strpos_general` with the given haystack and result type. + macro_rules! dispatch_needle { + ($haystack:expr, $result_type:ty, $args:expr) => { + match $args[1].data_type() { + DataType::Utf8 => strpos_general::<_, _, $result_type>( + $haystack, + $args[1].as_string::(), + ), + DataType::LargeUtf8 => strpos_general::<_, _, $result_type>( + $haystack, + $args[1].as_string::(), + ), + DataType::Utf8View => strpos_general::<_, _, $result_type>( + $haystack, + $args[1].as_string_view(), + ), + other => exec_err!("Unsupported data type {other:?} for strpos needle"), + } + }; + } + + match args[0].data_type() { + DataType::Utf8 => dispatch_needle!(args[0].as_string::(), Int32Type, args), + DataType::LargeUtf8 => { + dispatch_needle!(args[0].as_string::(), Int64Type, args) } - (DataType::Utf8View, DataType::Utf8View) => { - let string_array = args[0].as_string_view(); - let substring_array = args[1].as_string_view(); - calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array) + DataType::Utf8View => dispatch_needle!(args[0].as_string_view(), Int32Type, args), + other => { + exec_err!("Unsupported data type {other:?} for strpos haystack") } - (DataType::Utf8View, DataType::Utf8) => { - let string_array = args[0].as_string_view(); - let substring_array = args[1].as_string::(); - calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array) + } +} + +/// Find `needle` in `haystack` using `memchr` to quickly skip to positions +/// where the first byte matches, then verify the remaining bytes. Returns +/// the 0-based byte offset of the match, or `None` if not found. An empty +/// `needle` matches at offset 0. +fn find_substring_bytes(haystack: &[u8], needle: &[u8]) -> Option { + let needle_len = needle.len(); + let haystack_len = haystack.len(); + + if needle_len == 0 { + return Some(0); + } + if needle_len > haystack_len { + return None; + } + + let first_byte = needle[0]; + let mut offset = 0; + + while let Some(pos) = memchr(first_byte, &haystack[offset..]) { + let start = offset + pos; + if start + needle_len > haystack.len() { + return None; } - (DataType::Utf8View, DataType::LargeUtf8) => { - let string_array = args[0].as_string_view(); - let substring_array = args[1].as_string::(); - calculate_strpos::<_, _, Int32Type>(&string_array, &substring_array) + if haystack[start..start + needle_len] == *needle { + return Some(start); } + offset = start + 1; + } - other => { - exec_err!("Unsupported data type combination {other:?} for function strpos") - } + None +} + +/// Converts a byte offset within a haystack to a 1-based character position. +/// For ASCII data, byte offset == char offset so we just add 1. For non-ASCII, +/// we count UTF-8 characters in the prefix before the match. +#[inline] +fn byte_offset_to_char_pos( + haystack: &str, + byte_offset: usize, + ascii_only: bool, +) -> Option { + if ascii_only { + return T::Native::from_usize(byte_offset + 1); } + // SAFETY: byte_offset is at a UTF-8 char boundary because both haystack + // and needle are valid UTF-8, and UTF-8 is self-synchronizing: a valid + // needle byte sequence can only match starting at a char boundary in a + // valid haystack. + debug_assert!(haystack.is_char_boundary(byte_offset)); + let prefix = + unsafe { std::str::from_utf8_unchecked(&haystack.as_bytes()[..byte_offset]) }; + T::Native::from_usize(prefix.chars().count() + 1) } -/// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) -/// strpos('high', 'ig') = 2 -/// The implementation uses UTF-8 code points as characters -fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>( - string_array: &V1, - substring_array: &V2, +/// Fallback strpos implementation for when both haystack and needle are arrays. +/// Building a new `memmem::Finder` for every row is too expensive; it is faster +/// to use `memchr::memchr`. +fn strpos_general<'a, V1, V2, T: ArrowPrimitiveType>( + haystack_array: V1, + needle_array: V2, ) -> Result where - V1: StringArrayType<'a, Item = &'a str>, - V2: StringArrayType<'a, Item = &'a str>, + V1: StringArrayType<'a, Item = &'a str> + Copy, + V2: StringArrayType<'a, Item = &'a str> + Copy, { - let ascii_only = substring_array.is_ascii() && string_array.is_ascii(); - let string_iter = string_array.iter(); - let substring_iter = substring_array.iter(); - - let result = string_iter - .zip(substring_iter) - .map(|(string, substring)| match (string, substring) { - (Some(string), Some(substring)) => { - // If only ASCII characters are present, we can use the slide window method to find - // the sub vector in the main vector. This is faster than string.find() method. - if ascii_only { - // If the substring is empty, the result is 1. - if substring.is_empty() { - T::Native::from_usize(1) - } else { - T::Native::from_usize( - string - .as_bytes() - .windows(substring.len()) - .position(|w| w == substring.as_bytes()) - .map(|x| x + 1) - .unwrap_or(0), - ) + let ascii_only = needle_array.is_ascii() && haystack_array.is_ascii(); + let haystack_iter = haystack_array.iter(); + let needle_iter = needle_array.iter(); + + let result = haystack_iter + .zip(needle_iter) + .map(|(haystack, needle)| match (haystack, needle) { + (Some(haystack), Some(needle)) => { + let haystack_bytes = haystack.as_bytes(); + let needle_bytes = needle.as_bytes(); + + match find_substring_bytes(haystack_bytes, needle_bytes) { + None => T::Native::from_usize(0), + Some(byte_offset) => { + byte_offset_to_char_pos::(haystack, byte_offset, ascii_only) } - } else { - // The `find` method returns the byte index of the substring. - // We count the number of chars up to that byte index. - T::Native::from_usize( - string - .find(substring) - .map(|x| string[..x].chars().count() + 1) - .unwrap_or(0), - ) } } _ => None, @@ -232,6 +251,85 @@ where Ok(Arc::new(result) as ArrayRef) } +/// Fast-path strpos implementation for when the haystack is an array and the +/// needle is a scalar. We can pre-build a `memmem::Finder` once and reuse it +/// for every haystack row. +fn strpos_scalar_needle( + haystack_array: &ArrayRef, + needle_scalar: &ScalarValue, +) -> Result { + let Some(needle_str) = needle_scalar.try_as_str() else { + return exec_err!( + "Unsupported data type {:?} for strpos needle", + needle_scalar.data_type() + ); + }; + + // Null needle => null result for every row + let Some(needle_str) = needle_str else { + return match haystack_array.data_type() { + DataType::LargeUtf8 => { + Ok(ColumnarValue::Array(Arc::new( + PrimitiveArray::::new_null(haystack_array.len()), + ))) + } + DataType::Utf8 | DataType::Utf8View => Ok(ColumnarValue::Array(Arc::new( + PrimitiveArray::::new_null(haystack_array.len()), + ))), + other => exec_err!("Unsupported data type {other:?} for strpos haystack"), + }; + }; + + let result = match haystack_array.data_type() { + DataType::Utf8 => strpos_with_finder::<_, Int32Type>( + haystack_array.as_string::(), + needle_str, + ), + DataType::LargeUtf8 => strpos_with_finder::<_, Int64Type>( + haystack_array.as_string::(), + needle_str, + ), + DataType::Utf8View => strpos_with_finder::<_, Int32Type>( + haystack_array.as_string_view(), + needle_str, + ), + other => { + exec_err!("Unsupported data type {other:?} for strpos haystack") + } + }?; + Ok(ColumnarValue::Array(result)) +} + +fn strpos_with_finder<'a, V, T: ArrowPrimitiveType>( + haystack_array: V, + needle: &str, +) -> Result +where + V: StringArrayType<'a, Item = &'a str> + Copy, +{ + let needle_bytes = needle.as_bytes(); + let ascii_haystack = haystack_array.is_ascii(); + let finder = memmem::Finder::new(needle_bytes); + + let result = haystack_array + .iter() + .map(|string| match string { + Some(string) => { + let haystack_bytes = string.as_bytes(); + match finder.find(haystack_bytes) { + None => T::Native::from_usize(0), + Some(byte_offset) => { + byte_offset_to_char_pos::(string, byte_offset, ascii_haystack) + } + } + } + None => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + #[cfg(test)] mod tests { use arrow::array::{Array, Int32Array, Int64Array}; diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 27b194ca2b99d..903c03857e370 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -15,25 +15,24 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; -use crate::strings::make_and_append_view; +use crate::strings::{StringViewArrayBuilder, append_view}; use crate::utils::make_scalar_function; use arrow::array::{ - Array, ArrayIter, ArrayRef, AsArray, Int64Array, NullBufferBuilder, StringArrayType, - StringViewArray, StringViewBuilder, + Array, ArrayRef, AsArray, GenericStringArray, Int64Array, OffsetSizeTrait, + StringArrayType, StringViewArray, make_view, }; -use arrow::buffer::ScalarBuffer; +use arrow::buffer::{NullBuffer, ScalarBuffer}; use arrow::datatypes::DataType; use datafusion_common::cast::as_int64_array; use datafusion_common::types::{ - logical_int32, logical_int64, logical_string, NativeType, + NativeType, logical_int32, logical_int64, logical_string, }; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::{ - Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, - TypeSignatureClass, Volatility, + Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; @@ -53,7 +52,7 @@ use datafusion_macros::user_doc; standard_argument(name = "str", prefix = "String"), argument( name = "start_pos", - description = "Character position to start the substring at. The first character in the string has a position of 1." + description = "Character position to start the substring at. The first character in the string has a position of 1. If the start position is less than 1, it is treated as if it is before the start of the string and the (absolute) number of characters before position 1 is subtracted from `length` (if given). For example, `substr('abc', -3, 6)` returns `'ab'`." ), argument( name = "length", @@ -104,10 +103,6 @@ impl SubstrFunc { } impl ScalarUDFImpl for SubstrFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "substr" } @@ -121,10 +116,7 @@ impl ScalarUDFImpl for SubstrFunc { Ok(DataType::Utf8View) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(substr, vec![])(&args.args) } @@ -137,19 +129,16 @@ impl ScalarUDFImpl for SubstrFunc { } } -/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) -/// substr('alphabet', 3) = 'phabet' -/// substr('alphabet', 3, 2) = 'ph' -/// The implementation uses UTF-8 code points as characters +/// Dispatches `substr` to the appropriate string array implementation. fn substr(args: &[ArrayRef]) -> Result { match args[0].data_type() { DataType::Utf8 => { let string_array = args[0].as_string::(); - string_substr::<_>(string_array, &args[1..]) + generic_string_substr(string_array, &args[1..]) } DataType::LargeUtf8 => { let string_array = args[0].as_string::(); - string_substr::<_>(string_array, &args[1..]) + generic_string_substr(string_array, &args[1..]) } DataType::Utf8View => { let string_array = args[0].as_string_view(); @@ -162,67 +151,74 @@ fn substr(args: &[ArrayRef]) -> Result { } } -// Convert the given `start` and `count` to valid byte indices within `input` string -// -// Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start, count)` -// `start` is 1-based, if `count` is not provided count to the end of the string -// Input indices are character-based, and return values are byte indices -// The input bounds can be outside string bounds, this function will return -// the intersection between input bounds and valid string bounds -// `input_ascii_only` is used to optimize this function if `input` is ASCII-only -// -// * Example -// 'Hi🌏' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx] -// `get_true_start_end('Hi🌏', 1, None) -> (0, 6)` -// `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)` -// `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)` -fn get_true_start_end( +/// Convert the given `start` and `count` to valid byte indices within `input` string. +/// +/// Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start, count)`. +/// `start` is 1-based; if `count` is not provided, returns indices to the end of the string. +/// Input indices are character-based, and return values are byte indices. +/// The input bounds can be outside string bounds; this function will return +/// the intersection between input bounds and valid string bounds. +/// `is_input_ascii_only` is used to optimize this function if `input` is ASCII-only. +/// +/// # Example +/// ```text +/// 'Hi🌏' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx] +/// get_true_start_end('Hi🌏', 1, None) -> Ok((0, 6)) +/// get_true_start_end('Hi🌏', 1, Some(1)) -> Ok((0, 1)) +/// get_true_start_end('Hi🌏', -10, Some(2)) -> Ok((0, 0)) +/// ``` +pub fn get_true_start_end( input: &str, start: i64, - count: Option, + count: Option, is_input_ascii_only: bool, -) -> (usize, usize) { - let start = start.checked_sub(1).unwrap_or(start); +) -> Result<(usize, usize)> { + if let Some(count) = count + && count < 0 + { + return exec_err!("negative count not allowed: {count}"); + } + + // The caller-provided `start` is 1-indexed. + let Some(start) = start.checked_sub(1) else { + return exec_err!("start position overflow: {start}"); + }; let end = match count { - Some(count) => start + count as i64, + Some(count) => start.saturating_add(count), None => input.len() as i64, }; - let count_to_end = count.is_some(); let start = start.clamp(0, input.len() as i64) as usize; let end = end.clamp(0, input.len() as i64) as usize; - let count = end - start; - // If input is ASCII-only, byte-based indices equals to char-based indices + // If input is ASCII-only, byte-based indices equal char-based indices if is_input_ascii_only { - return (start, end); + return Ok((start, end)); } - // Otherwise, calculate byte indices from char indices - // Note this decoding is relatively expensive for this simple `substr` function,, - // so the implementation attempts to decode in one pass (and caused the complexity) - let (mut st, mut ed) = (input.len(), input.len()); - let mut start_counting = false; - let mut cnt = 0; - for (char_cnt, (byte_cnt, _)) in input.char_indices().enumerate() { - if char_cnt == start { - st = byte_cnt; - if count_to_end { - start_counting = true; - } else { + // Otherwise, calculate byte indices from char indices. We initialize both + // `byte_start` and `byte_end` to the string length to handle cases where + // the requested 'start' or 'end' positions are at or beyond the end of the + // string (resulting in an empty substring). + let mut byte_start = input.len(); + let mut byte_end = input.len(); + + for (char_idx, (byte_idx, _)) in input.char_indices().enumerate() { + if char_idx == start { + byte_start = byte_idx; + // If no length is specified, we only need the start offset. + if count.is_none() { break; } } - if start_counting { - if cnt == count { - ed = byte_cnt; - break; - } - cnt += 1; + if char_idx == end { + byte_end = byte_idx; + break; } } - (st, ed) + + Ok((byte_start, byte_end)) } // String characters are variable length encoded in UTF-8, `substr()` function's @@ -235,7 +231,7 @@ fn get_true_start_end( // string, such as `substr(long_str_with_1k_chars, 1, 32)`. // In such case the overhead of ASCII-validation may not be worth it, so // skip the validation for short prefix for now. -fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( +pub fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( string_array: &V, start: &Int64Array, count: Option<&Int64Array>, @@ -247,7 +243,7 @@ fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( // HACK: can be simplified if function has specialized // implementation for `ScalarValue` (implement without `make_scalar_function()`) - let avg_prefix_len = start + let total_prefix_len = start .iter() .zip(count.iter()) .take(n_sample) @@ -255,11 +251,11 @@ fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( let start = start.unwrap_or(0); let count = count.unwrap_or(0); // To get substring, need to decode from 0 to start+count instead of start to start+count - start + count + start.saturating_add(count) }) - .sum::(); + .fold(0i64, |acc, val| acc.saturating_add(val)); - avg_prefix_len as f64 / n_sample as f64 <= short_prefix_threshold + (total_prefix_len as f64 / n_sample as f64) <= short_prefix_threshold } None => false, }; @@ -272,104 +268,42 @@ fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( } } -// The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44 -// From for ByteView fn string_view_substr( string_view_array: &StringViewArray, args: &[ArrayRef], ) -> Result { - let mut views_buf = Vec::with_capacity(string_view_array.len()); - let mut null_builder = NullBufferBuilder::new(string_view_array.len()); - let start_array = as_int64_array(&args[0])?; - let count_array_opt = if args.len() == 2 { - Some(as_int64_array(&args[1])?) - } else { - None - }; + let count_array_opt = args.get(1).map(|a| as_int64_array(a)).transpose()?; - let enable_ascii_fast_path = + let is_ascii = enable_ascii_fast_path(&string_view_array, start_array, count_array_opt); - // In either case of `substr(s, i)` or `substr(s, i, cnt)` - // If any of input argument is `NULL`, the result is `NULL` - match args.len() { - 1 => { - for ((str_opt, raw_view), start_opt) in string_view_array - .iter() - .zip(string_view_array.views().iter()) - .zip(start_array.iter()) - { - if let (Some(str), Some(start)) = (str_opt, start_opt) { - let (start, end) = - get_true_start_end(str, start, None, enable_ascii_fast_path); - let substr = &str[start..end]; - - make_and_append_view( - &mut views_buf, - &mut null_builder, - raw_view, - substr, - start as u32, - ); - } else { - null_builder.append_null(); - views_buf.push(0); - } - } - } - 2 => { - let count_array = count_array_opt.unwrap(); - for (((str_opt, raw_view), start_opt), count_opt) in string_view_array - .iter() - .zip(string_view_array.views().iter()) - .zip(start_array.iter()) - .zip(count_array.iter()) - { - if let (Some(str), Some(start), Some(count)) = - (str_opt, start_opt, count_opt) - { - if count < 0 { - return exec_err!( - "negative substring length not allowed: substr(, {start}, {count})" - ); - } else { - if start == i64::MIN { - return exec_err!( - "negative overflow when calculating skip value" - ); - } - let (start, end) = get_true_start_end( - str, - start, - Some(count as u64), - enable_ascii_fast_path, - ); - let substr = &str[start..end]; - - make_and_append_view( - &mut views_buf, - &mut null_builder, - raw_view, - substr, - start as u32, - ); - } - } else { - null_builder.append_null(); - views_buf.push(0); - } - } - } - other => { - return exec_err!( - "substr was called with {other} arguments. It requires 2 or 3." - ) + // Combine null bitmaps from all inputs in bulk. + let nulls = NullBuffer::union_many([ + string_view_array.nulls(), + start_array.nulls(), + count_array_opt.and_then(|a| a.nulls()), + ]); + + let mut views_buf = Vec::with_capacity(string_view_array.len()); + + for (i, raw_view) in string_view_array.views().iter().enumerate() { + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + views_buf.push(0); + continue; } + + let string = string_view_array.value(i); + let start = start_array.value(i); + let count = count_array_opt.map(|a| a.value(i)); + + let (byte_start, byte_end) = get_true_start_end(string, start, count, is_ascii)?; + let substr = &string[byte_start..byte_end]; + + append_view(&mut views_buf, raw_view, substr, byte_start as u32); } let views_buf = ScalarBuffer::from(views_buf); - let nulls_buf = null_builder.finish(); // Safety: // (1) The blocks of the given views are all provided @@ -379,98 +313,149 @@ fn string_view_substr( let array = StringViewArray::new_unchecked( views_buf, string_view_array.data_buffers().to_vec(), - nulls_buf, + nulls, ); Ok(Arc::new(array) as ArrayRef) } } -fn string_substr<'a, V>(string_array: V, args: &[ArrayRef]) -> Result -where - V: StringArrayType<'a>, -{ +fn values_fit_in_i32(string_array: &GenericStringArray) -> bool { + // The Arrow spec defines StringView offset fields as signed 32-bit + // integers, so the maximum representable offset is i32::MAX. + string_array + .offsets() + .last() + .map(|offset| offset.as_usize() <= i32::MAX as usize) + .unwrap_or(true) +} + +#[inline] +fn append_view_from_buffer( + views_buf: &mut Vec, + substr: &str, + byte_offset: usize, +) -> bool { + let byte_offset = + u32::try_from(byte_offset).expect("validated string buffer offset fits in i32"); + let view = make_view(substr.as_bytes(), 0, byte_offset); + views_buf.push(view); + substr.len() > 12 +} + +#[expect(clippy::needless_range_loop)] +fn generic_string_substr( + string_array: &GenericStringArray, + args: &[ArrayRef], +) -> Result { + // We'd like to return a StringViewArray that points into the input string + // array's values buffer. Since the Arrow spec defines StringView offsets + // as i32, we can't use this approach when the values buffer is >2GB, so + // fallback to copying. + if !values_fit_in_i32(string_array) { + return generic_string_substr_copy(string_array, args); + } + let start_array = as_int64_array(&args[0])?; - let count_array_opt = if args.len() == 2 { - Some(as_int64_array(&args[1])?) + let count_array_opt = args.get(1).map(|a| as_int64_array(a)).transpose()?; + + let is_ascii = enable_ascii_fast_path(&string_array, start_array, count_array_opt); + let offsets = string_array.value_offsets(); + let mut views_buf = Vec::with_capacity(string_array.len()); + let mut has_out_of_line = false; + + // Combine null bitmaps from all inputs in bulk. + let nulls = NullBuffer::union_many([ + string_array.nulls(), + start_array.nulls(), + count_array_opt.and_then(|a| a.nulls()), + ]); + + for i in 0..string_array.len() { + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + views_buf.push(0); + continue; + } + + let string = string_array.value(i); + let source_offset = offsets[i].as_usize(); + let start = start_array.value(i); + let count = count_array_opt.map(|a| a.value(i)); + + let (byte_start, byte_end) = get_true_start_end(string, start, count, is_ascii)?; + has_out_of_line |= append_view_from_buffer( + &mut views_buf, + &string[byte_start..byte_end], + source_offset + byte_start, + ); + } + + let views_buf = ScalarBuffer::from(views_buf); + + // If all result strings are stored inline, we don't need to retain the + // input string array. + let data_buffers = if has_out_of_line { + vec![string_array.values().clone()] } else { - None + vec![] }; - let enable_ascii_fast_path = - enable_ascii_fast_path(&string_array, start_array, count_array_opt); - - match args.len() { - 1 => { - let iter = ArrayIter::new(string_array); - let mut result_builder = StringViewBuilder::new(); - for (string, start) in iter.zip(start_array.iter()) { - match (string, start) { - (Some(string), Some(start)) => { - let (start, end) = get_true_start_end( - string, - start, - None, - enable_ascii_fast_path, - ); // start, end is byte-based - let substr = &string[start..end]; - result_builder.append_value(substr); - } - _ => { - result_builder.append_null(); - } - } - } - Ok(Arc::new(result_builder.finish()) as ArrayRef) - } - 2 => { - let iter = ArrayIter::new(string_array); - let count_array = count_array_opt.unwrap(); - let mut result_builder = StringViewBuilder::new(); - - for ((string, start), count) in - iter.zip(start_array.iter()).zip(count_array.iter()) - { - match (string, start, count) { - (Some(string), Some(start), Some(count)) => { - if count < 0 { - return exec_err!( - "negative substring length not allowed: substr(, {start}, {count})" - ); - } else { - if start == i64::MIN { - return exec_err!( - "negative overflow when calculating skip value" - ); - } - let (start, end) = get_true_start_end( - string, - start, - Some(count as u64), - enable_ascii_fast_path, - ); // start, end is byte-based - let substr = &string[start..end]; - result_builder.append_value(substr); - } - } - _ => { - result_builder.append_null(); - } - } - } - Ok(Arc::new(result_builder.finish()) as ArrayRef) - } - other => { - exec_err!("substr was called with {other} arguments. It requires 2 or 3.") + // Safety: + // (1) The blocks of the given views are all provided + // (2) Each referenced range in the source values buffer is within bounds + unsafe { + let array = StringViewArray::new_unchecked(views_buf, data_buffers, nulls); + Ok(Arc::new(array) as ArrayRef) + } +} + +// Fallback for `generic_string_substr` if we can't use zerocopy because the +// input string array is too large. +fn generic_string_substr_copy( + string_array: &GenericStringArray, + args: &[ArrayRef], +) -> Result { + let start_array = as_int64_array(&args[0])?; + let count_array_opt = args.get(1).map(|a| as_int64_array(a)).transpose()?; + + let is_ascii = enable_ascii_fast_path(&string_array, start_array, count_array_opt); + + // Combine null bitmaps from all inputs in bulk. + let nulls = NullBuffer::union_many([ + string_array.nulls(), + start_array.nulls(), + count_array_opt.and_then(|a| a.nulls()), + ]); + + let len = string_array.len(); + let mut result_builder = StringViewArrayBuilder::with_capacity(len); + + for i in 0..len { + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + result_builder.append_placeholder(); + continue; } + + let string = string_array.value(i); + let start = start_array.value(i); + let count = count_array_opt.map(|a| a.value(i)); + + let (byte_start, byte_end) = get_true_start_end(string, start, count, is_ascii)?; + result_builder.append_value(&string[byte_start..byte_end]); } + + Ok(Arc::new(result_builder.finish(nulls)?) as ArrayRef) } #[cfg(test)] mod tests { - use arrow::array::{Array, StringViewArray}; + use std::sync::Arc; + + use arrow::array::{ + Array, ArrayRef, AsArray, Int64Array, StringArray, StringViewArray, + }; use arrow::datatypes::DataType::Utf8View; - use datafusion_common::{exec_err, Result, ScalarValue}; + use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use crate::unicode::substr::SubstrFunc; @@ -775,7 +760,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::from(1i64)), ColumnarValue::Scalar(ScalarValue::from(-1i64)), ], - exec_err!("negative substring length not allowed: substr(, 1, -1)"), + exec_err!("negative count not allowed: -1"), &str, Utf8View, StringViewArray @@ -810,9 +795,9 @@ mod tests { SubstrFunc::new(), vec![ ColumnarValue::Scalar(ScalarValue::from("abc")), - ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), + ColumnarValue::Scalar(ScalarValue::from(i64::MIN)), ], - Ok(Some("abc")), + exec_err!("start position overflow: -9223372036854775808"), &str, Utf8View, StringViewArray @@ -821,10 +806,22 @@ mod tests { SubstrFunc::new(), vec![ ColumnarValue::Scalar(ScalarValue::from("overflow")), - ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), + ColumnarValue::Scalar(ScalarValue::from(i64::MIN)), ColumnarValue::Scalar(ScalarValue::from(1i64)), ], - exec_err!("negative overflow when calculating skip value"), + exec_err!("start position overflow: -9223372036854775808"), + &str, + Utf8View, + StringViewArray + ); + test_function!( + SubstrFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("large count")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::Scalar(ScalarValue::from(i64::MAX)), + ], + Ok(Some("arge count")), &str, Utf8View, StringViewArray @@ -832,4 +829,25 @@ mod tests { Ok(()) } + + #[test] + fn test_sliced_string_array_array_args() -> Result<()> { + // Use strings longer than 12 bytes so the result views are out-of-line. + let string_array = Arc::new(StringArray::from(vec![ + "skipped_prefix_value", + "alphabet_long_string", + "joséésojanother_long", + ])) as ArrayRef; + let string_array = string_array.slice(1, 2); + let start_array = Arc::new(Int64Array::from(vec![3, 5])) as ArrayRef; + let count_array = Arc::new(Int64Array::from(vec![15, 14])) as ArrayRef; + + let result = super::substr(&[string_array, start_array, count_array])?; + let result = result.as_string_view(); + + assert_eq!(result.value(0), "phabet_long_str"); + assert_eq!(result.value(1), "ésojanother_lo"); + + Ok(()) + } } diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index bf59787206927..d122a34a9fc38 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -15,22 +15,28 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::{ - ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, - PrimitiveArray, StringBuilder, + Array, ArrayRef, AsArray, ByteView, GenericStringArray, OffsetSizeTrait, + PrimitiveArray, StringArrayType, StringViewArray, make_view, new_null_array, }; -use arrow::datatypes::{DataType, Int32Type, Int64Type}; +use arrow::buffer::ScalarBuffer; +use arrow::datatypes::{DataType, Int64Type}; +use arrow_buffer::NullBuffer; -use crate::utils::{make_scalar_function, utf8_to_str_type}; -use datafusion_common::{exec_err, utils::take_function_args, Result}; +use crate::strings::GenericStringArrayBuilder; +use crate::utils::make_scalar_function; +use datafusion_common::{ + Result, ScalarValue, exec_datafusion_err, exec_err, utils::take_function_args, +}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; +use memchr::{memchr_iter, memmem, memrchr_iter}; #[user_doc( doc_section(label = "String Functions"), @@ -92,10 +98,6 @@ impl SubstrIndexFunc { } impl ScalarUDFImpl for SubstrIndexFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "substr_index" } @@ -105,14 +107,22 @@ impl ScalarUDFImpl for SubstrIndexFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "substr_index") + Ok(arg_types[0].clone()) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - make_scalar_function(substr_index, vec![])(&args.args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + + if let ( + ColumnarValue::Array(string_array), + ColumnarValue::Scalar(delim_scalar), + ColumnarValue::Scalar(count_scalar), + ) = (&args[0], &args[1], &args[2]) + { + return substr_index_scalar(string_array, delim_scalar, count_scalar); + } + + make_scalar_function(substr_index, vec![])(&args) } fn aliases(&self) -> &[String] { @@ -137,31 +147,35 @@ fn substr_index(args: &[ArrayRef]) -> Result { let string_array = str.as_string::(); let delimiter_array = delim.as_string::(); let count_array: &PrimitiveArray = count.as_primitive(); - substr_index_general::( + substr_index_general( string_array, delimiter_array, count_array, + GenericStringArrayBuilder::::with_capacity( + string_array.len(), + visible_string_bytes(string_array), + ), ) } DataType::LargeUtf8 => { let string_array = str.as_string::(); let delimiter_array = delim.as_string::(); let count_array: &PrimitiveArray = count.as_primitive(); - substr_index_general::( + substr_index_general( string_array, delimiter_array, count_array, + GenericStringArrayBuilder::::with_capacity( + string_array.len(), + visible_string_bytes(string_array), + ), ) } DataType::Utf8View => { let string_array = str.as_string_view(); let delimiter_array = delim.as_string_view(); let count_array: &PrimitiveArray = count.as_primitive(); - substr_index_general::( - string_array, - delimiter_array, - count_array, - ) + substr_index_view(string_array, delimiter_array, count_array) } other => { exec_err!("Unsupported data type {other:?} for function substr_index") @@ -169,75 +183,459 @@ fn substr_index(args: &[ArrayRef]) -> Result { } } -fn substr_index_general< - 'a, - T: ArrowPrimitiveType, - V: ArrayAccessor, - P: ArrayAccessor, ->( - string_array: V, - delimiter_array: V, - count_array: P, +fn substr_index_scalar( + string_array: &ArrayRef, + delim_scalar: &ScalarValue, + count_scalar: &ScalarValue, +) -> Result { + if string_array.is_empty() { + return Ok(ColumnarValue::Array(new_null_array( + string_array.data_type(), + 0, + ))); + } + + let delimiter = delim_scalar.try_as_str().ok_or_else(|| { + exec_datafusion_err!( + "Unsupported delimiter type {:?} for substr_index", + delim_scalar.data_type() + ) + })?; + + let count = match count_scalar { + ScalarValue::Int64(v) => *v, + other => { + return exec_err!( + "Unsupported count type {:?} for substr_index", + other.data_type() + ); + } + }; + + let (Some(delimiter), Some(count)) = (delimiter, count) else { + return Ok(ColumnarValue::Array(new_null_array( + string_array.data_type(), + string_array.len(), + ))); + }; + + let result = match string_array.data_type() { + DataType::Utf8View => { + substr_index_scalar_view(string_array.as_string_view(), delimiter, count) + } + DataType::Utf8 => { + let arr = string_array.as_string::(); + substr_index_scalar_impl( + arr, + delimiter, + count, + GenericStringArrayBuilder::::with_capacity( + arr.len(), + visible_string_bytes(arr), + ), + ) + } + DataType::LargeUtf8 => { + let arr = string_array.as_string::(); + substr_index_scalar_impl( + arr, + delimiter, + count, + GenericStringArrayBuilder::::with_capacity( + arr.len(), + visible_string_bytes(arr), + ), + ) + } + other => exec_err!("Unsupported string type {other:?} for substr_index"), + }?; + + Ok(ColumnarValue::Array(result)) +} + +#[inline] +fn visible_string_bytes( + string_array: &GenericStringArray, +) -> usize { + let offsets = string_array.value_offsets(); + offsets[offsets.len() - 1].as_usize() - offsets[0].as_usize() +} + +fn substr_index_general<'a, S, O>( + string_array: S, + delimiter_array: S, + count_array: &PrimitiveArray, + mut builder: GenericStringArrayBuilder, ) -> Result where - T::Native: OffsetSizeTrait, + S: StringArrayType<'a> + Copy, + O: OffsetSizeTrait, { - let mut builder = StringBuilder::new(); - let string_iter = ArrayIter::new(string_array); - let delimiter_array_iter = ArrayIter::new(delimiter_array); - let count_array_iter = ArrayIter::new(count_array); - string_iter - .zip(delimiter_array_iter) - .zip(count_array_iter) - .for_each(|((string, delimiter), n)| match (string, delimiter, n) { - (Some(string), Some(delimiter), Some(n)) => { - // In MySQL, these cases will return an empty string. - if n == 0 || string.is_empty() || delimiter.is_empty() { - builder.append_value(""); - return; + let num_rows = string_array.len(); + // Output is null if and only if any input is null. + let nulls = NullBuffer::union_many([ + string_array.nulls(), + delimiter_array.nulls(), + count_array.nulls(), + ]); + + for i in 0..num_rows { + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + builder.append_placeholder(); + continue; + } + // SAFETY: `i < num_rows` and the union of input nulls is valid at i, + // so each input is also valid at i. + let string = unsafe { string_array.value_unchecked(i) }; + let delimiter = unsafe { delimiter_array.value_unchecked(i) }; + let n = unsafe { count_array.value_unchecked(i) }; + builder.append_value(substr_index_slice(string, delimiter, n)); + } + + Ok(Arc::new(builder.finish(nulls)?) as ArrayRef) +} + +fn substr_index_view( + string_array: &StringViewArray, + delimiter_array: &StringViewArray, + count_array: &PrimitiveArray, +) -> Result { + let nulls = NullBuffer::union_many([ + string_array.nulls(), + delimiter_array.nulls(), + count_array.nulls(), + ]); + let views = string_array.views(); + let mut views_buf = Vec::with_capacity(string_array.len()); + let mut has_out_of_line = false; + + for (i, raw_view) in views.iter().enumerate() { + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + views_buf.push(0); + continue; + } + + let string = string_array.value(i); + let delimiter = delimiter_array.value(i); + let count = count_array.value(i); + let substr = substr_index_slice(string, delimiter, count); + has_out_of_line |= append_substr_view(&mut views_buf, raw_view, string, substr); + } + + let data_buffers = if has_out_of_line { + string_array.data_buffers().to_vec() + } else { + vec![] + }; + + // Safety: each appended view is either: + // (1) a copied null sentinel, + // (2) the original valid input view, or + // (3) built by `append_view` for a contiguous substring of the input row. + unsafe { + Ok(Arc::new(StringViewArray::new_unchecked( + ScalarBuffer::from(views_buf), + data_buffers, + nulls, + )) as ArrayRef) + } +} + +fn substr_index_scalar_impl<'a, S, O>( + string_array: S, + delimiter: &str, + count: i64, + builder: GenericStringArrayBuilder, +) -> Result +where + S: StringArrayType<'a> + Copy, + O: OffsetSizeTrait, +{ + if count == 0 || delimiter.is_empty() { + return map_strings(string_array, builder, |string| &string[..0]); + } + + if delimiter.len() == 1 { + let delimiter_byte = delimiter.as_bytes()[0]; + return map_strings(string_array, builder, |string| { + substr_index_single_byte(string, delimiter_byte, count) + }); + } + + let occurrence_idx = usize::try_from(count.unsigned_abs()).unwrap_or(usize::MAX) - 1; + if count > 0 { + let finder = memmem::Finder::new(delimiter.as_bytes()); + map_strings(string_array, builder, |string| { + substr_index_slice_finder(string, &finder, delimiter.len(), occurrence_idx) + }) + } else { + let finder_rev = memmem::FinderRev::new(delimiter.as_bytes()); + map_strings(string_array, builder, |string| { + substr_index_rslice_finder( + string, + &finder_rev, + delimiter.len(), + occurrence_idx, + ) + }) + } +} + +fn substr_index_scalar_view( + string_array: &StringViewArray, + delimiter: &str, + count: i64, +) -> Result { + let views = string_array.views(); + let mut views_buf = Vec::with_capacity(string_array.len()); + let mut has_out_of_line = false; + + if count == 0 || delimiter.is_empty() { + let empty_view = make_view(b"", 0, 0); + for i in 0..string_array.len() { + if string_array.is_null(i) { + views_buf.push(0); + } else { + views_buf.push(empty_view); + } + } + } else if delimiter.len() == 1 { + let delimiter_byte = delimiter.as_bytes()[0]; + for (i, raw_view) in views.iter().enumerate() { + if string_array.is_null(i) { + views_buf.push(0); + continue; + } + + let string = string_array.value(i); + let substr = substr_index_single_byte(string, delimiter_byte, count); + has_out_of_line |= + append_substr_view(&mut views_buf, raw_view, string, substr); + } + } else { + let occurrence_idx = + usize::try_from(count.unsigned_abs()).unwrap_or(usize::MAX) - 1; + if count > 0 { + let finder = memmem::Finder::new(delimiter.as_bytes()); + for (i, raw_view) in views.iter().enumerate() { + if string_array.is_null(i) { + views_buf.push(0); + continue; } - let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); - let length = if n > 0 { - let split = string.split(delimiter); - split - .take(occurrences) - .map(|s| s.len() + delimiter.len()) - .sum::() - - delimiter.len() - } else { - let split = string.rsplit(delimiter); - split - .take(occurrences) - .map(|s| s.len() + delimiter.len()) - .sum::() - - delimiter.len() - }; - if n > 0 { - match string.get(..length) { - Some(substring) => builder.append_value(substring), - None => builder.append_null(), - } - } else { - match string.get(string.len().saturating_sub(length)..) { - Some(substring) => builder.append_value(substring), - None => builder.append_null(), - } + let string = string_array.value(i); + let substr = substr_index_slice_finder( + string, + &finder, + delimiter.len(), + occurrence_idx, + ); + has_out_of_line |= + append_substr_view(&mut views_buf, raw_view, string, substr); + } + } else { + let finder_rev = memmem::FinderRev::new(delimiter.as_bytes()); + for (i, raw_view) in views.iter().enumerate() { + if string_array.is_null(i) { + views_buf.push(0); + continue; } + + let string = string_array.value(i); + let substr = substr_index_rslice_finder( + string, + &finder_rev, + delimiter.len(), + occurrence_idx, + ); + has_out_of_line |= + append_substr_view(&mut views_buf, raw_view, string, substr); } - _ => builder.append_null(), - }); + } + } + + let data_buffers = if has_out_of_line { + string_array.data_buffers().to_vec() + } else { + vec![] + }; + + // Safety: each appended view is either: + // (1) a copied null sentinel, + // (2) the original valid input view, + // (3) an inline empty string view, or + // (4) built by `append_view` for a contiguous substring of the input row. + unsafe { + Ok(Arc::new(StringViewArray::new_unchecked( + ScalarBuffer::from(views_buf), + data_buffers, + string_array.nulls().cloned(), + )) as ArrayRef) + } +} + +fn map_strings<'a, S, O, F>( + string_array: S, + mut builder: GenericStringArrayBuilder, + f: F, +) -> Result +where + S: StringArrayType<'a> + Copy, + O: OffsetSizeTrait, + F: Fn(&'a str) -> &'a str, +{ + let nulls = string_array.nulls().cloned(); + for i in 0..string_array.len() { + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + builder.append_placeholder(); + continue; + } + // SAFETY: `i < string_array.len()` and `nulls` is valid at i, so the + // input is also valid at i. + let s = unsafe { string_array.value_unchecked(i) }; + builder.append_value(f(s)); + } + Ok(Arc::new(builder.finish(nulls)?) as ArrayRef) +} + +#[inline] +fn substr_index_slice<'a>(string: &'a str, delimiter: &str, count: i64) -> &'a str { + if count == 0 || string.is_empty() || delimiter.is_empty() { + return &string[..0]; + } + + if delimiter.len() == 1 { + return substr_index_single_byte(string, delimiter.as_bytes()[0], count); + } + + let occurrences = usize::try_from(count.unsigned_abs()).unwrap_or(usize::MAX); + if count > 0 { + string + .match_indices(delimiter) + .nth(occurrences - 1) + .map(|(idx, _)| &string[..idx]) + .unwrap_or(string) + } else { + string + .rmatch_indices(delimiter) + .nth(occurrences - 1) + .map(|(idx, _)| &string[idx + delimiter.len()..]) + .unwrap_or(string) + } +} + +#[inline] +fn substr_index_single_byte(string: &str, delimiter: u8, count: i64) -> &str { + let occurrences = usize::try_from(count.unsigned_abs()).unwrap_or(usize::MAX); + let idx = if count > 0 { + memchr_iter(delimiter, string.as_bytes()).nth(occurrences - 1) + } else { + memrchr_iter(delimiter, string.as_bytes()) + .nth(occurrences - 1) + .map(|idx| idx + 1) + }; + + match idx { + Some(idx) if count > 0 => &string[..idx], + Some(idx) => &string[idx..], + None => string, + } +} + +#[inline] +fn substr_index_slice_finder<'a>( + string: &'a str, + finder: &memmem::Finder, + delimiter_len: usize, + occurrence_idx: usize, +) -> &'a str { + let bytes = string.as_bytes(); + let mut start = 0; + for _ in 0..occurrence_idx { + match finder.find(&bytes[start..]) { + Some(pos) => start += pos + delimiter_len, + None => return string, + } + } + + match finder.find(&bytes[start..]) { + Some(pos) => &string[..start + pos], + None => string, + } +} + +#[inline] +fn substr_index_rslice_finder<'a>( + string: &'a str, + finder: &memmem::FinderRev, + delimiter_len: usize, + occurrence_idx: usize, +) -> &'a str { + let bytes = string.as_bytes(); + let mut end = bytes.len(); + for _ in 0..occurrence_idx { + match finder.rfind(&bytes[..end]) { + Some(pos) => end = pos, + None => return string, + } + } - Ok(Arc::new(builder.finish()) as ArrayRef) + match finder.rfind(&bytes[..end]) { + Some(pos) => &string[pos + delimiter_len..], + None => string, + } +} + +#[inline] +fn substr_view(original_view: &u128, substr: &str, start_offset: u32) -> u128 { + if substr.len() > 12 { + let view = ByteView::from(*original_view); + make_view( + substr.as_bytes(), + view.buffer_index, + view.offset + start_offset, + ) + } else { + make_view(substr.as_bytes(), 0, 0) + } +} + +#[inline] +fn append_substr_view( + views_buf: &mut Vec, + raw_view: &u128, + string: &str, + substr: &str, +) -> bool { + if substr.len() == string.len() { + views_buf.push(*raw_view); + return substr.len() > 12; + } + + if substr.is_empty() { + views_buf.push(make_view(b"", 0, 0)); + return false; + } + + let start_offset = substr.as_ptr() as usize - string.as_ptr() as usize; + let start_offset = + u32::try_from(start_offset).expect("string view offsets fit in u32"); + views_buf.push(substr_view(raw_view, substr, start_offset)); + substr.len() > 12 } #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{ + Array, ArrayRef, AsArray, Int64Array, StringArray, StringViewArray, + }; + use arrow::datatypes::DataType::{Utf8, Utf8View}; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::config::ConfigOptions; use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + use std::sync::Arc; use crate::unicode::substrindex::SubstrIndexFunc; use crate::utils::test::test_function; @@ -328,6 +726,135 @@ mod tests { Utf8, StringArray ); + test_function!( + SubstrIndexFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "verylongprefix.segment.tail".into(), + ))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(".".into()))), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + Ok(Some("verylongprefix")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + SubstrIndexFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "www.apache.org".into(), + ))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(".".into()))), + ColumnarValue::Scalar(ScalarValue::from(-1i64)), + ], + Ok(Some("org")), + &str, + Utf8View, + StringViewArray + ); + Ok(()) + } + + #[test] + fn test_substr_index_utf8view_scalar_fast_path() -> Result<()> { + let input = Arc::new(StringViewArray::from(vec![ + Some("alpha.beta.gamma"), + Some("short.val"), + None, + ])) as ArrayRef; + + let arg_fields = vec![ + Field::new("a", Utf8View, true).into(), + Field::new("b", Utf8View, true).into(), + Field::new("c", DataType::Int64, true).into(), + ]; + + let args = ScalarFunctionArgs { + number_rows: input.len(), + args: vec![ + ColumnarValue::Array(Arc::clone(&input)), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(".".into()))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ], + arg_fields, + return_field: Field::new("f", Utf8View, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = match SubstrIndexFunc::new().invoke_with_args(args)? { + ColumnarValue::Array(result) => result, + other => panic!("expected array result, got {other:?}"), + }; + let result = result.as_string_view(); + + assert_eq!(result.len(), 3); + assert_eq!(result.value(0), "alpha"); + assert_eq!(result.value(1), "short"); + assert!(result.is_null(2)); + + Ok(()) + } + + #[test] + fn test_substr_index_utf8view_array_sliced() -> Result<()> { + use super::substr_index_view; + + let strings: StringViewArray = vec![ + Some("skip_this.value"), + Some("this_is_a_long_prefix.suffix"), + Some("short.val"), + Some("another_long_result.rest"), + None, + ] + .into_iter() + .collect(); + let delimiters: StringViewArray = + vec![Some("."), Some("."), Some("."), Some("."), Some(".")] + .into_iter() + .collect(); + let counts = Int64Array::from(vec![1, 1, -1, 1, 1]); + + let sliced_strings = strings.slice(1, 4); + let sliced_delimiters = delimiters.slice(1, 4); + let sliced_counts = counts.slice(1, 4); + + let result = + substr_index_view(&sliced_strings, &sliced_delimiters, &sliced_counts)?; + let result = result.as_string_view(); + + assert_eq!(result.len(), 4); + assert_eq!(result.value(0), "this_is_a_long_prefix"); + assert_eq!(result.value(1), "val"); + assert_eq!(result.value(2), "another_long_result"); + assert!(result.is_null(3)); + + Ok(()) + } + + #[test] + fn test_substr_index_utf8view_scalar_reuses_original_view_when_unchanged() + -> Result<()> { + use super::substr_index_scalar_view; + + let strings: StringViewArray = vec![ + Some("very_long_value_without_separator"), + Some("short"), + None, + ] + .into_iter() + .collect(); + + let result = substr_index_scalar_view(&strings, ".", 1)?; + let result = result.as_string_view(); + + assert_eq!(result.len(), 3); + assert_eq!(result.value(0), "very_long_value_without_separator"); + assert_eq!(result.value(1), "short"); + assert_eq!(result.views()[0], strings.views()[0]); + assert_eq!(result.views()[1], strings.views()[1]); + assert!(result.is_null(2)); Ok(()) } diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index 911b8d311996e..85e83897f41da 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -15,28 +15,29 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::sync::Arc; - -use arrow::array::{ - ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait, -}; +use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray, StringArrayType}; +use arrow::buffer::NullBuffer; use arrow::datatypes::DataType; use datafusion_common::HashMap; -use unicode_segmentation::UnicodeSegmentation; -use crate::utils::{make_scalar_function, utf8_to_str_type}; -use datafusion_common::{exec_err, Result}; +use super::common::try_as_scalar_str; +use crate::strings::{ + BulkNullStringArrayBuilder, GenericStringArrayBuilder, StringViewArrayBuilder, + StringWriter, +}; +use crate::utils::make_scalar_function; +use datafusion_common::{Result, exec_err}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; #[user_doc( doc_section(label = "String Functions"), - description = "Translates characters in a string to specified translation characters.", - syntax_example = "translate(str, chars, translation)", + description = "Performs character-wise substitution based on a mapping.", + syntax_example = "translate(str, from, to)", sql_example = r#"```sql > select translate('twice', 'wic', 'her'); +--------------------------------------------------+ @@ -46,10 +47,10 @@ use datafusion_macros::user_doc; +--------------------------------------------------+ ```"#, standard_argument(name = "str", prefix = "String"), - argument(name = "chars", description = "Characters to translate."), + argument(name = "from", description = "The characters to be replaced."), argument( - name = "translation", - description = "Translation characters. Translation characters replace only characters at the same position in the **chars** string." + name = "to", + description = "The characters to replace them with. Each character in **from** that is found in **str** is replaced by the character at the same index in **to**. Any characters in **from** that don't have a corresponding character in **to** are removed. If a character appears more than once in **from**, the first occurrence determines the mapping." ) )] #[derive(Debug, PartialEq, Eq, Hash)] @@ -71,6 +72,7 @@ impl TranslateFunc { vec![ Exact(vec![Utf8View, Utf8, Utf8]), Exact(vec![Utf8, Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8, Utf8]), ], Volatility::Immutable, ), @@ -79,10 +81,6 @@ impl TranslateFunc { } impl ScalarUDFImpl for TranslateFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "translate" } @@ -92,13 +90,52 @@ impl ScalarUDFImpl for TranslateFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "translate") + Ok(arg_types[0].clone()) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // When from and to are scalars, pre-build the translation map once + if let (Some(from_str), Some(to_str)) = ( + try_as_scalar_str(&args.args[1]), + try_as_scalar_str(&args.args[2]), + ) { + let table = build_translate_table(from_str, to_str); + + let string_array = args.args[0].to_array_of_size(args.number_rows)?; + let len = string_array.len(); + + let result = match string_array.data_type() { + DataType::Utf8View => { + let arr = string_array.as_string_view(); + let builder = StringViewArrayBuilder::with_capacity(len); + translate_with_table(&arr, &table, builder) + } + DataType::Utf8 => { + let arr = string_array.as_string::(); + let builder = GenericStringArrayBuilder::::with_capacity( + len, + arr.value_data().len(), + ); + translate_with_table(&arr, &table, builder) + } + DataType::LargeUtf8 => { + let arr = string_array.as_string::(); + let builder = GenericStringArrayBuilder::::with_capacity( + len, + arr.value_data().len(), + ); + translate_with_table(&arr, &table, builder) + } + other => { + return exec_err!( + "Unsupported data type {other:?} for function translate" + ); + } + }?; + + return Ok(ColumnarValue::Array(result)); + } + make_scalar_function(invoke_translate, vec![])(&args.args) } @@ -108,24 +145,34 @@ impl ScalarUDFImpl for TranslateFunc { } fn invoke_translate(args: &[ArrayRef]) -> Result { + let len = args[0].len(); match args[0].data_type() { DataType::Utf8View => { let string_array = args[0].as_string_view(); let from_array = args[1].as_string::(); let to_array = args[2].as_string::(); - translate::(string_array, from_array, to_array) + let builder = StringViewArrayBuilder::with_capacity(len); + translate(&string_array, from_array, to_array, builder) } DataType::Utf8 => { let string_array = args[0].as_string::(); let from_array = args[1].as_string::(); let to_array = args[2].as_string::(); - translate::(string_array, from_array, to_array) + let builder = GenericStringArrayBuilder::::with_capacity( + len, + string_array.value_data().len(), + ); + translate(&string_array, from_array, to_array, builder) } DataType::LargeUtf8 => { let string_array = args[0].as_string::(); - let from_array = args[1].as_string::(); - let to_array = args[2].as_string::(); - translate::(string_array, from_array, to_array) + let from_array = args[1].as_string::(); + let to_array = args[2].as_string::(); + let builder = GenericStringArrayBuilder::::with_capacity( + len, + string_array.value_data().len(), + ); + translate(&string_array, from_array, to_array, builder) } other => { exec_err!("Unsupported data type {other:?} for function translate") @@ -133,61 +180,263 @@ fn invoke_translate(args: &[ArrayRef]) -> Result { } } -/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. +/// Replaces each character in string that matches a character in the from set +/// with the corresponding character in the to set. If from is longer than to, +/// occurrences of the extra characters in from are deleted. +/// /// translate('12345', '143', 'ax') = 'a2x5' -fn translate<'a, T: OffsetSizeTrait, V, B>( - string_array: V, - from_array: B, - to_array: B, +fn translate<'a, S, O>( + string_array: &S, + from_array: &GenericStringArray, + to_array: &GenericStringArray, + mut builder: O, ) -> Result where - V: ArrayAccessor, - B: ArrayAccessor, + S: StringArrayType<'a>, + O: BulkNullStringArrayBuilder, { - let string_array_iter = ArrayIter::new(string_array); - let from_array_iter = ArrayIter::new(from_array); - let to_array_iter = ArrayIter::new(to_array); - - let result = string_array_iter - .zip(from_array_iter) - .zip(to_array_iter) - .map(|((string, from), to)| match (string, from, to) { - (Some(string), Some(from), Some(to)) => { - // create a hashmap of [char, index] to change from O(n) to O(1) for from list - let from_map: HashMap<&str, usize> = from - .graphemes(true) - .collect::>() - .iter() - .enumerate() - .map(|(index, c)| (c.to_owned(), index)) - .collect(); - - let to = to.graphemes(true).collect::>(); - - Some( - string - .graphemes(true) - .collect::>() - .iter() - .flat_map(|c| match from_map.get(*c) { - Some(n) => to.get(*n).copied(), - None => Some(*c), - }) - .collect::>() - .concat(), - ) + let mut from_map: HashMap> = HashMap::new(); + let len = string_array.len(); + let nulls = NullBuffer::union_many([ + string_array.nulls(), + from_array.nulls(), + to_array.nulls(), + ]); + + if let Some(nulls_ref) = nulls.as_ref() { + for i in 0..len { + if nulls_ref.is_null(i) { + builder.append_placeholder(); + continue; + } + + // SAFETY: union of input nulls is non-null at i, so each input is too. + let string = unsafe { string_array.value_unchecked(i) }; + let from = unsafe { from_array.value_unchecked(i) }; + let to = unsafe { to_array.value_unchecked(i) }; + append_translated_row(&mut builder, string, from, to, &mut from_map); + } + } else { + for i in 0..len { + // SAFETY: i < len, and no input has a null buffer. + let string = unsafe { string_array.value_unchecked(i) }; + let from = unsafe { from_array.value_unchecked(i) }; + let to = unsafe { to_array.value_unchecked(i) }; + append_translated_row(&mut builder, string, from, to, &mut from_map); + } + } + + builder.finish(nulls) +} + +#[inline] +fn append_translated_row( + builder: &mut B, + string: &str, + from: &str, + to: &str, + from_map: &mut HashMap>, +) { + if let Some(ascii_table) = build_ascii_translate_table(from, to) { + append_translated_ascii(builder, string, &ascii_table); + return; + } + + from_map.clear(); + let mut to_iter = to.chars(); + for c in from.chars() { + let replacement = to_iter.next(); + from_map.entry(c).or_insert(replacement); + } + + builder.append_with(|w| write_translated_chars(w, string, from_map)); +} + +#[inline] +fn write_translated_chars( + w: &mut W, + input: &str, + from_map: &HashMap>, +) { + for c in input.chars() { + match from_map.get(&c) { + Some(Some(r)) => w.write_char(*r), + Some(None) => {} // delete: `from` had no corresponding `to` char + None => w.write_char(c), + } + } +} + +/// Sentinel value in the ASCII translate table indicating the character should +/// be deleted (the `from` character has no corresponding `to` character). Any +/// value > 127 works since valid ASCII is 0–127. +const ASCII_DELETE: u8 = 0xFF; + +/// Lookup table for ASCII-only translation. Entries 0..128 map input bytes to +/// replacement bytes, or `ASCII_DELETE` if the character should be deleted. +/// Entries 128..256 map to themselves so non-ASCII bytes pass through +/// unchanged. +#[derive(Debug)] +struct AsciiTranslateTable { + map: [u8; 256], + has_delete: bool, +} + +/// We use a byte-indexed table when both `from` and `to` strings are ASCII, +/// otherwise a char-indexed map where `None` means delete. +#[expect( + clippy::large_enum_variant, + reason = "one instance per call, passed by reference" +)] +enum TranslateTable { + Byte(AsciiTranslateTable), + Char(HashMap>), +} + +#[inline] +fn build_translate_table(from: &str, to: &str) -> TranslateTable { + if let Some(ascii) = build_ascii_translate_table(from, to) { + return TranslateTable::Byte(ascii); + } + let mut from_map: HashMap> = HashMap::with_capacity(from.len()); + let mut to_iter = to.chars(); + for c in from.chars() { + let replacement = to_iter.next(); + from_map.entry(c).or_insert(replacement); + } + TranslateTable::Char(from_map) +} + +/// Returns `None` if either string contains non-ASCII characters. +fn build_ascii_translate_table(from: &str, to: &str) -> Option { + if !from.is_ascii() || !to.is_ascii() { + return None; + } + + let to_bytes = to.as_bytes(); + let mut map = std::array::from_fn::(|i| i as u8); + let mut seen = [false; 128]; + let mut has_delete = false; + + for (i, from_byte) in from.bytes().enumerate() { + let idx = from_byte as usize; + if !seen[idx] { + seen[idx] = true; + if i < to_bytes.len() { + map[idx] = to_bytes[i]; + } else { + map[idx] = ASCII_DELETE; + has_delete = true; } - _ => None, - }) - .collect::>(); + } + } + + Some(AsciiTranslateTable { map, has_delete }) +} - Ok(Arc::new(result) as ArrayRef) +#[inline] +fn append_translated_ascii( + builder: &mut B, + input: &str, + table: &AsciiTranslateTable, +) { + // Fast path: equal-length byte-to-byte map when no deletions. + if !table.has_delete { + // SAFETY: ASCII source bytes map to ASCII replacements; non-ASCII + // bytes 128..256 map to themselves, so multi-byte UTF-8 sequences + // pass through unchanged. Output length equals input length and + // remains valid UTF-8. + unsafe { + builder.append_byte_map(input.as_bytes(), |b| table.map[b as usize]); + } + } else { + builder.append_with(|w| write_translated_ascii(w, input, table)); + } +} + +#[inline] +fn write_translated_ascii( + w: &mut W, + input: &str, + table: &AsciiTranslateTable, +) { + let bytes = input.as_bytes(); + let mut copy_start = 0; + + for (i, &b) in bytes.iter().enumerate() { + let mapped = table.map[b as usize]; + if mapped == b { + continue; + } + + if copy_start < i { + w.write_str(&input[copy_start..i]); + } + if mapped != ASCII_DELETE { + w.write_char(mapped as char); + } + copy_start = i + 1; + } + + if copy_start < input.len() { + w.write_str(&input[copy_start..]); + } +} + +fn translate_with_table<'a, S, O>( + string_array: &S, + table: &TranslateTable, + mut builder: O, +) -> Result +where + S: StringArrayType<'a>, + O: BulkNullStringArrayBuilder, +{ + let len = string_array.len(); + let nulls = string_array.nulls().cloned(); + + if let Some(nulls_ref) = nulls.as_ref() { + for i in 0..len { + if nulls_ref.is_null(i) { + builder.append_placeholder(); + continue; + } + + // SAFETY: input null buffer is non-null at i. + let s = unsafe { string_array.value_unchecked(i) }; + apply_translate_table(&mut builder, s, table); + } + } else { + for i in 0..len { + // SAFETY: no null buffer means every index is valid. + let s = unsafe { string_array.value_unchecked(i) }; + apply_translate_table(&mut builder, s, table); + } + } + + builder.finish(nulls) +} + +#[inline] +fn apply_translate_table( + builder: &mut B, + input: &str, + table: &TranslateTable, +) { + match table { + TranslateTable::Byte(t) => append_translated_ascii(builder, input, t), + TranslateTable::Char(m) => { + builder.append_with(|w| write_translated_chars(w, input, m)) + } + } } #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use std::sync::Arc; + + use arrow::array::{Array, ArrayRef, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{Utf8, Utf8View}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -245,6 +494,18 @@ mod tests { Utf8, StringArray ); + test_function!( + TranslateFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("abcabc")), + ColumnarValue::Scalar(ScalarValue::from("aa")), + ColumnarValue::Scalar(ScalarValue::from("de")) + ], + Ok(Some("dbcdbc")), + &str, + Utf8, + StringArray + ); test_function!( TranslateFunc::new(), vec![ @@ -257,6 +518,59 @@ mod tests { Utf8, StringArray ); + // Non-ASCII input with ASCII scalar from/to. + test_function!( + TranslateFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::from("café")), + ColumnarValue::Scalar(ScalarValue::from("ae")), + ColumnarValue::Scalar(ScalarValue::from("AE")) + ], + Ok(Some("cAfé")), + &str, + Utf8, + StringArray + ); + // Utf8View input should produce Utf8View output + test_function!( + TranslateFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("12345".into()))), + ColumnarValue::Scalar(ScalarValue::from("143")), + ColumnarValue::Scalar(ScalarValue::from("ax")) + ], + Ok(Some("a2x5")), + &str, + Utf8View, + StringViewArray + ); + // Null Utf8View input + test_function!( + TranslateFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ColumnarValue::Scalar(ScalarValue::from("143")), + ColumnarValue::Scalar(ScalarValue::from("ax")) + ], + Ok(None), + &str, + Utf8View, + StringViewArray + ); + // Non-ASCII Utf8View input + test_function!( + TranslateFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("é2íñ5".into()))), + ColumnarValue::Scalar(ScalarValue::from("éñí")), + ColumnarValue::Scalar(ScalarValue::from("óü")) + ], + Ok(Some("ó2ü5")), + &str, + Utf8View, + StringViewArray + ); + #[cfg(not(feature = "unicode_expressions"))] test_function!( TranslateFunc::new(), @@ -275,4 +589,27 @@ mod tests { Ok(()) } + + #[test] + fn test_array_args_with_nulls() -> Result<()> { + let string_array = Arc::new(StringArray::from(vec![ + Some("café!"), + Some("abc"), + Some("abc"), + ])) as ArrayRef; + let from_array = + Arc::new(StringArray::from(vec![Some("!"), Some("a"), None])) as ArrayRef; + let to_array = + Arc::new(StringArray::from(vec![Some(""), Some("x"), Some("y")])) as ArrayRef; + + let result = super::invoke_translate(&[string_array, from_array, to_array])?; + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.len(), 3); + assert_eq!(result.value(0), "café"); + assert_eq!(result.value(1), "xbc"); + assert!(result.is_null(2)); + + Ok(()) + } } diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 5a2368a38ef9d..b9bde1454994c 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -19,9 +19,9 @@ use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray} use arrow::compute::try_binary; use arrow::datatypes::{DataType, DecimalType}; use arrow::error::ArrowError; -use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::function::Hint; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::ColumnarValue; +use datafusion_expr::function::Hint; use std::sync::Arc; /// Creates a function to identify the optimal return type of a string function given @@ -147,7 +147,7 @@ where if scalar.is_null() { // Null scalar is castable to any numeric, creating a non-null expression. // Provide null array explicitly to make result null - PrimitiveArray::::new_null(1) + PrimitiveArray::::new_null(left.len()) } else { let right = R::Native::try_from(scalar.clone()).map_err(|_| { DataFusionError::NotImplemented(format!( @@ -189,20 +189,50 @@ where R::Native: TryFrom, { let result_array = calculate_binary_math::(left, right, fun)?; + Ok(Arc::new( + result_array + .as_ref() + .clone() + .with_precision_and_scale(precision, scale)?, + )) +} + +/// Converts Decimal128 components (value and scale) to an unscaled i128 +pub fn decimal128_to_i128(value: i128, scale: i8) -> Result { if scale < 0 { - not_impl_err!("Negative scale is not supported for power for decimal types") + Err(ArrowError::ComputeError( + "Negative scale is not supported".into(), + )) + } else if scale == 0 { + Ok(value) } else { - Ok(Arc::new( - result_array - .as_ref() - .clone() - .with_precision_and_scale(precision, scale)?, + match i128::from(10).checked_pow(scale as u32) { + Some(divisor) => Ok(value / divisor), + None => Err(ArrowError::ComputeError(format!( + "Cannot get a power of {scale}" + ))), + } + } +} + +pub fn decimal32_to_i32(value: i32, scale: i8) -> Result { + if scale < 0 { + Err(ArrowError::ComputeError( + "Negative scale is not supported".into(), )) + } else if scale == 0 { + Ok(value) + } else { + match 10_i32.checked_pow(scale as u32) { + Some(divisor) => Ok(value / divisor), + None => Err(ArrowError::ComputeError(format!( + "Cannot get a power of {scale}" + ))), + } } } -/// Converts Decimal128 components (value and scale) to an unscaled i128 -pub fn decimal128_to_i128(value: i128, scale: i8) -> Result { +pub fn decimal64_to_i64(value: i64, scale: i8) -> Result { if scale < 0 { Err(ArrowError::ComputeError( "Negative scale is not supported".into(), @@ -210,7 +240,7 @@ pub fn decimal128_to_i128(value: i128, scale: i8) -> Result { } else if scale == 0 { Ok(value) } else { - match i128::from(10).checked_pow(scale as u32) { + match i64::from(10).checked_pow(scale as u32) { Some(divisor) => Ok(value / divisor), None => Err(ArrowError::ComputeError(format!( "Cannot get a power of {scale}" @@ -333,12 +363,30 @@ pub mod test { }; } - use arrow::datatypes::DataType; - #[allow(unused_imports)] + use arrow::{ + array::Int32Array, + datatypes::{DataType, Int32Type}, + }; + use itertools::Either; pub(crate) use test_function; use super::*; + #[test] + fn test_calculate_binary_math_scalar_null() { + let left = Int32Array::from(vec![1, 2]); + let right = ColumnarValue::Scalar(ScalarValue::Int32(None)); + let result = calculate_binary_math::( + &left, + &right, + |x, y| Ok(x + y), + ) + .unwrap(); + + assert_eq!(result.len(), 2); + assert_eq!(result.null_count(), 2); + } + #[test] fn string_to_int_type() { let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap(); @@ -377,4 +425,106 @@ pub mod test { } } } + + #[test] + fn test_decimal32_to_i32() { + let cases: [(i32, i8, Either); _] = [ + (123, 0, Either::Left(123)), + (1230, 1, Either::Left(123)), + (123000, 3, Either::Left(123)), + (1234567, 2, Either::Left(12345)), + (-1234567, 2, Either::Left(-12345)), + (1, 0, Either::Left(1)), + ( + 123, + -3, + Either::Right("Negative scale is not supported".into()), + ), + ( + 123, + i8::MAX, + Either::Right("Cannot get a power of 127".into()), + ), + (999999999, 0, Either::Left(999999999)), + (999999999, 3, Either::Left(999999)), + ]; + + for (value, scale, expected) in cases { + match decimal32_to_i32(value, scale) { + Ok(actual) => { + let expected_value = + expected.left().expect("Got value but expected none"); + assert_eq!( + actual, expected_value, + "{value} and {scale} vs {expected_value:?}" + ); + } + Err(ArrowError::ComputeError(msg)) => { + assert_eq!( + msg, + expected.right().expect("Got error but expected value") + ); + } + Err(_) => { + assert!(expected.is_right()) + } + } + } + } + + #[test] + fn test_decimal64_to_i64() { + let cases: [(i64, i8, Either); _] = [ + (123, 0, Either::Left(123)), + (1234567890, 2, Either::Left(12345678)), + (-1234567890, 2, Either::Left(-12345678)), + ( + 123, + -3, + Either::Right("Negative scale is not supported".into()), + ), + ( + 123, + i8::MAX, + Either::Right("Cannot get a power of 127".into()), + ), + ( + 999999999999999999i64, + 0, + Either::Left(999999999999999999i64), + ), + ( + 999999999999999999i64, + 3, + Either::Left(999999999999999999i64 / 1000), + ), + ( + -999999999999999999i64, + 3, + Either::Left(-999999999999999999i64 / 1000), + ), + ]; + + for (value, scale, expected) in cases { + match decimal64_to_i64(value, scale) { + Ok(actual) => { + let expected_value = + expected.left().expect("Got value but expected none"); + assert_eq!( + actual, expected_value, + "{value} and {scale} vs {expected_value:?}" + ); + } + Err(ArrowError::ComputeError(msg)) => { + assert_eq!( + msg, + expected.right().expect("Got error but expected value") + ); + } + Err(_) => { + assert!(expected.is_right()) + } + } + } + } } diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 15d3261ca5132..0822a17f24f16 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -43,6 +43,14 @@ name = "datafusion_optimizer" [features] recursive_protection = ["dep:recursive"] +# Note -- please DO NOT add a dependency here to any of the datafusion-functions +# crates. While it is tempting to try and add an optimizer pass that uses +# datafusion-functions Doing so makes it harder for downstream crates to +# provide their own function library and smaller install footprint. +# +# If you want to add special handling for a specific function, use the methods +# on the ScalarUDFImpl or AggregateUDFImpl traits (or add a new method to those +# traits). [dependencies] arrow = { workspace = true } chrono = { workspace = true } @@ -55,7 +63,7 @@ itertools = { workspace = true } log = { workspace = true } recursive = { workspace = true, optional = true } regex = { workspace = true } -regex-syntax = "0.8.6" +regex-syntax = "0.8.9" [dev-dependencies] async-trait = { workspace = true } @@ -71,3 +79,11 @@ insta = { workspace = true } [[bench]] name = "projection_unnecessary" harness = false + +[[bench]] +name = "optimize_projections" +harness = false + +[[bench]] +name = "unions_to_filter" +harness = false diff --git a/datafusion/optimizer/benches/optimize_projections.rs b/datafusion/optimizer/benches/optimize_projections.rs new file mode 100644 index 0000000000000..d190c5ceabb2f --- /dev/null +++ b/datafusion/optimizer/benches/optimize_projections.rs @@ -0,0 +1,235 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Micro-benchmarks for the `OptimizeProjections` logical optimizer rule. +//! +//! Each case models a plan shape typical of TPC-H, TPC-DS, or ClickBench. +//! Schemas use realistic widths and the rule operates on a fresh +//! `LogicalPlan` per iteration (construction is in the criterion setup +//! closure and excluded from measurement). + +use std::hint::black_box; + +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; +use datafusion_expr::{ + JoinType, LogicalPlan, LogicalPlanBuilder, col, lit, logical_plan::table_scan, +}; +use datafusion_functions_aggregate::expr_fn::sum; +use datafusion_optimizer::optimize_projections::OptimizeProjections; +use datafusion_optimizer::{OptimizerContext, OptimizerRule}; + +fn table(name: &str, cols: usize) -> LogicalPlan { + let fields: Vec = (0..cols) + .map(|i| Field::new(format!("c{i}"), DataType::Int32, true)) + .collect(); + table_scan(Some(name), &Schema::new(fields), None) + .unwrap() + .build() + .unwrap() +} + +fn scan_with_filter(name: &str, cols: usize, filter_col: usize) -> LogicalPlan { + LogicalPlanBuilder::from(table(name, cols)) + .filter(col(format!("{name}.c{filter_col}")).gt(lit(0i32))) + .unwrap() + .build() + .unwrap() +} + +/// TPC-H Q3-like: customer ⨝ orders ⨝ lineitem with filters above each scan, +/// GROUP BY 3 keys, 1 SUM aggregate. Models the canonical filter→join→aggregate +/// analytical shape after PushDownFilter. +fn plan_tpch_q3() -> LogicalPlan { + let customer = scan_with_filter("customer", 8, 6); + let orders = scan_with_filter("orders", 9, 4); + let lineitem = scan_with_filter("lineitem", 16, 10); + + LogicalPlanBuilder::from(customer) + .join_on( + orders, + JoinType::Inner, + vec![col("customer.c0").eq(col("orders.c1"))], + ) + .unwrap() + .join_on( + lineitem, + JoinType::Inner, + vec![col("lineitem.c0").eq(col("orders.c0"))], + ) + .unwrap() + .aggregate( + vec![col("lineitem.c0"), col("orders.c4"), col("orders.c7")], + vec![sum(col("lineitem.c5") - col("lineitem.c6"))], + ) + .unwrap() + .build() + .unwrap() +} + +/// TPC-H Q5-like: 6-way join through region→nation→customer→orders→lineitem +/// →supplier, GROUP BY 1 key, 1 SUM. Exercises nested-join pruning depth. +fn plan_tpch_q5() -> LogicalPlan { + let region = scan_with_filter("region", 3, 1); + let nation = table("nation", 4); + let customer = table("customer", 8); + let orders = table("orders", 9); + let lineitem = table("lineitem", 16); + let supplier = table("supplier", 7); + + LogicalPlanBuilder::from(region) + .join_on( + nation, + JoinType::Inner, + vec![col("region.c0").eq(col("nation.c2"))], + ) + .unwrap() + .join_on( + customer, + JoinType::Inner, + vec![col("nation.c0").eq(col("customer.c3"))], + ) + .unwrap() + .join_on( + orders, + JoinType::Inner, + vec![col("customer.c0").eq(col("orders.c1"))], + ) + .unwrap() + .join_on( + lineitem, + JoinType::Inner, + vec![col("lineitem.c0").eq(col("orders.c0"))], + ) + .unwrap() + .join_on( + supplier, + JoinType::Inner, + vec![col("lineitem.c2").eq(col("supplier.c0"))], + ) + .unwrap() + .aggregate( + vec![col("nation.c1")], + vec![sum(col("lineitem.c5") - col("lineitem.c6"))], + ) + .unwrap() + .build() + .unwrap() +} + +/// ClickBench-style: single wide `hits` table (100 cols), conjunctive filter, +/// GROUP BY 2 keys, 2 SUM aggregates. Stresses wide-schema column lookup. +fn plan_clickbench_groupby() -> LogicalPlan { + let hits = table("hits", 100); + let predicate = col("hits.c5") + .gt(lit(100i32)) + .and(col("hits.c12").lt(lit(1000i32))); + LogicalPlanBuilder::from(hits) + .filter(predicate) + .unwrap() + .aggregate( + vec![col("hits.c3"), col("hits.c7")], + vec![sum(col("hits.c42")), sum(col("hits.c60"))], + ) + .unwrap() + .build() + .unwrap() +} + +/// TPC-DS-style CTE shape: a SubqueryAlias wrapping a filter+projection over +/// a wide fact table, joined back on two dimension tables and aggregated. +fn plan_tpcds_subquery() -> LogicalPlan { + let store_sales = table("store_sales", 23); + let customer = table("customer", 18); + let item = table("item", 22); + + let sub = LogicalPlanBuilder::from(store_sales) + .filter(col("store_sales.c5").gt(lit(0i32))) + .unwrap() + .project(vec![ + col("store_sales.c0"), + col("store_sales.c3"), + col("store_sales.c13"), + ]) + .unwrap() + .alias("sub") + .unwrap() + .build() + .unwrap(); + + LogicalPlanBuilder::from(customer) + .join_on( + sub, + JoinType::Inner, + vec![col("customer.c0").eq(col("sub.c3"))], + ) + .unwrap() + .join_on( + item, + JoinType::Inner, + vec![col("item.c0").eq(col("sub.c0"))], + ) + .unwrap() + .aggregate(vec![col("customer.c2")], vec![sum(col("sub.c13"))]) + .unwrap() + .build() + .unwrap() +} + +/// Narrow 10-column table, single filter, project 3 cols. Guards against +/// regressions on the common small-schema case where a lookup-map fix for +/// wide schemas might hurt by adding hashing overhead. +fn plan_small_schema() -> LogicalPlan { + LogicalPlanBuilder::from(table("t", 10)) + .filter(col("t.c3").gt(lit(0i32))) + .unwrap() + .project(vec![col("t.c0"), col("t.c1"), col("t.c5")]) + .unwrap() + .build() + .unwrap() +} + +type BenchCase = (&'static str, fn() -> LogicalPlan); + +fn bench_optimize_projections(c: &mut Criterion) { + let rule = OptimizeProjections::new(); + let config = OptimizerContext::new(); + let mut group = c.benchmark_group("optimize_projections"); + + let cases: &[BenchCase] = &[ + ("tpch_q3", plan_tpch_q3), + ("tpch_q5", plan_tpch_q5), + ("clickbench_groupby", plan_clickbench_groupby), + ("tpcds_subquery", plan_tpcds_subquery), + ("small_schema", plan_small_schema), + ]; + + for (name, build) in cases { + group.bench_function(*name, |b| { + b.iter_batched( + build, + |plan| black_box(rule.rewrite(plan, &config).unwrap()), + BatchSize::SmallInput, + ); + }); + } + + group.finish(); +} + +criterion_group!(benches, bench_optimize_projections); +criterion_main!(benches); diff --git a/datafusion/optimizer/benches/projection_unnecessary.rs b/datafusion/optimizer/benches/projection_unnecessary.rs index bdc59de4820b7..2082ed6a37515 100644 --- a/datafusion/optimizer/benches/projection_unnecessary.rs +++ b/datafusion/optimizer/benches/projection_unnecessary.rs @@ -16,10 +16,10 @@ // under the License. use arrow::datatypes::{DataType, Field, Schema}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::ToDFSchema; use datafusion_common::{Column, TableReference}; -use datafusion_expr::{logical_plan::LogicalPlan, projection_schema, Expr}; +use datafusion_expr::{Expr, logical_plan::LogicalPlan, projection_schema}; use datafusion_optimizer::optimize_projections::is_projection_unnecessary; use std::hint::black_box; use std::sync::Arc; diff --git a/datafusion/optimizer/benches/unions_to_filter.rs b/datafusion/optimizer/benches/unions_to_filter.rs new file mode 100644 index 0000000000000..3f7ef1e582410 --- /dev/null +++ b/datafusion/optimizer/benches/unions_to_filter.rs @@ -0,0 +1,195 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Microbenchmarks for the [`UnionsToFilter`] optimizer rule. +//! +//! Three scenarios are covered: +//! +//! 1. **merge** – N branches over the *same* table, each with a simple +//! equality filter. All branches should be merged into a single +//! `DISTINCT(Filter(OR …))` plan. +//! +//! 2. **no_merge** – N branches over *different* tables. The rule must +//! recognise that no merge is possible and leave the plan unchanged. +//! This exercises the "cold path" without any rewrite work. +//! +//! 3. **merge_with_projection** – N branches over the same table but each +//! branch wraps the filter in a `Projection`. This exercises the wrapper- +//! peeling and re-wrapping paths in addition to the core merge logic. +//! +//! To generate a flamegraph (requires `cargo-flamegraph`): +//! ```text +//! cargo flamegraph -p datafusion-optimizer --bench unions_to_filter \ +//! --flamechart --root --profile profiling --freq 1000 -- --bench +//! ``` + +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, logical_plan::table_scan}; +use datafusion_expr::{col, lit}; +use datafusion_optimizer::OptimizerContext; +use datafusion_optimizer::unions_to_filter::UnionsToFilter; +use datafusion_optimizer::{Optimizer, OptimizerRule}; +use std::hint::black_box; +use std::sync::Arc; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Build a three-column table scan for `name`. +fn scan(name: &str) -> LogicalPlan { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + table_scan(Some(name), &schema, None) + .unwrap() + .build() + .unwrap() +} + +/// Build a `DISTINCT (UNION ALL …)` plan whose `n` branches all filter over +/// the *same* table (`t`), so the rule can merge them. +fn build_merge_plan(n: usize) -> LogicalPlan { + assert!(n >= 2); + let mut builder: Option = None; + for i in 0..n { + let branch = LogicalPlanBuilder::from(scan("t")) + .filter(col("a").eq(lit(i as i32))) + .unwrap() + .build() + .unwrap(); + builder = Some(match builder { + None => LogicalPlanBuilder::from(branch), + Some(b) => b.union(branch).unwrap(), + }); + } + builder.unwrap().distinct().unwrap().build().unwrap() +} + +/// Build a `DISTINCT (UNION ALL …)` plan whose `n` branches each filter over a +/// *different* table, so no merge is possible. +fn build_no_merge_plan(n: usize) -> LogicalPlan { + assert!(n >= 2); + let mut builder: Option = None; + for i in 0..n { + let branch = LogicalPlanBuilder::from(scan(&format!("t{i}"))) + .filter(col("a").eq(lit(i as i32))) + .unwrap() + .build() + .unwrap(); + builder = Some(match builder { + None => LogicalPlanBuilder::from(branch), + Some(b) => b.union(branch).unwrap(), + }); + } + builder.unwrap().distinct().unwrap().build().unwrap() +} + +/// Build a `DISTINCT (UNION ALL …)` plan whose `n` branches each wrap the +/// filter inside a `Projection` over the *same* table. +fn build_merge_with_projection_plan(n: usize) -> LogicalPlan { + assert!(n >= 2); + let mut builder: Option = None; + for i in 0..n { + let branch = LogicalPlanBuilder::from(scan("t")) + .filter(col("a").eq(lit(i as i32))) + .unwrap() + .project(vec![col("a"), col("b")]) + .unwrap() + .build() + .unwrap(); + builder = Some(match builder { + None => LogicalPlanBuilder::from(branch), + Some(b) => b.union(branch).unwrap(), + }); + } + builder.unwrap().distinct().unwrap().build().unwrap() +} + +/// Run the [`UnionsToFilter`] rule through the full [`Optimizer`] pipeline +/// (single pass, feature flag enabled). +fn run_optimizer(plan: &LogicalPlan, ctx: &OptimizerContext) -> LogicalPlan { + let rules: Vec> = + vec![Arc::new(UnionsToFilter::new())]; + Optimizer::with_rules(rules) + .optimize(plan.clone(), ctx, |_, _| {}) + .unwrap() +} + +// --------------------------------------------------------------------------- +// Benchmark functions +// --------------------------------------------------------------------------- + +fn bench_merge(c: &mut Criterion) { + let mut options = ConfigOptions::default(); + options.optimizer.enable_unions_to_filter = true; + let ctx = + OptimizerContext::new_with_config_options(Arc::new(options)).with_max_passes(1); + + let mut group = c.benchmark_group("unions_to_filter/merge"); + for n in [2, 8, 32, 128] { + let plan = build_merge_plan(n); + group.bench_with_input(BenchmarkId::from_parameter(n), &plan, |b, p| { + b.iter(|| black_box(run_optimizer(p, &ctx))); + }); + } + group.finish(); +} + +fn bench_no_merge(c: &mut Criterion) { + let mut options = ConfigOptions::default(); + options.optimizer.enable_unions_to_filter = true; + let ctx = + OptimizerContext::new_with_config_options(Arc::new(options)).with_max_passes(1); + + let mut group = c.benchmark_group("unions_to_filter/no_merge"); + for n in [2, 8, 32, 128] { + let plan = build_no_merge_plan(n); + group.bench_with_input(BenchmarkId::from_parameter(n), &plan, |b, p| { + b.iter(|| black_box(run_optimizer(p, &ctx))); + }); + } + group.finish(); +} + +fn bench_merge_with_projection(c: &mut Criterion) { + let mut options = ConfigOptions::default(); + options.optimizer.enable_unions_to_filter = true; + let ctx = + OptimizerContext::new_with_config_options(Arc::new(options)).with_max_passes(1); + + let mut group = c.benchmark_group("unions_to_filter/merge_with_projection"); + for n in [2, 8, 32, 128] { + let plan = build_merge_with_projection_plan(n); + group.bench_with_input(BenchmarkId::from_parameter(n), &plan, |b, p| { + b.iter(|| black_box(run_optimizer(p, &ctx))); + }); + } + group.finish(); +} + +criterion_group!( + benches, + bench_merge, + bench_no_merge, + bench_merge_with_projection +); +criterion_main!(benches); diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/optimizer/src/analyzer/function_rewrite.rs index c6bf14ebce2e3..9faa60d939fe3 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs @@ -23,9 +23,9 @@ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{DFSchema, Result}; use crate::utils::NamePreserver; +use datafusion_expr::LogicalPlan; use datafusion_expr::expr_rewriter::FunctionRewrite; use datafusion_expr::utils::merge_schema; -use datafusion_expr::LogicalPlan; use std::sync::Arc; /// Analyzer rule that invokes [`FunctionRewrite`]s on expressions diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 272692f983683..ddb3b828f01dd 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -22,9 +22,9 @@ use std::sync::Arc; use log::debug; +use datafusion_common::Result; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; -use datafusion_common::Result; use datafusion_expr::expr_rewriter::FunctionRewrite; use datafusion_expr::{InvariantLevel, LogicalPlan}; diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index 6381db63122dd..95649ab8286b7 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -28,14 +28,14 @@ use arrow::datatypes::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ - internal_datafusion_err, plan_err, Column, DFSchema, Result, ScalarValue, + Column, DFSchema, Result, ScalarValue, internal_datafusion_err, plan_err, }; use datafusion_expr::expr::{AggregateFunction, Alias}; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::utils::grouping_set_to_exprlist; use datafusion_expr::{ - bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast, Aggregate, - Expr, Projection, + Aggregate, Expr, Projection, bitwise_and, bitwise_or, bitwise_shift_left, + bitwise_shift_right, cast, }; use itertools::Itertools; @@ -72,6 +72,7 @@ fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result>()) } +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key fn replace_grouping_exprs( input: Arc, schema: &DFSchema, @@ -96,12 +97,19 @@ fn replace_grouping_exprs( .into_iter() .zip(columns.into_iter().skip(group_expr_len + grouping_id_len)) { + let grouping_id_type = is_grouping_set + .then(|| { + schema + .field_with_name(None, Aggregate::INTERNAL_GROUPING_ID) + .map(|f| f.data_type().clone()) + }) + .transpose()?; match expr { Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => { let grouping_expr = grouping_function_on_id( function, &group_expr_to_bitmap_index, - is_grouping_set, + grouping_id_type, )?; projection_exprs.push(Expr::Alias(Alias::new( grouping_expr, @@ -109,6 +117,24 @@ fn replace_grouping_exprs( column.name, ))); } + Expr::Alias(Alias { + ref relation, + ref name, + .. + }) if is_grouping_function(&expr) => { + let function = unwrap_alias_to_grouping_function(&expr)?; + let grouping_expr = grouping_function_on_id( + function, + &group_expr_to_bitmap_index, + grouping_id_type, + )?; + // Preserve the outermost user-provided alias + projection_exprs.push(Expr::Alias(Alias::new( + grouping_expr, + relation.clone(), + name.clone(), + ))); + } _ => { projection_exprs.push(Expr::Column(column)); new_agg_expr.push(expr); @@ -147,10 +173,27 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { Ok(transformed_plan) } +/// Recursively unwrap `Expr::Alias` nodes to reach the inner `AggregateFunction`. +/// Returns an error if the innermost expression is not an `AggregateFunction`, +/// which should not happen if `is_grouping_function` returned true. +fn unwrap_alias_to_grouping_function(expr: &Expr) -> Result<&AggregateFunction> { + match expr { + Expr::AggregateFunction(function) => Ok(function), + Expr::Alias(Alias { expr, .. }) => unwrap_alias_to_grouping_function(expr), + _ => plan_err!("Expected grouping aggregate function inside alias, got {expr}"), + } +} + fn is_grouping_function(expr: &Expr) -> bool { // TODO: Do something better than name here should grouping be a built // in expression? - matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }) if func.name() == "grouping") + match expr { + Expr::AggregateFunction(AggregateFunction { func, .. }) => { + func.name() == "grouping" + } + Expr::Alias(Alias { expr, .. }) => is_grouping_function(expr), + _ => false, + } } fn contains_grouping_function(exprs: &[Expr]) -> bool { @@ -158,6 +201,7 @@ fn contains_grouping_function(exprs: &[Expr]) -> bool { } /// Validate that the arguments to the grouping function are in the group by clause. +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key fn validate_args( function: &AggregateFunction, group_by_expr: &HashMap<&Expr, usize>, @@ -178,43 +222,48 @@ fn validate_args( } } +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key fn grouping_function_on_id( function: &AggregateFunction, group_by_expr: &HashMap<&Expr, usize>, - is_grouping_set: bool, + // None means not a grouping set (result is always 0). + grouping_id_type: Option, ) -> Result { validate_args(function, group_by_expr)?; let args = &function.params.args; // Postgres allows grouping function for group by without grouping sets, the result is then // always 0 - if !is_grouping_set { + let Some(grouping_id_type) = grouping_id_type else { return Ok(Expr::Literal(ScalarValue::from(0i32), None)); - } - - let group_by_expr_count = group_by_expr.len(); - let literal = |value: usize| { - if group_by_expr_count < 8 { - Expr::Literal(ScalarValue::from(value as u8), None) - } else if group_by_expr_count < 16 { - Expr::Literal(ScalarValue::from(value as u16), None) - } else if group_by_expr_count < 32 { - Expr::Literal(ScalarValue::from(value as u32), None) - } else { - Expr::Literal(ScalarValue::from(value as u64), None) - } }; + // Use the actual __grouping_id column type to size literals correctly. This + // accounts for duplicate-ordinal bits that `Aggregate::grouping_id_type` + // packs into the high bits of the column, which a simple count of grouping + // expressions would miss. + let literal = |value: usize| match &grouping_id_type { + DataType::UInt8 => Expr::Literal(ScalarValue::from(value as u8), None), + DataType::UInt16 => Expr::Literal(ScalarValue::from(value as u16), None), + DataType::UInt32 => Expr::Literal(ScalarValue::from(value as u32), None), + DataType::UInt64 => Expr::Literal(ScalarValue::from(value as u64), None), + other => panic!("unexpected __grouping_id type: {other}"), + }; let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID)); - // The grouping call is exactly our internal grouping id - if args.len() == group_by_expr_count + if args.len() == group_by_expr.len() && args .iter() .rev() .enumerate() .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx)) { - return Ok(cast(grouping_id_column, DataType::Int32)); + let n = group_by_expr.len(); + // Mask the ordinal bits above position `n` so only the semantic bitmask is visible. + // checked_shl returns None when n >= 64 (all bits are semantic), mapping to u64::MAX. + let semantic_mask: u64 = 1u64.checked_shl(n as u32).map_or(u64::MAX, |m| m - 1); + let masked_id = + bitwise_and(grouping_id_column.clone(), literal(semantic_mask as usize)); + return Ok(cast(masked_id, DataType::Int32)); } args.iter() diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 4fb0f8553b4ba..032fe2524096e 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -17,42 +17,49 @@ //! Optimizer rule for type validation and coercion -use std::sync::Arc; - +use arrow::compute::can_cast_types; use datafusion_expr::binary::BinaryTypeCoercer; -use itertools::{izip, Itertools as _}; - -use arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; +use itertools::{Itertools as _, izip}; +use std::sync::{Arc, LazyLock}; use crate::analyzer::AnalyzerRule; use crate::utils::NamePreserver; + +use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; +use arrow::temporal_conversions::SECONDS_IN_DAY; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{ + Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, - plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, - TableReference, + plan_err, }; use datafusion_expr::expr::{ - self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList, - InSubquery, Like, ScalarFunction, Sort, WindowFunction, + self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, + HigherOrderFunction, InList, InSubquery, Like, ScalarFunction, SetComparison, Sort, + WindowFunction, }; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; -use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion}; +use datafusion_expr::type_coercion::binary::{ + comparison_coercion, like_coercion, type_union_coercion, +}; use datafusion_expr::type_coercion::functions::{ - data_types_with_scalar_udf, fields_with_aggregate_udf, + UDFCoercionExt, fields_with_udf, value_fields_with_higher_order_udf_and_lambdas, }; use datafusion_expr::type_coercion::other::{ - get_coerce_type_for_case_expression, get_coerce_type_for_list, + get_coerce_type_for_case_expression, get_coerce_type_for_case_when, + get_coerce_type_for_list, +}; +use datafusion_expr::type_coercion::{ + is_datetime, is_interval, is_signed_numeric, is_timestamp, }; -use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_utf8view_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ - is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - AggregateUDF, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection, - ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits, + Cast, Expr, ExprSchemable, Join, Limit, LogicalPlan, Operator, Projection, Union, + ValueOrLambda, WindowFrame, WindowFrameBound, WindowFrameUnits, is_false, + is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, lit, not, }; /// Performs type coercion by determining the schema @@ -90,11 +97,11 @@ impl AnalyzerRule for TypeCoercion { } fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result { - let empty_schema = DFSchema::empty(); + static EMPTY_SCHEMA: LazyLock = LazyLock::new(DFSchema::empty); // recurse let transformed_plan = plan - .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))? + .transform_up_with_subqueries(|plan| analyze_internal(&EMPTY_SCHEMA, plan))? .data; // finish @@ -290,17 +297,150 @@ impl<'a> TypeCoercionRewriter<'a> { right: Expr, right_schema: &DFSchema, ) -> Result<(Expr, Expr)> { - let (left_type, right_type) = BinaryTypeCoercer::new( - &left.get_type(left_schema)?, + let left_data_type = left.get_type(left_schema)?; + let right_data_type = right.get_type(right_schema)?; + let (left_type, right_type) = + BinaryTypeCoercer::new(&left_data_type, &op, &right_data_type) + .get_input_types()?; + let left_cast_ok = can_cast_types(&left_data_type, &left_type); + let right_cast_ok = can_cast_types(&right_data_type, &right_type); + + // handle special cases for + // * Date +/- int => Date + // * Date + time => Timestamp + let left_expr = if !left_cast_ok { + Self::coerce_date_time_math_op( + left, + &op, + &left_data_type, + &left_type, + &right_type, + )? + } else { + left.cast_to(&left_type, left_schema)? + }; + + let right_expr = if !right_cast_ok { + Self::coerce_date_time_math_op( + right, + &op, + &right_data_type, + &right_type, + &left_type, + )? + } else { + right.cast_to(&right_type, right_schema)? + }; + + Ok((left_expr, right_expr)) + } + + fn coerce_date_time_math_op( + expr: Expr, + op: &Operator, + left_current_type: &DataType, + left_target_type: &DataType, + right_target_type: &DataType, + ) -> Result { + use DataType::*; + + fn cast(expr: Expr, target_type: DataType) -> Expr { + Expr::Cast(Cast::new(Box::new(expr), target_type)) + } + + fn time_to_nanos( + expr: Expr, + expr_type: &DataType, + ) -> Result { + let expr = match expr_type { + Time32(TimeUnit::Second) => { + cast(cast(expr, Int32), Int64) + * lit(ScalarValue::Int64(Some(1_000_000_000))) + } + Time32(TimeUnit::Millisecond) => { + cast(cast(expr, Int32), Int64) + * lit(ScalarValue::Int64(Some(1_000_000))) + } + Time64(TimeUnit::Microsecond) => { + cast(expr, Int64) * lit(ScalarValue::Int64(Some(1_000))) + } + Time64(TimeUnit::Nanosecond) => cast(expr, Int64), + t => return internal_err!("Unexpected time data type {t}"), + }; + + Ok(expr) + } + + let e = match ( &op, - &right.get_type(right_schema)?, - ) - .get_input_types()?; + &left_current_type, + &left_target_type, + &right_target_type, + ) { + // int +/- date => date + ( + Operator::Plus | Operator::Minus, + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64, + Interval(IntervalUnit::MonthDayNano), + Date32 | Date64, + ) => { + // cast to i64 first + let expr = match *left_current_type { + Int64 => expr, + _ => cast(expr, Int64), + }; + // next, multiply by 86400 to get seconds + let expr = expr * lit(ScalarValue::from(SECONDS_IN_DAY)); + // cast to duration + let expr = cast(expr, Duration(TimeUnit::Second)); + // finally cast to interval + cast(expr, Interval(IntervalUnit::MonthDayNano)) + } + // These might seem to be a bit convoluted, however for arrow to do date + time arithmetic + // date must be cast to Timestamp(Nanosecond) and time cast to Duration(Nanosecond) + // (they must be the same timeunit). + // + // For Time32/64 we first need to cast to an Int64, convert that to nanoseconds based + // on the time unit, then cast that to duration. + // + // Time + date -> timestamp or + ( + Operator::Plus | Operator::Minus, + Time32(_) | Time64(_), + Duration(TimeUnit::Nanosecond), + Timestamp(TimeUnit::Nanosecond, None), + ) => { + // cast to int64, convert to nanoseconds + let expr = time_to_nanos(expr, left_current_type)?; + // cast to duration + cast(expr, Duration(TimeUnit::Nanosecond)) + } + // Similar to above, for arrow to do time - time we need to convert to an interval. + // To do that we first need to cast to an Int64, convert that to nanoseconds based + // on the time unit, then cast that to duration, then finally cast to an interval. + // + // Time - time -> timestamp + ( + Operator::Plus | Operator::Minus, + Time32(_) | Time64(_), + Interval(IntervalUnit::MonthDayNano), + Interval(IntervalUnit::MonthDayNano), + ) => { + // cast to int64, convert to nanoseconds + let expr = time_to_nanos(expr, left_current_type)?; + // cast to duration + let expr = cast(expr, Duration(TimeUnit::Nanosecond)); + // finally cast to interval + cast(expr, Interval(IntervalUnit::MonthDayNano)) + } + _ => { + return plan_err!( + "Cannot automatically convert {left_current_type} to {left_target_type}" + ); + } + }; - Ok(( - left.cast_to(&left_type, left_schema)?, - right.cast_to(&right_type, right_schema)?, - )) + Ok(e) } } @@ -368,6 +508,43 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { negated, )))) } + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => { + let new_plan = analyze_internal( + self.schema, + Arc::unwrap_or_clone(subquery.subquery), + )? + .data; + let expr_type = expr.get_type(self.schema)?; + let subquery_type = new_plan.schema().field(0).data_type(); + if (expr_type.is_numeric() && subquery_type.is_string()) + || (subquery_type.is_numeric() && expr_type.is_string()) + { + return plan_err!( + "expr type {expr_type} can't cast to {subquery_type} in SetComparison" + ); + } + let common_type = comparison_coercion(&expr_type, subquery_type).ok_or( + plan_datafusion_err!( + "expr type {expr_type} can't cast to {subquery_type} in SetComparison" + ), + )?; + let new_subquery = Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + spans: subquery.spans, + }; + Ok(Transformed::yes(Expr::SetComparison(SetComparison::new( + Box::new(expr.cast_to(&common_type, self.schema)?), + cast_subquery(new_subquery, &common_type)?, + op, + quantifier, + )))) + } Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( *expr, self.schema, @@ -390,6 +567,20 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( get_casted_expr_for_bool_op(*expr, self.schema)?, ))), + Expr::Negative(expr) => { + let data_type = expr.get_type(self.schema)?; + if data_type.is_null() + || is_signed_numeric(&data_type) + || is_interval(&data_type) + || is_timestamp(&data_type) + { + Ok(Transformed::no(Expr::Negative(expr))) + } else { + plan_err!( + "Negation only supports numeric, interval and timestamp types" + ) + } + } Expr::Like(Like { negated, expr, @@ -480,7 +671,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { get_coerce_type_for_list(&expr_data_type, &list_data_types); match result_type { None => plan_err!( - "Can not find compatible types to compare {expr_data_type} with [{}]", list_data_types.iter().join(", ") + "Can not find compatible types to compare {expr_data_type} with [{}]", + list_data_types.iter().join(", ") ), Some(coerced_type) => { // find the coerced type @@ -491,9 +683,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { list_expr.cast_to(&coerced_type, self.schema) }) .collect::>>()?; - Ok(Transformed::yes(Expr::InList(InList ::new( - Box::new(cast_expr), - cast_list_expr, + Ok(Transformed::yes(Expr::InList(InList::new( + Box::new(cast_expr), + cast_list_expr, negated, )))) } @@ -504,11 +696,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { Ok(Transformed::yes(Expr::Case(case))) } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let new_expr = coerce_arguments_for_signature_with_scalar_udf( - args, - self.schema, - &func, - )?; + let new_expr = + coerce_arguments_for_signature(args, self.schema, func.as_ref())?; Ok(Transformed::yes(Expr::ScalarFunction( ScalarFunction::new_udf(func, new_expr), ))) @@ -524,11 +713,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { null_treatment, }, }) => { - let new_expr = coerce_arguments_for_signature_with_aggregate_udf( - args, - self.schema, - &func, - )?; + let new_expr = + coerce_arguments_for_signature(args, self.schema, func.as_ref())?; + + let filter = filter + .map(|filter| filter.cast_to(&DataType::Boolean, self.schema)) + .transpose()? + .map(Box::new); + Ok(Transformed::yes(Expr::AggregateFunction( expr::AggregateFunction::new_udf( func, @@ -559,15 +751,18 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { let args = match &fun { expr::WindowFunctionDefinition::AggregateUDF(udf) => { - coerce_arguments_for_signature_with_aggregate_udf( - args, - self.schema, - udf, - )? + coerce_arguments_for_signature(args, self.schema, udf.as_ref())? + } + expr::WindowFunctionDefinition::WindowUDF(udf) => { + coerce_arguments_for_signature(args, self.schema, udf.as_ref())? } - _ => args, }; + let filter = filter + .map(|filter| filter.cast_to(&DataType::Boolean, self.schema)) + .transpose()? + .map(Box::new); + let new_expr = Expr::from(WindowFunction { fun, params: expr::WindowFunctionParams { @@ -582,6 +777,35 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }); Ok(Transformed::yes(new_expr)) } + Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => { + let current_fields = args + .iter() + .map(|arg| match arg { + Expr::Lambda(lambda) => Ok(ValueOrLambda::Lambda( + lambda.body.to_field(self.schema)?.1, + )), + _ => Ok(ValueOrLambda::Value(arg.to_field(self.schema)?.1)), + }) + .collect::>>()?; + + let new_fields = value_fields_with_higher_order_udf_and_lambdas( + ¤t_fields, + func.as_ref(), + )?; + + let new_args = std::iter::zip(args, new_fields) + .map(|(arg, new_field)| match (&arg, new_field) { + (Expr::Lambda(_lambda), ValueOrLambda::Lambda(_)) => Ok(arg), + (Expr::Lambda(_lambda), ValueOrLambda::Value(_)) => internal_err!("value_fields_with_higher_order_udf returned a value for a lambda argument"), + (_, ValueOrLambda::Value(new_field)) => arg.cast_to(new_field.data_type(), self.schema), + (_, ValueOrLambda::Lambda(_)) => internal_err!("value_fields_with_higher_order_udf returned a lambda for a value argument"), + }) + .collect::>()?; + + Ok(Transformed::yes(Expr::HigherOrderFunction( + HigherOrderFunction::new(func, new_args), + ))) + } // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::Alias(_) @@ -591,13 +815,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { | Expr::SimilarTo(_) | Expr::IsNotNull(_) | Expr::IsNull(_) - | Expr::Negative(_) | Expr::Cast(_) | Expr::TryCast(_) | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), + | Expr::OuterReferenceColumn(_, _) + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => Ok(Transformed::no(expr)), } } } @@ -687,7 +912,7 @@ fn coerce_scalar_range_aware( // If type coercion fails, check if the largest type in family works: if let Some(largest_type) = get_widest_type_in_family(target_type) { coerce_scalar(largest_type, value).map_or_else( - |_| exec_err!("Cannot cast {value:?} to {target_type}"), + |_| exec_err!("Cannot cast {value} to {target_type}"), |_| ScalarValue::try_from(target_type), ) } else { @@ -726,12 +951,15 @@ fn coerce_frame_bound( fn extract_window_frame_target_type(col_type: &DataType) -> Result { if col_type.is_numeric() - || is_utf8_or_utf8view_or_large_utf8(col_type) - || matches!(col_type, DataType::List(_)) - || matches!(col_type, DataType::LargeList(_)) - || matches!(col_type, DataType::FixedSizeList(_, _)) - || matches!(col_type, DataType::Null) - || matches!(col_type, DataType::Boolean) + || col_type.is_string() + || col_type.is_null() + || matches!( + col_type, + DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + | DataType::Boolean + ) { Ok(col_type.clone()) } else if is_datetime(col_type) { @@ -784,48 +1012,17 @@ fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result { /// `signature`, if possible. /// /// See the module level documentation for more detail on coercion. -fn coerce_arguments_for_signature_with_scalar_udf( +fn coerce_arguments_for_signature( expressions: Vec, schema: &DFSchema, - func: &ScalarUDF, + func: &F, ) -> Result> { - if expressions.is_empty() { - return Ok(expressions); - } - - let current_types = expressions - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - - let new_types = data_types_with_scalar_udf(¤t_types, func)?; - - expressions - .into_iter() - .enumerate() - .map(|(i, expr)| expr.cast_to(&new_types[i], schema)) - .collect() -} - -/// Returns `expressions` coerced to types compatible with -/// `signature`, if possible. -/// -/// See the module level documentation for more detail on coercion. -fn coerce_arguments_for_signature_with_aggregate_udf( - expressions: Vec, - schema: &DFSchema, - func: &AggregateUDF, -) -> Result> { - if expressions.is_empty() { - return Ok(expressions); - } - let current_fields = expressions .iter() .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()?; - let new_types = fields_with_aggregate_udf(¤t_fields, func)? + let coerced_types = fields_with_udf(¤t_fields, func)? .into_iter() .map(|f| f.data_type().clone()) .collect::>(); @@ -833,7 +1030,7 @@ fn coerce_arguments_for_signature_with_aggregate_udf( expressions .into_iter() .enumerate() - .map(|(i, expr)| expr.cast_to(&new_types[i], schema)) + .map(|(i, expr)| expr.cast_to(&coerced_types[i], schema)) .collect() } @@ -894,8 +1091,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { .iter() .map(|(when, _then)| when.get_type(schema)) .collect::>>()?; - let coerced_type = - get_coerce_type_for_case_expression(&when_types, Some(case_type)); + let coerced_type = get_coerce_type_for_case_when(&when_types, case_type); coerced_type.ok_or_else(|| { plan_datafusion_err!( "Failed to coerce case ({case_type}) and when ({}) \ @@ -973,7 +1169,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { /// **Field-level metadata merging**: Later fields take precedence for duplicate metadata keys. /// /// **Type coercion precedence**: The coerced type is determined by iteratively applying -/// `comparison_coercion()` between the accumulated type and each new input's type. The +/// `type_union_coercion()` between the accumulated type and each new input's type. The /// result depends on type coercion rules, not input order. /// /// **Nullability merging**: Nullability is accumulated using logical OR (`||`). @@ -996,7 +1192,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { /// ``` /// /// **Precedence Summary**: -/// - **Datatypes**: Determined by `comparison_coercion()` rules, not input order +/// - **Datatypes**: Determined by `type_union_coercion()` rules, not input order /// - **Nullability**: Later inputs can add nullability but cannot remove it (logical OR) /// - **Metadata**: Later inputs take precedence for same keys (HashMap::extend semantics) pub fn coerce_union_schema(inputs: &[Arc]) -> Result { @@ -1046,7 +1242,7 @@ fn coerce_union_schema_with_schema( plan_schema.fields().iter() ) { let coerced_type = - comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else( + type_union_coercion(union_datatype, plan_field.data_type()).ok_or_else( || { plan_datafusion_err!( "Incompatible inputs for Union: Previous inputs were \ @@ -1112,17 +1308,17 @@ fn project_with_column_index( #[cfg(test)] mod test { - use std::any::Any; + use std::sync::Arc; use arrow::datatypes::DataType::Utf8; use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, TimeUnit}; use insta::assert_snapshot; + use crate::analyzer::Analyzer; use crate::analyzer::type_coercion::{ - coerce_case_expression, TypeCoercion, TypeCoercionRewriter, + TypeCoercion, TypeCoercionRewriter, coerce_case_expression, }; - use crate::analyzer::Analyzer; use crate::assert_analyzed_plan_with_config_eq_snapshot; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{TransformedResult, TreeNode}; @@ -1131,10 +1327,10 @@ mod test { use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort}; use datafusion_expr::test::function_stub::avg_udaf; use datafusion_expr::{ - cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF, - BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan, - Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, - SimpleAggregateUDF, Subquery, Union, Volatility, + AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, Case, ColumnarValue, Expr, + ExprSchemable, Filter, LogicalPlan, Operator, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, Signature, SimpleAggregateUDF, Subquery, Union, Volatility, cast, + col, create_udaf, is_true, lit, }; use datafusion_functions_aggregate::average::AvgAccumulator; use datafusion_sql::TableReference; @@ -1235,6 +1431,17 @@ mod test { ) } + #[test] + fn negative_expr_wrapped_by_is_null_errors() -> Result<()> { + let predicate = Expr::IsNull(Box::new(Expr::Negative(Box::new(lit("a"))))); + let plan = LogicalPlan::Filter(Filter::try_new(predicate, empty())?); + + assert_type_coercion_error( + plan, + "Negation only supports numeric, interval and timestamp types", + ) + } + #[test] fn test_coerce_union() -> Result<()> { let left_plan = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { @@ -1305,7 +1512,7 @@ mod test { true, plan.clone(), @r" - Projection: CAST(a AS LargeUtf8) + Projection: CAST(a AS LargeUtf8) AS a EmptyRelation: rows=0 " )?; @@ -1341,7 +1548,7 @@ mod test { true, plan.clone(), @r" - Projection: CAST(a AS LargeUtf8) + Projection: CAST(a AS LargeUtf8) AS a EmptyRelation: rows=0 " )?; @@ -1371,7 +1578,7 @@ mod test { true, sort_plan.clone(), @r" - Projection: CAST(a AS LargeUtf8) + Projection: CAST(a AS LargeUtf8) AS a Sort: a ASC NULLS FIRST Projection: a EmptyRelation: rows=0 @@ -1400,7 +1607,7 @@ mod test { true, plan.clone(), @r" - Projection: CAST(a AS LargeUtf8) + Projection: CAST(a AS LargeUtf8) AS a Sort: a ASC NULLS FIRST Projection: a EmptyRelation: rows=0 @@ -1436,7 +1643,7 @@ mod test { true, plan.clone(), @r" - Projection: CAST(a AS LargeBinary) + Projection: CAST(a AS LargeBinary) AS a EmptyRelation: rows=0 " )?; @@ -1493,7 +1700,7 @@ mod test { true, sort_plan.clone(), @r" - Projection: CAST(a AS LargeBinary) + Projection: CAST(a AS LargeBinary) AS a Sort: a ASC NULLS FIRST Projection: a EmptyRelation: rows=0 @@ -1524,7 +1731,7 @@ mod test { true, plan.clone(), @r" - Projection: CAST(a AS LargeBinary) + Projection: CAST(a AS LargeBinary) AS a Sort: a ASC NULLS FIRST Projection: a EmptyRelation: rows=0 @@ -1559,10 +1766,6 @@ mod test { } impl ScalarUDFImpl for TestScalarUDF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "TestScalarUDF" } @@ -1580,6 +1783,31 @@ mod test { } } + #[derive(Debug, Hash, PartialEq, Eq)] + struct TestArrayElementUDF; + + impl ScalarUDFImpl for TestArrayElementUDF { + fn name(&self) -> &str { + "TestArrayElementUDF" + } + + fn signature(&self) -> &Signature { + static SIGNATURE: std::sync::LazyLock = + std::sync::LazyLock::new(|| { + Signature::array_and_index(Volatility::Immutable) + }); + &SIGNATURE + } + + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(Utf8) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) + } + } + #[test] fn scalar_udf() -> Result<()> { let empty = empty(); @@ -1754,7 +1982,10 @@ mod test { .err() .unwrap() .strip_backtrace(); - assert!(err.starts_with("Error during planning: Failed to coerce arguments to satisfy a call to 'avg' function: coercion from Utf8 to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed")); + assert!( + err.contains("Function 'avg' failed to match any signature"), + "Err: {err:?}" + ); Ok(()) } @@ -1882,7 +2113,7 @@ mod test { let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); assert_type_coercion_error( plan, - "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean" + "Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean", )?; // is not true @@ -2028,7 +2259,7 @@ mod test { let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); assert_type_coercion_error( plan, - "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean" + "Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean", )?; // is not unknown @@ -2211,6 +2442,9 @@ mod test { let actual = coerce_case_expression(case, &schema)?; assert_eq!(expected, actual); + // CASE string WHEN float/integer/string: comparison coercion + // prefers numeric, so the common type for the CASE expr and + // WHEN values is Float32. let case = Case { expr: Some(Box::new(col("string"))), when_then_expr: vec![ @@ -2220,7 +2454,7 @@ mod test { ], else_expr: Some(Box::new(col("string"))), }; - let case_when_common_type = Utf8; + let case_when_common_type = DataType::Float32; let then_else_common_type = Utf8; let expected = cast_helper( case.clone(), @@ -2465,7 +2699,34 @@ mod test { assert_analyzed_plan_eq!( plan, @r#" - Projection: a = CAST(CAST(a AS Map("key_value": Struct("key": Utf8, "value": nullable Float64), unsorted)) AS Map("entries": Struct("key": Utf8, "value": nullable Float64), unsorted)) + Projection: a = CAST(CAST(a AS Map("key_value": non-null Struct("key": non-null Utf8, "value": Float64), unsorted)) AS Map("entries": non-null Struct("key": non-null Utf8, "value": Float64), unsorted)) + EmptyRelation: rows=0 + "# + ) + } + + #[test] + fn array_element_preserves_parquet_list_field_name() -> Result<()> { + let list_type = DataType::List(Arc::new(Field::new( + "element", + DataType::Struct( + vec![ + Field::new("id", Utf8, true), + Field::new("prim", DataType::Boolean, true), + ] + .into(), + ), + true, + ))); + + let expr = ScalarUDF::from(TestArrayElementUDF).call(vec![col("a"), lit(1_i64)]); + let empty = empty_with_type(list_type); + let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); + + assert_analyzed_plan_eq!( + plan, + @r#" + Projection: TestArrayElementUDF(a, Int64(1)) EmptyRelation: rows=0 "# ) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 2510068494591..2775d62144c56 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -27,14 +27,16 @@ use crate::optimizer::ApplyOrder; use crate::utils::NamePreserver; use datafusion_common::alias::AliasGenerator; -use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE}; +use datafusion_common::cse::{CSE, CSEController, FoundCommonNodes}; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result}; -use datafusion_expr::expr::{Alias, ScalarFunction}; +use datafusion_common::{Column, DFSchema, DFSchemaRef, Result, qualified_name}; +use datafusion_expr::expr::{Alias, HigherOrderFunction, ScalarFunction}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; -use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator, SortExpr}; +use datafusion_expr::{ + BinaryExpr, Case, Expr, ExpressionPlacement, Operator, SortExpr, col, +}; const CSE_PREFIX: &str = "__common_expr"; @@ -323,11 +325,7 @@ impl CommonSubexprEliminate { .map(|expr| Some(name_preserver.save(expr))) .collect::>() } else { - new_aggr_expr - .clone() - .into_iter() - .map(|_| None) - .collect::>() + (0..new_aggr_expr.len()).map(|_| None).collect() }; let mut agg_exprs = common_exprs @@ -588,8 +586,12 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Unnest(_) | LogicalPlan::RecursiveQuery(_) => { // This rule handles recursion itself in a `ApplyOrder::TopDown` like - // manner. - plan.map_children(|c| self.rewrite(c, config))? + // manner. Process uncorrelated subqueries in expressions + // (e.g., Expr::ScalarSubquery), then direct children. + plan.map_uncorrelated_subqueries(|c| self.rewrite(c, config))? + .transform_sibling(|plan| { + plan.map_children(|c| self.rewrite(c, config)) + })? } }; @@ -649,12 +651,15 @@ impl CSEController for ExprCSEController<'_> { fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> { match node { - // In case of `ScalarFunction`s we don't know which children are surely + // In case of `ScalarFunction`s and `HigherOrderFunction`s we don't know which children are surely // executed so start visiting all children conditionally and stop the // recursion with `TreeNodeRecursion::Jump`. Expr::ScalarFunction(ScalarFunction { func, args }) => { func.conditional_arguments(args) } + Expr::HigherOrderFunction(HigherOrderFunction { func, args }) => { + func.conditional_arguments(args) + } // In case of `And` and `Or` the first child is surely executed, but we // account subexpressions as conditional in the second. @@ -695,18 +700,38 @@ impl CSEController for ExprCSEController<'_> { fn is_valid(node: &Expr) -> bool { !node.is_volatile_node() + && !matches!(node, Expr::Lambda(_) | Expr::LambdaVariable(_)) } fn is_ignored(&self, node: &Expr) -> bool { + // MoveTowardsLeafNodes expressions (e.g. get_field) are cheap struct + // field accesses that the ExtractLeafExpressions / PushDownLeafProjections + // rules deliberately duplicate when needed (one copy for a filter + // predicate, another for an output column). CSE deduplicating them + // creates intermediate projections that fight with those rules, + // causing optimizer instability — ExtractLeafExpressions will undo + // the dedup, creating an infinite loop that runs until the iteration + // limit is hit. Skip them. + if node.placement() == ExpressionPlacement::MoveTowardsLeafNodes { + return true; + } + // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] let is_normal_minus_aggregates = matches!( node, + // TODO: there's an argument for removing `Literal` from here, + // maybe using `Expr::placemement().should_push_to_leaves()` instead + // so that we extract common literals and don't broadcast them to num_batch_rows multiple times. + // However that currently breaks things like `percentile_cont()` which expect literal arguments + // (and would instead be getting `col(__common_expr_n)`). Expr::Literal(..) | Expr::Column(..) | Expr::ScalarVariable(..) | Expr::Alias(..) | Expr::Wildcard { .. } + | Expr::Lambda(_) + | Expr::LambdaVariable(_) ); let is_aggr = matches!(node, Expr::AggregateFunction(..)); @@ -810,21 +835,22 @@ fn extract_expressions(expr: &Expr, result: &mut Vec) { #[cfg(test)] mod test { - use std::any::Any; + use std::iter; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_expr::logical_plan::{table_scan, JoinType}; + use datafusion_expr::logical_plan::{JoinType, table_scan}; use datafusion_expr::{ - grouping_set, is_null, not, AccumulatorFactoryFunction, AggregateUDF, - ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, - SimpleAggregateUDF, Volatility, + AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarFunctionArgs, + ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, Volatility, + grouping_set, is_null, not, }; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use super::*; use crate::assert_optimized_plan_eq_snapshot; use crate::optimizer::OptimizerContext; + use crate::test::udfs::leaf_udf_expr; use crate::test::*; use datafusion_expr::test::function_stub::{avg, sum}; @@ -1680,9 +1706,6 @@ mod test { } impl ScalarUDFImpl for TestUdf { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "my_udf" } @@ -1806,10 +1829,6 @@ mod test { } } impl ScalarUDFImpl for RandomStub { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "random" } @@ -1826,4 +1845,56 @@ mod test { panic!("dummy - not implemented") } } + + /// Identical MoveTowardsLeafNodes expressions should NOT be deduplicated + /// by CSE — they are cheap (e.g. struct field access) and the extraction + /// rules deliberately duplicate them. Deduplicating causes optimizer + /// instability where one optimizer rule will undo the work of another, + /// resulting in an infinite optimization loop until the + /// we hit the max iteration limit and then give up. + #[test] + fn test_leaf_expression_not_extracted() -> Result<()> { + let table_scan = test_table_scan()?; + + let leaf = leaf_udf_expr(col("a")); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![leaf.clone().alias("c1"), leaf.alias("c2")])? + .build()?; + + // Plan should be unchanged — no __common_expr introduced + assert_optimized_plan_equal!( + plan, + @r" + Projection: leaf_udf(test.a) AS c1, leaf_udf(test.a) AS c2 + TableScan: test + " + ) + } + + /// When a MoveTowardsLeafNodes expression appears as a sub-expression of + /// a larger expression that IS duplicated, only the outer expression gets + /// deduplicated; the leaf sub-expression stays inline. + #[test] + fn test_leaf_subexpression_not_extracted() -> Result<()> { + let table_scan = test_table_scan()?; + + // leaf_udf(a) + b appears twice — the outer `+` is a common + // sub-expression, but leaf_udf(a) by itself is MoveTowardsLeafNodes + // and should not be extracted separately. + let common = leaf_udf_expr(col("a")) + col("b"); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![common.clone().alias("c1"), common.alias("c2")])? + .build()?; + + // The whole `leaf_udf(a) + b` gets deduplicated as __common_expr_1, + // but leaf_udf(a) alone is NOT pulled out. + assert_optimized_plan_equal!( + plan, + @r" + Projection: __common_expr_1 AS c1, __common_expr_1 AS c2 + Projection: leaf_udf(test.a) + test.b AS __common_expr_1, test.a, test.b, test.c + TableScan: test + " + ) + } } diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 63236787743a4..9490af0e59749 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -26,17 +26,18 @@ use crate::simplify_expressions::ExprSimplifier; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; -use datafusion_common::{plan_err, Column, DFSchemaRef, HashMap, Result, ScalarValue}; +use datafusion_common::{ + Column, DFSchemaRef, HashMap, Result, ScalarValue, assert_or_internal_err, plan_err, +}; use datafusion_expr::expr::Alias; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::{ collect_subquery_cols, conjunction, find_join_exprs, split_conjunction, }; use datafusion_expr::{ - expr, lit, BinaryExpr, Cast, EmptyRelation, Expr, FetchType, LogicalPlan, - LogicalPlanBuilder, Operator, + BinaryExpr, Cast, EmptyRelation, Expr, ExprSchemable, FetchType, LogicalPlan, + LogicalPlanBuilder, Operator, expr, lit, }; -use datafusion_physical_expr::execution_props::ExecutionProps; /// This struct rewrite the sub query plan by pull up the correlated /// expressions(contains outer reference columns) from the inner subquery's @@ -136,6 +137,12 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { fn f_down(&mut self, plan: LogicalPlan) -> Result> { match plan { LogicalPlan::Filter(_) => Ok(Transformed::no(plan)), + // Subquery nodes are scope boundaries for correlation. A nested + // Subquery's outer references belong to a different decorrelation + // level and must not be pulled up into the current scope. + LogicalPlan::Subquery(_) => { + Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) + } LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => { let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); if plan_hold_outer { @@ -180,7 +187,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { find_join_exprs(subquery_filter_exprs)?; if let Some(in_predicate) = &self.in_predicate_opt { // in_predicate may be already included in the join filters, remove it from the join filters first. - join_filters = remove_duplicated_filter(join_filters, in_predicate); + join_filters = remove_duplicated_filter(join_filters, in_predicate)?; } let correlated_subquery_cols = collect_subquery_cols(&join_filters, subquery_schema)?; @@ -461,25 +468,39 @@ fn collect_local_correlated_cols( } } -fn remove_duplicated_filter(filters: Vec, in_predicate: &Expr) -> Vec { - filters +fn remove_duplicated_filter( + filters: Vec, + in_predicate: &Expr, +) -> Result> { + // We assume below that swapping the order of operands to an operator does + // not change behavior, which is only true if the operator is commutative. + assert_or_internal_err!( + match in_predicate { + Expr::BinaryExpr(b) => b.op.swap() == Some(b.op), + _ => true, + }, + "remove_duplicated_filter: in_predicate must use a commutative operator" + ); + + Ok(filters .into_iter() .filter(|filter| { if filter == in_predicate { return false; } - // ignore the binary order + // Treat swapped operand order to a binary operator as equivalent !match (filter, in_predicate) { (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => { - (a_expr.op == b_expr.op) - && (a_expr.left == b_expr.left && a_expr.right == b_expr.right) - || (a_expr.left == b_expr.right && a_expr.right == b_expr.left) + a_expr.op == b_expr.op + && ((a_expr.left == b_expr.left && a_expr.right == b_expr.right) + || (a_expr.left == b_expr.right + && a_expr.right == b_expr.left)) } _ => false, } }) - .collect::>() + .collect::>()) } fn agg_exprs_evaluation_result_on_empty_batch( @@ -491,26 +512,21 @@ fn agg_exprs_evaluation_result_on_empty_batch( let result_expr = e .clone() .transform_up(|expr| { - let new_expr = match expr { - Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => { - if func.name() == "count" { - Transformed::yes(Expr::Literal( - ScalarValue::Int64(Some(0)), - None, - )) - } else { - Transformed::yes(Expr::Literal(ScalarValue::Null, None)) - } - } - _ => Transformed::no(expr), + let new_expr = if let Expr::AggregateFunction(agg) = &expr { + let return_type = expr.get_type(schema.as_ref())?; + let default_value = agg.func.default_value(&return_type)?; + Transformed::yes(Expr::Literal(default_value, None)) + } else { + Transformed::no(expr) }; Ok(new_expr) }) .data()?; let result_expr = result_expr.unalias(); - let props = ExecutionProps::new(); - let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema)); + let info = SimplifyContext::builder() + .with_schema(Arc::clone(schema)) + .build(); let simplifier = ExprSimplifier::new(info); let result_expr = simplifier.simplify(result_expr)?; expr_result_map_for_count_bug.insert(e.schema_name().to_string(), result_expr); @@ -543,8 +559,9 @@ fn proj_exprs_evaluation_result_on_empty_batch( .data()?; if result_expr.ne(expr) { - let props = ExecutionProps::new(); - let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema)); + let info = SimplifyContext::builder() + .with_schema(Arc::clone(schema)) + .build(); let simplifier = ExprSimplifier::new(info); let result_expr = simplifier.simplify(result_expr)?; let expr_name = match expr { @@ -584,8 +601,7 @@ fn filter_exprs_evaluation_result_on_empty_batch( .data()?; let pull_up_expr = if result_expr.ne(filter_expr) { - let props = ExecutionProps::new(); - let info = SimplifyContext::new(&props).with_schema(schema); + let info = SimplifyContext::builder().with_schema(schema).build(); let simplifier = ExprSimplifier::new(info); let result_expr = simplifier.simplify(result_expr)?; match &result_expr { diff --git a/datafusion/optimizer/src/decorrelate_lateral_join.rs b/datafusion/optimizer/src/decorrelate_lateral_join.rs index 7d2072ad1ce99..a8df5e69e3f33 100644 --- a/datafusion/optimizer/src/decorrelate_lateral_join.rs +++ b/datafusion/optimizer/src/decorrelate_lateral_join.rs @@ -17,27 +17,28 @@ //! [`DecorrelateLateralJoin`] decorrelates logical plans produced by lateral joins. -use std::collections::BTreeSet; +use std::sync::Arc; -use crate::decorrelate::PullUpCorrelatedExpr; +use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR}; use crate::optimizer::ApplyOrder; +use crate::utils::evaluates_to_null; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_expr::{lit, Join}; +use datafusion_expr::{Expr, Join, expr}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; -use datafusion_common::Result; -use datafusion_expr::logical_plan::JoinType; +use datafusion_common::{Column, DFSchema, Result, ScalarValue, TableReference}; +use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::conjunction; -use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, SubqueryAlias}; /// Optimizer rule for rewriting lateral joins to joins #[derive(Default, Debug)] pub struct DecorrelateLateralJoin {} impl DecorrelateLateralJoin { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self::default() } @@ -70,74 +71,303 @@ impl OptimizerRule for DecorrelateLateralJoin { } } -// Build the decorrelated join based on the original lateral join query. For now, we only support cross/inner -// lateral joins. +// Build the decorrelated join based on the original lateral join query. +// Supports INNER and LEFT lateral joins. fn rewrite_internal(join: Join) -> Result> { - if join.join_type != JoinType::Inner { + if !matches!(join.join_type, JoinType::Inner | JoinType::Left) { return Ok(Transformed::no(LogicalPlan::Join(join))); } + let original_join_type = join.join_type; - match join.right.apply_with_subqueries(|p| { - // TODO: support outer joins - if p.contains_outer_reference() { - Ok(TreeNodeRecursion::Stop) - } else { - Ok(TreeNodeRecursion::Continue) - } - })? { - TreeNodeRecursion::Stop => {} - TreeNodeRecursion::Continue => { - // The left side contains outer references, we need to decorrelate it. - return Ok(Transformed::new( - LogicalPlan::Join(join), - false, - TreeNodeRecursion::Jump, - )); - } - TreeNodeRecursion::Jump => { - unreachable!("") - } - } - - let LogicalPlan::Subquery(subquery) = join.right.as_ref() else { + // The right side is wrapped in a Subquery node when it contains outer + // references. Quickly skip joins that don't have this structure. + let Some((subquery, alias)) = extract_lateral_subquery(join.right.as_ref()) else { return Ok(Transformed::no(LogicalPlan::Join(join))); }; - if join.join_type != JoinType::Inner { + // If the subquery has no outer references, there is nothing to decorrelate. + // A LATERAL with no outer references is just a cross join. + let has_outer_refs = matches!( + subquery.subquery.apply_with_subqueries(|p| { + if p.contains_outer_reference() { + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + })?, + TreeNodeRecursion::Stop + ); + if !has_outer_refs { return Ok(Transformed::no(LogicalPlan::Join(join))); } + let subquery_plan = subquery.subquery.as_ref(); + let original_join_filter = join.filter.clone(); + + // Walk the subquery plan bottom-up, extracting correlated filter + // predicates into join conditions and converting ungrouped aggregates + // into group-by aggregates keyed on the correlation columns. let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true); let rewritten_subquery = subquery_plan.clone().rewrite(&mut pull_up).data()?; if !pull_up.can_pull_up { return Ok(Transformed::no(LogicalPlan::Join(join))); } - let mut all_correlated_cols = BTreeSet::new(); - pull_up - .correlated_subquery_cols_map - .values() - .for_each(|cols| all_correlated_cols.extend(cols.clone())); - let join_filter_opt = conjunction(pull_up.join_filters); - let join_filter = match join_filter_opt { - Some(join_filter) => join_filter, - None => lit(true), - }; - // -- inner join but the right side always has one row, we need to rewrite it to a left join - // SELECT * FROM t0, LATERAL (SELECT sum(v1) FROM t1 WHERE t0.v0 = t1.v0); - // -- inner join but the right side number of rows is related to the filter (join) condition, so keep inner join. - // SELECT * FROM t0, LATERAL (SELECT * FROM t1 WHERE t0.v0 = t1.v0); - let new_plan = LogicalPlanBuilder::from(join.left) - .join_on( - rewritten_subquery, + // TODO: support HAVING in lateral subqueries. + // + if pull_up.pull_up_having_expr.is_some() { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + + // The correlation predicates (extracted from the subquery's WHERE) become + // the rewritten join's ON clause. See below for discussion of how the + // user's original ON clause is handled. + let correlation_filter = conjunction(pull_up.join_filters); + + // Look up each aggregate's default value on empty input (e.g., COUNT → 0, + // SUM → NULL). This must happen before wrapping in SubqueryAlias, because + // the map is keyed by LogicalPlan and wrapping changes the plan. + let collected_count_expr_map = pull_up + .collected_count_expr_map + .get(&rewritten_subquery) + .cloned(); + + // Re-wrap in SubqueryAlias if the original had one, preserving the alias name. + // The SubqueryAlias re-qualifies all columns with the alias, so we must also + // rewrite column references in both the correlation and ON-clause filters. + let (right_plan, correlation_filter, original_join_filter) = + if let Some(ref alias) = alias { + let inner_schema = Arc::clone(rewritten_subquery.schema()); + let right = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( + Arc::new(rewritten_subquery), + alias.clone(), + )?); + let corr = correlation_filter + .map(|f| requalify_filter(f, &inner_schema, alias)) + .transpose()?; + let on = original_join_filter + .map(|f| requalify_filter(f, &inner_schema, alias)) + .transpose()?; + (right, corr, on) + } else { + (rewritten_subquery, correlation_filter, original_join_filter) + }; + + // For LEFT lateral joins, verify that all column references in the + // correlation filter are resolvable within the join's left and right + // schemas. If the lateral subquery references columns from an outer scope, + // the extracted filter will contain unresolvable columns and we must skip + // decorrelation. + // + // INNER lateral joins do not need this check: later optimizer passes + // (filter pushdown, join reordering) can restructure the plan to resolve + // cross-scope references. LEFT joins cannot be freely reordered. + if original_join_type == JoinType::Left + && let Some(ref filter) = correlation_filter + { + let left_schema = join.left.schema(); + let right_schema = right_plan.schema(); + let has_outer_scope_refs = filter + .column_refs() + .iter() + .any(|col| !left_schema.has_column(col) && !right_schema.has_column(col)); + if has_outer_scope_refs { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } + } + + // Use a left join when the user wrote LEFT LATERAL or when a scalar + // aggregation was pulled up (preserves outer rows with no matches). + let join_type = + if original_join_type == JoinType::Left || pull_up.pulled_up_scalar_agg { + JoinType::Left + } else { + JoinType::Inner + }; + + // The correlation predicates (extracted from the subquery's WHERE) are + // turned into the rewritten join's ON clause. There are three cases that + // determine how the user's original ON clause is handled: + // + // - INNER lateral: user ON clause becomes a post-join filter. This restores + // inner-join semantics if the join is upgraded to LEFT for count-bug + // handling. + // + // - LEFT lateral with grouped (or no) agg: user ON clause is merged into + // the rewritten ON clause, alongside the correlation predicates. LEFT + // join semantics correctly preserve unmatched rows with NULLs. + // + // - LEFT lateral with an ungrouped aggregate (which decorrelation converts + // to a group-by keyed on the correlation columns): user ON clause cannot + // be placed in the join condition (it would conflict with count-bug + // compensation) or as a post-join filter (that would remove + // left-preserved rows). Instead, a projection is added after count-bug + // compensation that replaces each right-side column with NULL when the ON + // condition is not satisfied: + // + // CASE WHEN (on_cond) IS NOT TRUE THEN NULL ELSE END + // + // This simulates LEFT JOIN semantics for the user's ON clause without + // interfering with count-bug compensation. + let (join_filter, post_join_filter, on_condition_for_projection) = + if original_join_type == JoinType::Left { if pull_up.pulled_up_scalar_agg { - JoinType::Left + (correlation_filter, None, original_join_filter) } else { - JoinType::Inner - }, - Some(join_filter), - )? + let combined = conjunction( + correlation_filter.into_iter().chain(original_join_filter), + ); + (combined, None, None) + } + } else { + (correlation_filter, original_join_filter, None) + }; + + let left_field_count = join.left.schema().fields().len(); + let new_plan = LogicalPlanBuilder::from(join.left) + .join_on(right_plan, join_type, join_filter)? .build()?; - // TODO: handle count(*) bug + + // Handle the count bug: in the rewritten left join, unmatched outer + // rows get NULLs for all right-side columns. But some aggregates + // have non-NULL defaults on empty input (e.g., COUNT returns 0, not + // NULL). Add a projection that wraps those columns: + // CASE WHEN __always_true IS NULL THEN ELSE END + let new_plan = if let Some(expr_map) = collected_count_expr_map { + let join_schema = new_plan.schema(); + let alias_qualifier = alias.as_ref(); + let mut proj_exprs: Vec = vec![]; + + for (i, (qualifier, field)) in join_schema.iter().enumerate() { + let col = Expr::Column(Column::new(qualifier.cloned(), field.name())); + + // Only compensate right-side (subquery) fields. Left-side fields + // may share a name with an aggregate alias but must not be wrapped. + let name = field.name(); + if i >= left_field_count + && let Some(default_value) = expr_map.get(name.as_str()) + && !evaluates_to_null(default_value.clone(), default_value.column_refs())? + { + // Column whose aggregate doesn't naturally return NULL + // on empty input (e.g., COUNT returns 0). Wrap it. + let indicator_col = + Column::new(alias_qualifier.cloned(), UN_MATCHED_ROW_INDICATOR); + let case_expr = Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![( + Box::new(Expr::IsNull(Box::new(Expr::Column(indicator_col)))), + Box::new(default_value.clone()), + )], + else_expr: Some(Box::new(col)), + }); + proj_exprs.push(Expr::Alias(expr::Alias { + expr: Box::new(case_expr), + relation: qualifier.cloned(), + name: name.to_string(), + metadata: None, + })); + continue; + } + proj_exprs.push(col); + } + + LogicalPlanBuilder::from(new_plan) + .project(proj_exprs)? + .build()? + } else { + new_plan + }; + + // For LEFT lateral joins with an ungrouped aggregate, simulate LEFT JOIN + // semantics for the user's ON clause by adding a projection that replaces + // right-side columns with NULL when the ON condition is false (see + // commentary above). + // + // Note: the ON condition expression is duplicated per column, so this + // assumes it is deterministic. + let new_plan = if let Some(on_cond) = on_condition_for_projection { + let schema = Arc::clone(new_plan.schema()); + let mut proj_exprs: Vec = vec![]; + + for (i, (qualifier, field)) in schema.iter().enumerate() { + let col = Expr::Column(Column::new(qualifier.cloned(), field.name())); + + if i < left_field_count { + proj_exprs.push(col); + continue; + } + + let typed_null = + Expr::Literal(ScalarValue::try_from(field.data_type())?, None); + let case_expr = Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![( + Box::new(Expr::IsNotTrue(Box::new(on_cond.clone()))), + Box::new(typed_null), + )], + else_expr: Some(Box::new(col)), + }); + proj_exprs.push(case_expr.alias_qualified(qualifier.cloned(), field.name())); + } + + LogicalPlanBuilder::from(new_plan) + .project(proj_exprs)? + .build()? + } else { + new_plan + }; + + // Apply the original ON clause as a post-join filter (INNER lateral only). + let new_plan = if let Some(on_filter) = post_join_filter { + LogicalPlanBuilder::from(new_plan) + .filter(on_filter)? + .build()? + } else { + new_plan + }; + Ok(Transformed::new(new_plan, true, TreeNodeRecursion::Jump)) } + +/// Extract the Subquery and optional alias from a lateral join's right side. +fn extract_lateral_subquery( + plan: &LogicalPlan, +) -> Option<(Subquery, Option)> { + match plan { + LogicalPlan::Subquery(sq) => Some((sq.clone(), None)), + LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { + if let LogicalPlan::Subquery(sq) = input.as_ref() { + Some((sq.clone(), Some(alias.clone()))) + } else { + None + } + } + _ => None, + } +} + +/// Rewrite column references in a join filter expression so that columns +/// belonging to the inner (right) side use the SubqueryAlias qualifier. +/// +/// The `PullUpCorrelatedExpr` pass extracts join filters with the inner +/// columns qualified by their original table names (e.g., `t2.t1_id`). +/// When the inner plan is wrapped in a `SubqueryAlias("sub")`, those +/// columns are re-qualified as `sub.t1_id`. This function applies the +/// same requalification to the filter so it matches the aliased schema. +fn requalify_filter( + filter: Expr, + inner_schema: &DFSchema, + alias: &TableReference, +) -> Result { + filter + .transform(|expr| { + if let Expr::Column(col) = &expr + && inner_schema.has_column(col) + { + let new_col = Column::new(Some(alias.clone()), col.name.clone()); + return Ok(Transformed::yes(Expr::Column(new_col))); + } + Ok(Transformed::no(expr)) + }) + .data() +} diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 0590aba52bfab..0609109ec6e58 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -27,14 +27,17 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{assert_or_internal_err, plan_err, Column, Result}; +use datafusion_common::{ + Column, DFSchemaRef, ExprSchema, NullEquality, Result, assert_or_internal_err, + plan_err, +}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::{conjunction, expr_to_columns, split_conjunction_owned}; use datafusion_expr::{ - exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, - LogicalPlan, LogicalPlanBuilder, Operator, + BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, exists, + in_subquery, lit, not, not_exists, not_in_subquery, }; use log::debug; @@ -44,7 +47,7 @@ use log::debug; pub struct DecorrelatePredicateSubquery {} impl DecorrelatePredicateSubquery { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self::default() } @@ -86,6 +89,7 @@ impl OptimizerRule for DecorrelatePredicateSubquery { // iterate through all exists clauses in predicate, turning each into a join let mut cur_input = Arc::unwrap_or_clone(filter.input); + let original_schema = cur_input.schema().columns(); for subquery_expr in with_subqueries { match extract_subquery_info(subquery_expr) { // The subquery expression is at the top level of the filter @@ -112,6 +116,13 @@ impl OptimizerRule for DecorrelatePredicateSubquery { let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; cur_input = LogicalPlan::Filter(new_filter); } + + if cur_input.schema().fields().len() != original_schema.len() { + cur_input = LogicalPlanBuilder::from(cur_input) + .project(original_schema.into_iter().map(Expr::from))? + .build()?; + } + Ok(Transformed::yes(cur_input)) } @@ -310,6 +321,39 @@ fn mark_join( ) } +/// Check if join keys in the join filter may contain NULL values +/// +/// Returns true if any join key column is nullable on either side. +/// This is used to optimize null-aware anti joins: if all join keys are non-nullable, +/// we can use a regular anti join instead of the more expensive null-aware variant. +fn join_keys_may_be_null( + join_filter: &Expr, + left_schema: &DFSchemaRef, + right_schema: &DFSchemaRef, +) -> Result { + // Extract columns from the join filter + let mut columns = std::collections::HashSet::new(); + expr_to_columns(join_filter, &mut columns)?; + + // Check if any column is nullable + for col in columns { + // Check in left schema + if let Ok(field) = left_schema.field_from_column(&col) + && field.as_ref().is_nullable() + { + return Ok(true); + } + // Check in right schema + if let Ok(field) = right_schema.field_from_column(&col) + && field.as_ref().is_nullable() + { + return Ok(true); + } + } + + Ok(false) +} + fn build_join( left: &LogicalPlan, subquery: &LogicalPlan, @@ -364,8 +408,8 @@ fn build_join( })), ) => { let right_col = create_col_from_scalar_expr(right.deref(), alias)?; - let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); - in_predicate + + Expr::eq(left.deref().clone(), Expr::Column(right_col)) } (None, None) => lit(true), _ => return Ok(None), @@ -403,6 +447,8 @@ fn build_join( // Degenerate case: no right columns referenced by the predicate(s) sub_query_alias.clone() }; + + // Mark joins don't use null-aware semantics (they use three-valued logic with mark column) let new_plan = LogicalPlanBuilder::from(left.clone()) .join_on(right_projected, join_type, Some(join_filter))? .build()?; @@ -415,10 +461,36 @@ fn build_join( return Ok(Some(new_plan)); } + // Determine if this should be a null-aware anti join + // Null-aware semantics are only needed for NOT IN subqueries, not NOT EXISTS: + // - NOT IN: Uses three-valued logic, requires null-aware handling + // - NOT EXISTS: Uses two-valued logic, regular anti join is correct + // We can distinguish them: NOT IN has in_predicate_opt, NOT EXISTS does not + // + // Additionally, if the join keys are non-nullable on both sides, we don't need + // null-aware semantics because NULLs cannot exist in the data. + let null_aware = join_type == JoinType::LeftAnti + && in_predicate_opt.is_some() + && join_keys_may_be_null(&join_filter, left.schema(), sub_query_alias.schema())?; + // join our sub query into the main plan - let new_plan = LogicalPlanBuilder::from(left.clone()) - .join_on(sub_query_alias, join_type, Some(join_filter))? - .build()?; + let new_plan = if null_aware { + // Use join_detailed_with_options to set null_aware flag + LogicalPlanBuilder::from(left.clone()) + .join_detailed_with_options( + sub_query_alias, + join_type, + (Vec::::new(), Vec::::new()), // No equijoin keys, filter-based join + Some(join_filter), + NullEquality::NullEqualsNothing, + true, // null_aware + )? + .build()? + } else { + LogicalPlanBuilder::from(left.clone()) + .join_on(sub_query_alias, join_type, Some(join_filter))? + .build()? + }; debug!( "predicate subquery optimized:\n{}", new_plan.display_indent() @@ -474,7 +546,7 @@ mod tests { use crate::assert_optimized_plan_eq_display_indent_snapshot; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::builder::table_source; - use datafusion_expr::{and, binary_expr, col, lit, not, out_ref_col, table_scan}; + use datafusion_expr::{and, binary_expr, col, out_ref_col, table_scan}; macro_rules! assert_optimized_plan_equal { ( @@ -613,7 +685,7 @@ mod tests { assert_optimized_plan_equal!( plan, - @r###" + @r" Projection: customer.c_custkey [c_custkey:Int64] LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8] LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8] @@ -624,7 +696,7 @@ mod tests { SubqueryAlias: __correlated_sq_2 [o_custkey:Int64] Projection: orders.o_custkey [o_custkey:Int64] TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - "### + " ) } @@ -1672,13 +1744,14 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean] - LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] - Projection: orders.o_custkey [o_custkey:Int64] - Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] + Filter: __correlated_sq_1.mark OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, mark:Boolean] + LeftMark Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, mark:Boolean] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __correlated_sq_1 [o_custkey:Int64] + Projection: orders.o_custkey [o_custkey:Int64] + Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] " ) } @@ -1977,7 +2050,7 @@ mod tests { TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [arr:Int32;N] Unnest: lists[sq.arr|depth=1] structs[] [arr:Int32;N] - TableScan: sq [arr:List(Field { data_type: Int32, nullable: true });N] + TableScan: sq [arr:List(Int32);N] " ) } @@ -2012,7 +2085,7 @@ mod tests { TableScan: test [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __correlated_sq_1 [a:UInt32;N] Unnest: lists[sq.a|depth=1] structs[] [a:UInt32;N] - TableScan: sq [a:List(Field { data_type: UInt32, nullable: true });N] + TableScan: sq [a:List(UInt32);N] " ) } diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index ae1d7df46d52e..95b70da443d88 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -20,20 +20,20 @@ use crate::{OptimizerConfig, OptimizerRule}; use std::sync::Arc; use crate::join_key_set::JoinKeySet; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{NullEquality, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, }; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; -use datafusion_expr::{and, build_join_schema, ExprSchemable, Operator}; +use datafusion_expr::{ExprSchemable, Operator, and, build_join_schema}; #[derive(Default, Debug)] pub struct EliminateCrossJoin; impl EliminateCrossJoin { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -85,6 +85,17 @@ impl OptimizerRule for EliminateCrossJoin { plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { + // Fast path: nothing to do if the plan contains no `Join` nodes. + // Without this guard the rule still falls through to + // `rewrite_children`, which walks the entire plan, processes + // uncorrelated subqueries, and rewrites every direct child via + // `map_children` (clone-on-write) — paid by every query in the + // logical optimizer pipeline. Same shape as the + // `plan_has_subqueries` fast-path landed in #22298. + if !plan_has_joins(&plan) { + return Ok(Transformed::no(plan)); + } + let plan_schema = Arc::clone(plan.schema()); let mut possible_join_keys = JoinKeySet::new(); let mut all_inputs: Vec = vec![]; @@ -207,12 +218,45 @@ impl OptimizerRule for EliminateCrossJoin { } } +/// Returns `true` if `plan` contains at least one [`LogicalPlan::Join`] +/// node, either directly in its tree *or* inside an embedded subquery +/// plan reachable through `Expr::ScalarSubquery` / `Expr::InSubquery` +/// / `Expr::Exists` / `Expr::SetComparison`. +/// +/// Used as a fast-path gate at the top of [`EliminateCrossJoin::rewrite`] +/// so that join-free plans skip the full recursive rewrite. Subquery +/// traversal matters because `rewrite_children` also dives into +/// uncorrelated subqueries via `map_uncorrelated_subqueries`; ignoring +/// them here would skip optimizing a `CROSS JOIN` that sits only inside +/// an `IN (SELECT ... FROM a, b)`-style predicate. +/// +/// `LogicalPlan::apply_with_subqueries` already implements the +/// "walk this node + every child + every subquery plan" traversal we +/// need, so the helper is a thin wrapper around it. +fn plan_has_joins(plan: &LogicalPlan) -> bool { + let mut found = false; + let _ = plan.apply_with_subqueries(|node| { + if matches!(node, LogicalPlan::Join(_)) { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }); + found +} + fn rewrite_children( optimizer: &impl OptimizerRule, plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - let transformed_plan = plan.map_children(|input| optimizer.rewrite(input, config))?; + // Process uncorrelated subqueries in expressions, then direct children. + let transformed_plan = plan + .map_uncorrelated_subqueries(|input| optimizer.rewrite(input, config))? + .transform_sibling(|plan| { + plan.map_children(|input| optimizer.rewrite(input, config)) + })?; // recompute schema if the plan was transformed if transformed_plan.transformed { @@ -276,10 +320,9 @@ fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool { join_type: JoinType::Inner, .. }) = child + && !can_flatten_join_inputs(child) { - if !can_flatten_join_inputs(child) { - return false; - } + return false; } } true @@ -316,10 +359,10 @@ fn find_inner_join( )?; // Save join keys - if let Some((valid_l, valid_r)) = key_pair { - if can_hash(&valid_l.get_type(left_input.schema())?) { - join_keys.push((valid_l, valid_r)); - } + if let Some((valid_l, valid_r)) = key_pair + && can_hash(&valid_l.get_type(left_input.schema())?) + { + join_keys.push((valid_l, valid_r)); } } @@ -342,6 +385,7 @@ fn find_inner_join( filter: None, schema: join_schema, null_equality, + null_aware: false, })); } } @@ -364,6 +408,7 @@ fn find_inner_join( join_type: JoinType::Inner, join_constraint: JoinConstraint::On, null_equality, + null_aware: false, })) } @@ -449,9 +494,9 @@ mod tests { use crate::test::*; use datafusion_expr::{ + Operator::{And, Or}, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, - Operator::{And, Or}, }; use insta::assert_snapshot; @@ -523,7 +568,7 @@ mod tests { plan, @ r" Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -609,7 +654,7 @@ mod tests { plan, @ r" Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -635,7 +680,7 @@ mod tests { plan, @ r" Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -857,7 +902,7 @@ mod tests { plan, @ r" Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] @@ -937,7 +982,7 @@ mod tests { TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t3 [a:UInt32, b:UInt32, c:UInt32] TableScan: t4 [a:UInt32, b:UInt32, c:UInt32] " @@ -1011,7 +1056,7 @@ mod tests { Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] @@ -1247,7 +1292,7 @@ mod tests { plan, @ r" Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32] TableScan: t1 [a:UInt32, b:UInt32, c:UInt32] TableScan: t2 [a:UInt32, b:UInt32, c:UInt32] " @@ -1368,6 +1413,7 @@ mod tests { filter: None, schema: join_schema, null_equality: NullEquality::NullEqualsNull, // Test preservation + null_aware: false, }); // Apply filter that can create join conditions @@ -1411,4 +1457,102 @@ mod tests { Ok(()) } + + // ---------------- fast-path tests ---------------- + + /// `plan_has_joins` detects a `Join` at the root of the plan. + #[test] + fn plan_has_joins_detects_root_join() -> Result<()> { + let plan = LogicalPlanBuilder::from(test_table_scan_with_name("t1")?) + .cross_join(test_table_scan_with_name("t2")?)? + .build()?; + assert!(plan_has_joins(&plan)); + Ok(()) + } + + /// `plan_has_joins` detects a `Join` nested under other operators. + #[test] + fn plan_has_joins_detects_nested_join() -> Result<()> { + let plan = LogicalPlanBuilder::from(test_table_scan_with_name("t1")?) + .cross_join(test_table_scan_with_name("t2")?)? + .filter(col("t1.a").eq(col("t2.a")))? + .project(vec![col("t1.a")])? + .build()?; + assert!(plan_has_joins(&plan)); + Ok(()) + } + + /// Join-free plans return `false` so the fast-path in `rewrite` can + /// bail out before doing any recursion. + #[test] + fn plan_has_joins_returns_false_for_join_free_plan() -> Result<()> { + let plan = LogicalPlanBuilder::from(test_table_scan_with_name("t1")?) + .filter(col("a").gt(lit(0_i32)))? + .project(vec![col("a"), col("b")])? + .build()?; + assert!(!plan_has_joins(&plan)); + Ok(()) + } + + /// `plan_has_joins` walks into embedded subquery plans — e.g. an + /// outer `Filter` whose predicate is `IN (SELECT ... FROM a, b)` + /// where the inner plan contains a `CROSS JOIN`. Without this the + /// fast-path would silently skip optimizing joins-in-subqueries + /// because `LogicalPlan::apply` doesn't descend into subquery + /// plan trees. + #[test] + fn plan_has_joins_detects_join_inside_subquery() -> Result<()> { + use datafusion_expr::in_subquery; + + // Subquery plan that itself contains a join. + let subquery_plan = + LogicalPlanBuilder::from(test_table_scan_with_name("sub_t1")?) + .cross_join(test_table_scan_with_name("sub_t2")?)? + .project(vec![col("sub_t1.a")])? + .build()?; + + // Outer plan with NO direct Join — only the IN subquery reaches one. + let outer = LogicalPlanBuilder::from(test_table_scan_with_name("t1")?) + .filter(in_subquery(col("a"), Arc::new(subquery_plan)))? + .project(vec![col("a")])? + .build()?; + + assert!( + plan_has_joins(&outer), + "plan_has_joins must descend into subquery plans" + ); + Ok(()) + } + + /// `EliminateCrossJoin::rewrite` short-circuits on join-free plans: + /// no recursion into `rewrite_children`, no `Transformed::yes`, + /// the plan comes back identical. + #[test] + fn rewrite_short_circuits_when_plan_has_no_joins() -> Result<()> { + let plan = LogicalPlanBuilder::from(test_table_scan_with_name("t1")?) + .filter(col("a").gt(lit(0_i32)))? + .project(vec![col("a"), col("b")])? + .build()?; + + let starting_display = plan.display_indent_schema().to_string(); + let starting_schema = Arc::clone(plan.schema()); + + let rule = EliminateCrossJoin::new(); + let Transformed { + transformed, + data: optimized_plan, + .. + } = rule.rewrite(plan, &OptimizerContext::new())?; + + assert!( + !transformed, + "join-free plan should not be marked as transformed" + ); + assert_eq!(&starting_schema, optimized_plan.schema()); + assert_eq!( + starting_display, + optimized_plan.display_indent_schema().to_string() + ); + Ok(()) + } } diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index a6651df938a70..97aa6e1d8480d 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -20,7 +20,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::Result; +use datafusion_common::{Result, get_required_sort_exprs_indices, internal_err}; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::{Aggregate, Expr, Sort, SortExpr}; use std::hash::{Hash, Hasher}; @@ -32,7 +32,7 @@ use indexmap::IndexSet; pub struct EliminateDuplicatedExpr; impl EliminateDuplicatedExpr { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -76,12 +76,36 @@ impl OptimizerRule for EliminateDuplicatedExpr { .map(|wrapper| wrapper.0) .collect(); + let sort_expr_names = unique_exprs + .iter() + .map(|sort_expr| sort_expr.expr.schema_name().to_string()) + .collect::>(); + let required_indices = get_required_sort_exprs_indices( + sort.input.schema().as_ref(), + &sort_expr_names, + ); + + let unique_exprs = if required_indices.len() < unique_exprs.len() { + required_indices + .into_iter() + .map(|idx| unique_exprs[idx].clone()) + .collect() + } else { + unique_exprs + }; + let transformed = if len != unique_exprs.len() { Transformed::yes } else { Transformed::no }; + if unique_exprs.is_empty() { + return internal_err!( + "FD pruning unexpectedly removed all ORDER BY expressions" + ); + } + Ok(transformed(LogicalPlan::Sort(Sort { expr: unique_exprs, input: sort.input, @@ -118,9 +142,9 @@ impl OptimizerRule for EliminateDuplicatedExpr { #[cfg(test)] mod tests { use super::*; + use crate::OptimizerContext; use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; - use crate::OptimizerContext; use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; @@ -130,7 +154,8 @@ mod tests { @ $expected:literal $(,)? ) => {{ let optimizer_ctx = OptimizerContext::new().with_max_passes(1); - let rules: Vec> = vec![Arc::new(EliminateDuplicatedExpr::new())]; + let rules: Vec> = + vec![Arc::new(EliminateDuplicatedExpr::new())]; assert_optimized_plan_eq_snapshot!( optimizer_ctx, rules, diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 1b763d6f8957b..8be5fb0857a9e 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -34,7 +34,7 @@ use crate::{OptimizerConfig, OptimizerRule}; pub struct EliminateFilter; impl EliminateFilter { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -81,10 +81,10 @@ impl OptimizerRule for EliminateFilter { mod tests { use std::sync::Arc; - use crate::assert_optimized_plan_eq_snapshot; use crate::OptimizerContext; + use crate::assert_optimized_plan_eq_snapshot; use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{col, lit, logical_plan::builder::LogicalPlanBuilder, Expr}; + use datafusion_expr::{Expr, col, lit, logical_plan::builder::LogicalPlanBuilder}; use crate::eliminate_filter::EliminateFilter; use crate::test::*; diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index 4e16fc0aa159c..e21241ba7d993 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -15,12 +15,15 @@ // specific language governing permissions and limitations // under the License. -//! [`EliminateGroupByConstant`] removes constant expressions from `GROUP BY` clause +//! [`EliminateGroupByConstant`] removes constant and functionally redundant +//! expressions from `GROUP BY` clause use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::Transformed; +use std::collections::HashSet; + use datafusion_common::Result; +use datafusion_common::tree_node::Transformed; use datafusion_expr::{Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Volatility}; /// Optimizer rule that removes constant expressions from `GROUP BY` clause @@ -47,25 +50,30 @@ impl OptimizerRule for EliminateGroupByConstant { ) -> Result> { match plan { LogicalPlan::Aggregate(aggregate) => { - let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>) = aggregate + // Collect bare column references in GROUP BY + let group_by_columns: HashSet<&datafusion_common::Column> = aggregate + .group_expr + .iter() + .filter_map(|expr| match expr { + Expr::Column(c) => Some(c), + _ => None, + }) + .collect(); + + let (redundant, required): (Vec<_>, Vec<_>) = aggregate .group_expr .iter() - .partition(|expr| is_constant_expression(expr)); - - // If no constant expressions found (nothing to optimize) or - // constant expression is the only expression in aggregate, - // optimization is skipped - if const_group_expr.is_empty() - || (!const_group_expr.is_empty() - && nonconst_group_expr.is_empty() - && aggregate.aggr_expr.is_empty()) + .partition(|expr| is_redundant_group_expr(expr, &group_by_columns)); + + if redundant.is_empty() + || (required.is_empty() && aggregate.aggr_expr.is_empty()) { return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); } let simplified_aggregate = LogicalPlan::Aggregate(Aggregate::try_new( aggregate.input, - nonconst_group_expr.into_iter().cloned().collect(), + required.into_iter().cloned().collect(), aggregate.aggr_expr.clone(), )?); @@ -91,23 +99,47 @@ impl OptimizerRule for EliminateGroupByConstant { } } -/// Checks if expression is constant, and can be eliminated from group by. -/// -/// Intended to be used only within this rule, helper function, which heavily -/// relies on `SimplifyExpressions` result. -fn is_constant_expression(expr: &Expr) -> bool { +/// Checks if a GROUP BY expression is redundant (can be removed without +/// changing grouping semantics). An expression is redundant if it is a +/// deterministic function of constants and columns already present as bare +/// column references in the GROUP BY. +fn is_redundant_group_expr( + expr: &Expr, + group_by_columns: &HashSet<&datafusion_common::Column>, +) -> bool { + // Bare column references are never redundant - they define the grouping + if matches!(expr, Expr::Column(_)) { + return false; + } + is_deterministic_of(expr, group_by_columns) +} + +/// Returns true if `expr` is a deterministic expression whose only column +/// references are contained in `known_columns`. +fn is_deterministic_of( + expr: &Expr, + known_columns: &HashSet<&datafusion_common::Column>, +) -> bool { match expr { - Expr::Alias(e) => is_constant_expression(&e.expr), + Expr::Alias(e) => is_deterministic_of(&e.expr, known_columns), + Expr::Column(c) => known_columns.contains(c), + Expr::Literal(_, _) => true, Expr::BinaryExpr(e) => { - is_constant_expression(&e.left) && is_constant_expression(&e.right) + is_deterministic_of(&e.left, known_columns) + && is_deterministic_of(&e.right, known_columns) } - Expr::Literal(_, _) => true, Expr::ScalarFunction(e) => { matches!( e.func.signature().volatility, Volatility::Immutable | Volatility::Stable - ) && e.args.iter().all(is_constant_expression) + ) && e + .args + .iter() + .all(|arg| is_deterministic_of(arg, known_columns)) } + Expr::Cast(e) => is_deterministic_of(&e.expr, known_columns), + Expr::TryCast(e) => is_deterministic_of(&e.expr, known_columns), + Expr::Negative(e) => is_deterministic_of(e, known_columns), _ => false, } } @@ -115,16 +147,15 @@ fn is_constant_expression(expr: &Expr) -> bool { #[cfg(test)] mod tests { use super::*; + use crate::OptimizerContext; use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; - use crate::OptimizerContext; use arrow::datatypes::DataType; - use datafusion_common::Result; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ - col, lit, ColumnarValue, LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, TypeSignature, + ColumnarValue, LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + Signature, TypeSignature, col, lit, }; use datafusion_functions_aggregate::expr_fn::count; @@ -161,9 +192,6 @@ mod tests { } impl ScalarUDFImpl for ScalarUDFMock { - fn as_any(&self) -> &dyn std::any::Any { - self - } fn name(&self) -> &str { "scalar_fn_mock" } @@ -268,6 +296,43 @@ mod tests { ") } + #[test] + fn test_eliminate_deterministic_expr_of_group_by_column() -> Result<()> { + let scan = test_table_scan()?; + // GROUP BY a, a - 1, a - 2, a - 3 -> GROUP BY a + let plan = LogicalPlanBuilder::from(scan) + .aggregate( + vec![ + col("a"), + col("a") - lit(1u32), + col("a") - lit(2u32), + col("a") - lit(3u32), + ], + vec![count(col("c"))], + )? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Projection: test.a, test.a - UInt32(1), test.a - UInt32(2), test.a - UInt32(3), count(test.c) + Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]] + TableScan: test + ") + } + + #[test] + fn test_no_eliminate_independent_columns() -> Result<()> { + // GROUP BY a, b - 1 should NOT eliminate b - 1 (b is not a group by column) + let scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .aggregate(vec![col("a"), col("b") - lit(1u32)], vec![count(col("c"))])? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Aggregate: groupBy=[[test.a, test.b - UInt32(1)]], aggr=[[count(test.c)]] + TableScan: test + ") + } + #[test] fn test_no_op_volatile_scalar_fn_with_constant_arg() -> Result<()> { let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility( diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 412bbea2ae92c..885910c1e4182 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -22,8 +22,8 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::JoinType::Inner; use datafusion_expr::{ - logical_plan::{EmptyRelation, LogicalPlan}, Expr, + logical_plan::{EmptyRelation, LogicalPlan}, }; /// Eliminates joins when join condition is false. @@ -74,9 +74,9 @@ impl OptimizerRule for EliminateJoin { #[cfg(test)] mod tests { + use crate::OptimizerContext; use crate::assert_optimized_plan_eq_snapshot; use crate::eliminate_join::EliminateJoin; - use crate::OptimizerContext; use datafusion_common::Result; use datafusion_expr::JoinType::Inner; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 8e25d3246f6c2..1ec3c856080eb 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -18,8 +18,8 @@ //! [`EliminateLimit`] eliminates `LIMIT` when possible use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::Transformed; use datafusion_common::Result; +use datafusion_common::tree_node::Transformed; use datafusion_expr::logical_plan::{EmptyRelation, FetchType, LogicalPlan, SkipType}; use std::sync::Arc; @@ -34,7 +34,7 @@ use std::sync::Arc; pub struct EliminateLimit; impl EliminateLimit { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -77,7 +77,7 @@ impl OptimizerRule for EliminateLimit { } else if matches!(limit.get_skip_type()?, SkipType::Literal(0)) { // If fetch is `None` and skip is 0, then Limit takes no effect and // we can remove it. Its input also can be Limit, so we should apply again. - #[allow(clippy::used_underscore_binding)] + #[expect(clippy::used_underscore_binding)] return self.rewrite(Arc::unwrap_or_clone(limit.input), _config); } Ok(Transformed::no(LogicalPlan::Limit(limit))) @@ -90,14 +90,13 @@ impl OptimizerRule for EliminateLimit { #[cfg(test)] mod tests { use super::*; - use crate::test::*; use crate::OptimizerContext; + use crate::test::*; use datafusion_common::Column; use datafusion_expr::{ col, - logical_plan::{builder::LogicalPlanBuilder, JoinType}, + logical_plan::{JoinType, builder::LogicalPlanBuilder}, }; - use std::sync::Arc; use crate::assert_optimized_plan_eq_snapshot; use crate::push_down_limit::PushDownLimit; diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 45877642f2766..4691eaf48b0b9 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -15,44 +15,71 @@ // specific language governing permissions and limitations // under the License. -//! [`EliminateOuterJoin`] converts `LEFT/RIGHT/FULL` joins to `INNER` joins +//! [`EliminateOuterJoin`] rewrites outer joins to simpler join types when +//! filters make the outer rows unnecessary (e.g. `LEFT`/`RIGHT` to `INNER`, +//! and `FULL` to `LEFT`/`RIGHT`/`INNER`). +use crate::push_down_filter::replace_cols_by_name; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{Column, DFSchema, Result}; -use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan}; +use datafusion_common::{Column, DFSchema, Result, qualified_name}; +use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, Projection}; use datafusion_expr::{Expr, Filter, Operator}; use crate::optimizer::ApplyOrder; use datafusion_common::tree_node::Transformed; -use datafusion_expr::expr::{BinaryExpr, Cast, TryCast}; +use datafusion_expr::expr::{BinaryExpr, Cast, InList, Like, TryCast}; +use std::collections::HashMap; use std::sync::Arc; +/// Attempt to simplify outer joins when filters make their null-padded +/// rows impossible to observe. /// -/// Attempt to replace outer joins with inner joins. +/// Outer joins are generally more expensive than inner joins and can block +/// predicate pushdown and other optimizations. When a filter above an outer +/// join removes every row the join would add for unmatched input rows, the +/// join can be changed to a cheaper join type. /// -/// Outer joins are typically more expensive to compute at runtime -/// than inner joins and prevent various forms of predicate pushdown -/// and other optimizations, so removing them if possible is beneficial. +/// For example: /// -/// Inner joins filter out rows that do match. Outer joins pass rows -/// that do not match padded with nulls. If there is a filter in the -/// query that would filter any such null rows after the join the rows -/// introduced by the outer join are filtered. +/// ```sql +/// SELECT ... +/// FROM a LEFT JOIN b ON ... +/// WHERE b.xx = 100 +/// ``` /// -/// For example, in the `select ... from a left join b on ... where b.xx = 100;` +/// For unmatched rows from `a`, the LEFT JOIN would produce a row with +/// `b.xx` set to NULL. The predicate `b.xx = 100` does not pass for those +/// rows, so the query does not need the LEFT JOIN's null-padded output and +/// the join can be rewritten as an inner join. /// -/// For rows when `b.xx` is null (as it would be after an outer join), -/// the `b.xx = 100` predicate filters them out and there is no -/// need to produce null rows for output. +/// The same reasoning can also simplify FULL joins to LEFT, RIGHT, or INNER +/// joins when filters remove the rows padded on one or both sides. /// -/// Generally, an outer join can be rewritten to inner join if the -/// filters from the WHERE clause return false while any inputs are -/// null and columns of those quals are come from nullable side of -/// outer join. +/// This rule looks for a filter above an outer join: +/// +/// ```text +/// Filter(predicate) +/// Join(LEFT/RIGHT/FULL) +/// ``` +/// +/// It also handles plan shapes where projection pruning has inserted one or +/// more Projection nodes between the filter and join: +/// +/// ```text +/// Filter(predicate over projection output) +/// Projection(...) +/// ... +/// Join(LEFT/RIGHT/FULL) +/// ``` +/// +/// In the projection case, the rule rewrites a copy of the predicate through +/// each Projection so it can analyze the predicate against the Join inputs. +/// The original filter predicate and Projection nodes are preserved when the +/// plan is rebuilt. #[derive(Default, Debug)] pub struct EliminateOuterJoin; impl EliminateOuterJoin { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -77,59 +104,136 @@ impl OptimizerRule for EliminateOuterJoin { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - match plan { - LogicalPlan::Filter(mut filter) => match Arc::unwrap_or_clone(filter.input) { - LogicalPlan::Join(join) => { - let mut non_nullable_cols: Vec = vec![]; - - extract_non_nullable_columns( - &filter.predicate, - &mut non_nullable_cols, - join.left.schema(), - join.right.schema(), - true, - ); + let LogicalPlan::Filter(filter) = plan else { + return Ok(Transformed::no(plan)); + }; - let new_join_type = if join.join_type.is_outer() { - let mut left_non_nullable = false; - let mut right_non_nullable = false; - for col in non_nullable_cols.iter() { - if join.left.schema().has_column(col) { - left_non_nullable = true; - } - if join.right.schema().has_column(col) { - right_non_nullable = true; - } - } - eliminate_outer( - join.join_type, - left_non_nullable, - right_non_nullable, - ) - } else { - join.join_type - }; + // Descend through one or more Projection nodes until we find a Join. + // For each Projection we encounter, rewrite a working copy of the + // predicate by replacing references to projection output columns with + // the expressions that define them. Keep the filter's original + // predicate intact for eventual use in the rebuilt plan; the rewritten + // predicate is used only for the null-rejection analysis. + let mut rewritten_predicate = filter.predicate.clone(); + let mut projections: Vec = Vec::new(); + let mut cur = Arc::clone(&filter.input); - let new_join = Arc::new(LogicalPlan::Join(Join { - left: join.left, - right: join.right, - join_type: new_join_type, - join_constraint: join.join_constraint, - on: join.on.clone(), - filter: join.filter.clone(), - schema: Arc::clone(&join.schema), - null_equality: join.null_equality, - })); - Filter::try_new(filter.predicate, new_join) - .map(|f| Transformed::yes(LogicalPlan::Filter(f))) + let new_join = loop { + match cur.as_ref() { + LogicalPlan::Projection(p) => { + rewritten_predicate = + inline_through_projection(rewritten_predicate, p)?; + let next = Arc::clone(&p.input); + projections.push(p.clone()); + cur = next; + } + LogicalPlan::Join(join) => { + let Some(new_join) = try_simplify_join(join, &rewritten_predicate) + else { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + }; + break new_join; } - filter_input => { - filter.input = Arc::new(filter_input); - Ok(Transformed::no(LogicalPlan::Filter(filter))) + _ => { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); } - }, - _ => Ok(Transformed::no(plan)), + } + }; + + let rebuilt_inner = rewrap_projections(new_join, projections); + Filter::try_new(filter.predicate, Arc::new(rebuilt_inner)) + .map(|f| Transformed::yes(LogicalPlan::Filter(f))) + } +} + +/// Run the null-rejection analysis on `predicate` against `join`'s left/right +/// schemas. Return `Some(new_join_plan)` if the join type can be tightened +/// (e.g. LEFT → INNER), `None` otherwise. +fn try_simplify_join(join: &Join, predicate: &Expr) -> Option { + if !join.join_type.is_outer() { + return None; + } + + let mut null_rejecting_cols: Vec = vec![]; + extract_null_rejecting_columns( + predicate, + &mut null_rejecting_cols, + join.left.schema(), + join.right.schema(), + true, + ); + + let mut left_non_nullable = false; + let mut right_non_nullable = false; + for col in null_rejecting_cols.iter() { + if join.left.schema().has_column(col) { + left_non_nullable = true; } + if join.right.schema().has_column(col) { + right_non_nullable = true; + } + } + + let new_join_type = + eliminate_outer(join.join_type, left_non_nullable, right_non_nullable); + if new_join_type == join.join_type { + return None; + } + + Some(LogicalPlan::Join(Join { + left: Arc::clone(&join.left), + right: Arc::clone(&join.right), + join_type: new_join_type, + join_constraint: join.join_constraint, + on: join.on.clone(), + filter: join.filter.clone(), + schema: Arc::clone(&join.schema), + null_equality: join.null_equality, + null_aware: join.null_aware, + })) +} + +/// Substitute the projection's output column references in `predicate` with +/// the projection's defining expressions (stripped of any `Alias` wrapper). +/// The result expresses `predicate` over the projection's *input* schema. +/// +/// Unlike `PushDownFilter`, this rule does not change expression evaluation +/// behavior (in fact, the rewritten expressions are only used for analysis +/// purposes). Therefore, function volatility and `MoveTowardsLeafNodes` +/// placement can be ignored here. +fn inline_through_projection(predicate: Expr, p: &Projection) -> Result { + let mut map: HashMap = HashMap::new(); + for ((qualifier, field), expr) in p.schema.iter().zip(p.expr.iter()) { + map.insert( + qualified_name(qualifier, field.name()), + unalias(expr).clone(), + ); + } + replace_cols_by_name(predicate, &map) +} + +/// Re-attach a stack of projections above `new_inner`, restoring the original +/// plan shape with the new (possibly retyped) join at the bottom. Projection +/// schemas are reused as-is; only nullability of columns sourced from the +/// formerly-outer side may have changed, and the existing rule already takes +/// this looser-schema approach at the join itself. +fn rewrap_projections( + new_inner: LogicalPlan, + projections: Vec, +) -> LogicalPlan { + let mut current = new_inner; + for mut p in projections.into_iter().rev() { + p.input = Arc::new(current); + current = LogicalPlan::Projection(p); + } + current +} + +fn unalias(expr: &Expr) -> &Expr { + if let Expr::Alias(a) = expr { + unalias(&a.expr) + } else { + expr } } @@ -138,88 +242,59 @@ pub fn eliminate_outer( left_non_nullable: bool, right_non_nullable: bool, ) -> JoinType { - let mut new_join_type = join_type; - match join_type { - JoinType::Left => { - if right_non_nullable { - new_join_type = JoinType::Inner; - } - } - JoinType::Right => { - if left_non_nullable { - new_join_type = JoinType::Inner; - } - } - JoinType::Full => { - if left_non_nullable && right_non_nullable { - new_join_type = JoinType::Inner; - } else if left_non_nullable { - new_join_type = JoinType::Left; - } else if right_non_nullable { - new_join_type = JoinType::Right; - } - } - _ => {} + match (join_type, left_non_nullable, right_non_nullable) { + (JoinType::Left, _, true) => JoinType::Inner, + (JoinType::Right, true, _) => JoinType::Inner, + (JoinType::Full, true, true) => JoinType::Inner, + (JoinType::Full, true, false) => JoinType::Left, + (JoinType::Full, false, true) => JoinType::Right, + _ => join_type, } - new_join_type } -/// Recursively traverses expr, if expr returns false when -/// any inputs are null, treats columns of both sides as non_nullable columns. +/// Find the columns that `expr` rejects NULL on. If any of these columns are +/// NULL, `expr` is guaranteed to evaluate to NULL or false, and the row +/// therefore cannot survive a WHERE clause. Matching columns are appended to +/// `null_rejecting_cols`. +/// +/// The caller uses the result to decide whether an outer join's null-padded +/// rows could survive the predicate above the join: if a column from the +/// nullable side appears in `null_rejecting_cols`, it cannot, and the outer +/// join can be converted to an inner join. /// -/// For and/or expr, extracts from all sub exprs and merges the columns. -/// For or expr, if one of sub exprs returns true, discards all columns from or expr. -/// For IS NOT NULL/NOT expr, always returns false for NULL input. -/// extracts columns from these exprs. -/// For all other exprs, fall through -fn extract_non_nullable_columns( +/// `left_schema` and `right_schema` are the join's two child schemas. +/// `top_level` is true at the root of the WHERE predicate and false on each +/// recursion. +fn extract_null_rejecting_columns( expr: &Expr, - non_nullable_cols: &mut Vec, + null_rejecting_cols: &mut Vec, left_schema: &Arc, right_schema: &Arc, top_level: bool, ) { match expr { Expr::Column(col) => { - non_nullable_cols.push(col.clone()); + null_rejecting_cols.push(col.clone()); } Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { - // If one of the inputs are null for these operators, the results should be false. - Operator::Eq - | Operator::NotEq - | Operator::Lt - | Operator::LtEq - | Operator::Gt - | Operator::GtEq => { - extract_non_nullable_columns( - left, - non_nullable_cols, - left_schema, - right_schema, - false, - ); - extract_non_nullable_columns( - right, - non_nullable_cols, - left_schema, - right_schema, - false, - ) - } Operator::And | Operator::Or => { - // treat And as Or if does not from top level, such as - // not (c1 < 10 and c2 > 100) + // AND distributes only down a top-level AND chain in the WHERE + // clause: each conjunct is independently null- rejecting, so + // any column either side discovers is a column the WHERE + // rejects NULL on. Once an AND appears below any other context, + // we fall back to the per-side analysis used for OR, because + // the context might influence whether the row is filtered. if top_level && *op == Operator::And { - extract_non_nullable_columns( + extract_null_rejecting_columns( left, - non_nullable_cols, + null_rejecting_cols, left_schema, right_schema, top_level, ); - extract_non_nullable_columns( + extract_null_rejecting_columns( right, - non_nullable_cols, + null_rejecting_cols, left_schema, right_schema, top_level, @@ -227,76 +302,139 @@ fn extract_non_nullable_columns( return; } - let mut left_non_nullable_cols: Vec = vec![]; - let mut right_non_nullable_cols: Vec = vec![]; - - extract_non_nullable_columns( + // OR (and nested AND): a row survives if EITHER operand returns + // true. We can credit a join side as null-rejecting only when + // BOTH operands independently reject NULL on a column from that + // side — otherwise the other branch could let the NULL row + // through. + let mut left_cols: Vec = vec![]; + let mut right_cols: Vec = vec![]; + extract_null_rejecting_columns( left, - &mut left_non_nullable_cols, + &mut left_cols, left_schema, right_schema, top_level, ); - extract_non_nullable_columns( + extract_null_rejecting_columns( right, - &mut right_non_nullable_cols, + &mut right_cols, left_schema, right_schema, top_level, ); - // for query: select *** from a left join b where b.c1 ... or b.c2 ... - // this can be eliminated to inner join. - // for query: select *** from a left join b where a.c1 ... or b.c2 ... - // this can not be eliminated. - // If columns of relation exist in both sub exprs, any columns of this relation - // can be added to non nullable columns. - if !left_non_nullable_cols.is_empty() - && !right_non_nullable_cols.is_empty() - { - for left_col in &left_non_nullable_cols { - for right_col in &right_non_nullable_cols { - if (left_schema.has_column(left_col) - && left_schema.has_column(right_col)) - || (right_schema.has_column(left_col) - && right_schema.has_column(right_col)) - { - non_nullable_cols.push(left_col.clone()); - break; - } - } + let find_on = |cols: &[Column], schema: &DFSchema| { + cols.iter().find(|c| schema.has_column(c)).cloned() + }; + for schema in [left_schema, right_schema] { + if let (Some(c), Some(_)) = + (find_on(&left_cols, schema), find_on(&right_cols, schema)) + { + null_rejecting_cols.push(c); } } } + // Any other operator that DataFusion declares as NULL-on-NULL: + // recurse into both operands so we collect their columns. + op if op.returns_null_on_null() => { + extract_null_rejecting_columns( + left, + null_rejecting_cols, + left_schema, + right_schema, + false, + ); + extract_null_rejecting_columns( + right, + null_rejecting_cols, + left_schema, + right_schema, + false, + ) + } + // All other operators (notably including IS [ NOT ] DISTINCT FROM) + // are declared as not null-propagating, so they don't contribute + // any null-rejecting columns. _ => {} }, - Expr::Not(arg) => extract_non_nullable_columns( + Expr::Not(arg) | Expr::Negative(arg) => extract_null_rejecting_columns( arg, - non_nullable_cols, + null_rejecting_cols, left_schema, right_schema, false, ), - Expr::IsNotNull(arg) => { + // IS NOT NULL / IS TRUE / IS FALSE / IS NOT UNKNOWN all return FALSE on + // NULL input. At the top of a WHERE clause, that FALSE filters the row + // and so we can recurse; below the top level the surrounding context + // may transform that FALSE into something that accepts NULL rows, + // making the recursion unsound. + Expr::IsNotNull(arg) + | Expr::IsTrue(arg) + | Expr::IsFalse(arg) + | Expr::IsNotUnknown(arg) => { if !top_level { return; } - extract_non_nullable_columns( + extract_null_rejecting_columns( arg, - non_nullable_cols, + null_rejecting_cols, left_schema, right_schema, false, ) } - Expr::Cast(Cast { expr, data_type: _ }) - | Expr::TryCast(TryCast { expr, data_type: _ }) => extract_non_nullable_columns( + Expr::Cast(Cast { expr, field: _ }) + | Expr::TryCast(TryCast { expr, field: _ }) => extract_null_rejecting_columns( + expr, + null_rejecting_cols, + left_schema, + right_schema, + false, + ), + // IN list and BETWEEN are null-rejecting on the input expression: + // NULL input yields a NULL result, regardless of whether the list + // or range bounds themselves contain NULLs. + Expr::InList(InList { expr, .. }) => extract_null_rejecting_columns( expr, - non_nullable_cols, + null_rejecting_cols, + left_schema, + right_schema, + false, + ), + Expr::Between(between) => extract_null_rejecting_columns( + &between.expr, + null_rejecting_cols, left_schema, right_schema, false, ), + Expr::Like(Like { expr, pattern, .. }) => { + extract_null_rejecting_columns( + expr, + null_rejecting_cols, + left_schema, + right_schema, + false, + ); + extract_null_rejecting_columns( + pattern, + null_rejecting_cols, + left_schema, + right_schema, + false, + ); + } + // Anything not handled above contributes no null-rejecting + // columns. Two categories worth calling out: + // - IS NULL, IS NOT TRUE, IS NOT FALSE, IS UNKNOWN — return + // TRUE on NULL input, so they actively *accept* NULL rows + // and are intentionally excluded. + // - Function calls (scalar / aggregate / window / UDF), + // scalar subqueries, struct/list accessors, aliases, + // literals, etc. — we don't have a uniform NULL-propagation + // guarantee for these cases, so we conservatively skip them. _ => {} } } @@ -304,15 +442,16 @@ fn extract_non_nullable_columns( #[cfg(test)] mod tests { use super::*; + use crate::OptimizerContext; use crate::assert_optimized_plan_eq_snapshot; use crate::test::*; - use crate::OptimizerContext; use arrow::datatypes::DataType; + use datafusion_common::ScalarValue; use datafusion_expr::{ + Operator::{And, Or}, binary_expr, cast, col, lit, logical_plan::builder::LogicalPlanBuilder, - try_cast, - Operator::{And, Or}, + not, try_cast, }; macro_rules! assert_optimized_plan_equal { @@ -436,11 +575,138 @@ mod tests { } #[test] - fn eliminate_full_with_type_cast() -> Result<()> { + fn eliminate_left_with_in_list() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; - // eliminate to inner join + // t2.b IN (1, 2, 3) rejects nulls — if t2.b is NULL the IN returns + // NULL which is filtered out. So Left Join should become Inner Join. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t2.b").in_list(vec![lit(1u32), lit(2u32), lit(3u32)], false))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IN ([UInt32(1), UInt32(2), UInt32(3)]) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_left_with_in_list_containing_null() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // IN list with NULL still rejects null input columns: + // if t2.b is NULL, NULL IN (1, NULL) evaluates to NULL, which is filtered out + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter( + col("t2.b") + .in_list(vec![lit(1u32), lit(ScalarValue::UInt32(None))], false), + )? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IN ([UInt32(1), UInt32(NULL)]) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_left_with_not_in_list() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // NOT IN also rejects nulls: if t2.b is NULL, NOT (NULL IN (...)) + // evaluates to NULL, which is filtered out + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t2.b").in_list(vec![lit(1u32), lit(2u32)], true))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b NOT IN ([UInt32(1), UInt32(2)]) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_left_with_between() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // BETWEEN rejects nulls: if t2.b is NULL, NULL BETWEEN 1 AND 10 + // evaluates to NULL, which is filtered out + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t2.b").between(lit(1u32), lit(10u32)))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b BETWEEN UInt32(1) AND UInt32(10) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_right_with_between() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Right join: filter on left (nullable) side with BETWEEN should convert to Inner + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Right, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t1.b").between(lit(1u32), lit(10u32)))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t1.b BETWEEN UInt32(1) AND UInt32(10) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_full_with_between() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Full join with BETWEEN on both sides should become Inner let plan = LogicalPlanBuilder::from(t1) .join( t2, @@ -449,17 +715,865 @@ mod tests { None, )? .filter(binary_expr( - cast(col("t1.b"), DataType::Int64).gt(lit(10u32)), + col("t1.b").between(lit(1u32), lit(10u32)), And, - try_cast(col("t2.c"), DataType::Int64).lt(lit(20u32)), + col("t2.b").between(lit(5u32), lit(20u32)), ))? .build()?; assert_optimized_plan_equal!(plan, @r" - Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) < UInt32(20) + Filter: t1.b BETWEEN UInt32(1) AND UInt32(10) AND t2.b BETWEEN UInt32(5) AND UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_full_with_in_list() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Full join with IN filters on both sides should become Inner + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Full, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(binary_expr( + col("t1.b").in_list(vec![lit(1u32), lit(2u32)], false), + And, + col("t2.b").in_list(vec![lit(3u32), lit(4u32)], false), + ))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t1.b IN ([UInt32(1), UInt32(2)]) AND t2.b IN ([UInt32(3), UInt32(4)]) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn no_eliminate_left_with_in_list_or_is_null() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // WHERE (t2.b IN (1, 2)) OR (t2.b IS NULL) + // The OR with IS NULL makes the predicate null-tolerant: + // when t2.b is NULL, IS NULL returns true, so the whole OR is true. + // The outer join must be preserved. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(binary_expr( + col("t2.b").in_list(vec![lit(1u32), lit(2u32)], false), + Or, + col("t2.b").is_null(), + ))? + .build()?; + + // Should NOT be converted to Inner — OR with IS NULL preserves null rows + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IN ([UInt32(1), UInt32(2)]) OR t2.b IS NULL + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_left_with_like() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // LIKE rejects nulls: if t2.b is NULL, the result is NULL (filtered out) + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t2.b").like(lit("%pattern%")))? + .build()?; + + assert_optimized_plan_equal!(plan, @r#" + Filter: t2.b LIKE Utf8("%pattern%") Inner Join: t1.a = t2.a TableScan: t1 TableScan: t2 + "#) + } + + #[test] + fn eliminate_left_with_like_pattern_column() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // LIKE with nullable column on the pattern side: + // 'x' LIKE t2.b → if t2.b is NULL, result is NULL (filtered out) + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(lit("x").like(col("t2.b")))? + .build()?; + + assert_optimized_plan_equal!(plan, @r#" + Filter: Utf8("x") LIKE t2.b + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + "#) + } + + #[test] + fn eliminate_full_with_like_cross_side() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // LIKE with columns from both sides: t1.c LIKE t2.b + // If t1 is NULL → NULL LIKE t2.b → NULL → filtered out (left non-nullable) + // If t2 is NULL → t1.c LIKE NULL → NULL → filtered out (right non-nullable) + // Both sides are non-nullable → FULL → INNER + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Full, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t1.c").like(col("t2.b")))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t1.c LIKE t2.b + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_left_with_is_true() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // IS TRUE rejects nulls: if the expression is NULL, IS TRUE returns false + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t2.b").gt(lit(10u32)).is_true())? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b > UInt32(10) IS TRUE + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_left_with_is_false() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // IS FALSE rejects nulls: if the expression is NULL, IS FALSE returns false + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t2.b").gt(lit(10u32)).is_false())? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b > UInt32(10) IS FALSE + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_left_with_is_not_unknown() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // IS NOT UNKNOWN rejects nulls: if the expression is NULL, IS NOT UNKNOWN returns false + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t2.b").gt(lit(10u32)).is_not_unknown())? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b > UInt32(10) IS NOT UNKNOWN + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn no_eliminate_left_with_is_not_true() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // IS NOT TRUE is NOT null-rejecting: if the expression is NULL, + // IS NOT TRUE returns true, so null rows pass through + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t2.b").gt(lit(10u32)).is_not_true())? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b > UInt32(10) IS NOT TRUE + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn no_eliminate_left_with_is_unknown() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // IS UNKNOWN is NOT null-rejecting: if the expression is NULL, + // IS UNKNOWN returns true, so null rows pass through + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t2.b").gt(lit(10u32)).is_unknown())? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b > UInt32(10) IS UNKNOWN + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn no_eliminate_left_with_not_is_true() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // NOT( IS TRUE) is equivalent to ( IS NOT TRUE): TRUE when + // is FALSE OR NULL. So `WHERE NOT((t2.b > 5) IS TRUE)` accepts + // rows where t2.b is NULL (because t2.b > 5 is NULL → IS TRUE is + // FALSE → NOT FALSE = TRUE). The LEFT JOIN must NOT be converted. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(not(col("t2.b").gt(lit(5u32)).is_true()))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: NOT t2.b > UInt32(5) IS TRUE + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn no_eliminate_left_with_not_is_false() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Same shape, IS FALSE: NOT( IS FALSE) accepts NULL on the + // inner column. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(not(col("t2.b").gt(lit(5u32)).is_false()))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: NOT t2.b > UInt32(5) IS FALSE + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn no_eliminate_left_with_not_is_not_unknown() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Same shape, IS NOT UNKNOWN: NOT( IS NOT UNKNOWN) is + // equivalent to ( IS UNKNOWN), which is TRUE when is NULL. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(not(col("t2.b").gt(lit(5u32)).is_not_unknown()))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: NOT t2.b > UInt32(5) IS NOT UNKNOWN + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_full_with_type_cast() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // eliminate to inner join + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Full, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(binary_expr( + cast(col("t1.b"), DataType::Int64).gt(lit(10u32)), + And, + try_cast(col("t2.c"), DataType::Int64).lt(lit(20u32)), + ))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: CAST(t1.b AS Int64) > UInt32(10) AND TRY_CAST(t2.c AS Int64) < UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + // ----- FULL JOIN → LEFT / RIGHT tests ----- + #[test] + fn eliminate_full_to_left_with_left_filter() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // FULL JOIN with null-rejecting filter only on left side → LEFT JOIN + // (left side becomes non-nullable, right side stays nullable) + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Full, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t1.b").gt(lit(10u32)))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t1.b > UInt32(10) + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_full_to_right_with_right_filter() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // FULL JOIN with null-rejecting filter only on right side → RIGHT JOIN + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Full, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t2.b").in_list(vec![lit(1u32), lit(2u32)], false))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IN ([UInt32(1), UInt32(2)]) + Right Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_full_to_left_with_like() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // FULL JOIN with LIKE on left side only → LEFT JOIN + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Full, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t1.b").like(lit("%val%")))? + .build()?; + + assert_optimized_plan_equal!(plan, @r#" + Filter: t1.b LIKE Utf8("%val%") + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + "#) + } + + #[test] + fn eliminate_full_to_right_with_is_true() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // FULL JOIN with IS TRUE on right side only → RIGHT JOIN + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Full, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(col("t2.b").gt(lit(10u32)).is_true())? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b > UInt32(10) IS TRUE + Right Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + // ----- Nested AND / OR tests ----- + + #[test] + fn eliminate_left_with_and_multiple_null_rejecting() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Multiple null-rejecting predicates combined with AND on nullable side + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(binary_expr( + col("t2.b").in_list(vec![lit(1u32), lit(2u32)], false), + And, + col("t2.c").between(lit(5u32), lit(20u32)), + ))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IN ([UInt32(1), UInt32(2)]) AND t2.c BETWEEN UInt32(5) AND UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_left_with_or_same_side() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // OR of two null-rejecting predicates on different columns of the same + // nullable side. If t2 rows are NULL (from LEFT JOIN), both t2.b and + // t2.c are NULL, so the entire OR evaluates to NULL → filtered out. + // This IS null-rejecting, so join should be eliminated. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(binary_expr( + col("t2.b").gt(lit(10u32)), + Or, + col("t2.c").lt(lit(20u32)), + ))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b > UInt32(10) OR t2.c < UInt32(20) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn no_eliminate_left_with_or_cross_side() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // OR with columns from different sides — t1.b (preserved) OR t2.b + // (nullable). When t2 is NULL, t1.b > 10 can still be true, so the + // OR is NOT null-rejecting. Join must be preserved. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(binary_expr( + col("t1.b").gt(lit(10u32)), + Or, + col("t2.b").lt(lit(20u32)), + ))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t1.b > UInt32(10) OR t2.b < UInt32(20) + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + // ----- Mixed predicate tests ----- + + #[test] + fn eliminate_full_with_mixed_predicates() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // FULL JOIN with different null-rejecting expr types on each side: + // LIKE on left, BETWEEN on right → INNER JOIN + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Full, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(binary_expr( + col("t1.b").like(lit("%pattern%")), + And, + col("t2.b").between(lit(1u32), lit(10u32)), + ))? + .build()?; + + assert_optimized_plan_equal!(plan, @r#" + Filter: t1.b LIKE Utf8("%pattern%") AND t2.b BETWEEN UInt32(1) AND UInt32(10) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + "#) + } + + #[test] + fn eliminate_left_with_is_true_and_in_list() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // AND of IS TRUE and IN on nullable side — both null-rejecting + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(binary_expr( + col("t2.b").gt(lit(5u32)).is_true(), + And, + col("t2.c").in_list(vec![lit(1u32), lit(2u32)], false), + ))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b > UInt32(5) IS TRUE AND t2.c IN ([UInt32(1), UInt32(2)]) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + // ----- Filter pierces a Projection to reach the Join ----- + + #[test] + fn eliminate_left_through_projection() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Filter → Projection → LeftJoin is the shape produced by projection + // pruning in queries such as TPC-DS q49, where the post-join + // Projection sits between the filter and the join. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .project(vec![col("t1.a"), col("t2.b").alias("bb")])? + .filter(col("bb").gt(lit(10u32)))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: bb > UInt32(10) + Projection: t1.a, t2.b AS bb + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn no_eliminate_left_through_projection_with_or_cross_side() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // After inlining the filter is still t1.b > 10 OR t2.b < 20, which + // is null-tolerant when t2 is NULL (the t1.b clause can still hold). + // The LEFT JOIN must be preserved. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .project(vec![col("t1.b").alias("x"), col("t2.b").alias("y")])? + .filter(binary_expr( + col("x").gt(lit(10u32)), + Or, + col("y").lt(lit(20u32)), + ))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: x > UInt32(10) OR y < UInt32(20) + Projection: t1.b AS x, t2.b AS y + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn no_eliminate_left_through_projection_with_only_left_filter() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // A filter that constrains only the preserved (left) side of a + // LEFT JOIN does not justify converting it to INNER — the LEFT + // would still pass nullable right-side rows that the filter + // accepts. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .project(vec![col("t1.b").alias("x"), col("t2.b")])? + .filter(col("x").gt(lit(10u32)))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: x > UInt32(10) + Projection: t1.b AS x, t2.b + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn eliminate_left_with_arithmetic_predicate() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // t2.b * 2 + 1 > 10 is null-rejecting on t2.b: arithmetic + // operators propagate NULL, so the whole expression is NULL when + // t2.b is NULL, and NULL > 10 is filtered out by WHERE. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter( + binary_expr( + binary_expr(col("t2.b"), Operator::Multiply, lit(2u32)), + Operator::Plus, + lit(1u32), + ) + .gt(lit(10u32)), + )? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b * UInt32(2) + UInt32(1) > UInt32(10) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + #[test] + fn eliminate_left_with_negative_predicate() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Unary minus propagates NULL: -NULL is NULL, so `WHERE -t2.b > 0` + // is null-rejecting on t2.b. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(Expr::Negative(Box::new(col("t2.b"))).gt(lit(0u32)))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: (- t2.b) > UInt32(0) + Inner Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn no_eliminate_left_with_is_distinct_from() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // IS DISTINCT FROM is NOT null-rejecting: t2.b IS DISTINCT FROM 5 is + // true when t2.b is NULL (NULL is distinct from 5). Padding rows from + // a LEFT JOIN would survive the filter, so the LEFT JOIN must stay. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(binary_expr( + col("t2.b"), + Operator::IsDistinctFrom, + lit(5u32), + ))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IS DISTINCT FROM UInt32(5) + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn no_eliminate_left_with_is_not_distinct_from() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // IS NOT DISTINCT FROM is also not null-rejecting: t2.b IS NOT + // DISTINCT FROM NULL is true when t2.b is NULL. The LEFT JOIN must + // stay. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .filter(binary_expr( + col("t2.b"), + Operator::IsNotDistinctFrom, + lit(ScalarValue::UInt32(None)), + ))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b IS NOT DISTINCT FROM UInt32(NULL) + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 + ") + } + + #[test] + fn no_eliminate_through_non_transparent() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + + // Limit is intentionally not treated as transparent: a Limit below + // the Filter changes which rows survive, so swapping LEFT→INNER + // beneath it could yield a different surviving-row set even when + // the filter is null-rejecting on the right side. + let plan = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Left, + (vec![Column::from_name("a")], vec![Column::from_name("a")]), + None, + )? + .limit(0, Some(5))? + .filter(col("t2.b").gt(lit(10u32)))? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Filter: t2.b > UInt32(10) + Limit: skip=0, fetch=5 + Left Join: t1.a = t2.a + TableScan: t1 + TableScan: t2 ") } } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 9228e84abf931..0a50761e8a9f7 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -19,7 +19,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::{assert_or_internal_err, DFSchema}; +use datafusion_common::{DFSchema, assert_or_internal_err}; use datafusion_common::{NullEquality, Result}; use datafusion_expr::utils::split_conjunction_owned; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; @@ -42,7 +42,7 @@ type EquijoinPredicate = (Expr, Expr); pub struct ExtractEquijoinPredicate; impl ExtractEquijoinPredicate { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -76,6 +76,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_constraint, schema, null_equality, + null_aware, }) => { let left_schema = left.schema(); let right_schema = right.schema(); @@ -117,6 +118,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { // According to `is not distinct from`'s semantics, it's // safe to override it null_equality: NullEquality::NullEqualsNull, + null_aware, }))); } } @@ -132,6 +134,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_constraint, schema, null_equality, + null_aware, }))) } else { Ok(Transformed::no(LogicalPlan::Join(Join { @@ -143,6 +146,7 @@ impl OptimizerRule for ExtractEquijoinPredicate { join_constraint, schema, null_equality, + null_aware, }))) } } @@ -273,7 +277,7 @@ mod tests { use crate::test::*; use arrow::datatypes::DataType; use datafusion_expr::{ - col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType, + JoinType, col, lit, logical_plan::builder::LogicalPlanBuilder, }; use std::sync::Arc; diff --git a/datafusion/optimizer/src/extract_leaf_expressions.rs b/datafusion/optimizer/src/extract_leaf_expressions.rs new file mode 100644 index 0000000000000..185f9d045f10f --- /dev/null +++ b/datafusion/optimizer/src/extract_leaf_expressions.rs @@ -0,0 +1,3089 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Two-pass optimizer pipeline that pushes cheap expressions (like struct field +//! access `user['status']`) closer to data sources, enabling early data reduction +//! and source-level optimizations (e.g., Parquet column pruning). See +//! [`ExtractLeafExpressions`] (pass 1) and [`PushDownLeafProjections`] (pass 2). + +use indexmap::{IndexMap, IndexSet}; +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion_common::alias::AliasGenerator; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::{Column, DFSchema, Result, qualified_name}; +use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::{Expr, ExpressionPlacement, Projection}; + +use crate::optimizer::ApplyOrder; +use crate::push_down_filter::replace_cols_by_name; +use crate::utils::{ColumnReference, has_all_column_refs, schema_columns}; +use crate::{OptimizerConfig, OptimizerRule}; + +/// Prefix for aliases generated by the extraction optimizer passes. +/// +/// This prefix is **reserved for internal optimizer use**. User-defined aliases +/// starting with this prefix may be misidentified as optimizer-generated +/// extraction aliases, leading to unexpected behavior. Do not use this prefix +/// in user queries. +const EXTRACTED_EXPR_PREFIX: &str = "__datafusion_extracted"; + +/// Returns `true` if any sub-expression in `exprs` has +/// [`ExpressionPlacement::MoveTowardsLeafNodes`] placement. +/// +/// This is a lightweight pre-check that short-circuits as soon as one +/// extractable expression is found, avoiding the expensive allocations +/// (column HashSets, extractors, expression rewrites) that the full +/// extraction pipeline requires. +fn has_extractable_expr(exprs: &[Expr]) -> bool { + exprs.iter().any(|expr| { + expr.exists(|e| Ok(e.placement() == ExpressionPlacement::MoveTowardsLeafNodes)) + .unwrap_or(false) + }) +} + +/// Extracts `MoveTowardsLeafNodes` sub-expressions from non-projection nodes +/// into **extraction projections** (pass 1 of 2). +/// +/// This handles Filter, Sort, Limit, Aggregate, and Join nodes. For Projection +/// nodes, extraction and pushdown are handled by [`PushDownLeafProjections`]. +/// +/// # Key Concepts +/// +/// **Extraction projection**: a projection inserted *below* a node that +/// pre-computes a cheap expression and exposes it under an alias +/// (`__datafusion_extracted_N`). The parent node then references the alias +/// instead of the original expression. +/// +/// **Recovery projection**: a projection inserted *above* a node to restore +/// the original output schema when extraction changes it. +/// Schema-preserving nodes (Filter, Sort, Limit) gain extra columns from +/// the extraction projection that bubble up; the recovery projection selects +/// only the original columns to hide the extras. +/// +/// # Example +/// +/// Given a filter with a struct field access: +/// +/// ```text +/// Filter: user['status'] = 'active' +/// TableScan: t [id, user] +/// ``` +/// +/// This rule: +/// 1. Inserts an **extraction projection** below the filter: +/// 2. Adds a **recovery projection** above to hide the extra column: +/// +/// ```text +/// Projection: id, user <-- recovery projection +/// Filter: __datafusion_extracted_1 = 'active' +/// Projection: user['status'] AS __datafusion_extracted_1, id, user <-- extraction projection +/// TableScan: t [id, user] +/// ``` +/// +/// **Important:** The `PushDownFilter` rule is aware of projections created by this rule +/// and will not push filters through them. It uses `ExpressionPlacement` to detect +/// `MoveTowardsLeafNodes` expressions and skip filter pushdown past them. +#[derive(Default, Debug)] +pub struct ExtractLeafExpressions {} + +impl ExtractLeafExpressions { + /// Create a new [`ExtractLeafExpressions`] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for ExtractLeafExpressions { + fn name(&self) -> &str { + "extract_leaf_expressions" + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + if !config.options().optimizer.enable_leaf_expression_pushdown { + return Ok(Transformed::no(plan)); + } + let alias_generator = config.alias_generator(); + + // Advance the alias generator past any user-provided __datafusion_extracted_N + // aliases to prevent collisions when generating new extraction aliases. + advance_generator_past_existing(&plan, alias_generator)?; + + plan.transform_down_with_subqueries(|plan| { + extract_from_plan(plan, alias_generator) + }) + } +} + +/// Scans the current plan node's expressions for pre-existing +/// `__datafusion_extracted_N` aliases and advances the generator +/// counter past them to avoid collisions with user-provided aliases. +fn advance_generator_past_existing( + plan: &LogicalPlan, + alias_generator: &AliasGenerator, +) -> Result<()> { + plan.apply(|plan| { + plan.expressions().iter().try_for_each(|expr| { + expr.apply(|e| { + if let Expr::Alias(alias) = e + && let Some(id) = alias + .name + .strip_prefix(EXTRACTED_EXPR_PREFIX) + .and_then(|s| s.strip_prefix('_')) + .and_then(|s| s.parse().ok()) + { + alias_generator.update_min_id(id); + } + Ok(TreeNodeRecursion::Continue) + })?; + Ok::<(), datafusion_common::error::DataFusionError>(()) + })?; + Ok(TreeNodeRecursion::Continue) + }) + .map(|_| ()) +} + +/// Extracts `MoveTowardsLeafNodes` sub-expressions from a plan node. +/// +/// Works for any number of inputs (0, 1, 2, …N). For multi-input nodes +/// like Join, each extracted sub-expression is routed to the correct input +/// by checking which input's schema contains all of the expression's column +/// references. +fn extract_from_plan( + plan: LogicalPlan, + alias_generator: &Arc, +) -> Result> { + // Only extract from plan types whose output schema is predictable after + // expression rewriting. Nodes like Window derive column names from + // their expressions, so rewriting `get_field` inside a window function + // changes the output schema and breaks the recovery projection. + if !matches!( + &plan, + LogicalPlan::Aggregate(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Sort(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Join(_) + ) { + return Ok(Transformed::no(plan)); + } + + let inputs = plan.inputs(); + if inputs.is_empty() { + return Ok(Transformed::no(plan)); + } + + // Fast pre-check: skip all allocations if no extractable expressions exist + if !has_extractable_expr(&plan.expressions()) { + return Ok(Transformed::no(plan)); + } + + // Save original output schema before any transformation + let original_schema = Arc::clone(plan.schema()); + + // Build per-input schemas from borrowed inputs (before plan is consumed + // by map_expressions). We only need schemas and column sets for routing; + // the actual inputs are cloned later only if extraction succeeds. + let input_schemas: Vec> = + inputs.iter().map(|i| Arc::clone(i.schema())).collect(); + + // Build per-input extractors + let mut extractors: Vec = input_schemas + .iter() + .map(|schema| LeafExpressionExtractor::new(schema.as_ref(), alias_generator)) + .collect(); + + // Build per-input column sets for routing expressions to the correct input + let input_column_sets: Vec> = + input_schemas + .iter() + .map(|schema| schema_columns(schema.as_ref())) + .collect(); + + // Transform expressions via map_expressions with routing + let transformed = plan.map_expressions(|expr| { + routing_extract(expr, &mut extractors, &input_column_sets) + })?; + + // If no expressions were rewritten, nothing was extracted + if !transformed.transformed { + return Ok(transformed); + } + + // Clone inputs now that we know extraction succeeded. Wrap in Arc + // upfront since build_extraction_projection expects &Arc. + let owned_inputs: Vec> = transformed + .data + .inputs() + .into_iter() + .map(|i| Arc::new(i.clone())) + .collect(); + + // Build per-input extraction projections (None means no extractions for that input) + let new_inputs: Vec = owned_inputs + .into_iter() + .zip(extractors.iter()) + .map(|(input_arc, extractor)| { + match extractor.build_extraction_projection(&input_arc)? { + Some(plan) => Ok(plan), + // No extractions for this input — recover the LogicalPlan + // without cloning (refcount is 1 since build returned None). + None => Ok(Arc::unwrap_or_clone(input_arc)), + } + }) + .collect::>>()?; + + // Rebuild the plan keeping its rewritten expressions but replacing + // inputs with the new extraction projections. + let new_plan = transformed + .data + .with_new_exprs(transformed.data.expressions(), new_inputs)?; + + // Add recovery projection if the output schema changed + let recovered = build_recovery_projection(original_schema.as_ref(), new_plan)?; + + Ok(Transformed::yes(recovered)) +} + +/// Given an expression, returns the index of the input whose columns fully +/// cover the expression's column references. +/// Returns `None` if the expression references columns from multiple inputs +/// or if multiple inputs match (ambiguous, e.g. unqualified columns present +/// in both sides of a join). +fn find_owning_input( + expr: &Expr, + input_column_sets: &[std::collections::HashSet], +) -> Option { + let mut found = None; + for (idx, cols) in input_column_sets.iter().enumerate() { + if has_all_column_refs(expr, cols) { + if found.is_some() { + // Ambiguous — multiple inputs match + return None; + } + found = Some(idx); + } + } + found +} + +/// Walks an expression tree top-down, extracting `MoveTowardsLeafNodes` +/// sub-expressions and routing each to the correct per-input extractor. +fn routing_extract( + expr: Expr, + extractors: &mut [LeafExpressionExtractor], + input_column_sets: &[std::collections::HashSet], +) -> Result> { + expr.transform_down(|e| { + // Skip expressions already aliased with extracted expression pattern + if let Expr::Alias(alias) = &e + && alias.name.starts_with(EXTRACTED_EXPR_PREFIX) + { + return Ok(Transformed { + data: e, + transformed: false, + tnr: TreeNodeRecursion::Jump, + }); + } + + // Don't extract Alias nodes directly — preserve the alias and let + // transform_down recurse into the inner expression + if matches!(&e, Expr::Alias(_)) { + return Ok(Transformed::no(e)); + } + + match e.placement() { + ExpressionPlacement::MoveTowardsLeafNodes => { + if let Some(idx) = find_owning_input(&e, input_column_sets) { + let col_ref = extractors[idx].add_extracted(e)?; + Ok(Transformed::yes(col_ref)) + } else { + // References columns from multiple inputs — cannot extract + Ok(Transformed::no(e)) + } + } + ExpressionPlacement::Column => { + // Track columns that the parent node references so the + // extraction projection includes them as pass-through. + // Without this, the extraction projection would only + // contain __datafusion_extracted_N aliases, and the parent couldn't + // resolve its other column references. + if let Expr::Column(col) = &e + && let Some(idx) = find_owning_input(&e, input_column_sets) + { + extractors[idx].columns_needed.insert(col.clone()); + } + Ok(Transformed::no(e)) + } + _ => Ok(Transformed::no(e)), + } + }) +} + +/// Rewrites extraction pairs and column references from one qualifier +/// space to another. +/// +/// Builds a replacement map by zipping `from_schema` (whose qualifiers +/// currently appear in `pairs` / `columns`) with `to_schema` (the +/// qualifiers we want), then applies `replace_cols_by_name`. +/// +/// Used for SubqueryAlias (alias-space -> input-space) and Union +/// (union output-space -> per-branch input-space). +fn remap_pairs_and_columns( + pairs: &[(Expr, String)], + columns: &IndexSet, + from_schema: &DFSchema, + to_schema: &DFSchema, +) -> Result { + let mut replace_map = HashMap::new(); + for ((from_q, from_f), (to_q, to_f)) in from_schema.iter().zip(to_schema.iter()) { + replace_map.insert( + qualified_name(from_q, from_f.name()), + Expr::Column(Column::new(to_q.cloned(), to_f.name())), + ); + } + let remapped_pairs: Vec<(Expr, String)> = pairs + .iter() + .map(|(expr, alias)| { + Ok(( + replace_cols_by_name(expr.clone(), &replace_map)?, + alias.clone(), + )) + }) + .collect::>()?; + let remapped_columns: IndexSet = columns + .iter() + .filter_map(|col| { + let rewritten = + replace_cols_by_name(Expr::Column(col.clone()), &replace_map).ok()?; + if let Expr::Column(c) = rewritten { + Some(c) + } else { + Some(col.clone()) + } + }) + .collect(); + Ok(ExtractionTarget { + pairs: remapped_pairs, + columns: remapped_columns, + }) +} + +// ============================================================================= +// Helper Types & Functions for Extraction Targeting +// ============================================================================= + +/// A bundle of extraction pairs (expression + alias) and standalone columns +/// that need to be pushed through a plan node. +struct ExtractionTarget { + /// Extracted expressions paired with their generated aliases. + pairs: Vec<(Expr, String)>, + /// Standalone column references needed by the parent node. + columns: IndexSet, +} + +/// Build a replacement map from a projection: output_column_name -> underlying_expr. +/// +/// This is used to resolve column references through a renaming projection. +/// For example, if a projection has `user AS x`, this maps `x` -> `col("user")`. +fn build_projection_replace_map(projection: &Projection) -> HashMap { + projection + .schema + .iter() + .zip(projection.expr.iter()) + .map(|((qualifier, field), expr)| { + let key = Column::from((qualifier, field)).flat_name(); + (key, expr.clone().unalias()) + }) + .collect() +} + +/// Build a recovery projection to restore the original output schema. +/// +/// After extraction, a node's output schema may differ from the original: +/// +/// - **Schema-preserving nodes** (Filter/Sort/Limit): the extraction projection +/// below adds extra `__datafusion_extracted_N` columns that bubble up through +/// the node. Recovery selects only the original columns to hide the extras. +/// ```text +/// Original schema: [id, user] +/// After extraction: [__datafusion_extracted_1, id, user] ← extra column leaked through +/// Recovery: SELECT id, user FROM ... ← hides __datafusion_extracted_1 +/// ``` +/// +/// - **Schema-defining nodes** (Aggregate): same number of columns but names +/// may differ because extracted aliases replaced the original expressions. +/// Recovery maps positionally, aliasing where names changed. +/// ```text +/// Original: [SUM(user['balance'])] +/// After: [SUM(__datafusion_extracted_1)] ← name changed +/// Recovery: SUM(__datafusion_extracted_1) AS "SUM(user['balance'])" +/// ``` +/// +/// - **Schemas identical** → no recovery projection needed. +fn build_recovery_projection( + original_schema: &DFSchema, + input: LogicalPlan, +) -> Result { + let new_schema = input.schema(); + let orig_len = original_schema.fields().len(); + let new_len = new_schema.fields().len(); + + if orig_len == new_len { + // Same number of fields — check if schemas are identical + let schemas_match = original_schema.iter().zip(new_schema.iter()).all( + |((orig_q, orig_f), (new_q, new_f))| { + orig_f.name() == new_f.name() && orig_q == new_q + }, + ); + if schemas_match { + return Ok(input); + } + + // Schema-defining nodes (Aggregate, Join): names may differ at some + // positions because extracted aliases replaced the original expressions. + // Map positionally, aliasing where the name changed. + // + // Invariant: `with_new_exprs` on all supported node types (Aggregate, + // Filter, Sort, Limit, Join) preserves column order, so positional + // mapping is safe here. + debug_assert!( + orig_len == new_len, + "build_recovery_projection: positional mapping requires same field count, \ + got original={orig_len} vs new={new_len}" + ); + let mut proj_exprs = Vec::with_capacity(orig_len); + for (i, (orig_qualifier, orig_field)) in original_schema.iter().enumerate() { + let (new_qualifier, new_field) = new_schema.qualified_field(i); + if orig_field.name() == new_field.name() && orig_qualifier == new_qualifier { + proj_exprs.push(Expr::from((orig_qualifier, orig_field))); + } else { + let new_col = Expr::Column(Column::from((new_qualifier, new_field))); + proj_exprs.push( + new_col.alias_qualified(orig_qualifier.cloned(), orig_field.name()), + ); + } + } + let projection = Projection::try_new(proj_exprs, Arc::new(input))?; + Ok(LogicalPlan::Projection(projection)) + } else { + // Schema-preserving nodes: new schema has extra extraction columns. + // Original columns still exist by name; select them to hide extras. + let col_exprs: Vec = original_schema.iter().map(Expr::from).collect(); + let projection = Projection::try_new(col_exprs, Arc::new(input))?; + Ok(LogicalPlan::Projection(projection)) + } +} + +/// Collects `MoveTowardsLeafNodes` sub-expressions found during expression +/// tree traversal and can build an extraction projection from them. +/// +/// # Example +/// +/// Given `Filter: user['status'] = 'active' AND user['name'] IS NOT NULL`: +/// - `add_extracted(user['status'])` → stores it, returns `col("__datafusion_extracted_1")` +/// - `add_extracted(user['name'])` → stores it, returns `col("__datafusion_extracted_2")` +/// - `build_extraction_projection()` produces: +/// `Projection: user['status'] AS __datafusion_extracted_1, user['name'] AS __datafusion_extracted_2, ` +struct LeafExpressionExtractor<'a> { + /// Extracted expressions: maps expression -> alias + extracted: IndexMap, + /// Columns referenced by extracted expressions or the parent node, + /// included as pass-through in the extraction projection. + columns_needed: IndexSet, + /// Input schema + input_schema: &'a DFSchema, + /// Alias generator + alias_generator: &'a Arc, +} + +impl<'a> LeafExpressionExtractor<'a> { + fn new(input_schema: &'a DFSchema, alias_generator: &'a Arc) -> Self { + Self { + extracted: IndexMap::new(), + columns_needed: IndexSet::new(), + input_schema, + alias_generator, + } + } + + /// Adds an expression to extracted set, returns column reference. + fn add_extracted(&mut self, expr: Expr) -> Result { + // Deduplication: reuse existing alias if same expression + if let Some(alias) = self.extracted.get(&expr) { + return Ok(Expr::Column(Column::new_unqualified(alias))); + } + + // Track columns referenced by this expression + for col in expr.column_refs() { + self.columns_needed.insert(col.clone()); + } + + // Generate unique alias + let alias = self.alias_generator.next(EXTRACTED_EXPR_PREFIX); + self.extracted.insert(expr, alias.clone()); + + Ok(Expr::Column(Column::new_unqualified(&alias))) + } + + /// Builds an extraction projection above the given input, or merges into + /// it if the input is already a projection. Delegates to + /// [`build_extraction_projection_impl`]. + /// + /// Returns `None` if there are no extractions. + fn build_extraction_projection( + &self, + input: &Arc, + ) -> Result> { + if self.extracted.is_empty() { + return Ok(None); + } + let pairs: Vec<(Expr, String)> = self + .extracted + .iter() + .map(|(e, a)| (e.clone(), a.clone())) + .collect(); + let proj = build_extraction_projection_impl( + &pairs, + &self.columns_needed, + input, + self.input_schema, + )?; + Ok(Some(LogicalPlan::Projection(proj))) + } +} + +/// Build an extraction projection above the target node (shared by both passes). +/// +/// If the target is an existing projection, merges into it. This requires +/// resolving column references through the projection's rename mapping: +/// if the projection has `user AS u`, and an extracted expression references +/// `u['name']`, we must rewrite it to `user['name']` since the merged +/// projection reads from the same input as the original. +/// +/// Deduplicates by resolved expression equality and adds pass-through +/// columns as needed. Otherwise builds a fresh projection with extracted +/// expressions + ALL input schema columns. +fn build_extraction_projection_impl( + extracted_exprs: &[(Expr, String)], + columns_needed: &IndexSet, + target: &Arc, + target_schema: &DFSchema, +) -> Result { + if let LogicalPlan::Projection(existing) = target.as_ref() { + // Merge into existing projection + let mut proj_exprs = existing.expr.clone(); + + // Build a map of existing expressions (by Expr equality) to their aliases + let existing_extractions: IndexMap = existing + .expr + .iter() + .filter_map(|e| { + if let Expr::Alias(alias) = e + && alias.name.starts_with(EXTRACTED_EXPR_PREFIX) + { + return Some((*alias.expr.clone(), alias.name.clone())); + } + None + }) + .collect(); + + // Resolve column references through the projection's rename mapping + let replace_map = build_projection_replace_map(existing); + + // Add new extracted expressions, resolving column refs through the projection + for (expr, alias) in extracted_exprs { + let resolved = replace_cols_by_name(expr.clone().alias(alias), &replace_map)?; + let resolved_inner = if let Expr::Alias(a) = &resolved { + a.expr.as_ref() + } else { + &resolved + }; + if let Some(existing_alias) = existing_extractions.get(resolved_inner) { + // Same expression already extracted under a different alias — + // add the expression with the new alias so both names are + // available in the output. We can't reference the existing alias + // as a column within the same projection, so we duplicate the + // computation. + if existing_alias != alias { + proj_exprs.push(resolved); + } + } else { + proj_exprs.push(resolved); + } + } + + // Add any new pass-through columns that aren't already in the projection. + // We check against existing.input.schema() (the projection's source) rather + // than target_schema (the projection's output) because columns produced + // by alias expressions (e.g., CSE's __common_expr_N) exist in the output but + // not the input, and cannot be added as pass-through Column references. + let existing_cols: IndexSet = existing + .expr + .iter() + .filter_map(|e| { + if let Expr::Column(c) = e { + Some(c.clone()) + } else { + None + } + }) + .collect(); + + let input_schema = existing.input.schema(); + for col in columns_needed { + let col_expr = Expr::Column(col.clone()); + let resolved = replace_cols_by_name(col_expr, &replace_map)?; + if let Expr::Column(resolved_col) = &resolved + && !existing_cols.contains(resolved_col) + && input_schema.has_column(resolved_col) + { + proj_exprs.push(Expr::Column(resolved_col.clone())); + } + // If resolved to non-column expr, it's already computed by existing projection + } + + Projection::try_new(proj_exprs, Arc::clone(&existing.input)) + } else { + // Build new projection with extracted expressions + all input columns + let mut proj_exprs = Vec::new(); + for (expr, alias) in extracted_exprs { + proj_exprs.push(expr.clone().alias(alias)); + } + for (qualifier, field) in target_schema.iter() { + proj_exprs.push(Expr::from((qualifier, field))); + } + Projection::try_new(proj_exprs, Arc::clone(target)) + } +} + +// ============================================================================= +// Pass 2: PushDownLeafProjections +// ============================================================================= + +/// Pushes extraction projections down through schema-preserving nodes towards +/// leaf nodes (pass 2 of 2, after [`ExtractLeafExpressions`]). +/// +/// Handles two types of projections: +/// - **Pure extraction projections** (all `__datafusion_extracted` aliases + columns): +/// pushes through Filter/Sort/Limit, merges into existing projections, or routes +/// into multi-input node inputs (Join, SubqueryAlias, etc.) +/// - **Mixed projections** (user projections containing `MoveTowardsLeafNodes` +/// sub-expressions): splits into a recovery projection + extraction projection, +/// then pushes the extraction projection down. +/// +/// # Example: Pushing through a Filter +/// +/// After pass 1, the extraction projection sits directly below the filter: +/// ```text +/// Projection: id, user <-- recovery +/// Filter: __datafusion_extracted_1 = 'active' +/// Projection: user['status'] AS __datafusion_extracted_1, id, user <-- extraction +/// TableScan: t [id, user] +/// ``` +/// +/// Pass 2 pushes the extraction projection through the recovery and filter, +/// and a subsequent `OptimizeProjections` pass removes the (now-redundant) +/// recovery projection: +/// ```text +/// Filter: __datafusion_extracted_1 = 'active' +/// Projection: user['status'] AS __datafusion_extracted_1, id, user <-- extraction (pushed down) +/// TableScan: t [id, user] +/// ``` +#[derive(Default, Debug)] +pub struct PushDownLeafProjections {} + +impl PushDownLeafProjections { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for PushDownLeafProjections { + fn name(&self) -> &str { + "push_down_leaf_projections" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + if !config.options().optimizer.enable_leaf_expression_pushdown { + return Ok(Transformed::no(plan)); + } + let alias_generator = config.alias_generator(); + match try_push_input(&plan, alias_generator)? { + Some(new_plan) => Ok(Transformed::yes(new_plan)), + None => Ok(Transformed::no(plan)), + } + } +} + +/// Attempts to push a projection's extractable expressions further down. +/// +/// Returns `Some(new_subtree)` if the projection was pushed down or merged, +/// `None` if there is nothing to push or the projection sits above a barrier. +fn try_push_input( + input: &LogicalPlan, + alias_generator: &Arc, +) -> Result> { + let LogicalPlan::Projection(proj) = input else { + return Ok(None); + }; + split_and_push_projection(proj, alias_generator) +} + +/// Splits a projection into extractable pieces, pushes them towards leaf +/// nodes, and adds a recovery projection if needed. +/// +/// Handles both: +/// - **Pure extraction projections** (all `__datafusion_extracted` aliases + columns) +/// - **Mixed projections** (containing `MoveTowardsLeafNodes` sub-expressions) +/// +/// Returns `Some(new_subtree)` if extractions were pushed down, +/// `None` if there is nothing to extract or push. +/// +/// # Example: Mixed Projection +/// +/// ```text +/// Input plan: +/// Projection: user['name'] IS NOT NULL AS has_name, id +/// Filter: ... +/// TableScan +/// +/// Phase 1 (Split): +/// extraction_pairs: [(user['name'], "__datafusion_extracted_1")] +/// recovery_exprs: [__datafusion_extracted_1 IS NOT NULL AS has_name, id] +/// +/// Phase 2 (Push): +/// Push extraction projection through Filter toward TableScan +/// +/// Phase 3 (Recovery): +/// Projection: __datafusion_extracted_1 IS NOT NULL AS has_name, id <-- recovery +/// Filter: ... +/// Projection: user['name'] AS __datafusion_extracted_1, id <-- extraction (pushed) +/// TableScan +/// ``` +fn split_and_push_projection( + proj: &Projection, + alias_generator: &Arc, +) -> Result> { + // Fast pre-check: skip if there are no pre-existing extracted aliases + // and no new extractable expressions. + let has_existing_extracted = proj.expr.iter().any(|e| { + matches!(e, Expr::Alias(alias) if alias.name.starts_with(EXTRACTED_EXPR_PREFIX)) + }); + if !has_existing_extracted && !has_extractable_expr(&proj.expr) { + return Ok(None); + } + + let input = &proj.input; + let input_schema = input.schema(); + + // ── Phase 1: Split ────────────────────────────────────────────────── + // For each projection expression, collect extraction pairs and build + // recovery expressions. + // + // Pre-existing `__datafusion_extracted` aliases are inserted into the + // extractor's `IndexMap` with the **full** `Expr::Alias(…)` as the key, + // so the alias name participates in equality. This prevents collisions + // when CSE rewrites produce the same inner expression under different + // alias names (e.g. `__common_expr_4 AS __datafusion_extracted_1` and + // `__common_expr_4 AS __datafusion_extracted_3`). New extractions from + // `routing_extract` use bare (non-Alias) keys and get normal dedup. + // + // When building the final `extraction_pairs`, the Alias wrapper is + // stripped so consumers see the usual `(inner_expr, alias_name)` tuples. + + let mut extractors = vec![LeafExpressionExtractor::new( + input_schema.as_ref(), + alias_generator, + )]; + let input_column_sets = vec![schema_columns(input_schema.as_ref())]; + + let original_schema = proj.schema.as_ref(); + let mut recovery_exprs: Vec = Vec::with_capacity(proj.expr.len()); + let mut needs_recovery = false; + let mut has_new_extractions = false; + let mut proj_exprs_captured: usize = 0; + // Track standalone column expressions (Case B) to detect column refs + // from extracted aliases (Case A) that aren't also standalone expressions. + let mut standalone_columns: IndexSet = IndexSet::new(); + + for (expr, (qualifier, field)) in proj.expr.iter().zip(original_schema.iter()) { + if let Expr::Alias(alias) = expr + && alias.name.starts_with(EXTRACTED_EXPR_PREFIX) + { + // Insert the full Alias expression as the key so that + // distinct alias names don't collide in the IndexMap. + let alias_name = alias.name.clone(); + + for col_ref in alias.expr.column_refs() { + extractors[0].columns_needed.insert(col_ref.clone()); + } + + extractors[0] + .extracted + .insert(expr.clone(), alias_name.clone()); + recovery_exprs.push(Expr::Column(Column::new_unqualified(&alias_name))); + proj_exprs_captured += 1; + } else if let Expr::Column(col) = expr { + // Plain column pass-through — track it in the extractor + extractors[0].columns_needed.insert(col.clone()); + standalone_columns.insert(col.clone()); + recovery_exprs.push(expr.clone()); + proj_exprs_captured += 1; + } else { + // Everything else: run through routing_extract + let transformed = + routing_extract(expr.clone(), &mut extractors, &input_column_sets)?; + if transformed.transformed { + has_new_extractions = true; + } + let transformed_expr = transformed.data; + + // Build recovery expression, aliasing back to original name if needed + let original_name = field.name(); + let needs_alias = if let Expr::Column(col) = &transformed_expr { + col.name.as_str() != original_name + } else { + let expr_name = transformed_expr.schema_name().to_string(); + original_name != &expr_name + }; + let recovery_expr = if needs_alias { + needs_recovery = true; + transformed_expr + .clone() + .alias_qualified(qualifier.cloned(), original_name) + } else { + transformed_expr.clone() + }; + + // If the expression was transformed (i.e., has extracted sub-parts), + // it differs from what the pushed projection outputs → needs recovery. + // Also, any non-column, non-__datafusion_extracted expression needs recovery + // because the pushed extraction projection won't output it directly. + if transformed.transformed || !matches!(expr, Expr::Column(_)) { + needs_recovery = true; + } + + recovery_exprs.push(recovery_expr); + } + } + + // Build extraction_pairs, stripping the Alias wrapper from pre-existing + // entries (they used the full Alias as the map key to avoid dedup). + let extractor = &extractors[0]; + let extraction_pairs: Vec<(Expr, String)> = extractor + .extracted + .iter() + .map(|(e, a)| match e { + Expr::Alias(alias) => (*alias.expr.clone(), a.clone()), + _ => (e.clone(), a.clone()), + }) + .collect(); + let columns_needed = &extractor.columns_needed; + + // If no extractions found, nothing to do + if extraction_pairs.is_empty() { + return Ok(None); + } + + // If columns_needed has entries that aren't standalone projection columns + // (i.e., they came from column refs inside extracted aliases), a merge + // into an inner projection will widen the schema with those extra columns, + // requiring a recovery projection to restore the original schema. + if columns_needed + .iter() + .any(|c| !standalone_columns.contains(c)) + { + needs_recovery = true; + } + + // ── Phase 2: Push down ────────────────────────────────────────────── + let proj_input = Arc::clone(&proj.input); + let pushed = push_extraction_pairs( + &extraction_pairs, + columns_needed, + proj, + &proj_input, + alias_generator, + proj_exprs_captured, + )?; + + // ── Phase 3: Recovery ─────────────────────────────────────────────── + // Determine the base plan: either the pushed result or an in-place extraction. + let base_plan = match pushed { + Some(plan) => plan, + None => { + if !has_new_extractions { + // Only pre-existing __datafusion_extracted aliases and columns, no new + // extractions from routing_extract. The original projection is + // already an extraction projection that couldn't be pushed + // further. Return None. + return Ok(None); + } + // Build extraction projection in-place (couldn't push down) + let input_arc = Arc::clone(input); + let extraction = build_extraction_projection_impl( + &extraction_pairs, + columns_needed, + &input_arc, + input_schema.as_ref(), + )?; + LogicalPlan::Projection(extraction) + } + }; + + // Wrap with recovery projection if the output schema changed + if needs_recovery { + let recovery = LogicalPlan::Projection(Projection::try_new( + recovery_exprs, + Arc::new(base_plan), + )?); + Ok(Some(recovery)) + } else { + Ok(Some(base_plan)) + } +} + +/// Returns true if the plan is a Projection where ALL expressions are either +/// `Alias(EXTRACTED_EXPR_PREFIX, ...)` or `Column`, with at least one extraction. +/// Such projections can safely be pushed further without re-extraction. +fn is_pure_extraction_projection(plan: &LogicalPlan) -> bool { + let LogicalPlan::Projection(proj) = plan else { + return false; + }; + let mut has_extraction = false; + for expr in &proj.expr { + match expr { + Expr::Alias(alias) if alias.name.starts_with(EXTRACTED_EXPR_PREFIX) => { + has_extraction = true; + } + Expr::Column(_) => {} + _ => return false, + } + } + has_extraction +} + +/// Pushes extraction pairs down through the projection's input node, +/// dispatching to the appropriate handler based on the input node type. +fn push_extraction_pairs( + pairs: &[(Expr, String)], + columns_needed: &IndexSet, + proj: &Projection, + proj_input: &Arc, + alias_generator: &Arc, + proj_exprs_captured: usize, +) -> Result> { + match proj_input.as_ref() { + // Merge into existing projection, then try to push the result further down. + // Only merge when every expression in the outer projection is fully + // captured as either an extraction pair (Case A: __datafusion_extracted + // alias) or a plain column (Case B). Uncaptured expressions (e.g. + // `col AS __common_expr_1` from CSE, or complex expressions with + // extracted sub-parts) would be lost during the merge. + LogicalPlan::Projection(_) if proj_exprs_captured == proj.expr.len() => { + let target_schema = Arc::clone(proj_input.schema()); + let merged = build_extraction_projection_impl( + pairs, + columns_needed, + proj_input, + target_schema.as_ref(), + )?; + let merged_plan = LogicalPlan::Projection(merged); + + // After merging, try to push the result further down, but ONLY + // if the merged result is still a pure extraction projection + // (all __datafusion_extracted aliases + columns). If the merge inherited + // bare MoveTowardsLeafNodes expressions from the inner projection, + // pushing would re-extract them into new aliases and fail when + // the (None, true) fallback can't find the original aliases. + // This handles: Extraction → Recovery(cols) → Filter → ... → TableScan + // by pushing through the recovery projection AND the filter in one pass. + if is_pure_extraction_projection(&merged_plan) + && let Some(pushed) = try_push_input(&merged_plan, alias_generator)? + { + return Ok(Some(pushed)); + } + Ok(Some(merged_plan)) + } + // Generic: handles Filter/Sort/Limit (via recursion), + // SubqueryAlias (with qualifier remap in try_push_into_inputs), + // Join, and anything else. + // Safely bails out for nodes that don't pass through extracted + // columns (Aggregate, Window) via the output schema check. + _ => try_push_into_inputs( + pairs, + columns_needed, + proj_input.as_ref(), + alias_generator, + ), + } +} + +/// Routes extraction pairs and columns to the appropriate inputs. +/// +/// - **Union**: broadcasts to every input via [`remap_pairs_and_columns`]. +/// - **Other nodes**: routes each expression to the one input that owns +/// all of its column references (via [`find_owning_input`]). +/// +/// Returns `None` if any expression can't be routed or no input has pairs. +fn route_to_inputs( + pairs: &[(Expr, String)], + columns: &IndexSet, + node: &LogicalPlan, + input_column_sets: &[std::collections::HashSet], + input_schemas: &[Arc], +) -> Result>> { + let num_inputs = input_schemas.len(); + let mut per_input: Vec = (0..num_inputs) + .map(|_| ExtractionTarget { + pairs: vec![], + columns: IndexSet::new(), + }) + .collect(); + + if matches!(node, LogicalPlan::Union(_)) { + // Union output schema and each input schema have the same fields by + // index but may differ in qualifiers (e.g. output `s` vs input + // `simple_struct.s`). Remap pairs/columns to each input's space. + let union_schema = node.schema(); + for (idx, input_schema) in input_schemas.iter().enumerate() { + per_input[idx] = + remap_pairs_and_columns(pairs, columns, union_schema, input_schema)?; + } + } else { + for (expr, alias) in pairs { + match find_owning_input(expr, input_column_sets) { + Some(idx) => per_input[idx].pairs.push((expr.clone(), alias.clone())), + None => return Ok(None), // Cross-input expression — bail out + } + } + for col in columns { + let col_expr = Expr::Column(col.clone()); + match find_owning_input(&col_expr, input_column_sets) { + Some(idx) => { + per_input[idx].columns.insert(col.clone()); + } + None => return Ok(None), // Ambiguous column — bail out + } + } + } + + // Check at least one input has extractions to push + if per_input.iter().all(|t| t.pairs.is_empty()) { + return Ok(None); + } + + Ok(Some(per_input)) +} + +/// Pushes extraction expressions into a node's inputs by routing each +/// expression to the input that owns all of its column references. +/// +/// Works for any number of inputs (1, 2, …N). For single-input nodes, +/// all expressions trivially route to that input. For multi-input nodes +/// (Join, etc.), each expression is routed to the side that owns its columns. +/// +/// Returns `Some(new_node)` if all expressions could be routed AND the +/// rebuilt node's output schema contains all extracted aliases. +/// Returns `None` if any expression references columns from multiple inputs +/// or the node doesn't pass through the extracted columns. +/// +/// # Example: Join with expressions from both sides +/// +/// ```text +/// Extraction projection above a Join: +/// Projection: left.user['name'] AS __datafusion_extracted_1, right.order['total'] AS __datafusion_extracted_2, ... +/// Join: left.id = right.user_id +/// TableScan: left [id, user] +/// TableScan: right [user_id, order] +/// +/// After routing each expression to its owning input: +/// Join: left.id = right.user_id +/// Projection: user['name'] AS __datafusion_extracted_1, id, user <-- left-side extraction +/// TableScan: left [id, user] +/// Projection: order['total'] AS __datafusion_extracted_2, user_id, order <-- right-side extraction +/// TableScan: right [user_id, order] +/// ``` +fn try_push_into_inputs( + pairs: &[(Expr, String)], + columns_needed: &IndexSet, + node: &LogicalPlan, + alias_generator: &Arc, +) -> Result> { + let inputs = node.inputs(); + if inputs.is_empty() { + return Ok(None); + } + + // Unnest may output a column with the same name but different value/type + // than its input column. Name-based routing cannot distinguish those. + // On top of that Unnest can't go through the `node.with_new_exprs(node.expressions(), new_inputs)` rebuild + if matches!(node, LogicalPlan::Unnest(_)) { + return Ok(None); + } + + // SubqueryAlias remaps qualifiers between input and output. + // Rewrite pairs/columns from alias-space to input-space before routing. + let remapped = if let LogicalPlan::SubqueryAlias(sa) = node { + remap_pairs_and_columns(pairs, columns_needed, &sa.schema, sa.input.schema())? + } else { + ExtractionTarget { + pairs: pairs.to_vec(), + columns: columns_needed.clone(), + } + }; + let pairs = &remapped.pairs[..]; + let columns_needed = &remapped.columns; + + // Build per-input schemas and column sets for routing + let input_schemas: Vec> = + inputs.iter().map(|i| Arc::clone(i.schema())).collect(); + let input_column_sets: Vec> = + input_schemas.iter().map(|s| schema_columns(s)).collect(); + + // Route pairs and columns to the appropriate inputs + let per_input = match route_to_inputs( + pairs, + columns_needed, + node, + &input_column_sets, + &input_schemas, + )? { + Some(routed) => routed, + None => return Ok(None), + }; + + let num_inputs = inputs.len(); + + // Build per-input extraction projections and push them as far as possible + // immediately. This is critical because map_children preserves cached schemas, + // so if the TopDown pass later pushes a child further (changing its output + // schema), the parent node's schema becomes stale. + let mut new_inputs: Vec = Vec::with_capacity(num_inputs); + for (idx, input) in inputs.into_iter().enumerate() { + if per_input[idx].pairs.is_empty() { + new_inputs.push(input.clone()); + } else { + let input_arc = Arc::new(input.clone()); + let target_schema = Arc::clone(input.schema()); + let proj = build_extraction_projection_impl( + &per_input[idx].pairs, + &per_input[idx].columns, + &input_arc, + target_schema.as_ref(), + )?; + // Verify all requested aliases appear in the projection's output. + // A merge may deduplicate if the same expression already exists + // under a different alias, leaving the requested alias missing. + let proj_schema = proj.schema.as_ref(); + for (_expr, alias) in &per_input[idx].pairs { + if !proj_schema.fields().iter().any(|f| f.name() == alias) { + return Ok(None); + } + } + let proj_plan = LogicalPlan::Projection(proj); + // Try to push the extraction projection further down within + // this input (e.g., through Filter → existing extraction projection). + // This ensures the input's output schema is stable and won't change + // when the TopDown pass later visits children. + match try_push_input(&proj_plan, alias_generator)? { + Some(pushed) => new_inputs.push(pushed), + None => new_inputs.push(proj_plan), + } + } + } + + // Rebuild the node with new inputs + let new_node = node.with_new_exprs(node.expressions(), new_inputs)?; + + // Safety check: verify all extracted aliases appear in the rebuilt + // node's output schema. Nodes like Aggregate define their own output + // and won't pass through extracted columns — bail out for those. + let output_schema = new_node.schema(); + for (_expr, alias) in pairs { + if !output_schema.fields().iter().any(|f| f.name() == alias) { + return Ok(None); + } + } + + Ok(Some(new_node)) +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::optimize_projections::OptimizeProjections; + use crate::test::udfs::PlacementTestUDF; + use crate::test::*; + use crate::{Optimizer, OptimizerContext}; + use datafusion_expr::expr::ScalarFunction; + use datafusion_expr::{ + ScalarUDF, col, lit, logical_plan::builder::LogicalPlanBuilder, + }; + + fn leaf_udf(expr: Expr, name: &str) -> Expr { + Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(ScalarUDF::new_from_impl( + PlacementTestUDF::new() + .with_placement(ExpressionPlacement::MoveTowardsLeafNodes), + )), + vec![expr, lit(name)], + )) + } + + // ========================================================================= + // Combined optimization stage formatter + // ========================================================================= + + /// Runs all 4 optimization stages and returns a single formatted string. + /// Stages that produce the same plan as the previous stage show + /// "(same as )" to reduce noise. + /// + /// Stages: + /// 1. **Original** - OptimizeProjections only (baseline) + /// 2. **After Extraction** - + ExtractLeafExpressions + /// 3. **After Pushdown** - + PushDownLeafProjections + /// 4. **Optimized** - + final OptimizeProjections + fn format_optimization_stages(plan: &LogicalPlan) -> Result { + let run = |rules: Vec>| -> Result { + let ctx = OptimizerContext::new().with_max_passes(1); + let optimizer = Optimizer::with_rules(rules); + let optimized = optimizer.optimize(plan.clone(), &ctx, |_, _| {})?; + Ok(format!("{optimized}")) + }; + + let original = run(vec![Arc::new(OptimizeProjections::new())])?; + + let after_extract = run(vec![ + Arc::new(OptimizeProjections::new()), + Arc::new(ExtractLeafExpressions::new()), + ])?; + + let after_pushdown = run(vec![ + Arc::new(OptimizeProjections::new()), + Arc::new(ExtractLeafExpressions::new()), + Arc::new(PushDownLeafProjections::new()), + ])?; + + let optimized = run(vec![ + Arc::new(OptimizeProjections::new()), + Arc::new(ExtractLeafExpressions::new()), + Arc::new(PushDownLeafProjections::new()), + Arc::new(OptimizeProjections::new()), + ])?; + + let mut out = format!("## Original Plan\n{original}"); + + out.push_str("\n\n## After Extraction\n"); + if after_extract == original { + out.push_str("(same as original)"); + } else { + out.push_str(&after_extract); + } + + out.push_str("\n\n## After Pushdown\n"); + if after_pushdown == after_extract { + out.push_str("(same as after extraction)"); + } else { + out.push_str(&after_pushdown); + } + + out.push_str("\n\n## Optimized\n"); + if optimized == after_pushdown { + out.push_str("(same as after pushdown)"); + } else { + out.push_str(&optimized); + } + + Ok(out) + } + + /// Assert all optimization stages for a plan in a single insta snapshot. + macro_rules! assert_stages { + ($plan:expr, @ $expected:literal $(,)?) => {{ + let result = format_optimization_stages(&$plan)?; + insta::assert_snapshot!(result, @ $expected); + Ok::<(), datafusion_common::DataFusionError>(()) + }}; + } + + #[test] + fn test_extract_from_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .select(vec![ + table_scan + .schema() + .index_of_column_by_name(None, "id") + .unwrap(), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id + Projection: test.id, test.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + Projection: test.id + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id + TableScan: test projection=[id, user] + "#) + } + + #[test] + fn test_no_extraction_for_column() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("a").eq(lit(1)))? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + Filter: test.a = Int32(1) + TableScan: test projection=[a, b, c] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + #[test] + fn test_extract_from_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")) + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")) + TableScan: test projection=[user] + "#) + } + + #[test] + fn test_extract_from_projection_with_subexpression() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![ + leaf_udf(col("user"), "name") + .is_not_null() + .alias("has_name"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) IS NOT NULL AS has_name + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 IS NOT NULL AS has_name + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")) IS NOT NULL AS has_name + TableScan: test projection=[user] + "#) + } + + #[test] + fn test_projection_no_extraction_for_column() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + TableScan: test projection=[a, b] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + #[test] + fn test_filter_with_deduplication() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let field_access = leaf_udf(col("user"), "name"); + // Filter with the same expression used twice + let plan = LogicalPlanBuilder::from(table_scan) + .filter( + field_access + .clone() + .is_not_null() + .and(field_access.is_null()), + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.user, Utf8("name")) IS NOT NULL AND leaf_udf(test.user, Utf8("name")) IS NULL + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, test.user + Filter: __datafusion_extracted_1 IS NOT NULL AND __datafusion_extracted_1 IS NULL + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + #[test] + fn test_already_leaf_expression_in_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "name").eq(lit("test")))? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.user, Utf8("name")) = Utf8("test") + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, test.user + Filter: __datafusion_extracted_1 = Utf8("test") + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + #[test] + fn test_extract_from_aggregate_group_by() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![leaf_udf(col("user"), "status")], vec![count(lit(1))])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Aggregate: groupBy=[[leaf_udf(test.user, Utf8("status"))]], aggr=[[COUNT(Int32(1))]] + TableScan: test projection=[user] + + ## After Extraction + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("status")), COUNT(Int32(1)) + Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[COUNT(Int32(1))]] + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("status")), COUNT(Int32(1)) + Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[COUNT(Int32(1))]] + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + #[test] + fn test_extract_from_aggregate_args() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("user")], + vec![count(leaf_udf(col("user"), "value"))], + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(leaf_udf(test.user, Utf8("value")))]] + TableScan: test projection=[user] + + ## After Extraction + Projection: test.user, COUNT(__datafusion_extracted_1) AS COUNT(leaf_udf(test.user,Utf8("value"))) + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(__datafusion_extracted_1)]] + Projection: leaf_udf(test.user, Utf8("value")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + #[test] + fn test_projection_with_filter_combined() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .project(vec![leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[user] + + ## After Extraction + Projection: leaf_udf(test.user, Utf8("name")) + Projection: test.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + Projection: __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2 + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2 + TableScan: test projection=[user] + "#) + } + + #[test] + fn test_projection_preserves_alias() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![leaf_udf(col("user"), "name").alias("username")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) AS username + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS username + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")) AS username + TableScan: test projection=[user] + "#) + } + + /// Test: Projection with different field than Filter + /// SELECT id, s['label'] FROM t WHERE s['value'] > 150 + /// Both s['label'] and s['value'] should be in a single extraction projection. + #[test] + fn test_projection_different_field_from_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "value").gt(lit(150)))? + .project(vec![col("user"), leaf_udf(col("user"), "label")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.user, leaf_udf(test.user, Utf8("label")) + Filter: leaf_udf(test.user, Utf8("value")) > Int32(150) + TableScan: test projection=[user] + + ## After Extraction + Projection: test.user, leaf_udf(test.user, Utf8("label")) + Projection: test.user + Filter: __datafusion_extracted_1 > Int32(150) + Projection: leaf_udf(test.user, Utf8("value")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + Projection: test.user, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("label")) + Filter: __datafusion_extracted_1 > Int32(150) + Projection: leaf_udf(test.user, Utf8("value")) AS __datafusion_extracted_1, test.user, leaf_udf(test.user, Utf8("label")) AS __datafusion_extracted_2 + TableScan: test projection=[user] + + ## Optimized + (same as after pushdown) + "#) + } + + #[test] + fn test_projection_deduplication() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let field = leaf_udf(col("user"), "name"); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![field.clone(), field.clone().alias("name2")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")), leaf_udf(test.user, Utf8("name")) AS name2 + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")), __datafusion_extracted_1 AS name2 + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")), leaf_udf(test.user, Utf8("name")) AS name2 + TableScan: test projection=[user] + "#) + } + + // ========================================================================= + // Additional tests for code coverage + // ========================================================================= + + /// Extractions push through Sort nodes to reach the TableScan. + #[test] + fn test_extract_through_sort() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .sort(vec![col("user").sort(true, true)])? + .project(vec![leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) + Sort: test.user ASC NULLS FIRST + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")) + Sort: test.user ASC NULLS FIRST + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + (same as after pushdown) + "#) + } + + /// Extractions push through Limit nodes to reach the TableScan. + #[test] + fn test_extract_through_limit() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .limit(0, Some(10))? + .project(vec![leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) + Limit: skip=0, fetch=10 + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")) + Limit: skip=0, fetch=10 + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")) + Limit: skip=0, fetch=10 + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Aliased aggregate functions like count(...).alias("cnt") are handled. + #[test] + fn test_extract_from_aliased_aggregate() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("user")], + vec![count(leaf_udf(col("user"), "value")).alias("cnt")], + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(leaf_udf(test.user, Utf8("value"))) AS cnt]] + TableScan: test projection=[user] + + ## After Extraction + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(__datafusion_extracted_1) AS cnt]] + Projection: leaf_udf(test.user, Utf8("value")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// Aggregates with no MoveTowardsLeafNodes expressions return unchanged. + #[test] + fn test_aggregate_no_extraction() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![count(col("b"))])? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b)]] + TableScan: test projection=[a, b] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + /// Projections containing extracted expression aliases are skipped (already extracted). + #[test] + fn test_skip_extracted_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![ + leaf_udf(col("user"), "name").alias("__datafusion_extracted_manual"), + col("user"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_manual, test.user + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// Multiple extractions merge into a single extracted expression projection. + #[test] + fn test_merge_into_existing_extracted_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .filter(leaf_udf(col("user"), "name").is_not_null())? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.user, Utf8("name")) IS NOT NULL + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, test.user + Filter: __datafusion_extracted_1 IS NOT NULL + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.id, test.user + Projection: test.id, test.user + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + Projection: test.id, test.user + Filter: __datafusion_extracted_1 IS NOT NULL + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, test.id, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[id, user] + + ## Optimized + Projection: test.id, test.user + Filter: __datafusion_extracted_1 IS NOT NULL + Projection: test.id, test.user, __datafusion_extracted_1 + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, test.id, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[id, user] + "#) + } + + /// Extractions push through passthrough projections (columns only). + #[test] + fn test_extract_through_passthrough_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("user")])? + .project(vec![leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")) + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")) + TableScan: test projection=[user] + "#) + } + + /// Projections with aliased columns (nothing to extract) return unchanged. + #[test] + fn test_projection_early_return_no_extraction() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a").alias("x"), col("b")])? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + Projection: test.a AS x, test.b + TableScan: test projection=[a, b] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + /// Projections with arithmetic expressions but no MoveTowardsLeafNodes return unchanged. + #[test] + fn test_projection_with_arithmetic_no_extraction() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![(col("a") + col("b")).alias("sum")])? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + Projection: test.a + test.b AS sum + TableScan: test projection=[a, b] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + /// Aggregate extractions merge into existing extracted projection created by Filter. + #[test] + fn test_aggregate_merge_into_extracted_projection() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .aggregate(vec![leaf_udf(col("user"), "name")], vec![count(lit(1))])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Aggregate: groupBy=[[leaf_udf(test.user, Utf8("name"))]], aggr=[[COUNT(Int32(1))]] + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[user] + + ## After Extraction + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")), COUNT(Int32(1)) + Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[COUNT(Int32(1))]] + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + Projection: test.user + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, test.user + TableScan: test projection=[user] + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")), COUNT(Int32(1)) + Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[COUNT(Int32(1))]] + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("name")), COUNT(Int32(1)) + Aggregate: groupBy=[[__datafusion_extracted_1]], aggr=[[COUNT(Int32(1))]] + Projection: __datafusion_extracted_1 + Filter: __datafusion_extracted_2 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Projection containing a MoveTowardsLeafNodes sub-expression above an + /// Aggregate. Aggregate blocks pushdown, so the (None, true) recovery + /// fallback path fires: in-place extraction + recovery projection. + #[test] + fn test_projection_with_leaf_expr_above_aggregate() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("user")], vec![count(lit(1))])? + .project(vec![ + leaf_udf(col("user"), "name") + .is_not_null() + .alias("has_name"), + col("COUNT(Int32(1))"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("name")) IS NOT NULL AS has_name, COUNT(Int32(1)) + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(Int32(1))]] + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 IS NOT NULL AS has_name, COUNT(Int32(1)) + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user, COUNT(Int32(1)) + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(Int32(1))]] + TableScan: test projection=[user] + + ## Optimized + Projection: leaf_udf(test.user, Utf8("name")) IS NOT NULL AS has_name, COUNT(Int32(1)) + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(Int32(1))]] + TableScan: test projection=[user] + "#) + } + + /// Merging adds new pass-through columns not in the existing extracted projection. + #[test] + fn test_merge_with_new_columns() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("a"), "x").eq(lit(1)))? + .filter(leaf_udf(col("b"), "y").eq(lit(2)))? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.b, Utf8("y")) = Int32(2) + Filter: leaf_udf(test.a, Utf8("x")) = Int32(1) + TableScan: test projection=[a, b, c] + + ## After Extraction + Projection: test.a, test.b, test.c + Filter: __datafusion_extracted_1 = Int32(2) + Projection: leaf_udf(test.b, Utf8("y")) AS __datafusion_extracted_1, test.a, test.b, test.c + Projection: test.a, test.b, test.c + Filter: __datafusion_extracted_2 = Int32(1) + Projection: leaf_udf(test.a, Utf8("x")) AS __datafusion_extracted_2, test.a, test.b, test.c + TableScan: test projection=[a, b, c] + + ## After Pushdown + Projection: test.a, test.b, test.c + Filter: __datafusion_extracted_1 = Int32(2) + Filter: __datafusion_extracted_2 = Int32(1) + Projection: leaf_udf(test.a, Utf8("x")) AS __datafusion_extracted_2, test.a, test.b, test.c, leaf_udf(test.b, Utf8("y")) AS __datafusion_extracted_1 + TableScan: test projection=[a, b, c] + + ## Optimized + Projection: test.a, test.b, test.c + Filter: __datafusion_extracted_1 = Int32(2) + Projection: test.a, test.b, test.c, __datafusion_extracted_1 + Filter: __datafusion_extracted_2 = Int32(1) + Projection: leaf_udf(test.a, Utf8("x")) AS __datafusion_extracted_2, test.a, test.b, test.c, leaf_udf(test.b, Utf8("y")) AS __datafusion_extracted_1 + TableScan: test projection=[a, b, c] + "#) + } + + // ========================================================================= + // Join extraction tests + // ========================================================================= + + /// Create a second table scan with struct field for join tests + fn test_table_scan_with_struct_named(name: &str) -> Result { + use arrow::datatypes::Schema; + let schema = Schema::new(test_table_scan_with_struct_fields()); + datafusion_expr::logical_plan::table_scan(Some(name), &schema, None)?.build() + } + + /// Extraction from equijoin keys (`on` expressions). + #[test] + fn test_extract_from_join_on() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_with_expr_keys( + right, + JoinType::Inner, + ( + vec![leaf_udf(col("user"), "id")], + vec![leaf_udf(col("user"), "id")], + ), + None, + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Inner Join: leaf_udf(test.user, Utf8("id")) = leaf_udf(right.user, Utf8("id")) + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Inner Join: __datafusion_extracted_1 = __datafusion_extracted_2 + Projection: leaf_udf(test.user, Utf8("id")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("id")) AS __datafusion_extracted_2, right.id, right.user + TableScan: right projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// Extraction from non-equi join filter. + #[test] + fn test_extract_from_join_filter() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_on( + right, + JoinType::Inner, + vec![ + col("test.user").eq(col("right.user")), + leaf_udf(col("test.user"), "status").eq(lit("active")), + ], + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Inner Join: Filter: test.user = right.user AND leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Inner Join: Filter: test.user = right.user AND __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// Extraction from both left and right sides of a join. + #[test] + fn test_extract_from_join_both_sides() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_on( + right, + JoinType::Inner, + vec![ + col("test.user").eq(col("right.user")), + leaf_udf(col("test.user"), "status").eq(lit("active")), + leaf_udf(col("right.user"), "role").eq(lit("admin")), + ], + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Inner Join: Filter: test.user = right.user AND leaf_udf(test.user, Utf8("status")) = Utf8("active") AND leaf_udf(right.user, Utf8("role")) = Utf8("admin") + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Inner Join: Filter: test.user = right.user AND __datafusion_extracted_1 = Utf8("active") AND __datafusion_extracted_2 = Utf8("admin") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("role")) AS __datafusion_extracted_2, right.id, right.user + TableScan: right projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// Join with no MoveTowardsLeafNodes expressions returns unchanged. + #[test] + fn test_extract_from_join_no_extraction() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan()?; + let right = test_table_scan_with_name("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join(right, JoinType::Inner, (vec!["a"], vec!["a"]), None)? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + Inner Join: test.a = right.a + TableScan: test projection=[a, b, c] + TableScan: right projection=[a, b, c] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + /// Join followed by filter with extraction. + #[test] + fn test_extract_from_filter_above_join() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_with_expr_keys( + right, + JoinType::Inner, + ( + vec![leaf_udf(col("user"), "id")], + vec![leaf_udf(col("user"), "id")], + ), + None, + )? + .filter(leaf_udf(col("test.user"), "status").eq(lit("active")))? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + Inner Join: leaf_udf(test.user, Utf8("id")) = leaf_udf(right.user, Utf8("id")) + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user, right.id, right.user + Projection: test.id, test.user, right.id, right.user + Inner Join: __datafusion_extracted_2 = __datafusion_extracted_3 + Projection: leaf_udf(test.user, Utf8("id")) AS __datafusion_extracted_2, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("id")) AS __datafusion_extracted_3, right.id, right.user + TableScan: right projection=[id, user] + + ## After Pushdown + Projection: test.id, test.user, right.id, right.user + Filter: __datafusion_extracted_1 = Utf8("active") + Inner Join: __datafusion_extracted_2 = __datafusion_extracted_3 + Projection: leaf_udf(test.user, Utf8("id")) AS __datafusion_extracted_2, test.id, test.user, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1 + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("id")) AS __datafusion_extracted_3, right.id, right.user + TableScan: right projection=[id, user] + + ## Optimized + Projection: test.id, test.user, right.id, right.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: test.id, test.user, __datafusion_extracted_1, right.id, right.user + Inner Join: __datafusion_extracted_2 = __datafusion_extracted_3 + Projection: leaf_udf(test.user, Utf8("id")) AS __datafusion_extracted_2, test.id, test.user, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1 + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("id")) AS __datafusion_extracted_3, right.id, right.user + TableScan: right projection=[id, user] + "#) + } + + /// Extraction projection (get_field in SELECT) above a Join pushes into + /// the correct input side. + #[test] + fn test_extract_projection_above_join() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join(right, JoinType::Inner, (vec!["id"], vec!["id"]), None)? + .project(vec![ + leaf_udf(col("test.user"), "status"), + leaf_udf(col("right.user"), "role"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(test.user, Utf8("status")), leaf_udf(right.user, Utf8("role")) + Inner Join: test.id = right.id + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("status")), __datafusion_extracted_2 AS leaf_udf(right.user,Utf8("role")) + Inner Join: test.id = right.id + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("role")) AS __datafusion_extracted_2, right.id, right.user + TableScan: right projection=[id, user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(test.user,Utf8("status")), __datafusion_extracted_2 AS leaf_udf(right.user,Utf8("role")) + Inner Join: test.id = right.id + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("role")) AS __datafusion_extracted_2, right.id + TableScan: right projection=[id, user] + "#) + } + + /// Join where both sides have same-named columns: a qualified reference + /// to the right side must be routed to the right input, not the left. + #[test] + fn test_extract_from_join_qualified_right_side() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + // Filter references right.user explicitly — must route to right side + let plan = LogicalPlanBuilder::from(left) + .join_on( + right, + JoinType::Inner, + vec![ + col("test.id").eq(col("right.id")), + leaf_udf(col("right.user"), "status").eq(lit("active")), + ], + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Inner Join: Filter: test.id = right.id AND leaf_udf(right.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Inner Join: Filter: test.id = right.id AND __datafusion_extracted_1 = Utf8("active") + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_1, right.id, right.user + TableScan: right projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + "#) + } + + /// When both inputs contain the same unqualified column, an unqualified + /// column reference is ambiguous and `find_owning_input` must return + /// `None` rather than always returning 0 (the left side). + #[test] + fn test_find_owning_input_ambiguous_unqualified_column() { + use std::collections::HashSet; + + // Simulate schema_columns output for two sides of a join where both + // have a "user" column — each set contains the qualified and + // unqualified form. + let relation = "test".into(); + let left_cols: HashSet = [ + ColumnReference::new(Some(&relation), "user"), + ColumnReference::new_unqualified("user"), + ] + .into_iter() + .collect(); + + let relation = "right".into(); + let right_cols: HashSet = [ + ColumnReference::new(Some(&relation), "user"), + ColumnReference::new_unqualified("user"), + ] + .into_iter() + .collect(); + + let input_column_sets = vec![left_cols, right_cols]; + + // Unqualified "user" matches both sets — must return None (ambiguous) + let unqualified = Expr::Column(Column::new_unqualified("user")); + assert_eq!(find_owning_input(&unqualified, &input_column_sets), None); + + // Qualified "right.user" matches only the right set — must return Some(1) + let qualified_right = Expr::Column(Column::new(Some("right"), "user")); + assert_eq!( + find_owning_input(&qualified_right, &input_column_sets), + Some(1) + ); + + // Qualified "test.user" matches only the left set — must return Some(0) + let qualified_left = Expr::Column(Column::new(Some("test"), "user")); + assert_eq!( + find_owning_input(&qualified_left, &input_column_sets), + Some(0) + ); + } + + /// Two leaf_udf expressions from different sides of a Join in a Filter. + /// Each is routed to its respective input side independently. + #[test] + fn test_extract_from_join_cross_input_expression() -> Result<()> { + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_on( + right, + datafusion_expr::JoinType::Inner, + vec![col("test.id").eq(col("right.id"))], + )? + .filter( + leaf_udf(col("test.user"), "status") + .eq(leaf_udf(col("right.user"), "status")), + )? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(test.user, Utf8("status")) = leaf_udf(right.user, Utf8("status")) + Inner Join: Filter: test.id = right.id + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, test.user, right.id, right.user + Filter: __datafusion_extracted_1 = __datafusion_extracted_2 + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_2, test.id, test.user, right.id, right.user + Inner Join: Filter: test.id = right.id + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Pushdown + Projection: test.id, test.user, right.id, right.user + Filter: __datafusion_extracted_1 = __datafusion_extracted_2 + Inner Join: Filter: test.id = right.id + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_2, right.id, right.user + TableScan: right projection=[id, user] + + ## Optimized + (same as after pushdown) + "#) + } + + // ========================================================================= + // Column-rename through intermediate node tests + // ========================================================================= + + /// Projection with leaf expr above Filter above renaming Projection. + #[test] + fn test_extract_through_filter_with_column_rename() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("user").alias("x")])? + .filter(col("x").is_not_null())? + .project(vec![leaf_udf(col("x"), "a")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(x, Utf8("a")) + Filter: x IS NOT NULL + Projection: test.user AS x + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(x,Utf8("a")) + Filter: x IS NOT NULL + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(x,Utf8("a")) + Filter: x IS NOT NULL + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Same as above but with a partial extraction (leaf + arithmetic). + #[test] + fn test_extract_partial_through_filter_with_column_rename() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("user").alias("x")])? + .filter(col("x").is_not_null())? + .project(vec![leaf_udf(col("x"), "a").is_not_null()])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(x, Utf8("a")) IS NOT NULL + Filter: x IS NOT NULL + Projection: test.user AS x + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 IS NOT NULL AS leaf_udf(x,Utf8("a")) IS NOT NULL + Filter: x IS NOT NULL + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 IS NOT NULL AS leaf_udf(x,Utf8("a")) IS NOT NULL + Filter: x IS NOT NULL + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Tests merge_into_extracted_projection path through a renaming projection. + #[test] + fn test_extract_from_filter_above_renaming_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("user").alias("x")])? + .filter(leaf_udf(col("x"), "a").eq(lit("active")))? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Filter: leaf_udf(x, Utf8("a")) = Utf8("active") + Projection: test.user AS x + TableScan: test projection=[user] + + ## After Extraction + Projection: x + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + Projection: x + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: test.user AS x, leaf_udf(test.user, Utf8("a")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + // ========================================================================= + // SubqueryAlias extraction tests + // ========================================================================= + + /// Extraction projection pushes through SubqueryAlias. + #[test] + fn test_extract_through_subquery_alias() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .alias("sub")? + .project(vec![leaf_udf(col("sub.user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(sub.user, Utf8("name")) + SubqueryAlias: sub + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(sub.user,Utf8("name")) + SubqueryAlias: sub + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(sub.user,Utf8("name")) + SubqueryAlias: sub + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Extraction projection pushes through SubqueryAlias + Filter. + #[test] + fn test_extract_through_subquery_alias_with_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .alias("sub")? + .filter(leaf_udf(col("sub.user"), "status").eq(lit("active")))? + .project(vec![leaf_udf(col("sub.user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(sub.user, Utf8("name")) + Filter: leaf_udf(sub.user, Utf8("status")) = Utf8("active") + SubqueryAlias: sub + TableScan: test projection=[user] + + ## After Extraction + Projection: leaf_udf(sub.user, Utf8("name")) + Projection: sub.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(sub.user, Utf8("status")) AS __datafusion_extracted_1, sub.user + SubqueryAlias: sub + TableScan: test projection=[user] + + ## After Pushdown + Projection: __datafusion_extracted_2 AS leaf_udf(sub.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + SubqueryAlias: sub + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_2 AS leaf_udf(sub.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + SubqueryAlias: sub + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2 + TableScan: test projection=[user] + "#) + } + + /// Two layers of SubqueryAlias: extraction pushes through both. + #[test] + fn test_extract_through_nested_subquery_alias() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .alias("inner_sub")? + .alias("outer_sub")? + .project(vec![leaf_udf(col("outer_sub.user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: leaf_udf(outer_sub.user, Utf8("name")) + SubqueryAlias: outer_sub + SubqueryAlias: inner_sub + TableScan: test projection=[user] + + ## After Extraction + (same as original) + + ## After Pushdown + Projection: __datafusion_extracted_1 AS leaf_udf(outer_sub.user,Utf8("name")) + SubqueryAlias: outer_sub + SubqueryAlias: inner_sub + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1, test.user + TableScan: test projection=[user] + + ## Optimized + Projection: __datafusion_extracted_1 AS leaf_udf(outer_sub.user,Utf8("name")) + SubqueryAlias: outer_sub + SubqueryAlias: inner_sub + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_1 + TableScan: test projection=[user] + "#) + } + + /// Plain columns through SubqueryAlias -- no extraction needed. + #[test] + fn test_subquery_alias_no_extraction() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .alias("sub")? + .project(vec![col("sub.a"), col("sub.b")])? + .build()?; + + assert_stages!(plan, @" + ## Original Plan + SubqueryAlias: sub + TableScan: test projection=[a, b] + + ## After Extraction + (same as original) + + ## After Pushdown + (same as after extraction) + + ## Optimized + (same as after pushdown) + ") + } + + /// Two UDFs with the same `name()` but different concrete types should NOT be + /// deduplicated -- they are semantically different expressions that happen to + /// collide on `schema_name()`. + #[test] + fn test_different_udfs_same_schema_name_not_deduplicated() -> Result<()> { + let udf_a = Arc::new(ScalarUDF::new_from_impl( + PlacementTestUDF::new() + .with_placement(ExpressionPlacement::MoveTowardsLeafNodes) + .with_id(1), + )); + let udf_b = Arc::new(ScalarUDF::new_from_impl( + PlacementTestUDF::new() + .with_placement(ExpressionPlacement::MoveTowardsLeafNodes) + .with_id(2), + )); + + let expr_a = Expr::ScalarFunction(ScalarFunction::new_udf( + udf_a, + vec![col("user"), lit("field")], + )); + let expr_b = Expr::ScalarFunction(ScalarFunction::new_udf( + udf_b, + vec![col("user"), lit("field")], + )); + + // Verify preconditions: same schema_name but different Expr + assert_eq!( + expr_a.schema_name().to_string(), + expr_b.schema_name().to_string(), + "Both expressions should have the same schema_name" + ); + assert_ne!( + expr_a, expr_b, + "Expressions should NOT be equal (different UDF instances)" + ); + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .filter(expr_a.clone().eq(lit("a")).and(expr_b.clone().eq(lit("b"))))? + .select(vec![ + table_scan + .schema() + .index_of_column_by_name(None, "id") + .unwrap(), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id + Filter: leaf_udf(test.user, Utf8("field")) = Utf8("a") AND leaf_udf(test.user, Utf8("field")) = Utf8("b") + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id + Projection: test.id, test.user + Filter: __datafusion_extracted_1 = Utf8("a") AND __datafusion_extracted_2 = Utf8("b") + Projection: leaf_udf(test.user, Utf8("field")) AS __datafusion_extracted_1, leaf_udf(test.user, Utf8("field")) AS __datafusion_extracted_2, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + (same as after extraction) + + ## Optimized + Projection: test.id + Filter: __datafusion_extracted_1 = Utf8("a") AND __datafusion_extracted_2 = Utf8("b") + Projection: leaf_udf(test.user, Utf8("field")) AS __datafusion_extracted_1, leaf_udf(test.user, Utf8("field")) AS __datafusion_extracted_2, test.id + TableScan: test projection=[id, user] + "#) + } + + // ========================================================================= + // Filter pushdown interaction tests + // ========================================================================= + + /// Extraction pushdown through a filter that already had its own + /// `leaf_udf` extracted. + #[test] + fn test_extraction_pushdown_through_filter_with_extracted_predicate() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "status").eq(lit("active")))? + .project(vec![col("id"), leaf_udf(col("user"), "name")])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id, leaf_udf(test.user, Utf8("name")) + Filter: leaf_udf(test.user, Utf8("status")) = Utf8("active") + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, leaf_udf(test.user, Utf8("name")) + Projection: test.id, test.user + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2 + TableScan: test projection=[id, user] + + ## Optimized + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")) + Filter: __datafusion_extracted_1 = Utf8("active") + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2 + TableScan: test projection=[id, user] + "#) + } + + /// Same expression in filter predicate and projection output. + #[test] + fn test_extraction_pushdown_same_expr_in_filter_and_projection() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let field_expr = leaf_udf(col("user"), "status"); + let plan = LogicalPlanBuilder::from(table_scan) + .filter(field_expr.clone().gt(lit(5)))? + .project(vec![col("id"), field_expr])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id, leaf_udf(test.user, Utf8("status")) + Filter: leaf_udf(test.user, Utf8("status")) > Int32(5) + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, leaf_udf(test.user, Utf8("status")) + Projection: test.id, test.user + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("status")) + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2 + TableScan: test projection=[id, user] + + ## Optimized + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("status")) + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_2 + TableScan: test projection=[id, user] + "#) + } + + /// Left join with a `leaf_udf` filter on the right side AND + /// the projection also selects `leaf_udf` from the right side. + #[test] + fn test_left_join_with_filter_and_projection_extraction() -> Result<()> { + use datafusion_expr::JoinType; + + let left = test_table_scan_with_struct()?; + let right = test_table_scan_with_struct_named("right")?; + + let plan = LogicalPlanBuilder::from(left) + .join_on( + right, + JoinType::Left, + vec![ + col("test.id").eq(col("right.id")), + leaf_udf(col("right.user"), "status").gt(lit(5)), + ], + )? + .project(vec![ + col("test.id"), + leaf_udf(col("test.user"), "name"), + leaf_udf(col("right.user"), "status"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id, leaf_udf(test.user, Utf8("name")), leaf_udf(right.user, Utf8("status")) + Left Join: Filter: test.id = right.id AND leaf_udf(right.user, Utf8("status")) > Int32(5) + TableScan: test projection=[id, user] + TableScan: right projection=[id, user] + + ## After Extraction + Projection: test.id, leaf_udf(test.user, Utf8("name")), leaf_udf(right.user, Utf8("status")) + Projection: test.id, test.user, right.id, right.user + Left Join: Filter: test.id = right.id AND __datafusion_extracted_1 > Int32(5) + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_1, right.id, right.user + TableScan: right projection=[id, user] + + ## After Pushdown + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")), __datafusion_extracted_3 AS leaf_udf(right.user,Utf8("status")) + Left Join: Filter: test.id = right.id AND __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2, test.id, test.user + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_1, right.id, right.user, leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_3 + TableScan: right projection=[id, user] + + ## Optimized + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")), __datafusion_extracted_3 AS leaf_udf(right.user,Utf8("status")) + Left Join: Filter: test.id = right.id AND __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2, test.id + TableScan: test projection=[id, user] + Projection: leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_1, right.id, leaf_udf(right.user, Utf8("status")) AS __datafusion_extracted_3 + TableScan: right projection=[id, user] + "#) + } + + /// Extraction projection pushed through a filter whose predicate + /// references a different extracted expression. + #[test] + fn test_pure_extraction_proj_push_through_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(leaf_udf(col("user"), "status").gt(lit(5)))? + .project(vec![ + col("id"), + leaf_udf(col("user"), "name"), + leaf_udf(col("user"), "status"), + ])? + .build()?; + + assert_stages!(plan, @r#" + ## Original Plan + Projection: test.id, leaf_udf(test.user, Utf8("name")), leaf_udf(test.user, Utf8("status")) + Filter: leaf_udf(test.user, Utf8("status")) > Int32(5) + TableScan: test projection=[id, user] + + ## After Extraction + Projection: test.id, leaf_udf(test.user, Utf8("name")), leaf_udf(test.user, Utf8("status")) + Projection: test.id, test.user + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user + TableScan: test projection=[id, user] + + ## After Pushdown + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")), __datafusion_extracted_3 AS leaf_udf(test.user,Utf8("status")) + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, test.user, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_3 + TableScan: test projection=[id, user] + + ## Optimized + Projection: test.id, __datafusion_extracted_2 AS leaf_udf(test.user,Utf8("name")), __datafusion_extracted_3 AS leaf_udf(test.user,Utf8("status")) + Filter: __datafusion_extracted_1 > Int32(5) + Projection: leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1, test.id, leaf_udf(test.user, Utf8("name")) AS __datafusion_extracted_2, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_3 + TableScan: test projection=[id, user] + "#) + } + + /// When an extraction projection's __extracted alias references a column + /// (e.g. `user`) that is NOT a standalone expression in the projection, + /// the merge into the inner projection should still succeed. + #[test] + fn test_merge_extraction_into_projection_with_column_ref_inflation() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + + // Inner projection (simulates a trimmed projection) + let inner = LogicalPlanBuilder::from(table_scan) + .project(vec![col("user"), col("id")])? + .build()?; + + // Outer projection: __extracted alias + id (but NOT user as standalone). + // The alias references `user` internally, inflating columns_needed. + let plan = LogicalPlanBuilder::from(inner) + .project(vec![ + leaf_udf(col("user"), "status") + .alias(format!("{EXTRACTED_EXPR_PREFIX}_1")), + col("id"), + ])? + .build()?; + + // Run only PushDownLeafProjections + let ctx = OptimizerContext::new().with_max_passes(1); + let optimizer = + Optimizer::with_rules(vec![Arc::new(PushDownLeafProjections::new())]); + let result = optimizer.optimize(plan, &ctx, |_, _| {})?; + + // With the fix: merge succeeds → extraction merged into inner projection. + // Without the fix: merge rejected → two separate projections remain. + insta::assert_snapshot!(format!("{result}"), @r#" + Projection: __datafusion_extracted_1, test.id + Projection: test.user, test.id, leaf_udf(test.user, Utf8("status")) AS __datafusion_extracted_1 + TableScan: test + "#); + + Ok(()) + } + + /// Regression test for the `Assertion failed: expr.is_empty(): Unnest` + /// internal error. + /// + /// `try_push_into_inputs` rebuilds the parent node via + /// `node.with_new_exprs(node.expressions(), new_inputs)`. For `Unnest`, + /// `apply_expressions` exposes the `exec_columns` as `Expr::Column`s + /// (so `expressions()` is **non-empty**), but `with_new_exprs` for + /// `Unnest` immediately calls `assert_no_expressions(expr)?` and errors + /// out. The optimizer should treat `Unnest` as a barrier and bail + /// instead of attempting to push through it. + #[test] + fn test_no_push_through_unnest() -> Result<()> { + use arrow::datatypes::{DataType, Field, Schema}; + + let schema = Schema::new(vec![ + Field::new("list_col", DataType::new_list(DataType::Int32, true), true), + Field::new("other_col", DataType::Int32, true), + ]); + let table_scan = + datafusion_expr::logical_plan::table_scan(Some("t"), &schema, None)? + .build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .unnest_column("list_col")? + .filter(leaf_udf(col("list_col"), "x").eq(lit(1i32)))? + .build()?; + + let ctx = OptimizerContext::new().with_max_passes(1); + let optimizer = Optimizer::with_rules(vec![ + Arc::new(ExtractLeafExpressions::new()), + Arc::new(PushDownLeafProjections::new()), + ]); + let optimized = optimizer.optimize(plan, &ctx, |_, _| {})?; + + insta::assert_snapshot!(format!("{optimized}"), @r#" + Projection: list_col, t.other_col + Filter: __datafusion_extracted_1 = Int32(1) + Projection: leaf_udf(list_col, Utf8("x")) AS __datafusion_extracted_1, list_col, t.other_col + Unnest: lists[t.list_col|depth=1] structs[] + TableScan: t + "#); + + Ok(()) + } +} diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 8ad7fa53c0e33..c8f419d3e543e 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -23,7 +23,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::{NullEquality, Result}; use datafusion_expr::utils::conjunction; -use datafusion_expr::{logical_plan::Filter, Expr, ExprSchemable, LogicalPlan}; +use datafusion_expr::{Expr, ExprSchemable, LogicalPlan, logical_plan::Filter}; use std::sync::Arc; /// The FilterNullJoinKeys rule will identify joins with equi-join conditions @@ -108,12 +108,12 @@ fn create_not_null_predicate(filters: Vec) -> Expr { #[cfg(test)] mod tests { use super::*; - use crate::assert_optimized_plan_eq_snapshot; use crate::OptimizerContext; + use crate::assert_optimized_plan_eq_snapshot; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Column; use datafusion_expr::logical_plan::table_scan; - use datafusion_expr::{col, lit, JoinType, LogicalPlanBuilder}; + use datafusion_expr::{JoinType, LogicalPlanBuilder, col, lit}; macro_rules! assert_optimized_plan_equal { ( diff --git a/datafusion/optimizer/src/join_key_set.rs b/datafusion/optimizer/src/join_key_set.rs index 0a97173b30966..de795c0aeacfa 100644 --- a/datafusion/optimizer/src/join_key_set.rs +++ b/datafusion/optimizer/src/join_key_set.rs @@ -157,7 +157,7 @@ impl Equivalent<(Expr, Expr)> for ExprPair<'_> { #[cfg(test)] mod test { use crate::join_key_set::JoinKeySet; - use datafusion_expr::{col, Expr}; + use datafusion_expr::{Expr, col}; #[test] fn test_insert() { diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 7632ff858df61..fbe7ad2f4d327 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -23,8 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! # DataFusion Optimizer @@ -59,6 +57,7 @@ pub mod eliminate_nested_union { } pub mod eliminate_outer_join; pub mod extract_equijoin_predicate; +pub mod extract_leaf_expressions; pub mod filter_null_join_keys; pub mod optimize_projections; pub mod optimize_unions; @@ -67,9 +66,11 @@ pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; pub mod replace_distinct_aggregate; +pub mod rewrite_set_comparison; pub mod scalar_subquery_to_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; +pub mod unions_to_filter; pub mod utils; #[cfg(test)] @@ -84,7 +85,7 @@ pub(crate) mod join_key_set; mod plan_signature; #[cfg(test)] -#[ctor::ctor] +#[ctor::ctor(unsafe)] fn init() { // Enable RUST_LOG logging configuration for test let _ = env_logger::try_init(); diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index ee7b006a2d496..acdbf71d05d5c 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -21,24 +21,21 @@ mod required_indices; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use std::collections::HashSet; use std::sync::Arc; use datafusion_common::{ - assert_eq_or_internal_err, get_required_group_by_exprs_indices, - internal_datafusion_err, internal_err, Column, DFSchema, HashMap, JoinType, Result, + Column, DFSchema, HashMap, JoinType, Result, assert_eq_or_internal_err, + get_required_group_by_exprs_indices, internal_datafusion_err, internal_err, }; use datafusion_expr::expr::Alias; use datafusion_expr::{ - logical_plan::LogicalPlan, Aggregate, Distinct, EmptyRelation, Expr, Projection, - TableScan, Unnest, Window, + Aggregate, Distinct, EmptyRelation, Expr, Projection, TableScanBuilder, Unnest, + Window, logical_plan::LogicalPlan, }; use crate::optimize_projections::required_indices::RequiredIndices; use crate::utils::NamePreserver; -use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, -}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeContainer}; /// Optimizer rule to prune unnecessary columns from intermediate schemas /// inside the [`LogicalPlan`]. This rule: @@ -77,7 +74,7 @@ use datafusion_common::tree_node::{ pub struct OptimizeProjections {} impl OptimizeProjections { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -136,9 +133,11 @@ fn optimize_projections( // their parents' required indices. match plan { LogicalPlan::Projection(proj) => { - return merge_consecutive_projections(proj)?.transform_data(|proj| { - rewrite_projection_given_requirements(proj, config, &indices) - }) + return merge_consecutive_projections(proj)? + .transform_data(|proj| { + rewrite_projection_given_requirements(proj, config, &indices) + })? + .transform_data(|plan| optimize_subqueries(plan, config)); } LogicalPlan::Aggregate(aggregate) => { // Split parent requirements to GROUP BY and aggregate sections: @@ -147,26 +146,39 @@ fn optimize_projections( // `aggregate.aggr_expr`: let (group_by_reqs, aggregate_reqs) = indices.split_off(n_group_exprs); - // Get absolutely necessary GROUP BY fields: - let group_by_expr_existing = aggregate - .group_expr - .iter() - .map(|group_by_expr| group_by_expr.schema_name().to_string()) - .collect::>(); - - let new_group_bys = if let Some(simplest_groupby_indices) = - get_required_group_by_exprs_indices( - aggregate.input.schema(), - &group_by_expr_existing, - ) { - // Some of the fields in the GROUP BY may be required by the - // parent even if these fields are unnecessary in terms of - // functional dependency. - group_by_reqs - .append(&simplest_groupby_indices) - .get_at_indices(&aggregate.group_expr) - } else { + // Get absolutely necessary GROUP BY fields. + // + // When the input has no functional dependencies, we can + // short-circuit this analysis. + let new_group_bys = if aggregate + .input + .schema() + .functional_dependencies() + .is_empty() + { aggregate.group_expr + } else { + let group_by_expr_existing = aggregate + .group_expr + .iter() + .map(|group_by_expr| group_by_expr.schema_name().to_string()) + .collect::>(); + + if let Some(simplest_groupby_indices) = + get_required_group_by_exprs_indices( + aggregate.input.schema(), + &group_by_expr_existing, + ) + { + // Some of the fields in the GROUP BY may be required by + // the parent even if these fields are unnecessary in + // terms of functional dependency. + group_by_reqs + .append(&simplest_groupby_indices) + .get_at_indices(&aggregate.group_expr) + } else { + aggregate.group_expr + } }; // Only use the absolutely necessary aggregate expressions required @@ -210,7 +222,8 @@ fn optimize_projections( new_aggr_expr, ) .map(LogicalPlan::Aggregate) - }); + })? + .transform_data(|plan| optimize_subqueries(plan, config)); } LogicalPlan::Window(window) => { let input_schema = Arc::clone(window.input.schema()); @@ -250,33 +263,22 @@ fn optimize_projections( .map(LogicalPlan::Window) .map(Transformed::yes) } - }); + })? + .transform_data(|plan| optimize_subqueries(plan, config)); } LogicalPlan::TableScan(table_scan) => { - let TableScan { - table_name, - source, - projection, - filters, - fetch, - projected_schema: _, - } = table_scan; - // Get indices referred to in the original (schema with all fields) // given projected indices. - let projection = match &projection { + let projection = match &table_scan.projection { Some(projection) => indices.into_mapped_indices(|idx| projection[idx]), None => indices.into_inner(), }; - return TableScan::try_new( - table_name, - source, - Some(projection), - filters, - fetch, - ) - .map(LogicalPlan::TableScan) - .map(Transformed::yes); + let new_scan = TableScanBuilder::from(table_scan) + .with_projection(Some(projection)) + .build()?; + + return Transformed::yes(LogicalPlan::TableScan(new_scan)) + .transform_data(|plan| optimize_subqueries(plan, config)); } // Other node types are handled below _ => {} @@ -334,29 +336,34 @@ fn optimize_projections( .collect() } LogicalPlan::Extension(extension) => { - let Some(necessary_children_indices) = + if let Some(necessary_children_indices) = extension.node.necessary_children_exprs(indices.indices()) - else { - // Requirements from parent cannot be routed down to user defined logical plan safely - return Ok(Transformed::no(plan)); - }; - let children = extension.node.inputs(); - assert_eq_or_internal_err!( - children.len(), - necessary_children_indices.len(), - "Inconsistent length between children and necessary children indices. \ + { + let children = extension.node.inputs(); + assert_eq_or_internal_err!( + children.len(), + necessary_children_indices.len(), + "Inconsistent length between children and necessary children indices. \ Make sure `.necessary_children_exprs` implementation of the \ `UserDefinedLogicalNode` is consistent with actual children length \ for the node." - ); - children - .into_iter() - .zip(necessary_children_indices) - .map(|(child, necessary_indices)| { - RequiredIndices::new_from_indices(necessary_indices) - .with_plan_exprs(&plan, child.schema()) - }) - .collect::>>()? + ); + children + .into_iter() + .zip(necessary_children_indices) + .map(|(child, necessary_indices)| { + RequiredIndices::new_from_indices(necessary_indices) + .with_plan_exprs(&plan, child.schema()) + }) + .collect::>>()? + } else { + // Requirements from parent cannot be routed down to user defined logical plan safely + // Assume it requires all input exprs here + plan.inputs() + .into_iter() + .map(RequiredIndices::new_for_all_exprs) + .collect() + } } LogicalPlan::EmptyRelation(_) | LogicalPlan::Values(_) @@ -364,34 +371,21 @@ fn optimize_projections( // These operators have no inputs, so stop the optimization process. return Ok(Transformed::no(plan)); } - LogicalPlan::RecursiveQuery(recursive) => { - // Only allow subqueries that reference the current CTE; nested subqueries are not yet - // supported for projection pushdown for simplicity. - // TODO: be able to do projection pushdown on recursive CTEs with subqueries - if plan_contains_other_subqueries( - recursive.static_term.as_ref(), - &recursive.name, - ) || plan_contains_other_subqueries( - recursive.recursive_term.as_ref(), - &recursive.name, - ) { - return Ok(Transformed::no(plan)); - } - - plan.inputs() - .into_iter() - .map(|input| { - indices - .clone() - .with_projection_beneficial() - .with_plan_exprs(&plan, input.schema()) - }) - .collect::>>()? + LogicalPlan::RecursiveQuery(_) => { + // optimize the static and recursive terms: treat each recursive CTE term like a + // standalone subquery: optimize its internals, but do not push parent required indices + // through the RecursiveQuery boundary, as this can otherwise lead to bugs + // (see: https://github.com/apache/datafusion/issues/22249) + return plan.map_children(|c| { + let indices = RequiredIndices::new_for_all_exprs(&c); + optimize_projections(c, config, indices) + }); } LogicalPlan::Join(join) => { let left_len = join.left.schema().fields().len(); + let right_len = join.right.schema().fields().len(); let (left_req_indices, right_req_indices) = - split_join_requirements(left_len, indices, &join.join_type); + split_join_requirements(left_len, right_len, indices, &join.join_type); let left_indices = left_req_indices.with_plan_exprs(&plan, join.left.schema())?; let right_indices = @@ -463,6 +457,9 @@ fn optimize_projections( ) })?; + let transformed_plan = + transformed_plan.transform_data(|plan| optimize_subqueries(plan, config))?; + // If any of the children are transformed, we need to potentially update the plan's schema if transformed_plan.transformed { transformed_plan.map_data(|plan| plan.recompute_schema()) @@ -471,8 +468,19 @@ fn optimize_projections( } } -/// Merges consecutive projections. -/// +/// Optimizes uncorrelated subquery plans embedded in expressions of the given +/// plan node (e.g., `Expr::ScalarSubquery`). `map_children` only visits direct +/// plan inputs, so subqueries must be handled separately. +fn optimize_subqueries( + plan: LogicalPlan, + config: &dyn OptimizerConfig, +) -> Result> { + plan.map_uncorrelated_subqueries(|subquery_plan| { + let indices = RequiredIndices::new_for_all_exprs(&subquery_plan); + optimize_projections(subquery_plan, config, indices) + }) +} + /// Given a projection `proj`, this function attempts to merge it with a previous /// projection if it exists and if merging is beneficial. Merging is considered /// beneficial when expressions in the current projection are non-trivial and @@ -504,6 +512,30 @@ fn optimize_projections( /// - `Ok(None)`: Signals that merge is not beneficial (and has not taken place). /// - `Err(error)`: An error occurred during the function call. fn merge_consecutive_projections(proj: Projection) -> Result> { + // Collapse the whole chain in one pass; otherwise an N-deep chain needs + // N outer optimizer passes to fully fold. + let mut current = proj; + let mut transformed_any = false; + loop { + let Transformed { + data, transformed, .. + } = merge_consecutive_projections_one_level(current)?; + current = data; + if !transformed { + break; + } + transformed_any = true; + } + Ok(if transformed_any { + Transformed::yes(current) + } else { + Transformed::no(current) + }) +} + +fn merge_consecutive_projections_one_level( + proj: Projection, +) -> Result> { let Projection { expr, input, @@ -530,15 +562,14 @@ fn merge_consecutive_projections(proj: Projection) -> Result 1 - && !is_expr_trivial( - &prev_projection.expr - [prev_projection.schema.index_of_column(col).unwrap()], - ) + && !prev_projection.expr[prev_projection.schema.index_of_column(col).unwrap()] + .placement() + .should_push_to_leaves() }) { // no change return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no); @@ -565,7 +596,22 @@ fn merge_consecutive_projections(proj: Projection) -> Result rewrite_expr(*expr, &prev_projection).map(|result| { result.update_data(|expr| { - Expr::Alias(Alias::new(expr, relation, name).with_metadata(metadata)) + // After substitution, the inner expression may now have the + // same schema_name as the alias (e.g. when an extraction + // alias like `__extracted_1 AS f(x)` is resolved back to + // `f(x)`). Wrapping in a redundant self-alias causes a + // cosmetic `f(x) AS f(x)` due to Display vs schema_name + // formatting differences. Drop the alias when it matches. + if metadata.is_none() && expr.schema_name().to_string() == name { + expr + } else { + Expr::Alias(Alias { + expr: Box::new(expr), + relation, + name, + metadata, + }) + } }) }), e => rewrite_expr(e, &prev_projection), @@ -591,11 +637,6 @@ fn merge_consecutive_projections(proj: Projection) -> Result bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_, _)) -} - /// Rewrites a projection expression using the projection before it (i.e. its input) /// This is a subroutine to the `merge_consecutive_projections` function. /// @@ -675,56 +716,6 @@ fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { }) } -/// Accumulates outer-referenced columns by the -/// given expression, `expr`. -/// -/// # Parameters -/// -/// * `expr` - The expression to analyze for outer-referenced columns. -/// * `columns` - A mutable reference to a `HashSet` where detected -/// columns are collected. -fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) { - // inspect_expr_pre doesn't handle subquery references, so find them explicitly - expr.apply(|expr| { - match expr { - Expr::OuterReferenceColumn(_, col) => { - columns.insert(col); - } - Expr::ScalarSubquery(subquery) => { - outer_columns_helper_multi(&subquery.outer_ref_columns, columns); - } - Expr::Exists(exists) => { - outer_columns_helper_multi(&exists.subquery.outer_ref_columns, columns); - } - Expr::InSubquery(insubquery) => { - outer_columns_helper_multi( - &insubquery.subquery.outer_ref_columns, - columns, - ); - } - _ => {} - }; - Ok(TreeNodeRecursion::Continue) - }) - // unwrap: closure above never returns Err, so can not be Err here - .unwrap(); -} - -/// A recursive subroutine that accumulates outer-referenced columns by the -/// given expressions (`exprs`). -/// -/// # Parameters -/// -/// * `exprs` - The expressions to analyze for outer-referenced columns. -/// * `columns` - A mutable reference to a `HashSet` where detected -/// columns are collected. -fn outer_columns_helper_multi<'a, 'b>( - exprs: impl IntoIterator, - columns: &'b mut HashSet<&'a Column>, -) { - exprs.into_iter().for_each(|e| outer_columns(e, columns)); -} - /// Splits requirement indices for a join into left and right children based on /// the join type. /// @@ -745,6 +736,7 @@ fn outer_columns_helper_multi<'a, 'b>( /// # Parameters /// /// * `left_len` - The length of the left child. +/// * `right_len` - The length of the right child. /// * `indices` - A slice of requirement indices. /// * `join_type` - The type of join (e.g. `INNER`, `LEFT`, `RIGHT`). /// @@ -756,21 +748,29 @@ fn outer_columns_helper_multi<'a, 'b>( /// adjusted based on the join type. fn split_join_requirements( left_len: usize, + right_len: usize, indices: RequiredIndices, join_type: &JoinType, ) -> (RequiredIndices, RequiredIndices) { match join_type { // In these cases requirements are split between left/right children: - JoinType::Inner - | JoinType::Left - | JoinType::Right - | JoinType::Full - | JoinType::LeftMark - | JoinType::RightMark => { + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { // Decrease right side indices by `left_len` so that they point to valid // positions within the right child: indices.split_off(left_len) } + JoinType::LeftMark => { + // LeftMark output: [left_cols(0..left_len), mark] + // The mark column is synthetic (produced by the join itself), + // so discard it and route only to the left child. + let (left_indices, _mark) = indices.split_off(left_len); + (left_indices, RequiredIndices::new()) + } + JoinType::RightMark => { + // Same as LeftMark, but for the right child. + let (right_indices, _mark) = indices.split_off(right_len); + (RequiredIndices::new(), right_indices) + } // All requirements can be re-routed to left child directly. JoinType::LeftAnti | JoinType::LeftSemi => (indices, RequiredIndices::new()), // All requirements can be re-routed to right side directly. @@ -876,65 +876,6 @@ pub fn is_projection_unnecessary( )) } -/// Returns true if the plan subtree contains any subqueries that are not the -/// CTE reference itself. This treats any non-CTE [`LogicalPlan::SubqueryAlias`] -/// node (including aliased relations) as a blocker, along with expression-level -/// subqueries like scalar, EXISTS, or IN. These cases prevent projection -/// pushdown for now because we cannot safely reason about their column usage. -fn plan_contains_other_subqueries(plan: &LogicalPlan, cte_name: &str) -> bool { - if let LogicalPlan::SubqueryAlias(alias) = plan { - if alias.alias.table() != cte_name - && !subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name) - { - return true; - } - } - - let mut found = false; - plan.apply_expressions(|expr| { - if expr_contains_subquery(expr) { - found = true; - Ok(TreeNodeRecursion::Stop) - } else { - Ok(TreeNodeRecursion::Continue) - } - }) - .expect("expression traversal never fails"); - if found { - return true; - } - - plan.inputs() - .into_iter() - .any(|child| plan_contains_other_subqueries(child, cte_name)) -} - -fn expr_contains_subquery(expr: &Expr) -> bool { - expr.exists(|e| match e { - Expr::ScalarSubquery(_) | Expr::Exists(_) | Expr::InSubquery(_) => Ok(true), - _ => Ok(false), - }) - // Safe unwrap since we are doing a simple boolean check - .unwrap() -} - -fn subquery_alias_targets_recursive_cte(plan: &LogicalPlan, cte_name: &str) -> bool { - match plan { - LogicalPlan::TableScan(scan) => scan.table_name.table() == cte_name, - LogicalPlan::SubqueryAlias(alias) => { - subquery_alias_targets_recursive_cte(alias.input.as_ref(), cte_name) - } - _ => { - let inputs = plan.inputs(); - if inputs.len() == 1 { - subquery_alias_targets_recursive_cte(inputs[0], cte_name) - } else { - false - } - } - } -} - #[cfg(test)] mod tests { use std::cmp::Ordering; @@ -957,14 +898,15 @@ mod tests { }; use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ - binary_expr, build_join_schema, + BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, Projection, + UserDefinedLogicalNodeCore, WindowFunctionDefinition, binary_expr, + build_join_schema, builder::table_scan_with_filters, col, expr::{self, Cast}, lit, logical_plan::{builder::LogicalPlanBuilder, table_scan}, - not, try_cast, when, BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, - Projection, UserDefinedLogicalNodeCore, WindowFunctionDefinition, + not, try_cast, when, }; use insta::assert_snapshot; @@ -1171,6 +1113,57 @@ mod tests { } } + /// A user-defined node that does NOT implement `necessary_children_exprs`, + /// so the optimizer cannot determine which columns are required from its + /// children and must assume all columns are needed. + #[derive(Debug, Hash, PartialEq, Eq)] + struct OpaqueRequirementsUserDefined { + input: Arc, + schema: DFSchemaRef, + } + + // Manual implementation needed because of `schema` field. Comparison excludes this field. + impl PartialOrd for OpaqueRequirementsUserDefined { + fn partial_cmp(&self, other: &Self) -> Option { + self.input + .partial_cmp(&other.input) + .filter(|cmp| *cmp != Ordering::Equal || self == other) + } + } + + impl UserDefinedLogicalNodeCore for OpaqueRequirementsUserDefined { + fn name(&self) -> &str { + "OpaqueRequirementsUserDefined" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + mut inputs: Vec, + ) -> Result { + Ok(Self { + input: Arc::new(inputs.swap_remove(0)), + schema: Arc::clone(&self.schema), + }) + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "OpaqueRequirementsUserDefined") + } + } + #[test] fn merge_two_projection() -> Result<()> { let table_scan = test_table_scan()?; @@ -2203,6 +2196,29 @@ mod tests { Ok(()) } + #[test] + fn test_continue_processing_through_extension() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .project(vec![col("a")])? + .project(vec![col("a")])? + .build()?; + let plan = LogicalPlan::Extension(Extension { + node: Arc::new(OpaqueRequirementsUserDefined { + input: Arc::new(plan), + schema: Arc::clone(table_scan.schema()), + }), + }); + let plan = optimize(plan).expect("failed to optimize plan"); + assert_optimized_plan_equal!( + plan, + @r" + OpaqueRequirementsUserDefined + TableScan: test projection=[a] + " + ) + } + /// tests that it removes an aggregate is never used downstream #[test] fn table_unused_aggregate() -> Result<()> { @@ -2310,6 +2326,68 @@ mod tests { ) } + // Regression test for https://github.com/apache/datafusion/issues/20083 + // Optimizer must not fail when LeftMark joins from EXISTS OR EXISTS + // feed into a Left join. + #[test] + fn optimize_projections_exists_or_exists_with_outer_join() -> Result<()> { + use datafusion_expr::utils::disjunction; + use datafusion_expr::{exists, out_ref_col}; + + let table_a = test_table_scan_with_name("a")?; + let table_b = test_table_scan_with_name("b")?; + + let sq_a = Arc::new( + LogicalPlanBuilder::from(test_table_scan_with_name("sq_a")?) + .filter(col("sq_a.a").eq(out_ref_col(DataType::UInt32, "a.a")))? + .project(vec![lit(1)])? + .build()?, + ); + + let sq_b = Arc::new( + LogicalPlanBuilder::from(test_table_scan_with_name("sq_b")?) + .filter(col("sq_b.b").eq(out_ref_col(DataType::UInt32, "a.b")))? + .project(vec![lit(1)])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(table_a) + .filter(disjunction(vec![exists(sq_a), exists(sq_b)]).unwrap())? + .join(table_b, JoinType::Left, (vec!["a"], vec!["a"]), None)? + .build()?; + + let optimizer = Optimizer::new(); + let config = OptimizerContext::new(); + optimizer.optimize(plan, &config, observe)?; + + Ok(()) + } + + #[test] + fn optimize_projections_left_mark_join_with_projection() -> Result<()> { + let table_a = test_table_scan_with_name("a")?; + let table_b = test_table_scan_with_name("b")?; + let table_c = test_table_scan_with_name("c")?; + + let plan = LogicalPlanBuilder::from(table_a) + .join(table_b, JoinType::LeftMark, (vec!["a"], vec!["a"]), None)? + .project(vec![col("a.a"), col("a.b"), col("a.c")])? + .join(table_c, JoinType::Left, (vec!["a"], vec!["a"]), None)? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Left Join: a.a = c.a + Projection: a.a, a.b, a.c + LeftMark Join: a.a = b.a + TableScan: a projection=[a, b, c] + TableScan: b projection=[a] + TableScan: c projection=[a, b, c] + " + ) + } + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} fn optimize(plan: LogicalPlan) -> Result { diff --git a/datafusion/optimizer/src/optimize_projections/required_indices.rs b/datafusion/optimizer/src/optimize_projections/required_indices.rs index c1e0885c9b5f2..5e73a9fbeceda 100644 --- a/datafusion/optimizer/src/optimize_projections/required_indices.rs +++ b/datafusion/optimizer/src/optimize_projections/required_indices.rs @@ -17,8 +17,7 @@ //! [`RequiredIndices`] helper for OptimizeProjection -use crate::optimize_projections::outer_columns; -use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{Column, DFSchemaRef, Result}; use datafusion_expr::{Expr, LogicalPlan}; @@ -105,44 +104,59 @@ impl RequiredIndices { /// Adds the indices of the fields referred to by the given expression /// `expr` within the given schema (`input_schema`). /// - /// Self is NOT compacted (and thus this method is not pub) + /// Self is NOT compacted (duplicate indices are removed by a subsequent + /// [`Self::compact`] call), and thus this method is not pub. /// /// # Parameters /// /// * `input_schema`: The input schema to analyze for index requirements. /// * `expr`: An expression for which we want to find necessary field indices. fn add_expr(&mut self, input_schema: &DFSchemaRef, expr: &Expr) { - // TODO could remove these clones (and visit the expression directly) - let mut cols = expr.column_refs(); - // Get outer-referenced (subquery) columns: - outer_columns(expr, &mut cols); - self.indices.reserve(cols.len()); - for col in cols { - if let Some(idx) = input_schema.maybe_index_of_column(col) { - self.indices.push(idx); + // `apply` does not descend into subqueries, so recurse manually to + // handle those cases. + expr.apply(|e| { + match e { + Expr::Column(c) | Expr::OuterReferenceColumn(_, c) => { + if let Some(idx) = input_schema.maybe_index_of_column(c) { + self.indices.push(idx); + } + } + Expr::ScalarSubquery(sub) => { + self.add_exprs(input_schema, &sub.outer_ref_columns); + } + Expr::Exists(ex) => { + self.add_exprs(input_schema, &ex.subquery.outer_ref_columns); + } + Expr::InSubquery(isq) => { + self.add_exprs(input_schema, &isq.subquery.outer_ref_columns); + } + _ => {} } + Ok(TreeNodeRecursion::Continue) + }) + .expect("traversal is infallible"); + } + + /// Like [`Self::add_expr`], but for multiple expressions. + fn add_exprs<'a>( + &mut self, + input_schema: &DFSchemaRef, + exprs: impl IntoIterator, + ) { + for expr in exprs { + self.add_expr(input_schema, expr); } } /// Adds the indices of the fields referred to by the given expressions - /// `within the given schema. - /// - /// # Parameters - /// - /// * `input_schema`: The input schema to analyze for index requirements. - /// * `exprs`: the expressions for which we want to find field indices. + /// within the given schema. pub fn with_exprs<'a>( - self, + mut self, schema: &DFSchemaRef, exprs: impl IntoIterator, ) -> Self { - exprs - .into_iter() - .fold(self, |mut acc, expr| { - acc.add_expr(schema, expr); - acc - }) - .compact() + self.add_exprs(schema, exprs); + self.compact() } /// Adds all `indices` into this instance. diff --git a/datafusion/optimizer/src/optimize_unions.rs b/datafusion/optimizer/src/optimize_unions.rs index cfabd512b427b..80f8ebeef1697 100644 --- a/datafusion/optimizer/src/optimize_unions.rs +++ b/datafusion/optimizer/src/optimize_unions.rs @@ -18,10 +18,10 @@ //! [`OptimizeUnions`]: removes `Union` nodes in the logical plan. use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::Transformed; use datafusion_common::Result; +use datafusion_common::tree_node::Transformed; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; -use datafusion_expr::{Distinct, LogicalPlan, Union}; +use datafusion_expr::{Distinct, LogicalPlan, Projection, Union}; use itertools::Itertools; use std::sync::Arc; @@ -32,7 +32,7 @@ use std::sync::Arc; pub struct OptimizeUnions; impl OptimizeUnions { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -64,11 +64,11 @@ impl OptimizerRule for OptimizeUnions { let inputs = inputs .into_iter() .flat_map(extract_plans_from_union) - .map(|plan| coerce_plan_expr_for_schema(plan, &schema)) + .map(|plan| Ok(Arc::new(coerce_plan_expr_for_schema(plan, &schema)?))) .collect::>>()?; Ok(Transformed::yes(LogicalPlan::Union(Union { - inputs: inputs.into_iter().map(Arc::new).collect_vec(), + inputs, schema, }))) } @@ -105,6 +105,38 @@ fn extract_plans_from_union(plan: Arc) -> Vec { .into_iter() .map(Arc::unwrap_or_clone) .collect::>(), + // While unnesting, unwrap a Projection whose input is a nested Union, + // flatten the inner Union, and push the same Projection down onto + // each of the nested Union’s children. + // + // Example: + // Union { Projection { Union { plan1, plan2 } }, plan3 } + // => Union { Projection { plan1 }, Projection { plan2 }, plan3 } + LogicalPlan::Projection(Projection { + expr, + input, + schema, + .. + }) => match Arc::unwrap_or_clone(input) { + LogicalPlan::Union(Union { inputs, .. }) => inputs + .into_iter() + .map(Arc::unwrap_or_clone) + .map(|plan| { + LogicalPlan::Projection( + Projection::try_new_with_schema( + expr.clone(), + Arc::new(plan), + Arc::clone(&schema), + ) + .unwrap(), + ) + }) + .collect::>(), + + plan => vec![LogicalPlan::Projection( + Projection::try_new_with_schema(expr, Arc::new(plan), schema).unwrap(), + )], + }, plan => vec![plan], } } @@ -119,10 +151,10 @@ fn extract_plan_from_distinct(plan: Arc) -> Arc { #[cfg(test)] mod tests { use super::*; - use crate::analyzer::type_coercion::TypeCoercion; + use crate::OptimizerContext; use crate::analyzer::Analyzer; + use crate::analyzer::type_coercion::TypeCoercion; use crate::assert_optimized_plan_eq_snapshot; - use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_expr::{col, logical_plan::table_scan}; @@ -331,6 +363,27 @@ mod tests { ") } + #[test] + fn eliminate_nested_union_in_projection() -> Result<()> { + let plan_builder = table_scan(Some("table"), &schema(), None)?; + + let plan = plan_builder + .clone() + .union(plan_builder.clone().build()?)? + .project(vec![col("id").alias("table_id"), col("key"), col("value")])? + .union(plan_builder.build()?)? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Union + Projection: id AS table_id, key, value + TableScan: table + Projection: id AS table_id, key, value + TableScan: table + TableScan: table + ") + } + #[test] fn eliminate_nested_union_with_type_cast_projection() -> Result<()> { let table_1 = table_scan( @@ -444,9 +497,7 @@ mod tests { OptimizerContext::new().with_max_passes(1), vec![Arc::new(OptimizeUnions::new())], plan, - @r" - TableScan: table - " + @"TableScan: table" ) } } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 421563d5e7e88..a765d7f27a51e 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -22,15 +22,24 @@ use std::sync::Arc; use chrono::{DateTime, Utc}; use datafusion_expr::registry::FunctionRegistry; -use datafusion_expr::{assert_expected_schema, InvariantLevel}; +use datafusion_expr::{InvariantLevel, assert_expected_schema}; use log::{debug, warn}; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; -use datafusion_common::{internal_err, DFSchema, DataFusionError, HashSet, Result}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, +}; +use datafusion_common::{DFSchema, DataFusionError, HashSet, Result, internal_err}; +use datafusion_expr::dml::CopyTo; use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::{ + Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, Distinct, + DistinctOn, DmlStatement, Explain, Expr, Extension, Filter, Join, Limit, Projection, + RecursiveQuery, Repartition, Sort, Statement, Subquery, SubqueryAlias, Union, Unnest, + Window, +}; use crate::common_subexpr_eliminate::CommonSubexprEliminate; use crate::decorrelate_lateral_join::DecorrelateLateralJoin; @@ -43,6 +52,7 @@ use crate::eliminate_join::EliminateJoin; use crate::eliminate_limit::EliminateLimit; use crate::eliminate_outer_join::EliminateOuterJoin; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; +use crate::extract_leaf_expressions::{ExtractLeafExpressions, PushDownLeafProjections}; use crate::filter_null_join_keys::FilterNullJoinKeys; use crate::optimize_projections::OptimizeProjections; use crate::optimize_unions::OptimizeUnions; @@ -51,17 +61,19 @@ use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; +use crate::rewrite_set_comparison::RewriteSetComparison; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; +use crate::unions_to_filter::UnionsToFilter; use crate::utils::log_plan; -/// `OptimizerRule`s transforms one [`LogicalPlan`] into another which -/// computes the same results, but in a potentially more efficient -/// way. If there are no suitable transformations for the input plan, -/// the optimizer should simply return it unmodified. +/// Transforms one [`LogicalPlan`] into another which computes the same results, +/// but in a potentially more efficient way. /// -/// To change the semantics of a `LogicalPlan`, see [`AnalyzerRule`] +/// See notes on [`Self::rewrite`] for details on how to implement an `OptimizerRule`. +/// +/// To change the semantics of a `LogicalPlan`, see [`AnalyzerRule`]. /// /// Use [`SessionState::add_optimizer_rule`] to register additional /// `OptimizerRule`s. @@ -86,8 +98,40 @@ pub trait OptimizerRule: Debug { true } - /// Try to rewrite `plan` to an optimized form, returning `Transformed::yes` - /// if the plan was rewritten and `Transformed::no` if it was not. + /// Try to rewrite `plan` to an optimized form, returning [`Transformed::yes`] + /// if the plan was rewritten and [`Transformed::no`] if it was not. + /// + /// # Notes for implementations: + /// + /// ## Return the same plan if no changes were made + /// + /// If there are no suitable transformations for the input plan, + /// the optimizer should simply return it unmodified. + /// + /// The optimizer will call `rewrite` several times until a fixed point is + /// reached, so it is important that `rewrite` return [`Transformed::no`] if + /// the output is the same. + /// + /// ## Matching on functions + /// + /// The rule should avoid function-specific transformations, and instead use + /// methods on [`ScalarUDFImpl`] and [`AggregateUDFImpl`]. Specifically, the + /// rule should not check function names as functions can be overridden, and + /// may not have the same semantics as the functions provided with + /// DataFusion. + /// + /// For example, if a rule rewrites a function based on the check + /// `func.name() == "sum"`, it may rewrite the plan incorrectly if the + /// registered `sum` function has different semantics (for example, the + /// `sum` function from the `datafusion-spark` crate). + /// + /// There are still several cases that rely on function name checking in + /// the rules included with DataFusion. Please see [#18643] for more details + /// and to help remove these cases. + /// + /// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl + /// [`AggregateUDFImpl`]: datafusion_expr::ScalarUDFImpl + /// [#18643]: https://github.com/apache/datafusion/issues/18643 fn rewrite( &self, _plan: LogicalPlan, @@ -100,8 +144,9 @@ pub trait OptimizerRule: Debug { /// Options to control the DataFusion Optimizer. pub trait OptimizerConfig { /// Return the time at which the query execution started. This - /// time is used as the value for now() - fn query_execution_start_time(&self) -> DateTime; + /// time is used as the value for `now()`. If `None`, time-dependent + /// functions like `now()` will not be simplified during optimization. + fn query_execution_start_time(&self) -> Option>; /// Return alias generator used to generate unique aliases for subqueries fn alias_generator(&self) -> &Arc; @@ -118,8 +163,9 @@ pub trait OptimizerConfig { #[derive(Debug)] pub struct OptimizerContext { /// Query execution start time that can be used to rewrite - /// expressions such as `now()` to use a literal value instead - query_execution_start_time: DateTime, + /// expressions such as `now()` to use a literal value instead. + /// If `None`, time-dependent functions will not be simplified. + query_execution_start_time: Option>, /// Alias generator used to generate unique aliases for subqueries alias_generator: Arc, @@ -139,7 +185,7 @@ impl OptimizerContext { /// Create a optimizer config with provided [ConfigOptions]. pub fn new_with_config_options(options: Arc) -> Self { Self { - query_execution_start_time: Utc::now(), + query_execution_start_time: Some(Utc::now()), alias_generator: Arc::new(AliasGenerator::new()), options, } @@ -153,13 +199,19 @@ impl OptimizerContext { self } - /// Specify whether the optimizer should skip rules that produce - /// errors, or fail the query + /// Set the query execution start time pub fn with_query_execution_start_time( mut self, - query_execution_tart_time: DateTime, + query_execution_start_time: DateTime, ) -> Self { - self.query_execution_start_time = query_execution_tart_time; + self.query_execution_start_time = Some(query_execution_start_time); + self + } + + /// Clear the query execution start time. When `None`, time-dependent + /// functions like `now()` will not be simplified during optimization. + pub fn without_query_execution_start_time(mut self) -> Self { + self.query_execution_start_time = None; self } @@ -185,7 +237,7 @@ impl Default for OptimizerContext { } impl OptimizerConfig for OptimizerContext { - fn query_execution_start_time(&self) -> DateTime { + fn query_execution_start_time(&self) -> Option> { self.query_execution_start_time } @@ -226,8 +278,19 @@ impl Default for Optimizer { impl Optimizer { /// Create a new optimizer using the recommended list of rules pub fn new() -> Self { + // NOTEs: + // - The order of rules in this list is important, as it determines the + // order in which they are applied. + // - Adding a new rule here is expensive as it will be applied to all + // queries, and will likely increase the optimization time. Please extend + // existing rules when possible, rather than adding a new rule. + // If you do add a new rule considering having aggressive no-op paths + // (e.g. if the plan doesn't contain any of the nodes you are looking for + // return `Transformed::no`; only works if you control the traversal). let rules: Vec> = vec![ + Arc::new(RewriteSetComparison::new()), Arc::new(OptimizeUnions::new()), + Arc::new(UnionsToFilter::new()), Arc::new(SimplifyExpressions::new()), Arc::new(ReplaceDistinctWithAggregate::new()), Arc::new(EliminateJoin::new()), @@ -250,6 +313,8 @@ impl Optimizer { // that might benefit from the following rules Arc::new(EliminateGroupByConstant::new()), Arc::new(CommonSubexprEliminate::new()), + Arc::new(ExtractLeafExpressions::new()), + Arc::new(PushDownLeafProjections::new()), Arc::new(OptimizeProjections::new()), ]; @@ -288,9 +353,7 @@ impl TreeNodeRewriter for Rewriter<'_> { fn f_down(&mut self, node: LogicalPlan) -> Result> { if self.apply_order == ApplyOrder::TopDown { - { - self.rule.rewrite(node, self.config) - } + self.rule.rewrite(node, self.config) } else { Ok(Transformed::no(node)) } @@ -298,15 +361,220 @@ impl TreeNodeRewriter for Rewriter<'_> { fn f_up(&mut self, node: LogicalPlan) -> Result> { if self.apply_order == ApplyOrder::BottomUp { - { - self.rule.rewrite(node, self.config) - } + self.rule.rewrite(node, self.config) } else { Ok(Transformed::no(node)) } } } +/// Applies `f` to each child (input) of `plan` in place, using +/// [`Arc::make_mut`] for copy-on-write semantics on `Arc` +/// children. When the `Arc` refcount is 1 (the common case here) +/// `Arc::make_mut` hands out a `&mut` without cloning; when it is >1 the +/// inner value is cloned first. +/// +/// Returns `Ok(true)` if any child was modified by `f`. +/// +/// This is deliberately private to the optimizer rather than a method on +/// [`LogicalPlan`]: it is an implementation detail of in-place rewriting, and +/// the `Arc::make_mut` approach does not generalize to the other tree types +/// (`Expr` children are `Box`ed; `PhysicalExpr`/`ExecutionPlan` children are +/// `Arc`, which `Arc::make_mut` cannot handle). If `TreeNode` ever +/// grows an in-place traversal this logic can move there. +/// +/// # Error semantics +/// +/// If `f` returns `Err` for a child, that error is returned immediately; +/// children visited earlier keep whatever modifications `f` already applied +/// to them — they are **not** rolled back. +fn map_children_mut Result>( + plan: &mut LogicalPlan, + mut f: F, +) -> Result { + Ok(match plan { + LogicalPlan::Projection(Projection { input, .. }) + | LogicalPlan::Filter(Filter { input, .. }) + | LogicalPlan::Repartition(Repartition { input, .. }) + | LogicalPlan::Window(Window { input, .. }) + | LogicalPlan::Aggregate(Aggregate { input, .. }) + | LogicalPlan::Sort(Sort { input, .. }) + | LogicalPlan::Limit(Limit { input, .. }) + | LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) + | LogicalPlan::Analyze(Analyze { input, .. }) + | LogicalPlan::Dml(DmlStatement { input, .. }) + | LogicalPlan::Copy(CopyTo { input, .. }) + | LogicalPlan::Unnest(Unnest { input, .. }) => f(Arc::make_mut(input))?, + LogicalPlan::Subquery(Subquery { subquery, .. }) => f(Arc::make_mut(subquery))?, + LogicalPlan::Join(Join { left, right, .. }) => { + let l = f(Arc::make_mut(left))?; + let r = f(Arc::make_mut(right))?; + l || r + } + LogicalPlan::Union(Union { inputs, .. }) => { + let mut changed = false; + for input in inputs { + changed |= f(Arc::make_mut(input))?; + } + changed + } + LogicalPlan::Distinct(Distinct::All(input)) => f(Arc::make_mut(input))?, + LogicalPlan::Distinct(Distinct::On(DistinctOn { input, .. })) => { + f(Arc::make_mut(input))? + } + LogicalPlan::Explain(Explain { plan, .. }) => f(Arc::make_mut(plan))?, + LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { + input, + .. + })) + | LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { input, .. })) => { + f(Arc::make_mut(input))? + } + LogicalPlan::RecursiveQuery(RecursiveQuery { + static_term, + recursive_term, + .. + }) => { + let s = f(Arc::make_mut(static_term))?; + let r = f(Arc::make_mut(recursive_term))?; + s || r + } + LogicalPlan::Statement(Statement::Prepare(p)) => f(Arc::make_mut(&mut p.input))?, + LogicalPlan::Extension(Extension { node }) => { + let inputs = node.inputs(); + if inputs.is_empty() { + false + } else { + // Extension nodes don't expose mutable children, + // fall back to the ownership-based API + let mut changed = false; + let exprs = node.expressions(); + let new_inputs: Vec = inputs + .into_iter() + .map(|input| { + let mut plan = input.clone(); + if f(&mut plan)? { + changed = true; + } + Ok(plan) + }) + .collect::>>()?; + if changed { + *node = node.with_exprs_and_inputs(exprs, new_inputs)?; + } + changed + } + } + // plans without inputs + LogicalPlan::TableScan { .. } + | LogicalPlan::EmptyRelation { .. } + | LogicalPlan::Values { .. } + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Ddl(DdlStatement::CreateExternalTable(_)) + | LogicalPlan::Ddl(DdlStatement::CreateCatalogSchema(_)) + | LogicalPlan::Ddl(DdlStatement::CreateCatalog(_)) + | LogicalPlan::Ddl(DdlStatement::CreateIndex(_)) + | LogicalPlan::Ddl(DdlStatement::DropTable(_)) + | LogicalPlan::Ddl(DdlStatement::DropView(_)) + | LogicalPlan::Ddl(DdlStatement::DropCatalogSchema(_)) + | LogicalPlan::Ddl(DdlStatement::CreateFunction(_)) + | LogicalPlan::Ddl(DdlStatement::DropFunction(_)) + | LogicalPlan::Statement(_) => false, + }) +} + +/// Rewrites a plan tree in place using `Arc::make_mut` for +/// copy-on-write semantics on `Arc` children. +/// +/// This avoids the `Arc::unwrap_or_clone` + `Arc::new` cycle that the +/// ownership-based `TreeNode::rewrite` performs at every child node. +/// +/// # Error semantics +/// +/// On `Err`, `*plan` is left in an **unspecified** state and must not be used. +/// Note this is different than consuming APIs such as [`TreeNode::rewrite`] +/// where the original plan is freed and no longer available on error +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] +fn rewrite_plan_in_place( + plan: &mut LogicalPlan, + apply_order: ApplyOrder, + rule: &dyn OptimizerRule, + config: &dyn OptimizerConfig, +) -> Result { + // f_down phase + let mut changed = false; + if apply_order == ApplyOrder::TopDown { + // `rule.rewrite()` takes the plan by value, so bridge the `&mut` to an + // owned value with `std::mem::take`. `LogicalPlan::default()` is a cheap + // empty placeholder (shared empty schema, no allocation) and is + // overwritten with the rule's output on the next line. + let owned = std::mem::take(plan); + let result = rule.rewrite(owned, config)?; + *plan = result.data; + changed |= result.transformed; + // Respect TreeNodeRecursion::Stop/Jump from the rule + if result.tnr == TreeNodeRecursion::Stop { + return Ok(changed); + } + } + + // Recurse into children using Arc::make_mut (zero-cost when refcount == 1) + changed |= map_children_mut(plan, |child| { + rewrite_plan_in_place(child, apply_order, rule, config) + })?; + + // f_up phase + if apply_order == ApplyOrder::BottomUp { + let owned = std::mem::take(plan); + let result = rule.rewrite(owned, config)?; + *plan = result.data; + changed |= result.transformed; + } + + Ok(changed) +} + +/// Returns true if the plan contains any subquery expressions +/// (EXISTS, IN subquery, scalar subquery, set comparison). +/// +/// Used to determine whether the more expensive `rewrite_with_subqueries` +/// traversal is needed. When the plan has no subqueries, the cheaper +/// `rewrite` traversal is sufficient since all plan nodes are reachable +/// via direct children. +fn plan_has_subqueries(plan: &LogicalPlan) -> bool { + let mut found = false; + let _ = plan.apply(|node| { + if found { + return Ok(TreeNodeRecursion::Stop); + } + node.apply_expressions(|expr| { + if found { + return Ok(TreeNodeRecursion::Stop); + } + expr.apply(|e| { + if matches!( + e, + Expr::Exists(_) + | Expr::InSubquery(_) + | Expr::SetComparison(_) + | Expr::ScalarSubquery(_) + ) { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + })?; + Ok(if found { + TreeNodeRecursion::Stop + } else { + TreeNodeRecursion::Continue + }) + }); + found +} + impl Optimizer { /// Optimizes the logical plan by applying optimizer rules, and /// invoking observer function after each call @@ -336,6 +604,14 @@ impl Optimizer { while i < options.optimizer.max_passes { log_plan(&format!("Optimizer input (pass {i})"), &new_plan); + // Check once per pass whether the plan contains subquery + // expressions. When there are no subqueries, we use the + // cheaper `rewrite` traversal instead of + // `rewrite_with_subqueries`, avoiding the per-node + // map_subqueries call that walks all expression trees + // via ownership-based transform_down. + let has_subqueries = plan_has_subqueries(&new_plan); + for rule in &self.rules { // If skipping failed rules, copy plan before attempting to rewrite // as rewriting is destructive @@ -348,9 +624,42 @@ impl Optimizer { let result = match rule.apply_order() { // optimizer handles recursion - Some(apply_order) => new_plan.rewrite_with_subqueries( - &mut Rewriter::new(apply_order, rule.as_ref(), config), - ), + Some(apply_order) => { + if has_subqueries { + // Plans with subqueries need the full + // rewrite_with_subqueries traversal to + // recurse into subquery plans. + new_plan.rewrite_with_subqueries( + &mut Rewriter::new( + apply_order, + rule.as_ref(), + config, + ), + ) + } else { + // No subqueries: use in-place rewriting + // with Arc::make_mut for zero-cost CoW on + // children, avoiding Arc unwrap/rewrap. + // + // On error `new_plan` is left in an unspecified + // state (see `rewrite_plan_in_place`); the result + // handling below discards it, restoring `prev_plan` + // when `skip_failed_rules` is set or propagating + // the error otherwise. + rewrite_plan_in_place( + &mut new_plan, + apply_order, + rule.as_ref(), + config, + ) + .map(|transformed| { + Transformed::new_transformed( + std::mem::take(&mut new_plan), + transformed, + ) + }) + } + } // rule handles recursion itself None => { rule.rewrite(new_plan, config) @@ -446,7 +755,7 @@ impl Optimizer { /// These are invariants which should hold true before and after [`LogicalPlan`] optimization. /// /// This differs from [`LogicalPlan::check_invariants`], which addresses if a singular -/// LogicalPlan is valid. Instead this address if the optimization was valid based upon permitted changes. +/// LogicalPlan is valid. Instead, this address if the optimization was valid based upon permitted changes. fn assert_valid_optimization( plan: &LogicalPlan, prev_schema: &Arc, @@ -464,10 +773,10 @@ mod tests { use datafusion_common::tree_node::Transformed; use datafusion_common::{ - assert_contains, plan_err, DFSchema, DFSchemaRef, DataFusionError, Result, + DFSchema, DFSchemaRef, DataFusionError, Result, assert_contains, plan_err, }; use datafusion_expr::logical_plan::EmptyRelation; - use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, Projection}; + use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, Projection, col, lit}; use crate::optimizer::Optimizer; use crate::test::test_table_scan; diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index 73e6b418272a9..6f46d7b663342 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -89,7 +89,7 @@ mod tests { use std::sync::Arc; use datafusion_common::{DFSchema, Result}; - use datafusion_expr::{lit, LogicalPlan}; + use datafusion_expr::{LogicalPlan, lit}; use crate::plan_signature::get_node_number; diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 629b13e4001d8..18ddc361a0692 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -19,11 +19,11 @@ use std::sync::Arc; -use datafusion_common::tree_node::Transformed; use datafusion_common::JoinType; -use datafusion_common::{plan_err, Result}; +use datafusion_common::tree_node::Transformed; +use datafusion_common::{Column, DFSchemaRef, Result, ScalarValue, plan_err}; use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::{EmptyRelation, Projection, Union}; +use datafusion_expr::{EmptyRelation, Expr, GroupingSet, Projection, Union, cast, lit}; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; @@ -33,7 +33,7 @@ use crate::{OptimizerConfig, OptimizerRule}; pub struct PropagateEmptyRelation; impl PropagateEmptyRelation { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -73,12 +73,8 @@ impl OptimizerRule for PropagateEmptyRelation { Ok(Transformed::no(plan)) } LogicalPlan::Join(ref join) => { - // TODO: For Join, more join type need to be careful: - // For LeftOut/Full Join, if the right side is empty, the Join can be eliminated with a Projection with left side - // columns + right side columns replaced with null values. - // For RightOut/Full Join, if the left side is empty, the Join can be eliminated with a Projection with right side - // columns + left side columns replaced with null values. let (left_empty, right_empty) = binary_plan_children_is_empty(&plan)?; + let left_field_count = join.left.schema().fields().len(); match join.join_type { // For Full Join, only both sides are empty, the Join result is empty. @@ -88,6 +84,24 @@ impl OptimizerRule for PropagateEmptyRelation { schema: Arc::clone(&join.schema), }), )), + // For Full Join, if one side is empty, replace with a + // Projection that null-pads the empty side's columns. + JoinType::Full if right_empty => { + Ok(Transformed::yes(build_null_padded_projection( + Arc::clone(&join.left), + &join.schema, + left_field_count, + true, + )?)) + } + JoinType::Full if left_empty => { + Ok(Transformed::yes(build_null_padded_projection( + Arc::clone(&join.right), + &join.schema, + left_field_count, + false, + )?)) + } JoinType::Inner if left_empty || right_empty => Ok(Transformed::yes( LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -100,12 +114,32 @@ impl OptimizerRule for PropagateEmptyRelation { schema: Arc::clone(&join.schema), }), )), + // Left Join with empty right: all left rows survive + // with NULLs for right columns. + JoinType::Left if right_empty => { + Ok(Transformed::yes(build_null_padded_projection( + Arc::clone(&join.left), + &join.schema, + left_field_count, + true, + )?)) + } JoinType::Right if right_empty => Ok(Transformed::yes( LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: Arc::clone(&join.schema), }), )), + // Right Join with empty left: all right rows survive + // with NULLs for left columns. + JoinType::Right if left_empty => { + Ok(Transformed::yes(build_null_padded_projection( + Arc::clone(&join.right), + &join.schema, + left_field_count, + false, + )?)) + } JoinType::LeftSemi if left_empty || right_empty => Ok( Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -140,10 +174,16 @@ impl OptimizerRule for PropagateEmptyRelation { } } LogicalPlan::Aggregate(ref agg) => { - if !agg.group_expr.is_empty() { - if let Some(empty_plan) = empty_child(&plan)? { - return Ok(Transformed::yes(empty_plan)); - } + // An aggregate over an empty input can be eliminated only when + // there is no empty grouping set. An empty grouping set `()` + // (from `GROUPING SETS(())`, `ROLLUP(...)`, or `CUBE(...)`) + // always produces exactly one row even on empty input, so it + // must not be replaced by an empty relation. + if !agg.group_expr.is_empty() + && !has_empty_grouping_set(&agg.group_expr) + && let Some(empty_plan) = empty_child(&plan)? + { + return Ok(Transformed::yes(empty_plan)); } Ok(Transformed::no(LogicalPlan::Aggregate(agg.clone()))) } @@ -230,18 +270,93 @@ fn empty_child(plan: &LogicalPlan) -> Result> { } } +/// Builds a Projection that replaces one side of an outer join with NULL literals. +/// +/// When one side of an outer join is an `EmptyRelation`, the join can be eliminated +/// by projecting the surviving side's columns as-is and replacing the empty side's +/// columns with `CAST(NULL AS )`. +/// +/// The join schema is used as the projection's output schema to preserve nullability +/// guarantees (important for FULL JOIN where the surviving side's columns are marked +/// nullable in the join schema even if they aren't in the source schema). +/// +/// # Example +/// +/// For a `LEFT JOIN` where the right side is empty: +/// ```text +/// Left Join (orders.id = returns.order_id) Projection(orders.id, orders.amount, +/// ├── TableScan: orders => CAST(NULL AS Int64) AS order_id, +/// └── EmptyRelation CAST(NULL AS Utf8) AS reason) +/// └── TableScan: orders +/// ``` +fn build_null_padded_projection( + surviving_plan: Arc, + join_schema: &DFSchemaRef, + left_field_count: usize, + empty_side_is_right: bool, +) -> Result { + let exprs = join_schema + .iter() + .enumerate() + .map(|(i, (qualifier, field))| { + let on_empty_side = if empty_side_is_right { + i >= left_field_count + } else { + i < left_field_count + }; + + if on_empty_side { + cast(lit(ScalarValue::Null), field.data_type().clone()) + .alias_qualified(qualifier.cloned(), field.name()) + } else { + Expr::Column(Column::new(qualifier.cloned(), field.name())) + } + }) + .collect::>(); + + Ok(LogicalPlan::Projection(Projection::try_new_with_schema( + exprs, + surviving_plan, + Arc::clone(join_schema), + )?)) +} + +/// Returns `true` if any grouping set in the list of GROUP BY expressions is +/// the empty set `()`. +/// +/// An empty grouping set acts as a "grand total" group: the aggregate must +/// always produce **exactly one row** for it, even when the input is empty. +/// This means an aggregate with an empty grouping set cannot be replaced by +/// an empty relation. +/// +/// The three forms that can contain an empty grouping set: +/// - `GROUPING SETS (…, (), …)` — explicitly listed. +/// - `ROLLUP(exprs)` — always expands to include `()`. +/// - `CUBE(exprs)` — always expands to include `()`. +fn has_empty_grouping_set(group_expr: &[Expr]) -> bool { + match group_expr.first() { + Some(Expr::GroupingSet(GroupingSet::GroupingSets(groups))) => { + groups.iter().any(|g| g.is_empty()) + } + // Both ROLLUP and CUBE always include the empty grouping set (). + Some(Expr::GroupingSet(GroupingSet::Rollup(_))) + | Some(Expr::GroupingSet(GroupingSet::Cube(_))) => true, + _ => false, + } +} + #[cfg(test)] mod tests { - use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{Column, DFSchema, JoinType}; + use datafusion_common::{Column, DFSchema}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ - binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Operator, + Operator, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, }; + use crate::OptimizerContext; use crate::assert_optimized_plan_eq_snapshot; use crate::eliminate_filter::EliminateFilter; use crate::optimize_unions::OptimizeUnions; @@ -249,7 +364,6 @@ mod tests { assert_optimized_plan_with_rules, test_table_scan, test_table_scan_fields, test_table_scan_with_name, }; - use crate::OptimizerContext; use super::*; @@ -571,6 +685,111 @@ mod tests { assert_empty_left_empty_right_lp(true, false, JoinType::RightAnti, false) } + #[test] + fn test_left_join_right_empty_null_pad() -> Result<()> { + let left = + LogicalPlanBuilder::from(test_table_scan_with_name("left")?).build()?; + let right_empty = LogicalPlanBuilder::from(test_table_scan_with_name("right")?) + .filter(lit(false))? + .build()?; + + let plan = LogicalPlanBuilder::from(left) + .join_using( + right_empty, + JoinType::Left, + vec![Column::from_name("a".to_string())], + )? + .build()?; + + let expected = "Projection: left.a, left.b, left.c, CAST(NULL AS UInt32) AS a, CAST(NULL AS UInt32) AS b, CAST(NULL AS UInt32) AS c\n TableScan: left"; + assert_together_optimized_plan(plan, expected, true) + } + + #[test] + fn test_right_join_left_empty_null_pad() -> Result<()> { + let left_empty = LogicalPlanBuilder::from(test_table_scan_with_name("left")?) + .filter(lit(false))? + .build()?; + let right = + LogicalPlanBuilder::from(test_table_scan_with_name("right")?).build()?; + + let plan = LogicalPlanBuilder::from(left_empty) + .join_using( + right, + JoinType::Right, + vec![Column::from_name("a".to_string())], + )? + .build()?; + + let expected = "Projection: CAST(NULL AS UInt32) AS a, CAST(NULL AS UInt32) AS b, CAST(NULL AS UInt32) AS c, right.a, right.b, right.c\n TableScan: right"; + assert_together_optimized_plan(plan, expected, true) + } + + #[test] + fn test_full_join_right_empty_null_pad() -> Result<()> { + let left = + LogicalPlanBuilder::from(test_table_scan_with_name("left")?).build()?; + let right_empty = LogicalPlanBuilder::from(test_table_scan_with_name("right")?) + .filter(lit(false))? + .build()?; + + let plan = LogicalPlanBuilder::from(left) + .join_using( + right_empty, + JoinType::Full, + vec![Column::from_name("a".to_string())], + )? + .build()?; + + let expected = "Projection: left.a, left.b, left.c, CAST(NULL AS UInt32) AS a, CAST(NULL AS UInt32) AS b, CAST(NULL AS UInt32) AS c\n TableScan: left"; + assert_together_optimized_plan(plan, expected, true) + } + + #[test] + fn test_full_join_left_empty_null_pad() -> Result<()> { + let left_empty = LogicalPlanBuilder::from(test_table_scan_with_name("left")?) + .filter(lit(false))? + .build()?; + let right = + LogicalPlanBuilder::from(test_table_scan_with_name("right")?).build()?; + + let plan = LogicalPlanBuilder::from(left_empty) + .join_using( + right, + JoinType::Full, + vec![Column::from_name("a".to_string())], + )? + .build()?; + + let expected = "Projection: CAST(NULL AS UInt32) AS a, CAST(NULL AS UInt32) AS b, CAST(NULL AS UInt32) AS c, right.a, right.b, right.c\n TableScan: right"; + assert_together_optimized_plan(plan, expected, true) + } + + #[test] + fn test_left_join_complex_on_right_empty_null_pad() -> Result<()> { + let left = + LogicalPlanBuilder::from(test_table_scan_with_name("left")?).build()?; + let right_empty = LogicalPlanBuilder::from(test_table_scan_with_name("right")?) + .filter(lit(false))? + .build()?; + + // Complex ON condition: left.a = right.a AND left.b > right.b + let plan = LogicalPlanBuilder::from(left) + .join( + right_empty, + JoinType::Left, + ( + vec![Column::from_name("a".to_string())], + vec![Column::from_name("a".to_string())], + ), + Some(col("left.b").gt(col("right.b"))), + )? + .build()?; + + let expected = "Projection: left.a, left.b, left.c, CAST(NULL AS UInt32) AS a, CAST(NULL AS UInt32) AS b, CAST(NULL AS UInt32) AS c\n TableScan: left"; + assert_together_optimized_plan(plan, expected, true) + } + #[test] fn test_empty_with_non_empty() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index ea0980ad4e1c7..f30b1187b7bca 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -23,28 +23,34 @@ use std::sync::Arc; use arrow::datatypes::DataType; use indexmap::IndexSet; use itertools::Itertools; +use log::{Level, debug, log_enabled}; +use datafusion_common::instant::Instant; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{ - assert_eq_or_internal_err, assert_or_internal_err, internal_err, plan_err, - qualified_name, Column, DFSchema, Result, + Column, DFSchema, Result, assert_eq_or_internal_err, internal_err, plan_err, + qualified_name, }; use datafusion_expr::expr::WindowFunction; use datafusion_expr::expr_rewriter::replace_col; -use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union}; +use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan}; use datafusion_expr::utils::{ conjunction, expr_to_columns, split_conjunction, split_conjunction_owned, }; use datafusion_expr::{ - and, or, BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown, + BinaryExpr, Distinct, Expr, Filter, Operator, Projection, + TableProviderFilterPushDown, and, or, }; use crate::optimizer::ApplyOrder; -use crate::simplify_expressions::simplify_predicates; -use crate::utils::{has_all_column_refs, is_restrict_null_predicate}; +use crate::simplify_expressions::{reorder_predicates, simplify_predicates}; +use crate::utils::{ + ColumnReference, has_all_column_refs, is_restrict_null_predicate, schema_columns, +}; use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_expr::ExpressionPlacement; /// Optimizer rule for pushing (moving) filter expressions down in a plan so /// they are applied as early as possible. @@ -61,7 +67,7 @@ use crate::{OptimizerConfig, OptimizerRule}; /// Sort (a, b) /// ``` /// -/// A better plan is to filter the data *before* the Sort, which sorts fewer +/// A better plan is to filter the data *before* the Sort, which sorts fewer /// rows and therefore does less work overall: /// /// ```text @@ -75,7 +81,7 @@ use crate::{OptimizerConfig, OptimizerRule}; /// different result. /// /// ```text -/// Filter (a > 10) <-- can not move this Filter before the limit +/// Filter (a > 10) <-- cannot move this Filter before the limit /// Limit (fetch=3) /// Sort (a, b) /// ``` @@ -85,46 +91,46 @@ use crate::{OptimizerConfig, OptimizerRule}; /// satisfies `filter(op(data)) = op(filter(data))`. /// /// The filter-commutative property is plan and column-specific. A filter on `a` -/// can be pushed through a `Aggregate(group_by = [a], agg=[sum(b))`. However, a -/// filter on `sum(b)` can not be pushed through the same aggregate. +/// can be pushed through a `Aggregate(group_by = [a], agg=[sum(b)])`. However, a +/// filter on `sum(b)` cannot be pushed through the same aggregate. /// /// # Handling Conjunctions /// -/// It is possible to only push down **part** of a filter expression if is +/// It is possible to only push down **part** of a filter expression if it is /// connected with `AND`s (more formally if it is a "conjunction"). /// /// For example, given the following plan: /// /// ```text /// Filter(a > 10 AND sum(b) < 5) -/// Aggregate(group_by = [a], agg = [sum(b)) +/// Aggregate(group_by = [a], agg = [sum(b)]) /// ``` /// -/// The `a > 10` is commutative with the `Aggregate` but `sum(b) < 5` is not. -/// Therefore it is possible to only push part of the expression, resulting in: +/// The `a > 10` is commutative with the `Aggregate` but `sum(b) < 5` is not. +/// Therefore it is possible to only push down part of the expression, resulting in: /// /// ```text /// Filter(sum(b) < 5) -/// Aggregate(group_by = [a], agg = [sum(b)) +/// Aggregate(group_by = [a], agg = [sum(b)]) /// Filter(a > 10) /// ``` /// /// # Handling Column Aliases /// -/// This optimizer must sometimes handle re-writing filter expressions when they -/// pushed, for example if there is a projection that aliases `a+1` to `"b"`: +/// This optimizer must sometimes handle rewriting filter expressions when they are +/// pushed. For example, consider a projection that aliases `a+1` to `"b"`: /// /// ```text /// Filter (b > 10) /// Projection: [a+1 AS "b"] <-- changes the name of `a+1` to `b` /// ``` /// -/// To apply the filter prior to the `Projection`, all references to `b` must be +/// To push this filter below the `Projection`, all references to `b` must be /// rewritten to `a+1`: /// /// ```text -/// Projection: a AS "b" -/// Filter: (a + 1 > 10) <--- changed from b to a + 1 +/// Projection: [a+1 AS "b"] +/// Filter: (a+1 > 10) <--- changed from b to a+1 /// ``` /// # Implementation Notes /// @@ -175,27 +181,9 @@ pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) { } } -/// For a given JOIN type, determine whether each input of the join is preserved -/// for the join condition (`ON` clause filters). -/// -/// It is only correct to push filters below a join for preserved inputs. -/// -/// # Return Value -/// A tuple of booleans - (left_preserved, right_preserved). -/// -/// See [`lr_is_preserved`] for a definition of "preserved". +/// See [`JoinType::on_lr_is_preserved`] for details. pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) { - match join_type { - JoinType::Inner => (true, true), - JoinType::Left => (false, true), - JoinType::Right => (true, false), - JoinType::Full => (false, false), - JoinType::LeftSemi | JoinType::RightSemi => (true, true), - JoinType::LeftAnti => (false, true), - JoinType::RightAnti => (true, false), - JoinType::LeftMark => (false, true), - JoinType::RightMark => (true, false), - } + join_type.on_lr_is_preserved() } /// Evaluates the columns referenced in the given expression to see if they refer @@ -205,11 +193,11 @@ struct ColumnChecker<'a> { /// schema of left join input left_schema: &'a DFSchema, /// columns in left_schema, computed on demand - left_columns: Option>, + left_columns: Option>>, /// schema of right join input right_schema: &'a DFSchema, /// columns in left_schema, computed on demand - right_columns: Option>, + right_columns: Option>>, } impl<'a> ColumnChecker<'a> { @@ -239,20 +227,6 @@ impl<'a> ColumnChecker<'a> { } } -/// Returns all columns in the schema -fn schema_columns(schema: &DFSchema) -> HashSet { - schema - .iter() - .flat_map(|(qualifier, field)| { - [ - Column::new(qualifier.cloned(), field.name()), - // we need to push down filter using unqualified column as well - Column::new_unqualified(field.name()), - ] - }) - .collect::>() -} - /// Determine whether the predicate can evaluate as the join conditions fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { let mut is_evaluate = true; @@ -263,6 +237,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump), Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::ScalarSubquery(_) | Expr::OuterReferenceColumn(_, _) | Expr::Unnest(_) => { @@ -288,7 +263,10 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::Cast(_) | Expr::TryCast(_) | Expr::InList { .. } - | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue), + | Expr::ScalarFunction(_) + | Expr::HigherOrderFunction(_) + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => Ok(TreeNodeRecursion::Continue), // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] Expr::AggregateFunction(_) @@ -334,10 +312,8 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { /// * do nothing. fn extract_or_clauses_for_join<'a>( filters: &'a [Expr], - schema: &'a DFSchema, + schema_cols: &'a HashSet, ) -> impl Iterator + 'a { - let schema_columns = schema_columns(schema); - // new formed OR clauses and their column references filters.iter().filter_map(move |expr| { if let Expr::BinaryExpr(BinaryExpr { @@ -346,8 +322,8 @@ fn extract_or_clauses_for_join<'a>( right, }) = expr { - let left_expr = extract_or_clause(left.as_ref(), &schema_columns); - let right_expr = extract_or_clause(right.as_ref(), &schema_columns); + let left_expr = extract_or_clause(left.as_ref(), schema_cols); + let right_expr = extract_or_clause(right.as_ref(), schema_cols); // If nothing can be extracted from any sub clauses, do nothing for this OR clause. if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) { @@ -369,7 +345,10 @@ fn extract_or_clauses_for_join<'a>( /// Otherwise, return None. /// /// For other clause, apply the rule above to extract clause. -fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option { +fn extract_or_clause( + expr: &Expr, + schema_columns: &HashSet, +) -> Option { let mut predicate = None; match expr { @@ -435,6 +414,10 @@ fn push_down_all_join( // 3) should be kept as filter conditions let left_schema = join.left.schema(); let right_schema = join.right.schema(); + + let left_schema_columns = schema_columns(left_schema.as_ref()); + let right_schema_columns = schema_columns(right_schema.as_ref()); + let mut left_push = vec![]; let mut right_push = vec![]; let mut keep_predicates = vec![]; @@ -454,39 +437,48 @@ fn push_down_all_join( } } - // For infer predicates, if they can not push through join, just drop them + // Push predicates inferred from the join expression for predicate in inferred_join_predicates { - if left_preserved && checker.is_left_only(&predicate) { + if checker.is_left_only(&predicate) { left_push.push(predicate); - } else if right_preserved && checker.is_right_only(&predicate) { + } else if checker.is_right_only(&predicate) { right_push.push(predicate); } } let mut on_filter_join_conditions = vec![]; let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type); - - if !on_filter.is_empty() { - for on in on_filter { - if on_left_preserved && checker.is_left_only(&on) { - left_push.push(on) - } else if on_right_preserved && checker.is_right_only(&on) { - right_push.push(on) - } else { - on_filter_join_conditions.push(on) - } + for on in on_filter { + if on_left_preserved && checker.is_left_only(&on) { + left_push.push(on) + } else if on_right_preserved && checker.is_right_only(&on) { + right_push.push(on) + } else { + on_filter_join_conditions.push(on) } } // Extract from OR clause, generate new predicates for both side of join if possible. // We only track the unpushable predicates above. if left_preserved { - left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema)); - left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema)); + left_push.extend(extract_or_clauses_for_join( + &keep_predicates, + &left_schema_columns, + )); + left_push.extend(extract_or_clauses_for_join( + &join_conditions, + &left_schema_columns, + )); } if right_preserved { - right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema)); - right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema)); + right_push.extend(extract_or_clauses_for_join( + &keep_predicates, + &right_schema_columns, + )); + right_push.extend(extract_or_clauses_for_join( + &join_conditions, + &right_schema_columns, + )); } // For predicates from join filter, we should check with if a join side is preserved @@ -494,55 +486,71 @@ fn push_down_all_join( if on_left_preserved { left_push.extend(extract_or_clauses_for_join( &on_filter_join_conditions, - left_schema, + &left_schema_columns, )); } if on_right_preserved { right_push.extend(extract_or_clauses_for_join( &on_filter_join_conditions, - right_schema, + &right_schema_columns, )); } + // Add any new join conditions as the non join predicates + let join_conditions_empty = join_conditions.is_empty(); + join_conditions.extend(on_filter_join_conditions); + join.filter = conjunction(join_conditions); + + if join_conditions_empty && left_push.is_empty() && right_push.is_empty() { + // wrap the join on the filter whose predicates must be kept, if any + return Ok(Transformed::no(with_filters( + keep_predicates, + LogicalPlan::Join(join), + ))); + } + if let Some(predicate) = conjunction(left_push) { - join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?)); + join.left = Arc::new(LogicalPlan::Filter(Filter::new(predicate, join.left))); } + if let Some(predicate) = conjunction(right_push) { - join.right = - Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?)); + join.right = Arc::new(LogicalPlan::Filter(Filter::new(predicate, join.right))); } - // Add any new join conditions as the non join predicates - join_conditions.extend(on_filter_join_conditions); - join.filter = conjunction(join_conditions); - // wrap the join on the filter whose predicates must be kept, if any - let plan = LogicalPlan::Join(join); - let plan = if let Some(predicate) = conjunction(keep_predicates) { - LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?) - } else { - plan - }; - Ok(Transformed::yes(plan)) + Ok(Transformed::yes(with_filters( + keep_predicates, + LogicalPlan::Join(join), + ))) } fn push_down_join( - join: Join, - parent_predicate: Option<&Expr>, + mut join: Join, + parent_predicate: Option, ) -> Result> { // Split the parent predicate into individual conjunctive parts. - let predicates = parent_predicate - .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone())); + let predicates = parent_predicate.map_or_else(Vec::new, split_conjunction_owned); // Extract conjunctions from the JOIN's ON filter, if present. let on_filters = join .filter - .as_ref() - .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone())); + .take() + .map_or_else(Vec::new, split_conjunction_owned); // Are there any new join predicates that can be inferred from the filter expressions? - let inferred_join_predicates = - infer_join_predicates(&join, &predicates, &on_filters)?; + let inferred_join_predicates = with_debug_timing("infer_join_predicates", || { + infer_join_predicates(&join, &predicates, &on_filters) + })?; + + if log_enabled!(Level::Debug) { + debug!( + "push_down_filter: join_type={:?}, parent_predicates={}, on_filters={}, inferred_join_predicates={}", + join.join_type, + predicates.len(), + on_filters.len(), + inferred_join_predicates.len() + ); + } if on_filters.is_empty() && predicates.is_empty() @@ -616,7 +624,7 @@ impl InferredPredicates { fn new(join_type: JoinType) -> Self { Self { predicates: vec![], - is_inner_join: matches!(join_type, JoinType::Inner), + is_inner_join: join_type == JoinType::Inner, } } @@ -766,22 +774,37 @@ impl OptimizerRule for PushDownFilter { fn rewrite( &self, plan: LogicalPlan, - _config: &dyn OptimizerConfig, + config: &dyn OptimizerConfig, ) -> Result> { + let _ = config; if let LogicalPlan::Join(join) = plan { return push_down_join(join, None); }; - let plan_schema = Arc::clone(plan.schema()); - let LogicalPlan::Filter(mut filter) = plan else { return Ok(Transformed::no(plan)); }; let predicate = split_conjunction_owned(filter.predicate.clone()); let old_predicate_len = predicate.len(); - let new_predicates = simplify_predicates(predicate)?; - if old_predicate_len != new_predicates.len() { + let new_predicates = + with_debug_timing("simplify_predicates", || simplify_predicates(predicate))?; + + if log_enabled!(Level::Debug) { + debug!( + "push_down_filter: simplify_predicates old_count={}, new_count={}", + old_predicate_len, + new_predicates.len() + ); + } + + // Place cheap predicates before expensive ones, so the `AND` + // evaluator's right-side short-circuit can skip evaluating expensive + // predicates on rows that have already been filtered out. + let (new_predicates, reorder_changed) = reorder_predicates(new_predicates); + + let count_changed = old_predicate_len != new_predicates.len(); + if count_changed || reorder_changed { let Some(new_predicate) = conjunction(new_predicates) else { // new_predicates is empty - remove the filter entirely // Return the child plan without the filter @@ -790,49 +813,56 @@ impl OptimizerRule for PushDownFilter { filter.predicate = new_predicate; } - match Arc::unwrap_or_clone(filter.input) { - LogicalPlan::Filter(child_filter) => { - let parents_predicates = split_conjunction_owned(filter.predicate); + // If the child has a fetch (limit) or skip (offset), pushing a filter + // below it would change semantics: the limit/offset should apply before + // the filter, not after. + if filter.input.fetch()?.is_some() || filter.input.skip()?.is_some() { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } - // remove duplicated filters - let child_predicates = split_conjunction_owned(child_filter.predicate); - let new_predicates = parents_predicates - .into_iter() - .chain(child_predicates) - // use IndexSet to remove dupes while preserving predicate order - .collect::>() - .into_iter() - .collect::>(); + match Arc::unwrap_or_clone(filter.input) { + LogicalPlan::Filter(mut child_filter) => { + // Child filters first to preserve execution order. + // Use IndexSet to remove duplicates while preserving predicate order. + let new_predicates: IndexSet = + split_conjunction_owned(child_filter.predicate) + .into_iter() + .chain(split_conjunction_owned(filter.predicate)) + .collect(); let Some(new_predicate) = conjunction(new_predicates) else { return plan_err!("at least one expression exists"); }; - let new_filter = LogicalPlan::Filter(Filter::try_new( - new_predicate, - child_filter.input, - )?); - #[allow(clippy::used_underscore_binding)] - self.rewrite(new_filter, _config) + + child_filter.predicate = new_predicate; + self.rewrite(LogicalPlan::Filter(child_filter), config) } - LogicalPlan::Repartition(repartition) => { - let new_filter = - Filter::try_new(filter.predicate, Arc::clone(&repartition.input)) - .map(LogicalPlan::Filter)?; - insert_below(LogicalPlan::Repartition(repartition), new_filter) + LogicalPlan::Repartition(mut repartition) => { + filter.input = repartition.input; + repartition.input = Arc::new(LogicalPlan::Filter(filter)); + Ok(Transformed::yes(LogicalPlan::Repartition(repartition))) } LogicalPlan::Distinct(distinct) => { - let new_filter = - Filter::try_new(filter.predicate, Arc::clone(distinct.input())) - .map(LogicalPlan::Filter)?; - insert_below(LogicalPlan::Distinct(distinct), new_filter) + let distinct = match distinct { + Distinct::All(input) => { + filter.input = input; + Distinct::All(Arc::new(LogicalPlan::Filter(filter))) + } + Distinct::On(mut distinct) => { + filter.input = distinct.input; + distinct.input = Arc::new(LogicalPlan::Filter(filter)); + Distinct::On(distinct) + } + }; + + Ok(Transformed::yes(LogicalPlan::Distinct(distinct))) } - LogicalPlan::Sort(sort) => { - let new_filter = - Filter::try_new(filter.predicate, Arc::clone(&sort.input)) - .map(LogicalPlan::Filter)?; - insert_below(LogicalPlan::Sort(sort), new_filter) + LogicalPlan::Sort(mut sort) => { + filter.input = sort.input; + sort.input = Arc::new(LogicalPlan::Filter(filter)); + Ok(Transformed::yes(LogicalPlan::Sort(sort))) } - LogicalPlan::SubqueryAlias(subquery_alias) => { + LogicalPlan::SubqueryAlias(mut subquery_alias) => { let mut replace_map = HashMap::new(); for (i, (qualifier, field)) in subquery_alias.input.schema().iter().enumerate() @@ -844,30 +874,24 @@ impl OptimizerRule for PushDownFilter { Expr::Column(Column::new(qualifier.cloned(), field.name())), ); } - let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?; - let new_filter = LogicalPlan::Filter(Filter::try_new( - new_predicate, - Arc::clone(&subquery_alias.input), - )?); - insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter) + filter.predicate = replace_cols_by_name(filter.predicate, &replace_map)?; + filter.input = subquery_alias.input; + subquery_alias.input = Arc::new(LogicalPlan::Filter(filter)); + Ok(Transformed::yes(LogicalPlan::SubqueryAlias(subquery_alias))) } LogicalPlan::Projection(projection) => { let predicates = split_conjunction_owned(filter.predicate.clone()); - let (new_projection, keep_predicate) = + let (mut result, keep_predicates) = rewrite_projection(predicates, projection)?; - if new_projection.transformed { - match keep_predicate { - None => Ok(new_projection), - Some(keep_predicate) => new_projection.map_data(|child_plan| { - Filter::try_new(keep_predicate, Arc::new(child_plan)) - .map(LogicalPlan::Filter) - }), - } + if result.transformed { + result.data = with_filters(keep_predicates, result.data) } else { - filter.input = Arc::new(new_projection.data); - Ok(Transformed::no(LogicalPlan::Filter(filter))) + filter.input = Arc::new(result.data); + result.data = LogicalPlan::Filter(filter) } + + Ok(result) } LogicalPlan::Unnest(mut unnest) => { let predicates = split_conjunction_owned(filter.predicate.clone()); @@ -878,11 +902,10 @@ impl OptimizerRule for PushDownFilter { for idx in &unnest.struct_type_columns { let (sub_qualifier, field) = unnest.input.schema().qualified_field(*idx); - let field_name = field.name().clone(); - if let DataType::Struct(children) = field.data_type() { + let field_name = field.name(); for child in children { - let child_name = child.name().clone(); + let child_name = child.name(); unnest_struct_columns.push(Column::new( sub_qualifier.cloned(), format!("{field_name}.{child_name}"), @@ -925,29 +948,21 @@ impl OptimizerRule for PushDownFilter { // Filter // Unnest Input (Projection) - let unnest_input = std::mem::take(&mut unnest.input); - - let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new( - conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty. - unnest_input, - )?); - + // Safe to unwrap since non_unnest_predicates is not empty. + filter.predicate = conjunction(non_unnest_predicates).unwrap(); + filter.input = unnest.input; // Directly assign new filter plan as the new unnest's input. // The new filter plan will go through another rewrite pass since the rule itself // is applied recursively to all the child from top to down - let unnest_plan = - insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?; - - match conjunction(unnest_predicates) { - None => Ok(unnest_plan), - Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter( - Filter::try_new(predicate, Arc::new(unnest_plan.data))?, - ))), - } + unnest.input = Arc::new(LogicalPlan::Filter(filter)); + Ok(Transformed::yes(with_filters( + unnest_predicates, + LogicalPlan::Unnest(unnest), + ))) } - LogicalPlan::Union(ref union) => { + LogicalPlan::Union(mut union) => { let mut inputs = Vec::with_capacity(union.inputs.len()); - for input in &union.inputs { + for input in union.inputs { let mut replace_map = HashMap::new(); for (i, (qualifier, field)) in input.schema().iter().enumerate() { let (union_qualifier, union_field) = @@ -960,72 +975,51 @@ impl OptimizerRule for PushDownFilter { let push_predicate = replace_cols_by_name(filter.predicate.clone(), &replace_map)?; - inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new( + inputs.push(Arc::new(LogicalPlan::Filter(Filter::new( push_predicate, - Arc::clone(input), - )?))) + input, + )))) } - Ok(Transformed::yes(LogicalPlan::Union(Union { - inputs, - schema: Arc::clone(&plan_schema), - }))) + + union.inputs = inputs; + Ok(Transformed::yes(LogicalPlan::Union(union))) } - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(mut agg) => { // We can push down Predicate which in groupby_expr. - let group_expr_columns = agg - .group_expr - .iter() - .map(|e| { - let (relation, name) = e.qualified_name(); - Column::new(relation, name) - }) - .collect::>(); + let group_expr_columns = expr_columns(&agg.group_expr); - let predicates = split_conjunction_owned(filter.predicate); + // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)] + // After push, we need to replace `a+b` with Column(a)+Column(b) + // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))} + let mut replace_map = HashMap::new(); + for expr in &agg.group_expr { + replace_map.insert(expr.schema_name().to_string(), unalias(expr)); + } + let predicates = split_conjunction_owned(filter.predicate); let mut keep_predicates = vec![]; let mut push_predicates = vec![]; for expr in predicates { let cols = expr.column_refs(); if cols.iter().all(|c| group_expr_columns.contains(c)) { - push_predicates.push(expr); + push_predicates.push(replace_cols_by_name(expr, &replace_map)?); } else { keep_predicates.push(expr); } } - // As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)] - // After push, we need to replace `a+b` with Column(a)+Column(b) - // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))} - let mut replace_map = HashMap::new(); - for expr in &agg.group_expr { - replace_map.insert(expr.schema_name().to_string(), expr.clone()); - } - let replaced_push_predicates = push_predicates - .into_iter() - .map(|expr| replace_cols_by_name(expr, &replace_map)) - .collect::>>()?; - - let agg_input = Arc::clone(&agg.input); - Transformed::yes(LogicalPlan::Aggregate(agg)) - .transform_data(|new_plan| { - // If we have a filter to push, we push it down to the input of the aggregate - if let Some(predicate) = conjunction(replaced_push_predicates) { - let new_filter = make_filter(predicate, agg_input)?; - insert_below(new_plan, new_filter) - } else { - Ok(Transformed::no(new_plan)) - } - })? - .map_data(|child_plan| { - // if there are any remaining predicates we can't push, add them - // back as a filter - if let Some(predicate) = conjunction(keep_predicates) { - make_filter(predicate, Arc::new(child_plan)) - } else { - Ok(child_plan) - } - }) + // If we have a filter to push, we push it down to the input of the aggregate + let result = if let Some(predicate) = conjunction(push_predicates) { + filter.predicate = predicate; + filter.input = agg.input; + agg.input = Arc::new(LogicalPlan::Filter(filter)); + Transformed::yes(LogicalPlan::Aggregate(agg)) + } else { + Transformed::no(LogicalPlan::Aggregate(agg)) + }; + + // If there are any remaining predicates we can't push, add them back as a filter + result.map_data(|plan| Ok(with_filters(keep_predicates, plan))) } // Tries to push filters based on the partition key(s) of the window function(s) used. // Example: @@ -1037,22 +1031,16 @@ impl OptimizerRule for PushDownFilter { // Filter: (b > 1) and (c > 1) // Window: func() PARTITION BY [a] ... // Filter: (a > 1) - LogicalPlan::Window(window) => { + LogicalPlan::Window(mut window) => { // Retrieve the set of potential partition keys where we can push filters by. // Unlike aggregations, where there is only one statement per SELECT, there can be // multiple window functions, each with potentially different partition keys. // Therefore, we need to ensure that any potential partition key returned is used in // ALL window functions. Otherwise, filters cannot be pushed by through that column. - let extract_partition_keys = |func: &WindowFunction| { - func.params - .partition_by - .iter() - .map(|c| { - let (relation, name) = c.qualified_name(); - Column::new(relation, name) - }) - .collect::>() - }; + fn extract_partition_keys(func: &WindowFunction) -> HashSet { + expr_columns(&func.params.partition_by) + } + let potential_partition_keys = window .window_expr .iter() @@ -1102,33 +1090,32 @@ impl OptimizerRule for PushDownFilter { // place, so we can use `push_predicates` directly. This is consistent with other // optimizers, such as the one used by Postgres. - let window_input = Arc::clone(&window.input); - Transformed::yes(LogicalPlan::Window(window)) - .transform_data(|new_plan| { - // If we have a filter to push, we push it down to the input of the window - if let Some(predicate) = conjunction(push_predicates) { - let new_filter = make_filter(predicate, window_input)?; - insert_below(new_plan, new_filter) - } else { - Ok(Transformed::no(new_plan)) - } - })? - .map_data(|child_plan| { - // if there are any remaining predicates we can't push, add them - // back as a filter - if let Some(predicate) = conjunction(keep_predicates) { - make_filter(predicate, Arc::new(child_plan)) - } else { - Ok(child_plan) - } - }) + // If we have a filter to push, we push it down to the input of the aggregate + let result = if let Some(predicate) = conjunction(push_predicates) { + filter.predicate = predicate; + filter.input = window.input; + window.input = Arc::new(LogicalPlan::Filter(filter)); + Transformed::yes(LogicalPlan::Window(window)) + } else { + Transformed::no(LogicalPlan::Window(window)) + }; + + // If there are any remaining predicates we can't push, add them back as a filter + result.map_data(|plan| Ok(with_filters(keep_predicates, plan))) } - LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)), - LogicalPlan::TableScan(scan) => { + LogicalPlan::Join(join) => push_down_join(join, Some(filter.predicate)), + LogicalPlan::TableScan(mut scan) => { let filter_predicates = split_conjunction(&filter.predicate); + // Filters containing scalar subqueries cannot be pushed to + // providers because the subquery result is not available + // until execution time. + let (subquery_filters, pushdown_candidates): (Vec<&Expr>, Vec<&Expr>) = + filter_predicates + .into_iter() + .partition(|pred| pred.contains_scalar_subquery()); let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) = - filter_predicates + pushdown_candidates .into_iter() .partition(|pred| pred.is_volatile()); @@ -1144,13 +1131,21 @@ impl OptimizerRule for PushDownFilter { non_volatile_filters.len() ); + if supported_filters + .iter() + .all(|res| res == &TableProviderFilterPushDown::Unsupported) + { + filter.input = Arc::new(LogicalPlan::TableScan(scan)); + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + // Compose scan filters from non-volatile filters of `Exact` or `Inexact` pushdown type - let zip = non_volatile_filters.into_iter().zip(supported_filters); + let zip = non_volatile_filters.iter().zip(supported_filters.iter()); let new_scan_filters = zip .clone() - .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported) - .map(|(pred, _)| pred); + .filter(|(_, res)| *res != &TableProviderFilterPushDown::Unsupported) + .map(|(&pred, _)| pred); // Add new scan filters let new_scan_filters: Vec = scan @@ -1161,26 +1156,31 @@ impl OptimizerRule for PushDownFilter { .cloned() .collect(); - // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, and also include volatile filters + if supported_filters + .iter() + .all(|res| res == &TableProviderFilterPushDown::Inexact) + && scan.filters == new_scan_filters + { + filter.input = Arc::new(LogicalPlan::TableScan(scan)); + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } else { + scan.filters = new_scan_filters; + } + + // Compose predicates to be of `Unsupported` or `Inexact` pushdown type, + // and also include volatile and subquery-containing filters let new_predicate: Vec = zip - .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact) - .map(|(pred, _)| pred) + .filter(|(_, res)| *res != &TableProviderFilterPushDown::Exact) + .map(|(&pred, _)| pred) .chain(volatile_filters) + .chain(subquery_filters) .cloned() .collect(); - let new_scan = LogicalPlan::TableScan(TableScan { - filters: new_scan_filters, - ..scan - }); - - Transformed::yes(new_scan).transform_data(|new_scan| { - if let Some(predicate) = conjunction(new_predicate) { - make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes) - } else { - Ok(Transformed::no(new_scan)) - } - }) + Ok(Transformed::yes(with_filters( + new_predicate, + LogicalPlan::TableScan(scan), + ))) } LogicalPlan::Extension(extension_plan) => { // This check prevents the Filter from being removed when the extension node has no children, @@ -1195,17 +1195,16 @@ impl OptimizerRule for PushDownFilter { // determine if we can push any predicates down past the extension node // each element is true for push, false to keep - let predicate_push_or_keep = split_conjunction(&filter.predicate) - .iter() - .map(|expr| { - let cols = expr.column_refs(); - if cols.iter().any(|c| prevent_cols.contains(&c.name)) { - Ok(false) // No push (keep) - } else { - Ok(true) // push - } - }) - .collect::>>()?; + let predicate_push_or_keep: Vec = + split_conjunction(&filter.predicate) + .iter() + .map(|expr| { + !expr + .column_refs() + .iter() + .any(|c| prevent_cols.contains(&c.name)) + }) + .collect(); // all predicates are kept, no changes needed if predicate_push_or_keep.iter().all(|&x| !x) { @@ -1218,7 +1217,7 @@ impl OptimizerRule for PushDownFilter { let mut push_predicates = vec![]; for (push, expr) in predicate_push_or_keep .into_iter() - .zip(split_conjunction_owned(filter.predicate).into_iter()) + .zip(split_conjunction_owned(filter.predicate)) { if !push { keep_predicates.push(expr); @@ -1227,33 +1226,25 @@ impl OptimizerRule for PushDownFilter { } } - let new_children = match conjunction(push_predicates) { - Some(predicate) => extension_plan - .node - .inputs() - .into_iter() - .map(|child| { - Ok(LogicalPlan::Filter(Filter::try_new( - predicate.clone(), - Arc::new(child.clone()), - )?)) - }) - .collect::>>()?, - None => extension_plan.node.inputs().into_iter().cloned().collect(), - }; + // Unwrap - push_predicates is not empty, predicate_push_or_keep checked. + let predicate = conjunction(push_predicates).unwrap(); + let new_children = extension_plan + .node + .inputs() + .into_iter() + .map(|child| { + LogicalPlan::Filter(Filter::new( + predicate.clone(), + Arc::new(child.clone()), + )) + }) + .collect(); + // extension with new inputs. - let child_plan = LogicalPlan::Extension(extension_plan); - let new_extension = - child_plan.with_new_exprs(child_plan.expressions(), new_children)?; - - let new_plan = match conjunction(keep_predicates) { - Some(predicate) => LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(new_extension), - )?), - None => new_extension, - }; - Ok(Transformed::yes(new_plan)) + let extension = LogicalPlan::Extension(extension_plan); + let new_plan = + extension.with_new_exprs(extension.expressions(), new_children)?; + Ok(Transformed::yes(with_filters(keep_predicates, new_plan))) } child => { filter.input = Arc::new(child); @@ -1293,159 +1284,167 @@ impl OptimizerRule for PushDownFilter { fn rewrite_projection( predicates: Vec, mut projection: Projection, -) -> Result<(Transformed, Option)> { - // A projection is filter-commutable if it do not contain volatile predicates or contain volatile - // predicates that are not used in the filter. However, we should re-writes all predicate expressions. - // collect projection. - let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection +) -> Result<(Transformed, Vec)> { + // Partition projection expressions into non-pushable vs pushable. + // Non-pushable expressions are volatile (must not be duplicated) or + // MoveTowardsLeafNodes (cheap expressions like get_field where re-inlining + // into a filter causes optimizer instability — ExtractLeafExpressions will + // undo the push-down, creating an infinite loop that runs until the + // iteration limit is hit). + let (non_pushable_map, pushable_map) = projection .schema .iter() .zip(projection.expr.iter()) .map(|((qualifier, field), expr)| { - // strip alias, as they should not be part of filters - let expr = expr.clone().unalias(); - - (qualified_name(qualifier, field.name()), expr) + (qualified_name(qualifier, field.name()), unalias(expr)) }) - .partition(|(_, value)| value.is_volatile()); + .partition(|(_, value)| { + value.is_volatile() + || value.placement() == ExpressionPlacement::MoveTowardsLeafNodes + }); let mut push_predicates = vec![]; let mut keep_predicates = vec![]; for expr in predicates { - if contain(&expr, &volatile_map) { + if contain(&expr, &non_pushable_map) { keep_predicates.push(expr); } else { push_predicates.push(expr); } } - match conjunction(push_predicates) { - Some(expr) => { - // re-write all filters based on this projection - // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" - let new_filter = LogicalPlan::Filter(Filter::try_new( - replace_cols_by_name(expr, &non_volatile_map)?, - std::mem::take(&mut projection.input), - )?); - - projection.input = Arc::new(new_filter); - - Ok(( - Transformed::yes(LogicalPlan::Projection(projection)), - conjunction(keep_predicates), - )) - } - None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)), - } + let projection = if let Some(expr) = conjunction(push_predicates) { + // re-write all filters based on this projection + // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" + projection.input = Arc::new(LogicalPlan::Filter(Filter::new( + replace_cols_by_name(expr, &pushable_map)?, + projection.input, + ))); + + Transformed::yes(LogicalPlan::Projection(projection)) + } else { + Transformed::no(LogicalPlan::Projection(projection)) + }; + + Ok((projection, keep_predicates)) } /// Creates a new LogicalPlan::Filter node. +/// +/// Deprecated: use [`Filter::try_new`] directly. +#[deprecated] pub fn make_filter(predicate: Expr, input: Arc) -> Result { Filter::try_new(predicate, input).map(LogicalPlan::Filter) } -/// Replace the existing child of the single input node with `new_child`. -/// -/// Starting: -/// ```text -/// plan -/// child -/// ``` -/// -/// Ending: -/// ```text -/// plan -/// new_child -/// ``` -fn insert_below( - plan: LogicalPlan, - new_child: LogicalPlan, -) -> Result> { - let mut new_child = Some(new_child); - let transformed_plan = plan.map_children(|_child| { - if let Some(new_child) = new_child.take() { - Ok(Transformed::yes(new_child)) - } else { - // already took the new child - internal_err!("node had more than one input") - } - })?; - - // make sure we did the actual replacement - assert_or_internal_err!(new_child.is_none(), "node had no inputs"); - - Ok(transformed_plan) -} - impl PushDownFilter { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } } +fn with_debug_timing(label: &'static str, f: F) -> Result +where + F: FnOnce() -> Result, +{ + if !log_enabled!(Level::Debug) { + return f(); + } + let start = Instant::now(); + let result = f(); + debug!( + "push_down_filter_timing: section={label}, elapsed_us={}", + start.elapsed().as_micros() + ); + result +} + /// replaces columns by its name on the projection. pub fn replace_cols_by_name( e: Expr, - replace_map: &HashMap, + replace_map: &HashMap>, ) -> Result { e.transform_up(|expr| { - Ok(if let Expr::Column(c) = &expr { - match replace_map.get(&c.flat_name()) { - Some(new_c) => Transformed::yes(new_c.clone()), - None => Transformed::no(expr), - } + if let Expr::Column(c) = &expr + && let Some(new_expr) = replace_map.get(&c.flat_name()) + { + Ok(Transformed::yes(new_expr.as_ref().clone())) } else { - Transformed::no(expr) - }) + Ok(Transformed::no(expr)) + } }) .data() } +/// Unalias expression reference. +fn unalias(expr: &Expr) -> &Expr { + if let Expr::Alias(alias) = expr { + unalias(&alias.expr) + } else { + expr + } +} + /// check whether the expression uses the columns in `check_map`. -fn contain(e: &Expr, check_map: &HashMap) -> bool { +fn contain(e: &Expr, check_map: &HashMap) -> bool { let mut is_contain = false; e.apply(|expr| { - Ok(if let Expr::Column(c) = &expr { - match check_map.get(&c.flat_name()) { - Some(_) => { - is_contain = true; - TreeNodeRecursion::Stop - } - None => TreeNodeRecursion::Continue, - } + if let Expr::Column(c) = &expr + && check_map.contains_key(&c.flat_name()) + { + is_contain = true; + Ok(TreeNodeRecursion::Stop) } else { - TreeNodeRecursion::Continue - }) + Ok(TreeNodeRecursion::Continue) + } }) .unwrap(); is_contain } +fn with_filters(predicates: Vec, plan: LogicalPlan) -> LogicalPlan { + if let Some(predicate) = conjunction(predicates) { + LogicalPlan::Filter(Filter::new(predicate, Arc::new(plan))) + } else { + plan + } +} + +fn expr_columns(exprs: &[Expr]) -> HashSet { + exprs + .iter() + .map(|expr| { + let (relation, name) = expr.qualified_name(); + Column::new(relation, name) + }) + .collect() +} + #[cfg(test)] mod tests { - use std::any::Any; use std::cmp::Ordering; use std::fmt::{Debug, Formatter}; - use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use arrow::datatypes::{Field, Schema, SchemaRef}; use async_trait::async_trait; use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue}; - use datafusion_expr::expr::{ScalarFunction, WindowFunction}; + use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ - col, in_list, in_subquery, lit, ColumnarValue, ExprFunctionExt, Extension, - LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, - TableSource, TableType, UserDefinedLogicalNodeCore, Volatility, - WindowFunctionDefinition, + ColumnarValue, ExprFunctionExt, Extension, LogicalPlanBuilder, + ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TableScan, TableSource, + TableType, UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition, col, + in_list, in_subquery, lit, }; + use crate::OptimizerContext; use crate::assert_optimized_plan_eq_snapshot; use crate::optimizer::Optimizer; use crate::simplify_expressions::SimplifyExpressions; + use crate::test::udfs::leaf_udf_expr; use crate::test::*; - use crate::OptimizerContext; use datafusion_expr::test::function_stub::sum; use insta::assert_snapshot; @@ -1484,6 +1483,17 @@ mod tests { }}; } + /// For testing that we don't return [Transformed::yes] when not necessary, + /// as it triggers rebuilding parent plan nodes. + macro_rules! assert_plan_not_transformed { + ($plan:expr) => {{ + let transformed = PushDownFilter::new() + .rewrite($plan, &OptimizerContext::new()) + .expect("failed to optimize plan"); + assert!(!transformed.transformed); + }}; + } + #[test] fn filter_before_projection() -> Result<()> { let table_scan = test_table_scan()?; @@ -1632,6 +1642,8 @@ mod tests { .aggregate(vec![col("a")], vec![sum(col("b")).alias("b")])? .filter(col("b").gt(lit(10i64)))? .build()?; + assert_plan_not_transformed!(plan.clone()); + // filter of aggregate is after aggregation since they are non-commutative assert_optimized_plan_equal!( plan, @@ -1824,6 +1836,7 @@ mod tests { .window(vec![window])? .filter(col("c").gt(lit(10i64)))? .build()?; + assert_plan_not_transformed!(plan.clone()); assert_optimized_plan_equal!( plan, @@ -2331,7 +2344,7 @@ mod tests { plan, @r" Projection: test.a, test1.d - Cross Join: + Cross Join: Projection: test.a, test.b, test.c TableScan: test, full_filters=[test.a = Int32(1)] Projection: test1.d, test1.e, test1.f @@ -2361,7 +2374,7 @@ mod tests { plan, @r" Projection: test.a, test1.a - Cross Join: + Cross Join: Projection: test.a, test.b, test.c TableScan: test, full_filters=[test.a = Int32(1)] Projection: test1.a, test1.b, test1.c @@ -2431,7 +2444,7 @@ mod tests { plan, @r" Projection: test.a - Filter: test.a >= Int64(1) AND test.a <= Int64(1) + Filter: test.a <= Int64(1) AND test.a >= Int64(1) Limit: skip=0, fetch=1 TableScan: test " @@ -2720,8 +2733,7 @@ mod tests { ) } - /// post-left-join predicate on a column common to both sides is only pushed to the left side - /// i.e. - not duplicated to the right side + /// post-left-join predicate on a column common to both sides is pushed to both sides #[test] fn filter_using_left_join_on_common() -> Result<()> { let table_scan = test_table_scan()?; @@ -2749,20 +2761,19 @@ mod tests { TableScan: test2 ", ); - // filter sent to left side of the join, not the right + // filter sent to left side of the join and to the right assert_optimized_plan_equal!( plan, @r" Left Join: Using test.a = test2.a TableScan: test, full_filters=[test.a <= Int64(1)] Projection: test2.a - TableScan: test2 + TableScan: test2, full_filters=[test2.a <= Int64(1)] " ) } - /// post-right-join predicate on a column common to both sides is only pushed to the right side - /// i.e. - not duplicated to the left side. + /// post-right-join predicate on a column common to both sides is pushed to both sides #[test] fn filter_using_right_join_on_common() -> Result<()> { let table_scan = test_table_scan()?; @@ -2790,12 +2801,12 @@ mod tests { TableScan: test2 ", ); - // filter sent to right side of join, not duplicated to the left + // filter sent to right side of join, sent to the left as well assert_optimized_plan_equal!( plan, @r" Right Join: Using test.a = test2.a - TableScan: test + TableScan: test, full_filters=[test.a <= Int64(1)] Projection: test2.a TableScan: test2, full_filters=[test2.a <= Int64(1)] " @@ -2977,7 +2988,7 @@ mod tests { Projection: test.a, test.b, test.c TableScan: test Projection: test2.a, test2.b, test2.c - TableScan: test2, full_filters=[test2.c > UInt32(4)] + TableScan: test2, full_filters=[test2.a > UInt32(1), test2.c > UInt32(4)] " ) } @@ -3051,6 +3062,7 @@ mod tests { Some(filter), )? .build()?; + assert_plan_not_transformed!(plan.clone()); // not part of the test, just good to know: assert_snapshot!(plan, @@ -3099,10 +3111,6 @@ mod tests { .map(|_| self.filter_support.clone()) .collect()) } - - fn as_any(&self) -> &dyn Any { - self - } } fn table_scan_with_pushdown_provider_builder( @@ -3119,6 +3127,7 @@ mod tests { projection, source: Arc::new(test_provider), fetch: None, + statistics_requests: std::collections::BTreeSet::new(), }); Ok(LogicalPlanBuilder::from(table_scan)) @@ -3161,15 +3170,16 @@ mod tests { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?; - let optimized_plan = PushDownFilter::new() + let optimized = PushDownFilter::new() .rewrite(plan, &OptimizerContext::new()) - .expect("failed to optimize plan") - .data; + .expect("failed to optimize plan"); + assert!(optimized.transformed); + assert_plan_not_transformed!(optimized.data.clone()); // Optimizing the same plan multiple times should produce the same plan // each time. assert_optimized_plan_equal!( - optimized_plan, + optimized.data, @r" Filter: a = Int64(1) TableScan: test, partial_filters=[a = Int64(1)] @@ -3181,6 +3191,7 @@ mod tests { fn filter_with_table_provider_unsupported() -> Result<()> { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Unsupported)?; + assert_plan_not_transformed!(plan.clone()); assert_optimized_plan_equal!( plan, @@ -3212,6 +3223,28 @@ mod tests { ) } + #[test] + fn multi_combined_two_filters() -> Result<()> { + let plan = table_scan_with_pushdown_provider_builder( + TableProviderFilterPushDown::Inexact, + vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))], + Some(vec![0]), + )? + .filter(col("a").eq(lit(10i64)))? + .filter(col("b").gt(lit(11i64)))? + .project(vec![col("a"), col("b")])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Projection: a, b + Filter: a = Int64(10) AND b > Int64(11) + TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)] + " + ) + } + #[test] fn multi_combined_filter_exact() -> Result<()> { let plan = table_scan_with_pushdown_provider_builder( @@ -3232,6 +3265,27 @@ mod tests { ) } + #[test] + fn multi_combined_two_filters_exact() -> Result<()> { + let plan = table_scan_with_pushdown_provider_builder( + TableProviderFilterPushDown::Exact, + vec![], + Some(vec![0]), + )? + .filter(col("a").eq(lit(10i64)))? + .filter(col("b").gt(lit(11i64)))? + .project(vec![col("a"), col("b")])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Projection: a, b + TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)] + " + ) + } + #[test] fn test_filter_with_alias() -> Result<()> { // in table scan the true col name is 'test.a', @@ -3924,9 +3978,6 @@ mod tests { } impl ScalarUDFImpl for TestScalarUDF { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "TestScalarUDF" } @@ -4131,7 +4182,7 @@ mod tests { plan, @r" Projection: a, b - Filter: t.a > Int32(5) AND t.b > Int32(10) AND TestScalarUDF() > Float64(0.1) + Filter: TestScalarUDF() > Float64(0.1) AND t.a > Int32(5) AND t.b > Int32(10) TableScan: test " ) @@ -4222,4 +4273,127 @@ mod tests { " ) } + + /// Test that filters are NOT pushed through MoveTowardsLeafNodes projections. + /// These are cheap expressions (like get_field) where re-inlining into a filter + /// has no benefit and causes optimizer instability — ExtractLeafExpressions will + /// undo the push-down, creating an infinite loop that runs until the iteration + /// limit is hit. + #[test] + fn filter_not_pushed_through_move_towards_leaves_projection() -> Result<()> { + let table_scan = test_table_scan()?; + + // Create a projection with a MoveTowardsLeafNodes expression + let proj = LogicalPlanBuilder::from(table_scan) + .project(vec![ + leaf_udf_expr(col("a")).alias("val"), + col("b"), + col("c"), + ])? + .build()?; + + // Put a filter on the MoveTowardsLeafNodes column + let plan = LogicalPlanBuilder::from(proj) + .filter(col("val").gt(lit(150i64)))? + .build()?; + + // Filter should NOT be pushed through — val maps to a MoveTowardsLeafNodes expr + assert_optimized_plan_equal!( + plan, + @r" + Filter: val > Int64(150) + Projection: leaf_udf(test.a) AS val, test.b, test.c + TableScan: test + " + ) + } + + /// Test mixed predicates: Column predicate pushed, MoveTowardsLeafNodes kept. + #[test] + fn filter_mixed_predicates_partial_push() -> Result<()> { + let table_scan = test_table_scan()?; + + // Create a projection with both MoveTowardsLeafNodes and Column expressions + let proj = LogicalPlanBuilder::from(table_scan) + .project(vec![ + leaf_udf_expr(col("a")).alias("val"), + col("b"), + col("c"), + ])? + .build()?; + + // Filter with both: val > 150 (MoveTowardsLeafNodes) AND b > 5 (Column) + let plan = LogicalPlanBuilder::from(proj) + .filter(col("val").gt(lit(150i64)).and(col("b").gt(lit(5i64))))? + .build()?; + + // val > 150 should be kept above, b > 5 should be pushed through + assert_optimized_plan_equal!( + plan, + @r" + Filter: val > Int64(150) + Projection: leaf_udf(test.a) AS val, test.b, test.c + TableScan: test, full_filters=[test.b > Int64(5)] + " + ) + } + + #[test] + fn filter_not_pushed_down_through_table_scan_with_fetch() -> Result<()> { + let scan = test_table_scan()?; + let scan_with_fetch = match scan { + LogicalPlan::TableScan(scan) => LogicalPlan::TableScan(TableScan { + fetch: Some(10), + ..scan + }), + _ => unreachable!(), + }; + let plan = LogicalPlanBuilder::from(scan_with_fetch) + .filter(col("a").gt(lit(10i64)))? + .build()?; + // Filter must NOT be pushed into the table scan when it has a fetch (limit) + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.a > Int64(10) + TableScan: test, fetch=10 + " + ) + } + + #[test] + fn filter_push_down_through_sort_without_fetch() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .sort(vec![col("a").sort(true, true)])? + .filter(col("a").gt(lit(10i64)))? + .build()?; + // Filter should be pushed below the sort + assert_optimized_plan_equal!( + plan, + @r" + Sort: test.a ASC NULLS FIRST + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) + } + + #[test] + fn filter_not_pushed_down_through_sort_with_fetch() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .sort_with_limit(vec![col("a").sort(true, true)], Some(5))? + .filter(col("a").gt(lit(10i64)))? + .build()?; + // Filter must NOT be pushed below the sort when it has a fetch (limit), + // because the limit should apply before the filter. + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.a > Int64(10) + Sort: test.a ASC NULLS FIRST, fetch=5 + TableScan: test + " + ) + } } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 80d4a2de6679d..4a26cd5884f6b 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -23,11 +23,11 @@ use std::sync::Arc; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::Result; use datafusion_common::tree_node::Transformed; use datafusion_common::utils::combine_limit; -use datafusion_common::Result; use datafusion_expr::logical_plan::{Join, JoinType, Limit, LogicalPlan}; -use datafusion_expr::{lit, FetchType, SkipType}; +use datafusion_expr::{FetchType, SkipType, lit}; /// Optimization rule that tries to push down `LIMIT`. //. It will push down through projection, limits (taking the smaller limit) @@ -35,7 +35,7 @@ use datafusion_expr::{lit, FetchType, SkipType}; pub struct PushDownLimit {} impl PushDownLimit { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -47,10 +47,11 @@ impl OptimizerRule for PushDownLimit { true } + #[expect(clippy::only_used_in_recursion)] fn rewrite( &self, plan: LogicalPlan, - _config: &dyn OptimizerConfig, + config: &dyn OptimizerConfig, ) -> Result> { let LogicalPlan::Limit(mut limit) = plan else { return Ok(Transformed::no(plan)); @@ -81,8 +82,7 @@ impl OptimizerRule for PushDownLimit { }); // recursively reapply the rule on the new plan - #[allow(clippy::used_underscore_binding)] - return self.rewrite(plan, _config); + return self.rewrite(plan, config); } // no fetch to push, so return the original plan @@ -281,8 +281,8 @@ mod test { use crate::OptimizerContext; use datafusion_common::DFSchemaRef; use datafusion_expr::{ - col, exists, logical_plan::builder::LogicalPlanBuilder, Expr, Extension, - UserDefinedLogicalNodeCore, + Expr, Extension, UserDefinedLogicalNodeCore, col, exists, + logical_plan::builder::LogicalPlanBuilder, }; use datafusion_functions_aggregate::expr_fn::max; @@ -1044,7 +1044,7 @@ mod test { plan, @r" Limit: skip=0, fetch=1000 - Cross Join: + Cross Join: Limit: skip=0, fetch=1000 TableScan: test, fetch=1000 Limit: skip=0, fetch=1000 @@ -1067,7 +1067,7 @@ mod test { plan, @r" Limit: skip=1000, fetch=1000 - Cross Join: + Cross Join: Limit: skip=0, fetch=2000 TableScan: test, fetch=2000 Limit: skip=0, fetch=2000 diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 215f5e240d5de..06df61e766615 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -25,8 +25,8 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::{Column, Result}; use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::{col, lit, ExprFunctionExt, Limit, LogicalPlanBuilder}; use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; +use datafusion_expr::{ExprFunctionExt, Limit, LogicalPlanBuilder, col, lit}; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] /// @@ -69,7 +69,7 @@ use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; pub struct ReplaceDistinctWithAggregate {} impl ReplaceDistinctWithAggregate { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -109,7 +109,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { .enumerate() .all(|(idx, f_idx)| idx == *f_idx) { - return Ok(Transformed::yes(input.as_ref().clone())); + return Ok(Transformed::yes(Arc::unwrap_or_clone(input))); } } @@ -214,7 +214,7 @@ mod tests { use crate::OptimizerContext; use datafusion_common::Result; use datafusion_expr::{ - col, logical_plan::builder::LogicalPlanBuilder, table_scan, Expr, + Expr, col, logical_plan::builder::LogicalPlanBuilder, table_scan, }; use datafusion_functions_aggregate::sum::sum; diff --git a/datafusion/optimizer/src/rewrite_set_comparison.rs b/datafusion/optimizer/src/rewrite_set_comparison.rs new file mode 100644 index 0000000000000..c8c35b518743a --- /dev/null +++ b/datafusion/optimizer/src/rewrite_set_comparison.rs @@ -0,0 +1,171 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule rewriting `SetComparison` subqueries (e.g. `= ANY`, +//! `> ALL`) into boolean expressions built from `EXISTS` subqueries +//! that capture SQL three-valued logic. + +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{Column, DFSchema, ExprSchema, Result, ScalarValue, plan_err}; +use datafusion_expr::expr::{self, Exists, SetComparison, SetQuantifier}; +use datafusion_expr::logical_plan::Subquery; +use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; +use datafusion_expr::{Expr, LogicalPlan, lit}; +use std::sync::Arc; + +use datafusion_expr::utils::merge_schema; + +/// Rewrite `SetComparison` expressions to scalar subqueries that return the +/// correct boolean value (including SQL NULL semantics). After this rule +/// runs, later rules such as `ScalarSubqueryToJoin` can decorrelate and +/// remove the remaining subquery. +#[derive(Debug, Default)] +pub struct RewriteSetComparison; + +impl RewriteSetComparison { + /// Create a new `RewriteSetComparison` optimizer rule. + pub fn new() -> Self { + Self + } + + fn rewrite_plan(&self, plan: LogicalPlan) -> Result> { + let schema = merge_schema(&plan.inputs()); + plan.map_expressions(|expr| { + expr.transform_up(|expr| rewrite_set_comparison(expr, &schema)) + }) + } +} + +impl OptimizerRule for RewriteSetComparison { + fn name(&self) -> &str { + "rewrite_set_comparison" + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + plan.transform_up_with_subqueries(|plan| self.rewrite_plan(plan)) + } +} + +fn rewrite_set_comparison( + expr: Expr, + outer_schema: &DFSchema, +) -> Result> { + match expr { + Expr::SetComparison(set_comparison) => { + let rewritten = build_set_comparison_subquery(set_comparison, outer_schema)?; + Ok(Transformed::yes(rewritten)) + } + _ => Ok(Transformed::no(expr)), + } +} + +fn build_set_comparison_subquery( + set_comparison: SetComparison, + outer_schema: &DFSchema, +) -> Result { + let SetComparison { + expr, + subquery, + op, + quantifier, + } = set_comparison; + + let left_expr = to_outer_reference(*expr, outer_schema)?; + let subquery_schema = subquery.subquery.schema(); + if subquery_schema.fields().is_empty() { + return plan_err!("single expression required."); + } + // avoid `head_output_expr` for aggr/window plan, it will gives group-by expr if exists + let right_expr = Expr::Column(Column::from(subquery_schema.qualified_field(0))); + + let comparison = Expr::BinaryExpr(expr::BinaryExpr::new( + Box::new(left_expr), + op, + Box::new(right_expr), + )); + + let true_exists = + exists_subquery(&subquery, Expr::IsTrue(Box::new(comparison.clone())))?; + let null_exists = + exists_subquery(&subquery, Expr::IsNull(Box::new(comparison.clone())))?; + + let result_expr = match quantifier { + SetQuantifier::Any => Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![ + (Box::new(true_exists), Box::new(lit(true))), + ( + Box::new(null_exists), + Box::new(Expr::Literal(ScalarValue::Boolean(None), None)), + ), + ], + else_expr: Some(Box::new(lit(false))), + }), + SetQuantifier::All => { + let false_exists = + exists_subquery(&subquery, Expr::IsFalse(Box::new(comparison.clone())))?; + Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![ + (Box::new(false_exists), Box::new(lit(false))), + ( + Box::new(null_exists), + Box::new(Expr::Literal(ScalarValue::Boolean(None), None)), + ), + ], + else_expr: Some(Box::new(lit(true))), + }) + } + }; + + Ok(result_expr) +} + +fn exists_subquery(subquery: &Subquery, filter: Expr) -> Result { + let plan = LogicalPlanBuilder::from(subquery.subquery.as_ref().clone()) + .filter(filter)? + .build()?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::Exists(Exists { + subquery: Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: subquery.spans.clone(), + }, + negated: false, + })) +} + +fn to_outer_reference(expr: Expr, outer_schema: &DFSchema) -> Result { + expr.transform_up(|expr| match expr { + Expr::Column(col) => { + let field = outer_schema.field_from_column(&col)?; + Ok(Transformed::yes(Expr::OuterReferenceColumn( + Arc::clone(field), + col, + ))) + } + Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), + _ => Ok(Transformed::no(expr)), + }) + .map(|t| t.data) +} diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 2df1be1b7f0ba..44011a125ba96 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -30,38 +30,50 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; -use datafusion_common::{assert_or_internal_err, plan_err, Column, Result, ScalarValue}; +use datafusion_common::{Column, Result, ScalarValue, assert_or_internal_err, plan_err}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::conjunction; -use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, lit, not, when}; -/// Optimizer rule for rewriting subquery filters to joins -/// and places additional projection on top of the filter, to preserve -/// original schema. +/// Optimizer rule that rewrites scalar subquery filters to joins and places an +/// additional projection on top of the filter to preserve the original schema. +/// +/// When [`datafusion_common::config::OptimizerOptions::enable_physical_uncorrelated_scalar_subquery`] is +/// true (the default), only *correlated* scalar subqueries are rewritten here; +/// uncorrelated ones are left for physical execution via `ScalarSubqueryExec`. +/// When the option is false, all scalar subqueries — correlated and +/// uncorrelated — are rewritten to left joins by this rule. #[derive(Default, Debug)] pub struct ScalarSubqueryToJoin {} impl ScalarSubqueryToJoin { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self::default() } - /// Finds expressions that have a scalar subquery in them (and recurses when found) + /// Finds expressions that contain correlated scalar subqueries (and + /// recurses when found). /// /// # Arguments - /// * `predicate` - A conjunction to split and search + /// * `predicate` - A conjunction to split and search. + /// * `alias_gen` - Generator used to produce unique aliases for each + /// extracted scalar subquery (e.g. `__scalar_sq_1`, `__scalar_sq_2`). + /// Each subquery is replaced by a column reference using the generated + /// alias, and the same alias is later used to construct the join. /// /// Returns a tuple (subqueries, alias) fn extract_subquery_exprs( &self, predicate: &Expr, alias_gen: &Arc, + physical_uncorrelated: bool, ) -> Result<(Vec<(Subquery, String)>, Expr)> { let mut extract = ExtractScalarSubQuery { sub_query_info: vec![], alias_gen, + physical_uncorrelated, }; predicate .clone() @@ -83,15 +95,23 @@ impl OptimizerRule for ScalarSubqueryToJoin { ) -> Result> { match plan { LogicalPlan::Filter(filter) => { + let physical_uncorrelated = config + .options() + .optimizer + .enable_physical_uncorrelated_scalar_subquery; // Optimization: skip the rest of the rule and its copies if - // there are no scalar subqueries - if !contains_scalar_subquery(&filter.predicate) { + // there are no scalar subqueries this rule should rewrite + if !contains_scalar_subquery_to_rewrite( + &filter.predicate, + physical_uncorrelated, + ) { return Ok(Transformed::no(LogicalPlan::Filter(filter))); } let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs( &filter.predicate, config.alias_generator(), + physical_uncorrelated, )?; assert_or_internal_err!( @@ -102,18 +122,17 @@ impl OptimizerRule for ScalarSubqueryToJoin { // iterate through all subqueries in predicate, turning each into a left join let mut cur_input = filter.input.as_ref().clone(); for (subquery, alias) in subqueries { - if let Some((optimized_subquery, expr_check_map)) = + if let Some((optimized_subquery, compensation_exprs)) = build_join(&subquery, &cur_input, &alias)? { - if !expr_check_map.is_empty() { + if !compensation_exprs.is_empty() { rewrite_expr = rewrite_expr .transform_up(|expr| { - // replace column references with entry in map, if it exists - if let Some(map_expr) = expr + if let Some(compensation_expr) = expr .try_as_col() - .and_then(|col| expr_check_map.get(&col.name)) + .and_then(|col| compensation_exprs.get(col)) { - Ok(Transformed::yes(map_expr.clone())) + Ok(Transformed::yes(compensation_expr.clone())) } else { Ok(Transformed::no(expr)) } @@ -137,23 +156,33 @@ impl OptimizerRule for ScalarSubqueryToJoin { Ok(Transformed::yes(new_plan)) } LogicalPlan::Projection(projection) => { - // Optimization: skip the rest of the rule and its copies if - // there are no scalar subqueries - if !projection.expr.iter().any(contains_scalar_subquery) { + let physical_uncorrelated = config + .options() + .optimizer + .enable_physical_uncorrelated_scalar_subquery; + // Optimization: skip the rest of the rule and its copies if there + // are no scalar subqueries this rule should rewrite + if !projection.expr.iter().any(|expr| { + contains_scalar_subquery_to_rewrite(expr, physical_uncorrelated) + }) { return Ok(Transformed::no(LogicalPlan::Projection(projection))); } let mut all_subqueries = vec![]; - let mut expr_to_rewrite_expr_map = HashMap::new(); - let mut subquery_to_expr_map = HashMap::new(); - for expr in projection.expr.iter() { - let (subqueries, rewrite_exprs) = - self.extract_subquery_exprs(expr, config.alias_generator())?; - for (subquery, _) in &subqueries { - subquery_to_expr_map.insert(subquery.clone(), expr.clone()); + let mut alias_to_index: HashMap = HashMap::new(); + let mut rewrite_exprs: Vec = + Vec::with_capacity(projection.expr.len()); + for (idx, expr) in projection.expr.iter().enumerate() { + let (subqueries, rewrite_expr) = self.extract_subquery_exprs( + expr, + config.alias_generator(), + physical_uncorrelated, + )?; + for (_, alias) in &subqueries { + alias_to_index.insert(alias.clone(), idx); } all_subqueries.extend(subqueries); - expr_to_rewrite_expr_map.insert(expr, rewrite_exprs); + rewrite_exprs.push(rewrite_expr); } assert_or_internal_err!( !all_subqueries.is_empty(), @@ -162,33 +191,27 @@ impl OptimizerRule for ScalarSubqueryToJoin { // iterate through all subqueries in predicate, turning each into a left join let mut cur_input = projection.input.as_ref().clone(); for (subquery, alias) in all_subqueries { - if let Some((optimized_subquery, expr_check_map)) = + if let Some((optimized_subquery, compensation_exprs)) = build_join(&subquery, &cur_input, &alias)? { cur_input = optimized_subquery; - if !expr_check_map.is_empty() { - if let Some(expr) = subquery_to_expr_map.get(&subquery) { - if let Some(rewrite_expr) = - expr_to_rewrite_expr_map.get(expr) - { - let new_expr = rewrite_expr - .clone() - .transform_up(|expr| { - // replace column references with entry in map, if it exists - if let Some(map_expr) = - expr.try_as_col().and_then(|col| { - expr_check_map.get(&col.name) - }) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - }) - .data()?; - expr_to_rewrite_expr_map.insert(expr, new_expr); - } - } + if !compensation_exprs.is_empty() + && let Some(&idx) = alias_to_index.get(&alias) + { + let new_expr = rewrite_exprs[idx] + .clone() + .transform_up(|expr| { + if let Some(compensation_expr) = expr + .try_as_col() + .and_then(|col| compensation_exprs.get(col)) + { + Ok(Transformed::yes(compensation_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + }) + .data()?; + rewrite_exprs[idx] = new_expr; } } else { // if we can't handle all of the subqueries then bail for now @@ -197,14 +220,13 @@ impl OptimizerRule for ScalarSubqueryToJoin { } let mut proj_exprs = vec![]; - for expr in projection.expr.iter() { + for (expr, new_expr) in projection.expr.iter().zip(rewrite_exprs) { let old_expr_name = expr.schema_name().to_string(); - let new_expr = expr_to_rewrite_expr_map.get(expr).unwrap(); let new_expr_name = new_expr.schema_name().to_string(); if new_expr_name != old_expr_name { - proj_exprs.push(new_expr.clone().alias(old_expr_name)) + proj_exprs.push(new_expr.alias(old_expr_name)) } else { - proj_exprs.push(new_expr.clone()); + proj_exprs.push(new_expr); } } let new_plan = LogicalPlanBuilder::from(cur_input) @@ -226,16 +248,28 @@ impl OptimizerRule for ScalarSubqueryToJoin { } } -/// Returns true if the expression has a scalar subquery somewhere in it -/// false otherwise -fn contains_scalar_subquery(expr: &Expr) -> bool { - expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_)))) - .expect("Inner is always Ok") +/// Returns true if the expression contains a scalar subquery that this rule +/// should rewrite to a join. +/// +/// When `enable_physical_uncorrelated_scalar_subquery` is true (the default) only +/// correlated scalar subqueries are rewritten — uncorrelated ones are handled +/// by the physical planner via `ScalarSubqueryExec`. When it is false, all +/// scalar subqueries (correlated and uncorrelated) are rewritten. +fn contains_scalar_subquery_to_rewrite(expr: &Expr, physical_uncorrelated: bool) -> bool { + expr.exists(|expr| { + Ok(matches!( + expr, + Expr::ScalarSubquery(sq) + if !physical_uncorrelated || !sq.outer_ref_columns.is_empty() + )) + }) + .expect("Inner is always Ok") } struct ExtractScalarSubQuery<'a> { sub_query_info: Vec<(Subquery, String)>, alias_gen: &'a Arc, + physical_uncorrelated: bool, } impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { @@ -243,19 +277,25 @@ impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { fn f_down(&mut self, expr: Expr) -> Result> { match expr { - Expr::ScalarSubquery(subquery) => { - let subqry_alias = self.alias_gen.next("__scalar_sq"); - self.sub_query_info - .push((subquery.clone(), subqry_alias.clone())); + // Match scalar subqueries this rule should rewrite to a join. When + // `physical_uncorrelated` is true, only correlated subqueries are + // rewritten — uncorrelated ones are handled later by the physical + // planner. When false, both are rewritten. + Expr::ScalarSubquery(ref subquery) + if !self.physical_uncorrelated + || !subquery.outer_ref_columns.is_empty() => + { + let subquery = subquery.clone(); let scalar_expr = subquery .subquery .head_output_expr()? .map_or(plan_err!("single expression required."), Ok)?; + let subqry_alias = self.alias_gen.next("__scalar_sq"); + let col = + create_col_from_scalar_expr(&scalar_expr, subqry_alias.clone())?; + self.sub_query_info.push((subquery, subqry_alias)); Ok(Transformed::new( - Expr::Column(create_col_from_scalar_expr( - &scalar_expr, - subqry_alias, - )?), + Expr::Column(col), true, TreeNodeRecursion::Jump, )) @@ -276,134 +316,122 @@ impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { /// /// ```text /// select c.id from customers c -/// left join (select c_id, avg(total) as val from orders group by c_id) o on o.c_id = c.c_id -/// where c.balance > o.val +/// left join (select c_id, avg(total) from orders group by c_id) o +/// on o.c_id = c.id +/// where c.balance > o."avg(total)" /// ``` /// -/// Or a query like: -/// -/// ```text -/// select id from customers where balance > -/// (select avg(total) from orders) -/// ``` -/// -/// and optimizes it into: -/// -/// ```text -/// select c.id from customers c -/// left join (select avg(total) as val from orders) a -/// where c.balance > a.val -/// ``` +/// When [`datafusion_common::config::OptimizerOptions::enable_physical_uncorrelated_scalar_subquery`] is +/// false, this function also handles uncorrelated scalar subqueries, rewriting +/// them as a `Left Join: Filter: Boolean(true)` instead of leaving them for +/// `ScalarSubqueryExec`. /// /// # Arguments /// -/// * `query_info` - The subquery portion of the `where` (select avg(total) from orders) -/// * `filter_input` - The non-subquery portion (from customers) -/// * `outer_others` - Any additional parts to the `where` expression (and c.x = y) -/// * `subquery_alias` - Subquery aliases +/// * `subquery` - The scalar subquery to rewrite (correlated, or uncorrelated +/// when `enable_physical_uncorrelated_scalar_subquery` is false). +/// * `outer_input` - The outer plan that the decorrelated subquery is +/// left-joined onto — the input of the `Filter` or `Projection` node +/// that contained the subquery. +/// * `subquery_alias` - The unique alias assigned to the decorrelated +/// subquery; used both to qualify the join condition and to produce +/// column references for the caller to substitute. +/// +/// Returns `Ok(None)` if the subquery cannot be decorrelated. On success, +/// returns the rewritten outer plan and a map from each count-bug-affected +/// column to its `CASE WHEN __always_true IS NULL THEN ... END` compensation +/// expression, which the caller must substitute into any expression that +/// references those columns. fn build_join( subquery: &Subquery, - filter_input: &LogicalPlan, + outer_input: &LogicalPlan, subquery_alias: &str, -) -> Result)>> { +) -> Result)>> { + // `build_join` also handles uncorrelated scalar subqueries (as a left + // join with `Boolean(true)`) when the + // `enable_physical_uncorrelated_scalar_subquery` option is disabled. let subquery_plan = subquery.subquery.as_ref(); let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true); - let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?; + let decorrelated_subquery = subquery_plan.clone().rewrite(&mut pull_up).data()?; if !pull_up.can_pull_up { return Ok(None); } - let collected_count_expr_map = - pull_up.collected_count_expr_map.get(&new_plan).cloned(); - let sub_query_alias = LogicalPlanBuilder::from(new_plan) + let collected_count_expr_map = pull_up + .collected_count_expr_map + .get(&decorrelated_subquery) + .cloned(); + let aliased_subquery = LogicalPlanBuilder::from(decorrelated_subquery) .alias(subquery_alias.to_string())? .build()?; - let mut all_correlated_cols = BTreeSet::new(); - pull_up + let all_correlated_cols: BTreeSet = pull_up .correlated_subquery_cols_map .values() - .for_each(|cols| all_correlated_cols.extend(cols.clone())); + .flatten() + .cloned() + .collect(); - // alias the join filter + // Correlated columns now live in the decorrelated subquery's output, + // so re-qualify them with the subquery alias. let join_filter_opt = conjunction(pull_up.join_filters).map_or(Ok(None), |filter| { replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some) })?; - // join our sub query into the main plan - let new_plan = if join_filter_opt.is_none() { - match filter_input { - LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: true, - schema: _, - }) => sub_query_alias, - _ => { - // if not correlated, group down to 1 row and left join on that (preserving row count) - LogicalPlanBuilder::from(filter_input.clone()) - .join_on( - sub_query_alias, - JoinType::Left, - vec![Expr::Literal(ScalarValue::Boolean(Some(true)), None)], - )? - .build()? - } - } - } else { - // left join if correlated, grouping by the join keys so we don't change row count - LogicalPlanBuilder::from(filter_input.clone()) - .join_on(sub_query_alias, JoinType::Left, join_filter_opt)? - .build()? - }; - let mut computation_project_expr = HashMap::new(); + // When pull-up did not extract any usable join keys (a correlated subquery + // whose predicate references only outer columns), fall back to `ON true`: + // the decorrelated subquery still yields at most one row per outer row + // because its aggregate is grouped by the (empty) set of correlated inner + // columns. + let join_filter = join_filter_opt.or_else(|| Some(lit(true))); + + let new_plan = LogicalPlanBuilder::from(outer_input.clone()) + .join_on(aliased_subquery, JoinType::Left, join_filter)? + .build()?; + + // Add count-bug compensation for each of the subquery's projected + // expressions that yield non-NULL values on empty input. We wrap each + // such expression in a CASE that substitutes the empty-input value + // when the LEFT JOIN produced synthetic right-side NULLs (no inner + // row matched), and uses the actual right-side value (which may + // itself be NULL) otherwise. + let mut compensation_exprs = HashMap::new(); if let Some(expr_map) = collected_count_expr_map { + let mut expr_rewrite = TypeCoercionRewriter { + schema: new_plan.schema(), + }; + let having_arm = pull_up + .pull_up_having_expr + .as_ref() + .map(|f| (not(f.clone()), lit(ScalarValue::Null))); for (name, result) in expr_map { if evaluates_to_null(result.clone(), result.column_refs())? { - // If expr always returns null when column is null, skip processing + // Aggregates whose empty-input value is NULL (max/min/sum/…) + // need no compensation: the LEFT JOIN already produces NULL + // for unmatched outer rows. continue; } - let computer_expr = if let Some(filter) = &pull_up.pull_up_having_expr { - Expr::Case(expr::Case { - expr: None, - when_then_expr: vec![ - ( - Box::new(Expr::IsNull(Box::new(Expr::Column( - Column::new_unqualified(UN_MATCHED_ROW_INDICATOR), - )))), - Box::new(result), - ), - ( - Box::new(Expr::Not(Box::new(filter.clone()))), - Box::new(Expr::Literal(ScalarValue::Null, None)), - ), - ], - else_expr: Some(Box::new(Expr::Column(Column::new_unqualified( - name.clone(), - )))), - }) - } else { - Expr::Case(expr::Case { - expr: None, - when_then_expr: vec![( - Box::new(Expr::IsNull(Box::new(Expr::Column( - Column::new_unqualified(UN_MATCHED_ROW_INDICATOR), - )))), - Box::new(result), - )], - else_expr: Some(Box::new(Expr::Column(Column::new_unqualified( - name.clone(), - )))), - }) - }; - let mut expr_rewrite = TypeCoercionRewriter { - schema: new_plan.schema(), - }; - computation_project_expr - .insert(name, computer_expr.rewrite(&mut expr_rewrite).data()?); + + let indicator_col = + Column::new(Some(subquery_alias), UN_MATCHED_ROW_INDICATOR); + // Qualify with the subquery alias to avoid ambiguity when the + // outer table has a column with the same name as the aggregate. + let value_col = Column::new(Some(subquery_alias), name); + + let mut builder = when(Expr::Column(indicator_col).is_null(), result); + if let Some((when_expr, then_expr)) = &having_arm { + builder = builder.when(when_expr.clone(), then_expr.clone()); + } + let compensation_expr = builder.otherwise(Expr::Column(value_col.clone()))?; + compensation_exprs.insert( + value_col, + compensation_expr.rewrite(&mut expr_rewrite).data()?, + ); } } - Ok(Some((new_plan, computation_project_expr))) + Ok(Some((new_plan, compensation_exprs))) } #[cfg(test)] @@ -417,7 +445,7 @@ mod tests { use datafusion_expr::test::function_stub::sum; use crate::assert_optimized_plan_eq_display_indent_snapshot; - use datafusion_expr::{col, lit, out_ref_col, scalar_subquery, Between}; + use datafusion_expr::{Between, col, expr, out_ref_col, scalar_subquery}; use datafusion_functions_aggregate::min_max::{max, min}; macro_rules! assert_optimized_plan_equal { @@ -628,15 +656,13 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] - Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] - Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] - Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] " ) } @@ -831,7 +857,7 @@ mod tests { assert_optimized_plan_equal!( plan, @r#" - Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(NULL AS Boolean) THEN Utf8("a") ELSE Utf8("b") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N] + Projection: customer.c_custkey, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN CASE WHEN CAST(Float64(NULL) AS Boolean) THEN Utf8("a") ELSE Utf8("b") END ELSE __scalar_sq_1.CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END END AS CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END [c_custkey:Int64, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N] Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8;N, o_custkey:Int64;N, __always_true:Boolean;N] TableScan: customer [c_custkey:Int64, c_name:Utf8] SubqueryAlias: __scalar_sq_1 [CASE WHEN max(orders.o_totalprice) THEN Utf8("a") ELSE Utf8("b") END:Utf8, o_custkey:Int64, __always_true:Boolean] @@ -1033,14 +1059,12 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] - Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] - Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Filter: customer.c_custkey < () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] " ) } @@ -1063,14 +1087,12 @@ mod tests { plan, @r" Projection: customer.c_custkey [c_custkey:Int64] - Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] - Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] - Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] " ) } @@ -1161,21 +1183,66 @@ mod tests { assert_optimized_plan_equal!( plan, @r" + Projection: customer.c_custkey [c_custkey:Int64] + Filter: customer.c_custkey BETWEEN () AND () [c_custkey:Int64, c_name:Utf8] + Subquery: [min(orders.o_custkey):Int64;N] + Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + Subquery: [max(orders.o_custkey):Int64;N] + Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] + Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] + TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + " + ) + } + + #[test] + fn uncorrelated_scalar_subquery_rewritten_when_flag_off() -> Result<()> { + use datafusion_common::config::ConfigOptions; + + let sq = Arc::new( + LogicalPlanBuilder::from(scan_tpch_table("orders")) + .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? + .project(vec![max(col("orders.o_custkey"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) + .filter(col("customer.c_custkey").eq(scalar_subquery(sq)))? + .project(vec![col("customer.c_custkey")])? + .build()?; + + let mut options = ConfigOptions::default(); + options + .optimizer + .enable_physical_uncorrelated_scalar_subquery = false; + let context = crate::OptimizerContext::new_with_config_options(Arc::new(options)); + + let rule: Arc = + Arc::new(ScalarSubqueryToJoin::new()); + let optimizer = crate::Optimizer::with_rules(vec![rule]); + let optimized_plan = optimizer + .optimize(plan, &context, |_, _| {}) + .expect("failed to optimize plan"); + let formatted_plan = optimized_plan.display_indent_schema(); + + insta::assert_snapshot!( + formatted_plan, + @r" Projection: customer.c_custkey [c_custkey:Int64] Projection: customer.c_custkey, customer.c_name [c_custkey:Int64, c_name:Utf8] - Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N] - Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8] - SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N] - Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N] - Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N] + Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + Left Join: Filter: Boolean(true) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N] + TableScan: customer [c_custkey:Int64, c_name:Utf8] + SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N] Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N] Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N] TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] " - ) + ); + + Ok(()) } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 366c99ce8f28b..39c8541b51b2f 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -18,7 +18,7 @@ //! Expression simplification API use arrow::{ - array::{new_null_array, AsArray}, + array::{Array, AsArray, new_null_array}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; @@ -26,38 +26,45 @@ use std::borrow::Cow; use std::collections::HashSet; use std::ops::Not; use std::sync::Arc; +use std::sync::LazyLock; +use datafusion_common::config::ConfigOptions; +use datafusion_common::nested_struct::has_one_of_more_common_fields; use datafusion_common::{ + DFSchema, DataFusionError, Result, ScalarValue, exec_datafusion_err, internal_err, +}; +use datafusion_common::{ + HashMap, cast::{as_large_list_array, as_list_array}, metadata::FieldMetadata, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, - HashMap, -}; -use datafusion_common::{ - exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::HigherOrderFunction; use datafusion_expr::{ - and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, - Operator, Volatility, + BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Like, Operator, Volatility, + and, binary::BinaryTypeCoercer, lit, or, preimage::PreimageResult, }; +use datafusion_expr::{Cast, TryCast, simplify::ExprSimplifyResult}; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ expr::{InList, InSubquery}, utils::{iter_conjunction, iter_conjunction_owned}, }; -use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; -use crate::analyzer::type_coercion::TypeCoercionRewriter; +use crate::simplify_expressions::SimplifyContext; use crate::simplify_expressions::regex::simplify_regex_expr; use crate::simplify_expressions::unwrap_cast::{ is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary, is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist, unwrap_cast_in_comparison_for_binary, }; -use crate::simplify_expressions::SimplifyInfo; +use crate::{ + analyzer::type_coercion::TypeCoercionRewriter, + simplify_expressions::udf_preimage::rewrite_with_preimage, +}; use datafusion_expr::expr_rewriter::rewrite_with_guarantees_map; use datafusion_expr_common::casts::try_cast_literal_to_type; use indexmap::IndexSet; @@ -72,7 +79,6 @@ use regex::Regex; /// ``` /// use arrow::datatypes::{DataType, Field, Schema}; /// use datafusion_common::{DataFusionError, ToDFSchema}; -/// use datafusion_expr::execution_props::ExecutionProps; /// use datafusion_expr::simplify::SimplifyContext; /// use datafusion_expr::{col, lit}; /// use datafusion_optimizer::simplify_expressions::ExprSimplifier; @@ -83,8 +89,7 @@ use regex::Regex; /// .unwrap(); /// /// // Create the simplifier -/// let props = ExecutionProps::new(); -/// let context = SimplifyContext::new(&props).with_schema(schema); +/// let context = SimplifyContext::builder().with_schema(schema).build(); /// let simplifier = ExprSimplifier::new(context); /// /// // Use the simplifier @@ -96,8 +101,8 @@ use regex::Regex; /// let simplified = simplifier.simplify(expr).unwrap(); /// assert_eq!(simplified, col("b").lt(lit(2))); /// ``` -pub struct ExprSimplifier { - info: S, +pub struct ExprSimplifier { + info: SimplifyContext, /// Guarantees about the values of columns. This is provided by the user /// in [ExprSimplifier::with_guarantees()]. guarantees: Vec<(Expr, NullableInterval)>, @@ -111,13 +116,12 @@ pub struct ExprSimplifier { pub const THRESHOLD_INLINE_INLIST: usize = 3; pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3; -impl ExprSimplifier { - /// Create a new `ExprSimplifier` with the given `info` such as an - /// instance of [`SimplifyContext`]. See - /// [`simplify`](Self::simplify) for an example. +impl ExprSimplifier { + /// Create a new `ExprSimplifier` with the given [`SimplifyContext`]. + /// See [`simplify`](Self::simplify) for an example. /// /// [`SimplifyContext`]: datafusion_expr::simplify::SimplifyContext - pub fn new(info: S) -> Self { + pub fn new(info: SimplifyContext) -> Self { Self { info, guarantees: vec![], @@ -142,40 +146,21 @@ impl ExprSimplifier { /// `b > 2` /// /// ``` - /// use arrow::datatypes::DataType; - /// use datafusion_common::DFSchema; + /// use arrow::datatypes::{DataType, Field, Schema}; + /// use datafusion_common::{DFSchema, ToDFSchema}; /// use datafusion_common::Result; - /// use datafusion_expr::execution_props::ExecutionProps; /// use datafusion_expr::simplify::SimplifyContext; - /// use datafusion_expr::simplify::SimplifyInfo; /// use datafusion_expr::{col, lit, Expr}; /// use datafusion_optimizer::simplify_expressions::ExprSimplifier; /// use std::sync::Arc; /// - /// /// Simple implementation that provides `Simplifier` the information it needs - /// /// See SimplifyContext for a structure that does this. - /// #[derive(Default)] - /// struct Info { - /// execution_props: ExecutionProps, - /// }; - /// - /// impl SimplifyInfo for Info { - /// fn is_boolean_type(&self, expr: &Expr) -> Result { - /// Ok(false) - /// } - /// fn nullable(&self, expr: &Expr) -> Result { - /// Ok(true) - /// } - /// fn execution_props(&self) -> &ExecutionProps { - /// &self.execution_props - /// } - /// fn get_data_type(&self, expr: &Expr) -> Result { - /// Ok(DataType::Int32) - /// } - /// } - /// + /// // Create a schema and SimplifyContext + /// let schema = Schema::new(vec![Field::new("b", DataType::Int32, true)]) + /// .to_dfschema_ref() + /// .unwrap(); /// // Create the simplifier - /// let simplifier = ExprSimplifier::new(Info::default()); + /// let context = SimplifyContext::builder().with_schema(schema).build(); + /// let simplifier = ExprSimplifier::new(context); /// /// // b < 2 /// let b_lt_2 = col("b").gt(lit(2)); @@ -201,7 +186,7 @@ impl ExprSimplifier { since = "48.0.0", note = "Use `simplify_with_cycle_count_transformed` instead" )] - #[allow(unused_mut)] + #[expect(unused_mut)] pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> { let (transformed, cycle_count) = self.simplify_with_cycle_count_transformed(expr)?; @@ -225,7 +210,8 @@ impl ExprSimplifier { mut expr: Expr, ) -> Result<(Transformed, u32)> { let mut simplifier = Simplifier::new(&self.info); - let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; + let config_options = Some(Arc::clone(self.info.config_options())); + let mut const_evaluator = ConstEvaluator::try_new(config_options)?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); let guarantees_map: HashMap<&Expr, &NullableInterval> = self.guarantees.iter().map(|(k, v)| (k, v)).collect(); @@ -287,7 +273,6 @@ impl ExprSimplifier { /// ```rust /// use arrow::datatypes::{DataType, Field, Schema}; /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; - /// use datafusion_expr::execution_props::ExecutionProps; /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; /// use datafusion_expr::simplify::SimplifyContext; /// use datafusion_expr::{col, lit, Expr}; @@ -302,8 +287,7 @@ impl ExprSimplifier { /// .unwrap(); /// /// // Create the simplifier - /// let props = ExecutionProps::new(); - /// let context = SimplifyContext::new(&props).with_schema(schema); + /// let context = SimplifyContext::builder().with_schema(schema).build(); /// /// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5) /// let expr_x = col("x").gt_eq(lit(3_i64)); @@ -349,7 +333,6 @@ impl ExprSimplifier { /// ```rust /// use arrow::datatypes::{DataType, Field, Schema}; /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; - /// use datafusion_expr::execution_props::ExecutionProps; /// use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; /// use datafusion_expr::simplify::SimplifyContext; /// use datafusion_expr::{col, lit, Expr}; @@ -364,8 +347,7 @@ impl ExprSimplifier { /// .unwrap(); /// /// // Create the simplifier - /// let props = ExecutionProps::new(); - /// let context = SimplifyContext::new(&props).with_schema(schema); + /// let context = SimplifyContext::builder().with_schema(schema).build(); /// let simplifier = ExprSimplifier::new(context); /// /// // Expression: a = c AND 1 = b @@ -410,7 +392,6 @@ impl ExprSimplifier { /// use arrow::datatypes::{DataType, Field, Schema}; /// use datafusion_expr::{col, lit, Expr}; /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; - /// use datafusion_expr::execution_props::ExecutionProps; /// use datafusion_expr::simplify::SimplifyContext; /// use datafusion_optimizer::simplify_expressions::ExprSimplifier; /// @@ -420,9 +401,7 @@ impl ExprSimplifier { /// .to_dfschema_ref().unwrap(); /// /// // Create the simplifier - /// let props = ExecutionProps::new(); - /// let context = SimplifyContext::new(&props) - /// .with_schema(schema); + /// let context = SimplifyContext::builder().with_schema(schema).build(); /// let simplifier = ExprSimplifier::new(context); /// /// // Expression: a IS NOT NULL @@ -496,12 +475,11 @@ impl TreeNodeRewriter for Canonicalizer { } } -#[allow(rustdoc::private_intra_doc_links)] /// Partially evaluate `Expr`s so constant subtrees are evaluated at plan time. /// /// Note it does not handle algebraic rewrites such as `(a or false)` /// --> `a`, which is handled by [`Simplifier`] -struct ConstEvaluator<'a> { +struct ConstEvaluator { /// `can_evaluate` is used during the depth-first-search of the /// `Expr` tree to track if any siblings (or their descendants) were /// non evaluatable (e.g. had a column reference or volatile @@ -515,13 +493,15 @@ struct ConstEvaluator<'a> { /// means there were no non evaluatable siblings (or their /// descendants) so this `Expr` can be evaluated can_evaluate: Vec, - - execution_props: &'a ExecutionProps, - input_schema: DFSchema, - input_batch: RecordBatch, + /// Execution properties needed to call [`create_physical_expr`]. + /// `ConstEvaluator` only evaluates expressions without column references + /// (i.e. constant expressions) and doesn't use the variable binding features + /// of `ExecutionProps` (we explicitly filter out [`Expr::ScalarVariable`]). + /// The `config_options` are passed from the session to allow scalar functions + /// to access configuration like timezone. + execution_props: ExecutionProps, } -#[allow(dead_code)] /// The simplify result of ConstEvaluator enum ConstSimplifyResult { // Expr was simplified and contains the new expression @@ -532,7 +512,7 @@ enum ConstSimplifyResult { SimplifyRuntimeError(DataFusionError, Expr), } -impl TreeNodeRewriter for ConstEvaluator<'_> { +impl TreeNodeRewriter for ConstEvaluator { type Node = Expr; fn f_down(&mut self, expr: Expr) -> Result> { @@ -580,10 +560,9 @@ impl TreeNodeRewriter for ConstEvaluator<'_> { // This provides clearer error messages and fails fast. if let Expr::Cast(Cast { ref expr, .. }) | Expr::TryCast(TryCast { ref expr, .. }) = expr + && matches!(expr.as_ref(), Expr::Literal(_, _)) { - if matches!(expr.as_ref(), Expr::Literal(_, _)) { - return Err(err); - } + return Err(err); } // For other expressions (like CASE, COALESCE), preserve the original // to allow short-circuit evaluation at execution time @@ -596,29 +575,38 @@ impl TreeNodeRewriter for ConstEvaluator<'_> { } } -impl<'a> ConstEvaluator<'a> { - /// Create a new `ConstantEvaluator`. Session constants (such as - /// the time for `now()` are taken from the passed - /// `execution_props`. - pub fn try_new(execution_props: &'a ExecutionProps) -> Result { +static DUMMY_SCHEMA: LazyLock> = + LazyLock::new(|| Arc::new(Schema::new(vec![Field::new(".", DataType::Null, true)]))); + +static DUMMY_DF_SCHEMA: LazyLock = + LazyLock::new(|| DFSchema::try_from(Arc::clone(&*DUMMY_SCHEMA)).unwrap()); + +static DUMMY_BATCH: LazyLock = LazyLock::new(|| { + // Need a single "input" row to produce a single output row + let col = new_null_array(&DataType::Null, 1); + RecordBatch::try_new(DUMMY_SCHEMA.clone(), vec![col]).unwrap() +}); + +impl ConstEvaluator { + /// Create a new `ConstantEvaluator`. + /// + /// Note: `ConstEvaluator` filters out expressions with scalar variables + /// (like `$var`) and volatile functions, so it creates its own default + /// `ExecutionProps` internally. The filtered expressions will be evaluated + /// at runtime where proper variable bindings are available. + /// + /// The `config_options` parameter is used to pass session configuration + /// (like timezone) to scalar functions during constant evaluation. + pub fn try_new(config_options: Option>) -> Result { // The dummy column name is unused and doesn't matter as only // expressions without column references can be evaluated - static DUMMY_COL_NAME: &str = "."; - let schema = Arc::new(Schema::new(vec![Field::new( - DUMMY_COL_NAME, - DataType::Null, - true, - )])); - let input_schema = DFSchema::try_from(Arc::clone(&schema))?; - // Need a single "input" row to produce a single output row - let col = new_null_array(&DataType::Null, 1); - let input_batch = RecordBatch::try_new(schema, vec![col])?; + + let mut execution_props = ExecutionProps::new(); + execution_props.config_options = config_options; Ok(Self { can_evaluate: vec![], execution_props, - input_schema, - input_batch, }) } @@ -649,6 +637,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::OuterReferenceColumn(_, _) | Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::ScalarSubquery(_) | Expr::WindowFunction { .. } | Expr::GroupingSet(_) @@ -657,6 +646,37 @@ impl<'a> ConstEvaluator<'a> { Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } + Expr::HigherOrderFunction(HigherOrderFunction { func, .. }) => { + Self::volatility_ok(func.signature().volatility) + } + Expr::Cast(Cast { expr, field }) | Expr::TryCast(TryCast { expr, field }) => { + if let ( + Ok(DataType::Struct(source_fields)), + DataType::Struct(target_fields), + ) = (expr.get_type(&DFSchema::empty()), field.data_type()) + { + // Don't const-fold struct casts with different field counts + if source_fields.len() != target_fields.len() { + return false; + } + + // Skip const-folding when there is no field name overlap + if !has_one_of_more_common_fields(&source_fields, target_fields) { + return false; + } + + // Don't const-fold struct casts with empty (0-row) literals + // The simplifier uses a 1-row input batch, which causes dimension mismatches + // when evaluating 0-row struct literals + if let Expr::Literal(ScalarValue::Struct(struct_array), _) = + expr.as_ref() + && struct_array.len() == 0 + { + return false; + } + } + true + } Expr::Literal(_, _) | Expr::Alias(..) | Expr::Unnest(_) @@ -675,9 +695,9 @@ impl<'a> ConstEvaluator<'a> { | Expr::Like { .. } | Expr::SimilarTo { .. } | Expr::Case(_) - | Expr::Cast { .. } - | Expr::TryCast { .. } - | Expr::InList { .. } => true, + | Expr::InList { .. } + | Expr::Lambda(_) + | Expr::LambdaVariable(_) => true, } } @@ -688,12 +708,12 @@ impl<'a> ConstEvaluator<'a> { } let phys_expr = - match create_physical_expr(&expr, &self.input_schema, self.execution_props) { + match create_physical_expr(&expr, &DUMMY_DF_SCHEMA, &self.execution_props) { Ok(e) => e, Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), }; let metadata = phys_expr - .return_field(self.input_batch.schema_ref()) + .return_field(DUMMY_BATCH.schema_ref()) .ok() .and_then(|f| { let m = f.metadata(); @@ -702,7 +722,7 @@ impl<'a> ConstEvaluator<'a> { false => Some(FieldMetadata::from(m)), } }); - let col_val = match phys_expr.evaluate(&self.input_batch) { + let col_val = match phys_expr.evaluate(&DUMMY_BATCH) { Ok(v) => v, Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), }; @@ -710,7 +730,10 @@ impl<'a> ConstEvaluator<'a> { ColumnarValue::Array(a) => { if a.len() != 1 { ConstSimplifyResult::SimplifyRuntimeError( - exec_datafusion_err!("Could not evaluate the expression, found a result of length {}", a.len()), + exec_datafusion_err!( + "Could not evaluate the expression, found a result of length {}", + a.len() + ), expr, ) } else if as_list_array(&a).is_ok() { @@ -745,17 +768,17 @@ impl<'a> ConstEvaluator<'a> { /// * `false = true` and `true = false` to `false` /// * `!!expr` to `expr` /// * `expr = null` and `expr != null` to `null` -struct Simplifier<'a, S> { - info: &'a S, +struct Simplifier<'a> { + info: &'a SimplifyContext, } -impl<'a, S> Simplifier<'a, S> { - pub fn new(info: &'a S) -> Self { +impl<'a> Simplifier<'a> { + pub fn new(info: &'a SimplifyContext) -> Self { Self { info } } } -impl TreeNodeRewriter for Simplifier<'_, S> { +impl TreeNodeRewriter for Simplifier<'_> { type Node = Expr; /// rewrite the expression simplifying any constant expressions @@ -1050,9 +1073,27 @@ impl TreeNodeRewriter for Simplifier<'_, S> { right: left_right, })) } else { - return internal_err!("can_reduce_to_equal_statement should only be called with a BinaryExpr"); + return internal_err!( + "can_reduce_to_equal_statement should only be called with a BinaryExpr" + ); } } + // A = L1 AND A != L2 --> A = L1 (when L1 != L2) + Expr::BinaryExpr(BinaryExpr { + left, + op: And, + right, + }) if is_eq_and_ne_with_different_literal(&left, &right) => { + Transformed::yes(*left) + } + // A != L2 AND A = L1 --> A = L1 (when L1 != L2) + Expr::BinaryExpr(BinaryExpr { + left, + op: And, + right, + }) if is_eq_and_ne_with_different_literal(&right, &left) => { + Transformed::yes(*right) + } // // Rules for Multiply @@ -1620,17 +1661,19 @@ impl TreeNodeRewriter for Simplifier<'_, S> { left, op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch), right, - }) => Transformed::yes(simplify_regex_expr(left, op, right)?), + }) => simplify_regex_expr(left, op, right)?, // Rules for Like Expr::Like(like) => { // `\` is implicit escape, see https://github.com/apache/datafusion/issues/13291 let escape_char = like.escape_char.unwrap_or('\\'); - match as_string_scalar(&like.pattern) { - Some((data_type, pattern_str)) => { + + match StringScalar::try_from_expr(&like.pattern) { + Some(string_scalar) => { + let pattern_str = string_scalar.as_str(); match pattern_str { None => return Ok(Transformed::yes(lit_bool_null())), - Some(pattern_str) if pattern_str == "%" => { + Some("%") => { // exp LIKE '%' is // - when exp is not NULL, it's true // - when exp is NULL, it's NULL @@ -1657,15 +1700,15 @@ impl TreeNodeRewriter for Simplifier<'_, S> { { // Repeated occurrences of wildcard are redundant so remove them // exp LIKE '%%' --> exp LIKE '%' - let simplified_pattern = Regex::new("%%+") - .unwrap() - .replace_all(pattern_str, "%") - .to_string(); + + static LIKE_REGEX: LazyLock = + LazyLock::new(|| Regex::new("%%+").unwrap()); + let simplified_pattern = + LIKE_REGEX.replace_all(pattern_str, "%").to_string(); Transformed::yes(Expr::Like(Like { - pattern: Box::new(to_string_scalar( - &data_type, - Some(simplified_pattern), - )), + pattern: Box::new( + string_scalar.to_expr(&simplified_pattern), + ), ..like })) } @@ -1747,6 +1790,8 @@ impl TreeNodeRewriter for Simplifier<'_, S> { }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => { let lhs = to_inlist(*left).unwrap(); let rhs = to_inlist(*right).unwrap(); + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] + // Expr contains Arc with interior mutability but is intentionally used as hash key let mut seen: HashSet = HashSet::new(); let list = lhs .list @@ -1960,30 +2005,184 @@ impl TreeNodeRewriter for Simplifier<'_, S> { })) } + // ======================================= + // preimage_in_comparison + // ======================================= + // + // For case: + // date_part('YEAR', expr) op literal + // + // For details see datafusion_expr::ScalarUDFImpl::preimage + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + use datafusion_expr::Operator::*; + let is_preimage_op = matches!( + op, + Eq | NotEq + | Lt + | LtEq + | Gt + | GtEq + | IsDistinctFrom + | IsNotDistinctFrom + ); + if !is_preimage_op || is_null(&right) { + return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))); + } + + if let PreimageResult::Range { interval, expr } = + get_preimage(left.as_ref(), right.as_ref(), info)? + { + rewrite_with_preimage(*interval, op, expr)? + } else if let Some(swapped) = op.swap() { + if let PreimageResult::Range { interval, expr } = + get_preimage(right.as_ref(), left.as_ref(), info)? + { + rewrite_with_preimage(*interval, swapped, expr)? + } else { + Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, right })) + } + } else { + Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, right })) + } + } + // For case: + // date_part('YEAR', expr) IN (literal1, literal2, ...) + Expr::InList(InList { + expr, + list, + negated, + }) => { + if list.len() > THRESHOLD_INLINE_INLIST || list.iter().any(is_null) { + return Ok(Transformed::no(Expr::InList(InList { + expr, + list, + negated, + }))); + } + + let (op, combiner): (Operator, fn(Expr, Expr) -> Expr) = + if negated { (NotEq, and) } else { (Eq, or) }; + + let mut rewritten: Option = None; + for item in &list { + let PreimageResult::Range { interval, expr } = + get_preimage(expr.as_ref(), item, info)? + else { + return Ok(Transformed::no(Expr::InList(InList { + expr, + list, + negated, + }))); + }; + + let range_expr = rewrite_with_preimage(*interval, op, expr)?.data; + rewritten = Some(match rewritten { + None => range_expr, + Some(acc) => combiner(acc, range_expr), + }); + } + + if let Some(rewritten) = rewritten { + Transformed::yes(rewritten) + } else { + Transformed::no(Expr::InList(InList { + expr, + list, + negated, + })) + } + } + // no additional rewrites possible expr => Transformed::no(expr), }) } } -fn as_string_scalar(expr: &Expr) -> Option<(DataType, &Option)> { +fn get_preimage( + left_expr: &Expr, + right_expr: &Expr, + info: &SimplifyContext, +) -> Result { + let Expr::ScalarFunction(ScalarFunction { func, args }) = left_expr else { + return Ok(PreimageResult::None); + }; + if !is_literal_or_literal_cast(right_expr) { + return Ok(PreimageResult::None); + } + if func.signature().volatility != Volatility::Immutable { + return Ok(PreimageResult::None); + } + func.preimage(args, right_expr, info) +} + +fn is_literal_or_literal_cast(expr: &Expr) -> bool { match expr { - Expr::Literal(ScalarValue::Utf8(s), _) => Some((DataType::Utf8, s)), - Expr::Literal(ScalarValue::LargeUtf8(s), _) => Some((DataType::LargeUtf8, s)), - Expr::Literal(ScalarValue::Utf8View(s), _) => Some((DataType::Utf8View, s)), - _ => None, + Expr::Literal(_, _) => true, + Expr::Cast(Cast { expr, .. }) => matches!(expr.as_ref(), Expr::Literal(_, _)), + Expr::TryCast(TryCast { expr, .. }) => { + matches!(expr.as_ref(), Expr::Literal(_, _)) + } + _ => false, } } -fn to_string_scalar(data_type: &DataType, value: Option) -> Expr { - match data_type { - DataType::Utf8 => Expr::Literal(ScalarValue::Utf8(value), None), - DataType::LargeUtf8 => Expr::Literal(ScalarValue::LargeUtf8(value), None), - DataType::Utf8View => Expr::Literal(ScalarValue::Utf8View(value), None), - _ => unreachable!(), +/// Helper for working with string scalar values (Utf8, LargeUtf8, Utf8View) +pub(crate) enum StringScalar<'a> { + Utf8(&'a ScalarValue), + LargeUtf8(&'a ScalarValue), + Utf8View(&'a ScalarValue), +} + +impl<'a> StringScalar<'a> { + /// Create a `StringScalar` view from an `Expr` if it is a supported string literal. + /// Returns `None` if the expression is not a string literal. + pub(crate) fn try_from_expr(expr: &'a Expr) -> Option { + match expr { + Expr::Literal(scalar, _) => Self::try_from_scalar(scalar), + _ => None, + } + } + + /// Create a `StringScalar` view from a `ScalarValue` if it is a supported string type. + /// Returns `None` if the scalar value is not a supported string type. + fn try_from_scalar(scalar: &'a ScalarValue) -> Option { + match scalar { + ScalarValue::Utf8(_) => Some(Self::Utf8(scalar)), + ScalarValue::LargeUtf8(_) => Some(Self::LargeUtf8(scalar)), + ScalarValue::Utf8View(_) => Some(Self::Utf8View(scalar)), + _ => None, + } + } + + /// Returns the underlying string slice. + pub(crate) fn as_str(&self) -> Option<&'a str> { + match self { + Self::Utf8(scalar) | Self::LargeUtf8(scalar) | Self::Utf8View(scalar) => { + scalar.try_as_str().flatten() + } + } + } + + /// Build a new `Expr` of the same string type with the given value. + pub(crate) fn to_expr(&self, val: &str) -> Expr { + match self { + Self::Utf8(_) => Expr::Literal(ScalarValue::Utf8(Some(val.to_owned())), None), + Self::LargeUtf8(_) => { + Expr::Literal(ScalarValue::LargeUtf8(Some(val.to_owned())), None) + } + Self::Utf8View(_) => { + Expr::Literal(ScalarValue::Utf8View(Some(val.to_owned())), None) + } + } } } +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool { let lhs_set: HashSet<&Expr> = iter_conjunction(lhs).collect(); iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile()) @@ -2068,6 +2267,7 @@ fn to_inlist(expr: Expr) -> Option { /// Return the union of two inlist expressions /// maintaining the order of the elements in the two lists +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result { // extend the list in l1 with the elements in l2 that are not already in l1 let l1_items: HashSet<_> = l1.list.iter().collect(); @@ -2086,6 +2286,7 @@ fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result { /// Return the intersection of two inlist expressions /// maintaining the order of the elements in the two lists +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key fn inlist_intersection(mut l1: InList, l2: &InList, negated: bool) -> Result { let l2_items = l2.list.iter().collect::>(); @@ -2102,6 +2303,7 @@ fn inlist_intersection(mut l1: InList, l2: &InList, negated: bool) -> Result Result { let l2_items = l2.list.iter().collect::>(); @@ -2115,7 +2317,7 @@ fn inlist_except(mut l1: InList, l2: &InList) -> Result { } /// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL). -fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result { +fn is_exactly_true(expr: Expr, info: &SimplifyContext) -> Result { if !info.nullable(&expr)? { Ok(expr) } else { @@ -2131,8 +2333,8 @@ fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result { // A / 1 -> A // // Move this function body out of the large match branch avoid stack overflow -fn simplify_right_is_one_case( - info: &S, +fn simplify_right_is_one_case( + info: &SimplifyContext, left: Box, op: &Operator, right: &Expr, @@ -2156,10 +2358,12 @@ fn simplify_right_is_one_case( #[cfg(test)] mod tests { use super::*; - use crate::simplify_expressions::SimplifyContext; use crate::test::test_table_scan_with_name; - use arrow::datatypes::FieldRef; - use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; + use arrow::{ + array::{Int32Array, StructArray}, + datatypes::{FieldRef, Fields}, + }; + use datafusion_common::{DFSchemaRef, ToDFSchema, assert_contains}; use datafusion_expr::{ expr::WindowFunction, function::{ @@ -2185,9 +2389,11 @@ mod tests { // ------------------------------ #[test] fn api_basic() { - let props = ExecutionProps::new(); - let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema())); + let simplifier = ExprSimplifier::new( + SimplifyContext::builder() + .with_schema(test_schema()) + .build(), + ); let expr = lit(1) + lit(2); let expected = lit(3); @@ -2197,9 +2403,10 @@ mod tests { #[test] fn basic_coercion() { let schema = test_schema(); - let props = ExecutionProps::new(); let simplifier = ExprSimplifier::new( - SimplifyContext::new(&props).with_schema(Arc::clone(&schema)), + SimplifyContext::builder() + .with_schema(Arc::clone(&schema)) + .build(), ); // Note expr type is int32 (not int64) @@ -2227,9 +2434,11 @@ mod tests { #[test] fn simplify_and_constant_prop() { - let props = ExecutionProps::new(); - let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema())); + let simplifier = ExprSimplifier::new( + SimplifyContext::builder() + .with_schema(test_schema()) + .build(), + ); // should be able to simplify to false // (i * (1 - 2)) > 0 @@ -2240,9 +2449,11 @@ mod tests { #[test] fn simplify_and_constant_prop_with_case() { - let props = ExecutionProps::new(); - let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(test_schema())); + let simplifier = ExprSimplifier::new( + SimplifyContext::builder() + .with_schema(test_schema()) + .build(), + ); // CASE // WHEN i>5 AND false THEN i > 5 @@ -2410,6 +2621,27 @@ mod tests { assert_eq!(simplify(expr_b), expected); } + #[test] + fn test_simplify_eq_and_neq_with_different_literals() { + // A = 1 AND A != 0 --> A = 1 (when 1 != 0) + let expr = col("c2").eq(lit(1)).and(col("c2").not_eq(lit(0))); + let expected = col("c2").eq(lit(1)); + assert_eq!(simplify(expr), expected); + + // A != 0 AND A = 1 --> A = 1 (when 1 != 0) + let expr = col("c2").not_eq(lit(0)).and(col("c2").eq(lit(1))); + let expected = col("c2").eq(lit(1)); + assert_eq!(simplify(expr), expected); + + // Should NOT simplify when literals are the same (A = 1 AND A != 1) + // This is a contradiction but handled by other rules + let expr = col("c2").eq(lit(1)).and(col("c2").not_eq(lit(1))); + // Should not be simplified by this rule (left unchanged or handled elsewhere) + let result = simplify(expr.clone()); + // The expression should not have been simplified + assert_eq!(result, expr); + } + #[test] fn test_simplify_multiply_by_one() { let expr_a = col("c2") * lit(1); @@ -2687,6 +2919,21 @@ mod tests { } } + #[test] + fn test_simplify_concat_by_null() { + let null = Expr::Literal(ScalarValue::Utf8(None), None); + // A || null --> null + { + let expr = binary_expr(col("c1"), Operator::StringConcat, null.clone()); + assert_eq!(simplify(expr), null); + } + // null || A --> null + { + let expr = binary_expr(null.clone(), Operator::StringConcat, col("c1")); + assert_eq!(simplify(expr), null); + } + } + #[test] fn test_simplify_composed_bitwise_and() { // ((c2 > 5) & (c1 < 6)) & (c2 > 5) --> (c2 > 5) & (c1 < 6) @@ -3306,6 +3553,32 @@ mod tests { assert_no_change(regex_match(col("c1"), lit("foo|bar|baz|blarg|bozo|etc"))); } + #[test] + fn test_simplify_not_regex_match() { + let pattern = || lit("foo.*"); + + // NOT (c1 ~ pattern) --> c1 !~ pattern + assert_eq!( + simplify(regex_match(col("c1"), pattern()).not()), + regex_not_match(col("c1"), pattern()), + ); + // NOT (c1 !~ pattern) --> c1 ~ pattern + assert_eq!( + simplify(regex_not_match(col("c1"), pattern()).not()), + regex_match(col("c1"), pattern()), + ); + // NOT (c1 ~* pattern) --> c1 !~* pattern + assert_eq!( + simplify(regex_imatch(col("c1"), pattern()).not()), + regex_not_imatch(col("c1"), pattern()), + ); + // NOT (c1 !~* pattern) --> c1 ~* pattern + assert_eq!( + simplify(regex_not_imatch(col("c1"), pattern()).not()), + regex_imatch(col("c1"), pattern()), + ); + } + #[track_caller] fn assert_no_change(expr: Expr) { let optimized = simplify(expr.clone()); @@ -3356,18 +3629,17 @@ mod tests { fn try_simplify(expr: Expr) -> Result { let schema = expr_test_schema(); - let execution_props = ExecutionProps::new(); - let simplifier = ExprSimplifier::new( - SimplifyContext::new(&execution_props).with_schema(schema), - ); + let simplifier = + ExprSimplifier::new(SimplifyContext::builder().with_schema(schema).build()); simplifier.simplify(expr) } fn coerce(expr: Expr) -> Expr { let schema = expr_test_schema(); - let execution_props = ExecutionProps::new(); let simplifier = ExprSimplifier::new( - SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)), + SimplifyContext::builder() + .with_schema(Arc::clone(&schema)) + .build(), ); simplifier.coerce(expr, schema.as_ref()).unwrap() } @@ -3378,10 +3650,8 @@ mod tests { fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> { let schema = expr_test_schema(); - let execution_props = ExecutionProps::new(); - let simplifier = ExprSimplifier::new( - SimplifyContext::new(&execution_props).with_schema(schema), - ); + let simplifier = + ExprSimplifier::new(SimplifyContext::builder().with_schema(schema).build()); let (expr, count) = simplifier.simplify_with_cycle_count_transformed(expr)?; Ok((expr.data, count)) } @@ -3395,11 +3665,9 @@ mod tests { guarantees: Vec<(Expr, NullableInterval)>, ) -> Expr { let schema = expr_test_schema(); - let execution_props = ExecutionProps::new(); - let simplifier = ExprSimplifier::new( - SimplifyContext::new(&execution_props).with_schema(schema), - ) - .with_guarantees(guarantees); + let simplifier = + ExprSimplifier::new(SimplifyContext::builder().with_schema(schema).build()) + .with_guarantees(guarantees); simplifier.simplify(expr).unwrap() } @@ -4301,8 +4569,7 @@ mod tests { fn just_simplifier_simplify_null_in_empty_inlist() { let simplify = |expr: Expr| -> Expr { let schema = expr_test_schema(); - let execution_props = ExecutionProps::new(); - let info = SimplifyContext::new(&execution_props).with_schema(schema); + let info = SimplifyContext::builder().with_schema(schema).build(); let simplifier = &mut Simplifier::new(&info); expr.rewrite(simplifier) .expect("Failed to simplify expression") @@ -4668,10 +4935,9 @@ mod tests { #[test] fn simplify_common_factor_conjunction_in_disjunction() { - let props = ExecutionProps::new(); let schema = boolean_test_schema(); let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema)); + ExprSimplifier::new(SimplifyContext::builder().with_schema(schema).build()); let a = || col("A"); let b = || col("B"); @@ -4745,10 +5011,6 @@ mod tests { } impl AggregateUDFImpl for SimplifyMockUdaf { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "mock_simplify" } @@ -4826,10 +5088,6 @@ mod tests { } impl WindowUDFImpl for SimplifyMockUdwf { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "mock_simplify" } @@ -4874,10 +5132,6 @@ mod tests { } } impl ScalarUDFImpl for VolatileUdf { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "VolatileUdf" } @@ -5001,9 +5255,8 @@ mod tests { // The simplification should now fail with an error at plan time let schema = test_schema(); - let props = ExecutionProps::new(); let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema)); + ExprSimplifier::new(SimplifyContext::builder().with_schema(schema).build()); let result = simplifier.simplify(expr); assert!(result.is_err(), "Expected error for invalid cast"); let err_msg = result.unwrap_err().to_string(); @@ -5017,4 +5270,168 @@ mod tests { else_expr: None, }) } + + // -------------------------------- + // --- Struct Cast Tests ----- + // -------------------------------- + + /// Helper to create a `Struct` literal cast expression from `source_fields` and `target_fields`. + fn make_struct_cast_expr(source_fields: Fields, target_fields: Fields) -> Expr { + // Create 1-row struct array (not 0-row) so it can be evaluated by simplifier + let arrays: Vec> = vec![ + Arc::new(Int32Array::from(vec![Some(1)])), + Arc::new(Int32Array::from(vec![Some(2)])), + ]; + let struct_array = StructArray::try_new(source_fields, arrays, None).unwrap(); + + Expr::Cast(Cast::new( + Box::new(Expr::Literal( + ScalarValue::Struct(Arc::new(struct_array)), + None, + )), + DataType::Struct(target_fields), + )) + } + + #[test] + fn test_struct_cast_different_field_counts_not_foldable() { + // Test that struct casts with different field counts are NOT marked as foldable + // When field counts differ, const-folding should not be attempted + + let source_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + let target_fields = Fields::from(vec![ + Arc::new(Field::new("x", DataType::Int32, true)), + Arc::new(Field::new("y", DataType::Int32, true)), + Arc::new(Field::new("z", DataType::Int32, true)), + ]); + + let expr = make_struct_cast_expr(source_fields, target_fields); + + let simplifier = ExprSimplifier::new( + SimplifyContext::builder() + .with_schema(test_schema()) + .build(), + ); + + // The cast should remain unchanged since field counts differ + let result = simplifier.simplify(expr.clone()).unwrap(); + // Ensure const-folding was not attempted (the expression remains exactly the same) + assert_eq!( + result, expr, + "Struct cast with different field counts should remain unchanged (no const-folding)" + ); + } + + #[test] + fn test_struct_cast_same_field_count_foldable() { + // Test that struct casts with same field counts can be considered for const-folding + + let source_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + let target_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + let expr = make_struct_cast_expr(source_fields, target_fields); + + let simplifier = ExprSimplifier::new( + SimplifyContext::builder() + .with_schema(test_schema()) + .build(), + ); + + // The cast should be simplified + let result = simplifier.simplify(expr.clone()).unwrap(); + // Struct casts with same field count should be const-folded to a literal + assert!(matches!(result, Expr::Literal(_, _))); + // Ensure the simplifier made a change (not identical to original) + assert_ne!( + result, expr, + "Struct cast with same field count should be simplified (not identical to input)" + ); + } + + #[test] + fn test_struct_cast_different_names_same_count() { + // Test struct cast with same field count but different names + // Field count matches; simplification should be skipped because names do not overlap + + let source_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + let target_fields = Fields::from(vec![ + Arc::new(Field::new("x", DataType::Int32, true)), + Arc::new(Field::new("y", DataType::Int32, true)), + ]); + + let expr = make_struct_cast_expr(source_fields, target_fields); + + let simplifier = ExprSimplifier::new( + SimplifyContext::builder() + .with_schema(test_schema()) + .build(), + ); + + // The cast should remain unchanged because there is no name overlap + let result = simplifier.simplify(expr.clone()).unwrap(); + assert_eq!( + result, expr, + "Struct cast with different names but same field count should not be simplified" + ); + } + + #[test] + fn test_struct_cast_empty_array_not_foldable() { + // Test that struct casts with 0-row (empty) struct arrays are NOT const-folded + // The simplifier uses a 1-row input batch, which causes dimension mismatches + // when evaluating 0-row struct literals + + let source_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + let target_fields = Fields::from(vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]); + + // Create a 0-row (empty) struct array + let arrays: Vec> = vec![ + Arc::new(Int32Array::new(vec![].into(), None)), + Arc::new(Int32Array::new(vec![].into(), None)), + ]; + let struct_array = StructArray::try_new(source_fields, arrays, None).unwrap(); + + let expr = Expr::Cast(Cast::new( + Box::new(Expr::Literal( + ScalarValue::Struct(Arc::new(struct_array)), + None, + )), + DataType::Struct(target_fields), + )); + + let simplifier = ExprSimplifier::new( + SimplifyContext::builder() + .with_schema(test_schema()) + .build(), + ); + + // The cast should remain unchanged since the struct array is empty (0-row) + let result = simplifier.simplify(expr.clone()).unwrap(); + assert_eq!( + result, expr, + "Struct cast with empty (0-row) array should remain unchanged" + ); + } } diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index a1c1dc17d2945..17112d4f0ae24 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -19,10 +19,10 @@ use super::THRESHOLD_INLINE_INLIST; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::Result; -use datafusion_expr::expr::InList; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_expr::Expr; +use datafusion_expr::expr::InList; pub(super) struct ShortenInListSimplifier {} @@ -43,52 +43,50 @@ impl TreeNodeRewriter for ShortenInListSimplifier { ref list, negated, }) = expr + && !list.is_empty() + && ( + // For lists with only 1 value we allow more complex expressions to be simplified + // e.g SUBSTR(c1, 2, 3) IN ('1') -> SUBSTR(c1, 2, 3) = '1' + // for more than one we avoid repeating this potentially expensive + // expressions + list.len() == 1 + || list.len() <= THRESHOLD_INLINE_INLIST + && expr.try_as_col().is_some() + ) { - if !list.is_empty() - && ( - // For lists with only 1 value we allow more complex expressions to be simplified - // e.g SUBSTR(c1, 2, 3) IN ('1') -> SUBSTR(c1, 2, 3) = '1' - // for more than one we avoid repeating this potentially expensive - // expressions - list.len() == 1 - || list.len() <= THRESHOLD_INLINE_INLIST - && expr.try_as_col().is_some() - ) - { - let first_val = list[0].clone(); - if negated { - return Ok(Transformed::yes(list.iter().skip(1).cloned().fold( - (*expr.clone()).not_eq(first_val), - |acc, y| { - // Note that `A and B and C and D` is a left-deep tree structure - // as such we want to maintain this structure as much as possible - // to avoid reordering the expression during each optimization - // pass. - // - // Left-deep tree structure for `A and B and C and D`: - // ``` - // & - // / \ - // & D - // / \ - // & C - // / \ - // A B - // ``` - // - // The code below maintain the left-deep tree structure. - acc.and((*expr.clone()).not_eq(y)) - }, - ))); - } else { - return Ok(Transformed::yes(list.iter().skip(1).cloned().fold( - (*expr.clone()).eq(first_val), - |acc, y| { - // Same reasoning as above - acc.or((*expr.clone()).eq(y)) - }, - ))); - } + let first_val = list[0].clone(); + if negated { + return Ok(Transformed::yes(list.iter().skip(1).cloned().fold( + (*expr.clone()).not_eq(first_val), + |acc, y| { + // Note that `A and B and C and D` is a left-deep tree structure + // as such we want to maintain this structure as much as possible + // to avoid reordering the expression during each optimization + // pass. + // + // Left-deep tree structure for `A and B and C and D`: + // ``` + // & + // / \ + // & D + // / \ + // & C + // / \ + // A B + // ``` + // + // The code below maintain the left-deep tree structure. + acc.and((*expr.clone()).not_eq(y)) + }, + ))); + } else { + return Ok(Transformed::yes(list.iter().skip(1).cloned().fold( + (*expr.clone()).eq(first_val), + |acc, y| { + // Same reasoning as above + acc.or((*expr.clone()).eq(y)) + }, + ))); } } diff --git a/datafusion/optimizer/src/simplify_expressions/linear_aggregates.rs b/datafusion/optimizer/src/simplify_expressions/linear_aggregates.rs new file mode 100644 index 0000000000000..21389cf326c24 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/linear_aggregates.rs @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Simplification to refactor multiple aggregate functions to use the same aggregate function + +use datafusion_common::HashMap; +use datafusion_expr::expr::AggregateFunctionParams; +use datafusion_expr::{BinaryExpr, Expr}; +use datafusion_expr_common::operator::Operator; + +/// Threshold of the number of aggregates that share similar arguments before +/// triggering rewrite. +/// +/// There is a threshold because the canonical SUM rewrite described in +/// [`AggregateUDFImpl::simplify_expr_op_literal`] actually results in more +/// aggregates (2) for each original aggregate. It is important that CSE then +/// eliminate them. +/// +/// [`AggregateUDFImpl::simplify_expr_op_literal`]: datafusion_expr::AggregateUDFImpl::simplify_expr_op_literal +const DUPLICATE_THRESHOLD: usize = 2; + +/// Rewrites multiple aggregate expressions that have a common linear component +/// into multiple aggregate expressions that share that common component. +/// +/// For example, rewrites patterns such as +/// * `SUM(x + 1), SUM(x + 2), ...` +/// +/// Into +/// * `SUM(x) + 1 * COUNT(x), SUM(x) + 2 * COUNT(x), ...` +/// +/// See the background [`AggregateUDFImpl::simplify_expr_op_literal`] for details. +/// +/// Returns `true` if any of the arguments are rewritten (modified), `false` +/// otherwise. +/// +/// ## Design goals: +/// 1. Keep the aggregate specific logic out of the optimizer (can't depend directly on SUM) +/// 2. Optimize for the case that this rewrite will not apply (it almost never does) +/// +/// [`AggregateUDFImpl::simplify_expr_op_literal`]: datafusion_expr::AggregateUDFImpl::simplify_expr_op_literal +pub(super) fn rewrite_multiple_linear_aggregates( + agg_expr: &mut [Expr], +) -> datafusion_common::Result { + // map : count of expressions that have a common argument + let mut common_args = HashMap::new(); + + // First pass -- figure out any aggregates that can be split and have common + // expressions. + for agg in agg_expr.iter() { + let Expr::AggregateFunction(agg_function) = agg else { + continue; + }; + + let Some(arg) = candidate_linear_param(&agg_function.params) else { + continue; + }; + + let Some(expr_literal) = ExprLiteral::try_new(arg) else { + continue; + }; + + let counter = common_args.entry(expr_literal.expr()).or_insert(0); + *counter += 1; + } + + // (agg_index, new_expr) + let mut new_aggs = vec![]; + + // Second pass, actually rewrite any aggregates that have a common + // expression and enough duplicates. + for (idx, agg) in agg_expr.iter().enumerate() { + let Expr::AggregateFunction(agg_function) = agg else { + continue; + }; + + let Some(arg) = candidate_linear_param(&agg_function.params) else { + continue; + }; + + let Some(expr_literal) = ExprLiteral::try_new(arg) else { + continue; + }; + + // Not enough common expressions to make it worth rewriting + if common_args.get(expr_literal.expr()).unwrap_or(&0) < &DUPLICATE_THRESHOLD { + continue; + } + + if let Some(new_agg_function) = agg_function.func.simplify_expr_op_literal( + agg_function, + expr_literal.expr(), + expr_literal.op(), + expr_literal.lit(), + expr_literal.arg_is_left(), + )? { + new_aggs.push((idx, new_agg_function)); + } + } + + if new_aggs.is_empty() { + return Ok(false); + } + + // Otherwise replace the aggregate expressions + drop(common_args); // release borrow + for (idx, new_agg) in new_aggs { + let orig_name = agg_expr[idx].name_for_alias()?; + agg_expr[idx] = new_agg.alias_if_changed(orig_name)? + } + + Ok(true) +} + +/// Returns Some(&Expr) with the single argument if this is a suitable candidate +/// for the linear rewrite +fn candidate_linear_param(params: &AggregateFunctionParams) -> Option<&Expr> { + // Explicitly destructure to ensure we check all relevant fields + let AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + } = params; + + // Disqualify anything "non standard" + if *distinct + || filter.is_some() + || !order_by.is_empty() + || null_treatment.is_some() + || args.len() != 1 + { + return None; + } + let arg = args.first()?; + if arg.is_volatile() { + return None; + }; + Some(arg) +} + +/// A view into a [`Expr::BinaryExpr`] that is arbitrary expression and a +/// literal +/// +/// This is an enum to distinguish the direction of the operator arguments +#[derive(Debug, Clone)] +pub enum ExprLiteral<'a> { + /// if the expression is ` ` + ArgOpLit { + arg: &'a Expr, + op: Operator, + lit: &'a Expr, + }, + /// if the expression is ` ` + LitOpArg { + lit: &'a Expr, + op: Operator, + arg: &'a Expr, + }, +} + +impl<'a> ExprLiteral<'a> { + /// Try and split the Expr into its parts + fn try_new(expr: &'a Expr) -> Option { + match expr { + // + Expr::BinaryExpr(BinaryExpr { left, op, right }) + if matches!(left.as_ref(), Expr::Literal(..)) => + { + Some(Self::LitOpArg { + arg: right, + lit: left, + op: *op, + }) + } + + // + + Expr::BinaryExpr(BinaryExpr { left, op, right }) + if matches!(right.as_ref(), Expr::Literal(..)) => + { + Some(Self::ArgOpLit { + arg: left, + lit: right, + op: *op, + }) + } + _ => None, + } + } + + fn expr(&self) -> &'a Expr { + match self { + Self::ArgOpLit { arg, .. } => arg, + Self::LitOpArg { arg, .. } => arg, + } + } + + fn lit(&self) -> &'a Expr { + match self { + Self::ArgOpLit { lit, .. } => lit, + Self::LitOpArg { lit, .. } => lit, + } + } + + fn op(&self) -> Operator { + match self { + Self::ArgOpLit { op, .. } => *op, + Self::LitOpArg { op, .. } => *op, + } + } + + fn arg_is_left(&self) -> bool { + matches!(self, Self::ArgOpLit { .. }) + } +} diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index e238fca32689d..e0b53b79d468c 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -20,16 +20,21 @@ pub mod expr_simplifier; mod inlist_simplifier; +mod linear_aggregates; mod regex; +mod reorder_predicates; pub mod simplify_exprs; +pub mod simplify_literal; mod simplify_predicates; +mod udf_preimage; mod unwrap_cast; mod utils; // backwards compatibility -pub use datafusion_expr::simplify::{SimplifyContext, SimplifyInfo}; +pub use datafusion_expr::simplify::SimplifyContext; pub use expr_simplifier::*; +pub(crate) use reorder_predicates::reorder_predicates; pub use simplify_exprs::*; pub use simplify_predicates::simplify_predicates; diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index 82c5ea3d8d820..df4c344b2e407 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -15,10 +15,13 @@ // specific language governing permissions and limitations // under the License. +use datafusion_common::tree_node::Transformed; use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::{lit, BinaryExpr, Expr, Like, Operator}; +use datafusion_expr::{BinaryExpr, Expr, Like, Operator, lit}; use regex_syntax::hir::{Capture, Hir, HirKind, Literal, Look}; +use crate::simplify_expressions::expr_simplifier::StringScalar; + /// Maximum number of regex alternations (`foo|bar|...`) that will be expanded into multiple `LIKE` expressions. const MAX_REGEX_ALTERNATIONS_EXPANSION: usize = 4; @@ -36,59 +39,76 @@ const ANY_CHAR_REGEX_PATTERN: &str = ".*"; /// - partial anchored regex patterns (e.g. `^foo`) to `LIKE 'foo%'` /// - combinations (alternatives) of the above, will be concatenated with `OR` or `AND` /// - `EQ .*` to NotNull -/// - `NE .*` means IS EMPTY +/// - `NE .*` to col IS NULL AND Boolean(NULL) (false for any string, or NULL if col is NULL) /// /// Dev note: unit tests of this function are in `expr_simplifier.rs`, case `test_simplify_regex`. pub fn simplify_regex_expr( left: Box, op: Operator, right: Box, -) -> Result { - let mode = OperatorMode::new(&op); +) -> Result> { + // Check if the right operand is a supported string literal + let Some(string_scalar) = StringScalar::try_from_expr(right.as_ref()) else { + return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))); + }; + let pattern = string_scalar.as_str(); + let Some(pattern) = pattern else { + return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))); + }; - if let Expr::Literal(ScalarValue::Utf8(Some(pattern)), _) = right.as_ref() { - // Handle the special case for ".*" pattern - if pattern == ANY_CHAR_REGEX_PATTERN { - let new_expr = if mode.not { - // not empty - let empty_lit = Box::new(lit("")); - Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right: empty_lit, - }) - } else { - // not null - left.is_not_null() - }; - return Ok(new_expr); - } + let mode = OperatorMode::new(&op); + // Handle the special case for ".*" pattern + if pattern == ANY_CHAR_REGEX_PATTERN { + let new_expr = if mode.not { + let null_bool = lit(ScalarValue::Boolean(None)); + Expr::BinaryExpr(BinaryExpr { + left: Box::new(left.is_null()), + op: Operator::And, + right: Box::new(null_bool), + }) + } else { + // not null + left.is_not_null() + }; + return Ok(Transformed::yes(new_expr)); + } - match regex_syntax::Parser::new().parse(pattern) { - Ok(hir) => { - let kind = hir.kind(); - if let HirKind::Alternation(alts) = kind { - if alts.len() <= MAX_REGEX_ALTERNATIONS_EXPANSION { - if let Some(expr) = lower_alt(&mode, &left, alts) { - return Ok(expr); - } - } - } else if let Some(expr) = lower_simple(&mode, &left, &hir) { - return Ok(expr); + match regex_syntax::Parser::new().parse(pattern) { + Ok(hir) => { + let kind = hir.kind(); + if let HirKind::Alternation(alts) = kind { + if alts.len() <= MAX_REGEX_ALTERNATIONS_EXPANSION + && let Some(expr) = lower_alt(&mode, &left, alts, &string_scalar) + { + return Ok(Transformed::yes(expr)); } + } else if let Some(expr) = lower_simple(&mode, &left, &hir, &string_scalar) { + return Ok(Transformed::yes(expr)); } - Err(e) => { - // error out early since the execution may fail anyways - return Err(DataFusionError::Context( - "Invalid regex".to_owned(), - Box::new(DataFusionError::External(Box::new(e))), - )); - } + } + Err(e) => { + // error out early since the execution may fail anyways + return Err(DataFusionError::Context( + "Invalid regex".to_owned(), + Box::new(DataFusionError::External(Box::new(e))), + )); } } // Leave untouched if optimization didn't work - Ok(Expr::BinaryExpr(BinaryExpr { left, op, right })) + Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))) } #[derive(Debug)] @@ -117,11 +137,11 @@ impl OperatorMode { } /// Creates an [`LIKE`](Expr::Like) from the given `LIKE` pattern. - fn expr(&self, expr: Box, pattern: String) -> Expr { + fn expr(&self, expr: Box, pattern: Box) -> Expr { let like = Like { negated: self.not, expr, - pattern: Box::new(Expr::Literal(ScalarValue::from(pattern), None)), + pattern, escape_char: None, case_insensitive: self.i, }; @@ -263,20 +283,23 @@ fn partial_anchored_literal_to_like(v: &[Hir]) -> Option { /// Extracts a string literal expression assuming that [`is_anchored_literal`] /// returned true. -fn anchored_literal_to_expr(v: &[Hir]) -> Option { +fn anchored_literal_to_expr(v: &[Hir], string_scalar: &StringScalar) -> Option { match v.len() { - 2 => Some(lit("")), + 2 => Some(string_scalar.to_expr("")), 3 => { let HirKind::Literal(l) = v[1].kind() else { return None; }; - like_str_from_literal(l).map(lit) + like_str_from_literal(l).map(|s| string_scalar.to_expr(s)) } _ => None, } } -fn anchored_alternation_to_exprs(v: &[Hir]) -> Option> { +fn anchored_alternation_to_exprs( + v: &[Hir], + string_scalar: &StringScalar, +) -> Option> { if 3 != v.len() { return None; } @@ -287,11 +310,12 @@ fn anchored_alternation_to_exprs(v: &[Hir]) -> Option> { let mut literals = Vec::with_capacity(alters.len()); for hir in alters { let mut is_safe = false; - if let HirKind::Literal(l) = hir.kind() { - if let Some(safe_literal) = str_from_literal(l).map(lit) { - literals.push(safe_literal); - is_safe = true; - } + if let HirKind::Literal(l) = hir.kind() + && let Some(safe_literal) = + str_from_literal(l).map(|s| string_scalar.to_expr(s)) + { + literals.push(safe_literal); + is_safe = true; } if !is_safe { @@ -301,7 +325,9 @@ fn anchored_alternation_to_exprs(v: &[Hir]) -> Option> { return Some(literals); } else if let HirKind::Literal(l) = sub.kind() { - if let Some(safe_literal) = str_from_literal(l).map(lit) { + if let Some(safe_literal) = + str_from_literal(l).map(|s| string_scalar.to_expr(s)) + { return Some(vec![safe_literal]); } return None; @@ -311,29 +337,48 @@ fn anchored_alternation_to_exprs(v: &[Hir]) -> Option> { } /// Tries to lower (transform) a simple regex pattern to a LIKE expression. -fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { +fn lower_simple( + mode: &OperatorMode, + left: &Expr, + hir: &Hir, + string_scalar: &StringScalar, +) -> Option { match hir.kind() { HirKind::Empty => { - return Some(mode.expr(Box::new(left.clone()), "%".to_owned())); + return Some( + mode.expr(Box::new(left.clone()), Box::new(string_scalar.to_expr("%"))), + ); } HirKind::Literal(l) => { let s = like_str_from_literal(l)?; - return Some(mode.expr(Box::new(left.clone()), format!("%{s}%"))); + return Some(mode.expr( + Box::new(left.clone()), + Box::new(string_scalar.to_expr(&format!("%{s}%"))), + )); } HirKind::Concat(inner) if is_anchored_literal(inner) => { - return anchored_literal_to_expr(inner).map(|right| { - mode.expr_matches_literal(Box::new(left.clone()), Box::new(right)) + return anchored_literal_to_expr(inner, string_scalar).map(|right| { + if mode.i { + // Case-insensitive: use ILIKE for exact match (no wildcards) + mode.expr(Box::new(left.clone()), Box::new(right)) + } else { + // Case-sensitive: use Eq / NotEq + mode.expr_matches_literal(Box::new(left.clone()), Box::new(right)) + } }); } - HirKind::Concat(inner) if is_anchored_capture(inner) => { - return anchored_alternation_to_exprs(inner) + HirKind::Concat(inner) if !mode.i && is_anchored_capture(inner) => { + return anchored_alternation_to_exprs(inner, string_scalar) .map(|right| left.clone().in_list(right, mode.not)); } HirKind::Concat(inner) => { if let Some(pattern) = partial_anchored_literal_to_like(inner) .or_else(|| collect_concat_to_like_string(inner)) { - return Some(mode.expr(Box::new(left.clone()), pattern)); + return Some(mode.expr( + Box::new(left.clone()), + Box::new(string_scalar.to_expr(&pattern)), + )); } } _ => {} @@ -344,11 +389,16 @@ fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { /// Calls [`lower_simple`] for each alternative and combine the results with `or` or `and` /// based on [`OperatorMode`]. Any fail attempt to lower an alternative will makes this /// function to return `None`. -fn lower_alt(mode: &OperatorMode, left: &Expr, alts: &[Hir]) -> Option { +fn lower_alt( + mode: &OperatorMode, + left: &Expr, + alts: &[Hir], + string_scalar: &StringScalar, +) -> Option { let mut accu: Option = None; for part in alts { - if let Some(expr) = lower_simple(mode, left, part) { + if let Some(expr) = lower_simple(mode, left, part, string_scalar) { accu = match accu { Some(accu) => { if mode.not { diff --git a/datafusion/optimizer/src/simplify_expressions/reorder_predicates.rs b/datafusion/optimizer/src/simplify_expressions/reorder_predicates.rs new file mode 100644 index 0000000000000..221fa5d20c58c --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/reorder_predicates.rs @@ -0,0 +1,193 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Reorder conjunctive (`AND`) predicates so that cheap predicates run before +//! expensive ones. +//! +//! DataFusion's `AND` evaluator short-circuits the right-hand side when the +//! left-hand side keeps few rows, so leading with a cheap predicate shrinks +//! the batch that expensive ones see. +//! +//! The cost of evaluating a predicate is assessed with a simple, conservative +//! heuristic: we define an allow-list of cheap operations, and consider an +//! expression to be cheap if it consists ONLY of cheap operations; everything +//! else is considered expensive. The sort of stable, so order within each +//! class is preserved. +//! +//! This reordering scheme is intentionally simple; many enhancements are +//! possible (e.g., consider both cost and selectivity, build a more complex +//! cost model, add estimated evaluation cost for individual UDFs). + +use datafusion_common::tree_node::TreeNode; +use datafusion_expr::{BinaryExpr, Expr, Operator}; + +/// Stable partition of `predicates`: cheap first, then expensive. +/// +/// Returns `(predicates, changed)`. When `changed` is `false` the input was +/// already cheap-first and the caller can skip rebuilding the conjunction. +pub(crate) fn reorder_predicates(predicates: Vec) -> (Vec, bool) { + if predicates.len() <= 1 { + return (predicates, false); + } + + // Volatile predicates may have observable side-effects and reordering + // conjuncts can change how many times they evaluate. Preserve user order + // if any predicate contains a volatile expression. + if predicates.iter().any(Expr::is_volatile) { + return (predicates, false); + } + + let classes: Vec = predicates.iter().map(is_cheap_predicate).collect(); + + // A reorder is needed iff an expensive predicate precedes a cheap one + let needs_reorder = classes.windows(2).any(|w| !w[0] && w[1]); + if !needs_reorder { + return (predicates, false); + } + + let mut cheap = Vec::with_capacity(predicates.len()); + let mut expensive = Vec::new(); + for (p, is_cheap) in predicates.into_iter().zip(classes) { + if is_cheap { + cheap.push(p); + } else { + expensive.push(p); + } + } + cheap.extend(expensive); + (cheap, true) +} + +/// Returns true if every node in `expr`'s tree is cheap. +fn is_cheap_predicate(expr: &Expr) -> bool { + !expr + .exists(|node| Ok(!is_cheap_node(node))) + .expect("is_cheap_node is infallible") +} + +/// Returns true if `expr` is itself cheap. +/// +/// We use a simple, conservative heuristic to determine if an expression is +/// cheap to evaluate: we enumerate known-cheap operations (e.g., equality +/// comparisons, negations, casts), and consider anything outside this list to +/// be expensive. New/unrecognized expressions therefore default to being +/// expensive. +fn is_cheap_node(expr: &Expr) -> bool { + match expr { + // Direct reads and literals. + Expr::Column(_) + | Expr::Literal(_, _) + | Expr::ScalarVariable(_, _) + | Expr::Placeholder(_) + | Expr::OuterReferenceColumn(_, _) + | Expr::LambdaVariable(_) + // Wrappers; children are walked separately by `is_cheap_predicate`. + | Expr::Alias(_) + // Single-row unary predicates and arithmetic negation. + | Expr::Not(_) + | Expr::Negative(_) + | Expr::IsNull(_) + | Expr::IsNotNull(_) + | Expr::IsTrue(_) + | Expr::IsFalse(_) + | Expr::IsUnknown(_) + | Expr::IsNotTrue(_) + | Expr::IsNotFalse(_) + | Expr::IsNotUnknown(_) + // Composite cheap forms; child expressions are walked separately. + | Expr::Between(_) + | Expr::Case(_) + | Expr::Cast(_) + | Expr::TryCast(_) + | Expr::InList(_) => true, + // BinaryExpr is cheap unless the operator is LIKE or regexp matching. + Expr::BinaryExpr(BinaryExpr { op, .. }) => !matches!( + op, + Operator::LikeMatch + | Operator::ILikeMatch + | Operator::NotLikeMatch + | Operator::NotILikeMatch + | Operator::RegexMatch + | Operator::RegexIMatch + | Operator::RegexNotMatch + | Operator::RegexNotIMatch + ), + _ => false, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_expr::{col, lit}; + + #[test] + fn like_predicate_moves_after_equality() { + let cheap = col("a").eq(lit(1)); + let expensive = col("b").like(lit("%foo%")); + let (out, changed) = reorder_predicates(vec![expensive.clone(), cheap.clone()]); + assert_eq!(out, vec![cheap, expensive]); + assert!(changed); + } + + #[test] + fn order_among_cheap_predicates_is_preserved() { + let p1 = col("a").eq(lit(1)); + let p2 = col("b").eq(lit(2)); + let p3 = col("c").eq(lit(3)); + let input = vec![p1.clone(), p2.clone(), p3.clone()]; + let (out, changed) = reorder_predicates(input.clone()); + assert_eq!(out, input); + assert!(!changed); + } + + #[test] + fn order_among_expensive_predicates_is_preserved() { + let p1 = col("a").like(lit("%a%")); + let p2 = Expr::BinaryExpr(BinaryExpr::new( + Box::new(col("b")), + Operator::RegexMatch, + Box::new(lit("foo")), + )); + let p3 = col("c").like(lit("%c%")); + let input = vec![p1.clone(), p2.clone(), p3.clone()]; + let (out, changed) = reorder_predicates(input.clone()); + assert_eq!(out, input); + assert!(!changed); + } + + #[test] + fn already_cheap_first_reports_no_change() { + let cheap = col("a").eq(lit(1)); + let expensive = col("b").like(lit("%a%")); + let input = vec![cheap.clone(), expensive.clone()]; + let (out, changed) = reorder_predicates(input.clone()); + assert_eq!(out, input); + assert!(!changed); + } + + #[test] + fn nested_expensive_under_not_is_expensive() { + // The top node is `Not`, which is on the cheap allow-list. The walk + // must descend into the `Like` to flag this predicate as expensive. + let cheap = col("a").eq(lit(1)); + let nested = Expr::Not(Box::new(col("b").like(lit("%foo%")))); + let (out, changed) = reorder_predicates(vec![nested.clone(), cheap.clone()]); + assert_eq!(out, vec![cheap, nested]); + assert!(changed); + } +} diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 4faf9389cfac4..3e495f5355103 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -20,19 +20,20 @@ use std::sync::Arc; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result}; -use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::utils::merge_schema; +use datafusion_common::{Column, DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::Expr; +use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection}; +use datafusion_expr::simplify::SimplifyContext; +use datafusion_expr::utils::{ + columnize_expr, find_aggregate_exprs, grouping_set_to_exprlist, merge_schema, +}; +use super::ExprSimplifier; use crate::optimizer::ApplyOrder; +use crate::simplify_expressions::linear_aggregates::rewrite_multiple_linear_aggregates; use crate::utils::NamePreserver; use crate::{OptimizerConfig, OptimizerRule}; -use super::ExprSimplifier; - /// Optimizer Pass that simplifies [`LogicalPlan`]s by rewriting /// [`Expr`]`s evaluating constants and applying algebraic /// simplifications @@ -67,17 +68,14 @@ impl OptimizerRule for SimplifyExpressions { plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result, DataFusionError> { - let mut execution_props = ExecutionProps::new(); - execution_props.query_execution_start_time = config.query_execution_start_time(); - execution_props.config_options = Some(config.options()); - Self::optimize_internal(plan, &execution_props) + Self::optimize_internal(plan, config) } } impl SimplifyExpressions { fn optimize_internal( plan: LogicalPlan, - execution_props: &ExecutionProps, + config: &dyn OptimizerConfig, ) -> Result> { let schema = if !plan.inputs().is_empty() { DFSchemaRef::new(merge_schema(&plan.inputs())) @@ -100,7 +98,11 @@ impl SimplifyExpressions { Arc::new(DFSchema::empty()) }; - let info = SimplifyContext::new(execution_props).with_schema(schema); + let info = SimplifyContext::builder() + .with_schema(schema) + .with_config_options(config.options()) + .with_query_execution_start_time(config.query_execution_start_time()) + .build(); // Inputs have already been rewritten (due to bottom-up traversal handled by Optimizer) // Just need to rewrite our own expressions @@ -138,17 +140,110 @@ impl SimplifyExpressions { } else { rewrite_expr(expr) } - }) + })? + .transform_data(rewrite_aggregate_non_aggregate_aggr_expr) } } impl SimplifyExpressions { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } } +/// Ensures that `LogicalPlan::Aggregate` is well formed after rewrites +/// by potentially introducing an extra `Projection`. +/// +/// Also applies the [`rewrite_multiple_linear_aggregates`] special case +/// +/// # Rationale: +/// +/// [`LogicalPlan::Aggregate`] requires agg expressions to be (possibly aliased) +/// [`Expr::AggregateFunction`]. Some UDAF simplifiers may return other [`Expr`] +/// variants. +/// +/// # Operation +/// +/// Rewrites things like this (note that `exp1` is not an aggregate): +/// * `Aggregate(group_expr, aggr_expr=[exp1 + agg(exp2)])` +/// +/// into: +/// * `Projection(exp1 + _X)` +/// * ` Aggregate(group_expr, aggr_expr=[agg(exp2) AS _X])` +fn rewrite_aggregate_non_aggregate_aggr_expr( + plan: LogicalPlan, +) -> Result> { + let LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + mut aggr_expr, + schema, + .. + }) = plan + else { + return Ok(Transformed::no(plan)); + }; + + let rewrote_aggs = rewrite_multiple_linear_aggregates(&mut aggr_expr)?; + + // Ensure that all Aggregate arguments are AggregateExpr + if aggr_expr.iter().all(is_top_level_aggregate_expr) { + let new_plan = LogicalPlan::Aggregate(Aggregate::try_new_with_schema( + input, group_expr, aggr_expr, schema, + )?); + return if !rewrote_aggs { + Ok(Transformed::no(new_plan)) + } else { + Ok(Transformed::yes(new_plan)) + }; + } + + // Otherwise we need to add a Projection above Aggregate to calculate + // the final output expressions. + + let inner_aggr_expr = find_aggregate_exprs(aggr_expr.iter()); + let inner_aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::clone(&input), + group_expr.clone(), + inner_aggr_expr, + )?); + let inner_aggregate = Arc::new(inner_aggregate); + + let mut projection_exprs = aggregate_output_exprs(&group_expr)?; + projection_exprs.extend(aggr_expr); + let projection_exprs = projection_exprs + .into_iter() + .map(|expr| columnize_expr(expr, inner_aggregate.as_ref())) + .collect::>>()?; + + Ok(Transformed::yes(LogicalPlan::Projection( + Projection::try_new(projection_exprs, inner_aggregate)?, + ))) +} + +fn is_top_level_aggregate_expr(expr: &Expr) -> bool { + matches!( + expr.clone().unalias_nested().data, + Expr::AggregateFunction(_) + ) +} + +fn aggregate_output_exprs(group_expr: &[Expr]) -> Result> { + let mut output_exprs = grouping_set_to_exprlist(group_expr)? + .into_iter() + .cloned() + .collect::>(); + + if matches!(group_expr, [Expr::GroupingSet(_)]) { + output_exprs.push(Expr::Column(Column::from_name( + Aggregate::INTERNAL_GROUPING_ID, + ))); + } + + Ok(output_exprs) +} + #[cfg(test)] mod tests { use std::ops::Not; @@ -156,14 +251,15 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use chrono::{DateTime, Utc}; + use datafusion_common::ScalarValue; use datafusion_expr::logical_plan::builder::table_scan_with_filters; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::*; - use datafusion_functions_aggregate::expr_fn::{max, min}; + use datafusion_functions_aggregate::expr_fn::{max, min, sum}; + use crate::OptimizerContext; use crate::assert_optimized_plan_eq_snapshot; use crate::test::{assert_fields_eq, test_table_scan_with_name}; - use crate::OptimizerContext; use super::*; @@ -219,7 +315,7 @@ mod tests { assert_optimized_plan_equal!( table_scan, - @ r"TableScan: test projection=[a], full_filters=[Boolean(true)]" + @ "TableScan: test projection=[a], full_filters=[Boolean(true)]" ) } @@ -252,13 +348,59 @@ mod tests { assert_optimized_plan_equal!( plan, @ r" - Filter: test.b > Int32(1) - Projection: test.a - TableScan: test - " + Filter: test.b > Int32(1) + Projection: test.a + TableScan: test + " ) } + #[test] + fn test_simplify_udaf_to_non_aggregate_expr() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]); + let table_scan = table_scan(Some("test"), &schema, None)? + .build() + .expect("building scan"); + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![sum(col("a") + lit(2i64))])? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[]], aggr=[[sum(test.a + Int64(2))]] + TableScan: test + " + )?; + Ok(()) + } + + #[test] + fn test_simplify_common_sum_arg() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]); + let table_scan = table_scan(Some("test"), &schema, None)? + .build() + .expect("building scan"); + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + Vec::::new(), + vec![sum(col("a") + lit(2i64)), sum(col("a") + lit(3i64))], + )? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Projection: sum(test.a) + Int64(2) * CAST(count(test.a) AS Int64) AS sum(test.a + Int64(2)), sum(test.a) + Int64(3) * CAST(count(test.a) AS Int64) AS sum(test.a + Int64(3)) + Aggregate: groupBy=[[]], aggr=[[sum(test.a), count(test.a)]] + TableScan: test + " + )?; + Ok(()) + } + #[test] fn test_simplify_optimized_plan_with_or() -> Result<()> { let table_scan = test_table_scan(); @@ -270,10 +412,10 @@ mod tests { assert_optimized_plan_equal!( plan, @ r" - Filter: test.b > Int32(1) - Projection: test.a - TableScan: test - " + Filter: test.b > Int32(1) + Projection: test.a + TableScan: test + " ) } @@ -492,8 +634,7 @@ mod tests { .build()?; let actual = get_optimized_plan_formatted(plan, &time); - let expected = - "Projection: NOT test.a AS Boolean(true) OR Boolean(false) != test.a\ + let expected = "Projection: NOT test.a AS Boolean(true) OR Boolean(false) != test.a\ \n TableScan: test"; assert_eq!(expected, actual); @@ -872,7 +1013,7 @@ mod tests { ]); let table_scan = table_scan(Some("test"), &schema, None)?.build()?; - // Test `= ".*"` transforms to true (except for empty strings) + // Test `~ ".*"` transforms to true for any non-NULL string let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("a"), Operator::RegexMatch, lit(".*")))? .build()?; @@ -885,22 +1026,22 @@ mod tests { " )?; - // Test `!= ".*"` transforms to checking if the column is empty + // Test `!~ ".*"` preserves NULL semantics while remaining false for non-NULL strings let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("a"), Operator::RegexNotMatch, lit(".*")))? .build()?; assert_optimized_plan_equal!( plan, - @ r#" - Filter: test.a = Utf8("") + @ r" + Filter: test.a IS NULL AND Boolean(NULL) TableScan: test - "# + " )?; // Test case-insensitive versions - // Test `=~ ".*"` (case-insensitive) transforms to true (except for empty strings) + // Test `~* ".*"` transforms to true for any non-NULL string let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("b"), Operator::RegexIMatch, lit(".*")))? .build()?; @@ -913,17 +1054,51 @@ mod tests { " )?; - // Test `!~ ".*"` (case-insensitive) transforms to checking if the column is empty + // Test NULL `!~ ".*"` transforms to Boolean(NULL) + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .filter(binary_expr( + lit(ScalarValue::Utf8(None)), + Operator::RegexNotMatch, + lit(".*"), + ))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @ r" + Filter: Boolean(NULL) + TableScan: test + " + )?; + + // Test `!~* ".*"` preserves NULL semantics while remaining false for non-NULL strings let plan = LogicalPlanBuilder::from(table_scan.clone()) .filter(binary_expr(col("a"), Operator::RegexNotIMatch, lit(".*")))? .build()?; assert_optimized_plan_equal!( plan, - @ r#" - Filter: test.a = Utf8("") + @ r" + Filter: test.a IS NULL AND Boolean(NULL) TableScan: test - "# + " + )?; + + // Test NULL `!~* ".*"` transforms to Boolean(NULL) + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .filter(binary_expr( + lit(ScalarValue::Utf8(None)), + Operator::RegexNotIMatch, + lit(".*"), + ))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @ r" + Filter: Boolean(NULL) + TableScan: test + " ) } diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_literal.rs b/datafusion/optimizer/src/simplify_expressions/simplify_literal.rs new file mode 100644 index 0000000000000..72e9dbc99dfae --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/simplify_literal.rs @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Parses and simplifies an expression to a literal of a given type. +//! +//! This module provides functionality to parse and simplify static expressions +//! used in SQL constructs like `FROM TABLE SAMPLE (10 + 50 * 2)`. If they are required +//! in a planning (not an execution) phase, they need to be reduced to literals of a given type. + +use crate::simplify_expressions::ExprSimplifier; +use arrow::datatypes::ArrowPrimitiveType; +use datafusion_common::{ + DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, plan_datafusion_err, + plan_err, +}; +use datafusion_expr::Expr; +use datafusion_expr::simplify::SimplifyContext; +use std::sync::Arc; + +/// Parse and simplifies an expression to a numeric literal, +/// corresponding to an arrow primitive type `T` (for example, Float64Type). +/// +/// This function simplifies and coerces the expression, then extracts the underlying +/// native type using `TryFrom`. +/// +/// # Example +/// ```ignore +/// let value: f64 = parse_literal::(expr)?; +/// ``` +pub fn parse_literal(expr: &Expr) -> Result +where + T: ArrowPrimitiveType, + T::Native: TryFrom, +{ + // Empty schema is sufficient because it parses only literal expressions + let schema = DFSchemaRef::new(DFSchema::empty()); + + log::debug!("Parsing expr {:?} to type {}", expr, T::DATA_TYPE); + + let simplifier = ExprSimplifier::new( + SimplifyContext::builder() + .with_schema(Arc::clone(&schema)) + .build(), + ); + + // Simplify and coerce expression in case of constant arithmetic operations (e.g., 10 + 5) + let simplified_expr: Expr = simplifier + .simplify(expr.clone()) + .map_err(|err| plan_datafusion_err!("Cannot simplify {expr:?}: {err}"))?; + let coerced_expr: Expr = simplifier.coerce(simplified_expr, schema.as_ref())?; + log::debug!("Coerced expression: {:?}", &coerced_expr); + + match coerced_expr { + Expr::Literal(scalar_value, _) => { + // It is a literal - proceed to the underlying value + // Cast to the target type if needed + let casted_scalar = scalar_value.cast_to(&T::DATA_TYPE)?; + + // Extract the native type + T::Native::try_from(casted_scalar).map_err(|err| { + plan_datafusion_err!( + "Cannot extract {} from scalar value: {err}", + std::any::type_name::() + ) + }) + } + actual => { + plan_err!( + "Cannot extract literal from coerced {actual:?} expression given {expr:?} expression" + ) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{Float64Type, Int64Type}; + use datafusion_expr::{BinaryExpr, lit}; + use datafusion_expr_common::operator::Operator; + + #[test] + fn test_parse_sql_float_literal() { + let test_cases = vec![ + (Expr::Literal(ScalarValue::Float64(Some(0.0)), None), 0.0), + (Expr::Literal(ScalarValue::Float64(Some(1.0)), None), 1.0), + ( + Expr::BinaryExpr(BinaryExpr::new( + Box::new(lit(50.0)), + Operator::Minus, + Box::new(lit(10.0)), + )), + 40.0, + ), + ( + Expr::Literal(ScalarValue::Utf8(Some("1e2".into())), None), + 100.0, + ), + ( + Expr::Literal(ScalarValue::Utf8(Some("2.5e-1".into())), None), + 0.25, + ), + ]; + + for (expr, expected) in test_cases { + let result: Result = parse_literal::(&expr); + + match result { + Ok(value) => { + assert!( + (value - expected).abs() < 1e-10, + "For expression '{expr}': expected {expected}, got {value}", + ); + } + Err(e) => panic!("Failed to parse expression '{expr}': {e}"), + } + } + } + + #[test] + fn test_parse_sql_integer_literal() { + let expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(lit(2)), + Operator::Plus, + Box::new(lit(4)), + )); + + let result: Result = parse_literal::(&expr); + + match result { + Ok(value) => { + assert_eq!(6, value); + } + Err(e) => panic!("Failed to parse expression: {e}"), + } + } +} diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs b/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs index e811ce7313102..356f2711b708e 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs @@ -63,12 +63,14 @@ pub fn simplify_predicates(predicates: Vec) -> Result> { | Operator::Eq, right, }) => { - let left_col = extract_column_from_expr(left); - let right_col = extract_column_from_expr(right); - if let (Some(col), Some(_)) = (&left_col, right.as_literal()) { - column_predicates.entry(col.clone()).or_default().push(pred); - } else if let (Some(_), Some(col)) = (left.as_literal(), &right_col) { - column_predicates.entry(col.clone()).or_default().push(pred); + if let (Some(col), Some(_)) = + (extract_column_from_expr(left), right.as_literal()) + { + column_predicates.entry(col).or_default().push(pred); + } else if let (Some(_), Some(col)) = + (left.as_literal(), extract_column_from_expr(right)) + { + column_predicates.entry(col).or_default().push(pred); } else { other_predicates.push(pred); } diff --git a/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs b/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs new file mode 100644 index 0000000000000..d888a54d56574 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/udf_preimage.rs @@ -0,0 +1,402 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::{Result, internal_err, tree_node::Transformed}; +use datafusion_expr::{Expr, Operator, and, lit, or}; +use datafusion_expr_common::interval_arithmetic::Interval; + +/// Rewrites a binary expression using its "preimage" +/// +/// Specifically it rewrites expressions of the form ` OP x` (e.g. ` = +/// x`) where `` is known to have a pre-image (aka the entire single +/// range for which it is valid) and `x` is not `NULL` +/// +/// For details see [`datafusion_expr::ScalarUDFImpl::preimage`] +pub(super) fn rewrite_with_preimage( + preimage_interval: Interval, + op: Operator, + expr: Expr, +) -> Result> { + let (lower, upper) = preimage_interval.into_bounds(); + let (lower, upper) = (lit(lower), lit(upper)); + + let rewritten_expr = match op { + // < x ==> < lower + Operator::Lt => expr.lt(lower), + // >= x ==> >= lower + Operator::GtEq => expr.gt_eq(lower), + // > x ==> >= upper + Operator::Gt => expr.gt_eq(upper), + // <= x ==> < upper + Operator::LtEq => expr.lt(upper), + // = x ==> ( >= lower) and ( < upper) + Operator::Eq => and(expr.clone().gt_eq(lower), expr.lt(upper)), + // != x ==> ( < lower) or ( >= upper) + Operator::NotEq => or(expr.clone().lt(lower), expr.gt_eq(upper)), + // is not distinct from x ==> ( is NULL and x is NULL) or (( >= lower) and ( < upper)) + // but since x is always not NULL => ( is not NULL) and ( >= lower) and ( < upper) + Operator::IsNotDistinctFrom => expr + .clone() + .is_not_null() + .and(expr.clone().gt_eq(lower)) + .and(expr.lt(upper)), + // is distinct from x ==> ( < lower) or ( >= upper) or ( is NULL and x is not NULL) or ( is not NULL and x is NULL) + // but given that x is always not NULL => ( < lower) or ( >= upper) or ( is NULL) + Operator::IsDistinctFrom => expr + .clone() + .lt(lower) + .or(expr.clone().gt_eq(upper)) + .or(expr.is_null()), + _ => return internal_err!("Expect comparison operators"), + }; + Ok(Transformed::yes(rewritten_expr)) +} + +#[cfg(test)] +mod test { + + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field}; + use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue}; + use datafusion_expr::{ + ColumnarValue, Expr, Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, and, binary_expr, col, lit, or, preimage::PreimageResult, + simplify::SimplifyContext, + }; + + use super::Interval; + use crate::simplify_expressions::ExprSimplifier; + + fn is_distinct_from(left: Expr, right: Expr) -> Expr { + binary_expr(left, Operator::IsDistinctFrom, right) + } + + fn is_not_distinct_from(left: Expr, right: Expr) -> Expr { + binary_expr(left, Operator::IsNotDistinctFrom, right) + } + + #[derive(Debug, PartialEq, Eq, Hash)] + struct PreimageUdf { + /// Defaults to an exact signature with one Int32 argument and Immutable volatility + signature: Signature, + /// If true, returns a preimage; otherwise, returns None + enabled: bool, + } + + impl PreimageUdf { + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable), + enabled: true, + } + } + + /// Set the enabled flag + fn with_enabled(mut self, enabled: bool) -> Self { + self.enabled = enabled; + self + } + + /// Set the volatility + fn with_volatility(mut self, volatility: Volatility) -> Self { + self.signature.volatility = volatility; + self + } + } + + impl ScalarUDFImpl for PreimageUdf { + fn name(&self) -> &str { + "preimage_func" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(500)))) + } + + fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + _info: &SimplifyContext, + ) -> Result { + if !self.enabled { + return Ok(PreimageResult::None); + } + if args.len() != 1 { + return Ok(PreimageResult::None); + } + + let expr = args.first().cloned().expect("Should be column expression"); + match lit_expr { + Expr::Literal(ScalarValue::Int32(Some(500)), _) => { + Ok(PreimageResult::Range { + expr, + interval: Box::new(Interval::try_new( + ScalarValue::Int32(Some(100)), + ScalarValue::Int32(Some(200)), + )?), + }) + } + Expr::Literal(ScalarValue::Int32(Some(600)), _) => { + Ok(PreimageResult::Range { + expr, + interval: Box::new(Interval::try_new( + ScalarValue::Int32(Some(300)), + ScalarValue::Int32(Some(400)), + )?), + }) + } + _ => Ok(PreimageResult::None), + } + } + } + + fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { + let simplify_context = SimplifyContext::builder() + .with_schema(Arc::clone(schema)) + .build(); + ExprSimplifier::new(simplify_context) + .simplify(expr) + .unwrap() + } + + fn preimage_udf_expr() -> Expr { + ScalarUDF::new_from_impl(PreimageUdf::new()).call(vec![col("x")]) + } + + fn non_immutable_udf_expr() -> Expr { + ScalarUDF::new_from_impl(PreimageUdf::new().with_volatility(Volatility::Volatile)) + .call(vec![col("x")]) + } + + fn no_preimage_udf_expr() -> Expr { + ScalarUDF::new_from_impl(PreimageUdf::new().with_enabled(false)) + .call(vec![col("x")]) + } + + fn test_schema() -> DFSchemaRef { + Arc::new( + DFSchema::from_unqualified_fields( + vec![Field::new("x", DataType::Int32, true)].into(), + Default::default(), + ) + .unwrap(), + ) + } + + fn test_schema_xy() -> DFSchemaRef { + Arc::new( + DFSchema::from_unqualified_fields( + vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Int32, false), + ] + .into(), + Default::default(), + ) + .unwrap(), + ) + } + + #[test] + fn test_preimage_eq_rewrite() { + // Equality rewrite when preimage and column expression are available. + let schema = test_schema(); + let expr = preimage_udf_expr().eq(lit(500)); + let expected = and(col("x").gt_eq(lit(100)), col("x").lt(lit(200))); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_noteq_rewrite() { + // Inequality rewrite expands to disjoint ranges. + let schema = test_schema(); + let expr = preimage_udf_expr().not_eq(lit(500)); + let expected = col("x").lt(lit(100)).or(col("x").gt_eq(lit(200))); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_eq_rewrite_swapped() { + // Equality rewrite works when the literal appears on the left. + let schema = test_schema(); + let expr = lit(500).eq(preimage_udf_expr()); + let expected = and(col("x").gt_eq(lit(100)), col("x").lt(lit(200))); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_lt_rewrite() { + // Less-than comparison rewrites to the lower bound. + let schema = test_schema(); + let expr = preimage_udf_expr().lt(lit(500)); + let expected = col("x").lt(lit(100)); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_lteq_rewrite() { + // Less-than-or-equal comparison rewrites to the upper bound. + let schema = test_schema(); + let expr = preimage_udf_expr().lt_eq(lit(500)); + let expected = col("x").lt(lit(200)); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_gt_rewrite() { + // Greater-than comparison rewrites to the upper bound (inclusive). + let schema = test_schema(); + let expr = preimage_udf_expr().gt(lit(500)); + let expected = col("x").gt_eq(lit(200)); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_gteq_rewrite() { + // Greater-than-or-equal comparison rewrites to the lower bound. + let schema = test_schema(); + let expr = preimage_udf_expr().gt_eq(lit(500)); + let expected = col("x").gt_eq(lit(100)); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_is_not_distinct_from_rewrite() { + // IS NOT DISTINCT FROM rewrites to equality plus expression not-null check + // for non-null literal RHS. + let schema = test_schema(); + let expr = is_not_distinct_from(preimage_udf_expr(), lit(500)); + let expected = col("x") + .is_not_null() + .and(col("x").gt_eq(lit(100))) + .and(col("x").lt(lit(200))); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_is_distinct_from_rewrite() { + // IS DISTINCT FROM adds an explicit NULL branch for the column. + let schema = test_schema(); + let expr = is_distinct_from(preimage_udf_expr(), lit(500)); + let expected = col("x") + .lt(lit(100)) + .or(col("x").gt_eq(lit(200))) + .or(col("x").is_null()); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_in_list_rewrite() { + let schema = test_schema(); + let expr = preimage_udf_expr().in_list(vec![lit(500), lit(600)], false); + let expected = or( + and(col("x").gt_eq(lit(100)), col("x").lt(lit(200))), + and(col("x").gt_eq(lit(300)), col("x").lt(lit(400))), + ); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_not_in_list_rewrite() { + let schema = test_schema(); + let expr = preimage_udf_expr().in_list(vec![lit(500), lit(600)], true); + let expected = and( + or(col("x").lt(lit(100)), col("x").gt_eq(lit(200))), + or(col("x").lt(lit(300)), col("x").gt_eq(lit(400))), + ); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_in_list_long_list_no_rewrite() { + let schema = test_schema(); + let expr = preimage_udf_expr().in_list((1..100).map(lit).collect(), false); + + assert_eq!(optimize_test(expr.clone(), &schema), expr); + } + + #[test] + fn test_preimage_non_literal_rhs_no_rewrite() { + // Non-literal RHS should not be rewritten. + let schema = test_schema_xy(); + let expr = preimage_udf_expr().eq(col("y")); + let expected = expr.clone(); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_null_literal_no_rewrite_distinct_ops() { + // NULL literal RHS should not be rewritten for DISTINCTness operators: + // - `expr IS DISTINCT FROM NULL` <=> `NOT (expr IS NULL)` + // - `expr IS NOT DISTINCT FROM NULL` <=> `expr IS NULL` + // + // For normal comparisons (=, !=, <, <=, >, >=), `expr OP NULL` evaluates to NULL + // under SQL tri-state logic, and DataFusion's simplifier constant-folds it. + // https://docs.rs/datafusion/latest/datafusion/physical_optimizer/pruning/struct.PruningPredicate.html#boolean-tri-state-logic + + let schema = test_schema(); + + let expr = is_distinct_from(preimage_udf_expr(), lit(ScalarValue::Int32(None))); + assert_eq!(optimize_test(expr.clone(), &schema), expr); + + let expr = + is_not_distinct_from(preimage_udf_expr(), lit(ScalarValue::Int32(None))); + assert_eq!(optimize_test(expr.clone(), &schema), expr); + } + + #[test] + fn test_preimage_non_immutable_no_rewrite() { + // Non-immutable UDFs should not participate in preimage rewrites. + let schema = test_schema(); + let expr = non_immutable_udf_expr().eq(lit(500)); + let expected = expr.clone(); + + assert_eq!(optimize_test(expr, &schema), expected); + } + + #[test] + fn test_preimage_no_preimage_no_rewrite() { + // If the UDF provides no preimage, the expression should remain unchanged. + let schema = test_schema(); + let expr = no_preimage_udf_expr().eq(lit(500)); + let expected = expr.clone(); + + assert_eq!(optimize_test(expr, &schema), expected); + } +} diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index b1f3b006e0cfc..a5b65d0d8e7a4 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -55,14 +55,14 @@ //! ``` use arrow::datatypes::DataType; -use datafusion_common::{internal_err, tree_node::Transformed}; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::{lit, BinaryExpr}; -use datafusion_expr::{simplify::SimplifyInfo, Cast, Expr, Operator, TryCast}; +use datafusion_common::{internal_err, tree_node::Transformed}; +use datafusion_expr::{BinaryExpr, lit}; +use datafusion_expr::{Cast, Expr, Operator, TryCast, simplify::SimplifyContext}; use datafusion_expr_common::casts::{is_supported_type, try_cast_literal_to_type}; -pub(super) fn unwrap_cast_in_comparison_for_binary( - info: &S, +pub(super) fn unwrap_cast_in_comparison_for_binary( + info: &SimplifyContext, cast_expr: Expr, literal: Expr, op: Operator, @@ -104,10 +104,8 @@ pub(super) fn unwrap_cast_in_comparison_for_binary( } } -pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< - S: SimplifyInfo, ->( - info: &S, +pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary( + info: &SimplifyContext, expr: &Expr, op: Operator, literal: &Expr, @@ -142,10 +140,8 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary< } } -pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist< - S: SimplifyInfo, ->( - info: &S, +pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist( + info: &SimplifyContext, expr: &Expr, list: &[Expr], ) -> bool { @@ -241,7 +237,6 @@ mod tests { use crate::simplify_expressions::ExprSimplifier; use arrow::datatypes::{Field, TimeUnit}; use datafusion_common::{DFSchema, DFSchemaRef}; - use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::{cast, col, in_list, try_cast}; @@ -592,9 +587,10 @@ mod tests { } fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { - let props = ExecutionProps::new(); let simplifier = ExprSimplifier::new( - SimplifyContext::new(&props).with_schema(Arc::clone(schema)), + SimplifyContext::builder() + .with_schema(Arc::clone(schema)) + .build(), ); simplifier.simplify(expr).unwrap() diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 35e256f3064e3..b0908b47602f7 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -18,11 +18,11 @@ //! Utility functions for expression simplification use arrow::datatypes::i256; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::{ + Case, Expr, Like, Operator, expr::{Between, BinaryExpr, InList}, expr_fn::{and, bitwise_and, bitwise_or, or}, - Case, Expr, Like, Operator, }; pub static POWS_OF_TEN: [i128; 38] = [ @@ -290,6 +290,54 @@ pub fn is_lit(expr: &Expr) -> bool { matches!(expr, Expr::Literal(_, _)) } +/// Checks if `eq_expr` is `A = L1` and `ne_expr` is `A != L2` where L1 != L2. +/// This pattern can be simplified to just `A = L1` since if A equals L1 +/// and L1 is different from L2, then A is automatically not equal to L2. +pub fn is_eq_and_ne_with_different_literal(eq_expr: &Expr, ne_expr: &Expr) -> bool { + fn extract_var_and_literal(expr: &Expr) -> Option<(&Expr, &Expr)> { + match expr { + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) + | Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::NotEq, + right, + }) => match (left.as_ref(), right.as_ref()) { + (Expr::Literal(_, _), var) => Some((var, left)), + (var, Expr::Literal(_, _)) => Some((var, right)), + _ => None, + }, + _ => None, + } + } + match (eq_expr, ne_expr) { + ( + Expr::BinaryExpr(BinaryExpr { + op: Operator::Eq, .. + }), + Expr::BinaryExpr(BinaryExpr { + op: Operator::NotEq, + .. + }), + ) => { + // Check if both compare the same expression against different literals + if let (Some((var1, lit1)), Some((var2, lit2))) = ( + extract_var_and_literal(eq_expr), + extract_var_and_literal(ne_expr), + ) && var1 == var2 + && lit1 != lit2 + { + return true; + } + false + } + _ => false, + } +} + /// negate a Not clause /// input is the clause to be negated.(args of Not clause) /// For BinaryExpr, use the negation of op instead. diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 8eb4ae3976f91..00c8fab228117 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -23,15 +23,14 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{ - assert_eq_or_internal_err, tree_node::Transformed, DataFusionError, HashSet, Result, + DataFusionError, HashSet, Result, assert_eq_or_internal_err, tree_node::Transformed, }; use datafusion_expr::builder::project; use datafusion_expr::expr::AggregateFunctionParams; use datafusion_expr::{ - col, + Expr, col, expr::AggregateFunction, logical_plan::{Aggregate, LogicalPlan}, - Expr, }; /// single distinct to group by optimizer rule @@ -56,7 +55,7 @@ pub struct SingleDistinctToGroupBy {} const SINGLE_DISTINCT_ALIAS: &str = "alias1"; impl SingleDistinctToGroupBy { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -185,7 +184,11 @@ impl OptimizerRule for SingleDistinctToGroupBy { func, params: AggregateFunctionParams { - mut args, distinct, .. + mut args, + distinct, + filter, + order_by, + null_treatment, }, }) => { if distinct { @@ -205,9 +208,9 @@ impl OptimizerRule for SingleDistinctToGroupBy { func, vec![col(SINGLE_DISTINCT_ALIAS)], false, // intentional to remove distinct here - None, - vec![], - None, + filter, + order_by, + null_treatment, ))) // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation } else { @@ -218,9 +221,9 @@ impl OptimizerRule for SingleDistinctToGroupBy { Arc::clone(&func), args, false, - None, - vec![], - None, + filter, + order_by, + null_treatment, )) .alias(&alias_str), ); @@ -288,8 +291,8 @@ mod tests { use super::*; use crate::assert_optimized_plan_eq_display_indent_snapshot; use crate::test::*; - use datafusion_expr::expr::GroupingSet; use datafusion_expr::ExprFunctionExt; + use datafusion_expr::expr::GroupingSet; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::{count, count_distinct, max, min, sum}; diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 6e0b734bb9280..2915e77be2e12 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -20,10 +20,11 @@ use crate::optimizer::Optimizer; use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{assert_contains, Result}; -use datafusion_expr::{logical_plan::table_scan, LogicalPlan, LogicalPlanBuilder}; +use datafusion_common::{Result, assert_contains}; +use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, logical_plan::table_scan}; use std::sync::Arc; +pub mod udfs; pub mod user_defined; pub fn test_table_scan_fields() -> Vec { @@ -34,6 +35,28 @@ pub fn test_table_scan_fields() -> Vec { ] } +pub fn test_table_scan_with_struct_fields() -> Vec { + vec![ + Field::new("id", DataType::UInt32, false), + Field::new( + "user", + DataType::Struct( + vec![ + Field::new("name", DataType::Utf8, true), + Field::new("status", DataType::Utf8, true), + ] + .into(), + ), + true, + ), + ] +} + +pub fn test_table_scan_with_struct() -> Result { + let schema = Schema::new(test_table_scan_with_struct_fields()); + table_scan(Some("test"), &schema, None)?.build() +} + /// some tests share a common table with different names pub fn test_table_scan_with_name(name: &str) -> Result { let schema = Schema::new(test_table_scan_fields()); diff --git a/datafusion/optimizer/src/test/udfs.rs b/datafusion/optimizer/src/test/udfs.rs new file mode 100644 index 0000000000000..ba71b6a04a7a2 --- /dev/null +++ b/datafusion/optimizer/src/test/udfs.rs @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; +use datafusion_common::Result; +use datafusion_expr::{ + ColumnarValue, Expr, ExpressionPlacement, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, Signature, TypeSignature, +}; + +/// A configurable test UDF for optimizer tests. +/// Defaults to `MoveTowardsLeafNodes` placement. Use `with_placement()` to override. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct PlacementTestUDF { + signature: Signature, + placement: ExpressionPlacement, + id: usize, +} + +impl Default for PlacementTestUDF { + fn default() -> Self { + Self::new() + } +} + +impl PlacementTestUDF { + pub fn new() -> Self { + Self { + // Accept any one or two arguments and return UInt32 for testing purposes. + // The actual types don't matter since this UDF is not intended for execution. + signature: Signature::new( + TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]), + datafusion_expr::Volatility::Immutable, + ), + placement: ExpressionPlacement::MoveTowardsLeafNodes, + id: 0, + } + } + + /// Set the expression placement for this UDF, which is used by optimizer rules to determine where in the plan the expression should be placed. + /// This also resets the name of the UDF to a default based on the placement. + pub fn with_placement(mut self, placement: ExpressionPlacement) -> Self { + self.placement = placement; + self + } + + /// Set the id of the UDF. + /// This is an arbitrary made up field to allow creating multiple distinct UDFs with the same placement. + pub fn with_id(mut self, id: usize) -> Self { + self.id = id; + self + } +} + +impl ScalarUDFImpl for PlacementTestUDF { + fn name(&self) -> &str { + match self.placement { + ExpressionPlacement::MoveTowardsLeafNodes => "leaf_udf", + ExpressionPlacement::KeepInPlace => "keep_in_place_udf", + ExpressionPlacement::Column => "column_udf", + ExpressionPlacement::Literal => "literal_udf", + } + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::UInt32) + } + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + panic!("PlacementTestUDF: not intended for execution") + } + fn placement(&self, _args: &[ExpressionPlacement]) -> ExpressionPlacement { + self.placement + } +} + +/// Create a `leaf_udf(arg)` expression with `MoveTowardsLeafNodes` placement. +pub fn leaf_udf_expr(arg: Expr) -> Expr { + let udf = ScalarUDF::new_from_impl( + PlacementTestUDF::new().with_placement(ExpressionPlacement::MoveTowardsLeafNodes), + ); + udf.call(vec![arg]) +} diff --git a/datafusion/optimizer/src/test/user_defined.rs b/datafusion/optimizer/src/test/user_defined.rs index a39f90b5da5db..878ce274d5ed6 100644 --- a/datafusion/optimizer/src/test/user_defined.rs +++ b/datafusion/optimizer/src/test/user_defined.rs @@ -19,8 +19,8 @@ use datafusion_common::DFSchemaRef; use datafusion_expr::{ - logical_plan::{Extension, UserDefinedLogicalNodeCore}, Expr, LogicalPlan, + logical_plan::{Extension, UserDefinedLogicalNodeCore}, }; use std::{ fmt::{self, Debug}, diff --git a/datafusion/optimizer/src/unions_to_filter.rs b/datafusion/optimizer/src/unions_to_filter.rs new file mode 100644 index 0000000000000..158fd358287fe --- /dev/null +++ b/datafusion/optimizer/src/unions_to_filter.rs @@ -0,0 +1,652 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Rewrites `UNION DISTINCT` branches that differ only by filter predicates +//! into a single filtered branch plus `DISTINCT`. + +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::Result; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; +use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; +use datafusion_expr::utils::disjunction; +use datafusion_expr::{ + Distinct, Expr, Filter, LogicalPlan, Projection, SubqueryAlias, Union, +}; +use log::debug; +use std::sync::Arc; + +#[derive(Default, Debug)] +pub struct UnionsToFilter; + +impl UnionsToFilter { + #[expect(missing_docs)] + pub fn new() -> Self { + Self + } +} + +impl OptimizerRule for UnionsToFilter { + fn name(&self) -> &str { + "unions_to_filter" + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + if !config.options().optimizer.enable_unions_to_filter { + return Ok(Transformed::no(plan)); + } + + // Fast pre-check: if the plan tree has no Distinct::All node at all we can + // skip the expensive bottom-up rewrite_with_subqueries traversal entirely. + // This matters for large UNION ALL plans (e.g. TPC-DS Q4) where the rule + // can never fire and the traversal overhead is otherwise measurable. + if !plan.exists(|p| Ok(matches!(p, LogicalPlan::Distinct(Distinct::All(_)))))? { + return Ok(Transformed::no(plan)); + } + + plan.rewrite_with_subqueries(&mut UnionsToFilterRewriter) + } +} + +struct UnionsToFilterRewriter; + +impl TreeNodeRewriter for UnionsToFilterRewriter { + type Node = LogicalPlan; + + fn f_up(&mut self, plan: LogicalPlan) -> Result> { + match &plan { + LogicalPlan::Distinct(Distinct::All(input)) => { + match try_rewrite_distinct_union(input.as_ref().clone())? { + Some(rewritten) => Ok(Transformed::yes(rewritten)), + None => Ok(Transformed::no(plan)), + } + } + _ => Ok(Transformed::no(plan)), + } + } +} + +fn try_rewrite_distinct_union(plan: LogicalPlan) -> Result> { + let LogicalPlan::Union(Union { inputs, schema }) = plan else { + debug!("unions_to_filter skipped: input is not a UNION"); + return Ok(None); + }; + + if inputs.len() < 2 { + debug!( + "unions_to_filter skipped: UNION has {} input(s), need at least 2", + inputs.len() + ); + return Ok(None); + } + + // Use a Vec instead of HashMap: union branches are typically 2-10 entries, + // so a linear scan with PartialEq is faster than recursively hashing entire + // LogicalPlan subtrees (O(N * tree_size) hashing for every insert/lookup). + let mut grouped: Vec<(GroupKey, Vec)> = Vec::new(); + let mut transformed = false; + + for input in inputs { + let Some(branch) = extract_branch(Arc::unwrap_or_clone(input))? else { + return Ok(None); + }; + + let key = GroupKey { + source: branch.source, + wrappers: branch.wrappers, + }; + if let Some((_, conds)) = grouped.iter_mut().find(|(k, _)| k == &key) { + conds.push(branch.predicate); + transformed = true; + } else { + grouped.push((key, vec![branch.predicate])); + } + } + + if !transformed { + debug!("unions_to_filter skipped: no branch groups could be merged"); + return Ok(None); + } + + let mut builder: Option = None; + for (key, predicates) in grouped { + let combined = + disjunction(predicates).expect("union branches always provide predicates"); + let branch = LogicalPlanBuilder::from(key.source) + .filter(combined)? + .build()?; + let branch = wrap_branch(branch, &key.wrappers)?; + let branch = coerce_plan_expr_for_schema(branch, &schema)?; + let branch = align_plan_to_schema(branch, Arc::clone(&schema))?; + builder = Some(match builder { + None => LogicalPlanBuilder::from(branch), + Some(builder) => builder.union(branch)?, + }); + } + + let union = builder + .expect("at least one branch after rewrite") + .build()?; + Ok(Some(LogicalPlan::Distinct(Distinct::All(Arc::new(union))))) +} + +struct UnionBranch { + source: LogicalPlan, + predicate: Expr, + wrappers: Vec, +} + +fn extract_branch(plan: LogicalPlan) -> Result> { + let (wrappers, plan) = peel_wrappers(plan); + + // Volatile or subquery expressions in the projection must not be merged: + // they are evaluated once per branch in the original plan but would be + // evaluated once per combined row after the rewrite, which can change the + // output row set. + if !wrapper_projections_are_safe(&wrappers) { + debug!( + "unions_to_filter skipped: projection wrapper contains volatile expression or subquery" + ); + return Ok(None); + } + + match plan { + LogicalPlan::Filter(Filter { + predicate, input, .. + }) => { + if !is_mergeable_predicate(&predicate) { + debug!( + "unions_to_filter skipped: branch predicate contains volatility or a subquery" + ); + return Ok(None); + } + Ok(Some(UnionBranch { + source: strip_passthrough_nodes(Arc::unwrap_or_clone(input)), + predicate, + wrappers, + })) + } + // A Limit or Sort node changes the row-set semantics of the branch. + // Merging two such branches into one would silently drop the per-branch + // row restriction (LIMIT) or rely on an order guarantee that UNION does + // not preserve (ORDER BY). Bail out to leave the UNION unchanged. + LogicalPlan::Limit(_) => { + debug!("unions_to_filter skipped: branch contains LIMIT"); + Ok(None) + } + LogicalPlan::Sort(_) => { + debug!("unions_to_filter skipped: branch contains ORDER BY / SORT"); + Ok(None) + } + other => Ok(Some(UnionBranch { + source: strip_passthrough_nodes(other), + predicate: Expr::Literal( + datafusion_common::ScalarValue::Boolean(Some(true)), + None, + ), + wrappers, + })), + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct GroupKey { + source: LogicalPlan, + wrappers: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum Wrapper { + Projection { + expr: Vec, + schema: datafusion_common::DFSchemaRef, + }, + SubqueryAlias { + alias: datafusion_common::TableReference, + schema: datafusion_common::DFSchemaRef, + }, +} + +fn peel_wrappers(mut plan: LogicalPlan) -> (Vec, LogicalPlan) { + let mut wrappers = vec![]; + loop { + match plan { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + .. + }) => { + wrappers.push(Wrapper::Projection { expr, schema }); + plan = Arc::unwrap_or_clone(input); + } + LogicalPlan::SubqueryAlias(SubqueryAlias { + input, + alias, + schema, + .. + }) => { + wrappers.push(Wrapper::SubqueryAlias { alias, schema }); + plan = Arc::unwrap_or_clone(input); + } + other => return (wrappers, other), + } + } +} + +fn wrap_branch(mut plan: LogicalPlan, wrappers: &[Wrapper]) -> Result { + for wrapper in wrappers.iter().rev() { + plan = match wrapper { + Wrapper::Projection { expr, schema } => { + LogicalPlan::Projection(Projection::try_new_with_schema( + expr.clone(), + Arc::new(plan), + Arc::clone(schema), + )?) + } + // SubqueryAlias::try_new recomputes the schema from the new input. + // This is safe because the source table is unchanged; only the + // filter predicate differs, so the recomputed schema matches the + // original one stored in peel_wrappers. + Wrapper::SubqueryAlias { alias, .. } => LogicalPlan::SubqueryAlias( + SubqueryAlias::try_new(Arc::new(plan), alias.clone())?, + ), + }; + } + Ok(plan) +} + +fn strip_passthrough_nodes(mut plan: LogicalPlan) -> LogicalPlan { + loop { + plan = match plan { + LogicalPlan::Projection(Projection { input, .. }) => { + Arc::unwrap_or_clone(input) + } + LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => { + Arc::unwrap_or_clone(input) + } + other => return other, + }; + } +} + +fn align_plan_to_schema( + plan: LogicalPlan, + schema: datafusion_common::DFSchemaRef, +) -> Result { + if plan.schema() == &schema { + return Ok(plan); + } + + let expr = plan + .schema() + .iter() + .enumerate() + .map(|(i, _)| { + Expr::Column(datafusion_common::Column::from( + plan.schema().qualified_field(i), + )) + }) + .collect::>(); + + Ok(LogicalPlan::Projection(Projection::try_new_with_schema( + expr, + Arc::new(plan), + schema, + )?)) +} + +fn is_mergeable_predicate(expr: &Expr) -> bool { + !expr.is_volatile() && !expr_contains_subquery(expr) +} + +/// Returns `true` when every projection expression in `wrappers` is both +/// non-volatile and subquery-free. +/// +/// Volatile expressions (e.g. `random()`, `now()`) or correlated subqueries +/// in the SELECT list cannot be safely merged: the original plan evaluates +/// them once per branch execution, while the rewritten plan evaluates them +/// once per combined row, which can change the set of output rows. +fn wrapper_projections_are_safe(wrappers: &[Wrapper]) -> bool { + wrappers.iter().all(|w| match w { + Wrapper::Projection { expr, .. } => expr + .iter() + .all(|e| !e.is_volatile() && !expr_contains_subquery(e)), + Wrapper::SubqueryAlias { .. } => true, + }) +} + +fn expr_contains_subquery(expr: &Expr) -> bool { + expr.exists(|e| match e { + Expr::ScalarSubquery(_) | Expr::Exists(_) | Expr::InSubquery(_) => Ok(true), + _ => Ok(false), + }) + .expect("boolean expression walk is infallible") +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::OptimizerContext; + use crate::assert_optimized_plan_eq_snapshot; + use crate::test::test_table_scan_with_name; + use arrow::datatypes::DataType; + use datafusion_common::Result; + use datafusion_expr::{ + ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, + Volatility, col, lit, + }; + + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let mut options = datafusion_common::config::ConfigOptions::default(); + options.optimizer.enable_unions_to_filter = true; + let optimizer_ctx = OptimizerContext::new_with_config_options(Arc::new(options)) + .with_max_passes(1); + let rules: Vec> = + vec![Arc::new(UnionsToFilter::new())]; + assert_optimized_plan_eq_snapshot!( + optimizer_ctx, + rules, + $plan, + @ $expected, + ) + }}; + } + + #[derive(Debug, PartialEq, Eq, Hash)] + struct VolatileTestUdf; + + impl ScalarUDFImpl for VolatileTestUdf { + fn name(&self) -> &str { + "volatile_test" + } + + fn signature(&self) -> &Signature { + static SIGNATURE: std::sync::LazyLock = + std::sync::LazyLock::new(|| Signature::nullary(Volatility::Volatile)); + &SIGNATURE + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + panic!("VolatileTestUdf is not intended for execution") + } + } + + fn volatile_expr() -> Expr { + ScalarUDF::new_from_impl(VolatileTestUdf).call(vec![]) + } + + #[test] + fn rewrite_union_distinct_same_source_filters() -> Result<()> { + let left = LogicalPlanBuilder::from(test_table_scan_with_name("t")?) + .filter(col("a").eq(lit(1)))? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("t")?) + .filter(col("a").eq(lit(2)))? + .build()?; + + let plan = LogicalPlanBuilder::from(left) + .union_distinct(right)? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Distinct: + Projection: t.a, t.b, t.c + Filter: t.a = Int32(1) OR t.a = Int32(2) + TableScan: t + ")?; + Ok(()) + } + + #[test] + fn keep_union_distinct_different_sources() -> Result<()> { + let left = LogicalPlanBuilder::from(test_table_scan_with_name("t1")?) + .filter(col("a").eq(lit(1)))? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("t2")?) + .filter(col("a").eq(lit(2)))? + .build()?; + + let plan = LogicalPlanBuilder::from(left) + .union_distinct(right)? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + Filter: t1.a = Int32(1) + TableScan: t1 + Filter: t2.a = Int32(2) + TableScan: t2 + ")?; + Ok(()) + } + + #[test] + fn keep_union_distinct_with_volatile_predicate() -> Result<()> { + let left = LogicalPlanBuilder::from(test_table_scan_with_name("t")?) + .filter(volatile_expr().gt(lit(0.5_f64)))? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("t")?) + .filter(col("a").eq(lit(2)))? + .build()?; + + let plan = LogicalPlanBuilder::from(left) + .union_distinct(right)? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + Filter: volatile_test() > Float64(0.5) + TableScan: t + Filter: t.a = Int32(2) + TableScan: t + ")?; + Ok(()) + } + + #[test] + fn rewrite_union_distinct_with_matching_projection_prefix() -> Result<()> { + let left = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?) + .project(vec![col("a").alias("mgr"), col("b").alias("comm")])? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?) + .filter(col("b").eq(lit(5)))? + .project(vec![col("a").alias("mgr"), col("b").alias("comm")])? + .build()?; + + let plan = LogicalPlanBuilder::from(left) + .union_distinct(right)? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Distinct: + Projection: emp.a AS mgr, emp.b AS comm + Filter: Boolean(true) OR emp.b = Int32(5) + TableScan: emp + ")?; + Ok(()) + } + + /// A volatile expression in the **projection** (SELECT list) must block the + /// rewrite. Each original branch evaluates it independently; merging them + /// would evaluate it once per combined row, changing the row set. + #[test] + fn keep_union_distinct_with_volatile_projection() -> Result<()> { + // Both branches project volatile_test() AS v over the same source. + let left = LogicalPlanBuilder::from(test_table_scan_with_name("t")?) + .filter(col("a").eq(lit(1)))? + .project(vec![volatile_expr().alias("v"), col("a")])? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("t")?) + .filter(col("a").eq(lit(2)))? + .project(vec![volatile_expr().alias("v"), col("a")])? + .build()?; + + let plan = LogicalPlanBuilder::from(left) + .union_distinct(right)? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + Projection: volatile_test() AS v, t.a + Filter: t.a = Int32(1) + TableScan: t + Projection: volatile_test() AS v, t.a + Filter: t.a = Int32(2) + TableScan: t + ")?; + Ok(()) + } + + /// A scalar subquery in the **projection** must also block the rewrite. + #[test] + fn keep_union_distinct_with_subquery_in_projection() -> Result<()> { + use datafusion_expr::scalar_subquery; + + // Build a simple scalar subquery: (SELECT t2.b FROM t2 WHERE t2.a = t.a) + let t2 = test_table_scan_with_name("t2")?; + let subquery_plan = Arc::new( + LogicalPlanBuilder::from(t2) + .filter(col("t2.a").eq(col("t.a")))? + .project(vec![col("t2.b")])? + .build()?, + ); + let sq = scalar_subquery(subquery_plan); + + let left = LogicalPlanBuilder::from(test_table_scan_with_name("t")?) + .filter(col("a").eq(lit(1)))? + .project(vec![sq.clone().alias("sub"), col("a")])? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("t")?) + .filter(col("a").eq(lit(2)))? + .project(vec![sq.alias("sub"), col("a")])? + .build()?; + + let plan = LogicalPlanBuilder::from(left) + .union_distinct(right)? + .build()?; + + // Plan should be left untouched because the projection contains a subquery. + let optimized = { + let mut options = datafusion_common::config::ConfigOptions::default(); + options.optimizer.enable_unions_to_filter = true; + let optimizer_ctx = + OptimizerContext::new_with_config_options(Arc::new(options)) + .with_max_passes(1); + let rules: Vec> = + vec![Arc::new(UnionsToFilter::new())]; + crate::Optimizer::with_rules(rules).optimize( + plan.clone(), + &optimizer_ctx, + |_, _| {}, + )? + }; + // The Distinct(Union(...)) structure must be preserved. + assert!( + matches!( + &optimized, + LogicalPlan::Distinct(Distinct::All(inner)) + if matches!(inner.as_ref(), LogicalPlan::Union(_)) + ), + "expected Distinct(Union(...)) to be preserved, got:\n{optimized:?}" + ); + Ok(()) + } + + /// A UNION where both branches have a LIMIT must **not** be rewritten. + /// Each branch independently restricts the row-set; collapsing them into a + /// single branch would lose the per-branch LIMIT semantics. + #[test] + fn keep_union_distinct_with_limit_branches() -> Result<()> { + let left = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?) + .project(vec![col("a").alias("mgr"), col("b").alias("comm")])? + .limit(0, Some(2))? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?) + .project(vec![col("a").alias("mgr"), col("b").alias("comm")])? + .limit(0, Some(2))? + .build()?; + + let plan = LogicalPlanBuilder::from(left) + .union_distinct(right)? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + Limit: skip=0, fetch=2 + Projection: emp.a AS mgr, emp.b AS comm + TableScan: emp + Limit: skip=0, fetch=2 + Projection: emp.a AS mgr, emp.b AS comm + TableScan: emp + ")?; + Ok(()) + } + + /// A UNION where both branches have an ORDER BY (Sort) must **not** be + /// rewritten. ORDER BY inside a UNION subquery does not guarantee ordering + /// in the result; merging the branches would silently discard the Sort. + #[test] + fn keep_union_distinct_with_sort_branches() -> Result<()> { + let left = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?) + .project(vec![col("a").alias("mgr"), col("b").alias("comm")])? + .sort(vec![col("a").sort(true, true)])? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("emp")?) + .project(vec![col("a").alias("mgr"), col("b").alias("comm")])? + .sort(vec![col("a").sort(true, true)])? + .build()?; + + let plan = LogicalPlanBuilder::from(left) + .union_distinct(right)? + .build()?; + + assert_optimized_plan_equal!(plan, @r" + Distinct: + Union + Projection: mgr, comm + Sort: emp.a ASC NULLS FIRST + Projection: emp.a AS mgr, emp.b AS comm, emp.a + TableScan: emp + Projection: mgr, comm + Sort: emp.a ASC NULLS FIRST + Projection: emp.a AS mgr, emp.b AS comm, emp.a + TableScan: emp + ")?; + Ok(()) + } +} diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 81763fa0552fb..ad151d1ddb8e0 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -20,14 +20,15 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use crate::analyzer::type_coercion::TypeCoercionRewriter; -use arrow::array::{new_null_array, Array, RecordBatch}; +use arrow::array::{Array, RecordBatch, new_null_array}; use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::TableReference; use datafusion_common::cast::as_boolean_array; use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{Column, DFSchema, Result, ScalarValue}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::replace_col; -use datafusion_expr::{logical_plan::LogicalPlan, ColumnarValue, Expr}; +use datafusion_expr::{ColumnarValue, Expr, logical_plan::LogicalPlan}; use datafusion_physical_expr::create_physical_expr; use log::{debug, trace}; use std::sync::Arc; @@ -37,12 +38,17 @@ use std::sync::Arc; pub use datafusion_expr::expr_rewriter::NamePreserver; /// Returns true if `expr` contains all columns in `schema_cols` -pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet) -> bool { +pub(crate) fn has_all_column_refs( + expr: &Expr, + schema_cols: &HashSet, +) -> bool { let column_refs = expr.column_refs(); // note can't use HashSet::intersect because of different types (owned vs References) - schema_cols + column_refs .iter() - .filter(|c| column_refs.contains(c)) + .filter(|c| { + schema_cols.contains(&ColumnReference::new(c.relation.as_ref(), c.name())) + }) .count() == column_refs.len() } @@ -62,6 +68,40 @@ pub(crate) fn replace_qualified_name( replace_col(expr, &replace_map) } +/// Column reference to avoid copying string around +#[derive(PartialEq, Eq, Hash, Debug)] +pub(crate) struct ColumnReference<'a> { + pub relation: Option<&'a TableReference>, + pub name: &'a str, +} + +impl<'a> ColumnReference<'a> { + pub fn new(relation: Option<&'a TableReference>, name: &'a str) -> Self { + Self { relation, name } + } + + pub fn new_unqualified(name: &'a str) -> Self { + Self { + relation: None, + name, + } + } +} + +/// Returns references to all columns in the schema +pub(crate) fn schema_columns<'a>(schema: &'a DFSchema) -> HashSet> { + schema + .iter() + .flat_map(|(qualifier, field)| { + [ + ColumnReference::new(qualifier, field.name()), + // we need to push down filter using unqualified column as well + ColumnReference::new_unqualified(field.name()), + ] + }) + .collect::>() +} + /// Log the plan in debug/tracing mode after some part of the optimizer runs pub fn log_plan(description: &str, plan: &LogicalPlan) { debug!("{description}:\n{}\n", plan.display_indent()); @@ -154,7 +194,7 @@ fn coerce(expr: Expr, schema: &DFSchema) -> Result { #[cfg(test)] mod tests { use super::*; - use datafusion_expr::{binary_expr, case, col, in_list, is_null, lit, Operator}; + use datafusion_expr::{Operator, binary_expr, case, col, in_list, is_null, lit}; #[test] fn expr_is_restrict_null_predicate() -> Result<()> { diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index c0f48b8ebfc40..6fad39dc33d9f 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -15,17 +15,25 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; +use std::cmp::Ordering; use std::collections::HashMap; +use std::fmt::Formatter; use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{plan_err, Result, TableReference}; +use datafusion_common::{ + DFSchemaRef, Result, ScalarValue, TableReference, ToDFSchema, plan_err, +}; +use datafusion_expr::expr::Cast; +use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; use datafusion_expr::planner::ExprPlanner; use datafusion_expr::test::function_stub::sum_udaf; -use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; +use datafusion_expr::{ + AggregateUDF, Expr, Extension, LogicalPlan, ScalarUDF, SortExpr, + TableProviderFilterPushDown, TableSource, UserDefinedLogicalNodeCore, WindowUDF, col, +}; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::planner::AggregateFunctionPlanner; @@ -40,7 +48,7 @@ use datafusion_sql::sqlparser::parser::Parser; use insta::assert_snapshot; #[cfg(test)] -#[ctor::ctor] +#[ctor::ctor(unsafe)] fn init() { // enable logging so RUST_LOG works let _ = env_logger::try_init(); @@ -48,8 +56,7 @@ fn init() { #[test] fn recursive_cte_with_nested_subquery() -> Result<()> { - // Covers bailout path in `plan_contains_other_subqueries`, ensuring nested subqueries - // within recursive CTE branches prevent projection pushdown. + // projection optimization is applied to recursive CTEs even with nested subqueries let sql = r#" WITH RECURSIVE numbers(id, level) AS ( SELECT sub.id, sub.level FROM ( @@ -67,22 +74,21 @@ fn recursive_cte_with_nested_subquery() -> Result<()> { assert_snapshot!( format!("{plan}"), - @r#" - SubqueryAlias: numbers - Projection: sub.id AS id, sub.level AS level - RecursiveQuery: is_distinct=false - Projection: sub.id, sub.level - SubqueryAlias: sub - Projection: test.col_int32 AS id, Int64(1) AS level - TableScan: test - Projection: t.col_int32, numbers.level + Int64(1) - Inner Join: CAST(t.col_int32 AS Int64) = CAST(numbers.id AS Int64) + Int64(1) - SubqueryAlias: t - Filter: CAST(test.col_int32 AS Int64) IS NOT NULL - TableScan: test - Filter: CAST(numbers.id AS Int64) + Int64(1) IS NOT NULL - TableScan: numbers - "# + @r" + SubqueryAlias: numbers + Projection: sub.id AS id, sub.level AS level + RecursiveQuery: is_distinct=false + SubqueryAlias: sub + Projection: test.col_int32 AS id, Int64(1) AS level + TableScan: test projection=[col_int32] + Projection: t.col_int32, numbers.level + Int64(1) + Inner Join: CAST(t.col_int32 AS Int64) = CAST(numbers.id AS Int64) + Int64(1) + SubqueryAlias: t + Filter: CAST(test.col_int32 AS Int64) IS NOT NULL + TableScan: test projection=[col_int32] + Filter: CAST(numbers.id AS Int64) + Int64(1) IS NOT NULL + TableScan: numbers projection=[id, level] + " ); Ok(()) @@ -95,10 +101,10 @@ fn case_when() -> Result<()> { assert_snapshot!( format!("{plan}"), - @r#" -Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END AS CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END - TableScan: test projection=[col_int32] -"# + @r" + Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END AS CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END + TableScan: test projection=[col_int32] + " ); let sql = "SELECT CASE WHEN col_uint32 > 0 THEN 1 ELSE 0 END FROM test"; @@ -106,10 +112,10 @@ Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END assert_snapshot!( format!("{plan}"), - @r#" + @r" Projection: CASE WHEN test.col_uint32 > UInt32(0) THEN Int64(1) ELSE Int64(0) END AS CASE WHEN test.col_uint32 > Int64(0) THEN Int64(1) ELSE Int64(0) END TableScan: test projection=[col_uint32] - "# + " ); Ok(()) } @@ -128,15 +134,13 @@ fn subquery_filter_with_cast() -> Result<()> { assert_snapshot!( format!("{plan}"), @r#" - Projection: test.col_int32 - Inner Join: Filter: CAST(test.col_int32 AS Float64) > __scalar_sq_1.avg(test.col_int32) - TableScan: test projection=[col_int32] - SubqueryAlias: __scalar_sq_1 - Aggregate: groupBy=[[]], aggr=[[avg(CAST(test.col_int32 AS Float64))]] - Projection: test.col_int32 - Filter: __common_expr_4 >= Date32("2002-05-08") AND __common_expr_4 <= Date32("2002-05-13") - Projection: CAST(test.col_utf8 AS Date32) AS __common_expr_4, test.col_int32 - TableScan: test projection=[col_int32, col_utf8] + Filter: CAST(test.col_int32 AS Float64) > () + Subquery: + Aggregate: groupBy=[[]], aggr=[[avg(CAST(test.col_int32 AS Float64))]] + Projection: test.col_int32 + Filter: CAST(test.col_utf8 AS Date32) >= Date32("2002-05-08") AND CAST(test.col_utf8 AS Date32) <= Date32("2002-05-13") + TableScan: test projection=[col_int32, col_utf8] + TableScan: test projection=[col_int32] "# ); Ok(()) @@ -149,11 +153,11 @@ fn case_when_aggregate() -> Result<()> { assert_snapshot!( format!("{plan}"), - @r#" - Projection: test.col_utf8, sum(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n - Aggregate: groupBy=[[test.col_utf8]], aggr=[[sum(CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS sum(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]] - TableScan: test projection=[col_int32, col_utf8] - "# + @r" + Projection: test.col_utf8, sum(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n + Aggregate: groupBy=[[test.col_utf8]], aggr=[[sum(CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS sum(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]] + TableScan: test projection=[col_int32, col_utf8] + " ); Ok(()) } @@ -165,11 +169,11 @@ fn unsigned_target_type() -> Result<()> { assert_snapshot!( format!("{plan}"), - @r#" + @r" Projection: test.col_utf8 Filter: test.col_uint32 > UInt32(0) TableScan: test projection=[col_uint32, col_utf8] - "# + " ); Ok(()) } @@ -182,10 +186,10 @@ fn distribute_by() -> Result<()> { assert_snapshot!( format!("{plan}"), - @r#" - Repartition: DistributeBy(test.col_utf8) - TableScan: test projection=[col_int32, col_utf8] - "# + @r" + Repartition: DistributeBy(test.col_utf8) + TableScan: test projection=[col_int32, col_utf8] + " ); Ok(()) } @@ -200,16 +204,16 @@ fn semi_join_with_join_filter() -> Result<()> { assert_snapshot!( format!("{plan}"), - @r#" - Projection: test.col_utf8 - LeftSemi Join: test.col_int32 = __correlated_sq_1.col_int32 Filter: test.col_uint32 != __correlated_sq_1.col_uint32 + @r" + Projection: test.col_utf8 + LeftSemi Join: test.col_int32 = __correlated_sq_1.col_int32 Filter: test.col_uint32 != __correlated_sq_1.col_uint32 + Filter: test.col_int32 IS NOT NULL + TableScan: test projection=[col_int32, col_uint32, col_utf8] + SubqueryAlias: __correlated_sq_1 + SubqueryAlias: t2 Filter: test.col_int32 IS NOT NULL - TableScan: test projection=[col_int32, col_uint32, col_utf8] - SubqueryAlias: __correlated_sq_1 - SubqueryAlias: t2 - Filter: test.col_int32 IS NOT NULL - TableScan: test projection=[col_int32, col_uint32] - "# + TableScan: test projection=[col_int32, col_uint32] + " ); Ok(()) } @@ -224,15 +228,15 @@ fn anti_join_with_join_filter() -> Result<()> { assert_snapshot!( format!("{plan}"), - @r#" -Projection: test.col_utf8 - LeftAnti Join: test.col_int32 = __correlated_sq_1.col_int32 Filter: test.col_uint32 != __correlated_sq_1.col_uint32 - TableScan: test projection=[col_int32, col_uint32, col_utf8] - SubqueryAlias: __correlated_sq_1 - SubqueryAlias: t2 - Filter: test.col_int32 IS NOT NULL - TableScan: test projection=[col_int32, col_uint32] -"# + @r" + Projection: test.col_utf8 + LeftAnti Join: test.col_int32 = __correlated_sq_1.col_int32 Filter: test.col_uint32 != __correlated_sq_1.col_uint32 + TableScan: test projection=[col_int32, col_uint32, col_utf8] + SubqueryAlias: __correlated_sq_1 + SubqueryAlias: t2 + Filter: test.col_int32 IS NOT NULL + TableScan: test projection=[col_int32, col_uint32] + " ); Ok(()) } @@ -245,16 +249,16 @@ fn where_exists_distinct() -> Result<()> { assert_snapshot!( format!("{plan}"), - @r#" -LeftSemi Join: test.col_int32 = __correlated_sq_1.col_int32 - Filter: test.col_int32 IS NOT NULL - TableScan: test projection=[col_int32] - SubqueryAlias: __correlated_sq_1 - Aggregate: groupBy=[[t2.col_int32]], aggr=[[]] - SubqueryAlias: t2 - Filter: test.col_int32 IS NOT NULL - TableScan: test projection=[col_int32] -"# + @r" + LeftSemi Join: test.col_int32 = __correlated_sq_1.col_int32 + Filter: test.col_int32 IS NOT NULL + TableScan: test projection=[col_int32] + SubqueryAlias: __correlated_sq_1 + Aggregate: groupBy=[[t2.col_int32]], aggr=[[]] + SubqueryAlias: t2 + Filter: test.col_int32 IS NOT NULL + TableScan: test projection=[col_int32] + " ); Ok(()) @@ -269,15 +273,17 @@ fn intersect() -> Result<()> { assert_snapshot!( format!("{plan}"), - @r#" -LeftSemi Join: test.col_int32 = test.col_int32, test.col_utf8 = test.col_utf8 - Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]] - LeftSemi Join: test.col_int32 = test.col_int32, test.col_utf8 = test.col_utf8 - Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]] - TableScan: test projection=[col_int32, col_utf8] + @r" + LeftSemi Join: left.col_int32 = test.col_int32, left.col_utf8 = test.col_utf8 + Aggregate: groupBy=[[left.col_int32, left.col_utf8]], aggr=[[]] + LeftSemi Join: left.col_int32 = right.col_int32, left.col_utf8 = right.col_utf8 + Aggregate: groupBy=[[left.col_int32, left.col_utf8]], aggr=[[]] + SubqueryAlias: left + TableScan: test projection=[col_int32, col_utf8] + SubqueryAlias: right + TableScan: test projection=[col_int32, col_utf8] TableScan: test projection=[col_int32, col_utf8] - TableScan: test projection=[col_int32, col_utf8] -"# + " ); Ok(()) } @@ -291,11 +297,11 @@ fn between_date32_plus_interval() -> Result<()> { assert_snapshot!( format!("{plan}"), @r#" -Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] - Projection: - Filter: test.col_date32 >= Date32("1998-03-18") AND test.col_date32 <= Date32("1998-06-16") - TableScan: test projection=[col_date32] -"# + Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] + Projection: + Filter: test.col_date32 >= Date32("1998-03-18") AND test.col_date32 <= Date32("1998-06-16") + TableScan: test projection=[col_date32] + "# ); Ok(()) } @@ -309,11 +315,11 @@ fn between_date64_plus_interval() -> Result<()> { assert_snapshot!( format!("{plan}"), @r#" - Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] - Projection: - Filter: test.col_date64 >= Date64("1998-03-18") AND test.col_date64 <= Date64("1998-06-16") - TableScan: test projection=[col_date64] - "# + Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] + Projection: + Filter: test.col_date64 >= Date64("1998-03-18") AND test.col_date64 <= Date64("1998-06-16") + TableScan: test projection=[col_date64] + "# ); Ok(()) } @@ -337,16 +343,16 @@ fn join_keys_in_subquery_alias() { assert_snapshot!( format!("{plan}"), - @r#" - Inner Join: a.col_int32 = b.key - SubqueryAlias: a - Filter: test.col_int32 IS NOT NULL - TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc] - SubqueryAlias: b - Projection: test.col_int32 AS key - Filter: test.col_int32 IS NOT NULL - TableScan: test projection=[col_int32] - "# + @r" + Inner Join: a.col_int32 = b.key + SubqueryAlias: a + Filter: test.col_int32 IS NOT NULL + TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc] + SubqueryAlias: b + Projection: test.col_int32 AS key + Filter: test.col_int32 IS NOT NULL + TableScan: test projection=[col_int32] + " ); } @@ -357,20 +363,20 @@ fn join_keys_in_subquery_alias_1() { assert_snapshot!( format!("{plan}"), - @r#" - Inner Join: a.col_int32 = b.key - SubqueryAlias: a + @r" + Inner Join: a.col_int32 = b.key + SubqueryAlias: a + Filter: test.col_int32 IS NOT NULL + TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc] + SubqueryAlias: b + Projection: test.col_int32 AS key + Inner Join: test.col_int32 = c.col_int32 Filter: test.col_int32 IS NOT NULL - TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc] - SubqueryAlias: b - Projection: test.col_int32 AS key - Inner Join: test.col_int32 = c.col_int32 - Filter: test.col_int32 IS NOT NULL - TableScan: test projection=[col_int32] - SubqueryAlias: c - Filter: test.col_int32 IS NOT NULL - TableScan: test projection=[col_int32] - "# + TableScan: test projection=[col_int32] + SubqueryAlias: c + Filter: test.col_int32 IS NOT NULL + TableScan: test projection=[col_int32] + " ); } @@ -381,12 +387,12 @@ fn push_down_filter_groupby_expr_contains_alias() { assert_snapshot!( format!("{plan}"), - @r#" - Projection: test.col_int32 + test.col_uint32 AS c, count(Int64(1)) AS count(*) - Aggregate: groupBy=[[CAST(test.col_int32 AS Int64) + CAST(test.col_uint32 AS Int64)]], aggr=[[count(Int64(1))]] - Filter: CAST(test.col_int32 AS Int64) + CAST(test.col_uint32 AS Int64) > Int64(3) - TableScan: test projection=[col_int32, col_uint32] - "# + @r" + Projection: test.col_int32 + test.col_uint32 AS c, count(Int64(1)) AS count(*) + Aggregate: groupBy=[[CAST(test.col_int32 AS Int64) + CAST(test.col_uint32 AS Int64)]], aggr=[[count(Int64(1))]] + Filter: CAST(test.col_int32 AS Int64) + CAST(test.col_uint32 AS Int64) > Int64(3) + TableScan: test projection=[col_int32, col_uint32] + " ); } @@ -398,14 +404,14 @@ fn test_same_name_but_not_ambiguous() { assert_snapshot!( format!("{plan}"), - @r#" - LeftSemi Join: t1.col_int32 = t2.col_int32 - Aggregate: groupBy=[[t1.col_int32]], aggr=[[]] - SubqueryAlias: t1 - TableScan: test projection=[col_int32] - SubqueryAlias: t2 - TableScan: test projection=[col_int32] - "# + @r" + LeftSemi Join: t1.col_int32 = t2.col_int32 + Aggregate: groupBy=[[t1.col_int32]], aggr=[[]] + SubqueryAlias: t1 + TableScan: test projection=[col_int32] + SubqueryAlias: t2 + TableScan: test projection=[col_int32] + " ); } @@ -420,10 +426,10 @@ fn eliminate_nested_filters() { assert_snapshot!( format!("{plan}"), - @r#" -Filter: test.col_int32 > Int32(0) - TableScan: test projection=[col_int32] - "# + @r" + Filter: test.col_int32 > Int32(0) + TableScan: test projection=[col_int32] + " ); } @@ -438,11 +444,11 @@ fn eliminate_redundant_null_check_on_count() { assert_snapshot!( format!("{plan}"), - @r#" - Projection: test.col_int32, count(Int64(1)) AS count(*) AS c - Aggregate: groupBy=[[test.col_int32]], aggr=[[count(Int64(1))]] - TableScan: test projection=[col_int32] - "# + @r" + Projection: test.col_int32, count(Int64(1)) AS count(*) AS c + Aggregate: groupBy=[[test.col_int32]], aggr=[[count(Int64(1))]] + TableScan: test projection=[col_int32] + " ); } @@ -466,13 +472,13 @@ fn test_propagate_empty_relation_inner_join_and_unions() { assert_snapshot!( format!("{plan}"), - @r#" -Union - TableScan: test projection=[col_int32] - TableScan: test projection=[col_int32] - Filter: test.col_int32 < Int32(0) - TableScan: test projection=[col_int32] - "#); + @r" + Union + TableScan: test projection=[col_int32] + TableScan: test projection=[col_int32] + Filter: test.col_int32 < Int32(0) + TableScan: test projection=[col_int32] + "); } #[test] @@ -483,10 +489,10 @@ fn select_wildcard_with_repeated_column_but_is_aliased() { assert_snapshot!( format!("{plan}"), - @r#" - Projection: test.col_int32, test.col_uint32, test.col_utf8, test.col_date32, test.col_date64, test.col_ts_nano_none, test.col_ts_nano_utc, test.col_int32 AS col_32 - TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc] - "# + @r" + Projection: test.col_int32, test.col_uint32, test.col_utf8, test.col_date32, test.col_date64, test.col_ts_nano_none, test.col_ts_nano_utc, test.col_int32 AS col_32 + TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc] + " ); } @@ -507,24 +513,22 @@ fn select_correlated_predicate_subquery_with_uppercase_ident() { assert_snapshot!( format!("{plan}"), - @r#" - LeftSemi Join: test.col_int32 = __correlated_sq_1.COL_INT32 - Filter: test.col_int32 IS NOT NULL - TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc] - SubqueryAlias: __correlated_sq_1 - SubqueryAlias: T1 - Projection: test.col_int32 AS COL_INT32 - Filter: test.col_int32 IS NOT NULL - TableScan: test projection=[col_int32] - "# + @r" + LeftSemi Join: test.col_int32 = __correlated_sq_1.COL_INT32 + Filter: test.col_int32 IS NOT NULL + TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc] + SubqueryAlias: __correlated_sq_1 + SubqueryAlias: T1 + Projection: test.col_int32 AS COL_INT32 + Filter: test.col_int32 IS NOT NULL + TableScan: test projection=[col_int32] + " ); } - #[test] -fn recursive_cte_projection_pushdown() -> Result<()> { - // Test that projection pushdown works with recursive CTEs by ensuring - // only the required columns are projected from the base table, even when - // the CTE definition includes unused columns +fn recursive_cte_outer_projection_pushdown() -> Result<()> { + // projection optimization of a recursive CTE based on the outer query's projected columns is + // not done as this can lead to bugs (see: https://github.com/apache/datafusion/issues/22249). let sql = "WITH RECURSIVE nodes AS (\ SELECT col_int32 AS id, col_utf8 AS name, col_uint32 AS extra FROM test \ UNION ALL \ @@ -532,18 +536,20 @@ fn recursive_cte_projection_pushdown() -> Result<()> { ) SELECT id FROM nodes"; let plan = test_sql(sql)?; - // The optimizer successfully performs projection pushdown by only selecting the needed - // columns from the base table and recursive table, eliminating unused columns + // col_int32, col_utf8, and col_uint32 and projected from test since they are used in the + // recursive CTE, even though the outer query only requires col_int32 assert_snapshot!( format!("{plan}"), - @r#"SubqueryAlias: nodes - RecursiveQuery: is_distinct=false - Projection: test.col_int32 AS id - TableScan: test projection=[col_int32] - Projection: CAST(CAST(nodes.id AS Int64) + Int64(1) AS Int32) - Filter: nodes.id < Int32(3) - TableScan: nodes projection=[id] -"# + @r" + SubqueryAlias: nodes + Projection: id + RecursiveQuery: is_distinct=false + Projection: test.col_int32 AS id, test.col_utf8 AS name, test.col_uint32 AS extra + TableScan: test projection=[col_int32, col_uint32, col_utf8] + Projection: CAST(CAST(nodes.id AS Int64) + Int64(1) AS Int32), nodes.name, nodes.extra + Filter: nodes.id < Int32(3) + TableScan: nodes projection=[id, name, extra] + " ); Ok(()) } @@ -559,42 +565,17 @@ fn recursive_cte_with_aliased_self_reference() -> Result<()> { assert_snapshot!( format!("{plan}"), - @r#"SubqueryAlias: nodes - RecursiveQuery: is_distinct=false - Projection: test.col_int32 AS id - TableScan: test projection=[col_int32] - Projection: CAST(CAST(child.id AS Int64) + Int64(1) AS Int32) - SubqueryAlias: child - Filter: nodes.id < Int32(3) - TableScan: nodes projection=[id]"#, - ); - Ok(()) -} - -#[test] -fn recursive_cte_with_unused_columns() -> Result<()> { - // Test projection pushdown with a recursive CTE where the base case - // includes columns that are never used in the recursive part or final result - let sql = "WITH RECURSIVE series AS (\ - SELECT 1 AS n, col_utf8, col_uint32, col_date32 FROM test WHERE col_int32 = 1 \ - UNION ALL \ - SELECT n + 1, col_utf8, col_uint32, col_date32 FROM series WHERE n < 3\ - ) SELECT n FROM series"; - let plan = test_sql(sql)?; - - // The optimizer successfully performs projection pushdown by eliminating unused columns - // even when they're defined in the CTE but not actually needed - assert_snapshot!( - format!("{plan}"), - @r#"SubqueryAlias: series - RecursiveQuery: is_distinct=false - Projection: Int64(1) AS n - Filter: test.col_int32 = Int32(1) - TableScan: test projection=[col_int32] - Projection: series.n + Int64(1) - Filter: series.n < Int64(3) - TableScan: series projection=[n] -"# + @r" + SubqueryAlias: nodes + Projection: id + RecursiveQuery: is_distinct=false + Projection: test.col_int32 AS id, test.col_utf8 AS name + TableScan: test projection=[col_int32, col_utf8] + Projection: CAST(CAST(child.id AS Int64) + Int64(1) AS Int32), child.name + SubqueryAlias: child + Filter: nodes.id < Int32(3) + TableScan: nodes projection=[id, name] + ", ); Ok(()) } @@ -618,15 +599,16 @@ fn recursive_cte_projection_pushdown_baseline() -> Result<()> { // and only the needed column is selected from the recursive table assert_snapshot!( format!("{plan}"), - @r#"SubqueryAlias: countdown - RecursiveQuery: is_distinct=false - Projection: test.col_int32 AS n - Filter: test.col_int32 = Int32(5) - TableScan: test projection=[col_int32] - Projection: CAST(CAST(countdown.n AS Int64) - Int64(1) AS Int32) - Filter: countdown.n > Int32(1) - TableScan: countdown projection=[n] -"# + @r" + SubqueryAlias: countdown + RecursiveQuery: is_distinct=false + Projection: test.col_int32 AS n + Filter: test.col_int32 = Int32(5) + TableScan: test projection=[col_int32] + Projection: CAST(CAST(countdown.n AS Int64) - Int64(1) AS Int32) + Filter: countdown.n > Int32(1) + TableScan: countdown projection=[n] + " ); Ok(()) } @@ -683,6 +665,143 @@ fn test_sql(sql: &str) -> Result { fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} +fn optimize_plan(plan: LogicalPlan) -> Result { + let config = OptimizerContext::new().with_skip_failing_rules(false); + let optimizer = Optimizer::new(); + optimizer.optimize(plan, &config, observe) +} + +/// Extension node that does NOT implement `necessary_children_exprs`. +/// Used to test that the optimizer still processes subtrees below such nodes. +#[derive(Debug, Hash, PartialEq, Eq)] +struct OpaqueRequirementsExtension { + input: Arc, + schema: DFSchemaRef, +} + +impl PartialOrd for OpaqueRequirementsExtension { + fn partial_cmp(&self, other: &Self) -> Option { + self.input + .partial_cmp(&other.input) + .filter(|cmp| *cmp != Ordering::Equal || self == other) + } +} + +impl UserDefinedLogicalNodeCore for OpaqueRequirementsExtension { + fn name(&self) -> &str { + "OpaqueRequirementsExtension" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + mut inputs: Vec, + ) -> Result { + Ok(Self { + input: Arc::new(inputs.swap_remove(0)), + schema: Arc::clone(&self.schema), + }) + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "OpaqueRequirementsExtension") + } +} + +struct InexactFilterTableSource { + schema: SchemaRef, +} + +impl TableSource for InexactFilterTableSource { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()]) + } +} + +/// Reproduction of https://github.com/apache/datafusion/issues/18816 +/// Extension nodes without `necessary_children_exprs` should not prevent +/// the optimizer from pruning unnecessary columns in subtrees. +#[test] +fn extension_node_does_not_block_projection_pruning() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ])); + + let table_source: Arc = Arc::new(InexactFilterTableSource { + schema: Arc::clone(&schema), + }); + + let ts_cast = Expr::Cast(Cast::new( + Box::new(col("t.ts")), + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + )); + let ts_millis_1000 = Expr::Literal( + ScalarValue::TimestampMillisecond(Some(1000), Some("UTC".into())), + None, + ); + let ts_millis_2000 = Expr::Literal( + ScalarValue::TimestampMillisecond(Some(2000), Some("UTC".into())), + None, + ); + + let plan = LogicalPlanBuilder::scan("t", table_source, None)? + .project(vec![col("t.a"), ts_cast.alias_qualified(Some("t"), "ts")])? + .filter( + col("t.ts") + .gt(ts_millis_1000) + .and(col("t.ts").lt(ts_millis_2000)), + )? + .sort(vec![ + SortExpr::new(col("t.a"), true, true), + SortExpr::new(col("t.ts"), true, true), + ])? + .build()?; + + let df_schema = schema.to_dfschema_ref()?; + let plan = LogicalPlan::Extension(Extension { + node: Arc::new(OpaqueRequirementsExtension { + input: Arc::new(plan), + schema: df_schema, + }), + }); + + let optimized = optimize_plan(plan)?; + assert_snapshot!( + format!("{optimized}"), + @r#" + OpaqueRequirementsExtension + Sort: t.a ASC NULLS FIRST, t.ts ASC NULLS FIRST + Projection: t.a, CAST(t.ts AS Timestamp(ms, "UTC")) AS ts + Filter: __common_expr_3 > TimestampMillisecond(1000, Some("UTC")) AND __common_expr_3 < TimestampMillisecond(2000, Some("UTC")) + Projection: CAST(t.ts AS Timestamp(ms, "UTC")) AS __common_expr_3, t.a, t.ts + TableScan: t projection=[a, ts], partial_filters=[t.ts > TimestampNanosecond(1000000000, None), t.ts < TimestampNanosecond(2000000000, None), CAST(t.ts AS Timestamp(ms, "UTC")) > TimestampMillisecond(1000, Some("UTC")), CAST(t.ts AS Timestamp(ms, "UTC")) < TimestampMillisecond(2000, Some("UTC"))] + "#, + ); + + Ok(()) +} + #[derive(Default)] struct MyContextProvider { options: ConfigOptions, @@ -728,6 +847,13 @@ impl ContextProvider for MyContextProvider { None } + fn get_higher_order_meta( + &self, + _name: &str, + ) -> Option> { + None + } + fn get_aggregate_meta(&self, name: &str) -> Option> { self.udafs.get(name).cloned() } @@ -756,6 +882,10 @@ impl ContextProvider for MyContextProvider { Vec::new() } + fn higher_order_function_names(&self) -> Vec { + Vec::new() + } + fn udaf_names(&self) -> Vec { Vec::new() } @@ -774,10 +904,6 @@ struct MyTableSource { } impl TableSource for MyTableSource { - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { self.schema.clone() } diff --git a/datafusion/physical-expr-common/Cargo.toml b/datafusion/physical-expr-common/Cargo.toml index 4602e59c422c3..d1ee7feb29db1 100644 --- a/datafusion/physical-expr-common/Cargo.toml +++ b/datafusion/physical-expr-common/Cargo.toml @@ -40,10 +40,28 @@ workspace = true [lib] name = "datafusion_physical_expr_common" +[features] +default = [] +# Enables the `PhysicalExpr::to_proto` hook used by `datafusion-proto`. +# Off by default so crates that never serialize plans pay nothing. +proto = ["dep:datafusion-proto-models"] + [dependencies] -ahash = { workspace = true } arrow = { workspace = true } +chrono = { workspace = true } datafusion-common = { workspace = true } datafusion-expr-common = { workspace = true } +datafusion-proto-models = { workspace = true, optional = true } hashbrown = { workspace = true } +indexmap = { workspace = true } itertools = { workspace = true } +parking_lot = { workspace = true } +pin-project = { workspace = true } + +[dev-dependencies] +criterion = { workspace = true } +rand = { workspace = true } + +[[bench]] +harness = false +name = "compare_nested" diff --git a/datafusion/physical-expr-common/benches/compare_nested.rs b/datafusion/physical-expr-common/benches/compare_nested.rs new file mode 100644 index 0000000000000..56c122fef9420 --- /dev/null +++ b/datafusion/physical-expr-common/benches/compare_nested.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int32Array, Scalar, StringArray, StructArray}; +use arrow::datatypes::{DataType, Field, Fields}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_expr_common::operator::Operator; +use datafusion_physical_expr_common::datum::compare_op_for_nested; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +/// Build a StructArray with fields {x: Int32, y: Utf8}. +fn make_struct_array(num_rows: usize, rng: &mut StdRng) -> ArrayRef { + let ints: Int32Array = (0..num_rows).map(|_| Some(rng.random::())).collect(); + + let strings: StringArray = (0..num_rows) + .map(|_| { + let s: String = (0..12) + .map(|_| rng.random_range(b'a'..=b'z') as char) + .collect(); + Some(s) + }) + .collect(); + + let fields = Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, false), + ]); + + Arc::new( + StructArray::try_new(fields, vec![Arc::new(ints), Arc::new(strings)], None) + .unwrap(), + ) +} + +fn criterion_benchmark(c: &mut Criterion) { + let num_rows = 8192; + let mut rng = StdRng::seed_from_u64(42); + + let lhs = make_struct_array(num_rows, &mut rng); + let rhs_array = make_struct_array(num_rows, &mut rng); + let rhs_scalar = Scalar::new(make_struct_array(1, &mut rng)); + + c.bench_function("compare_nested array_array", |b| { + b.iter(|| { + black_box(compare_op_for_nested(Operator::Eq, &lhs, &rhs_array).unwrap()) + }) + }); + + c.bench_function("compare_nested array_scalar", |b| { + b.iter(|| { + black_box(compare_op_for_nested(Operator::Eq, &lhs, &rhs_scalar).unwrap()) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index 24bc430630598..ad184d6500d56 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -18,15 +18,15 @@ //! [`ArrowBytesMap`] and [`ArrowBytesSet`] for storing maps/sets of values from //! StringArray / LargeStringArray / BinaryArray / LargeBinaryArray. -use ahash::RandomState; use arrow::array::{ - cast::AsArray, - types::{ByteArrayType, GenericBinaryType, GenericStringType}, Array, ArrayRef, BufferBuilder, GenericBinaryArray, GenericStringArray, NullBufferBuilder, OffsetSizeTrait, + cast::AsArray, + types::{ByteArrayType, GenericBinaryType, GenericStringType}, }; use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow::datatypes::DataType; +use datafusion_common::hash_utils::RandomState; use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt}; use std::any::type_name; @@ -250,7 +250,7 @@ where map_size: 0, buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY), offsets: vec![O::default()], // first offset is always 0 - random_state: RandomState::new(), + random_state: RandomState::default(), hashes_buffer: vec![], null: None, } @@ -389,7 +389,7 @@ where // is value is already present in the set? let entry = self.map.find_mut(hash, |header| { // compare value if hashes match - if header.len != value_len { + if header.hash != hash || header.len != value_len { return false; } // value is stored inline so no need to consult buffer @@ -427,7 +427,7 @@ where // Check if the value is already present in the set let entry = self.map.find_mut(hash, |header| { // compare value if hashes match - if header.len != value_len { + if header.hash != hash { return false; } // Need to compare the bytes in the buffer diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index 2de563472c789..9d4b556393a24 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -17,16 +17,17 @@ //! [`ArrowBytesViewMap`] and [`ArrowBytesViewSet`] for storing maps/sets of values from //! `StringViewArray`/`BinaryViewArray`. -//! Much of the code is from `binary_map.rs`, but with simpler implementation because we directly use the -//! [`GenericByteViewBuilder`]. use crate::binary_map::OutputType; -use ahash::RandomState; +use arrow::array::NullBufferBuilder; use arrow::array::cast::AsArray; -use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder}; +use arrow::array::{Array, ArrayRef, BinaryViewArray, ByteView, make_view}; +use arrow::buffer::{Buffer, ScalarBuffer}; use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType}; +use datafusion_common::hash_utils::RandomState; use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt}; use std::fmt::Debug; +use std::mem::size_of; use std::sync::Arc; /// HashSet optimized for storing string or binary values that can produce that @@ -113,6 +114,9 @@ impl ArrowBytesViewSet { /// This map is used by the special `COUNT DISTINCT` aggregate function to /// store the distinct values, and by the `GROUP BY` operator to store /// group values when they are a single string array. +/// Max size of the in-progress buffer before flushing to completed buffers +const BYTE_VIEW_MAX_BLOCK_SIZE: usize = 2 * 1024 * 1024; + pub struct ArrowBytesViewMap where V: Debug + PartialEq + Eq + Clone + Copy + Default, @@ -124,8 +128,15 @@ where /// Total size of the map in bytes map_size: usize, - /// Builder for output array - builder: GenericByteViewBuilder, + /// Views for all stored values (in insertion order) + views: Vec, + /// In-progress buffer for out-of-line string data + in_progress: Vec, + /// Completed buffers containing string data + completed: Vec, + /// Tracks null values (true = null) + nulls: NullBufferBuilder, + /// random state used to generate hashes random_state: RandomState, /// buffer that stores hash values (reused across batches to save allocations) @@ -148,8 +159,11 @@ where output_type, map: hashbrown::hash_table::HashTable::with_capacity(INITIAL_MAP_CAPACITY), map_size: 0, - builder: GenericByteViewBuilder::new(), - random_state: RandomState::new(), + views: Vec::new(), + in_progress: Vec::new(), + completed: Vec::new(), + nulls: NullBufferBuilder::new(0), + random_state: RandomState::default(), hashes_buffer: vec![], null: None, } @@ -250,53 +264,109 @@ where // step 2: insert each value into the set, if not already present let values = values.as_byte_view::(); + // Get raw views buffer for direct comparison + let input_views = values.views(); + // Ensure lengths are equivalent - assert_eq!(values.len(), batch_hashes.len()); + assert_eq!(values.len(), self.hashes_buffer.len()); - for (value, &hash) in values.iter().zip(batch_hashes.iter()) { - // handle null value - let Some(value) = value else { + for i in 0..values.len() { + let view_u128 = input_views[i]; + let hash = self.hashes_buffer[i]; + + // handle null value via validity bitmap check + if values.is_null(i) { let payload = if let Some(&(payload, _offset)) = self.null.as_ref() { payload } else { let payload = make_payload_fn(None); - let null_index = self.builder.len(); - self.builder.append_null(); + let null_index = self.views.len(); + self.views.push(0); + self.nulls.append_null(); self.null = Some((payload, null_index)); payload }; observe_payload_fn(payload); continue; - }; - - // get the value as bytes - let value: &[u8] = value.as_ref(); + } - let entry = self.map.find_mut(hash, |header| { - let v = self.builder.get_value(header.view_idx); + // Extract length from the view (first 4 bytes of u128 in little-endian) + let len = view_u128 as u32; - if v.len() != value.len() { - return false; - } + // Check if value already exists + let maybe_payload = { + // Borrow completed and in_progress for comparison + let completed = &self.completed; + let in_progress = &self.in_progress; - v == value - }); + self.map + .find(hash, |header| { + if header.hash != hash { + return false; + } + + // Fast path: inline strings can be compared directly + if len <= 12 { + return header.view == view_u128; + } + + // For larger strings: first compare the 4-byte prefix + let stored_prefix = (header.view >> 32) as u32; + let input_prefix = (view_u128 >> 32) as u32; + if stored_prefix != input_prefix { + return false; + } + + // Prefix matched - compare full bytes + let byte_view = ByteView::from(header.view); + let stored_len = byte_view.length as usize; + let buffer_index = byte_view.buffer_index as usize; + let offset = byte_view.offset as usize; + + let stored_value = if buffer_index < completed.len() { + &completed[buffer_index].as_slice() + [offset..offset + stored_len] + } else { + &in_progress[offset..offset + stored_len] + }; + let input_value: &[u8] = values.value(i).as_ref(); + stored_value == input_value + }) + .map(|entry| entry.payload) + }; - let payload = if let Some(entry) = entry { - entry.payload + let payload = if let Some(payload) = maybe_payload { + payload } else { - // no existing value, make a new one. - let payload = make_payload_fn(Some(value)); + // no existing value, make a new one + let (new_view, payload) = if len <= 12 { + // Inline path: bytes are already packed in view_u128. + // The inline ByteView format is [len:u32 LE][data:12 bytes zero-padded], + // so extracting bytes from the u128 avoids a round-trip through + // values.value(i) (which reads the views buffer and returns the same slice). + let view_bytes = view_u128.to_le_bytes(); + let value = &view_bytes[4..4 + len as usize]; + let payload = make_payload_fn(Some(value)); + // For inline strings, the stored view is identical to the input view: + // make_view(value, 0, 0) produces the same u128 as view_u128. + // + // SAFETY: view_u128 was a valid view, and the enclosing `len <= 12` + // ensures it is inline + let new_view = unsafe { self.append_inline_view(view_u128) }; + (new_view, payload) + } else { + let value: &[u8] = values.value(i).as_ref(); + let payload = make_payload_fn(Some(value)); + let new_view = self.append_value(value); + (new_view, payload) + }; - let inner_view_idx = self.builder.len(); let new_header = Entry { - view_idx: inner_view_idx, + view: new_view, hash, payload, }; - self.builder.append_value(value); - self.map .insert_accounted(new_header, |h| h.hash, &mut self.map_size); payload @@ -311,29 +381,78 @@ where /// /// The values are guaranteed to be returned in the same order in which /// they were first seen. - pub fn into_state(self) -> ArrayRef { - let mut builder = self.builder; - match self.output_type { - OutputType::BinaryView => { - let array = builder.finish(); + pub fn into_state(mut self) -> ArrayRef { + // Flush any remaining in-progress buffer + if !self.in_progress.is_empty() { + let flushed = std::mem::take(&mut self.in_progress); + self.completed.push(Buffer::from_vec(flushed)); + } - Arc::new(array) - } + // Build null buffer if we have any nulls + let null_buffer = self.nulls.finish(); + + let views = ScalarBuffer::from(self.views); + let array = + unsafe { BinaryViewArray::new_unchecked(views, self.completed, null_buffer) }; + + match self.output_type { + OutputType::BinaryView => Arc::new(array), OutputType::Utf8View => { - // SAFETY: - // we asserted the input arrays were all the correct type and - // thus since all the values that went in were valid (e.g. utf8) - // so are all the values that come out - let array = builder.finish(); + // SAFETY: all input was valid utf8 let array = unsafe { array.to_string_view_unchecked() }; Arc::new(array) } - _ => { - unreachable!("Utf8/Binary should use `ArrowBytesMap`") - } + _ => unreachable!("Utf8/Binary should use `ArrowBytesMap`"), } } + /// Append an already-computed inline view (len <= 12) directly, bypassing + /// buffer allocation. + /// + /// Returns the view that was stored (identical to the argument). + /// + /// # Safety + /// + /// `view` must be a valid inline `ByteView`: the length field in the low + /// 32 bits must be <= 12, and the remaining 12 bytes must hold the + /// value's bytes (zero-padded if shorter). Calling with a non-inline view + /// would store a value that downstream `views` consumers interpret as + /// `[buffer_index, offset]` into the `completed`/`in_progress` buffers, + /// which is unsound for any view that didn't originate from a real + /// allocation in those buffers. + unsafe fn append_inline_view(&mut self, view: u128) -> u128 { + self.views.push(view); + self.nulls.append_non_null(); + view + } + + /// Append a value to our buffers and return the view pointing to it + fn append_value(&mut self, value: &[u8]) -> u128 { + let len = value.len(); + let view = if len <= 12 { + make_view(value, 0, 0) + } else { + // Ensure buffer is big enough + if self.in_progress.len() + len > BYTE_VIEW_MAX_BLOCK_SIZE { + let flushed = std::mem::replace( + &mut self.in_progress, + Vec::with_capacity(BYTE_VIEW_MAX_BLOCK_SIZE), + ); + self.completed.push(Buffer::from_vec(flushed)); + } + + let buffer_index = self.completed.len() as u32; + let offset = self.in_progress.len() as u32; + self.in_progress.extend_from_slice(value); + + make_view(value, buffer_index, offset) + }; + + self.views.push(view); + self.nulls.append_non_null(); + view + } + /// Total number of entries (including null, if present) pub fn len(&self) -> usize { self.non_null_len() + self.null.map(|_| 1).unwrap_or(0) @@ -352,8 +471,16 @@ where /// Return the total size, in bytes, of memory used to store the data in /// this set, not including `self` pub fn size(&self) -> usize { + let views_size = self.views.len() * size_of::(); + let in_progress_size = self.in_progress.capacity(); + let completed_size: usize = self.completed.iter().map(|b| b.len()).sum(); + let nulls_size = self.nulls.allocated_size(); + self.map_size - + self.builder.allocated_size() + + views_size + + in_progress_size + + completed_size + + nulls_size + self.hashes_buffer.allocated_size() } } @@ -366,7 +493,8 @@ where f.debug_struct("ArrowBytesMap") .field("map", &"") .field("map_size", &self.map_size) - .field("view_builder", &self.builder) + .field("views_len", &self.views.len()) + .field("completed_buffers", &self.completed.len()) .field("random_state", &self.random_state) .field("hashes_buffer", &self.hashes_buffer) .finish() @@ -374,13 +502,20 @@ where } /// Entry in the hash table -- see [`ArrowBytesViewMap`] for more details +/// +/// Stores the view pointing to our internal buffers, eliminating the need +/// for a separate builder index. For inline strings (<=12 bytes), the view +/// contains the entire value. For out-of-line strings, the view contains +/// buffer_index and offset pointing directly to our storage. #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] struct Entry where V: Debug + PartialEq + Eq + Clone + Copy + Default, { - /// The idx into the views array - view_idx: usize, + /// The u128 view pointing to our internal buffers. For inline strings, + /// this contains the complete value. For larger strings, this contains + /// the buffer_index/offset into our completed/in_progress buffers. + view: u128, hash: u64, @@ -390,7 +525,7 @@ where #[cfg(test)] mod tests { - use arrow::array::{BinaryViewArray, GenericByteViewArray, StringViewArray}; + use arrow::array::{GenericByteViewArray, StringViewArray}; use datafusion_common::HashMap; use super::*; diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs index c84d3afeeff6c..d23fb30db6c4a 100644 --- a/datafusion/physical-expr-common/src/datum.rs +++ b/datafusion/physical-expr-common/src/datum.rs @@ -16,16 +16,16 @@ // under the License. use arrow::array::BooleanArray; -use arrow::array::{make_comparator, ArrayRef, Datum}; -use arrow::buffer::NullBuffer; +use arrow::array::{ArrayRef, Datum, make_comparator}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::kernels::cmp::{ distinct, eq, gt, gt_eq, lt, lt_eq, neq, not_distinct, }; -use arrow::compute::{ilike, like, nilike, nlike, SortOptions}; +use arrow::compute::{SortOptions, ilike, like, nilike, nlike}; use arrow::error::ArrowError; -use datafusion_common::DataFusionError; -use datafusion_common::{arrow_datafusion_err, assert_or_internal_err, internal_err}; +use datafusion_common::utils::{normalize_float_zero, normalize_float_zero_scalar}; use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{arrow_datafusion_err, assert_or_internal_err, internal_err}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::operator::Operator; use std::sync::Arc; @@ -85,7 +85,22 @@ pub fn apply_cmp( } }; - apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?))) + // Arrow's comparison kernels use IEEE 754 totalOrder semantics for + // floats, which treats `-0.0` and `+0.0` as distinct. Normalize float + // operands so SQL semantics (`+0.0 == -0.0`) hold. No-op for + // non-float types. + let lhs = normalize_cmp_input(lhs); + let rhs = normalize_cmp_input(rhs); + apply(&lhs, &rhs, |l, r| Ok(Arc::new(f(l, r)?))) + } +} + +fn normalize_cmp_input(cv: &ColumnarValue) -> ColumnarValue { + match cv { + ColumnarValue::Array(a) => ColumnarValue::Array(normalize_float_zero(a)), + ColumnarValue::Scalar(s) => { + ColumnarValue::Scalar(normalize_float_zero_scalar(s.clone())) + } } } @@ -172,9 +187,9 @@ pub fn compare_op_for_nested( }; let values = match (is_l_scalar, is_r_scalar) { - (false, false) => (0..len).map(|i| cmp_with_op(i, i)).collect(), - (true, false) => (0..len).map(|i| cmp_with_op(0, i)).collect(), - (false, true) => (0..len).map(|i| cmp_with_op(i, 0)).collect(), + (false, false) => BooleanBuffer::collect_bool(len, |i| cmp_with_op(i, i)), + (true, false) => BooleanBuffer::collect_bool(len, |i| cmp_with_op(0, i)), + (false, true) => BooleanBuffer::collect_bool(len, |i| cmp_with_op(i, 0)), (true, true) => std::iter::once(cmp_with_op(0, 0)).collect(), }; @@ -190,14 +205,14 @@ pub fn compare_op_for_nested( (false, false) | (true, true) => NullBuffer::union(l.nulls(), r.nulls()), (true, false) => { // When left is null-scalar and right is array, expand left nulls to match result length - match l.nulls().filter(|nulls| !nulls.is_valid(0)) { + match l.nulls().filter(|nulls| nulls.is_null(0)) { Some(_) => Some(NullBuffer::new_null(len)), // Left scalar is null None => r.nulls().cloned(), // Left scalar is non-null } } (false, true) => { // When right is null-scalar and left is array, expand right nulls to match result length - match r.nulls().filter(|nulls| !nulls.is_valid(0)) { + match r.nulls().filter(|nulls| nulls.is_null(0)) { Some(_) => Some(NullBuffer::new_null(len)), // Right scalar is null None => l.nulls().cloned(), // Right scalar is non-null } diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index cac863ee69fb4..b6eaacdca2505 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -23,8 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! Physical Expr Common packages for [DataFusion] @@ -35,6 +33,7 @@ pub mod binary_map; pub mod binary_view_map; pub mod datum; +pub mod metrics; pub mod physical_expr; pub mod sort_expr; pub mod tree_node; diff --git a/datafusion/physical-expr-common/src/metrics/baseline.rs b/datafusion/physical-expr-common/src/metrics/baseline.rs new file mode 100644 index 0000000000000..52ad4aac9fd98 --- /dev/null +++ b/datafusion/physical-expr-common/src/metrics/baseline.rs @@ -0,0 +1,376 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Metrics common for almost all operators + +use std::{borrow::Cow, collections::BTreeMap, sync::Arc, task::Poll}; + +use arrow::record_batch::RecordBatch; +use datafusion_common::{Result, utils::memory::get_record_batch_memory_size}; + +use super::{ + Count, ExecutionPlanMetricsSet, Metric, MetricBuilder, MetricsSet, Time, Timestamp, +}; + +const OUTPUT_ROWS_SKEW_METRIC_NAME: &str = "output_rows_skew"; + +/// Helper for creating and tracking common "baseline" metrics for +/// each operator +/// +/// Example: +/// ``` +/// use datafusion_physical_expr_common::metrics::{ +/// BaselineMetrics, ExecutionPlanMetricsSet, +/// }; +/// let metrics = ExecutionPlanMetricsSet::new(); +/// +/// let partition = 2; +/// let baseline_metrics = BaselineMetrics::new(&metrics, partition); +/// +/// // during execution, in CPU intensive operation: +/// let timer = baseline_metrics.elapsed_compute().timer(); +/// // .. do CPU intensive work +/// timer.done(); +/// +/// // when operator is finished: +/// baseline_metrics.done(); +/// ``` +#[derive(Debug, Clone)] +pub struct BaselineMetrics { + /// end_time is set when `BaselineMetrics::done()` is called + end_time: Timestamp, + + /// amount of time the operator was actively trying to use the CPU + elapsed_compute: Time, + + /// output rows: the total output rows + output_rows: Count, + + /// Memory usage of all output batches. + /// + /// Note: This value may be overestimated. If multiple output `RecordBatch` + /// instances share underlying memory buffers, their sizes will be counted + /// multiple times. + /// Issue: + output_bytes: Count, + + /// output batches: the total output batch count + output_batches: Count, + // Remember to update `docs/source/user-guide/metrics.md` when updating comments + // or adding new metrics +} + +impl BaselineMetrics { + /// Create a new BaselineMetric structure, and set `start_time` to now + pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + let start_time = MetricBuilder::new(metrics).start_timestamp(partition); + start_time.record(); + + Self { + end_time: MetricBuilder::new(metrics) + .with_type(super::MetricType::Summary) + .end_timestamp(partition), + elapsed_compute: MetricBuilder::new(metrics) + .with_type(super::MetricType::Summary) + .elapsed_compute(partition), + output_rows: MetricBuilder::new(metrics) + .with_type(super::MetricType::Summary) + .output_rows(partition), + output_bytes: MetricBuilder::new(metrics) + .with_type(super::MetricType::Summary) + .output_bytes(partition), + output_batches: MetricBuilder::new(metrics) + .with_type(super::MetricType::Dev) + .output_batches(partition), + } + } + + /// Returns a [`BaselineMetrics`] that updates the same `elapsed_compute` ignoring + /// all other metrics + /// + /// This is useful when an operator offloads some of its intermediate work to separate tasks + /// that as a result won't be recorded by [`Self::record_poll`] + pub fn intermediate(&self) -> BaselineMetrics { + Self { + end_time: Default::default(), + elapsed_compute: self.elapsed_compute.clone(), + output_rows: Default::default(), + output_bytes: Default::default(), + output_batches: Default::default(), + } + } + + /// return the metric for cpu time spend in this operator + pub fn elapsed_compute(&self) -> &Time { + &self.elapsed_compute + } + + /// return the metric for the total number of output rows produced + pub fn output_rows(&self) -> &Count { + &self.output_rows + } + + /// return the metric for the total number of output batches produced + pub fn output_batches(&self) -> &Count { + &self.output_batches + } + + /// Returns a derived metric that summarizes how unevenly `output_rows` + /// are distributed across partitions. + /// + /// The score is normalized to the range `[0%, 100%]`, where `0%` + /// indicates a perfectly balanced distribution and `100%` indicates the + /// most skewed distribution. + /// + /// The calculation is: + /// `effective_parallelism = square(sum(r_i)) / sum(square(r_i))` + /// `output_rows_skew = (1 - ((effective_parallelism - 1) / (partition_count - 1))) * 100%` + /// + /// Example: for 4 partitions with output rows `[10, 10, 10, 10]`, + /// `effective_parallelism = 40^2 / (10^2 + 10^2 + 10^2 + 10^2) = 4`, + /// so `output_rows_skew = 0%`. For `[40, 0, 0, 0]`, the score is `100%`. + pub fn output_rows_skew_metric(metrics: &MetricsSet) -> Option> { + let output_rows = metrics + .iter() + .filter_map(|metric| match (metric.partition(), metric.value()) { + (Some(partition), super::MetricValue::OutputRows(count)) => { + Some((partition, count.value() as u128)) + } + _ => None, + }) + .fold( + BTreeMap::::new(), + |mut output_rows, (partition, rows)| { + *output_rows.entry(partition).or_default() += rows; + output_rows + }, + ) + .into_values() + .collect::>(); + + if output_rows.is_empty() { + return None; + } + + let ratio_metrics = super::RatioMetrics::new().with_display_raw_values(false); + if let Some(score) = output_rows_skew_score(&output_rows) { + ratio_metrics.set_part((score * 10_000.0).round() as usize); + ratio_metrics.set_total(10_000); + } + + Some(Arc::new( + Metric::new( + super::MetricValue::Ratio { + name: Cow::Borrowed(OUTPUT_ROWS_SKEW_METRIC_NAME), + ratio_metrics, + }, + None, + ) + .with_type(super::MetricType::Dev), + )) + } + + /// Records the fact that this operator's execution is complete + /// (recording the `end_time` metric). + /// + /// Note care should be taken to call `done()` manually if + /// `BaselineMetrics` is not `drop`ped immediately upon operator + /// completion, as async streams may not be dropped immediately + /// depending on the consumer. + pub fn done(&self) { + self.end_time.record() + } + + /// Record that some number of rows have been produced as output + /// + /// See the [`RecordOutput`] for conveniently recording record + /// batch output for other thing + pub fn record_output(&self, num_rows: usize) { + self.output_rows.add(num_rows); + } + + /// If not previously recorded `done()`, record + pub fn try_done(&self) { + if self.end_time.value().is_none() { + self.end_time.record() + } + } + + /// Process a poll result of a stream producing output for an operator. + /// + /// Note: this method only updates `output_rows` and `end_time` metrics. + /// Remember to update `elapsed_compute` and other metrics manually. + pub fn record_poll( + &self, + poll: Poll>>, + ) -> Poll>> { + if let Poll::Ready(maybe_batch) = &poll { + match maybe_batch { + Some(Ok(batch)) => { + batch.record_output(self); + } + Some(Err(_)) => self.done(), + None => self.done(), + } + } + poll + } +} + +impl Drop for BaselineMetrics { + fn drop(&mut self) { + self.try_done() + } +} + +/// See [`BaselineMetrics::output_rows_skew_metric`] for the algorithm. +fn output_rows_skew_score(output_rows: &[u128]) -> Option { + if output_rows.is_empty() { + return None; + } + + let partition_count = output_rows.len(); + if partition_count == 1 { + return Some(0.0); + } + + let (total_rows, sum_of_squares) = + output_rows + .iter() + .fold((0.0, 0.0), |(total_rows, sum_of_squares), rows| { + let rows = *rows as f64; + (total_rows + rows, sum_of_squares + rows.powi(2)) + }); + if total_rows == 0.0 { + return None; + } + + if sum_of_squares == 0.0 { + return None; + } + + let effective_parallelism = total_rows.powi(2) / sum_of_squares; + let balanced_score = (effective_parallelism - 1.0) / (partition_count as f64 - 1.0); + + Some((1.0 - balanced_score).clamp(0.0, 1.0)) +} + +/// Helper for creating and tracking spill-related metrics for +/// each operator +#[derive(Debug, Clone)] +pub struct SpillMetrics { + /// count of spills during the execution of the operator + pub spill_file_count: Count, + + /// total bytes actually written to disk during the execution of the operator + pub spilled_bytes: Count, + + /// total spilled rows during the execution of the operator + pub spilled_rows: Count, +} + +impl SpillMetrics { + /// Create a new SpillMetrics structure + pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + Self { + spill_file_count: MetricBuilder::new(metrics).spill_count(partition), + spilled_bytes: MetricBuilder::new(metrics).spilled_bytes(partition), + spilled_rows: MetricBuilder::new(metrics).spilled_rows(partition), + } + } +} + +/// Metrics for tracking batch splitting activity +#[derive(Debug, Clone)] +pub struct SplitMetrics { + /// Number of times an input [`RecordBatch`] was split + pub batches_split: Count, +} + +impl SplitMetrics { + /// Create a new [`SplitMetrics`] + pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + Self { + batches_split: MetricBuilder::new(metrics) + .with_category(super::MetricCategory::Rows) + .counter("batches_split", partition), + } + } +} + +/// Trait for things that produce output rows as a result of execution. +pub trait RecordOutput { + /// Record that some number of output rows have been produced + /// + /// Meant to be composable so that instead of returning `batch` + /// the operator can return `batch.record_output(baseline_metrics)` + fn record_output(self, bm: &BaselineMetrics) -> Self; +} + +impl RecordOutput for usize { + fn record_output(self, bm: &BaselineMetrics) -> Self { + bm.record_output(self); + self + } +} + +impl RecordOutput for RecordBatch { + fn record_output(self, bm: &BaselineMetrics) -> Self { + bm.record_output(self.num_rows()); + let n_bytes = get_record_batch_memory_size(&self); + bm.output_bytes.add(n_bytes); + bm.output_batches.add(1); + self + } +} + +impl RecordOutput for &RecordBatch { + fn record_output(self, bm: &BaselineMetrics) -> Self { + bm.record_output(self.num_rows()); + let n_bytes = get_record_batch_memory_size(self); + bm.output_bytes.add(n_bytes); + bm.output_batches.add(1); + self + } +} + +impl RecordOutput for Option<&RecordBatch> { + fn record_output(self, bm: &BaselineMetrics) -> Self { + if let Some(record_batch) = &self { + record_batch.record_output(bm); + } + self + } +} + +impl RecordOutput for Option { + fn record_output(self, bm: &BaselineMetrics) -> Self { + if let Some(record_batch) = &self { + record_batch.record_output(bm); + } + self + } +} + +impl RecordOutput for Result { + fn record_output(self, bm: &BaselineMetrics) -> Self { + if let Ok(record_batch) = &self { + record_batch.record_output(bm); + } + self + } +} diff --git a/datafusion/physical-expr-common/src/metrics/builder.rs b/datafusion/physical-expr-common/src/metrics/builder.rs new file mode 100644 index 0000000000000..de9d1e03d88df --- /dev/null +++ b/datafusion/physical-expr-common/src/metrics/builder.rs @@ -0,0 +1,341 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Builder for creating arbitrary metrics + +use std::{borrow::Cow, sync::Arc}; + +use crate::metrics::{ + MetricCategory, MetricType, + value::{PruningMetrics, RatioMergeStrategy, RatioMetrics}, +}; + +use super::{ + Count, ExecutionPlanMetricsSet, Gauge, Label, LabelValue, Metric, MetricValue, Time, + Timestamp, +}; + +/// Structure for constructing metrics, counters, timers, etc. +/// +/// Note the use of `Cow<..>` is to avoid allocations in the common +/// case of constant strings. Dynamically created label strings are shared when +/// [`Label`] values are cloned. +/// +/// ```rust +/// use datafusion_physical_expr_common::metrics::*; +/// +/// let metrics = ExecutionPlanMetricsSet::new(); +/// let partition = 1; +/// +/// // Create the standard output_rows metric +/// let output_rows = MetricBuilder::new(&metrics).output_rows(partition); +/// +/// // Create a operator specific counter with some labels +/// let num_bytes = MetricBuilder::new(&metrics) +/// .with_new_label("filename", "my_awesome_file.parquet") +/// .counter("num_bytes", partition); +/// ``` +#[derive(Clone)] +pub struct MetricBuilder<'a> { + /// Location that the metric created by this builder will be added do + metrics: &'a ExecutionPlanMetricsSet, + + /// optional partition number + partition: Option, + + /// arbitrary name=value pairs identifying this metric + labels: Vec(); + if id == TypeId::of::() { + "Utf8" + } else if id == TypeId::of::() { + "Utf8View" + } else if id == TypeId::of::() { + "Float32" + } else if id == TypeId::of::() { + "Int16" + } else if id == TypeId::of::() { + "Int32" + } else if id == TypeId::of::() { + "TimestampNs" + } else if id == TypeId::of::() { + "UInt8" + } else { + "Unknown" + } +} + +/// Builds a benchmark name from array type, list size, and null percentage. +fn bench_name(in_list_length: usize, null_percent: f64) -> String { + format!( + "in_list/{}/list={in_list_length}/nulls={}%", + array_type_name::(), + (null_percent * 100.0) as u32 + ) +} + +/// Runs in_list benchmarks for a string array type across all list-size × null-ratio × string-length combinations. +fn bench_string_type( c: &mut Criterion, - array_length: usize, - in_list_length: usize, - null_percent: f64, -) { - let mut rng = StdRng::seed_from_u64(120320); - for string_length in [5, 10, 20] { - let values: StringArray = (0..array_length) - .map(|_| { - rng.random_bool(null_percent) - .then(|| random_string(&mut rng, string_length)) - }) - .collect(); - - let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::from(random_string(&mut rng, string_length))) - .collect(); - - do_bench( - c, - &format!( - "in_list_utf8({string_length}) ({array_length}, {null_percent}) IN ({in_list_length}, 0)" - ), - Arc::new(values), - &in_list, - ) + rng: &mut StdRng, + make_scalar: fn(String) -> ScalarValue, +) where + A: Array + FromIterator> + 'static, +{ + for in_list_length in IN_LIST_LENGTHS { + for null_percent in NULL_PERCENTS { + for string_length in STRING_LENGTHS { + let values: A = (0..ARRAY_LENGTH) + .map(|_| { + rng.random_bool(1.0 - null_percent) + .then(|| random_string(rng, string_length)) + }) + .collect(); + + let in_list: Vec<_> = (0..in_list_length) + .map(|_| make_scalar(random_string(rng, string_length))) + .collect(); + + do_bench( + c, + &format!( + "{}/str={string_length}", + bench_name::(in_list_length, null_percent) + ), + Arc::new(values), + &in_list, + ) + } + } } +} - let values: Float32Array = (0..array_length) - .map(|_| rng.random_bool(null_percent).then(|| rng.random())) - .collect(); +/// Runs in_list benchmarks for a numeric array type across all list-size × null-ratio combinations. +fn bench_numeric_type( + c: &mut Criterion, + rng: &mut StdRng, + mut gen_value: impl FnMut(&mut StdRng) -> T, + make_scalar: fn(T) -> ScalarValue, +) where + A: Array + FromIterator> + 'static, +{ + for in_list_length in IN_LIST_LENGTHS { + for null_percent in NULL_PERCENTS { + let values: A = (0..ARRAY_LENGTH) + .map(|_| rng.random_bool(1.0 - null_percent).then(|| gen_value(rng))) + .collect(); - let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Float32(Some(rng.random()))) - .collect(); + let in_list: Vec<_> = (0..in_list_length) + .map(|_| make_scalar(gen_value(rng))) + .collect(); - do_bench( - c, - &format!("in_list_f32 ({array_length}, {null_percent}) IN ({in_list_length}, 0)"), - Arc::new(values), - &in_list, - ); + do_bench( + c, + &bench_name::(in_list_length, null_percent), + Arc::new(values), + &in_list, + ); + } + } +} + +/// Generates a random string with a length chosen from MIXED_STRING_LENGTHS. +fn random_mixed_length_string(rng: &mut StdRng) -> String { + let len = *MIXED_STRING_LENGTHS.choose(rng).unwrap(); + random_string(rng, len) +} + +/// Benchmarks realistic mixed-length IN list scenario. +/// +/// Tests with: +/// - Mixed short (≤12 bytes) and long (>12 bytes) strings in the IN list +/// - Varying prefixes (fully random strings) +/// - Configurable match rate (% of values that are in the IN list) +/// - Various IN list sizes (3, 8, 28, 100) +fn bench_realistic_mixed_strings( + c: &mut Criterion, + rng: &mut StdRng, + make_scalar: fn(String) -> ScalarValue, +) where + A: Array + FromIterator> + 'static, +{ + for in_list_length in IN_LIST_LENGTHS { + for match_percent in [0.0, 0.25, 0.75] { + for null_percent in NULL_PERCENTS { + // Generate IN list with mixed-length random strings + let in_list_strings: Vec = (0..in_list_length) + .map(|_| random_mixed_length_string(rng)) + .collect(); + + let in_list: Vec<_> = in_list_strings + .iter() + .map(|s| make_scalar(s.clone())) + .collect(); + + // Generate values array with controlled match rate + let values: A = (0..ARRAY_LENGTH) + .map(|_| { + if !rng.random_bool(1.0 - null_percent) { + None + } else if rng.random_bool(match_percent) { + // Pick from IN list (will match) + Some(in_list_strings.choose(rng).unwrap().clone()) + } else { + // Generate new random string (unlikely to match) + Some(random_mixed_length_string(rng)) + } + }) + .collect(); - let values: Int32Array = (0..array_length) - .map(|_| rng.random_bool(null_percent).then(|| rng.random())) + do_bench( + c, + &format!( + "in_list/{}/mixed/list={}/match={}%/nulls={}%", + array_type_name::(), + in_list_length, + (match_percent * 100.0) as u32, + (null_percent * 100.0) as u32 + ), + Arc::new(values), + &in_list, + ); + } + } + } +} + +/// Benchmarks the column-reference evaluation path (no static filter) by including +/// a column reference in the IN list, which prevents static filter creation. +/// +/// This simulates SQL like: +/// ```sql +/// CREATE TABLE t (a INT, b0 INT, b1 INT, b2 INT); +/// SELECT * FROM t WHERE a IN (b0, b1, b2); +/// ``` +/// +/// - `values`: the "needle" column (`a`) +/// - `list_cols`: the "haystack" columns (`b0`, `b1`, …) +fn do_bench_with_columns( + c: &mut Criterion, + name: &str, + values: ArrayRef, + list_cols: &[ArrayRef], +) { + let mut fields = vec![Field::new("a", values.data_type().clone(), true)]; + let mut columns: Vec = vec![values]; + + // Build list expressions: column refs (forces non-constant evaluation path) + let schema_fields: Vec = list_cols + .iter() + .enumerate() + .map(|(i, col_arr)| { + let name = format!("b{i}"); + fields.push(Field::new(&name, col_arr.data_type().clone(), true)); + columns.push(Arc::clone(col_arr)); + Field::new(&name, col_arr.data_type().clone(), true) + }) .collect(); - let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Int32(Some(rng.random()))) + let schema = Schema::new(fields); + let list_exprs: Vec> = schema_fields + .iter() + .map(|f| col(f.name(), &schema).unwrap()) .collect(); - do_bench( - c, - &format!("in_list_i32 ({array_length}, {null_percent}) IN ({in_list_length}, 0)"), - Arc::new(values), - &in_list, - ) + let expr = in_list(col("a", &schema).unwrap(), list_exprs, &false, &schema).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); + + c.bench_function(name, |b| { + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); } -fn criterion_benchmark(c: &mut Criterion) { - for in_list_length in [1, 3, 10, 100] { - for null_percent in [0., 0.2] { - do_benches(c, 1024, in_list_length, null_percent) +/// Benchmarks the IN list path with column references for Int32 arrays. +/// +/// Equivalent SQL: +/// ```sql +/// CREATE TABLE t (a INT, b0 INT, b1 INT, ...); +/// SELECT * FROM t WHERE a IN (b0, b1, ...); +/// ``` +fn bench_with_columns_int32(c: &mut Criterion) { + let mut rng = StdRng::seed_from_u64(42); + + for list_size in LIST_WITH_COLUMNS_LENGTHS { + for match_percent in MATCH_PERCENTS { + for null_percent in NULL_PERCENTS { + // Generate the "needle" column + let values: Int32Array = (0..ARRAY_LENGTH) + .map(|_| { + rng.random_bool(1.0 - null_percent) + .then(|| rng.random_range(0..1000)) + }) + .collect(); + + // Generate list columns with controlled match rate + let list_cols: Vec = (0..list_size) + .map(|_| { + let col: Int32Array = (0..ARRAY_LENGTH) + .map(|row| { + if rng.random_bool(1.0 - null_percent) { + if rng.random_bool(match_percent) { + // Copy from values to create a match + if values.is_null(row) { + Some(rng.random_range(0..1000)) + } else { + Some(values.value(row)) + } + } else { + // Random value (unlikely to match) + Some(rng.random_range(1000..2000)) + } + } else { + None + } + }) + .collect(); + Arc::new(col) as ArrayRef + }) + .collect(); + + do_bench_with_columns( + c, + &format!( + "in_list_cols/Int32/list={}/match={}%/nulls={}%", + list_size, + (match_percent * 100.0) as u32, + (null_percent * 100.0) as u32 + ), + Arc::new(values), + &list_cols, + ); + } + } + } +} + +/// Benchmarks the IN list path with column references for Utf8 arrays. +/// +/// Equivalent SQL: +/// ```sql +/// CREATE TABLE t (a VARCHAR, b0 VARCHAR, b1 VARCHAR, ...); +/// SELECT * FROM t WHERE a IN (b0, b1, ...); +/// ``` +fn bench_with_columns_utf8(c: &mut Criterion) { + let mut rng = StdRng::seed_from_u64(99); + + for list_size in LIST_WITH_COLUMNS_LENGTHS { + for match_percent in MATCH_PERCENTS { + // Generate the "needle" column + let value_strings: Vec> = (0..ARRAY_LENGTH) + .map(|_| rng.random_bool(0.8).then(|| random_string(&mut rng, 12))) + .collect(); + let values: StringArray = + value_strings.iter().map(|s| s.as_deref()).collect(); + + // Generate list columns with controlled match rate + let list_cols: Vec = (0..list_size) + .map(|_| { + let col: StringArray = (0..ARRAY_LENGTH) + .map(|row| { + if rng.random_bool(match_percent) { + // Copy from values to create a match + value_strings[row].as_deref() + } else { + Some("no_match_value_xyz") + } + }) + .collect(); + Arc::new(col) as ArrayRef + }) + .collect(); + + do_bench_with_columns( + c, + &format!( + "in_list_cols/Utf8/list={}/match={}%", + list_size, + (match_percent * 100.0) as u32, + ), + Arc::new(values), + &list_cols, + ); } } } -criterion_group!(benches, criterion_benchmark); +/// Entry point: registers in_list benchmarks for string and numeric array types. +fn criterion_benchmark(c: &mut Criterion) { + let mut rng = StdRng::seed_from_u64(120320); + + // Benchmarks for string array types (Utf8, Utf8View) + bench_string_type::(c, &mut rng, |s| ScalarValue::Utf8(Some(s))); + bench_string_type::(c, &mut rng, |s| ScalarValue::Utf8View(Some(s))); + + // Realistic mixed-length string benchmarks (TPC-H style) + bench_realistic_mixed_strings::(c, &mut rng, |s| { + ScalarValue::Utf8(Some(s)) + }); + bench_realistic_mixed_strings::(c, &mut rng, |s| { + ScalarValue::Utf8View(Some(s)) + }); + + // Benchmarks for numeric types + bench_numeric_type::( + c, + &mut rng, + |rng| rng.random(), + |v| ScalarValue::UInt8(Some(v)), + ); + bench_numeric_type::( + c, + &mut rng, + |rng| rng.random(), + |v| ScalarValue::Int16(Some(v)), + ); + bench_numeric_type::( + c, + &mut rng, + |rng| rng.random(), + |v| ScalarValue::Float32(Some(v)), + ); + bench_numeric_type::( + c, + &mut rng, + |rng| rng.random(), + |v| ScalarValue::Int32(Some(v)), + ); + bench_numeric_type::( + c, + &mut rng, + |rng| rng.random(), + |v| ScalarValue::TimestampNanosecond(Some(v), None), + ); + + // Column-reference path benchmarks (non-constant list expressions) + bench_with_columns_int32(c); + bench_with_columns_utf8(c); +} + +criterion_group! { + name = benches; + config = Criterion::default() + .warm_up_time(Duration::from_millis(100)) + .measurement_time(Duration::from_millis(500)); + targets = criterion_benchmark +} criterion_main!(benches); diff --git a/datafusion/physical-expr/benches/in_list_strategy.rs b/datafusion/physical-expr/benches/in_list_strategy.rs new file mode 100644 index 0000000000000..5c4922fdcf8a9 --- /dev/null +++ b/datafusion/physical-expr/benches/in_list_strategy.rs @@ -0,0 +1,1037 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Focused benchmarks for `InList` cases. +//! +//! This benchmark file adds targeted coverage for representative `IN LIST` +//! workloads with controlled parameters: +//! +//! - **Controlled match rates**: Exercises both hit-heavy and miss-heavy paths +//! - **List size scaling**: Measures behavior across small and large `IN` lists +//! - **Type coverage**: Covers primitive, string, string-view, dictionary, and +//! fixed-size-binary inputs +//! - **Shared-prefix strings**: Adds collision-heavy string cases where values +//! only differ late in the string +//! - **Mixed-length strings**: Covers inputs that combine short and long values +//! - **Null handling**: Includes representative `NULL` and `NOT IN` cases +//! +//! # Case Coverage +//! +//! | Case | Types | Characteristics | List Sizes Tested | +//! |------|-------|-----------------|-------------------| +//! | Narrow integer cases | UInt8 | small value domain | 4, 16 | +//! | Narrow integer cases | Int16 | larger value domain | 4, 64, 256 | +//! | 32-bit primitive cases | Int32, Float32 | small and large lists | 4, 32, 64, 256 | +//! | 64-bit primitive cases | Int64, TimestampNs | small and large lists | 4, 16, 32, 128 | +//! | Utf8 short-string cases | Utf8 | 8-byte strings | 4, 64, 256 | +//! | Utf8 long-string cases | Utf8 | 24-byte strings | 4, 64, 256 | +//! | Utf8View short-string cases | Utf8View | 8-byte strings | 4, 16, 64, 256 | +//! | Utf8View length-12 cases | Utf8View | 12-byte strings | 16, 64 | +//! | Utf8View long-string cases | Utf8View | 24-byte strings | 4, 16, 64, 256 | +//! | Shared-prefix string cases | Utf8, Utf8View | same prefix, different suffix | 16, 32, 64 | +//! | Fixed-size binary cases | FixedSizeBinary(16) | fixed-width binary values | 4, 64, 256, 10000 | + +use arrow::array::*; +use arrow::datatypes::{Field, Int32Type, Schema}; +use arrow::record_batch::RecordBatch; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_physical_expr::expressions::{col, in_list, lit}; +use rand::distr::Alphanumeric; +use rand::prelude::*; +use std::sync::Arc; + +const ARRAY_SIZE: usize = 8192; + +/// Match rates to test both code paths (miss-heavy and balanced) +const MATCH_RATES: [u32; 2] = [0, 50]; + +// ============================================================================= +// NUMERIC BENCHMARK HELPERS +// ============================================================================= + +/// Configuration for numeric benchmarks, grouping test parameters. +struct NumericBenchConfig { + list_size: usize, + match_rate: f64, + null_rate: f64, + make_value: fn(&mut StdRng) -> T, + to_scalar: fn(T) -> ScalarValue, + negated: bool, +} + +impl NumericBenchConfig { + fn new( + list_size: usize, + match_rate: f64, + make_value: fn(&mut StdRng) -> T, + to_scalar: fn(T) -> ScalarValue, + ) -> Self { + Self { + list_size, + match_rate, + null_rate: 0.0, + make_value, + to_scalar, + negated: false, + } + } + + fn with_null_rate(mut self, null_rate: f64) -> Self { + self.null_rate = null_rate; + self + } + + fn with_negated(mut self) -> Self { + self.negated = true; + self + } +} + +/// Creates and runs a benchmark for numeric types with controlled match rate. +/// Uses a seed derived from list_size to avoid subset correlation between sizes. +fn bench_numeric( + c: &mut Criterion, + group: &str, + name: &str, + cfg: &NumericBenchConfig, +) where + T: Clone, + A: Array + FromIterator> + 'static, +{ + // Use different seed per list_size to avoid subset correlation + let seed = 0xDEAD_BEEF_u64.wrapping_add(cfg.list_size as u64 * 0x1234_5678); + let mut rng = StdRng::seed_from_u64(seed); + + // Generate IN list values + let haystack: Vec = (0..cfg.list_size) + .map(|_| (cfg.make_value)(&mut rng)) + .collect(); + + // Generate array with controlled match rate and null rate + let values: A = (0..ARRAY_SIZE) + .map(|_| { + if cfg.null_rate > 0.0 && rng.random_bool(cfg.null_rate) { + None + } else if !haystack.is_empty() && rng.random_bool(cfg.match_rate) { + Some(haystack.choose(&mut rng).unwrap().clone()) + } else { + Some((cfg.make_value)(&mut rng)) + } + }) + .collect(); + + let schema = Schema::new(vec![Field::new("a", values.data_type().clone(), true)]); + let exprs: Vec<_> = haystack + .iter() + .map(|v: &T| lit((cfg.to_scalar)(v.clone()))) + .collect(); + let expr = in_list(col("a", &schema).unwrap(), exprs, &cfg.negated, &schema).unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values) as ArrayRef]) + .unwrap(); + + c.bench_with_input(BenchmarkId::new(group, name), &batch, |b, batch| { + b.iter(|| expr.evaluate(batch).unwrap()) + }); +} + +// ============================================================================= +// STRING BENCHMARK HELPERS +// ============================================================================= + +fn random_string(rng: &mut StdRng, len: usize) -> String { + String::from_utf8(rng.sample_iter(&Alphanumeric).take(len).collect()).unwrap() +} + +/// Creates a set of strings that share a common prefix but differ in suffix. +/// Uses random alphanumeric suffix to avoid bench-maxing on numeric patterns. +fn strings_with_shared_prefix( + rng: &mut StdRng, + count: usize, + prefix_len: usize, +) -> Vec { + let prefix = random_string(rng, prefix_len); + (0..count) + .map(|_| format!("{}{}", prefix, random_string(rng, 8))) // prefix + random 8-char suffix + .collect() +} + +/// Configuration for string benchmarks, grouping test parameters. +struct StringBenchConfig { + list_size: usize, + match_rate: f64, + null_rate: f64, + string_len: usize, + to_scalar: fn(String) -> ScalarValue, + negated: bool, +} + +impl StringBenchConfig { + fn new( + list_size: usize, + match_rate: f64, + string_len: usize, + to_scalar: fn(String) -> ScalarValue, + ) -> Self { + Self { + list_size, + match_rate, + null_rate: 0.0, + string_len, + to_scalar, + negated: false, + } + } + + fn with_null_rate(mut self, null_rate: f64) -> Self { + self.null_rate = null_rate; + self + } + + fn with_negated(mut self) -> Self { + self.negated = true; + self + } +} + +/// Creates and runs a benchmark for string types with controlled match rate. +/// Uses a seed derived from list_size and string_len to avoid correlation. +fn bench_string(c: &mut Criterion, group: &str, name: &str, cfg: &StringBenchConfig) +where + A: Array + FromIterator> + 'static, +{ + // Use different seed per (list_size, string_len) to avoid correlation + let seed = 0xCAFE_BABE_u64 + .wrapping_add(cfg.list_size as u64 * 0x1111) + .wrapping_add(cfg.string_len as u64 * 0x2222); + let mut rng = StdRng::seed_from_u64(seed); + + // Generate IN list values + let haystack: Vec = (0..cfg.list_size) + .map(|_| random_string(&mut rng, cfg.string_len)) + .collect(); + + // Generate array with controlled match rate and null rate + let values: A = (0..ARRAY_SIZE) + .map(|_| { + if cfg.null_rate > 0.0 && rng.random_bool(cfg.null_rate) { + None + } else if !haystack.is_empty() && rng.random_bool(cfg.match_rate) { + Some(haystack.choose(&mut rng).unwrap().clone()) + } else { + Some(random_string(&mut rng, cfg.string_len)) + } + }) + .collect(); + + let schema = Schema::new(vec![Field::new("a", values.data_type().clone(), true)]); + let exprs: Vec<_> = haystack + .iter() + .map(|v| lit((cfg.to_scalar)(v.clone()))) + .collect(); + let expr = in_list(col("a", &schema).unwrap(), exprs, &cfg.negated, &schema).unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values) as ArrayRef]) + .unwrap(); + + c.bench_with_input(BenchmarkId::new(group, name), &batch, |b, batch| { + b.iter(|| expr.evaluate(batch).unwrap()) + }); +} + +/// Benchmarks strings with shared prefixes and different suffixes. +/// Uses variable prefix lengths and random suffixes to avoid bench-maxing. +fn bench_string_shared_prefix( + c: &mut Criterion, + group: &str, + name: &str, + list_size: usize, + match_rate: f64, + prefix_len: usize, + to_scalar: fn(String) -> ScalarValue, +) where + A: Array + FromIterator> + 'static, +{ + let seed = 0xFEED_FACE_u64 + .wrapping_add(list_size as u64 * 0x3333) + .wrapping_add(prefix_len as u64 * 0x4444); + let mut rng = StdRng::seed_from_u64(seed); + + // Generate IN list with a shared prefix. + let haystack = strings_with_shared_prefix(&mut rng, list_size, prefix_len); + + // Generate non-matching strings with the same prefix to keep misses close + // to the matching set. + let non_match_pool = strings_with_shared_prefix(&mut rng, 100, prefix_len); + + // Generate array with controlled match rate + let values: A = (0..ARRAY_SIZE) + .map(|_| { + Some(if !haystack.is_empty() && rng.random_bool(match_rate) { + haystack.choose(&mut rng).unwrap().clone() + } else { + non_match_pool.choose(&mut rng).unwrap().clone() + }) + }) + .collect(); + + let schema = Schema::new(vec![Field::new("a", values.data_type().clone(), true)]); + let exprs: Vec<_> = haystack.iter().map(|v| lit(to_scalar(v.clone()))).collect(); + let expr = in_list(col("a", &schema).unwrap(), exprs, &false, &schema).unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values) as ArrayRef]) + .unwrap(); + + c.bench_with_input(BenchmarkId::new(group, name), &batch, |b, batch| { + b.iter(|| expr.evaluate(batch).unwrap()) + }); +} + +/// Benchmarks mixed-length strings (some short <= 12, some long > 12). +/// Uses a more realistic length distribution than the fixed-width cases. +fn bench_string_mixed_lengths( + c: &mut Criterion, + group: &str, + name: &str, + list_size: usize, + match_rate: f64, + to_scalar: fn(String) -> ScalarValue, +) where + A: Array + FromIterator> + 'static, +{ + let seed = 0xABCD_EF01_u64.wrapping_add(list_size as u64 * 0x5555); + let mut rng = StdRng::seed_from_u64(seed); + + // Mixed lengths: some short (<= 12), some long (> 12) + let lengths = [4, 8, 12, 16, 20, 24]; + + // Generate IN list with mixed lengths + let haystack: Vec = (0..list_size) + .map(|_| { + let len = *lengths.choose(&mut rng).unwrap(); + random_string(&mut rng, len) + }) + .collect(); + + // Generate array with controlled match rate and mixed lengths + let values: A = (0..ARRAY_SIZE) + .map(|_| { + Some(if !haystack.is_empty() && rng.random_bool(match_rate) { + haystack.choose(&mut rng).unwrap().clone() + } else { + let len = *lengths.choose(&mut rng).unwrap(); + random_string(&mut rng, len) + }) + }) + .collect(); + + let schema = Schema::new(vec![Field::new("a", values.data_type().clone(), true)]); + let exprs: Vec<_> = haystack.iter().map(|v| lit(to_scalar(v.clone()))).collect(); + let expr = in_list(col("a", &schema).unwrap(), exprs, &false, &schema).unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values) as ArrayRef]) + .unwrap(); + + c.bench_with_input(BenchmarkId::new(group, name), &batch, |b, batch| { + b.iter(|| expr.evaluate(batch).unwrap()) + }); +} + +// ============================================================================= +// NARROW INTEGER CASE BENCHMARKS +// ============================================================================= + +fn bench_narrow_integer(c: &mut Criterion) { + // UInt8: small value domain + // NOTE: With 256 possible values, list_size=16 covers 6.25% of value space, + // so even "match=0%" has ~6% accidental matches from random data. + for list_size in [4, 16] { + for match_pct in MATCH_RATES { + bench_numeric::( + c, + "narrow_integer", + &format!("u8/list={list_size}/match={match_pct}%"), + &NumericBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + |rng| rng.random(), + |v| ScalarValue::UInt8(Some(v)), + ), + ); + } + } + + // Int16: larger value domain with wider list sizes + for list_size in [4, 64, 256] { + for match_pct in MATCH_RATES { + bench_numeric::( + c, + "narrow_integer", + &format!("i16/list={list_size}/match={match_pct}%"), + &NumericBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + |rng| rng.random(), + |v| ScalarValue::Int16(Some(v)), + ), + ); + } + } +} + +// ============================================================================= +// PRIMITIVE SIZE-SCALING BENCHMARKS +// ============================================================================= + +fn bench_primitive(c: &mut Criterion) { + // Int32: small and larger list sizes + for list_size in [4, 32, 64, 256] { + let list_case = if list_size <= 32 { + "small_list" + } else { + "large_list" + }; + for match_pct in MATCH_RATES { + bench_numeric::( + c, + "primitive", + &format!("i32/{list_case}/list={list_size}/match={match_pct}%"), + &NumericBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + |rng| rng.random(), + |v| ScalarValue::Int32(Some(v)), + ), + ); + } + } + + // Int64: small and larger list sizes + for list_size in [4, 16, 32, 128] { + let list_case = if list_size <= 16 { + "small_list" + } else { + "large_list" + }; + for match_pct in MATCH_RATES { + bench_numeric::( + c, + "primitive", + &format!("i64/{list_case}/list={list_size}/match={match_pct}%"), + &NumericBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + |rng| rng.random(), + |v| ScalarValue::Int64(Some(v)), + ), + ); + } + } + + // NOT IN benchmark: test negated path + bench_numeric::( + c, + "primitive", + "i32/small_list/list=16/match=50%/NOT_IN", + &NumericBenchConfig::new( + 16, + 0.5, + |rng| rng.random(), + |v| ScalarValue::Int32(Some(v)), + ) + .with_negated(), + ); +} + +// ============================================================================= +// FLOAT AND TIMESTAMP CASE BENCHMARKS +// ============================================================================= + +fn bench_f32(c: &mut Criterion) { + // Float32: uses the same list sizes as the Int32 cases. + for list_size in [4, 32, 64] { + let list_case = if list_size <= 32 { + "small_list" + } else { + "large_list" + }; + for match_pct in MATCH_RATES { + bench_numeric::( + c, + "f32", + &format!("{list_case}/list={list_size}/match={match_pct}%"), + &NumericBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + |rng| rng.random::() * 1000.0, + |v| ScalarValue::Float32(Some(v)), + ), + ); + } + } +} + +fn bench_timestamp_ns(c: &mut Criterion) { + // TimestampNanosecond: uses the same list sizes as the Int64-style cases. + for list_size in [4, 16, 32] { + let list_case = if list_size <= 16 { + "small_list" + } else { + "large_list" + }; + for match_pct in MATCH_RATES { + bench_numeric::( + c, + "timestamp_ns", + &format!("{list_case}/list={list_size}/match={match_pct}%"), + &NumericBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + |rng| rng.random::().abs(), + |v| ScalarValue::TimestampNanosecond(Some(v), None), + ), + ); + } + } +} + +// ============================================================================= +// UTF8 STRING CASE BENCHMARKS +// ============================================================================= + +fn bench_utf8(c: &mut Criterion) { + let to_scalar: fn(String) -> ScalarValue = |s| ScalarValue::Utf8(Some(s)); + + // Short strings (8 bytes) + for list_size in [4, 64, 256] { + for match_pct in MATCH_RATES { + bench_string::( + c, + "utf8", + &format!("short_8b/list={list_size}/match={match_pct}%"), + &StringBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + 8, + to_scalar, + ), + ); + } + } + + // Long strings (24 bytes) + for list_size in [4, 64, 256] { + for match_pct in MATCH_RATES { + bench_string::( + c, + "utf8", + &format!("long_24b/list={list_size}/match={match_pct}%"), + &StringBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + 24, + to_scalar, + ), + ); + } + } + + // Mixed-length strings: realistic distribution + for list_size in [16, 64] { + for match_pct in MATCH_RATES { + bench_string_mixed_lengths::( + c, + "utf8", + &format!("mixed_len/list={list_size}/match={match_pct}%"), + list_size, + match_pct as f64 / 100.0, + to_scalar, + ); + } + } + + // Shared-prefix strings: same prefix, different suffix + bench_string_shared_prefix::( + c, + "utf8", + "shared_prefix/pfx=12/list=32/match=50%", + 32, + 0.5, + 12, + to_scalar, + ); + + // NOT IN benchmark + bench_string::( + c, + "utf8", + "short_8b/list=16/match=50%/NOT_IN", + &StringBenchConfig::new(16, 0.5, 8, to_scalar).with_negated(), + ); +} + +// ============================================================================= +// UTF8VIEW STRING CASE BENCHMARKS +// ============================================================================= + +fn bench_utf8view(c: &mut Criterion) { + let to_scalar: fn(String) -> ScalarValue = |s| ScalarValue::Utf8View(Some(s)); + + // Short strings (8 bytes) + for list_size in [4, 16, 64, 256] { + for match_pct in MATCH_RATES { + bench_string::( + c, + "utf8view", + &format!("short_8b/list={list_size}/match={match_pct}%"), + &StringBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + 8, + to_scalar, + ), + ); + } + } + + // Length-12 strings + for list_size in [16, 64] { + for match_pct in MATCH_RATES { + bench_string::( + c, + "utf8view", + &format!("len_12b/list={list_size}/match={match_pct}%"), + &StringBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + 12, + to_scalar, + ), + ); + } + } + + // Long strings (24 bytes) + for list_size in [4, 16, 64, 256] { + for match_pct in MATCH_RATES { + bench_string::( + c, + "utf8view", + &format!("long_24b/list={list_size}/match={match_pct}%"), + &StringBenchConfig::new( + list_size, + match_pct as f64 / 100.0, + 24, + to_scalar, + ), + ); + } + } + + // Mixed-length strings: realistic distribution + for list_size in [16, 64] { + for match_pct in MATCH_RATES { + bench_string_mixed_lengths::( + c, + "utf8view", + &format!("mixed_len/list={list_size}/match={match_pct}%"), + list_size, + match_pct as f64 / 100.0, + to_scalar, + ); + } + } + + // Shared-prefix strings with varying prefix lengths + for (prefix_len, list_size) in [(8, 16), (12, 32), (16, 64)] { + for match_pct in MATCH_RATES { + bench_string_shared_prefix::( + c, + "utf8view", + &format!( + "shared_prefix/pfx={prefix_len}/list={list_size}/match={match_pct}%" + ), + list_size, + match_pct as f64 / 100.0, + prefix_len, + to_scalar, + ); + } + } +} + +// ============================================================================= +// DICTIONARY ARRAY BENCHMARKS +// ============================================================================= + +/// Helper to benchmark dictionary-encoded Int32 arrays +fn bench_dict_int32( + c: &mut Criterion, + name: &str, + dict_size: usize, + list_size: usize, + negated: bool, +) { + let seed = 0xD1C7_0000_u64 + .wrapping_add(dict_size as u64 * 0x1111) + .wrapping_add(list_size as u64 * 0x2222); + let mut rng = StdRng::seed_from_u64(seed); + + let dict_values: Vec = (0..dict_size).map(|_| rng.random()).collect(); + let haystack: Vec = dict_values.iter().take(list_size).cloned().collect(); + + let indices: Vec = (0..ARRAY_SIZE) + .map(|_| rng.random_range(0..dict_size as i32)) + .collect(); + let indices_array = Int32Array::from(indices); + let values_array = Int32Array::from(dict_values); + let dict_array = + DictionaryArray::::try_new(indices_array, Arc::new(values_array)) + .unwrap(); + + let schema = Schema::new(vec![Field::new("a", dict_array.data_type().clone(), true)]); + let exprs: Vec<_> = haystack + .iter() + .map(|v| lit(ScalarValue::Int32(Some(*v)))) + .collect(); + let expr = in_list(col("a", &schema).unwrap(), exprs, &negated, &schema).unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict_array) as ArrayRef]) + .unwrap(); + + c.bench_with_input(BenchmarkId::new("dictionary", name), &batch, |b, batch| { + b.iter(|| expr.evaluate(batch).unwrap()) + }); +} + +/// Helper to benchmark dictionary-encoded string arrays +fn bench_dict_string( + c: &mut Criterion, + name: &str, + dict_size: usize, + list_size: usize, + string_len: usize, +) { + let seed = 0xD1C7_5778_u64 + .wrapping_add(dict_size as u64 * 0x3333) + .wrapping_add(string_len as u64 * 0x4444); + let mut rng = StdRng::seed_from_u64(seed); + + let dict_values: Vec = (0..dict_size) + .map(|_| random_string(&mut rng, string_len)) + .collect(); + let haystack: Vec = dict_values.iter().take(list_size).cloned().collect(); + + let indices: Vec = (0..ARRAY_SIZE) + .map(|_| rng.random_range(0..dict_size as i32)) + .collect(); + let indices_array = Int32Array::from(indices); + let values_array = StringArray::from(dict_values); + let dict_array = + DictionaryArray::::try_new(indices_array, Arc::new(values_array)) + .unwrap(); + + let schema = Schema::new(vec![Field::new("a", dict_array.data_type().clone(), true)]); + let exprs: Vec<_> = haystack + .iter() + .map(|v| lit(ScalarValue::Utf8(Some(v.clone())))) + .collect(); + let expr = in_list(col("a", &schema).unwrap(), exprs, &false, &schema).unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict_array) as ArrayRef]) + .unwrap(); + + c.bench_with_input(BenchmarkId::new("dictionary", name), &batch, |b, batch| { + b.iter(|| expr.evaluate(batch).unwrap()) + }); +} + +fn bench_dictionary(c: &mut Criterion) { + // Int32 dictionary: varying list sizes across dictionary values + // Dictionary with 100 unique values + for list_size in [4, 16, 64] { + bench_dict_int32( + c, + &format!("i32/dict=100/list={list_size}"), + 100, + list_size, + false, + ); + } + + // Int32 dictionary: varying dictionary cardinality + for dict_size in [10, 1000] { + bench_dict_int32( + c, + &format!("i32/dict={dict_size}/list=16"), + dict_size, + 16, + false, + ); + } + + // Int32 dictionary: NOT IN path + bench_dict_int32(c, "i32/dict=100/list=16/NOT_IN", 100, 16, true); + + // String dictionary: short strings (<= 12 bytes, common for codes/categories) + for list_size in [8, 32] { + bench_dict_string( + c, + &format!("utf8_short/dict=50/list={list_size}"), + 50, + list_size, + 8, + ); + } + + // String dictionary: long strings (>12 bytes) + bench_dict_string(c, "utf8_long/dict=100/list=16", 100, 16, 24); + + // String dictionary: large cardinality (realistic category counts) + bench_dict_string(c, "utf8_short/dict=500/list=20", 500, 20, 10); +} + +// ============================================================================= +// NULL HANDLING BENCHMARKS +// ============================================================================= +// +// Tests representative null-containing inputs across primitive and string cases. + +fn bench_nulls(c: &mut Criterion) { + // ========================================================================= + // PRIMITIVE CASES + // ========================================================================= + + // UInt8 case with nulls + bench_numeric::( + c, + "nulls", + "narrow_integer/u8/list=16/match=50%/nulls=20%", + &NumericBenchConfig::new( + 16, + 0.5, + |rng| rng.random(), + |v| ScalarValue::UInt8(Some(v)), + ) + .with_null_rate(0.2), + ); + + // Int32 small-list case with nulls + bench_numeric::( + c, + "nulls", + "primitive/i32/small_list/list=16/match=50%/nulls=20%", + &NumericBenchConfig::new( + 16, + 0.5, + |rng| rng.random(), + |v| ScalarValue::Int32(Some(v)), + ) + .with_null_rate(0.2), + ); + + // Int32 large-list case with nulls + bench_numeric::( + c, + "nulls", + "primitive/i32/large_list/list=64/match=50%/nulls=20%", + &NumericBenchConfig::new( + 64, + 0.5, + |rng| rng.random(), + |v| ScalarValue::Int32(Some(v)), + ) + .with_null_rate(0.2), + ); + + // ========================================================================= + // STRING CASES + // ========================================================================= + + let utf8_scalar: fn(String) -> ScalarValue = |s| ScalarValue::Utf8(Some(s)); + let utf8view_scalar: fn(String) -> ScalarValue = |s| ScalarValue::Utf8View(Some(s)); + + // Utf8 short-string case with nulls + bench_string::( + c, + "nulls", + "utf8/short_8b/list=16/match=50%/nulls=20%", + &StringBenchConfig::new(16, 0.5, 8, utf8_scalar).with_null_rate(0.2), + ); + + // Utf8 long-string case with nulls + bench_string::( + c, + "nulls", + "utf8/long_24b/list=16/match=50%/nulls=20%", + &StringBenchConfig::new(16, 0.5, 24, utf8_scalar).with_null_rate(0.2), + ); + + // Utf8View short-string case with nulls + bench_string::( + c, + "nulls", + "utf8view/short_8b/list=16/match=50%/nulls=20%", + &StringBenchConfig::new(16, 0.5, 8, utf8view_scalar).with_null_rate(0.2), + ); + + // Utf8View long-string case with nulls + bench_string::( + c, + "nulls", + "utf8view/long_24b/list=16/match=50%/nulls=20%", + &StringBenchConfig::new(16, 0.5, 24, utf8view_scalar).with_null_rate(0.2), + ); + + // ========================================================================= + // NOT IN CASES WITH NULLS + // ========================================================================= + + // Primitive NOT IN case with nulls + bench_numeric::( + c, + "nulls", + "primitive/i32/small_list/list=16/match=50%/nulls=20%/NOT_IN", + &NumericBenchConfig::new( + 16, + 0.5, + |rng| rng.random(), + |v| ScalarValue::Int32(Some(v)), + ) + .with_null_rate(0.2) + .with_negated(), + ); + + // String NOT IN case with nulls + bench_string::( + c, + "nulls", + "utf8view/short_8b/list=16/match=50%/nulls=20%/NOT_IN", + &StringBenchConfig::new(16, 0.5, 8, utf8view_scalar) + .with_null_rate(0.2) + .with_negated(), + ); + + // ========================================================================= + // HIGH NULL-RATE CASES + // ========================================================================= + + // 50% nulls - half the array is null + bench_numeric::( + c, + "nulls", + "primitive/i32/small_list/list=16/match=50%/nulls=50%", + &NumericBenchConfig::new( + 16, + 0.5, + |rng| rng.random(), + |v| ScalarValue::Int32(Some(v)), + ) + .with_null_rate(0.5), + ); + + bench_string::( + c, + "nulls", + "utf8view/short_8b/list=16/match=50%/nulls=50%", + &StringBenchConfig::new(16, 0.5, 8, utf8view_scalar).with_null_rate(0.5), + ); +} + +// ============================================================================= +// FIXED SIZE BINARY BENCHMARKS (FixedSizeBinary<16>, e.g. UUIDs) +// ============================================================================= + +/// Generates a random 16-byte value (UUID-sized). +fn random_fixed_binary_16(rng: &mut StdRng) -> Vec { + let mut buf = vec![0u8; 16]; + rng.fill(&mut buf[..]); + buf +} + +/// Benchmarks FixedSizeBinary(16) IN list evaluation. +/// FixedSizeBinary doesn't use the generic numeric helpers since its array +/// construction differs from primitive types. +fn bench_fixed_size_binary_inner( + c: &mut Criterion, + name: &str, + list_size: usize, + match_rate: f64, +) { + let seed = 0xF1ED_B1A7_u64.wrapping_add(list_size as u64 * 0x6666); + let mut rng = StdRng::seed_from_u64(seed); + + // Generate IN list values (16-byte each) + let haystack: Vec> = (0..list_size) + .map(|_| random_fixed_binary_16(&mut rng)) + .collect(); + + // Generate array with controlled match rate + let values: Vec> = (0..ARRAY_SIZE) + .map(|_| { + if !haystack.is_empty() && rng.random_bool(match_rate) { + haystack.choose(&mut rng).unwrap().clone() + } else { + random_fixed_binary_16(&mut rng) + } + }) + .collect(); + + let refs: Vec<&[u8]> = values.iter().map(|v| v.as_slice()).collect(); + let array = FixedSizeBinaryArray::from(refs); + + let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), true)]); + let exprs: Vec<_> = haystack + .iter() + .map(|v| lit(ScalarValue::FixedSizeBinary(16, Some(v.clone())))) + .collect(); + let expr = in_list(col("a", &schema).unwrap(), exprs, &false, &schema).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array) as ArrayRef]) + .unwrap(); + + c.bench_with_input( + BenchmarkId::new("fixed_size_binary", name), + &batch, + |b, batch| b.iter(|| expr.evaluate(batch).unwrap()), + ); +} + +fn bench_fixed_size_binary(c: &mut Criterion) { + for list_size in [4, 64, 256, 10000] { + for match_pct in MATCH_RATES { + bench_fixed_size_binary_inner( + c, + &format!("fsb16/list={list_size}/match={match_pct}%"), + list_size, + match_pct as f64 / 100.0, + ); + } + } +} + +// ============================================================================= +// CRITERION SETUP +// ============================================================================= + +criterion_group! { + name = benches; + config = Criterion::default(); + targets = bench_narrow_integer, bench_primitive, bench_f32, bench_timestamp_ns, bench_utf8, bench_utf8view, bench_dictionary, bench_nulls, bench_fixed_size_binary +} + +criterion_main!(benches); diff --git a/datafusion/physical-expr/benches/is_null.rs b/datafusion/physical-expr/benches/is_null.rs index 80b2907a9e989..0637ade1b3eec 100644 --- a/datafusion/physical-expr/benches/is_null.rs +++ b/datafusion/physical-expr/benches/is_null.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{builder::Int32Builder, RecordBatch}; +use arrow::array::{RecordBatch, builder::Int32Builder}; use arrow::datatypes::{DataType, Field, Schema}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_physical_expr::expressions::{Column, IsNotNullExpr, IsNullExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::hint::black_box; diff --git a/datafusion/physical-expr/benches/simplify.rs b/datafusion/physical-expr/benches/simplify.rs new file mode 100644 index 0000000000000..cc00c710004e8 --- /dev/null +++ b/datafusion/physical-expr/benches/simplify.rs @@ -0,0 +1,299 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This is an attempt at reproducing some predicates generated by TPC-DS query #76, +//! and trying to figure out how long it takes to simplify them. + +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; +use std::hint::black_box; +use std::sync::Arc; + +use datafusion_common::ScalarValue; +use datafusion_expr::Operator; + +use datafusion_physical_expr::expressions::{ + BinaryExpr, CaseExpr, Column, IsNullExpr, Literal, +}; + +fn catalog_sales_schema() -> Schema { + Schema::new(vec![ + Field::new("cs_sold_date_sk", DataType::Int64, true), // 0 + Field::new("cs_sold_time_sk", DataType::Int64, true), // 1 + Field::new("cs_ship_date_sk", DataType::Int64, true), // 2 + Field::new("cs_bill_customer_sk", DataType::Int64, true), // 3 + Field::new("cs_bill_cdemo_sk", DataType::Int64, true), // 4 + Field::new("cs_bill_hdemo_sk", DataType::Int64, true), // 5 + Field::new("cs_bill_addr_sk", DataType::Int64, true), // 6 + Field::new("cs_ship_customer_sk", DataType::Int64, true), // 7 + Field::new("cs_ship_cdemo_sk", DataType::Int64, true), // 8 + Field::new("cs_ship_hdemo_sk", DataType::Int64, true), // 9 + Field::new("cs_ship_addr_sk", DataType::Int64, true), // 10 + Field::new("cs_call_center_sk", DataType::Int64, true), // 11 + Field::new("cs_catalog_page_sk", DataType::Int64, true), // 12 + Field::new("cs_ship_mode_sk", DataType::Int64, true), // 13 + Field::new("cs_warehouse_sk", DataType::Int64, true), // 14 + Field::new("cs_item_sk", DataType::Int64, true), // 15 + Field::new("cs_promo_sk", DataType::Int64, true), // 16 + Field::new("cs_order_number", DataType::Int64, true), // 17 + Field::new("cs_quantity", DataType::Int64, true), // 18 + Field::new("cs_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("cs_list_price", DataType::Decimal128(7, 2), true), + Field::new("cs_sales_price", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_discount_amt", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_sales_price", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_list_price", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_tax", DataType::Decimal128(7, 2), true), + Field::new("cs_coupon_amt", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_ship_cost", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid_inc_tax", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid_inc_ship", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid_inc_ship_tax", DataType::Decimal128(7, 2), true), + Field::new("cs_net_profit", DataType::Decimal128(7, 2), true), + ]) +} + +fn web_sales_schema() -> Schema { + Schema::new(vec![ + Field::new("ws_sold_date_sk", DataType::Int64, true), + Field::new("ws_sold_time_sk", DataType::Int64, true), + Field::new("ws_ship_date_sk", DataType::Int64, true), + Field::new("ws_item_sk", DataType::Int64, true), + Field::new("ws_bill_customer_sk", DataType::Int64, true), + Field::new("ws_bill_cdemo_sk", DataType::Int64, true), + Field::new("ws_bill_hdemo_sk", DataType::Int64, true), + Field::new("ws_bill_addr_sk", DataType::Int64, true), + Field::new("ws_ship_customer_sk", DataType::Int64, true), + Field::new("ws_ship_cdemo_sk", DataType::Int64, true), + Field::new("ws_ship_hdemo_sk", DataType::Int64, true), + Field::new("ws_ship_addr_sk", DataType::Int64, true), + Field::new("ws_web_page_sk", DataType::Int64, true), + Field::new("ws_web_site_sk", DataType::Int64, true), + Field::new("ws_ship_mode_sk", DataType::Int64, true), + Field::new("ws_warehouse_sk", DataType::Int64, true), + Field::new("ws_promo_sk", DataType::Int64, true), + Field::new("ws_order_number", DataType::Int64, true), + Field::new("ws_quantity", DataType::Int64, true), + Field::new("ws_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("ws_list_price", DataType::Decimal128(7, 2), true), + Field::new("ws_sales_price", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_discount_amt", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_sales_price", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_list_price", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_tax", DataType::Decimal128(7, 2), true), + Field::new("ws_coupon_amt", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_ship_cost", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid_inc_tax", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid_inc_ship", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid_inc_ship_tax", DataType::Decimal128(7, 2), true), + Field::new("ws_net_profit", DataType::Decimal128(7, 2), true), + ]) +} + +// Helper to create a literal +fn lit_i64(val: i64) -> Arc { + Arc::new(Literal::new(ScalarValue::Int64(Some(val)))) +} + +fn lit_i32(val: i32) -> Arc { + Arc::new(Literal::new(ScalarValue::Int32(Some(val)))) +} + +fn lit_bool(val: bool) -> Arc { + Arc::new(Literal::new(ScalarValue::Boolean(Some(val)))) +} + +// Helper to create binary expressions +fn and( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::And, right)) +} + +fn gte( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::GtEq, right)) +} + +fn lte( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::LtEq, right)) +} + +fn modulo( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::Modulo, right)) +} + +fn eq( + left: Arc, + right: Arc, +) -> Arc { + Arc::new(BinaryExpr::new(left, Operator::Eq, right)) +} + +/// Build a predicate similar to TPC-DS q76 catalog_sales filter. +/// Uses placeholder columns instead of hash expressions. +pub fn catalog_sales_predicate(num_partitions: usize) -> Arc { + let cs_sold_date_sk: Arc = + Arc::new(Column::new("cs_sold_date_sk", 0)); + let cs_ship_addr_sk: Arc = + Arc::new(Column::new("cs_ship_addr_sk", 10)); + let cs_item_sk: Arc = Arc::new(Column::new("cs_item_sk", 15)); + + // Use a simple modulo expression as placeholder for hash + let item_hash_mod = modulo(cs_item_sk.clone(), lit_i64(num_partitions as i64)); + let date_hash_mod = modulo(cs_sold_date_sk.clone(), lit_i64(num_partitions as i64)); + + // cs_ship_addr_sk IS NULL + let is_null_expr: Arc = Arc::new(IsNullExpr::new(cs_ship_addr_sk)); + + // Build item_sk CASE expression with num_partitions branches + let item_when_then: Vec<(Arc, Arc)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(item_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(cs_item_sk.clone(), lit_i64(partition as i64)), + lte(cs_item_sk.clone(), lit_i64(18000)), + ); + (when_expr, then_expr) + }) + .collect(); + + let item_case_expr: Arc = + Arc::new(CaseExpr::try_new(None, item_when_then, Some(lit_bool(false))).unwrap()); + + // Build sold_date_sk CASE expression with num_partitions branches + let date_when_then: Vec<(Arc, Arc)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(date_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(cs_sold_date_sk.clone(), lit_i64(2415000 + partition as i64)), + lte(cs_sold_date_sk.clone(), lit_i64(2488070)), + ); + (when_expr, then_expr) + }) + .collect(); + + let date_case_expr: Arc = + Arc::new(CaseExpr::try_new(None, date_when_then, Some(lit_bool(false))).unwrap()); + + // Final: is_null AND item_case AND date_case + and(and(is_null_expr, item_case_expr), date_case_expr) +} +/// Build a predicate similar to TPC-DS q76 web_sales filter. +/// Uses placeholder columns instead of hash expressions. +fn web_sales_predicate(num_partitions: usize) -> Arc { + let ws_sold_date_sk: Arc = + Arc::new(Column::new("ws_sold_date_sk", 0)); + let ws_item_sk: Arc = Arc::new(Column::new("ws_item_sk", 3)); + let ws_ship_customer_sk: Arc = + Arc::new(Column::new("ws_ship_customer_sk", 8)); + + // Use simple modulo expression as placeholder for hash + let item_hash_mod = modulo(ws_item_sk.clone(), lit_i64(num_partitions as i64)); + let date_hash_mod = modulo(ws_sold_date_sk.clone(), lit_i64(num_partitions as i64)); + + // ws_ship_customer_sk IS NULL + let is_null_expr: Arc = + Arc::new(IsNullExpr::new(ws_ship_customer_sk)); + + // Build item_sk CASE expression with num_partitions branches + let item_when_then: Vec<(Arc, Arc)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(item_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(ws_item_sk.clone(), lit_i64(partition as i64)), + lte(ws_item_sk.clone(), lit_i64(18000)), + ); + (when_expr, then_expr) + }) + .collect(); + + let item_case_expr: Arc = + Arc::new(CaseExpr::try_new(None, item_when_then, Some(lit_bool(false))).unwrap()); + + // Build sold_date_sk CASE expression with num_partitions branches + let date_when_then: Vec<(Arc, Arc)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(date_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(ws_sold_date_sk.clone(), lit_i64(2415000 + partition as i64)), + lte(ws_sold_date_sk.clone(), lit_i64(2488070)), + ); + (when_expr, then_expr) + }) + .collect(); + + let date_case_expr: Arc = + Arc::new(CaseExpr::try_new(None, date_when_then, Some(lit_bool(false))).unwrap()); + + and(and(is_null_expr, item_case_expr), date_case_expr) +} + +/// Measures how long `PhysicalExprSimplifier::simplify` takes for a given expression. +fn bench_simplify( + c: &mut Criterion, + name: &str, + schema: &Schema, + expr: &Arc, +) { + let simplifier = PhysicalExprSimplifier::new(schema); + c.bench_function(name, |b| { + b.iter(|| black_box(simplifier.simplify(black_box(Arc::clone(expr))).unwrap())) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let cs_schema = catalog_sales_schema(); + let ws_schema = web_sales_schema(); + + for num_partitions in [16, 128] { + bench_simplify( + c, + &format!("tpc-ds/q76/cs/{num_partitions}"), + &cs_schema, + &catalog_sales_predicate(num_partitions), + ); + bench_simplify( + c, + &format!("tpc-ds/q76/ws/{num_partitions}"), + &ws_schema, + &web_sales_predicate(num_partitions), + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-expr/benches/string_concat.rs b/datafusion/physical-expr/benches/string_concat.rs new file mode 100644 index 0000000000000..23f54c7637bdd --- /dev/null +++ b/datafusion/physical-expr/benches/string_concat.rs @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::StringViewArray; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_expr::Operator; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::expressions::{BinaryExpr, Column}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +const NUM_ROWS: usize = 8192; +const SEED: u64 = 42; + +fn create_string_view_array( + num_rows: usize, + str_len: usize, + null_density: f64, + seed: u64, +) -> StringViewArray { + let mut rng = StdRng::seed_from_u64(seed); + let values: Vec> = (0..num_rows) + .map(|_| { + if rng.random::() < null_density { + None + } else { + let s: String = (0..str_len) + .map(|_| rng.random_range(b'a'..=b'z') as char) + .collect(); + Some(s) + } + }) + .collect(); + StringViewArray::from_iter(values) +} + +fn bench_concat_utf8view(c: &mut Criterion) { + let mut group = c.benchmark_group("concat_utf8view"); + + let schema = Arc::new(Schema::new(vec![ + Field::new("left", DataType::Utf8View, true), + Field::new("right", DataType::Utf8View, true), + ])); + + // left || right + let expr = BinaryExpr::new( + Arc::new(Column::new("left", 0)), + Operator::StringConcat, + Arc::new(Column::new("right", 1)), + ); + + for null_density in [0.0, 0.1, 0.5] { + let left = create_string_view_array(NUM_ROWS, 16, null_density, SEED); + let right = create_string_view_array(NUM_ROWS, 16, null_density, SEED + 1); + + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(left), Arc::new(right)]) + .unwrap(); + + let label = format!("nulls_{}", (null_density * 100.0) as u32); + group.bench_with_input( + BenchmarkId::new("concat", &label), + &null_density, + |b, _| { + b.iter(|| { + black_box(expr.evaluate(black_box(&batch)).unwrap()); + }) + }, + ); + } + + group.finish(); +} + +criterion_group!(benches, bench_concat_utf8view); +criterion_main!(benches); diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index f16895b44bf5e..e5d55aba4f51c 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -16,12 +16,12 @@ // under the License. pub(crate) mod groups_accumulator { - #[allow(unused_imports)] + #[expect(unused_imports)] pub(crate) mod accumulate { pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; } pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::{ - accumulate::NullState, GroupsAccumulatorAdapter, + GroupsAccumulatorAdapter, accumulate::NullState, }; } pub(crate) mod stats { @@ -29,8 +29,8 @@ pub(crate) mod stats { } pub mod utils { pub use datafusion_functions_aggregate_common::utils::{ - get_accum_scalar_values_as_arrays, get_sort_options, ordering_fields, - DecimalAverager, Hashable, + DecimalAverager, Hashable, get_accum_scalar_values_as_arrays, get_sort_options, + ordering_fields, }; } @@ -38,13 +38,20 @@ use std::fmt::Debug; use std::sync::Arc; use crate::expressions::Column; +use crate::physical_expr::create_physical_sort_exprs; +use crate::planner::{create_physical_expr, create_physical_exprs}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef}; +use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ - assert_or_internal_err, internal_err, not_impl_err, Result, ScalarValue, + DFSchema, Result, ScalarValue, assert_or_internal_err, internal_err, not_impl_err, }; -use datafusion_expr::{AggregateUDF, ReversedUDAF, SetMonotonicity}; +use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::expr::{ + AggregateFunction, AggregateFunctionParams, NullTreatment, physical_name, +}; +use datafusion_expr::{AggregateUDF, Expr, ReversedUDAF, SetMonotonicity}; use datafusion_expr_common::accumulator::Accumulator; use datafusion_expr_common::groups_accumulator::GroupsAccumulator; use datafusion_expr_common::type_coercion::aggregates::check_arg_count; @@ -55,6 +62,57 @@ use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +#[derive(Debug, Clone)] +struct AggregateHumanDisplay { + expression: String, + alias: Option, +} + +impl AggregateHumanDisplay { + fn try_new( + expression: Option, + alias: Option, + name: &str, + ) -> Result> { + let alias = alias.filter(|alias| !alias.is_empty()); + let Some(expression) = expression else { + if alias.is_some() { + return internal_err!( + "AggregateExprBuilder::human_display must be provided when human_display_alias is set" + ); + } + return Ok(None); + }; + + if expression.is_empty() { + if alias.is_some() { + return internal_err!( + "AggregateExprBuilder::human_display must be non-empty when human_display_alias is set" + ); + } + return Ok(None); + } + + if let Some(alias) = alias.as_deref() + && alias != name + { + return internal_err!( + "aggregate human_display_alias must match aggregate name `{name}`: {alias}" + ); + } + + Ok(Some(Self { expression, alias })) + } + + fn expression(&self) -> &str { + &self.expression + } + + fn alias(&self) -> Option<&str> { + self.alias.as_deref() + } +} + /// Builder for physical [`AggregateFunctionExpr`] /// /// `AggregateFunctionExpr` contains the information necessary to call @@ -65,8 +123,11 @@ pub struct AggregateExprBuilder { /// Physical expressions of the aggregate function args: Vec>, alias: Option, + output_metadata: Option, /// A human readable name - human_display: String, + human_display: Option, + /// Optional visible output alias for `human_display`. + human_display_alias: Option, /// Arrow Schema for the aggregate function schema: SchemaRef, /// The physical order by expressions @@ -85,7 +146,9 @@ impl AggregateExprBuilder { fun, args, alias: None, - human_display: String::default(), + output_metadata: None, + human_display: None, + human_display_alias: None, schema: Arc::new(Schema::empty()), order_bys: vec![], ignore_nulls: false, @@ -126,10 +189,6 @@ impl AggregateExprBuilder { /// # } /// # /// # impl AggregateUDFImpl for FirstValueUdf { - /// # fn as_any(&self) -> &dyn Any { - /// # unimplemented!() - /// # } - /// # /// # fn name(&self) -> &str { /// # unimplemented!() /// # } @@ -193,7 +252,9 @@ impl AggregateExprBuilder { fun, args, alias, + output_metadata, human_display, + human_display_alias, schema, order_bys, ignore_nulls, @@ -220,17 +281,23 @@ impl AggregateExprBuilder { &fun.signature().type_signature, )?; - let return_field = fun.return_field(&input_exprs_fields)?; + let mut return_field = fun.return_field(&input_exprs_fields)?; + if let Some(output_metadata) = output_metadata { + return_field = output_metadata.add_to_field_ref(return_field); + } let is_nullable = fun.is_nullable(); let name = match alias { None => { return internal_err!( "AggregateExprBuilder::alias must be provided prior to calling build" - ) + ); } Some(alias) => alias, }; + let human_display = + AggregateHumanDisplay::try_new(human_display, human_display_alias, &name)?; + let arg_fields = args .iter() .map(|e| e.return_field(schema.as_ref())) @@ -259,8 +326,24 @@ impl AggregateExprBuilder { self } - pub fn human_display(mut self, name: String) -> Self { - self.human_display = name; + fn output_metadata(mut self, metadata: Option) -> Self { + self.output_metadata = metadata; + self + } + + pub fn human_display(mut self, name: impl Into) -> Self { + let name = name.into(); + self.human_display = (!name.is_empty()).then_some(name); + if self.human_display.is_none() { + self.human_display_alias = None; + } + self + } + + #[doc(hidden)] + pub fn human_display_alias(mut self, alias: impl Into) -> Self { + let alias = alias.into(); + self.human_display_alias = (!alias.is_empty()).then_some(alias); self } @@ -305,6 +388,234 @@ impl AggregateExprBuilder { } } +#[derive(Debug, Clone)] +struct LoweredAggregateHumanDisplay { + expression: String, + alias: Option, +} + +/// Result of lowering a logical aggregate expression into physical aggregate +/// planning pieces. +#[derive(Debug, Clone)] +pub struct LoweredAggregate { + /// Physical aggregate expression that can be used by an aggregate execution + /// plan. + pub aggregate: Arc, + /// Optional physical filter expression for `FILTER (WHERE ...)`. + pub filter: Option>, + /// Physical ordering expressions from aggregate `ORDER BY`. + pub order_bys: Vec, +} + +/// Builder for converting a logical aggregate [`Expr`] into physical aggregate +/// planning pieces. +/// +/// This builder handles the logical-to-physical work needed for aggregate +/// planning: unwrapping aggregate aliases, choosing the output name, preserving +/// user-facing display text, lowering aggregate arguments, lowering the optional +/// filter, and lowering aggregate `ORDER BY` expressions. +pub struct LoweredAggregateBuilder<'a> { + expr: &'a Expr, + name: Option, + human_display: Option, + output_metadata: Option, + preserve_alias_metadata: bool, + logical_input_schema: &'a DFSchema, + physical_input_schema: &'a Schema, + execution_props: &'a ExecutionProps, +} + +impl<'a> LoweredAggregateBuilder<'a> { + /// Create a builder for lowering `expr`. + /// + /// `logical_input_schema` is used to resolve logical expressions such as + /// columns, while `physical_input_schema` is the input schema used by the + /// physical aggregate expression. + pub fn new( + expr: &'a Expr, + logical_input_schema: &'a DFSchema, + physical_input_schema: &'a Schema, + execution_props: &'a ExecutionProps, + ) -> Self { + Self { + expr, + name: None, + human_display: None, + output_metadata: None, + preserve_alias_metadata: true, + logical_input_schema, + physical_input_schema, + execution_props, + } + } + + /// Override the output column name for the aggregate. + /// + /// If this is not set, the builder uses the alias from `expr` when present, + /// or derives the physical name from the aggregate expression. + pub fn with_name(mut self, name: impl Into) -> Self { + self.name = Some(name.into()); + self + } + + /// Override the human-readable display text for the aggregate. + /// + /// This is useful when a caller has already computed the exact display text + /// it wants to preserve. When this override is used, aliases with metadata + /// are still unwrapped for planning, but alias metadata is not copied to the + /// aggregate output field. + pub fn with_human_display(mut self, human_display: impl Into) -> Self { + self.human_display = Some(LoweredAggregateHumanDisplay { + expression: human_display.into(), + alias: None, + }); + self.preserve_alias_metadata = false; + self + } + + /// Lower the logical aggregate expression into physical aggregate pieces. + pub fn build(self) -> Result { + let Self { + expr, + name, + human_display, + output_metadata, + preserve_alias_metadata, + logical_input_schema, + physical_input_schema, + execution_props, + } = self; + + let (name, human_display, output_metadata, expr) = lower_aggregate_display( + expr, + name, + human_display, + output_metadata, + preserve_alias_metadata, + ); + + let Expr::AggregateFunction(AggregateFunction { + func, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + }, + }) = &expr + else { + return internal_err!("Invalid aggregate expression '{expr:?}'"); + }; + + let name = if let Some(name) = name { + name + } else { + physical_name(&expr)? + }; + + let physical_args = + create_physical_exprs(args, logical_input_schema, execution_props)?; + let filter = filter + .as_ref() + .map(|filter| { + create_physical_expr(filter, logical_input_schema, execution_props) + }) + .transpose()?; + let order_bys = + create_physical_sort_exprs(order_by, logical_input_schema, execution_props)?; + let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls) + == NullTreatment::IgnoreNulls; + + let mut builder = AggregateExprBuilder::new(func.to_owned(), physical_args) + .order_by(order_bys.clone()) + .schema(Arc::new(physical_input_schema.to_owned())) + .alias(name) + .output_metadata(output_metadata) + .with_ignore_nulls(ignore_nulls) + .with_distinct(*distinct); + + if let Some(human_display) = human_display { + builder = builder.human_display(human_display.expression); + if let Some(alias) = human_display.alias { + builder = builder.human_display_alias(alias); + } + } + + Ok(LoweredAggregate { + aggregate: Arc::new(builder.build()?), + filter, + order_bys, + }) + } +} + +fn lower_aggregate_display( + expr: &Expr, + name: Option, + human_display: Option, + output_metadata: Option, + preserve_alias_metadata: bool, +) -> ( + Option, + Option, + Option, + Expr, +) { + let mut expr = expr.clone(); + let mut alias_name = None; + let mut alias_metadata = None; + while let Expr::Alias(alias) = expr { + if alias_name.is_none() { + alias_name = Some(alias.name); + alias_metadata = alias.metadata; + } + expr = *alias.expr; + } + + let output_metadata = if preserve_alias_metadata { + output_metadata.or(alias_metadata) + } else { + output_metadata + }; + + if human_display.is_some() { + return (name.or(alias_name), human_display, output_metadata, expr); + } + + match &expr { + Expr::AggregateFunction(_) => { + if let Some(alias_name) = alias_name { + let name = name.unwrap_or(alias_name); + let expression = expr.human_display().to_string(); + let human_display = if expression.is_empty() || expression == name { + LoweredAggregateHumanDisplay { + expression: name.clone(), + alias: None, + } + } else { + LoweredAggregateHumanDisplay { + expression, + alias: Some(name.clone()), + } + }; + + return (Some(name), Some(human_display), output_metadata, expr); + } + + let name = name.unwrap_or_else(|| expr.schema_name().to_string()); + let human_display = LoweredAggregateHumanDisplay { + expression: expr.human_display().to_string(), + alias: None, + }; + + (Some(name), Some(human_display), output_metadata, expr) + } + _ => (name.or(alias_name), None, output_metadata, expr), + } +} + /// Physical aggregate expression of a UDAF. /// /// Instances are constructed via [`AggregateExprBuilder`]. @@ -319,7 +630,7 @@ pub struct AggregateFunctionExpr { /// Output column name that this expression creates name: String, /// Simplified name for `tree` explain. - human_display: String, + human_display: Option, schema: Schema, // The physical order by expressions order_bys: Vec, @@ -351,8 +662,22 @@ impl AggregateFunctionExpr { } /// Simplified name for `tree` explain. - pub fn human_display(&self) -> &str { - &self.human_display + pub fn human_display(&self) -> Option<&str> { + self.human_display + .as_ref() + .map(AggregateHumanDisplay::expression) + } + + #[doc(hidden)] + pub fn human_display_alias(&self) -> Option<&str> { + self.human_display + .as_ref() + .and_then(AggregateHumanDisplay::alias) + } + + fn return_field_metadata(&self) -> Option { + let metadata = FieldMetadata::from(self.return_field.as_ref()); + (!metadata.is_empty()).then_some(metadata) } /// Return if the aggregation is distinct @@ -460,15 +785,22 @@ impl AggregateFunctionExpr { return Ok(None); }; - AggregateExprBuilder::new(Arc::new(updated_fn), self.args.to_vec()) - .order_by(self.order_bys.clone()) - .schema(Arc::new(self.schema.clone())) - .alias(self.name().to_string()) - .with_ignore_nulls(self.ignore_nulls) - .with_distinct(self.is_distinct) - .with_reversed(self.is_reversed) - .build() - .map(Some) + let mut builder = + AggregateExprBuilder::new(Arc::new(updated_fn), self.args.to_vec()) + .order_by(self.order_bys.clone()) + .schema(Arc::new(self.schema.clone())) + .alias(self.name().to_string()) + .output_metadata(self.return_field_metadata()) + .with_ignore_nulls(self.ignore_nulls) + .with_distinct(self.is_distinct) + .with_reversed(self.is_reversed); + if let Some(human_display) = self.human_display() { + builder = builder.human_display(human_display); + } + if let Some(alias) = self.human_display_alias() { + builder = builder.human_display_alias(alias); + } + builder.build().map(Some) } /// Creates accumulator implementation that supports retract @@ -586,23 +918,54 @@ impl AggregateFunctionExpr { ReversedUDAF::NotSupported => None, ReversedUDAF::Identical => Some(self.clone()), ReversedUDAF::Reversed(reverse_udf) => { + let was_aliased = self.human_display_alias().is_some(); let mut name = self.name().to_string(); + let mut human_display = self.human_display.clone(); + // Reversing display follows two paths: + // - aliased display keeps the output `name` unchanged and rewrites only + // the lowered expression in `human_display`. + // - non-aliased display rewrites the canonical `name`, and rewrites + // `human_display` only when present. // If the function is changed, we need to reverse order_by clause as well // i.e. First(a order by b asc null first) -> Last(a order by b desc null last) - if self.fun().name() != reverse_udf.name() { + if !was_aliased && self.fun().name() != reverse_udf.name() { replace_order_by_clause(&mut name); } - replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name()); - - AggregateExprBuilder::new(reverse_udf, self.args.to_vec()) - .order_by(self.order_bys.iter().map(|e| e.reverse()).collect()) - .schema(Arc::new(self.schema.clone())) - .alias(name) - .with_ignore_nulls(self.ignore_nulls) - .with_distinct(self.is_distinct) - .with_reversed(!self.is_reversed) - .build() - .ok() + if !was_aliased { + replace_fn_name_clause( + &mut name, + self.fun.name(), + reverse_udf.name(), + ); + } + + if let Some(human_display) = human_display.as_mut() { + if self.fun().name() != reverse_udf.name() { + replace_order_by_clause(&mut human_display.expression); + } + replace_fn_name_clause( + &mut human_display.expression, + self.fun.name(), + reverse_udf.name(), + ); + } + + let mut builder = + AggregateExprBuilder::new(reverse_udf, self.args.to_vec()) + .order_by(self.order_bys.iter().map(|e| e.reverse()).collect()) + .schema(Arc::new(self.schema.clone())) + .alias(name) + .output_metadata(self.return_field_metadata()) + .with_ignore_nulls(self.ignore_nulls) + .with_distinct(self.is_distinct) + .with_reversed(!self.is_reversed); + if let Some(human_display) = human_display { + builder = builder.human_display(human_display.expression); + if let Some(alias) = human_display.alias { + builder = builder.human_display_alias(alias); + } + } + builder.build().ok() } } } @@ -739,23 +1102,104 @@ fn replace_order_by_clause(order_by: &mut String) { (" ASC NULLS LAST]", " DESC NULLS FIRST]"), ]; - if let Some(start) = order_by.find("ORDER BY [") { - if let Some(end) = order_by[start..].find(']') { - let order_by_start = start + 9; - let order_by_end = start + end; - - let column_order = &order_by[order_by_start..=order_by_end]; - for (suffix, replacement) in suffixes { - if column_order.ends_with(suffix) { - let new_order = column_order.replace(suffix, replacement); - order_by.replace_range(order_by_start..=order_by_end, &new_order); - break; - } + if let Some(start) = order_by.find("ORDER BY [") + && let Some(end) = order_by[start..].find(']') + { + let order_by_start = start + 9; + let order_by_end = start + end; + + let column_order = &order_by[order_by_start..=order_by_end]; + for (suffix, replacement) in suffixes { + if column_order.ends_with(suffix) { + let new_order = column_order.replace(suffix, replacement); + order_by.replace_range(order_by_start..=order_by_end, &new_order); + break; } } } } fn replace_fn_name_clause(aggr_name: &mut String, fn_name_old: &str, fn_name_new: &str) { - *aggr_name = aggr_name.replace(fn_name_old, fn_name_new); + if let Some(rest) = aggr_name.strip_prefix(fn_name_old) { + *aggr_name = format!("{fn_name_new}{rest}"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::collections::HashMap; + + use arrow::datatypes::Field; + use datafusion_common::metadata::FieldMetadata; + use datafusion_expr::{col, test::function_stub::sum}; + + fn aggregate_test_schema() -> Result<(Schema, DFSchema)> { + let schema = Schema::new(vec![Field::new("column1", DataType::Int64, true)]); + let logical_schema = DFSchema::try_from(schema.clone())?; + Ok((schema, logical_schema)) + } + + fn test_metadata() -> FieldMetadata { + FieldMetadata::from(HashMap::from([( + "some_key".to_string(), + "some_value".to_string(), + )])) + } + + fn aggregate_alias_with_metadata() -> Expr { + sum(col("column1")).alias_with_metadata("agg", Some(test_metadata())) + } + + #[test] + fn lowered_aggregate_builder_unwraps_alias_with_metadata() -> Result<()> { + let (schema, logical_schema) = aggregate_test_schema()?; + let expr = aggregate_alias_with_metadata(); + + let lowered = LoweredAggregateBuilder::new( + &expr, + &logical_schema, + &schema, + &ExecutionProps::new(), + ) + .build()?; + + assert_eq!(lowered.aggregate.name(), "agg"); + assert_eq!(lowered.aggregate.human_display_alias(), Some("agg")); + assert_eq!( + lowered.aggregate.field().metadata().get("some_key"), + Some(&"some_value".to_string()) + ); + + Ok(()) + } + + #[test] + fn lowered_aggregate_builder_display_override_skips_alias_metadata() -> Result<()> { + let (schema, logical_schema) = aggregate_test_schema()?; + let expr = aggregate_alias_with_metadata(); + + let lowered = LoweredAggregateBuilder::new( + &expr, + &logical_schema, + &schema, + &ExecutionProps::new(), + ) + .with_human_display(expr.human_display().to_string()) + .build()?; + + assert_eq!(lowered.aggregate.name(), "agg"); + assert_eq!(lowered.aggregate.human_display_alias(), None); + assert!( + lowered + .aggregate + .field() + .metadata() + .get("some_key") + .is_none() + ); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index 166e639966f13..1dca36b75f9f5 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -20,18 +20,18 @@ use std::fmt::Debug; use std::sync::Arc; +use crate::PhysicalExpr; use crate::expressions::Column; use crate::intervals::cp_solver::{ExprIntervalGraph, PropagationResult}; use crate::utils::collect_columns; -use crate::PhysicalExpr; use arrow::datatypes::Schema; use datafusion_common::stats::Precision; use datafusion_common::{ - assert_or_internal_err, internal_datafusion_err, internal_err, ColumnStatistics, - Result, ScalarValue, + ColumnStatistics, Result, ScalarValue, assert_or_internal_err, + internal_datafusion_err, internal_err, }; -use datafusion_expr::interval_arithmetic::{cardinality_ratio, Interval}; +use datafusion_expr::interval_arithmetic::{Interval, cardinality_ratio}; /// The shared context used during the analysis of an expression. Includes /// the boundaries for all known columns. @@ -167,6 +167,7 @@ pub fn analyze( schema: &Schema, ) -> Result { let initial_boundaries = &context.boundaries; + if initial_boundaries .iter() .all(|bound| bound.interval.is_none()) @@ -178,7 +179,7 @@ pub fn analyze( "ExprBoundaries has a non-zero distinct count although it represents an empty table" ); assert_or_internal_err!( - context.selectivity == Some(0.0), + context.selectivity.unwrap_or(0.0) == 0.0, "AnalysisContext has a non-zero selectivity although it represents an empty table" ); Ok(context) @@ -187,8 +188,8 @@ pub fn analyze( .any(|bound| bound.interval.is_none()) { internal_err!( - "AnalysisContext is an inconsistent state. Some columns represent empty table while others don't" - ) + "AnalysisContext is an inconsistent state. Some columns represent empty table while others don't" + ) } else { let mut target_boundaries = context.boundaries; let mut graph = ExprIntervalGraph::try_new(Arc::clone(expr), schema)?; @@ -201,14 +202,13 @@ pub fn analyze( let target_expr_and_indices = graph.gather_node_indices(columns.as_slice()); for (expr, index) in &target_expr_and_indices { - if let Some(column) = expr.as_any().downcast_ref::() { - if let Some(bound) = + if let Some(column) = expr.downcast_ref::() + && let Some(bound) = target_boundaries.iter().find(|b| b.column == *column) - { - // Now, it's safe to unwrap - target_indices_and_boundaries - .push((*index, bound.interval.as_ref().unwrap().clone())); - } + { + // Now, it's safe to unwrap + target_indices_and_boundaries + .push((*index, bound.interval.as_ref().unwrap().clone())); } } @@ -241,14 +241,13 @@ fn shrink_boundaries( ) -> Result { let initial_boundaries = target_boundaries.clone(); target_expr_and_indices.iter().for_each(|(expr, i)| { - if let Some(column) = expr.as_any().downcast_ref::() { - if let Some(bound) = target_boundaries + if let Some(column) = expr.downcast_ref::() + && let Some(bound) = target_boundaries .iter_mut() .find(|bound| bound.column.eq(column)) - { - bound.interval = Some(graph.get_interval(*i)); - }; - } + { + bound.interval = Some(graph.get_interval(*i)); + }; }); let selectivity = calculate_selectivity(&target_boundaries, &initial_boundaries)?; @@ -261,6 +260,44 @@ fn shrink_boundaries( Ok(AnalysisContext::new(target_boundaries).with_selectivity(selectivity)) } +/// Returns `Some(1.0 / distinct_count)` when the filter demonstrably collapsed +/// a non-singleton interval down to a single point, i.e. an equality predicate +/// was applied. Returns `None` in all other cases, signalling that the caller +/// should fall back to [`cardinality_ratio`]. +/// +/// The `initial_interval` guard prevents double-counting selectivity when the +/// column statistics already described a singleton before any filter was +/// applied: if the initial interval was already the same single point, no +/// additional selectivity has been gained and the `1 / NDV` shortcut must not +/// fire. +fn singleton_selectivity( + initial_interval: &Interval, + target_interval: &Interval, + distinct_count: usize, +) -> Option { + // The target must have collapsed to a single non-null value. + if distinct_count == 0 + || target_interval.lower().is_null() + || target_interval.lower() != target_interval.upper() + { + return None; + } + + // Only treat this as a newly-applied equality filter when the initial + // interval was not already that same singleton. If it was, the stats + // already encoded this restriction and applying 1/NDV again would + // under-estimate the row count. + let initial_is_same_singleton = !initial_interval.lower().is_null() + && initial_interval.lower() == initial_interval.upper() + && initial_interval.lower() == target_interval.lower(); + + if initial_is_same_singleton { + return None; + } + + Some(1.0 / distinct_count as f64) +} + /// This function calculates the filter predicate's selectivity by comparing /// the initial and pruned column boundaries. Selectivity is defined as the /// ratio of rows in a table that satisfy the filter's predicate. @@ -279,13 +316,24 @@ fn calculate_selectivity( let mut acc: f64 = 1.0; for (initial, target) in initial_boundaries.iter().zip(target_boundaries) { match (initial.interval.as_ref(), target.interval.as_ref()) { - (Some(initial), Some(target)) => { - acc *= cardinality_ratio(initial, target); + (Some(initial_interval), Some(target_interval)) => { + if let Precision::Exact(distinct_count) + | Precision::Inexact(distinct_count) = target.distinct_count + && let Some(s) = singleton_selectivity( + initial_interval, + target_interval, + distinct_count, + ) + { + acc *= s; + continue; + } + acc *= cardinality_ratio(initial_interval, target_interval); } (None, Some(_)) => { return internal_err!( - "Initial boundary cannot be None while having a Some() target boundary" - ); + "Initial boundary cannot be None while having a Some() target boundary" + ); } _ => return Ok(0.0), } @@ -299,14 +347,14 @@ mod tests { use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{assert_contains, DFSchema}; + use datafusion_common::{DFSchema, ScalarValue, assert_contains, stats::Precision}; use datafusion_expr::{ - col, execution_props::ExecutionProps, interval_arithmetic::Interval, lit, Expr, + Expr, col, execution_props::ExecutionProps, interval_arithmetic::Interval, lit, }; - use crate::{create_physical_expr, AnalysisContext}; + use crate::{AnalysisContext, create_physical_expr, expressions::Column}; - use super::{analyze, ExprBoundaries}; + use super::{ExprBoundaries, analyze, calculate_selectivity, singleton_selectivity}; fn make_field(name: &str, data_type: DataType) -> Field { let nullable = false; @@ -373,7 +421,9 @@ mod tests { ) .unwrap(); let Some(actual) = &analysis_result.boundaries[0].interval else { - panic!("The analysis result should contain non-empty intervals for all columns"); + panic!( + "The analysis result should contain non-empty intervals for all columns" + ); }; let expected = Interval::make(lower, upper).unwrap(); assert_eq!( @@ -435,4 +485,92 @@ mod tests { .unwrap_err(); assert_contains!(analysis_error.to_string(), expected_error); } + + // --------------------------------------------------------------------------- + // Unit tests for singleton_selectivity and calculate_selectivity + // --------------------------------------------------------------------------- + + fn make_boundary(lower: i32, upper: i32, distinct_count: usize) -> ExprBoundaries { + ExprBoundaries { + column: Column::new("a", 0), + interval: Some( + Interval::try_new( + ScalarValue::Int32(Some(lower)), + ScalarValue::Int32(Some(upper)), + ) + .unwrap(), + ), + distinct_count: Precision::Exact(distinct_count), + } + } + + /// When the initial interval is already the same singleton as the target, + /// `singleton_selectivity` must return `None` so we do not double-apply + /// 1/NDV selectivity. + #[test] + fn test_singleton_selectivity_skipped_when_initial_is_same_singleton() { + let singleton = + Interval::try_new(ScalarValue::Int32(Some(5)), ScalarValue::Int32(Some(5))) + .unwrap(); + // Both initial and target are [5, 5] — no new equality filter was applied. + assert_eq!( + singleton_selectivity(&singleton, &singleton, 10), + None, + "shortcut must not fire when initial interval was already the same singleton" + ); + } + + /// When the initial interval is a broader range and the target collapses to + /// a singleton, `singleton_selectivity` must return `Some(1/NDV)`. + #[test] + fn test_singleton_selectivity_applied_when_range_collapses() { + let initial = + Interval::try_new(ScalarValue::Int32(Some(1)), ScalarValue::Int32(Some(100))) + .unwrap(); + let target = + Interval::try_new(ScalarValue::Int32(Some(5)), ScalarValue::Int32(Some(5))) + .unwrap(); + let result = singleton_selectivity(&initial, &target, 10); + assert_eq!( + result, + Some(0.1), + "shortcut must return 1/NDV when a range collapses to a singleton" + ); + } + + /// Regression test: `calculate_selectivity` must not apply the `1/NDV` + /// shortcut when the column statistics already describe a singleton interval + /// (i.e. before the filter, the column only ever held one value). In that + /// case the target and initial intervals are the same singleton, so the + /// cardinality ratio is 1.0 and the overall selectivity should remain 1.0. + #[test] + fn test_calculate_selectivity_already_singleton_initial_interval() { + let already_singleton = make_boundary(7, 7, 1); + + let selectivity = calculate_selectivity( + std::slice::from_ref(&already_singleton), + std::slice::from_ref(&already_singleton), + ) + .unwrap(); + + let wide_initial = make_boundary(1, 100, 50); + let same_singleton_target = make_boundary(7, 7, 50); + let selectivity_new = + calculate_selectivity(&[same_singleton_target], &[wide_initial]).unwrap(); + assert!( + (selectivity_new - 0.02).abs() < 1e-10, + "expected selectivity 1/NDV = 0.02, got {selectivity_new}" + ); + + let singleton_initial = make_boundary(7, 7, 50); + let singleton_target = make_boundary(7, 7, 50); + let selectivity_no_new_filter = + calculate_selectivity(&[singleton_target], &[singleton_initial]).unwrap(); + assert!( + (selectivity_no_new_filter - 1.0).abs() < 1e-10, + "expected selectivity 1.0 when initial was already the same singleton, got {selectivity_no_new_filter}" + ); + + let _ = selectivity; // silence unused warning + } } diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs index f1833666d6bbe..5612e63b530e7 100644 --- a/datafusion/physical-expr/src/async_scalar_function.rs +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -19,14 +19,13 @@ use crate::ScalarFunctionExpr; use arrow::array::RecordBatch; use arrow::compute::concat; use arrow::datatypes::{DataType, Field, FieldRef, Schema}; -use datafusion_common::config::ConfigOptions; use datafusion_common::Result; +use datafusion_common::config::ConfigOptions; use datafusion_common::{internal_err, not_impl_err}; -use datafusion_expr::async_udf::AsyncScalarUDF; use datafusion_expr::ScalarFunctionArgs; +use datafusion_expr::async_udf::AsyncScalarUDF; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use std::any::Any; use std::fmt::Display; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -68,7 +67,7 @@ impl AsyncFuncExpr { func: Arc, schema: &Schema, ) -> Result { - let Some(_) = func.as_any().downcast_ref::() else { + let Some(_) = func.downcast_ref::() else { return internal_err!( "unexpected function type, expected ScalarFunctionExpr, got: {:?}", func @@ -99,12 +98,10 @@ impl AsyncFuncExpr { /// Return the ideal batch size for this function pub fn ideal_batch_size(&self) -> Result> { - if let Some(expr) = self.func.as_any().downcast_ref::() { - if let Some(udf) = - expr.fun().inner().as_any().downcast_ref::() - { - return Ok(udf.ideal_batch_size()); - } + if let Some(expr) = self.func.downcast_ref::() + && let Some(udf) = expr.fun().inner().downcast_ref::() + { + return Ok(udf.ideal_batch_size()); } not_impl_err!("Can't get ideal_batch_size from {:?}", self.func) } @@ -117,8 +114,7 @@ impl AsyncFuncExpr { batch: &RecordBatch, config_options: Arc, ) -> Result { - let Some(scalar_function_expr) = - self.func.as_any().downcast_ref::() + let Some(scalar_function_expr) = self.func.downcast_ref::() else { return internal_err!( "unexpected function type, expected ScalarFunctionExpr, got: {:?}", @@ -129,7 +125,6 @@ impl AsyncFuncExpr { let Some(async_udf) = scalar_function_expr .fun() .inner() - .as_any() .downcast_ref::() else { return not_impl_err!( @@ -212,10 +207,6 @@ impl AsyncFuncExpr { } impl PhysicalExpr for AsyncFuncExpr { - fn as_any(&self) -> &dyn Any { - self - } - fn data_type(&self, input_schema: &Schema) -> Result { self.func.data_type(input_schema) } diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 5b64884f65bb8..d00a4a32278f0 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::fmt::Display; use std::ops::Deref; use std::sync::Arc; @@ -27,7 +28,7 @@ use crate::projection::ProjectionTargets; use crate::{PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{HashMap, JoinType, Result, ScalarValue}; +use datafusion_common::{JoinType, Result, ScalarValue}; use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; use indexmap::{IndexMap, IndexSet}; @@ -153,7 +154,7 @@ impl From> for ConstExpr { // By default, assume constant expressions are not same across partitions. // However, if we have a literal, it will have a single value that is the // same across all partitions. - let across = if let Some(lit) = expr.as_any().downcast_ref::() { + let across = if let Some(lit) = expr.downcast_ref::() { AcrossPartitions::Uniform(Some(lit.value().clone())) } else { AcrossPartitions::Heterogeneous @@ -201,7 +202,7 @@ impl EquivalenceClass { /// Insert the expression into this class, meaning it is known to be equal to /// all other expressions in this class. pub fn push(&mut self, expr: Arc) { - if let Some(lit) = expr.as_any().downcast_ref::() { + if let Some(lit) = expr.downcast_ref::() { let expr_across = AcrossPartitions::Uniform(Some(lit.value().clone())); if let Some(across) = self.constant.as_mut() { // TODO: Return an error if constant values do not agree. @@ -303,7 +304,7 @@ type AugmentedMapping<'a> = IndexMap< #[derive(Clone, Debug, Default)] pub struct EquivalenceGroup { /// A mapping from expressions to their equivalence class key. - map: HashMap, usize>, + map: IndexMap, usize>, /// The equivalence classes in this group. classes: Vec, } @@ -436,7 +437,7 @@ impl EquivalenceGroup { let cls = self.classes.swap_remove(idx); // Remove its entries from the lookup table: for expr in cls.iter() { - self.map.remove(expr); + self.map.swap_remove(expr); } // Update the lookup table for the moved class: if idx < self.classes.len() { @@ -448,7 +449,7 @@ impl EquivalenceGroup { /// Updates the entry in lookup table for the given equivalence class with /// the given index. fn update_lookup_table( - map: &mut HashMap, usize>, + map: &mut IndexMap, usize>, cls: &EquivalenceClass, idx: usize, ) { @@ -591,7 +592,7 @@ impl EquivalenceGroup { expr: &Arc, ) -> Option> { // Literals don't need to be projected - if expr.as_any().downcast_ref::().is_some() { + if expr.downcast_ref::().is_some() { return Some(Arc::clone(expr)); } @@ -734,13 +735,13 @@ impl EquivalenceGroup { &self, expr: &Arc, ) -> Option { - if let Some(lit) = expr.as_any().downcast_ref::() { + if let Some(lit) = expr.downcast_ref::() { return Some(AcrossPartitions::Uniform(Some(lit.value().clone()))); } - if let Some(cls) = self.get_equivalence_class(expr) { - if cls.constant.is_some() { - return cls.constant.clone(); - } + if let Some(cls) = self.get_equivalence_class(expr) + && cls.constant.is_some() + { + return cls.constant.clone(); } // TODO: This function should be able to return values of non-literal // complex constants as well; e.g. it should return `8` for the @@ -819,15 +820,15 @@ impl EquivalenceGroup { // Check if expressions are equivalent through equivalence classes // We need to check both directions since expressions might be in different classes - if let Some(left_class) = self.get_equivalence_class(left) { - if left_class.contains(right) { - return true; - } + if let Some(left_class) = self.get_equivalence_class(left) + && left_class.contains(right) + { + return true; } - if let Some(right_class) = self.get_equivalence_class(right) { - if right_class.contains(left) { - return true; - } + if let Some(right_class) = self.get_equivalence_class(right) + && right_class.contains(left) + { + return true; } // For non-leaf nodes, check structural equality @@ -841,7 +842,7 @@ impl EquivalenceGroup { } // Type equality check through reflection - if left.as_any().type_id() != right.as_any().type_id() { + if (left as &dyn Any).type_id() != (right as &dyn Any).type_id() { return false; } @@ -910,10 +911,9 @@ impl From> for EquivalenceGroup { mod tests { use super::*; use crate::equivalence::tests::create_test_params; - use crate::expressions::{binary, col, lit, BinaryExpr, Column, Literal}; + use crate::expressions::{BinaryExpr, Column, binary, col, lit}; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Operator; #[test] @@ -1082,8 +1082,7 @@ mod tests { left: Arc::clone(&col_a), right: Arc::clone(&col_b), expected: false, - description: - "Columns in different equivalence classes should not be equal", + description: "Columns in different equivalence classes should not be equal", }, // Literal tests TestCase { @@ -1111,8 +1110,7 @@ mod tests { Arc::clone(&col_y), )) as _, expected: true, - description: - "Binary expressions with equivalent operands should be equal", + description: "Binary expressions with equivalent operands should be equal", }, TestCase { left: Arc::new(BinaryExpr::new( @@ -1126,8 +1124,7 @@ mod tests { Arc::clone(&col_a), )) as _, expected: false, - description: - "Binary expressions with non-equivalent operands should not be equal", + description: "Binary expressions with non-equivalent operands should not be equal", }, TestCase { left: Arc::new(BinaryExpr::new( diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index a7289103806b8..64bb62901310f 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -31,9 +31,9 @@ pub use class::{AcrossPartitions, ConstExpr, EquivalenceClass, EquivalenceGroup} pub use ordering::OrderingEquivalenceClass; // Re-export for backwards compatibility, we recommend importing from // datafusion_physical_expr::projection instead -pub use crate::projection::{project_ordering, project_orderings, ProjectionMapping}; +pub use crate::projection::{ProjectionMapping, project_ordering, project_orderings}; pub use properties::{ - calculate_union, join_equivalence_properties, EquivalenceProperties, + EquivalenceProperties, calculate_union, join_equivalence_properties, }; // Convert each tuple to a `PhysicalSortExpr` and construct a vector. @@ -57,10 +57,9 @@ pub fn convert_to_orderings>>( #[cfg(test)] mod tests { use super::*; - use crate::expressions::{col, Column}; + use crate::expressions::{Column, col}; use crate::{LexRequirement, PhysicalSortExpr}; - use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::Result; use datafusion_physical_expr_common::sort_expr::PhysicalSortRequirement; diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index aa65c4a80ae9a..2ce8a8d246fe7 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use std::vec::IntoIter; use crate::expressions::with_new_schema; -use crate::{add_offset_to_physical_sort_exprs, LexOrdering, PhysicalExpr}; +use crate::{LexOrdering, PhysicalExpr, add_offset_to_physical_sort_exprs}; use arrow::compute::SortOptions; use arrow::datatypes::SchemaRef; @@ -326,10 +326,10 @@ mod tests { use crate::equivalence::tests::create_test_schema; use crate::equivalence::{ - convert_to_orderings, convert_to_sort_exprs, EquivalenceClass, EquivalenceGroup, - EquivalenceProperties, OrderingEquivalenceClass, + EquivalenceClass, EquivalenceGroup, EquivalenceProperties, + OrderingEquivalenceClass, convert_to_orderings, convert_to_sort_exprs, }; - use crate::expressions::{col, BinaryExpr, Column}; + use crate::expressions::{BinaryExpr, Column, col}; use crate::utils::tests::TestScalarUDF; use crate::{ AcrossPartitions, ConstExpr, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, @@ -338,8 +338,8 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::config::ConfigOptions; use datafusion_common::Result; + use datafusion_common::config::ConfigOptions; use datafusion_expr::{Operator, ScalarUDF}; #[test] @@ -639,8 +639,9 @@ mod tests { ]; for (orderings, eq_group, constants, reqs, expected) in test_cases { - let err_msg = - format!("error in test orderings: {orderings:?}, eq_group: {eq_group:?}, constants: {constants:?}, reqs: {reqs:?}, expected: {expected:?}"); + let err_msg = format!( + "error in test orderings: {orderings:?}, eq_group: {eq_group:?}, constants: {constants:?}, reqs: {reqs:?}, expected: {expected:?}" + ); let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); let orderings = convert_to_orderings(&orderings); eq_properties.add_orderings(orderings); diff --git a/datafusion/physical-expr/src/equivalence/properties/dependency.rs b/datafusion/physical-expr/src/equivalence/properties/dependency.rs index 8945d18be430f..2ebc71559fcf4 100644 --- a/datafusion/physical-expr/src/equivalence/properties/dependency.rs +++ b/datafusion/physical-expr/src/equivalence/properties/dependency.rs @@ -383,14 +383,13 @@ pub fn generate_dependency_orderings( #[cfg(test)] mod tests { use std::ops::Not; - use std::sync::Arc; use super::*; use crate::equivalence::tests::{ convert_to_sort_reqs, create_test_params, create_test_schema, parse_sort_expr, }; - use crate::equivalence::{convert_to_sort_exprs, ProjectionMapping}; - use crate::expressions::{col, BinaryExpr, CastExpr, Column}; + use crate::equivalence::{ProjectionMapping, convert_to_sort_exprs}; + use crate::expressions::{BinaryExpr, CastExpr, Column, col}; use crate::projection::tests::output_schema; use crate::{ConstExpr, EquivalenceProperties, ScalarFunctionExpr}; @@ -398,10 +397,9 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion_common::config::ConfigOptions; use datafusion_common::{Constraint, Constraints, Result}; - use datafusion_expr::sort_properties::SortProperties; use datafusion_expr::Operator; + use datafusion_expr::sort_properties::SortProperties; use datafusion_functions::string::concat; - use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ LexRequirement, PhysicalSortRequirement, }; @@ -933,33 +931,35 @@ mod tests { struct TestCase { name: &'static str, constants: Vec>, - equal_conditions: Vec<[Arc; 2]>, - sort_columns: &'static [&'static str], + equal_condition: [Arc; 2], should_satisfy_ordering: bool, } let col_a = col("a", schema.as_ref())?; let col_b = col("b", schema.as_ref())?; let col_c = col("c", schema.as_ref())?; - let cast_c = Arc::new(CastExpr::new(col_c, DataType::Date32, None)) as _; + let cast_c = Arc::new(CastExpr::new_with_target_field( + col_c, + Arc::new(Field::new("c", DataType::Date32, true)), + None, + )) as _; + let required_sort = vec![PhysicalSortExpr::new_default(col("c", &schema)?)]; let cases = vec![ TestCase { - name: "(a, b, c) -> (c)", + name: "cast_c = a", // b is constant, so it should be removed from the sort order constants: vec![Arc::clone(&col_b)], - equal_conditions: vec![[Arc::clone(&cast_c), Arc::clone(&col_a)]], - sort_columns: &["c"], + equal_condition: [Arc::clone(&cast_c), Arc::clone(&col_a)], should_satisfy_ordering: true, }, // Same test with above test, where equality order is swapped. // Algorithm shouldn't depend on this order. TestCase { - name: "(a, b, c) -> (c)", + name: "a = cast_c", // b is constant, so it should be removed from the sort order constants: vec![col_b], - equal_conditions: vec![[Arc::clone(&col_a), Arc::clone(&cast_c)]], - sort_columns: &["c"], + equal_condition: [Arc::clone(&col_a), Arc::clone(&cast_c)], should_satisfy_ordering: true, }, TestCase { @@ -967,8 +967,7 @@ mod tests { // b is not constant anymore constants: vec![], // a and c are still compatible, but this is irrelevant since the original ordering is (a, b, c) - equal_conditions: vec![[Arc::clone(&cast_c), Arc::clone(&col_a)]], - sort_columns: &["c"], + equal_condition: [Arc::clone(&cast_c), Arc::clone(&col_a)], should_satisfy_ordering: false, }, ]; @@ -981,9 +980,8 @@ mod tests { // Equal conditions before constants { let mut properties = base_properties.clone(); - for [left, right] in case.equal_conditions.clone() { - properties.add_equal_conditions(left, right)? - } + let [left, right] = case.equal_condition.clone(); + properties.add_equal_conditions(left, right)?; properties.add_constants( case.constants.iter().cloned().map(ConstExpr::from), )?; @@ -995,20 +993,13 @@ mod tests { properties.add_constants( case.constants.iter().cloned().map(ConstExpr::from), )?; - for [left, right] in case.equal_conditions { - properties.add_equal_conditions(left, right)? - } + let [left, right] = case.equal_condition; + properties.add_equal_conditions(left, right)?; properties }, ] { - let sort = case - .sort_columns - .iter() - .map(|&name| col(name, &schema).map(PhysicalSortExpr::new_default)) - .collect::>>()?; - assert_eq!( - properties.ordering_satisfy(sort)?, + properties.ordering_satisfy(required_sort.clone())?, case.should_satisfy_ordering, "failed test '{}'", case.name @@ -1528,4 +1519,102 @@ mod tests { Ok(()) } + + /// Test that orderings propagate through struct-producing projections. + /// + /// When a projection creates a struct via `named_struct('a', col_a, ...)`, + /// the output should preserve the ordering of `col_a` as an ordering on + /// `get_field(col("s"), "a")`. This enables sort elimination when the + /// framework sorts by a struct field that corresponds to an already-sorted + /// input column. + #[test] + fn test_ordering_propagation_through_named_struct() -> Result<()> { + use crate::expressions::Literal; + use datafusion_common::ScalarValue; + use datafusion_functions::core::{get_field, named_struct}; + + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let col_a = col("a", &input_schema)?; + let col_b = col("b", &input_schema)?; + let config = Arc::new(ConfigOptions::new()); + + // Build: named_struct('a', col_a, 'b', col_b) AS s + let named_struct_udf = named_struct(); + let named_struct_expr = Arc::new(ScalarFunctionExpr::new( + "named_struct", + named_struct_udf, + vec![ + Arc::new(Literal::new(ScalarValue::Utf8(Some("a".to_string())))), + Arc::clone(&col_a), + Arc::new(Literal::new(ScalarValue::Utf8(Some("b".to_string())))), + Arc::clone(&col_b), + ], + Arc::new(Field::new( + "named_struct", + DataType::Struct( + vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ] + .into(), + ), + true, + )), + Arc::clone(&config), + )) as Arc; + + // Projection: named_struct(...) AS s + let proj_exprs = vec![(named_struct_expr, "s".to_string())]; + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?; + + // Input is ordered by [a ASC] + let mut input_properties = EquivalenceProperties::new(Arc::clone(&input_schema)); + let sort_a = PhysicalSortExpr::new( + Arc::clone(&col_a), + SortOptions { + descending: false, + nulls_first: false, + }, + ); + input_properties.add_orderings([vec![sort_a]]); + + // Project through the named_struct + let out_schema = output_schema(&projection_mapping, &input_schema)?; + let out_properties = input_properties.project(&projection_mapping, out_schema); + + // Build the sort expression: get_field(col("s"), "a") + // This is what the framework would generate for ORDER BY s.a + let get_field_udf = get_field(); + let col_s = Arc::new(Column::new("s", 0)) as Arc; + let get_field_expr = Arc::new(ScalarFunctionExpr::new( + "get_field", + get_field_udf, + vec![ + Arc::clone(&col_s), + Arc::new(Literal::new(ScalarValue::Utf8(Some("a".to_string())))), + ], + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::clone(&config), + )) as Arc; + + let sort_get_field_a = PhysicalSortExpr::new( + get_field_expr, + SortOptions { + descending: false, + nulls_first: false, + }, + ); + + // The output should satisfy ordering by get_field(s, "a") + assert!( + out_properties.ordering_satisfy(vec![sort_get_field_a])?, + "Output should be ordered by get_field(s, 'a') since input is ordered by col_a" + ); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/equivalence/properties/joins.rs b/datafusion/physical-expr/src/equivalence/properties/joins.rs index 485b11d586397..536badba435d3 100644 --- a/datafusion/physical-expr/src/equivalence/properties/joins.rs +++ b/datafusion/physical-expr/src/equivalence/properties/joins.rs @@ -16,7 +16,7 @@ // under the License. use super::EquivalenceProperties; -use crate::{equivalence::OrderingEquivalenceClass, PhysicalExprRef}; +use crate::{PhysicalExprRef, equivalence::OrderingEquivalenceClass}; use arrow::datatypes::SchemaRef; use datafusion_common::{JoinSide, JoinType, Result}; @@ -140,7 +140,6 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Fields, Schema}; - use datafusion_common::Result; #[test] fn test_join_equivalence_properties() -> Result<()> { diff --git a/datafusion/physical-expr/src/equivalence/properties/mod.rs b/datafusion/physical-expr/src/equivalence/properties/mod.rs index c13618feb8aa2..bb74cd1d9c7b3 100644 --- a/datafusion/physical-expr/src/equivalence/properties/mod.rs +++ b/datafusion/physical-expr/src/equivalence/properties/mod.rs @@ -27,21 +27,21 @@ use std::mem; use std::sync::Arc; use self::dependency::{ - construct_prefix_orderings, generate_dependency_orderings, referred_dependencies, - Dependencies, DependencyMap, + Dependencies, DependencyMap, construct_prefix_orderings, + generate_dependency_orderings, referred_dependencies, }; use crate::equivalence::{ AcrossPartitions, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; -use crate::expressions::{with_new_schema, CastExpr, Column, Literal}; +use crate::expressions::{CastExpr, Column, Literal, with_new_schema}; use crate::{ ConstExpr, LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, }; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{plan_err, Constraint, Constraints, HashMap, Result}; +use datafusion_common::{Constraint, Constraints, HashMap, Result, plan_err}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_physical_expr_common::sort_expr::options_compatible; @@ -195,6 +195,27 @@ impl OrderingEquivalenceCache { } impl EquivalenceProperties { + /// Helper used by the ordering equivalence rule when considering whether a + /// cast-bearing expression can replace an existing sort key without + /// invalidating the ordering. + /// + /// The substitution is only allowed when the cast wraps the very same child + /// expression that the original sort used and the casted type is a + /// widening/order-preserving conversion. Without those restrictions, a + /// narrowing cast could collapse distinct values and violate the existing + /// sort order. + fn substitute_cast_ordering( + r_expr: Arc, + sort_expr: &PhysicalSortExpr, + expr_type: &DataType, + ) -> Option { + let cast_expr = r_expr.downcast_ref::()?; + + (cast_expr.expr().eq(&sort_expr.expr) + && CastExpr::check_bigger_cast(cast_expr.cast_type(), expr_type)) + .then(|| PhysicalSortExpr::new(r_expr, sort_expr.options)) + } + /// Creates an empty `EquivalenceProperties` object. pub fn new(schema: SchemaRef) -> Self { Self { @@ -207,8 +228,13 @@ impl EquivalenceProperties { } /// Adds constraints to the properties. - pub fn with_constraints(mut self, constraints: Constraints) -> Self { + pub fn set_constraints(&mut self, constraints: Constraints) { self.constraints = constraints; + } + + /// Adds constraints to the properties. + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.set_constraints(constraints); self } @@ -712,8 +738,7 @@ impl EquivalenceProperties { // Build a map of column positions in the ordering: let mut col_positions = HashMap::with_capacity(length); for (pos, req) in ordering.iter().enumerate() { - if let Some(col) = req.expr.as_any().downcast_ref::() - { + if let Some(col) = req.expr.downcast_ref::() { let nullable = col.nullable(&self.schema).unwrap_or(true); col_positions.insert(col.index(), (pos, nullable)); } @@ -757,8 +782,7 @@ impl EquivalenceProperties { // Build a map of column positions in the ordering: let mut col_positions = HashMap::with_capacity(length); for (pos, req) in ordering.iter().enumerate() { - if let Some(col) = req.expr.as_any().downcast_ref::() - { + if let Some(col) = req.expr.downcast_ref::() { let nullable = col.nullable(&self.schema).unwrap_or(true); col_positions.insert(col.index(), (pos, nullable)); } @@ -828,35 +852,25 @@ impl EquivalenceProperties { order .into_iter() .map(|sort_expr| { - let referring_exprs = mapping - .iter() - .map(|(source, _target)| source) - .filter(|source| expr_refers(source, &sort_expr.expr)) - .cloned(); - let mut result = vec![]; // The sort expression comes from this schema, so the // following call to `unwrap` is safe. let expr_type = sort_expr.expr.data_type(schema).unwrap(); + let original_sort_expr = sort_expr.clone(); // TODO: Add one-to-one analysis for ScalarFunctions. - for r_expr in referring_exprs { - // We check whether this expression is substitutable. - if let Some(cast_expr) = - r_expr.as_any().downcast_ref::() - { - // For casts, we need to know whether the cast - // expression matches: - if cast_expr.expr.eq(&sort_expr.expr) - && cast_expr.is_bigger_cast(&expr_type) - { - result.push(PhysicalSortExpr::new( - r_expr, - sort_expr.options, - )); - } - } - } - result.push(sort_expr); - result + mapping + .iter() + .map(|(source, _target)| source) + .filter(|source| expr_refers(source, &original_sort_expr.expr)) + .cloned() + .filter_map(|r_expr| { + Self::substitute_cast_ordering( + r_expr, + &original_sort_expr, + &expr_type, + ) + }) + .chain(std::iter::once(sort_expr)) + .collect::>() }) // Generate all valid orderings given substituted expressions: .multi_cartesian_product() @@ -1118,7 +1132,7 @@ impl EquivalenceProperties { .iter() .flat_map(|(_, targets)| { targets.iter().flat_map(|(target, _)| { - target.as_any().downcast_ref::().map(|c| c.index()) + target.downcast_ref::().map(|c| c.index()) }) }) .collect::>(); @@ -1277,7 +1291,7 @@ impl EquivalenceProperties { // Rewriting equivalence properties in terms of new schema is not // safe when schemas are not aligned: return plan_err!( - "Schemas have to be aligned to rewrite equivalences:\n Old schema: {:?}\n New schema: {:?}", + "Schemas have to be aligned to rewrite equivalences:\n Old schema: {}\n New schema: {}", self.schema, schema ); @@ -1376,10 +1390,10 @@ fn update_properties( // We have an intermediate (non-leaf) node, account for its children: let children_props = node.children.iter().map(|c| c.data.clone()).collect_vec(); node.data = node.expr.get_properties(&children_props)?; - } else if node.expr.as_any().is::() { + } else if node.expr.is::() { // We have a Literal, which is one of the two possible leaf node types: node.data = node.expr.get_properties(&[])?; - } else if node.expr.as_any().is::() { + } else if node.expr.is::() { // We have a Column, which is the other possible leaf node type: node.data.range = Interval::make_unbounded(&node.expr.data_type(eq_properties.schema())?)? @@ -1450,13 +1464,13 @@ fn get_expr_properties( range: Interval::make_unbounded(&expr.data_type(schema)?)?, preserves_lex_ordering: false, }) - } else if expr.as_any().downcast_ref::().is_some() { + } else if expr.downcast_ref::().is_some() { Ok(ExprProperties { sort_properties: SortProperties::Unordered, range: Interval::make_unbounded(&expr.data_type(schema)?)?, preserves_lex_ordering: false, }) - } else if let Some(literal) = expr.as_any().downcast_ref::() { + } else if let Some(literal) = expr.downcast_ref::() { Ok(ExprProperties { sort_properties: SortProperties::Singleton, range: literal.value().into(), diff --git a/datafusion/physical-expr/src/equivalence/properties/union.rs b/datafusion/physical-expr/src/equivalence/properties/union.rs index efbefd0d39bfb..d77129472a8ba 100644 --- a/datafusion/physical-expr/src/equivalence/properties/union.rs +++ b/datafusion/physical-expr/src/equivalence/properties/union.rs @@ -23,7 +23,7 @@ use crate::equivalence::class::AcrossPartitions; use crate::{ConstExpr, PhysicalSortExpr}; use arrow::datatypes::SchemaRef; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{Result, internal_err}; use datafusion_physical_expr_common::sort_expr::LexOrdering; /// Computes the union (in the sense of `UnionExec`) `EquivalenceProperties` @@ -307,9 +307,9 @@ fn advance_if_matches_constant<'a>( #[cfg(test)] mod tests { use super::*; + use crate::PhysicalExpr; use crate::equivalence::tests::{create_test_schema, parse_sort_expr}; use crate::expressions::col; - use crate::PhysicalExpr; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::ScalarValue; diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 1e5c7e7024405..6f0b60556a751 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -17,26 +17,31 @@ mod kernels; -use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::PhysicalExpr; +use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use std::hash::Hash; -use std::{any::Any, sync::Arc}; +use std::sync::Arc; use arrow::array::*; use arrow::compute::kernels::boolean::{and_kleene, or_kleene}; -use arrow::compute::kernels::concat_elements::concat_elements_utf8; -use arrow::compute::{cast, filter_record_batch, SlicesIterator}; +use arrow::compute::kernels::concat_elements::{ + concat_element_binary, concat_elements_utf8, +}; +use arrow::compute::{SlicesIterator, cast, filter_record_batch}; use arrow::datatypes::*; use arrow::error::ArrowError; use datafusion_common::cast::as_boolean_array; -use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, internal_err, not_impl_err}; + use datafusion_expr::binary::BinaryTypeCoercer; -use datafusion_expr::interval_arithmetic::{apply_operator, Interval}; +use datafusion_expr::interval_arithmetic::{Interval, apply_operator}; use datafusion_expr::sort_properties::ExprProperties; +#[expect(deprecated)] use datafusion_expr::statistics::Distribution::{Bernoulli, Gaussian}; +#[expect(deprecated)] use datafusion_expr::statistics::{ - combine_bernoullis, combine_gaussians, create_bernoulli_from_comparison, - new_generic_from_binary_op, Distribution, + Distribution, combine_bernoullis, combine_gaussians, + create_bernoulli_from_comparison, new_generic_from_binary_op, }; use datafusion_expr::{ColumnarValue, Operator}; use datafusion_physical_expr_common::datum::{apply, apply_cmp}; @@ -45,7 +50,8 @@ use kernels::{ bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar, bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn, bitwise_shift_right_dyn_scalar, bitwise_xor_dyn, bitwise_xor_dyn_scalar, - concat_elements_utf8view, regex_match_dyn, regex_match_dyn_scalar, + concat_elements_binary_view_array, concat_elements_utf8view, regex_match_dyn, + regex_match_dyn_scalar, }; /// Binary expression @@ -129,7 +135,7 @@ impl std::fmt::Display for BinaryExpr { expr: &dyn PhysicalExpr, precedence: u8, ) -> std::fmt::Result { - if let Some(child) = expr.as_any().downcast_ref::() { + if let Some(child) = expr.downcast_ref::() { let p = child.op.precedence(); if p == 0 || p < precedence { write!(f, "({child})")?; @@ -162,12 +168,113 @@ fn boolean_op( op(ll, rr).map(|t| Arc::new(t) as _) } -impl PhysicalExpr for BinaryExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self +/// Returns true if both operands are Date types (Date32 or Date64) +/// Used to detect Date - Date operations which should return Int64 (days difference) +fn is_date_minus_date(lhs: &DataType, rhs: &DataType) -> bool { + matches!( + (lhs, rhs), + (DataType::Date32, DataType::Date32) | (DataType::Date64, DataType::Date64) + ) +} + +/// Milliseconds per day, used for Date64 subtraction. +const MILLIS_PER_DAY: i64 = 86_400_000; + +/// Evaluates `Date32 - Date32` or `Date64 - Date64`, returning the difference in +/// whole days as `Int64`. +/// +/// This matches the behavior of PostgreSQL, DuckDB, and MySQL, where +/// `date - date` yields an integer day count rather than an interval. +fn apply_date_subtraction( + lhs: &ColumnarValue, + rhs: &ColumnarValue, +) -> Result { + match (lhs.data_type(), rhs.data_type()) { + (DataType::Date32, DataType::Date32) => { + subtract_date_to_days::(lhs, rhs, |l, r| l - r) + } + (DataType::Date64, DataType::Date64) => { + subtract_date_to_days::(lhs, rhs, |l, r| { + l.wrapping_sub(r) / MILLIS_PER_DAY + }) + } + (_, _) => unreachable!("apply_date_subtraction called with non-date types"), + } +} + +/// Generic date subtraction: operates directly on the native primitive values +/// of `T` (i32 for Date32, i64 for Date64), applying `day_diff_fn` to produce +/// an Int64 day count. +fn subtract_date_to_days( + lhs: &ColumnarValue, + rhs: &ColumnarValue, + day_diff_fn: impl Fn(i64, i64) -> i64, +) -> Result +where + T::Native: Copy + Into, +{ + /// Extract the date value as `i64`. Returns `None` for null scalars. + fn date_scalar_to_i64( + scalar: &ScalarValue, + ) -> Result> { + match scalar { + ScalarValue::Date32(value) if P::DATA_TYPE == DataType::Date32 => { + Ok(value.map(i64::from)) + } + ScalarValue::Date64(value) if P::DATA_TYPE == DataType::Date64 => Ok(*value), + other => { + internal_err!( + "{} date scalar expected, got: {}", + P::DATA_TYPE, + other.data_type() + ) + } + } + } + + match (lhs, rhs) { + (ColumnarValue::Array(left), ColumnarValue::Array(right)) => { + let left = left.as_primitive::(); + let right = right.as_primitive::(); + let result: Int64Array = + arrow::compute::binary::<_, _, _, Int64Type>(left, right, |l, r| { + day_diff_fn(l.into(), r.into()) + })?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => { + let left = left.as_primitive::(); + match date_scalar_to_i64::(right)? { + Some(right_val) => { + let result: Int64Array = + left.unary(|l| day_diff_fn(l.into(), right_val)); + Ok(ColumnarValue::Array(Arc::new(result))) + } + None => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))), + } + } + (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => { + let right = right.as_primitive::(); + match date_scalar_to_i64::(left)? { + Some(left_val) => { + let result: Int64Array = + right.unary(|r| day_diff_fn(left_val, r.into())); + Ok(ColumnarValue::Array(Arc::new(result))) + } + None => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))), + } + } + (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => { + let left_val = date_scalar_to_i64::(left)?; + let right_val = date_scalar_to_i64::(right)?; + Ok(ColumnarValue::Scalar(ScalarValue::Int64( + left_val.zip(right_val).map(|(l, r)| day_diff_fn(l, r)), + ))) + } } +} +impl PhysicalExpr for BinaryExpr { fn data_type(&self, input_schema: &Schema) -> Result { BinaryTypeCoercer::new( &self.left.data_type(input_schema)?, @@ -205,11 +312,11 @@ impl PhysicalExpr for BinaryExpr { ColumnarValue::Array(array) => { // When the array on the right is all true or all false, skip the scatter process let boolean_array = array.as_boolean(); - let true_count = boolean_array.true_count(); - let length = boolean_array.len(); - if true_count == length { + if boolean_array.null_count() == 0 && !boolean_array.has_false() { return Ok(lhs); - } else if true_count == 0 && boolean_array.null_count() == 0 { + } else if boolean_array.null_count() == 0 + && !boolean_array.has_true() + { // If the right-hand array is returned at this point,the lengths will be inconsistent; // returning a scalar can avoid this issue return Ok(ColumnarValue::Scalar(ScalarValue::Boolean( @@ -251,6 +358,11 @@ impl PhysicalExpr for BinaryExpr { match self.op { Operator::Plus if self.fail_on_overflow => return apply(&lhs, &rhs, add), Operator::Plus => return apply(&lhs, &rhs, add_wrapping), + // Special case: Date - Date returns Int64 (days difference) + // This aligns with PostgreSQL, DuckDB, and MySQL behavior + Operator::Minus if is_date_minus_date(&left_data_type, &right_data_type) => { + return apply_date_subtraction(&lhs, &rhs); + } Operator::Minus if self.fail_on_overflow => return apply(&lhs, &rhs, sub), Operator::Minus => return apply(&lhs, &rhs, sub_wrapping), Operator::Multiply if self.fail_on_overflow => return apply(&lhs, &rhs, mul), @@ -278,18 +390,14 @@ impl PhysicalExpr for BinaryExpr { let result_type = self.data_type(input_schema)?; // If the left-hand side is an array and the right-hand side is a non-null scalar, try the optimized kernel. - if let (ColumnarValue::Array(array), ColumnarValue::Scalar(ref scalar)) = - (&lhs, &rhs) + if let (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) = (&lhs, &rhs) + && !scalar.is_null() + && let Some(result_array) = + self.evaluate_array_scalar(array, scalar.clone())? { - if !scalar.is_null() { - if let Some(result_array) = - self.evaluate_array_scalar(array, scalar.clone())? - { - let final_array = result_array - .and_then(|a| to_result_type_array(&self.op, a, &result_type)); - return final_array.map(ColumnarValue::Array); - } - } + let final_array = result_array + .and_then(|a| to_result_type_array(&self.op, a, &result_type)); + return final_array.map(ColumnarValue::Array); } // if both arrays or both literals - extract arrays and continue execution @@ -413,6 +521,7 @@ impl PhysicalExpr for BinaryExpr { } } + #[expect(deprecated)] fn evaluate_statistics(&self, children: &[&Distribution]) -> Result { let (left, right) = (children[0], children[1]); @@ -420,10 +529,10 @@ impl PhysicalExpr for BinaryExpr { // We might be able to construct the output statistics more accurately, // without falling back to an unknown distribution, if we are dealing // with Gaussian distributions and numerical operations. - if let (Gaussian(left), Gaussian(right)) = (left, right) { - if let Some(result) = combine_gaussians(&self.op, left, right)? { - return Ok(Gaussian(result)); - } + if let (Gaussian(left), Gaussian(right)) = (left, right) + && let Some(result) = combine_gaussians(&self.op, left, right)? + { + return Ok(Gaussian(result)); } } else if self.op.is_logic_operator() { // If we are dealing with logical operators, we expect (and can only @@ -500,7 +609,7 @@ impl PhysicalExpr for BinaryExpr { expr: &dyn PhysicalExpr, precedence: u8, ) -> std::fmt::Result { - if let Some(child) = expr.as_any().downcast_ref::() { + if let Some(child) = expr.downcast_ref::() { let p = child.op.precedence(); if p == 0 || p < precedence { write!(f, "(")?; @@ -519,6 +628,108 @@ impl PhysicalExpr for BinaryExpr { write!(f, " {} ", self.op)?; write_child(f, self.right.as_ref(), precedence) } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + + // Linearize a nested binary expression tree of the same operator + // into a flat vector of operands to avoid deep recursion in proto. + let op = self.op; + let mut operand_refs: Vec<&Arc> = vec![&self.right]; + let mut current_expr: &BinaryExpr = self; + loop { + match current_expr.left.downcast_ref::() { + Some(bin) if bin.op == op => { + operand_refs.push(&bin.right); + current_expr = bin; + } + _ => { + operand_refs.push(¤t_expr.left); + break; + } + } + } + // Reverse so operands are ordered from left innermost to right outermost. + operand_refs.reverse(); + + let operands = ctx.encode_children_expressions(operand_refs)?; + + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( + Box::new(protobuf::PhysicalBinaryExprNode { + l: None, + r: None, + op: format!("{op:?}"), + operands, + }), + )), + })) + } +} + +#[cfg(feature = "proto")] +impl BinaryExpr { + /// Reconstruct a [`BinaryExpr`] (or a left-deep tree of them when the proto + /// uses the linearized `operands` form) from its protobuf representation. + /// + /// Takes the whole [`PhysicalExprNode`] — the exact inverse of what + /// [`PhysicalExpr::try_to_proto`] produces — so every expression's + /// `try_from_proto` shares one signature. The operator string is parsed + /// via the canonical [`Operator::from_proto_name`] mapping, so no `op` + /// argument needs to be threaded in by the caller. + /// + /// [`PhysicalExprNode`]: datafusion_proto_models::protobuf::PhysicalExprNode + /// [`PhysicalExpr::try_to_proto`]: datafusion_physical_expr_common::physical_expr::PhysicalExpr::try_to_proto + /// [`PhysicalExprDecodeCtx::decode`]: datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx::decode + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_physical_expr_common::expect_expr_variant; + use datafusion_proto_models::protobuf; + let node = expect_expr_variant!( + node, + protobuf::physical_expr_node::ExprType::BinaryExpr, + "BinaryExpr", + ); + let op = Operator::from_proto_name(&node.op).ok_or_else(|| { + datafusion_common::DataFusionError::Internal(format!( + "Unsupported binary operator '{}'", + node.op + )) + })?; + + if !node.operands.is_empty() { + // New linearized format: reduce the flat operands list back into + // a nested binary expression tree. + let operands = ctx.decode_children_expressions(&node.operands)?; + + if operands.len() < 2 { + return internal_err!( + "A binary expression must always have at least 2 operands" + ); + } + + Ok(operands + .into_iter() + .reduce(|left, right| { + Arc::new(BinaryExpr::new(left, op, right)) as Arc + }) + .expect("Binary expression could not be reduced to a single expression.")) + } else { + // Legacy format with l/r fields. + let left = + ctx.decode_required_expression(node.l.as_deref(), "BinaryExpr", "left")?; + let right = + ctx.decode_required_expression(node.r.as_deref(), "BinaryExpr", "right")?; + Ok(Arc::new(BinaryExpr::new(left, op, right))) + } + } } /// Casts dictionary array to result type for binary numerical operators. Such operators @@ -540,8 +751,8 @@ fn to_result_type_array( Ok(cast(&array, result_type)?) } else { internal_err!( - "Incompatible Dictionary value type {value_type} with result type {result_type} of Binary operator {op:?}" - ) + "Incompatible Dictionary value type {value_type} with result type {result_type} of Binary operator {op:?}" + ) } } _ => Ok(array), @@ -625,7 +836,7 @@ impl BinaryExpr { StringConcat => concat_elements(&left, &right), AtArrow | ArrowAt | Arrow | LongArrow | HashArrow | HashLongArrow | AtAt | HashMinus | AtQuestion | Question | QuestionAnd | QuestionPipe - | IntegerDivide => { + | IntegerDivide | Colon => { not_impl_err!( "Binary operator '{:?}' is not supported in the physical expr", self.op @@ -856,6 +1067,18 @@ fn concat_elements(left: &ArrayRef, right: &ArrayRef) -> Result { left.as_string_view(), right.as_string_view(), )?), + DataType::Binary => Arc::new(concat_element_binary::( + left.as_binary(), + right.as_binary(), + )?), + DataType::LargeBinary => Arc::new(concat_element_binary::( + left.as_binary(), + right.as_binary(), + )?), + DataType::BinaryView => Arc::new(concat_elements_binary_view_array( + left.as_binary_view(), + right.as_binary_view(), + )?), other => { return internal_err!( "Data type {other:?} not supported for binary operation 'concat_elements' on string arrays" @@ -895,7 +1118,7 @@ pub fn similar_to( #[cfg(test)] mod tests { use super::*; - use crate::expressions::{col, lit, try_cast, Column, Literal}; + use crate::expressions::{Column, Literal, col, lit, try_cast}; use datafusion_expr::lit as expr_lit; use datafusion_common::plan_datafusion_err; @@ -1018,7 +1241,8 @@ mod tests { ]); let a = $A_ARRAY::from($A_VEC); let b = $B_ARRAY::from($B_VEC); - let (lhs, rhs) = BinaryTypeCoercer::new(&$A_TYPE, &$OP, &$B_TYPE).get_input_types()?; + let (lhs, rhs) = + BinaryTypeCoercer::new(&$A_TYPE, &$OP, &$B_TYPE).get_input_types()?; let left = try_cast(col("a", &schema)?, &schema, lhs)?; let right = try_cast(col("b", &schema)?, &schema, rhs)?; @@ -1034,7 +1258,10 @@ mod tests { assert_eq!(expression.data_type(&schema)?, $C_TYPE); // compute - let result = expression.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array"); + let result = expression + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); // verify that the array's data_type is correct assert_eq!(*result.data_type(), $C_TYPE); @@ -1048,8 +1275,7 @@ mod tests { for (i, x) in $VEC.iter().enumerate() { let v = result.value(i); assert_eq!( - v, - *x, + v, *x, "Unexpected output at position {i}:\n\nActual:\n{v}\n\nExpected:\n{x}" ); } @@ -1804,6 +2030,82 @@ mod tests { Ok(()) } + #[test] + fn date32_minus_date32_returns_int64_days() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Date32, true), + Field::new("b", DataType::Date32, true), + ])); + let a = Arc::new(Date32Array::from(vec![ + Some(18_901), + Some(18_901), + None, + Some(18_900), + ])); + let b = Arc::new(Date32Array::from(vec![ + Some(18_898), + Some(18_904), + Some(18_900), + None, + ])); + + apply_arithmetic::( + schema, + vec![a, b], + Operator::Minus, + Int64Array::from(vec![Some(3), Some(-3), None, None]), + )?; + + Ok(()) + } + + #[test] + fn date64_minus_date64_returns_int64_days() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Date64, true), + Field::new("b", DataType::Date64, true), + ])); + let a = Arc::new(Date64Array::from(vec![ + Some(18_901 * MILLIS_PER_DAY), + Some(18_901 * MILLIS_PER_DAY), + None, + Some(18_900 * MILLIS_PER_DAY), + ])); + let b = Arc::new(Date64Array::from(vec![ + Some(18_898 * MILLIS_PER_DAY), + Some(18_904 * MILLIS_PER_DAY), + Some(18_900 * MILLIS_PER_DAY), + None, + ])); + + apply_arithmetic::( + schema, + vec![a, b], + Operator::Minus, + Int64Array::from(vec![Some(3), Some(-3), None, None]), + )?; + + Ok(()) + } + + #[test] + fn date32_minus_null_scalar_returns_int64_null_scalar() -> Result<()> { + let result = apply_date_subtraction( + &ColumnarValue::Array(Arc::new(Date32Array::from(vec![ + Some(18_901), + Some(18_900), + ]))), + &ColumnarValue::Scalar(ScalarValue::Date32(None)), + )?; + + assert!(matches!( + result, + ColumnarValue::Scalar(ScalarValue::Int64(None)) + )); + + Ok(()) + } + #[test] fn minus_op_dict() -> Result<()> { let schema = Schema::new(vec![ @@ -4426,11 +4728,13 @@ mod tests { // evaluate expression let result = expr.evaluate(&batch); - assert!(result - .err() - .unwrap() - .to_string() - .contains("Overflow happened on: 2147483647 + 1")); + assert!( + result + .err() + .unwrap() + .to_string() + .contains("Overflow happened on: 2147483647 + 1") + ); Ok(()) } @@ -4455,11 +4759,13 @@ mod tests { // evaluate expression let result = expr.evaluate(&batch); - assert!(result - .err() - .unwrap() - .to_string() - .contains("Overflow happened on: -2147483648 - 1")); + assert!( + result + .err() + .unwrap() + .to_string() + .contains("Overflow happened on: -2147483648 - 1") + ); Ok(()) } @@ -4484,11 +4790,13 @@ mod tests { // evaluate expression let result = expr.evaluate(&batch); - assert!(result - .err() - .unwrap() - .to_string() - .contains("Overflow happened on: 2147483647 * 2")); + assert!( + result + .err() + .unwrap() + .to_string() + .contains("Overflow happened on: 2147483647 * 2") + ); Ok(()) } @@ -4557,7 +4865,6 @@ mod tests { schema: &Schema, ) -> Result { Ok(binary_op(left, op, right, schema)? - .as_any() .downcast_ref::() .unwrap() .clone()) @@ -4565,6 +4872,7 @@ mod tests { /// Test for Uniform-Uniform, Unknown-Uniform, Uniform-Unknown and Unknown-Unknown evaluation. #[test] + #[expect(deprecated)] fn test_evaluate_statistics_combination_of_range_holders() -> Result<()> { let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]); let a = Arc::new(Column::new("a", 0)) as _; @@ -4632,6 +4940,7 @@ mod tests { } #[test] + #[expect(deprecated)] fn test_evaluate_statistics_bernoulli() -> Result<()> { let schema = &Schema::new(vec![ Field::new("a", DataType::Int64, false), @@ -4667,6 +4976,7 @@ mod tests { } #[test] + #[expect(deprecated)] fn test_propagate_statistics_combination_of_range_holders_arithmetic() -> Result<()> { let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]); let a = Arc::new(Column::new("a", 0)) as _; @@ -4736,6 +5046,7 @@ mod tests { } #[test] + #[expect(deprecated)] fn test_propagate_statistics_combination_of_range_holders_comparison() -> Result<()> { let schema = &Schema::new(vec![Field::new("a", DataType::Float64, false)]); let a = Arc::new(Column::new("a", 0)) as _; @@ -4797,9 +5108,10 @@ mod tests { let child_refs = child_view.iter().collect::>(); for op in &ops { let expr = binary_expr(Arc::clone(&a), *op, Arc::clone(&b), schema)?; - assert!(expr - .propagate_statistics(&parent, child_refs.as_slice())? - .is_some()); + assert!( + expr.propagate_statistics(&parent, child_refs.as_slice())? + .is_some() + ); } } diff --git a/datafusion/physical-expr/src/expressions/binary/kernels.rs b/datafusion/physical-expr/src/expressions/binary/kernels.rs index ad44b00212039..e573d7ece2afa 100644 --- a/datafusion/physical-expr/src/expressions/binary/kernels.rs +++ b/datafusion/physical-expr/src/expressions/binary/kernels.rs @@ -18,6 +18,7 @@ //! This module contains computation kernels that are specific to //! datafusion and not (yet) targeted to port upstream to arrow use arrow::array::*; +use arrow::buffer::{MutableBuffer, NullBuffer}; use arrow::compute::kernels::bitwise::{ bitwise_and, bitwise_and_scalar, bitwise_or, bitwise_or_scalar, bitwise_shift_left, bitwise_shift_left_scalar, bitwise_shift_right, bitwise_shift_right_scalar, @@ -27,8 +28,8 @@ use arrow::compute::kernels::boolean::not; use arrow::compute::kernels::comparison::{regexp_is_match, regexp_is_match_scalar}; use arrow::datatypes::DataType; use arrow::error::ArrowError; -use datafusion_common::{internal_err, plan_err}; use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{internal_err, plan_err}; use std::sync::Arc; @@ -108,16 +109,35 @@ macro_rules! call_scalar_kernel { /// downcasts left / right to the appropriate integral type and calls the kernel macro_rules! create_left_integral_dyn_scalar_kernel { ($FUNC:ident, $KERNEL:ident) => { - pub(crate) fn $FUNC(array: &dyn Array, scalar: ScalarValue) -> Option> { + pub(crate) fn $FUNC( + array: &dyn Array, + scalar: ScalarValue, + ) -> Option> { let result = match array.data_type() { - DataType::Int8 => call_scalar_kernel!(array, scalar, $KERNEL, Int8Array, i8), - DataType::Int16 => call_scalar_kernel!(array, scalar, $KERNEL, Int16Array, i16), - DataType::Int32 => call_scalar_kernel!(array, scalar, $KERNEL, Int32Array, i32), - DataType::Int64 => call_scalar_kernel!(array, scalar, $KERNEL, Int64Array, i64), - DataType::UInt8 => call_scalar_kernel!(array, scalar, $KERNEL, UInt8Array, u8), - DataType::UInt16 => call_scalar_kernel!(array, scalar, $KERNEL, UInt16Array, u16), - DataType::UInt32 => call_scalar_kernel!(array, scalar, $KERNEL, UInt32Array, u32), - DataType::UInt64 => call_scalar_kernel!(array, scalar, $KERNEL, UInt64Array, u64), + DataType::Int8 => { + call_scalar_kernel!(array, scalar, $KERNEL, Int8Array, i8) + } + DataType::Int16 => { + call_scalar_kernel!(array, scalar, $KERNEL, Int16Array, i16) + } + DataType::Int32 => { + call_scalar_kernel!(array, scalar, $KERNEL, Int32Array, i32) + } + DataType::Int64 => { + call_scalar_kernel!(array, scalar, $KERNEL, Int64Array, i64) + } + DataType::UInt8 => { + call_scalar_kernel!(array, scalar, $KERNEL, UInt8Array, u8) + } + DataType::UInt16 => { + call_scalar_kernel!(array, scalar, $KERNEL, UInt16Array, u16) + } + DataType::UInt32 => { + call_scalar_kernel!(array, scalar, $KERNEL, UInt32Array, u32) + } + DataType::UInt64 => { + call_scalar_kernel!(array, scalar, $KERNEL, UInt64Array, u64) + } other => plan_err!( "Data type {} not supported for binary operation '{}' on dyn arrays", other, @@ -141,11 +161,11 @@ create_left_integral_dyn_scalar_kernel!( bitwise_shift_left_scalar ); -/// Concatenates two `StringViewArray`s element-wise. +/// Concatenates two `StringViewArray`s element-wise. /// If either element is `Null`, the result element is also `Null`. /// /// # Errors -/// - Returns an error if the input arrays have different lengths. +/// - Returns an error if the input arrays have different lengths. /// - Returns an error if any concatenated string exceeds `u32::MAX` (≈4 GB) in length. pub fn concat_elements_utf8view( left: &StringViewArray, @@ -158,24 +178,71 @@ pub fn concat_elements_utf8view( right.len() ))); } - let capacity = left.len(); - let mut result = StringViewBuilder::with_capacity(capacity); + let mut result = StringViewBuilder::with_capacity(left.len()); - // Avoid reallocations by writing to a reused buffer (note we - // could be even more efficient r by creating the view directly - // here and avoid the buffer but that would be more complex) + // Avoid reallocations by writing to a reused buffer (note we could be even + // more efficient by creating the view directly here and avoid the buffer + // but that would be more complex) let mut buffer = String::new(); - for (left, right) in left.iter().zip(right.iter()) { - if let (Some(left), Some(right)) = (left, right) { - use std::fmt::Write; + // Pre-compute combined null bitmap, so the per-row NULL check is more + // efficient + let nulls = NullBuffer::union(left.nulls(), right.nulls()); + + for i in 0..left.len() { + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + result.append_null(); + } else { + let l = left.value(i); + let r = right.value(i); buffer.clear(); - write!(&mut buffer, "{left}{right}") - .expect("writing into string buffer failed"); + buffer.push_str(l); + buffer.push_str(r); result.try_append_value(&buffer)?; + } + } + Ok(result.finish()) +} + +/// Concatenates two `BinaryViewArray`s element-wise. +/// If either element is `Null`, the result element is also `Null`. +/// +/// # Errors +/// - Returns an error if the input arrays have different lengths. +/// - Returns an error if any concatenated string exceeds `u32::MAX` in length. +pub fn concat_elements_binary_view_array( + left: &BinaryViewArray, + right: &BinaryViewArray, +) -> std::result::Result { + if left.len() != right.len() { + return Err(ArrowError::ComputeError(format!( + "Arrays must have the same length: {} != {}", + left.len(), + right.len() + ))); + } + let mut result = BinaryViewBuilder::with_capacity(left.len()); + + // Avoid reallocations by writing to a reused buffer (note we could be even + // more efficient by creating the view directly here and avoid the buffer + // but that would be more complex) + let mut buffer = MutableBuffer::new(0); + + // Pre-compute combined null bitmap, so the per-row NULL check is more + // efficient + let nulls = NullBuffer::union(left.nulls(), right.nulls()); + + for i in 0..left.len() { + if nulls.as_ref().is_some_and(|n| n.is_null(i)) { + result.append_null(); } else { - // at least one of the values is null, so the output is also null - result.append_null() + let l = left.value(i); + let r = right.value(i); + buffer.clear(); + buffer.extend_from_slice(l); + buffer.extend_from_slice(r); + // No try-version of append_value + result.try_append_value(&buffer)?; } } Ok(result.finish()) @@ -296,8 +363,8 @@ pub(crate) fn regex_match_dyn_scalar( ) } other => internal_err!( - "Data type {} not supported for operation 'regex_match_dyn_scalar' on string array", - other + "Data type {} not supported for operation 'regex_match_dyn_scalar' on string array", + other ), }; Some(result) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index a3c181368d5f3..8a0f15467c47b 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -15,34 +15,39 @@ // specific language governing permissions and limitations // under the License. +mod literal_lookup_table; + use super::{Column, Literal}; -use crate::expressions::case::ResultState::{Complete, Empty, Partial}; -use crate::expressions::{lit, try_cast}; use crate::PhysicalExpr; +use crate::expressions::{LambdaVariable, lit, try_cast}; use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{ - is_not_null, not, nullif, prep_null_mask_filter, FilterBuilder, FilterPredicate, - SlicesIterator, + FilterBuilder, FilterPredicate, is_not_null, not, nullif, prep_null_mask_filter, }; use arrow::datatypes::{DataType, Schema, UInt32Type, UnionMode}; use arrow::error::ArrowError; use datafusion_common::cast::as_boolean_array; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{ - assert_or_internal_err, exec_err, internal_datafusion_err, internal_err, - DataFusionError, HashMap, HashSet, Result, ScalarValue, + DataFusionError, Result, ScalarValue, assert_or_internal_err, exec_err, + internal_datafusion_err, internal_err, }; use datafusion_expr::ColumnarValue; +use indexmap::IndexMap; use std::borrow::Cow; +use std::collections::BTreeSet; use std::hash::Hash; -use std::{any::Any, sync::Arc}; +use std::sync::Arc; +use crate::expressions::case::literal_lookup_table::LiteralLookupTable; +use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_physical_expr_common::datum::compare_with_eq; +use datafusion_physical_expr_common::utils::scatter; use itertools::Itertools; use std::fmt::{Debug, Formatter}; -type WhenThen = (Arc, Arc); +pub(super) type WhenThen = (Arc, Arc); #[derive(Debug, Hash, PartialEq, Eq)] enum EvalMethod { @@ -61,7 +66,7 @@ enum EvalMethod { /// for expressions that are infallible and can be cheaply computed for the entire /// record batch rather than just for the rows where the predicate is true. /// - /// CASE WHEN condition THEN column [ELSE NULL] END + /// CASE WHEN condition THEN infallible_expression [ELSE NULL] END InfallibleExprOrNull, /// This is a specialization for a specific use case where we can take a fast path /// if there is just one when/then pair and both the `then` and `else` expressions @@ -69,12 +74,45 @@ enum EvalMethod { /// CASE WHEN condition THEN literal ELSE literal END ScalarOrScalar, /// This is a specialization for a specific use case where we can take a fast path - /// if there is just one when/then pair and both the `then` and `else` are expressions + /// if there is just one when/then pair, the `then` is an expression, and `else` is either + /// an expression, literal NULL or absent. + /// + /// In contrast to [`EvalMethod::InfallibleExprOrNull`], this specialization can handle fallible + /// `then` expressions. /// - /// CASE WHEN condition THEN expression ELSE expression END + /// CASE WHEN condition THEN expression [ELSE expression] END ExpressionOrExpression(ProjectedCaseBody), + + /// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals + /// + /// See [`LiteralLookupTable`] for more details + WithExprScalarLookupTable(LiteralLookupTable), +} + +/// Implementing hash so we can use `derive` on [`EvalMethod`]. +/// +/// not implementing actual [`Hash`] as it is not dyn compatible so we cannot implement it for +/// `dyn` [`literal_lookup_table::WhenLiteralIndexMap`]. +/// +/// So implementing empty hash is still valid as the data is derived from `PhysicalExpr` s which are already hashed +impl Hash for LiteralLookupTable { + fn hash(&self, _state: &mut H) {} +} + +/// Implementing Equal so we can use `derive` on [`EvalMethod`]. +/// +/// not implementing actual [`PartialEq`] as it is not dyn compatible so we cannot implement it for +/// `dyn` [`literal_lookup_table::WhenLiteralIndexMap`]. +/// +/// So we always return true as the data is derived from `PhysicalExpr` s which are already compared +impl PartialEq for LiteralLookupTable { + fn eq(&self, _other: &Self) -> bool { + true + } } +impl Eq for LiteralLookupTable {} + /// The body of a CASE expression which consists of an optional base expression, the "when/then" /// branches and an optional "else" branch. #[derive(Debug, Hash, PartialEq, Eq)] @@ -91,11 +129,16 @@ impl CaseBody { /// Derives a [ProjectedCaseBody] from this [CaseBody]. fn project(&self) -> Result { // Determine the set of columns that are used in all the expressions of the case body. - let mut used_column_indices = HashSet::::new(); + // Use an ordered set so lambda variables continue to be positioned after columns + let mut used_column_indices = BTreeSet::::new(); let mut collect_column_indices = |expr: &Arc| { expr.apply(|expr| { - if let Some(column) = expr.as_any().downcast_ref::() { + if let Some(column) = expr.downcast_ref::() { used_column_indices.insert(column.index()); + } else if let Some(lambda_variable) = + expr.downcast_ref::() + { + used_column_indices.insert(lambda_variable.index()); } Ok(TreeNodeRecursion::Continue) }) @@ -118,14 +161,14 @@ impl CaseBody { .iter() .enumerate() .map(|(projected, original)| (*original, projected)) - .collect::>(); + .collect::>(); // Construct the projected body by rewriting each expression from the original body // using the column index mapping. let project = |expr: &Arc| -> Result> { Arc::clone(expr) .transform_down(|e| { - if let Some(column) = e.as_any().downcast_ref::() { + if let Some(column) = e.downcast_ref::() { let original = column.index(); let projected = *column_index_map.get(&original).unwrap(); if projected != original { @@ -134,6 +177,17 @@ impl CaseBody { projected, )))); } + } else if let Some(lambda_variable) = + e.downcast_ref::() + { + let original = lambda_variable.index(); + let projected = *column_index_map.get(&original).unwrap(); + if projected != original { + return Ok(Transformed::yes(Arc::new(LambdaVariable::new( + projected, + Arc::clone(lambda_variable.field()), + )))); + } } Ok(Transformed::no(e)) }) @@ -214,7 +268,7 @@ struct ProjectedCaseBody { /// [WHEN ...] /// [ELSE result] /// END -#[derive(Debug, Hash, PartialEq, Eq)] +#[derive(Debug)] pub struct CaseExpr { /// The case expression body body: CaseBody, @@ -222,6 +276,23 @@ pub struct CaseExpr { eval_method: EvalMethod, } +// eval_method is functionally derived from body, so excluding it from +// Hash/Eq avoids redundantly hashing the expression tree twice. For +// nested CASE chains this prevents exponential blowup (see #22173). +impl Hash for CaseExpr { + fn hash(&self, state: &mut H) { + self.body.hash(state); + } +} + +impl PartialEq for CaseExpr { + fn eq(&self, other: &Self) -> bool { + self.body == other.body + } +} + +impl Eq for CaseExpr {} + impl std::fmt::Display for CaseExpr { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { write!(f, "CASE ")?; @@ -244,7 +315,7 @@ impl std::fmt::Display for CaseExpr { /// this is limited to use with Column expressions but could potentially be used for other /// expressions in the future fn is_cheap_and_infallible(expr: &Arc) -> bool { - expr.as_any().is::() + expr.is::() } /// Creates a [FilterPredicate] from a boolean array. @@ -305,189 +376,6 @@ fn filter_array( filter.filter(array) } -fn merge( - mask: &BooleanArray, - truthy: ColumnarValue, - falsy: ColumnarValue, -) -> std::result::Result { - let (truthy, truthy_is_scalar) = match truthy { - ColumnarValue::Array(a) => (a, false), - ColumnarValue::Scalar(s) => (s.to_array()?, true), - }; - let (falsy, falsy_is_scalar) = match falsy { - ColumnarValue::Array(a) => (a, false), - ColumnarValue::Scalar(s) => (s.to_array()?, true), - }; - - if truthy_is_scalar && falsy_is_scalar { - return zip(mask, &Scalar::new(truthy), &Scalar::new(falsy)); - } - - let falsy = falsy.to_data(); - let truthy = truthy.to_data(); - - let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, truthy.len()); - - // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to - // fill with falsy values - - // keep track of how much is filled - let mut filled = 0; - let mut falsy_offset = 0; - let mut truthy_offset = 0; - - SlicesIterator::new(mask).for_each(|(start, end)| { - // the gap needs to be filled with falsy values - if start > filled { - if falsy_is_scalar { - for _ in filled..start { - // Copy the first item from the 'falsy' array into the output buffer. - mutable.extend(1, 0, 1); - } - } else { - let falsy_length = start - filled; - let falsy_end = falsy_offset + falsy_length; - mutable.extend(1, falsy_offset, falsy_end); - falsy_offset = falsy_end; - } - } - // fill with truthy values - if truthy_is_scalar { - for _ in start..end { - // Copy the first item from the 'truthy' array into the output buffer. - mutable.extend(0, 0, 1); - } - } else { - let truthy_length = end - start; - let truthy_end = truthy_offset + truthy_length; - mutable.extend(0, truthy_offset, truthy_end); - truthy_offset = truthy_end; - } - filled = end; - }); - // the remaining part is falsy - if filled < mask.len() { - if falsy_is_scalar { - for _ in filled..mask.len() { - // Copy the first item from the 'falsy' array into the output buffer. - mutable.extend(1, 0, 1); - } - } else { - let falsy_length = mask.len() - filled; - let falsy_end = falsy_offset + falsy_length; - mutable.extend(1, falsy_offset, falsy_end); - } - } - - let data = mutable.freeze(); - Ok(make_array(data)) -} - -/// Merges elements by index from a list of [`ArrayData`], creating a new [`ColumnarValue`] from -/// those values. -/// -/// Each element in `indices` is the index of an array in `values`. The `indices` array is processed -/// sequentially. The first occurrence of index value `n` will be mapped to the first -/// value of the array at index `n`. The second occurrence to the second value, and so on. -/// An index value where `PartialResultIndex::is_none` is `true` is used to indicate null values. -/// -/// # Implementation notes -/// -/// This algorithm is similar in nature to both `zip` and `interleave`, but there are some important -/// differences. -/// -/// In contrast to `zip`, this function supports multiple input arrays. Instead of a boolean -/// selection vector, an index array is to take values from the input arrays, and a special marker -/// value is used to indicate null values. -/// -/// In contrast to `interleave`, this function does not use pairs of indices. The values in -/// `indices` serve the same purpose as the first value in the pairs passed to `interleave`. -/// The index in the array is implicit and is derived from the number of times a particular array -/// index occurs. -/// The more constrained indexing mechanism used by this algorithm makes it easier to copy values -/// in contiguous slices. In the example below, the two subsequent elements from array `2` can be -/// copied in a single operation from the source array instead of copying them one by one. -/// Long spans of null values are also especially cheap because they do not need to be represented -/// in an input array. -/// -/// # Safety -/// -/// This function does not check that the number of occurrences of any particular array index matches -/// the length of the corresponding input array. If an array contains more values than required, the -/// spurious values will be ignored. If an array contains fewer values than necessary, this function -/// will panic. -/// -/// # Example -/// -/// ```text -/// ┌───────────┐ ┌─────────┐ ┌─────────┐ -/// │┌─────────┐│ │ None │ │ NULL │ -/// ││ A ││ ├─────────┤ ├─────────┤ -/// │└─────────┘│ │ 1 │ │ B │ -/// │┌─────────┐│ ├─────────┤ ├─────────┤ -/// ││ B ││ │ 0 │ merge(values, indices) │ A │ -/// │└─────────┘│ ├─────────┤ ─────────────────────────▶ ├─────────┤ -/// │┌─────────┐│ │ None │ │ NULL │ -/// ││ C ││ ├─────────┤ ├─────────┤ -/// │├─────────┤│ │ 2 │ │ C │ -/// ││ D ││ ├─────────┤ ├─────────┤ -/// │└─────────┘│ │ 2 │ │ D │ -/// └───────────┘ └─────────┘ └─────────┘ -/// values indices result -/// ``` -fn merge_n(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result { - #[cfg(debug_assertions)] - for ix in indices { - if let Some(index) = ix.index() { - assert!( - index < values.len(), - "Index out of bounds: {} >= {}", - index, - values.len() - ); - } - } - - let data_refs = values.iter().collect(); - let mut mutable = MutableArrayData::new(data_refs, true, indices.len()); - - // This loop extends the mutable array by taking slices from the partial results. - // - // take_offsets keeps track of how many values have been taken from each array. - let mut take_offsets = vec![0; values.len() + 1]; - let mut start_row_ix = 0; - loop { - let array_ix = indices[start_row_ix]; - - // Determine the length of the slice to take. - let mut end_row_ix = start_row_ix + 1; - while end_row_ix < indices.len() && indices[end_row_ix] == array_ix { - end_row_ix += 1; - } - let slice_length = end_row_ix - start_row_ix; - - // Extend mutable with either nulls or with values from the array. - match array_ix.index() { - None => mutable.extend_nulls(slice_length), - Some(index) => { - let start_offset = take_offsets[index]; - let end_offset = start_offset + slice_length; - mutable.extend(index, start_offset, end_offset); - take_offsets[index] = end_offset; - } - } - - if end_row_ix == indices.len() { - break; - } else { - // Set the start_row_ix for the next slice. - start_row_ix = end_row_ix; - } - } - - Ok(make_array(mutable.freeze())) -} - /// An index into the partial results array that's more compact than `usize`. /// /// `u32::MAX` is reserved as a special 'none' value. This is used instead of @@ -530,7 +418,9 @@ impl PartialResultIndex { fn is_none(&self) -> bool { self.index == NONE_VALUE } +} +impl MergeIndex for PartialResultIndex { /// Returns `Some(index)` if this value is not the 'none' placeholder, `None` otherwise. fn index(&self) -> Option { if self.is_none() { @@ -558,7 +448,7 @@ enum ResultState { Partial { // A `Vec` of partial results that should be merged. // `partial_result_indices` contains indexes into this vec. - arrays: Vec, + arrays: Vec, // Indicates per result row from which array in `partial_results` a value should be taken. indices: Vec, }, @@ -591,7 +481,7 @@ impl ResultBuilder { Self { data_type: data_type.clone(), row_count, - state: Empty, + state: ResultState::Empty, } } @@ -639,7 +529,7 @@ impl ResultBuilder { } else if row_indices.len() == self.row_count { self.set_complete_result(ColumnarValue::Array(a)) } else { - self.add_partial_result(row_indices, a.to_data()) + self.add_partial_result(row_indices, a) } } ColumnarValue::Scalar(s) => { @@ -648,7 +538,7 @@ impl ResultBuilder { } else { self.add_partial_result( row_indices, - s.to_array_of_size(row_indices.len())?.to_data(), + s.to_array_of_size(row_indices.len())?, ) } } @@ -663,7 +553,7 @@ impl ResultBuilder { fn add_partial_result( &mut self, row_indices: &ArrayRef, - row_values: ArrayData, + row_values: ArrayRef, ) -> Result<()> { assert_or_internal_err!( row_indices.null_count() == 0, @@ -671,21 +561,21 @@ impl ResultBuilder { ); match &mut self.state { - Empty => { + ResultState::Empty => { let array_index = PartialResultIndex::zero(); let mut indices = vec![PartialResultIndex::none(); self.row_count]; for row_ix in row_indices.as_primitive::().values().iter() { indices[*row_ix as usize] = array_index; } - self.state = Partial { + self.state = ResultState::Partial { arrays: vec![row_values], indices, }; Ok(()) } - Partial { arrays, indices } => { + ResultState::Partial { arrays, indices } => { let array_index = PartialResultIndex::try_new(arrays.len())?; arrays.push(row_values); @@ -705,7 +595,7 @@ impl ResultBuilder { } Ok(()) } - Complete(_) => internal_err!( + ResultState::Complete(_) => internal_err!( "Cannot add a partial result when complete result is already set" ), } @@ -718,23 +608,23 @@ impl ResultBuilder { /// without any merging overhead. fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> { match &self.state { - Empty => { - self.state = Complete(value); + ResultState::Empty => { + self.state = ResultState::Complete(value); Ok(()) } - Partial { .. } => { + ResultState::Partial { .. } => { internal_err!( "Cannot set a complete result when there are already partial results" ) } - Complete(_) => internal_err!("Complete result already set"), + ResultState::Complete(_) => internal_err!("Complete result already set"), } } /// Finishes building the result and returns the final array. fn finish(self) -> Result { match self.state { - Empty => { + ResultState::Empty => { // No complete result and no partial results. // This can happen for case expressions with no else branch where no rows // matched. @@ -742,11 +632,12 @@ impl ResultBuilder { &self.data_type, )?)) } - Partial { arrays, indices } => { + ResultState::Partial { arrays, indices } => { // Merge partial results into a single array. - Ok(ColumnarValue::Array(merge_n(&arrays, &indices)?)) + let array_refs = arrays.iter().map(|a| a.as_ref()).collect::>(); + Ok(ColumnarValue::Array(merge_n(&array_refs, &indices)?)) } - Complete(v) => { + ResultState::Complete(v) => { // If we have a complete result, we can just return it. Ok(v) } @@ -764,7 +655,7 @@ impl CaseExpr { // normalize null literals to None in the else_expr (this already happens // during SQL planning, but not necessarily for other use cases) let else_expr = match &else_expr { - Some(e) => match e.as_any().downcast_ref::() { + Some(e) => match e.downcast_ref::() { Some(lit) if lit.value().is_null() => None, _ => else_expr, }, @@ -781,28 +672,40 @@ impl CaseExpr { else_expr, }; - let eval_method = if body.expr.is_some() { - EvalMethod::WithExpression(body.project()?) - } else if body.when_then_expr.len() == 1 - && is_cheap_and_infallible(&(body.when_then_expr[0].1)) - && body.else_expr.is_none() - { - EvalMethod::InfallibleExprOrNull - } else if body.when_then_expr.len() == 1 - && body.when_then_expr[0].1.as_any().is::() - && body.else_expr.is_some() - && body.else_expr.as_ref().unwrap().as_any().is::() - { - EvalMethod::ScalarOrScalar - } else if body.when_then_expr.len() == 1 && body.else_expr.is_some() { - EvalMethod::ExpressionOrExpression(body.project()?) - } else { - EvalMethod::NoExpression(body.project()?) - }; + let eval_method = Self::find_best_eval_method(&body)?; Ok(Self { body, eval_method }) } + fn find_best_eval_method(body: &CaseBody) -> Result { + if body.expr.is_some() { + if let Some(mapping) = LiteralLookupTable::maybe_new(body) { + return Ok(EvalMethod::WithExprScalarLookupTable(mapping)); + } + + return Ok(EvalMethod::WithExpression(body.project()?)); + } + + Ok( + if body.when_then_expr.len() == 1 + && is_cheap_and_infallible(&(body.when_then_expr[0].1)) + && body.else_expr.is_none() + { + EvalMethod::InfallibleExprOrNull + } else if body.when_then_expr.len() == 1 + && body.when_then_expr[0].1.is::() + && body.else_expr.is_some() + && body.else_expr.as_ref().unwrap().is::() + { + EvalMethod::ScalarOrScalar + } else if body.when_then_expr.len() == 1 { + EvalMethod::ExpressionOrExpression(body.project()?) + } else { + EvalMethod::NoExpression(body.project()?) + }, + ) + } + /// Optional base expression that can be compared to literal values in the "when" expressions pub fn expr(&self) -> Option<&Arc> { self.body.expr.as_ref() @@ -831,10 +734,10 @@ impl CaseBody { } } // if all then results are null, we use data type of else expr instead if possible. - if data_type.equals_datatype(&DataType::Null) { - if let Some(e) = &self.else_expr { - data_type = e.data_type(input_schema)?; - } + if data_type.equals_datatype(&DataType::Null) + && let Some(e) = &self.else_expr + { + data_type = e.data_type(input_schema)?; } Ok(data_type) @@ -924,17 +827,15 @@ impl CaseBody { } }?; - // `true_count` ignores `true` values where the validity bit is not set, so there's - // no need to call `prep_null_mask_filter`. - let when_true_count = when_value.true_count(); - - // If the 'when' predicate did not match any rows, continue to the next branch immediately - if when_true_count == 0 { + // If the 'when' predicate did not match any rows, continue to the next branch immediately. + // Only counts valid slots that are true (masked-null predicate slots are ignored), + // so no `prep_null_mask_filter` needed here. + if !when_value.has_true() { continue; } // If the 'when' predicate matched all remaining rows, there is no need to filter - if when_true_count == remainder_batch.num_rows() { + if when_value.null_count() == 0 && !when_value.has_false() { let then_expression = &self.when_then_expr[i].1; let then_value = then_expression.evaluate(&remainder_batch)?; result_builder.add_branch_result(&remainder_rows, then_value)?; @@ -1013,17 +914,15 @@ impl CaseBody { internal_datafusion_err!("WHEN expression did not return a BooleanArray") })?; - // `true_count` ignores `true` values where the validity bit is not set, so there's - // no need to call `prep_null_mask_filter`. - let when_true_count = when_value.true_count(); - - // If the 'when' predicate did not match any rows, continue to the next branch immediately - if when_true_count == 0 { + // If the 'when' predicate did not match any rows, continue to the next branch immediately. + // Only counts valid slots that are true (masked-null predicate slots are ignored) + // so no `prep_null_mask_filter` needed here. + if !when_value.has_true() { continue; } // If the 'when' predicate matched all remaining rows, there is no need to filter - if when_true_count == remainder_batch.num_rows() { + if when_value.null_count() == 0 && !when_value.has_false() { let then_expression = &self.when_then_expr[i].1; let then_value = then_expression.evaluate(&remainder_batch)?; result_builder.add_branch_result(&remainder_rows, then_value)?; @@ -1097,23 +996,40 @@ impl CaseBody { let then_batch = filter_record_batch(batch, &when_filter)?; let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?; - let else_selection = not(&when_value)?; - let else_filter = create_filter(&else_selection, optimize_filter); - let else_batch = filter_record_batch(batch, &else_filter)?; - - // keep `else_expr`'s data type and return type consistent - let e = self.else_expr.as_ref().unwrap(); - let return_type = self.data_type(&batch.schema())?; - let else_expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| Arc::clone(e)); - - let else_value = else_expr.evaluate(&else_batch)?; - - Ok(ColumnarValue::Array(merge( - &when_value, - then_value, - else_value, - )?)) + match &self.else_expr { + None => { + let then_array = then_value.to_array(when_value.true_count())?; + scatter(&when_value, then_array.as_ref()).map(ColumnarValue::Array) + } + Some(else_expr) => { + let else_selection = not(&when_value)?; + let else_filter = create_filter(&else_selection, optimize_filter); + let else_batch = filter_record_batch(batch, &else_filter)?; + + // keep `else_expr`'s data type and return type consistent + let return_type = self.data_type(&batch.schema())?; + let else_expr = + try_cast(Arc::clone(else_expr), &batch.schema(), return_type.clone()) + .unwrap_or_else(|_| Arc::clone(else_expr)); + + let else_value = else_expr.evaluate(&else_batch)?; + + Ok(ColumnarValue::Array(match (then_value, else_value) { + (ColumnarValue::Array(t), ColumnarValue::Array(e)) => { + merge(&when_value, &t, &e) + } + (ColumnarValue::Scalar(t), ColumnarValue::Array(e)) => { + merge(&when_value, &t.to_scalar()?, &e) + } + (ColumnarValue::Array(t), ColumnarValue::Scalar(e)) => { + merge(&when_value, &t, &e.to_scalar()?) + } + (ColumnarValue::Scalar(t), ColumnarValue::Scalar(e)) => { + merge(&when_value, &t.to_scalar()?, &e.to_scalar()?) + } + }?)) + } + } } } @@ -1131,8 +1047,15 @@ impl CaseExpr { projected: &ProjectedCaseBody, ) -> Result { let return_type = self.data_type(&batch.schema())?; - if projected.projection.len() < batch.num_columns() { - let projected_batch = batch.project(&projected.projection)?; + // projected.projection may include indexes of lambda variables not available on this batch + let projection = projected + .projection + .iter() + .copied() + .filter(|index| *index < batch.num_columns()) + .collect::>(); + if projection.len() < batch.num_columns() { + let projected_batch = batch.project(&projection)?; projected .body .case_when_with_expr(&projected_batch, &return_type) @@ -1154,8 +1077,15 @@ impl CaseExpr { projected: &ProjectedCaseBody, ) -> Result { let return_type = self.data_type(&batch.schema())?; - if projected.projection.len() < batch.num_columns() { - let projected_batch = batch.project(&projected.projection)?; + // projected.projection may include indexes of lambda variables not available on this batch + let projection = projected + .projection + .iter() + .copied() + .filter(|index| *index < batch.num_columns()) + .collect::>(); + if projection.len() < batch.num_columns() { + let projected_batch = batch.project(&projection)?; projected .body .case_when_no_expr(&projected_batch, &return_type) @@ -1258,31 +1188,64 @@ impl CaseExpr { ) })?; - let true_count = when_value.true_count(); - if true_count == when_value.len() { + if when_value.null_count() == 0 && !when_value.has_false() { // All input rows are true, just call the 'then' expression self.body.when_then_expr[0].1.evaluate(batch) - } else if true_count == 0 { + } else if !when_value.has_true() { // All input rows are false/null, just call the 'else' expression - self.body.else_expr.as_ref().unwrap().evaluate(batch) - } else if projected.projection.len() < batch.num_columns() { - // The case expressions do not use all the columns of the input batch. - // Project first to reduce time spent filtering. - let projected_batch = batch.project(&projected.projection)?; - projected.body.expr_or_expr(&projected_batch, when_value) + match &self.body.else_expr { + Some(else_expr) => else_expr.evaluate(batch), + None => { + let return_type = self.data_type(&batch.schema())?; + Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + &return_type, + )?)) + } + } } else { - // All columns are used in the case expressions, so there is no need to project. - self.body.expr_or_expr(batch, when_value) + // projected.projection may include indexes of lambda variables not available on this batch + let projection = projected + .projection + .iter() + .copied() + .filter(|index| *index < batch.num_columns()) + .collect::>(); + if projection.len() < batch.num_columns() { + // The case expressions do not use all the columns of the input batch. + // Project first to reduce time spent filtering. + let projected_batch = batch.project(&projection)?; + projected.body.expr_or_expr(&projected_batch, when_value) + } else { + // All columns are used in the case expressions, so there is no need to project. + self.body.expr_or_expr(batch, when_value) + } } } -} -impl PhysicalExpr for CaseExpr { - /// Return a reference to Any that can be used for down-casting - fn as_any(&self) -> &dyn Any { - self + fn with_lookup_table( + &self, + batch: &RecordBatch, + lookup_table: &LiteralLookupTable, + ) -> Result { + let expr = self.body.expr.as_ref().unwrap(); + let evaluated_expression = expr.evaluate(batch)?; + + let is_scalar = matches!(evaluated_expression, ColumnarValue::Scalar(_)); + let evaluated_expression = evaluated_expression.to_array(1)?; + + let values = lookup_table.map_keys_to_values(&evaluated_expression)?; + + let result = if is_scalar { + ColumnarValue::Scalar(ScalarValue::try_from_array(values.as_ref(), 0)?) + } else { + ColumnarValue::Array(values) + }; + + Ok(result) } +} +impl PhysicalExpr for CaseExpr { fn data_type(&self, input_schema: &Schema) -> Result { self.body.data_type(input_schema) } @@ -1370,6 +1333,9 @@ impl PhysicalExpr for CaseExpr { } EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch), EvalMethod::ExpressionOrExpression(p) => self.expr_or_expr(batch, p), + EvalMethod::WithExprScalarLookupTable(lookup_table) => { + self.with_lookup_table(batch, lookup_table) + } } } @@ -1444,6 +1410,86 @@ impl PhysicalExpr for CaseExpr { } write!(f, "END") } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::Case(Box::new( + protobuf::PhysicalCaseNode { + expr: self + .expr() + .map(|expr| ctx.encode_child(expr).map(Box::new)) + .transpose()?, + when_then_expr: self + .when_then_expr() + .iter() + .map(|(when_expr, then_expr)| { + Ok(protobuf::PhysicalWhenThen { + when_expr: Some(ctx.encode_child(when_expr)?), + then_expr: Some(ctx.encode_child(then_expr)?), + }) + }) + .collect::>>()?, + else_expr: self + .else_expr() + .map(|expr| ctx.encode_child(expr).map(Box::new)) + .transpose()?, + }, + ))), + })) + } +} + +#[cfg(feature = "proto")] +impl CaseExpr { + /// Reconstruct a [`CaseExpr`] from its protobuf representation. + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_physical_expr_common::expect_expr_variant; + use datafusion_proto_models::protobuf; + + let case = expect_expr_variant!( + node, + protobuf::physical_expr_node::ExprType::Case, + "CaseExpr", + ); + + Ok(Arc::new(CaseExpr::try_new( + case.expr + .as_deref() + .map(|expr| ctx.decode(expr)) + .transpose()?, + case.when_then_expr + .iter() + .map(|when_then| { + Ok(( + ctx.decode_required_expression( + when_then.when_expr.as_ref(), + "CaseExpr", + "when_expr", + )?, + ctx.decode_required_expression( + when_then.then_expr.as_ref(), + "CaseExpr", + "then_expr", + )?, + )) + }) + .collect::>>()?, + case.else_expr + .as_deref() + .map(|expr| ctx.decode(expr)) + .transpose()?, + )?)) + } } /// Attempts to const evaluate the given `predicate`. @@ -1505,16 +1551,17 @@ mod tests { use super::*; use crate::expressions; - use crate::expressions::{binary, cast, col, is_not_null, lit, BinaryExpr}; + use crate::expressions::{BinaryExpr, binary, cast, col, is_not_null}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; use arrow::datatypes::Field; use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; - use datafusion_expr::type_coercion::binary::comparison_coercion; + use datafusion_expr::type_coercion::binary::type_union_coercion; use datafusion_expr_common::operator::Operator; use datafusion_physical_expr_common::physical_expr::fmt_sql; + use half::f16; #[test] fn case_with_expr() -> Result<()> { @@ -1585,6 +1632,86 @@ mod tests { Ok(()) } + // Make sure we are not failing when got literal in case when but input is dictionary encoded + #[test] + fn case_with_expr_primitive_dictionary() -> Result<()> { + let schema = Schema::new(vec![Field::new( + "a", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::UInt64)), + true, + )]); + let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]); + let values = UInt64Array::from(vec![Some(10), Some(20), None, Some(30)]); + let dictionary = DictionaryArray::new(keys, Arc::new(values)); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?; + + let schema = batch.schema(); + + // CASE a WHEN 10 THEN 123 WHEN 30 THEN 456 END + let when1 = lit(10_u64); + let then1 = lit(123_i32); + let when2 = lit(30_u64); + let then2 = lit(456_i32); + + let expr = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(when1, then1), (when2, then2)], + None, + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result)?; + + let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); + + assert_eq!(expected, result); + + Ok(()) + } + + // Make sure we are not failing when got literal in case when but input is dictionary encoded + #[test] + fn case_with_expr_boolean_dictionary() -> Result<()> { + let schema = Schema::new(vec![Field::new( + "a", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Boolean)), + true, + )]); + let keys = UInt8Array::from(vec![0u8, 1u8, 2u8, 3u8]); + let values = BooleanArray::from(vec![Some(true), Some(false), None, Some(true)]); + let dictionary = DictionaryArray::new(keys, Arc::new(values)); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dictionary)])?; + + let schema = batch.schema(); + + // CASE a WHEN true THEN 123 WHEN false THEN 456 END + let when1 = lit(true); + let then1 = lit(123i32); + let when2 = lit(false); + let then2 = lit(456i32); + + let expr = generate_case_when_with_type_coercion( + Some(col("a", &schema)?), + vec![(when1, then1), (when2, then2)], + None, + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result)?; + + let expected = &Int32Array::from(vec![Some(123), Some(456), None, Some(123)]); + + assert_eq!(expected, result); + + Ok(()) + } + #[test] fn case_with_expr_all_null_dictionary() -> Result<()> { let schema = Schema::new(vec![Field::new( @@ -2207,7 +2334,7 @@ mod tests { let expr2 = Arc::clone(&expr) .transform(|e| { - let transformed = match e.as_any().downcast_ref::() { + let transformed = match e.downcast_ref::() { Some(lit_value) => match lit_value.value() { ScalarValue::Utf8(Some(str_value)) => { Some(lit(str_value.to_uppercase())) @@ -2227,7 +2354,7 @@ mod tests { let expr3 = Arc::clone(&expr) .transform_down(|e| { - let transformed = match e.as_any().downcast_ref::() { + let transformed = match e.downcast_ref::() { Some(lit_value) => match lit_value.value() { ScalarValue::Utf8(Some(str_value)) => { Some(lit(str_value.to_uppercase())) @@ -2279,7 +2406,7 @@ mod tests { make_lit_i32(250), )); let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?; - assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull)); + assert_eq!(expr.eval_method, EvalMethod::InfallibleExprOrNull); match expr.evaluate(&batch)? { ColumnarValue::Array(array) => { assert_eq!(1000, array.len()); @@ -2381,9 +2508,7 @@ mod tests { thens_type .iter() .try_fold(else_type, |left_type, right_type| { - // TODO: now just use the `equal` coercion rule for case when. If find the issue, and - // refactor again. - comparison_coercion(&left_type, right_type) + type_union_coercion(&left_type, right_type) }) } @@ -2418,57 +2543,6 @@ mod tests { Ok(()) } - #[test] - fn test_merge_n() { - let a1 = StringArray::from(vec![Some("A")]).to_data(); - let a2 = StringArray::from(vec![Some("B")]).to_data(); - let a3 = StringArray::from(vec![Some("C"), Some("D")]).to_data(); - - let indices = vec![ - PartialResultIndex::none(), - PartialResultIndex::try_new(1).unwrap(), - PartialResultIndex::try_new(0).unwrap(), - PartialResultIndex::none(), - PartialResultIndex::try_new(2).unwrap(), - PartialResultIndex::try_new(2).unwrap(), - ]; - - let merged = merge_n(&[a1, a2, a3], &indices).unwrap(); - let merged = merged.as_string::(); - - assert_eq!(merged.len(), indices.len()); - assert!(!merged.is_valid(0)); - assert!(merged.is_valid(1)); - assert_eq!(merged.value(1), "B"); - assert!(merged.is_valid(2)); - assert_eq!(merged.value(2), "A"); - assert!(!merged.is_valid(3)); - assert!(merged.is_valid(4)); - assert_eq!(merged.value(4), "C"); - assert!(merged.is_valid(5)); - assert_eq!(merged.value(5), "D"); - } - - #[test] - fn test_merge() { - let a1 = Arc::new(StringArray::from(vec![Some("A"), Some("C")])); - let a2 = Arc::new(StringArray::from(vec![Some("B")])); - - let mask = BooleanArray::from(vec![true, false, true]); - - let merged = - merge(&mask, ColumnarValue::Array(a1), ColumnarValue::Array(a2)).unwrap(); - let merged = merged.as_string::(); - - assert_eq!(merged.len(), mask.len()); - assert!(merged.is_valid(0)); - assert_eq!(merged.value(0), "A"); - assert!(merged.is_valid(1)); - assert_eq!(merged.value(1), "B"); - assert!(merged.is_valid(2)); - assert_eq!(merged.value(2), "C"); - } - fn when_then_else( when: &Arc, then: &Arc, @@ -2646,4 +2720,730 @@ mod tests { assert_not_nullable(expr, schema); } } + + // Test Lookup evaluation + + fn test_case_when_literal_lookup( + values: ArrayRef, + lookup_map: &[(ScalarValue, ScalarValue)], + else_value: Option, + expected: ArrayRef, + ) { + // Create lookup + // CASE + // WHEN THEN + // WHEN THEN + // [ ELSE ] + + let schema = Schema::new(vec![Field::new( + "a", + values.data_type().clone(), + values.is_nullable(), + )]); + let schema = Arc::new(schema); + + let batch = RecordBatch::try_new(schema, vec![values]) + .expect("failed to create RecordBatch"); + + let schema = batch.schema_ref(); + let case = col("a", schema).expect("failed to create col"); + + let when_then = lookup_map + .iter() + .map(|(when, then)| { + ( + Arc::new(Literal::new(when.clone())) as _, + Arc::new(Literal::new(then.clone())) as _, + ) + }) + .collect::>(); + + let else_expr = else_value.map(|else_value| { + Arc::new(Literal::new(else_value)) as Arc + }); + let expr = CaseExpr::try_new(Some(case), when_then, else_expr) + .expect("failed to create case"); + + // Assert that we are testing what we intend to assert + assert!( + matches!( + expr.eval_method, + EvalMethod::WithExprScalarLookupTable { .. } + ), + "we should use the expected eval method" + ); + + let actual = expr + .evaluate(&batch) + .expect("failed to evaluate case") + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + + assert_eq!( + actual.data_type(), + expected.data_type(), + "Data type mismatch" + ); + + assert_eq!( + actual.as_ref(), + expected.as_ref(), + "actual (left) does not match expected (right)" + ); + } + + fn create_lookup( + when_then_pairs: impl IntoIterator, + ) -> Vec<(ScalarValue, ScalarValue)> + where + ScalarValue: From, + ScalarValue: From, + { + when_then_pairs + .into_iter() + .map(|(when, then)| (ScalarValue::from(when), ScalarValue::from(then))) + .collect() + } + + fn create_input_and_expected( + input_and_expected_pairs: impl IntoIterator, + ) -> (Input, Expected) + where + Input: Array + From>, + Expected: Array + From>, + { + let (input_items, expected_items): (Vec, Vec) = + input_and_expected_pairs.into_iter().unzip(); + + (Input::from(input_items), Expected::from(expected_items)) + } + + fn test_lookup_eval_with_and_without_else( + lookup_map: &[(ScalarValue, ScalarValue)], + input_values: ArrayRef, + expected: StringArray, + ) { + // Testing without ELSE should fallback to None + test_case_when_literal_lookup( + Arc::clone(&input_values), + lookup_map, + None, + Arc::new(expected.clone()), + ); + + // Testing with Else + let else_value = "___fallback___"; + + // Changing each expected None to be fallback + let expected_with_else = expected + .iter() + .map(|item| item.unwrap_or(else_value)) + .map(Some) + .collect::(); + + // Test case + test_case_when_literal_lookup( + input_values, + lookup_map, + Some(ScalarValue::Utf8(Some(else_value.to_string()))), + Arc::new(expected_with_else), + ); + } + + #[test] + fn test_case_when_literal_lookup_int32_to_string() { + let lookup_map = create_lookup([ + (Some(4), Some("four")), + (Some(2), Some("two")), + (Some(3), Some("three")), + (Some(1), Some("one")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (1, Some("one")), + (2, Some("two")), + (3, Some("three")), + (3, Some("three")), + (2, Some("two")), + (3, Some("three")), + (5, None), // No match in WHEN + (5, None), // No match in WHEN + (3, Some("three")), + (5, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_none_case_should_never_match() { + let lookup_map = create_lookup([ + (Some(4), Some("four")), + (None, Some("none")), + (Some(2), Some("two")), + (Some(1), Some("one")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (Some(1), Some("one")), + (Some(5), None), // No match in WHEN + (None, None), // None cases are never match in CASE WHEN syntax + (Some(2), Some("two")), + (None, None), // None cases are never match in CASE WHEN syntax + (None, None), // None cases are never match in CASE WHEN syntax + (Some(2), Some("two")), + (Some(5), None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_int32_to_string_with_duplicate_cases() { + let lookup_map = create_lookup([ + (Some(4), Some("four")), + (Some(4), Some("no 4")), + (Some(2), Some("two")), + (Some(2), Some("no 2")), + (Some(3), Some("three")), + (Some(3), Some("no 3")), + (Some(2), Some("no 2")), + (Some(4), Some("no 4")), + (Some(2), Some("no 2")), + (Some(3), Some("no 3")), + (Some(4), Some("no 4")), + (Some(2), Some("no 2")), + (Some(3), Some("no 3")), + (Some(3), Some("no 3")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (1, None), // No match in WHEN + (2, Some("two")), + (3, Some("three")), + (3, Some("three")), + (2, Some("two")), + (3, Some("three")), + (5, None), // No match in WHEN + (5, None), // No match in WHEN + (3, Some("three")), + (5, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_f32_to_string_with_special_values_and_duplicate_cases() + { + let lookup_map = create_lookup([ + (Some(4.0), Some("four point zero")), + (Some(f32::NAN), Some("NaN")), + (Some(3.2), Some("three point two")), + // Duplicate case to make sure it is not used + (Some(f32::NAN), Some("should not use this NaN branch")), + (Some(f32::INFINITY), Some("Infinity")), + (Some(0.0), Some("zero")), + // Duplicate case to make sure it is not used + ( + Some(f32::INFINITY), + Some("should not use this Infinity branch"), + ), + (Some(1.1), Some("one point one")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (1.1, Some("one point one")), + (f32::NAN, Some("NaN")), + (3.2, Some("three point two")), + (3.2, Some("three point two")), + (0.0, Some("zero")), + (f32::INFINITY, Some("Infinity")), + (3.2, Some("three point two")), + (f32::NEG_INFINITY, None), // No match in WHEN + (f32::NEG_INFINITY, None), // No match in WHEN + (3.2, Some("three point two")), + (-0.0, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_f16_to_string_with_special_values() { + let lookup_map = create_lookup([ + ( + ScalarValue::Float16(Some(f16::from_f32(3.2))), + Some("3 dot 2"), + ), + (ScalarValue::Float16(Some(f16::NAN)), Some("NaN")), + ( + ScalarValue::Float16(Some(f16::from_f32(17.4))), + Some("17 dot 4"), + ), + (ScalarValue::Float16(Some(f16::INFINITY)), Some("Infinity")), + (ScalarValue::Float16(Some(f16::ZERO)), Some("zero")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (f16::from_f32(3.2), Some("3 dot 2")), + (f16::NAN, Some("NaN")), + (f16::from_f32(17.4), Some("17 dot 4")), + (f16::from_f32(17.4), Some("17 dot 4")), + (f16::INFINITY, Some("Infinity")), + (f16::from_f32(17.4), Some("17 dot 4")), + (f16::NEG_INFINITY, None), // No match in WHEN + (f16::NEG_INFINITY, None), // No match in WHEN + (f16::from_f32(17.4), Some("17 dot 4")), + (f16::NEG_ZERO, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_f32_to_string_with_special_values() { + let lookup_map = create_lookup([ + (3.2, Some("3 dot 2")), + (f32::NAN, Some("NaN")), + (17.4, Some("17 dot 4")), + (f32::INFINITY, Some("Infinity")), + (f32::ZERO, Some("zero")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (3.2, Some("3 dot 2")), + (f32::NAN, Some("NaN")), + (17.4, Some("17 dot 4")), + (17.4, Some("17 dot 4")), + (f32::INFINITY, Some("Infinity")), + (17.4, Some("17 dot 4")), + (f32::NEG_INFINITY, None), // No match in WHEN + (f32::NEG_INFINITY, None), // No match in WHEN + (17.4, Some("17 dot 4")), + (-0.0, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_f64_to_string_with_special_values() { + let lookup_map = create_lookup([ + (3.2, Some("3 dot 2")), + (f64::NAN, Some("NaN")), + (17.4, Some("17 dot 4")), + (f64::INFINITY, Some("Infinity")), + (f64::ZERO, Some("zero")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (3.2, Some("3 dot 2")), + (f64::NAN, Some("NaN")), + (17.4, Some("17 dot 4")), + (17.4, Some("17 dot 4")), + (f64::INFINITY, Some("Infinity")), + (17.4, Some("17 dot 4")), + (f64::NEG_INFINITY, None), // No match in WHEN + (f64::NEG_INFINITY, None), // No match in WHEN + (17.4, Some("17 dot 4")), + (-0.0, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + // Test that we don't lose the decimal precision and scale info + #[test] + fn test_decimal_with_non_default_precision_and_scale() { + let lookup_map = create_lookup([ + (ScalarValue::Decimal32(Some(4), 3, 2), Some("four")), + (ScalarValue::Decimal32(Some(2), 3, 2), Some("two")), + (ScalarValue::Decimal32(Some(3), 3, 2), Some("three")), + (ScalarValue::Decimal32(Some(1), 3, 2), Some("one")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (1, Some("one")), + (2, Some("two")), + (3, Some("three")), + (3, Some("three")), + (2, Some("two")), + (3, Some("three")), + (5, None), // No match in WHEN + (5, None), // No match in WHEN + (3, Some("three")), + (5, None), // No match in WHEN + ]); + + let input_values = input_values + .with_precision_and_scale(3, 2) + .expect("must be able to set precision and scale"); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + // Test that we don't lose the timezone info + #[test] + fn test_timestamp_with_non_default_timezone() { + let timezone: Option> = Some("-10:00".into()); + let lookup_map = create_lookup([ + ( + ScalarValue::TimestampMillisecond(Some(4), timezone.clone()), + Some("four"), + ), + ( + ScalarValue::TimestampMillisecond(Some(2), timezone.clone()), + Some("two"), + ), + ( + ScalarValue::TimestampMillisecond(Some(3), timezone.clone()), + Some("three"), + ), + ( + ScalarValue::TimestampMillisecond(Some(1), timezone.clone()), + Some("one"), + ), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (1, Some("one")), + (2, Some("two")), + (3, Some("three")), + (3, Some("three")), + (2, Some("two")), + (3, Some("three")), + (5, None), // No match in WHEN + (5, None), // No match in WHEN + (3, Some("three")), + (5, None), // No match in WHEN + ]); + + let input_values = input_values.with_timezone_opt(timezone); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_with_strings_to_int32() { + let lookup_map = create_lookup([ + (Some("why"), Some(42)), + (Some("what"), Some(22)), + (Some("when"), Some(17)), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (Some("why"), Some(42)), + (Some("5"), None), // No match in WHEN + (None, None), // None cases are never match in CASE WHEN syntax + (Some("what"), Some(22)), + (None, None), // None cases are never match in CASE WHEN syntax + (None, None), // None cases are never match in CASE WHEN syntax + (Some("what"), Some(22)), + (Some("5"), None), // No match in WHEN + ]); + + let input_values = Arc::new(input_values) as ArrayRef; + + // Testing without ELSE should fallback to None + test_case_when_literal_lookup( + Arc::clone(&input_values), + &lookup_map, + None, + Arc::new(expected.clone()), + ); + + // Testing with Else + let else_value = 101; + + // Changing each expected None to be fallback + let expected_with_else = expected + .iter() + .map(|item| item.unwrap_or(else_value)) + .map(Some) + .collect::(); + + // Test case + test_case_when_literal_lookup( + input_values, + &lookup_map, + Some(ScalarValue::Int32(Some(else_value))), + Arc::new(expected_with_else), + ); + } + + /// Reproduces https://github.com/apache/datafusion/issues/22173 + /// + /// Nested self-referential CASE chains (common in rewrite-style projections) + /// should not cause exponential hashing work during physical planning. + #[test] + fn nested_self_referential_case_hash_stays_bounded() -> Result<()> { + use std::hash::Hasher; + + #[derive(Default)] + struct CountingHasher { + write_calls: usize, + bytes_written: usize, + } + + impl Hasher for CountingHasher { + fn finish(&self) -> u64 { + 0 + } + + fn write(&mut self, bytes: &[u8]) { + self.write_calls += 1; + self.bytes_written += bytes.len(); + } + } + + let schema = + Arc::new(Schema::new(vec![Field::new("kind", DataType::Utf8, true)])); + + let kind = col("kind", &schema)?; + let mut label = Arc::clone(&kind); + + let num_levels = 18; + for idx in 0..num_levels { + let predicate = Arc::new(BinaryExpr::new( + Arc::clone(&kind), + Operator::Eq, + lit(idx.to_string()), + )) as Arc; + + label = case(None, vec![(predicate, lit("label"))], Some(label))?; + } + + let mut hasher = CountingHasher::default(); + label.hash(&mut hasher); + + assert!( + hasher.write_calls < 50_000, + "hashing nested CASE expression took {} hasher writes and {} bytes", + hasher.write_calls, + hasher.bytes_written + ); + + Ok(()) + } +} + +#[cfg(all(test, feature = "proto"))] +mod proto_tests { + use super::*; + use crate::expressions::col; + use crate::proto_test_util::{ + StubDecoder, StubEncoder, UnreachableDecoder, column_node, + }; + use arrow::datatypes::Field; + use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx; + use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx; + use datafusion_proto_models::protobuf; + use datafusion_proto_models::protobuf::{PhysicalExprNode, PhysicalWhenThen}; + + fn proto_case_fixture() -> CaseExpr { + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); + CaseExpr::try_new( + Some(col("a", &schema).unwrap()), + vec![(lit(true), lit(1_i32))], + Some(lit(0_i32)), + ) + .unwrap() + } + + fn proto_when_then( + when_expr: Option, + then_expr: Option, + ) -> PhysicalWhenThen { + PhysicalWhenThen { + when_expr, + then_expr, + } + } + + fn proto_case_node( + expr: Option>, + when_then_expr: Vec, + else_expr: Option>, + ) -> PhysicalExprNode { + PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::Case(Box::new( + protobuf::PhysicalCaseNode { + expr, + when_then_expr, + else_expr, + }, + ))), + } + } + + #[test] + fn try_to_proto_encodes_case_expr() { + let case = proto_case_fixture(); + let encoder = StubEncoder::ok(); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let node = case + .try_to_proto(&ctx) + .unwrap() + .expect("CaseExpr should encode to Some(node)"); + + assert!(node.expr_id.is_none()); + let case_node = match node.expr_type { + Some(protobuf::physical_expr_node::ExprType::Case(boxed)) => *boxed, + other => panic!("expected a CaseExpr node, got {other:?}"), + }; + assert!(case_node.expr.is_some()); + assert_eq!(case_node.when_then_expr.len(), 1); + assert!(case_node.when_then_expr[0].when_expr.is_some()); + assert!(case_node.when_then_expr[0].then_expr.is_some()); + assert!(case_node.else_expr.is_some()); + } + + #[test] + fn try_to_proto_propagates_child_encode_error() { + let case = proto_case_fixture(); + // Call 1 is the optional CASE expr, call 2 is the WHEN expr. + let encoder = StubEncoder::failing_on(2); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let err = case.try_to_proto(&ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 2"))); + } + + #[test] + fn try_from_proto_decodes_case_expr() { + let node = proto_case_node( + Some(Box::new(column_node("case"))), + vec![proto_when_then( + Some(column_node("when")), + Some(column_node("then")), + )], + Some(Box::new(column_node("else"))), + ); + let schema = Schema::empty(); + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let decoded = CaseExpr::try_from_proto(&node, &ctx).unwrap(); + let case = decoded + .downcast_ref::() + .expect("decoded expr should be a CaseExpr"); + + assert!(case.expr().is_some()); + assert_eq!(case.when_then_expr().len(), 1); + assert!(case.else_expr().is_some()); + } + + #[test] + fn try_from_proto_rejects_non_case_node() { + let node = column_node("a"); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = CaseExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("PhysicalExprNode is not a CaseExpr")) + ); + } + + #[test] + fn try_from_proto_rejects_missing_when_expr() { + let node = proto_case_node( + None, + vec![proto_when_then(None, Some(column_node("then")))], + None, + ); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = CaseExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("CaseExpr is missing required field 'when_expr'")) + ); + } + + #[test] + fn try_from_proto_rejects_missing_then_expr() { + let node = proto_case_node( + None, + vec![proto_when_then(Some(column_node("when")), None)], + None, + ); + let schema = Schema::empty(); + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = CaseExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("CaseExpr is missing required field 'then_expr'")) + ); + } + + #[test] + fn try_from_proto_propagates_child_decode_error() { + let node = proto_case_node( + Some(Box::new(column_node("case"))), + vec![proto_when_then( + Some(column_node("when")), + Some(column_node("then")), + )], + Some(Box::new(column_node("else"))), + ); + let schema = Schema::empty(); + let decoder = StubDecoder::failing_on(2); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = CaseExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 2"))); + } } diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs new file mode 100644 index 0000000000000..15b3d04955b2e --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; +use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; +use arrow::datatypes::DataType; +use datafusion_common::{ScalarValue, internal_err}; + +#[derive(Clone, Debug)] +pub(super) struct BooleanIndexMap { + true_index: Option, + false_index: Option, +} + +impl BooleanIndexMap { + /// Try creating a new lookup table from the given literals and else index + /// The index of each literal in the vector is used as the mapped value in the lookup table. + /// + /// `literals` are guaranteed to be unique and non-nullable + pub(super) fn try_new( + unique_non_null_literals: Vec, + ) -> datafusion_common::Result { + let mut true_index: Option = None; + let mut false_index: Option = None; + + for (index, literal) in unique_non_null_literals.into_iter().enumerate() { + match literal { + ScalarValue::Boolean(Some(true)) => { + if true_index.is_some() { + return internal_err!( + "Duplicate true literal found in literals for BooleanIndexMap" + ); + } + true_index = Some(index as u32); + } + ScalarValue::Boolean(Some(false)) => { + if false_index.is_some() { + return internal_err!( + "Duplicate false literal found in literals for BooleanIndexMap" + ); + } + false_index = Some(index as u32); + } + ScalarValue::Boolean(None) => { + return internal_err!( + "Null literal found in non-null literals for BooleanIndexMap" + ); + } + _ => { + return internal_err!( + "Non-boolean literal found in literals for BooleanIndexMap" + ); + } + } + } + + Ok(Self { + true_index, + false_index, + }) + } + + fn map_boolean_array_to_when_indices( + &self, + array: &BooleanArray, + else_index: u32, + ) -> datafusion_common::Result> { + let true_index = self.true_index.unwrap_or(else_index); + let false_index = self.false_index.unwrap_or(else_index); + + Ok(array + .into_iter() + .map(|value| match value { + Some(true) => true_index, + Some(false) => false_index, + None => else_index, + }) + .collect::>()) + } +} + +impl WhenLiteralIndexMap for BooleanIndexMap { + fn map_to_when_indices( + &self, + array: &ArrayRef, + else_index: u32, + ) -> datafusion_common::Result> { + match array.data_type() { + DataType::Boolean => { + self.map_boolean_array_to_when_indices(array.as_boolean(), else_index) + } + // We support dictionary boolean array as we create the lookup table in `CaseWhen` expression + // creation when we don't know the schema, so we may receive dictionary encoded boolean arrays at execution time. + DataType::Dictionary(_, value_type) + if value_type.as_ref() == &DataType::Boolean => + { + // Since it is not common to have dictionary encoded boolean arrays + // at all than it is ok to do the cast here to simplify the implementation. + let converted = arrow::compute::cast(array.as_ref(), &DataType::Boolean)?; + self.map_boolean_array_to_when_indices(converted.as_boolean(), else_index) + } + _ => internal_err!( + "Expected boolean array for BooleanIndexMap, got {:?}", + array.data_type() + ), + } + } +} diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs new file mode 100644 index 0000000000000..e5cf3f84fd919 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs @@ -0,0 +1,223 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; +use arrow::array::{ + Array, ArrayRef, AsArray, BinaryArray, BinaryViewArray, DictionaryArray, + FixedSizeBinaryArray, LargeBinaryArray, LargeStringArray, StringArray, + StringViewArray, downcast_integer, +}; +use arrow::datatypes::{ + ArrowDictionaryKeyType, BinaryViewType, DataType, StringViewType, +}; +use datafusion_common::{HashMap, ScalarValue, internal_err, plan_datafusion_err}; +use std::fmt::Debug; + +/// Map from byte-like literal values to their first occurrence index +/// +/// This is a wrapper for handling different kinds of literal maps +#[derive(Clone, Debug)] +pub(super) struct BytesLikeIndexMap { + /// Map from non-null literal value the first occurrence index in the literals + map: HashMap, u32>, +} + +impl BytesLikeIndexMap { + /// Try creating a new lookup table from the given literals and else index + /// The index of each literal in the vector is used as the mapped value in the lookup table. + /// + /// `literals` are guaranteed to be unique and non-nullable + pub(super) fn try_new( + unique_non_null_literals: Vec, + ) -> datafusion_common::Result { + let input = ScalarValue::iter_to_array(unique_non_null_literals)?; + + // Literals are guaranteed to not contain nulls + if input.logical_null_count() > 0 { + return internal_err!("Literal values for WHEN clauses cannot contain nulls"); + } + + let map: HashMap, u32> = try_get_bytes_iterator(&input)? + // Flattening Option<&[u8]> to &[u8] as literals cannot contain nulls + .flatten() + .enumerate() + .map(|(map_index, value)| (value.to_vec(), map_index as u32)) + // Because literals are unique we can collect directly, and we can avoid only inserting the first occurrence + .collect(); + + Ok(Self { map }) + } +} + +impl WhenLiteralIndexMap for BytesLikeIndexMap { + fn map_to_when_indices( + &self, + array: &ArrayRef, + else_index: u32, + ) -> datafusion_common::Result> { + let indices = try_get_bytes_iterator(array)? + .map(|value| match value { + Some(value) => self.map.get(value).copied().unwrap_or(else_index), + None => else_index, + }) + .collect::>(); + + Ok(indices) + } +} + +fn try_get_bytes_iterator( + array: &ArrayRef, +) -> datafusion_common::Result> + '_>> { + Ok(match array.data_type() { + DataType::Utf8 => Box::new(array.as_string::().into_iter().map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + })), + + DataType::LargeUtf8 => { + Box::new(array.as_string::().into_iter().map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + })) + } + + DataType::Binary => Box::new(array.as_binary::().into_iter()), + + DataType::LargeBinary => Box::new(array.as_binary::().into_iter()), + + DataType::FixedSizeBinary(_) => Box::new(array.as_binary::().into_iter()), + + DataType::Utf8View => Box::new( + array + .as_byte_view::() + .into_iter() + .map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + }), + ), + DataType::BinaryView => { + Box::new(array.as_byte_view::().into_iter()) + } + + DataType::Dictionary(key, _) => { + macro_rules! downcast_dictionary_array_helper { + ($t:ty) => {{ get_bytes_iterator_for_dictionary(array.as_dictionary::<$t>())? }}; + } + + downcast_integer! { + key.as_ref() => (downcast_dictionary_array_helper), + k => unreachable!("unsupported dictionary key type: {}", k) + } + } + t => { + return Err(plan_datafusion_err!( + "Unsupported data type for bytes lookup table: {}", + t + )); + } + }) +} + +fn get_bytes_iterator_for_dictionary( + array: &DictionaryArray, +) -> datafusion_common::Result> + '_>> { + Ok(match array.values().data_type() { + DataType::Utf8 => Box::new( + array + .downcast_dict::() + .unwrap() + .into_iter() + .map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + }), + ), + + DataType::LargeUtf8 => Box::new( + array + .downcast_dict::() + .unwrap() + .into_iter() + .map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + }), + ), + + DataType::Binary => { + Box::new(array.downcast_dict::().unwrap().into_iter()) + } + + DataType::LargeBinary => Box::new( + array + .downcast_dict::() + .unwrap() + .into_iter(), + ), + + DataType::FixedSizeBinary(_) => Box::new( + array + .downcast_dict::() + .unwrap() + .into_iter(), + ), + + DataType::Utf8View => Box::new( + array + .downcast_dict::() + .unwrap() + .into_iter() + .map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + }), + ), + DataType::BinaryView => Box::new( + array + .downcast_dict::() + .unwrap() + .into_iter(), + ), + + t => { + return Err(plan_datafusion_err!( + "Unsupported data type for lookup table dictionary value: {}", + t + )); + } + }) +} diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs new file mode 100644 index 0000000000000..0d4291ccc934b --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs @@ -0,0 +1,327 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod boolean_lookup_table; +mod bytes_like_lookup_table; +mod primitive_lookup_table; + +use crate::expressions::Literal; +use crate::expressions::case::CaseBody; +use crate::expressions::case::literal_lookup_table::boolean_lookup_table::BooleanIndexMap; +use crate::expressions::case::literal_lookup_table::bytes_like_lookup_table::BytesLikeIndexMap; +use crate::expressions::case::literal_lookup_table::primitive_lookup_table::PrimitiveIndexMap; +use arrow::array::{Array, ArrayRef, UInt32Array, downcast_primitive}; +use arrow::datatypes::DataType; +use datafusion_common::{ScalarValue, arrow_datafusion_err, plan_datafusion_err}; +use indexmap::IndexMap; +use std::fmt::Debug; + +/// Optimization for CASE expressions with literal WHEN and THEN clauses +/// +/// for this form: +/// ```sql +/// CASE +/// WHEN THEN +/// WHEN THEN +/// WHEN THEN +/// WHEN THEN +/// ELSE +/// END +/// ``` +/// +/// # Improvement idea +/// TODO - we should think of unwrapping the `IN` expressions into multiple equality comparisons +/// so it will use this optimization as well, e.g. +/// ```sql +/// -- Before +/// CASE +/// WHEN ( = ) THEN +/// WHEN ( in (, ) THEN +/// WHEN ( = ) THEN +/// ELSE +/// +/// -- After +/// CASE +/// WHEN ( = ) THEN +/// WHEN ( = ) THEN +/// WHEN ( = ) THEN +/// WHEN ( = ) THEN +/// ELSE +/// END +/// ``` +/// +#[derive(Debug)] +pub(in super::super) struct LiteralLookupTable { + /// The lookup table to use for evaluating the CASE expression + lookup: Box, + + else_index: u32, + + /// [`ArrayRef`] where `array[i] = then_literals[i]` + /// the last value in the array is the else_expr + /// + /// This will be used to take from based on the indices returned by the lookup table to build the final output + then_and_else_values: ArrayRef, +} + +impl LiteralLookupTable { + pub(in super::super) fn maybe_new(body: &CaseBody) -> Option { + // We can't use the optimization if we don't have any when then pairs + if body.when_then_expr.is_empty() { + return None; + } + + // If we only have 1 than this optimization is not useful + if body.when_then_expr.len() == 1 { + return None; + } + + // Try to downcast all the WHEN/THEN expressions to literals + let when_then_exprs_maybe_literals = body + .when_then_expr + .iter() + .map(|(when, then)| { + let when_maybe_literal = when.downcast_ref::(); + let then_maybe_literal = then.downcast_ref::(); + + when_maybe_literal.zip(then_maybe_literal) + }) + .collect::>(); + + // If not all the WHEN/THEN expressions are literals we cannot use this optimization + if when_then_exprs_maybe_literals.contains(&None) { + return None; + } + + let when_then_exprs_scalars = when_then_exprs_maybe_literals + .into_iter() + // Unwrap the options as we have already checked there is no None + .flatten() + .map(|(when_lit, then_lit)| { + (when_lit.value().clone(), then_lit.value().clone()) + }) + // Only keep non-null WHEN literals + // as they cannot be matched - case NULL WHEN NULL THEN ... ELSE ... END always goes to ELSE + .filter(|(when_lit, _)| !when_lit.is_null()) + .collect::>(); + + if when_then_exprs_scalars.is_empty() { + // All WHEN literals were nulls, so cannot use optimization + // + // instead, another optimization would be to go straight to the ELSE clause + return None; + } + + // Keep only the first occurrence of each when literal (as the first match is used) + // and remove nulls (as they cannot be matched - case NULL WHEN NULL THEN ... ELSE ... END always goes to ELSE) + let (when, then): (Vec, Vec) = { + let mut map = IndexMap::with_capacity(body.when_then_expr.len()); + + for (when, then) in when_then_exprs_scalars.into_iter() { + // Don't overwrite existing entries as we want to keep the first occurrence + if !map.contains_key(&when) { + map.insert(when, then); + } + } + + map.into_iter().unzip() + }; + + let else_value: ScalarValue = if let Some(else_expr) = &body.else_expr { + let literal = else_expr.downcast_ref::()?; + + literal.value().clone() + } else { + let Ok(null_scalar) = ScalarValue::try_new_null(&then[0].data_type()) else { + return None; + }; + + null_scalar + }; + + { + let when_data_type = when[0].data_type(); + + // If not all the WHEN literals are the same data type we cannot use this optimization + if when.iter().any(|l| l.data_type() != when_data_type) { + return None; + } + } + + { + let data_type = then[0].data_type(); + + // If not all the then and the else literals are the same data type we cannot use this optimization + if then.iter().any(|l| l.data_type() != data_type) { + return None; + } + + if else_value.data_type() != data_type { + return None; + } + } + + let then_and_else_values = ScalarValue::iter_to_array( + then.iter() + // The else is in the end + .chain(std::iter::once(&else_value)) + .cloned(), + ) + .ok()?; + // The else expression is in the end + let else_index = then_and_else_values.len() as u32 - 1; + + let lookup = try_creating_lookup_table(when).ok()?; + + Some(Self { + lookup, + then_and_else_values, + else_index, + }) + } + + pub(in super::super) fn map_keys_to_values( + &self, + keys_array: &ArrayRef, + ) -> datafusion_common::Result { + let take_indices = self + .lookup + .map_to_when_indices(keys_array, self.else_index)?; + + // Zero-copy conversion + let take_indices = UInt32Array::from(take_indices); + + // An optimize version would depend on the type of the values_to_take_from + // For example, if the type is view we can just keep pointing to the same value (similar to dictionary) + // if the type is dictionary we can just use the indices as is (or cast them to the key type) and create a new dictionary array + let output = + arrow::compute::take(&self.then_and_else_values, &take_indices, None) + .map_err(|e| arrow_datafusion_err!(e))?; + + Ok(output) + } +} + +/// Map values that match the WHEN literal to the index of their corresponding WHEN clause +/// +/// For example, for this CASE expression: +/// +/// ```sql +/// CASE +/// WHEN THEN +/// WHEN THEN +/// WHEN THEN +/// WHEN THEN +/// ELSE +/// END +/// ``` +/// +/// this will map to 0, to 1, to 2, to 3 +pub(super) trait WhenLiteralIndexMap: Debug + Send + Sync { + /// Given an array of values, returns a vector of WHEN clause indices corresponding to each value in the provided array. + /// + /// For example, for this CASE expression: + /// + /// ```sql + /// CASE + /// WHEN THEN + /// WHEN THEN + /// WHEN THEN + /// WHEN THEN + /// ELSE + /// END + /// ``` + /// + /// the array will be the evaluated values of `` + /// and if that array is: + /// - `[, , , , ]` + /// + /// the returned vector will be: + /// - `[0, 2, else_index, 1, 0]` + /// + fn map_to_when_indices( + &self, + array: &ArrayRef, + else_index: u32, + ) -> datafusion_common::Result>; +} + +fn try_creating_lookup_table( + unique_non_null_literals: Vec, +) -> datafusion_common::Result> { + assert_ne!( + unique_non_null_literals.len(), + 0, + "Must have at least one literal" + ); + match unique_non_null_literals[0].data_type() { + DataType::Boolean => { + let lookup_table = BooleanIndexMap::try_new(unique_non_null_literals)?; + Ok(Box::new(lookup_table)) + } + + data_type if data_type.is_primitive() => { + macro_rules! create_matching_map { + ($t:ty) => {{ + let lookup_table = + PrimitiveIndexMap::<$t>::try_new(unique_non_null_literals)?; + Ok(Box::new(lookup_table)) + }}; + } + + downcast_primitive! { + data_type => (create_matching_map), + _ => Err(plan_datafusion_err!( + "Unsupported field type for primitive: {:?}", + data_type + )), + } + } + + DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + | DataType::FixedSizeBinary(_) + | DataType::Utf8View + | DataType::BinaryView => { + let lookup_table = BytesLikeIndexMap::try_new(unique_non_null_literals)?; + Ok(Box::new(lookup_table)) + } + + DataType::Dictionary(_key, value) + if matches!( + value.as_ref(), + DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + | DataType::FixedSizeBinary(_) + | DataType::Utf8View + | DataType::BinaryView + ) => + { + let lookup_table = BytesLikeIndexMap::try_new(unique_non_null_literals)?; + Ok(Box::new(lookup_table)) + } + + _ => Err(plan_datafusion_err!( + "Unsupported data type for lookup table: {}", + unique_non_null_literals[0].data_type() + )), + } +} diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs new file mode 100644 index 0000000000000..36d282c2a402b --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; +use arrow::array::{ + Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray, PrimitiveArray, +}; +use arrow::datatypes::{DataType, IntervalDayTime, IntervalMonthDayNano, i256}; +use datafusion_common::{HashMap, ScalarValue, internal_err}; +use half::f16; +use std::fmt::Debug; +use std::hash::Hash; + +#[derive(Clone)] +pub(super) struct PrimitiveIndexMap +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + data_type: DataType, + /// Literal value to map index + /// + /// If searching this map becomes a bottleneck consider using linear map implementations for small hashmaps + map: HashMap<::HashableKey, u32>, +} + +impl Debug for PrimitiveIndexMap +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PrimitiveIndexMap") + .field("map", &self.map) + .finish() + } +} + +impl PrimitiveIndexMap +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + /// Try creating a new lookup table from the given literals and else index. + /// The index of each literal in the vector is used as the mapped value in the lookup table. + /// + /// `literals` are guaranteed to be unique and non-nullable + pub(super) fn try_new( + unique_non_null_literals: Vec, + ) -> datafusion_common::Result { + let input = ScalarValue::iter_to_array(unique_non_null_literals)?; + + // Literals are guaranteed to not contain nulls + if input.null_count() > 0 { + return internal_err!("Literal values for WHEN clauses cannot contain nulls"); + } + + let map = input + .as_primitive::() + .values() + .iter() + .enumerate() + // Because literals are unique we can collect directly, and we can avoid only inserting the first occurrence + .map(|(map_index, value)| (value.into_hashable_key(), map_index as u32)) + .collect(); + + Ok(Self { + map, + data_type: input.data_type().clone(), + }) + } + + fn map_primitive_array_to_when_indices( + &self, + array: &PrimitiveArray, + else_index: u32, + ) -> datafusion_common::Result> { + let indices = array + .into_iter() + .map(|value| match value { + Some(value) => self + .map + .get(&value.into_hashable_key()) + .copied() + .unwrap_or(else_index), + + None => else_index, + }) + .collect::>(); + + Ok(indices) + } +} + +impl WhenLiteralIndexMap for PrimitiveIndexMap +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + fn map_to_when_indices( + &self, + array: &ArrayRef, + else_index: u32, + ) -> datafusion_common::Result> { + match array.data_type() { + dt if dt == &self.data_type => { + let primitive_array = array.as_primitive::(); + + self.map_primitive_array_to_when_indices(primitive_array, else_index) + } + // We support dictionary primitive array as we create the lookup table in `CaseWhen` expression + // creation when we don't know the schema, so we may receive dictionary encoded primitive arrays at execution time. + DataType::Dictionary(_, value_type) + if value_type.as_ref() == &self.data_type => + { + // Cast here to simplify the implementation. + let converted = arrow::compute::cast(array.as_ref(), &self.data_type)?; + self.map_primitive_array_to_when_indices( + converted.as_primitive::(), + else_index, + ) + } + _ => internal_err!( + "PrimitiveIndexMap expected array of type {:?} but got {:?}", + self.data_type, + array.data_type() + ), + } + } +} + +// TODO - We need to port it to arrow so that it can be reused in other places + +/// Trait that help convert a value to a key that is hashable and equatable +/// This is needed as some types like f16/f32/f64 do not implement Hash/Eq directly +pub(super) trait ToHashableKey: ArrowNativeTypeOp { + /// The type that is hashable and equatable + /// It must be an Arrow native type but it NOT GUARANTEED to be the same as Self + /// this is just a helper trait so you can reuse the same code for all arrow native types + type HashableKey: Hash + Eq + Debug + Clone + Copy + Send + Sync; + + /// Converts self to a hashable key + /// the result of this value can be used as the key in hash maps/sets + fn into_hashable_key(self) -> Self::HashableKey; +} + +macro_rules! impl_to_hashable_key { + (@single_already_hashable | $t:ty) => { + impl ToHashableKey for $t { + type HashableKey = $t; + + #[inline] + fn into_hashable_key(self) -> Self::HashableKey { + self + } + } + }; + (@already_hashable | $($t:ty),+ $(,)?) => { + $( + impl_to_hashable_key!(@single_already_hashable | $t); + )+ + }; + (@float | $t:ty => $hashable:ty) => { + impl ToHashableKey for $t { + type HashableKey = $hashable; + + #[inline] + fn into_hashable_key(self) -> Self::HashableKey { + self.to_bits() + } + } + }; +} + +impl_to_hashable_key!(@already_hashable | i8, i16, i32, i64, i128, i256, u8, u16, u32, u64, IntervalDayTime, IntervalMonthDayNano); +impl_to_hashable_key!(@float | f16 => u16); +impl_to_hashable_key!(@float | f32 => u32); +impl_to_hashable_key!(@float | f64 => u64); + +#[cfg(test)] +mod tests { + use super::ToHashableKey; + use arrow::array::downcast_primitive; + + // This test ensure that all arrow primitive types implement ToHashableKey + // otherwise the code will not compile + #[test] + fn should_implement_to_hashable_key_for_all_primitives() { + #[derive(Debug, Default)] + struct ExampleSet + where + T: arrow::datatypes::ArrowPrimitiveType, + T::Native: ToHashableKey, + { + _map: std::collections::HashSet<::HashableKey>, + } + + macro_rules! create_matching_set { + ($t:ty) => {{ + let _lookup_table = ExampleSet::<$t> { + _map: Default::default(), + }; + + return; + }}; + } + + let data_type = arrow::datatypes::DataType::Float16; + + downcast_primitive! { + data_type => (create_matching_set), + _ => panic!("not implemented for {data_type}"), + } + } +} diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 0419161b532ce..26f06b546ad1d 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -15,18 +15,21 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::fmt; use std::hash::Hash; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; -use arrow::compute::{can_cast_types, CastOptions}; +use arrow::compute::{CastOptions, can_cast_types}; use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema}; use arrow::record_batch::RecordBatch; +use datafusion_common::datatype::DataTypeExt; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::nested_struct::{ + requires_nested_struct_cast, validate_data_type_compatibility, +}; +use datafusion_common::{Result, not_impl_err}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::ExprProperties; @@ -41,13 +44,23 @@ const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions { format_options: DEFAULT_FORMAT_OPTIONS, }; +/// Check if name-based struct casting is allowed by validating field compatibility. +/// +/// This function applies the same validation rules as execution time to ensure +/// planning-time validation matches runtime validation, enabling fail-fast behavior +/// instead of deferring errors to execution. Handles structs at any nesting level +/// (e.g., `List`, `Dictionary<_, Struct>`). +fn can_cast_named_struct_types(source: &DataType, target: &DataType) -> bool { + validate_data_type_compatibility("", source, target).is_ok() +} + /// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast #[derive(Debug, Clone, Eq)] pub struct CastExpr { /// The expression to cast pub expr: Arc, - /// The data type to cast to - cast_type: DataType, + /// Field metadata describing the desired output after casting + target_field: FieldRef, /// Cast options cast_options: CastOptions<'static>, } @@ -56,7 +69,7 @@ pub struct CastExpr { impl PartialEq for CastExpr { fn eq(&self, other: &Self) -> bool { self.expr.eq(&other.expr) - && self.cast_type.eq(&other.cast_type) + && self.target_field.eq(&other.target_field) && self.cast_options.eq(&other.cast_options) } } @@ -64,21 +77,55 @@ impl PartialEq for CastExpr { impl Hash for CastExpr { fn hash(&self, state: &mut H) { self.expr.hash(state); - self.cast_type.hash(state); + self.target_field.hash(state); self.cast_options.hash(state); } } impl CastExpr { - /// Create a new CastExpr + /// Create a new `CastExpr` using only a `DataType`. + /// + /// This constructor is provided for compatibility with existing call sites + /// that only know the target type. It synthesizes a ``Field`` with the + /// given type (**nullable by default**) and no name metadata. Callers that + /// already have a `FieldRef` (for example, coming from schema inference or a + /// resolved column) should prefer [`CastExpr::new_with_target_field`], which + /// preserves the field's name, nullability, and other metadata. In other + /// words: + /// + /// * use `new()` when only a `DataType` is available and you want the legacy + /// semantics of a type-only cast + /// * use `new_with_target_field()` when you need explicit field + /// metadata/name/nullability preserved pub fn new( expr: Arc, cast_type: DataType, cast_options: Option>, + ) -> Self { + Self::new_with_target_field( + expr, + cast_type.into_nullable_field_ref(), + cast_options, + ) + } + + /// Create a new `CastExpr` with an explicit target `FieldRef`. + /// + /// The provided `target_field` is used verbatim for the expression's + /// return schema, so the field's name, nullability, and other metadata are + /// preserved. This is the preferred constructor when the caller already + /// has field information (for example, during logical-to-physical planning). + /// + /// See [`CastExpr::new`] for the compatibility constructor that only accepts + /// a `DataType`. + pub fn new_with_target_field( + expr: Arc, + target_field: FieldRef, + cast_options: Option>, ) -> Self { Self { expr, - cast_type, + target_field, cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS), } } @@ -90,7 +137,12 @@ impl CastExpr { /// The data type to cast to pub fn cast_type(&self) -> &DataType { - &self.cast_type + self.target_field.data_type() + } + + /// Field metadata describing the output column after casting. + pub fn target_field(&self) -> &FieldRef { + &self.target_field } /// The cast options @@ -98,13 +150,29 @@ impl CastExpr { &self.cast_options } - /// Check if the cast is a widening cast (e.g. from `Int8` to `Int16`). - pub fn is_bigger_cast(&self, src: &DataType) -> bool { - if self.cast_type.eq(src) { + fn resolved_target_field(&self, input_schema: &Schema) -> Result { + if is_default_target_field(&self.target_field) { + self.expr.return_field(input_schema).map(|field| { + Arc::new( + field + .as_ref() + .clone() + .with_data_type(self.cast_type().clone()), + ) + }) + } else { + Ok(Arc::clone(&self.target_field)) + } + } + + /// Check if casting from the specified source type to the target type is a + /// widening cast (e.g. from `Int8` to `Int16`). + pub fn check_bigger_cast(cast_type: &DataType, src: &DataType) -> bool { + if cast_type.eq(src) { return true; } matches!( - (src, &self.cast_type), + (src, cast_type), (Int8, Int16 | Int32 | Int64) | (Int16, Int32 | Int64) | (Int32, Int64) @@ -119,41 +187,69 @@ impl CastExpr { | (Utf8, LargeUtf8) ) } + + /// Check if the cast is a widening cast (e.g. from `Int8` to `Int16`). + pub fn is_bigger_cast(&self, src: &DataType) -> bool { + Self::check_bigger_cast(self.cast_type(), src) + } +} + +fn is_default_target_field(target_field: &FieldRef) -> bool { + target_field.name().is_empty() + && target_field.is_nullable() + && target_field.metadata().is_empty() +} + +pub(crate) fn is_order_preserving_cast_family( + source_type: &DataType, + target_type: &DataType, +) -> bool { + (source_type.is_numeric() || *source_type == Boolean) && target_type.is_numeric() + || source_type.is_temporal() && target_type.is_temporal() + || source_type.eq(target_type) +} + +pub(crate) fn cast_expr_properties( + child: &ExprProperties, + target_type: &DataType, +) -> Result { + let unbounded = Interval::make_unbounded(target_type)?; + if is_order_preserving_cast_family(&child.range.data_type(), target_type) { + Ok(child.clone().with_range(unbounded)) + } else { + Ok(ExprProperties::new_unknown().with_range(unbounded)) + } } impl fmt::Display for CastExpr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "CAST({} AS {:?})", self.expr, self.cast_type) + write!(f, "CAST({} AS {})", self.expr, self.cast_type()) } } impl PhysicalExpr for CastExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(self.cast_type.clone()) + Ok(self.cast_type().clone()) } fn nullable(&self, input_schema: &Schema) -> Result { - self.expr.nullable(input_schema) + // A cast is nullable if **either** the child is nullable or the + // target field allows nulls. This conservative rule prevents + // optimizers from assuming a non-null result when a null input could + // still propagate. `return_field()` continues to expose the exact + // target metadata separately. + let child_nullable = self.expr.nullable(input_schema)?; + let target_nullable = self.resolved_target_field(input_schema)?.is_nullable(); + Ok(child_nullable || target_nullable) } fn evaluate(&self, batch: &RecordBatch) -> Result { let value = self.expr.evaluate(batch)?; - value.cast_to(&self.cast_type, Some(&self.cast_options)) + value.cast_to(self.cast_type(), Some(&self.cast_options)) } fn return_field(&self, input_schema: &Schema) -> Result { - Ok(self - .expr - .return_field(input_schema)? - .as_ref() - .clone() - .with_data_type(self.cast_type.clone()) - .into()) + self.resolved_target_field(input_schema) } fn children(&self) -> Vec<&Arc> { @@ -164,16 +260,16 @@ impl PhysicalExpr for CastExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(CastExpr::new( + Ok(Arc::new(CastExpr::new_with_target_field( Arc::clone(&children[0]), - self.cast_type.clone(), + Arc::clone(&self.target_field), Some(self.cast_options.clone()), ))) } fn evaluate_bounds(&self, children: &[&Interval]) -> Result { // Cast current node's interval to the right type: - children[0].cast_to(&self.cast_type, &self.cast_options) + children[0].cast_to(self.cast_type(), &self.cast_options) } fn propagate_constraints( @@ -185,35 +281,78 @@ impl PhysicalExpr for CastExpr { // Get child's datatype: let cast_type = child_interval.data_type(); Ok(Some(vec![ - interval.cast_to(&cast_type, &DEFAULT_SAFE_CAST_OPTIONS)? + interval.cast_to(&cast_type, &DEFAULT_SAFE_CAST_OPTIONS)?, ])) } /// A [`CastExpr`] preserves the ordering of its child if the cast is done /// under the same datatype family. fn get_properties(&self, children: &[ExprProperties]) -> Result { - let source_datatype = children[0].range.data_type(); - let target_type = &self.cast_type; - - let unbounded = Interval::make_unbounded(target_type)?; - if (source_datatype.is_numeric() || source_datatype == Boolean) - && target_type.is_numeric() - || source_datatype.is_temporal() && target_type.is_temporal() - || source_datatype.eq(target_type) - { - Ok(children[0].clone().with_range(unbounded)) - } else { - Ok(ExprProperties::new_unknown().with_range(unbounded)) - } + cast_expr_properties(&children[0], self.cast_type()) } fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "CAST(")?; self.expr.fmt_sql(f)?; - write!(f, " AS {:?}", self.cast_type)?; + write!(f, " AS {:?}", self.cast_type())?; write!(f, ")") } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( + protobuf::PhysicalCastNode { + expr: Some(Box::new(ctx.encode_child(self.expr())?)), + arrow_type: Some(self.cast_type().try_into()?), + }, + ))), + })) + } +} + +#[cfg(feature = "proto")] +impl CastExpr { + /// Reconstruct a [`CastExpr`] from its protobuf representation. + /// + /// Takes the whole [`PhysicalExprNode`] so the decode signature matches + /// other migrated expressions and can inspect outer-node metadata if + /// needed in the future. + /// + /// [`PhysicalExprNode`]: datafusion_proto_models::protobuf::PhysicalExprNode + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_common::internal_datafusion_err; + use datafusion_common::internal_err; + use datafusion_proto_models::protobuf; + + let cast_expr = match &node.expr_type { + Some(protobuf::physical_expr_node::ExprType::Cast(cast_expr)) => { + cast_expr.as_ref() + } + _ => return internal_err!("PhysicalExprNode is not a CastExpr"), + }; + + let expr = ctx.decode_required_expression( + cast_expr.expr.as_deref(), + "CastExpr", + "expr", + )?; + let arrow_type = cast_expr.arrow_type.as_ref().ok_or_else(|| { + internal_datafusion_err!("CastExpr is missing required field 'arrow_type'") + })?; + + Ok(Arc::new(CastExpr::new(expr, arrow_type.try_into()?, None))) + } } /// Return a PhysicalExpression representing `expr` casted to @@ -225,15 +364,55 @@ pub fn cast_with_options( input_schema: &Schema, cast_type: DataType, cast_options: Option>, +) -> Result> { + cast_with_target_field( + expr, + input_schema, + cast_type.into_nullable_field_ref(), + cast_options, + ) +} + +/// Return a PhysicalExpression representing `expr` casted to `target_field`, +/// preserving any explicit field semantics such as name, nullability, and +/// metadata. +/// +/// If the input expression already has the same data type, this helper still +/// preserves an explicit `target_field` by constructing a field-aware +/// [`CastExpr`]. Only the default synthesized field created by the legacy +/// type-only API is elided back to the original child expression. +pub fn cast_with_target_field( + expr: Arc, + input_schema: &Schema, + target_field: FieldRef, + cast_options: Option>, ) -> Result> { let expr_type = expr.data_type(input_schema)?; - if expr_type == cast_type { - Ok(Arc::clone(&expr)) - } else if can_cast_types(&expr_type, &cast_type) { - Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) + let cast_type = target_field.data_type(); + if expr_type == *cast_type && is_default_target_field(&target_field) { + return Ok(Arc::clone(&expr)); + } + + let can_build_cast = if requires_nested_struct_cast(&expr_type, cast_type) { + // Allow casts involving structs (including nested inside Lists, Dictionaries, + // etc.) that pass name-based compatibility validation. This validation is + // applied at planning time (now) to fail fast, rather than deferring errors + // to execution time. The name-based casting logic will be executed at runtime + // via ColumnarValue::cast_to. + can_cast_named_struct_types(&expr_type, cast_type) } else { - not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}") + can_cast_types(&expr_type, cast_type) + }; + + if !can_build_cast { + return not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}"); } + + Ok(Arc::new(CastExpr::new_with_target_field( + expr, + target_field, + cast_options, + ))) } /// Return a PhysicalExpression representing `expr` casted to @@ -256,14 +435,45 @@ mod tests { use arrow::{ array::{ - Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int8Array, StringArray, Time64NanosecondArray, - TimestampNanosecondArray, UInt32Array, + Array, ArrayRef, Decimal128Array, Float32Array, Float64Array, Int8Array, + Int16Array, Int32Array, Int64Array, StringArray, StructArray, + Time64NanosecondArray, TimestampNanosecondArray, UInt32Array, }, datatypes::*, }; + use datafusion_common::ScalarValue; + use datafusion_common::cast::{ + as_boolean_array, as_int64_array, as_string_array, as_struct_array, + as_uint8_array, + }; use datafusion_physical_expr_common::physical_expr::fmt_sql; use insta::assert_snapshot; + use std::collections::HashMap; + + fn make_struct_array(fields: Fields, arrays: Vec) -> StructArray { + StructArray::new(fields, arrays, None) + } + + fn cast_struct_array( + column: &str, + input_field: Field, + target_field: Field, + input_array: StructArray, + ) -> Result { + let schema = Arc::new(Schema::new(vec![input_field])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(input_array) as ArrayRef], + )?; + let expr = CastExpr::new_with_target_field( + col(column, schema.as_ref())?, + Arc::new(target_field), + None, + ); + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + Ok(as_struct_array(result.as_ref())?.clone()) + } // runs an end-to-end test of physical type cast // 1. construct a record batch with a column "a" of type A @@ -283,10 +493,7 @@ mod tests { cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?; // verify that its display is correct - assert_eq!( - format!("CAST(a@0 AS {:?})", $TYPE), - format!("{}", expression) - ); + assert_eq!(format!("CAST(a@0 AS {})", $TYPE), format!("{}", expression)); // verify that the expression's type is correct assert_eq!(expression.data_type(&schema)?, $TYPE); @@ -310,7 +517,7 @@ mod tests { for (i, x) in $VEC.iter().enumerate() { match x { Some(x) => assert_eq!(result.value(i), *x), - None => assert!(!result.is_valid(i)), + None => assert!(result.is_null(i)), } } }}; @@ -335,10 +542,7 @@ mod tests { cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?; // verify that its display is correct - assert_eq!( - format!("CAST(a@0 AS {:?})", $TYPE), - format!("{}", expression) - ); + assert_eq!(format!("CAST(a@0 AS {})", $TYPE), format!("{}", expression)); // verify that the expression's type is correct assert_eq!(expression.data_type(&schema)?, $TYPE); @@ -365,7 +569,7 @@ mod tests { for (i, x) in $VEC.iter().enumerate() { match x { Some(x) => assert_eq!(result.value(i), *x), - None => assert!(!result.is_valid(i)), + None => assert!(result.is_null(i)), } } }}; @@ -740,6 +944,9 @@ mod tests { Ok(()) } + // Tests for timestamp timezone casting have been moved to timestamps.slt + // See the "Casting between timestamp with and without timezone" section + #[test] fn invalid_cast() { // Ensure a useful error happens at plan time if invalid casts are used @@ -765,14 +972,209 @@ mod tests { match result { Ok(_) => panic!("expected error"), Err(e) => { - assert!(e - .to_string() - .contains("Cannot cast string '9.1' to value of Int32 type")) + assert!( + e.to_string() + .contains("Cannot cast string '9.1' to value of Int32 type") + ) } } Ok(()) } + #[test] + fn field_aware_cast_preserves_target_field_semantics() -> Result<()> { + let metadata = HashMap::from([("target_meta".to_string(), "1".to_string())]); + + for (child_nullable, target_nullable) in [(true, false), (false, true)] { + let schema = Schema::new(vec![Field::new("a", Int32, child_nullable)]); + let expr = CastExpr::new_with_target_field( + col("a", &schema)?, + Arc::new( + Field::new("cast_target", Int64, target_nullable) + .with_metadata(metadata.clone()), + ), + None, + ); + + let field = expr.return_field(&schema)?; + assert_eq!(field.name(), "cast_target"); + assert_eq!(field.data_type(), &Int64); + assert_eq!(field.is_nullable(), target_nullable); + assert_eq!( + field.metadata().get("target_meta").map(String::as_str), + Some("1") + ); + assert_eq!(expr.nullable(&schema)?, child_nullable || target_nullable); + } + + Ok(()) + } + + #[test] + fn type_only_cast_preserves_legacy_field_name_and_nullability() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", Int32, false)]); + let expr = CastExpr::new(col("a", &schema)?, Int64, None); + + let field = expr.return_field(&schema)?; + + assert_eq!(field.name(), "a"); + assert_eq!(field.data_type(), &Int64); + assert!(!field.is_nullable()); + assert!(!expr.nullable(&schema)?); + + Ok(()) + } + + #[test] + fn struct_cast_validation_uses_nested_target_fields() -> Result<()> { + let source_type = Struct(Fields::from(vec![ + Arc::new(Field::new("x", Int32, true)), + Arc::new(Field::new("y", Utf8, true)), + ])); + let schema = Schema::new(vec![Field::new("a", source_type.clone(), true)]); + + let valid_target = Struct(Fields::from(vec![ + Arc::new(Field::new("y", Utf8, true)), + Arc::new(Field::new("x", Int64, true)), + ])); + cast_with_options(col("a", &schema)?, &schema, valid_target, None)?; + + let invalid_target = Struct(Fields::from(vec![ + Arc::new(Field::new("y", Utf8, true)), + Arc::new(Field::new("missing", Int64, false)), + ])); + let err = cast_with_options(col("a", &schema)?, &schema, invalid_target, None) + .expect_err("missing required struct field should fail"); + + assert!(err.to_string().contains("Unsupported CAST")); + + Ok(()) + } + + #[test] + fn field_aware_cast_struct_array_missing_child() -> Result<()> { + let source_a = Field::new("a", Int32, true); + let source_b = Field::new("b", Utf8, true); + let target_field = Field::new( + "s", + Struct( + vec![ + Arc::new(Field::new("a", Int64, true)), + Arc::new(Field::new("c", Utf8, true)), + ] + .into(), + ), + true, + ); + + let struct_array = cast_struct_array( + "s", + Field::new( + "s", + Struct( + vec![Arc::new(source_a.clone()), Arc::new(source_b.clone())].into(), + ), + true, + ), + target_field, + make_struct_array( + vec![Arc::new(source_a), Arc::new(source_b)].into(), + vec![ + Arc::new(Int32Array::from(vec![Some(1), None])) as ArrayRef, + Arc::new(StringArray::from(vec![Some("alpha"), Some("beta")])) + as ArrayRef, + ], + ), + )?; + let cast_a = as_int64_array(struct_array.column_by_name("a").unwrap().as_ref())?; + assert_eq!(cast_a.value(0), 1); + assert!(cast_a.is_null(1)); + + let cast_c = as_string_array(struct_array.column_by_name("c").unwrap().as_ref())?; + assert!(cast_c.is_null(0)); + assert!(cast_c.is_null(1)); + Ok(()) + } + + #[test] + fn field_aware_cast_nested_struct_array() -> Result<()> { + let inner_source = Field::new( + "inner", + Struct(vec![Arc::new(Field::new("x", Int32, true))].into()), + true, + ); + let inner_target = Field::new( + "inner", + Struct( + vec![ + Arc::new(Field::new("x", Int64, true)), + Arc::new(Field::new("y", Boolean, true)), + ] + .into(), + ), + true, + ); + let target_field = + Field::new("root", Struct(vec![Arc::new(inner_target)].into()), true); + + let inner_struct = make_struct_array( + vec![Arc::new(Field::new("x", Int32, true))].into(), + vec![Arc::new(Int32Array::from(vec![Some(7), None])) as ArrayRef], + ); + let outer_struct = make_struct_array( + vec![Arc::new(inner_source.clone())].into(), + vec![Arc::new(inner_struct) as ArrayRef], + ); + let struct_array = cast_struct_array( + "root", + Field::new("root", Struct(vec![Arc::new(inner_source)].into()), true), + target_field, + outer_struct, + )?; + let inner = + as_struct_array(struct_array.column_by_name("inner").unwrap().as_ref())?; + let x = as_int64_array(inner.column_by_name("x").unwrap().as_ref())?; + assert_eq!(x.value(0), 7); + assert!(x.is_null(1)); + let y = as_boolean_array(inner.column_by_name("y").unwrap().as_ref())?; + assert!(y.is_null(0)); + assert!(y.is_null(1)); + Ok(()) + } + + #[test] + fn field_aware_cast_struct_scalar() -> Result<()> { + let source_field = Field::new("a", Int32, true); + let target_field = Field::new( + "s", + Struct(vec![Arc::new(Field::new("a", UInt8, true))].into()), + true, + ); + + let schema = Arc::new(Schema::new(vec![Field::new( + "s", + Struct(vec![Arc::new(source_field.clone())].into()), + true, + )])); + let scalar_struct = make_struct_array( + vec![Arc::new(source_field)].into(), + vec![Arc::new(Int32Array::from(vec![Some(9)])) as ArrayRef], + ); + let literal = Arc::new(crate::expressions::Literal::new(ScalarValue::Struct( + Arc::new(scalar_struct), + ))); + let expr = CastExpr::new_with_target_field(literal, Arc::new(target_field), None); + + let batch = RecordBatch::new_empty(schema); + let result = expr.evaluate(&batch)?; + let ColumnarValue::Scalar(ScalarValue::Struct(array)) = result else { + panic!("expected struct scalar"); + }; + let casted = as_uint8_array(array.column_by_name("a").unwrap().as_ref())?; + assert_eq!(casted.value(0), 9); + Ok(()) + } + #[test] #[ignore] // TODO: https://github.com/apache/datafusion/issues/5396 fn test_cast_decimal() -> Result<()> { @@ -807,3 +1209,164 @@ mod tests { Ok(()) } } + +/// Tests for the `try_to_proto` / `try_from_proto` hooks. +#[cfg(all(test, feature = "proto"))] +mod proto_tests { + use super::*; + use crate::expressions::{Column, col}; + use crate::proto_test_util::{ + StubDecoder, StubEncoder, UnreachableDecoder, column_node, + }; + use arrow::datatypes::Field; + use datafusion_common::DataFusionError; + use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx; + use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx; + use datafusion_proto_models::datafusion_common::ArrowType; + use datafusion_proto_models::protobuf::{ + PhysicalCastNode, PhysicalExprNode, physical_expr_node, + }; + + /// A `CastExpr` over an `Int32` column, casting to `Int64`. + fn proto_cast_fixture() -> CastExpr { + let schema = Schema::new(vec![Field::new("a", Int32, false)]); + CastExpr::new(col("a", &schema).unwrap(), Int64, None) + } + + fn proto_int64_arrow_type() -> ArrowType { + (&Int64).try_into().unwrap() + } + + /// Build a `CastExpr` proto node with the given child and target type. + fn proto_cast_node( + expr: Option>, + arrow_type: Option, + ) -> PhysicalExprNode { + PhysicalExprNode { + expr_id: None, + expr_type: Some(physical_expr_node::ExprType::Cast(Box::new( + PhysicalCastNode { expr, arrow_type }, + ))), + } + } + + #[test] + fn try_to_proto_encodes_cast_expr() { + let cast = proto_cast_fixture(); + let encoder = StubEncoder::ok(); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let node = cast + .try_to_proto(&ctx) + .unwrap() + .expect("CastExpr should encode to Some(node)"); + + assert!(node.expr_id.is_none()); + let cast_node = match node.expr_type { + Some(physical_expr_node::ExprType::Cast(cast_node)) => *cast_node, + other => panic!("expected a Cast node, got {other:?}"), + }; + assert!(cast_node.expr.is_some()); + + let arrow_type = cast_node + .arrow_type + .as_ref() + .expect("cast type should be encoded"); + let data_type: DataType = arrow_type.try_into().unwrap(); + assert_eq!(data_type, Int64); + } + + #[test] + fn try_to_proto_propagates_child_encode_error() { + let cast = proto_cast_fixture(); + let encoder = StubEncoder::failing_on(1); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let err = cast.try_to_proto(&ctx).unwrap_err(); + assert!(matches!( + err, + DataFusionError::Internal(msg) if msg.contains("call 1") + )); + } + + #[test] + fn try_from_proto_decodes_cast_expr() { + let node = proto_cast_node( + Some(Box::new(column_node("a"))), + Some(proto_int64_arrow_type()), + ); + let schema = Schema::empty(); + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let decoded = CastExpr::try_from_proto(&node, &ctx).unwrap(); + let cast = decoded + .downcast_ref::() + .expect("decoded expr should be a CastExpr"); + + assert_eq!(cast.cast_type(), &Int64); + assert!(cast.expr().downcast_ref::().is_some()); + } + + #[test] + fn try_from_proto_rejects_non_cast_node() { + let node = column_node("a"); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = CastExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!( + err, + DataFusionError::Internal(msg) + if msg.contains("PhysicalExprNode is not a CastExpr") + )); + } + + #[test] + fn try_from_proto_rejects_missing_expr() { + let node = proto_cast_node(None, Some(proto_int64_arrow_type())); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = CastExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!( + err, + DataFusionError::Internal(msg) + if msg.contains("CastExpr is missing required field 'expr'") + )); + } + + #[test] + fn try_from_proto_rejects_missing_arrow_type() { + let node = proto_cast_node(Some(Box::new(column_node("a"))), None); + let schema = Schema::empty(); + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = CastExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!( + err, + DataFusionError::Internal(msg) + if msg.contains("CastExpr is missing required field 'arrow_type'") + )); + } + + #[test] + fn try_from_proto_propagates_child_decode_error() { + let node = proto_cast_node( + Some(Box::new(column_node("a"))), + Some(proto_int64_arrow_type()), + ); + let schema = Schema::empty(); + let decoder = StubDecoder::failing_on(1); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = CastExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!( + err, + DataFusionError::Internal(msg) if msg.contains("call 1") + )); + } +} diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 9ca464b304306..0a96b00444850 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -17,7 +17,6 @@ //! Physical column reference: [`Column`] -use std::any::Any; use std::hash::Hash; use std::sync::Arc; @@ -28,8 +27,9 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_common::{Result, internal_err, plan_err}; use datafusion_expr::ColumnarValue; +use datafusion_expr_common::placement::ExpressionPlacement; /// Represents the column at a given index in a RecordBatch /// @@ -105,11 +105,6 @@ impl std::fmt::Display for Column { } impl PhysicalExpr for Column { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - /// Get the data type of this expression, given the schema of the input fn data_type(&self, input_schema: &Schema) -> Result { self.bounds_check(input_schema)?; @@ -129,6 +124,7 @@ impl PhysicalExpr for Column { } fn return_field(&self, input_schema: &Schema) -> Result { + self.bounds_check(input_schema)?; Ok(input_schema.field(self.index).clone().into()) } @@ -146,6 +142,55 @@ impl PhysicalExpr for Column { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.name) } + + fn placement(&self) -> ExpressionPlacement { + ExpressionPlacement::Column + } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + _ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::Column( + protobuf::PhysicalColumn { + name: self.name.clone(), + index: self.index as u32, + }, + )), + })) + } +} + +#[cfg(feature = "proto")] +impl Column { + /// Reconstruct a [`Column`] from its protobuf representation. + /// + /// Takes the whole [`PhysicalExprNode`] — the exact inverse of what + /// [`PhysicalExpr::try_to_proto`] produces — so every expression's + /// `try_from_proto` shares one signature. The decode context is currently + /// unused, but is threaded through so that future expressions with child + /// sub-expressions can recurse via [`PhysicalExprDecodeCtx::decode`]. + /// + /// [`PhysicalExprNode`]: datafusion_proto_models::protobuf::PhysicalExprNode + /// [`PhysicalExpr::try_to_proto`]: datafusion_physical_expr_common::physical_expr::PhysicalExpr::try_to_proto + /// [`PhysicalExprDecodeCtx::decode`]: datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx::decode + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + _ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_physical_expr_common::expect_expr_variant; + use datafusion_proto_models::protobuf; + let protobuf::PhysicalColumn { name, index } = expect_expr_variant!( + node, + protobuf::physical_expr_node::ExprType::Column, + "Column", + ); + Ok(Arc::new(Column::new(name, *index as usize))) + } } impl Column { @@ -158,7 +203,11 @@ impl Column { self.name, self.index, input_schema.fields.len(), - input_schema.fields().iter().map(|f| f.name()).collect::>() + input_schema + .fields() + .iter() + .map(|f| f.name()) + .collect::>() ) } } @@ -180,7 +229,7 @@ pub fn with_new_schema( ) -> Result> { Ok(expr .transform_up(|expr| { - if let Some(col) = expr.as_any().downcast_ref::() { + if let Some(col) = expr.downcast_ref::() { let idx = col.index(); let Some(field) = schema.fields().get(idx) else { return plan_err!( diff --git a/datafusion/physical-expr/src/expressions/dynamic_filters/mod.rs b/datafusion/physical-expr/src/expressions/dynamic_filters/mod.rs new file mode 100644 index 0000000000000..9fe3feb58603c --- /dev/null +++ b/datafusion/physical-expr/src/expressions/dynamic_filters/mod.rs @@ -0,0 +1,1189 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use parking_lot::RwLock; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::{fmt::Display, hash::Hash, sync::Arc}; +use tokio::sync::watch; + +use crate::PhysicalExpr; +use arrow::datatypes::{DataType, Schema}; +use datafusion_common::{ + Result, + tree_node::{Transformed, TransformedResult, TreeNode}, +}; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr_common::physical_expr::DynHash; + +mod tracker; +pub use tracker::{DynamicFilterTracker, DynamicFilterTracking}; + +/// State of a dynamic filter, tracking both updates and completion. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum FilterState { + /// Filter is in progress and may receive more updates. + InProgress { generation: u64 }, + /// Filter is complete and will not receive further updates. + Complete { generation: u64 }, +} + +impl FilterState { + fn generation(&self) -> u64 { + match self { + FilterState::InProgress { generation } + | FilterState::Complete { generation } => *generation, + } + } +} + +/// A dynamic [`PhysicalExpr`] that can be updated by anyone with a reference to it. +/// +/// Any `ExecutionPlan` that uses this expression and holds a reference to it internally should probably also +/// implement `ExecutionPlan::reset_state` to remain compatible with recursive queries and other situations where +/// the same `ExecutionPlan` is reused with different data. +/// +/// For more background, please also see the [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog] +/// +/// [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog]: https://datafusion.apache.org/blog/2025/09/10/dynamic-filters +#[derive(Debug)] +pub struct DynamicFilterPhysicalExpr { + /// The original children of this PhysicalExpr, if any. + /// This is necessary because the dynamic filter may be initialized with a placeholder (e.g. `lit(true)`) + /// and later remapped to the actual expressions that are being filtered. + /// But we need to know the children (e.g. columns referenced in the expression) ahead of time to evaluate the expression correctly. + children: Vec>, + /// If any of the children were remapped / modified (e.g. to adjust for projections) we need to keep track of the new children + /// so that when we update `current()` in subsequent iterations we can re-apply the replacements. + remapped_children: Option>>, + /// The source of dynamic filters. + inner: Arc>, + /// Broadcasts filter state (updates and completion) to all waiters. + state_watch: watch::Sender, + /// For testing purposes track the data type and nullability to make sure they don't change. + /// If they do, there's a bug in the implementation. + /// But this can have overhead in production, so it's only included in our tests. + data_type: Arc>>, + nullable: Arc>>, +} + +/// Atomic internal state of a [`DynamicFilterPhysicalExpr`]. +/// +/// `expression_id` lives here because it identifies the actual filter expression `expr`. +/// Derived `DynamicFilterPhysicalExpr`s (e.g. via [`PhysicalExpr::with_new_children`]) are +/// the same logical filter and must report the same `expression_id`. +/// +/// **Warning:** exposed publicly solely so that proto (de)serialization in +/// `datafusion-proto` can read and rebuild this state. Do not treat this type +/// or its layout as a stable API. +#[derive(Clone, Debug)] +pub struct Inner { + /// A unique identifier for the expression. + pub expression_id: u64, + /// A counter that gets incremented every time the expression is updated so that we can track changes cheaply. + /// This is used for [`PhysicalExpr::snapshot_generation`] to have a cheap check for changes. + pub generation: u64, + pub expr: Arc, + /// Flag for quick synchronous check if filter is complete. + /// This is redundant with the watch channel state, but allows us to return immediately + /// from `wait_complete()` without subscribing if already complete. + pub is_complete: bool, +} + +impl Inner { + fn new(expr: Arc) -> Self { + Self { + expression_id: EXPR_ID_SOURCE.next(), + // Start with generation 1 which gives us a different result for [`PhysicalExpr::generation`] than the default 0. + // This is not currently used anywhere but it seems useful to have this simple distinction. + generation: 1, + expr, + is_complete: false, + } + } + + /// Clone the inner expression. + fn expr(&self) -> &Arc { + &self.expr + } +} + +impl Hash for DynamicFilterPhysicalExpr { + fn hash(&self, state: &mut H) { + // Use pointer identity of the inner Arc for stable hashing. + // This is stable across update() calls and consistent with Eq. + // See issue #19641 for details on why content-based hashing violates + // the Hash/Eq contract when the underlying expression can change. + Arc::as_ptr(&self.inner).hash(state); + self.children.dyn_hash(state); + self.remapped_children.dyn_hash(state); + } +} + +impl PartialEq for DynamicFilterPhysicalExpr { + fn eq(&self, other: &Self) -> bool { + // Two dynamic filters are equal if they share the same inner source + // AND have the same children configuration. + // This is consistent with Hash using Arc::as_ptr. + // See issue #19641 for details on the Hash/Eq contract violation fix. + Arc::ptr_eq(&self.inner, &other.inner) + && self.children == other.children + && self.remapped_children == other.remapped_children + } +} + +impl Eq for DynamicFilterPhysicalExpr {} + +impl Display for DynamicFilterPhysicalExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.render(f, |expr, f| write!(f, "{expr}")) + } +} + +impl DynamicFilterPhysicalExpr { + /// Create a new [`DynamicFilterPhysicalExpr`] + /// from an initial expression and a list of children. + /// The list of children is provided separately because + /// the initial expression may not have the same children. + /// For example, if the initial expression is just `true` + /// it will not reference any columns, but we may know that + /// we are going to replace this expression with a real one + /// that does reference certain columns. + /// In this case you **must** pass in the columns that will be + /// used in the final expression as children to this function + /// since DataFusion is generally not compatible with dynamic + /// *children* in expressions. + /// + /// To determine the children you can: + /// + /// - Use [`collect_columns`] to collect the columns from the expression. + /// - Use existing information, such as the sort columns in a `SortExec`. + /// + /// Generally the important bit is that the *leaf children that reference columns + /// do not change* since those will be used to determine what columns need to read or projected + /// when evaluating the expression. + /// + /// Any `ExecutionPlan` that uses this expression and holds a reference to it internally should probably also + /// implement `ExecutionPlan::reset_state` to remain compatible with recursive queries and other situations where + /// the same `ExecutionPlan` is reused with different data. + /// + /// [`collect_columns`]: crate::utils::collect_columns + pub fn new( + children: Vec>, + inner: Arc, + ) -> Self { + let (state_watch, _) = watch::channel(FilterState::InProgress { generation: 1 }); + Self { + children, + remapped_children: None, // Initially no remapped children + inner: Arc::new(RwLock::new(Inner::new(inner))), + state_watch, + data_type: Arc::new(RwLock::new(None)), + nullable: Arc::new(RwLock::new(None)), + } + } + + fn remap_children( + children: &[Arc], + remapped_children: Option<&Vec>>, + expr: Arc, + ) -> Result> { + if let Some(remapped_children) = remapped_children { + // Remap the children to the new children + // of the expression. + expr.transform_up(|child| { + // Check if this is any of our original children + if let Some(pos) = + children.iter().position(|c| c.as_ref() == child.as_ref()) + { + // If so, remap it to the current children + // of the expression. + let new_child = Arc::clone(&remapped_children[pos]); + Ok(Transformed::yes(new_child)) + } else { + // Otherwise, just return the expression + Ok(Transformed::no(child)) + } + }) + .data() + } else { + // If we don't have any remapped children, just return the expression + Ok(Arc::clone(&expr)) + } + } + + /// Get the current generation of the expression. + fn current_generation(&self) -> u64 { + self.inner.read().generation + } + + /// Get the current expression. + /// This will return the current expression with any children + /// remapped to match calls to [`PhysicalExpr::with_new_children`]. + pub fn current(&self) -> Result> { + let expr = Arc::clone(self.inner.read().expr()); + Self::remap_children(&self.children, self.remapped_children.as_ref(), expr) + } + + /// Update the current expression and notify all waiters. + /// Any children of this expression must be a subset of the original children + /// passed to the constructor. + /// This should be called e.g.: + /// - When we've computed the probe side's hash table in a HashJoinExec + /// - After every batch is processed if we update the TopK heap in a SortExec using a TopK approach. + pub fn update(&self, new_expr: Arc) -> Result<()> { + // Remap the children of the new expression to match the original children + // We still do this again in `current()` but doing it preventively here + // reduces the work needed in some cases if `current()` is called multiple times + // and the same externally facing `PhysicalExpr` is used for both `with_new_children` and `update()`.` + let new_expr = Self::remap_children( + &self.children, + self.remapped_children.as_ref(), + new_expr, + )?; + + // Load the current inner, increment generation, and store the new one + let mut current = self.inner.write(); + let new_generation = current.generation + 1; + *current = Inner { + // Preserve the expression id across updates. + expression_id: current.expression_id, + generation: new_generation, + expr: new_expr, + is_complete: current.is_complete, + }; + drop(current); // Release the lock before broadcasting + + // Broadcast the new state to all waiters + let _ = self.state_watch.send(FilterState::InProgress { + generation: new_generation, + }); + Ok(()) + } + + /// Mark this dynamic filter as complete and broadcast to all waiters. + /// + /// This signals that all expected updates have been received. + /// Waiters using [`Self::wait_complete`] will be notified. + pub fn mark_complete(&self) { + let mut current = self.inner.write(); + let current_generation = current.generation; + current.is_complete = true; + drop(current); + + // Broadcast completion to all waiters + let _ = self.state_watch.send(FilterState::Complete { + generation: current_generation, + }); + } + + /// Wait asynchronously for any update to this filter. + /// + /// This method will return when [`Self::update`] is called and the generation increases. + /// It does not guarantee that the filter is complete. + /// + /// Producers (e.g.) HashJoinExec may never update the expression or mark it as completed if there are no consumers. + /// If you call this method on a dynamic filter created by such a producer and there are no consumers registered this method would wait indefinitely. + /// This should not happen under normal operation and would indicate a programming error either in your producer or in DataFusion if the producer is a built in node. + pub async fn wait_update(&self) { + let mut rx = self.state_watch.subscribe(); + // Get the current generation + let current_gen = rx.borrow_and_update().generation(); + + // Wait until generation increases + let _ = rx.wait_for(|state| state.generation() > current_gen).await; + } + + /// Wait asynchronously until this dynamic filter is marked as complete. + /// + /// This method returns immediately if the filter is already complete. + /// Otherwise, it waits until [`Self::mark_complete`] is called. + /// + /// Unlike [`Self::wait_update`], this method guarantees that when it returns, + /// the filter is fully complete with no more updates expected. + /// + /// Producers (e.g.) HashJoinExec may never update the expression or mark it as completed if there are no consumers. + /// If you call this method on a dynamic filter created by such a producer and there are no consumers registered this method would wait indefinitely. + /// This should not happen under normal operation and would indicate a programming error either in your producer or in DataFusion if the producer is a built in node. + pub async fn wait_complete(&self) { + if self.inner.read().is_complete { + return; + } + + let mut rx = self.state_watch.subscribe(); + let _ = rx + .wait_for(|state| matches!(state, FilterState::Complete { .. })) + .await; + } + + /// Returns `true` if this filter has been marked complete via + /// [`Self::mark_complete`] and will therefore never change again. + pub(crate) fn is_complete(&self) -> bool { + self.inner.read().is_complete + } + + /// Subscribe to this filter's updates for cheap, synchronous change + /// detection. + /// + /// The returned [`DynamicFilterSubscription`] lets a consumer poll whether + /// the filter's expression has advanced since it last looked, without + /// re-walking a predicate tree or re-deriving a generation on every check. + /// This is the building block used by [`DynamicFilterTracker`] to watch + /// every dynamic filter inside a (possibly composite) predicate. + pub(crate) fn subscribe(&self) -> DynamicFilterSubscription { + let mut receiver = self.state_watch.subscribe(); + // Mark the current state as already-seen so the first `observe()` only + // reports updates that happen *after* subscription. + let last_generation = receiver.borrow_and_update().generation(); + DynamicFilterSubscription { + receiver, + last_generation, + } + } + + /// Check if this dynamic filter is being actively used by any consumers. + /// + /// Returns `true` if there are references beyond the producer (e.g., the HashJoinExec + /// that created the filter). This is useful to avoid computing expensive filter + /// expressions when no consumer will actually use them. + /// + /// # Implementation Details + /// + /// We check both Arc counts to handle two cases: + /// - Transformed filters (via `with_new_children`) share the inner Arc (inner count > 1) + /// - Direct clones (via `Arc::clone`) increment the outer count (outer count > 1) + pub fn is_used(self: &Arc) -> bool { + // Strong count > 1 means at least one consumer is holding a reference beyond the producer. + Arc::strong_count(self) > 1 || Arc::strong_count(&self.inner) > 1 + } + + fn render( + &self, + f: &mut std::fmt::Formatter<'_>, + render_expr: impl FnOnce( + Arc, + &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result, + ) -> std::fmt::Result { + let inner = self.current().map_err(|_| std::fmt::Error)?; + let current_generation = self.current_generation(); + write!(f, "DynamicFilter [ ")?; + if current_generation == 1 { + write!(f, "empty")?; + } else { + render_expr(inner, f)?; + } + + write!(f, " ]") + } + + /// Return the filter's original children (before any remapping). + /// + /// **Warning:** intended only for `datafusion-proto` (de)serialization. + /// Not a stable API. + pub fn original_children(&self) -> &[Arc] { + &self.children + } + + /// Return the filter's remapped children, if any have been set via + /// [`PhysicalExpr::with_new_children`]. + /// + /// **Warning:** intended only for `datafusion-proto` (de)serialization. + /// Not a stable API. + pub fn remapped_children(&self) -> Option<&[Arc]> { + self.remapped_children.as_deref() + } + + /// Rebuild a `DynamicFilterPhysicalExpr` from its stored parts. Used by + /// proto deserialization. + /// + /// **Warning:** intended only for `datafusion-proto` (de)serialization. + /// Not a stable API. + pub fn from_parts( + children: Vec>, + remapped_children: Option>>, + inner: Inner, + ) -> Self { + let state = if inner.is_complete { + FilterState::Complete { + generation: inner.generation, + } + } else { + FilterState::InProgress { + generation: inner.generation, + } + }; + let (state_watch, _) = watch::channel(state); + + Self { + children, + remapped_children, + inner: Arc::new(RwLock::new(inner)), + state_watch, + data_type: Arc::new(RwLock::new(None)), + nullable: Arc::new(RwLock::new(None)), + } + } + + /// Return a clone of the atomically-captured `Inner` state. + /// + /// **Warning:** intended only for `datafusion-proto` (de)serialization. + /// Not a stable API. + pub fn inner(&self) -> Inner { + self.inner.read().clone() + } +} + +impl PhysicalExpr for DynamicFilterPhysicalExpr { + fn children(&self) -> Vec<&Arc> { + self.remapped_children + .as_ref() + .unwrap_or(&self.children) + .iter() + .collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self { + children: self.children.clone(), + remapped_children: Some(children), + // Note: expression_id is preserved + inner: Arc::clone(&self.inner), + state_watch: self.state_watch.clone(), + data_type: Arc::clone(&self.data_type), + nullable: Arc::clone(&self.nullable), + })) + } + + fn data_type(&self, input_schema: &Schema) -> Result { + let res = self.current()?.data_type(input_schema)?; + #[cfg(test)] + { + use datafusion_common::internal_err; + // Check if the data type has changed. + let mut data_type_lock = self.data_type.write(); + + if let Some(existing) = &*data_type_lock { + if existing != &res { + // If the data type has changed, we have a bug. + return internal_err!( + "DynamicFilterPhysicalExpr data type has changed unexpectedly. \ + Expected: {existing:?}, Actual: {res:?}" + ); + } + } else { + *data_type_lock = Some(res.clone()); + } + } + Ok(res) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + let res = self.current()?.nullable(input_schema)?; + #[cfg(test)] + { + use datafusion_common::internal_err; + // Check if the nullability has changed. + let mut nullable_lock = self.nullable.write(); + if let Some(existing) = *nullable_lock { + if existing != res { + // If the nullability has changed, we have a bug. + return internal_err!( + "DynamicFilterPhysicalExpr nullability has changed unexpectedly. \ + Expected: {existing}, Actual: {res}" + ); + } + } else { + *nullable_lock = Some(res); + } + } + Ok(res) + } + + fn evaluate( + &self, + batch: &arrow::record_batch::RecordBatch, + ) -> Result { + let current = self.current()?; + #[cfg(test)] + { + // Ensure that we are not evaluating after the expression has changed. + let schema = batch.schema(); + self.nullable(&schema)?; + self.data_type(&schema)?; + }; + current.evaluate(batch) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.render(f, |expr, f| expr.fmt_sql(f)) + } + + fn snapshot(&self) -> Result>> { + // Return the current expression as a snapshot. + Ok(Some(self.current()?)) + } + + fn snapshot_generation(&self) -> u64 { + // Return the current generation of the expression. + self.inner.read().generation + } + + fn expression_id(&self) -> Option { + Some(self.inner.read().expression_id) + } +} + +/// The result of polling a [`DynamicFilterSubscription`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) struct DynamicFilterChange { + /// The filter's expression advanced since the previous observation. + pub(crate) changed: bool, + /// The filter has been marked complete; it will never change again and the + /// subscription can be dropped. + pub(crate) complete: bool, +} + +/// A cheap, synchronous handle for observing updates to a single +/// [`DynamicFilterPhysicalExpr`]. +/// +/// Obtained via [`DynamicFilterPhysicalExpr::subscribe`]. Steady-state polling +/// via [`Self::observe`] is a single atomic load (the underlying +/// [`tokio::sync::watch`] version counter); the lock is only taken when the +/// filter has actually been updated. +#[derive(Debug)] +pub(crate) struct DynamicFilterSubscription { + receiver: watch::Receiver, + /// Last generation we reported as "seen". Used to distinguish a real + /// expression update from a bare [`DynamicFilterPhysicalExpr::mark_complete`] + /// (which re-broadcasts the current generation without changing the + /// expression). + last_generation: u64, +} + +impl DynamicFilterSubscription { + /// Observe the latest state of the filter. + /// + /// Reports whether the filter's expression advanced since the previous call + /// and whether it has since been marked complete. Cheap when nothing has + /// changed: a single atomic comparison with no lock acquisition. + pub(crate) fn observe(&mut self) -> DynamicFilterChange { + match self.receiver.has_changed() { + Ok(true) => { + let state = *self.receiver.borrow_and_update(); + let changed = state.generation() > self.last_generation; + if changed { + self.last_generation = state.generation(); + } + DynamicFilterChange { + changed, + complete: matches!(state, FilterState::Complete { .. }), + } + } + Ok(false) => DynamicFilterChange { + changed: false, + complete: false, + }, + // The watch sender lives inside the predicate's + // `DynamicFilterPhysicalExpr`, which the owner of this subscription + // keeps alive, so observing a dropped sender signals a bug rather + // than normal completion. Flag it loudly in debug builds; in release + // degrade to "complete" (no further updates are possible) instead of + // silently masking it. + Err(_) => { + debug_assert!( + false, + "DynamicFilterSubscription observed a dropped watch sender; \ + the owning predicate should keep it alive" + ); + DynamicFilterChange { + changed: false, + complete: true, + } + } + } + } +} + +/// An atomic counter used to generate monotonic u64 ids. +struct ExpressionIdAtomicCounter { + inner: AtomicU64, +} + +impl ExpressionIdAtomicCounter { + const fn new() -> Self { + Self { + inner: AtomicU64::new(0), + } + } + + /// Returns a fresh `expression_id` by incrementing the internal counter. + fn next(&self) -> u64 { + self.inner.fetch_add(1, Ordering::Relaxed) + } +} + +/// Process-wide source of deterministic `expression_id`s for [`DynamicFilterPhysicalExpr`]. +/// +/// Currently, no other [`PhysicalExpr`]s use this source. If needed, it can be moved out of this +/// file and be made public for other expressions to use. +static EXPR_ID_SOURCE: ExpressionIdAtomicCounter = ExpressionIdAtomicCounter::new(); + +#[cfg(test)] +mod test { + use crate::{ + expressions::{BinaryExpr, col, lit}, + utils::reassign_expr_columns, + }; + use arrow::{ + array::RecordBatch, + datatypes::{DataType, Field, Schema}, + }; + use datafusion_common::ScalarValue; + + use super::*; + + #[test] + fn test_remap_children() { + let table_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let expr = Arc::new(BinaryExpr::new( + col("a", &table_schema).unwrap(), + datafusion_expr::Operator::Eq, + lit(42) as Arc, + )); + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![col("a", &table_schema).unwrap()], + expr as Arc, + )); + // Simulate two `ParquetSource` files with different filter schemas + // Both of these should hit the same inner `PhysicalExpr` even after `update()` is called + // and be able to remap children independently. + let filter_schema_1 = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let filter_schema_2 = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + // Each ParquetExec calls `with_new_children` on the DynamicFilterPhysicalExpr + // and remaps the children to the file schema. + let dynamic_filter_1 = reassign_expr_columns( + Arc::clone(&dynamic_filter) as Arc, + &filter_schema_1, + ) + .unwrap(); + let snap = dynamic_filter_1.snapshot().unwrap().unwrap(); + insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 0 }, op: Eq, right: Literal { value: Int32(42), field: Field { name: "lit", data_type: Int32 } }, fail_on_overflow: false }"#); + let dynamic_filter_2 = reassign_expr_columns( + Arc::clone(&dynamic_filter) as Arc, + &filter_schema_2, + ) + .unwrap(); + let snap = dynamic_filter_2.snapshot().unwrap().unwrap(); + insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 1 }, op: Eq, right: Literal { value: Int32(42), field: Field { name: "lit", data_type: Int32 } }, fail_on_overflow: false }"#); + // Both filters allow evaluating the same expression + let batch_1 = RecordBatch::try_new( + Arc::clone(&filter_schema_1), + vec![ + // a + ScalarValue::Int32(Some(42)).to_array_of_size(1).unwrap(), + // b + ScalarValue::Int32(Some(43)).to_array_of_size(1).unwrap(), + ], + ) + .unwrap(); + let batch_2 = RecordBatch::try_new( + Arc::clone(&filter_schema_2), + vec![ + // b + ScalarValue::Int32(Some(43)).to_array_of_size(1).unwrap(), + // a + ScalarValue::Int32(Some(42)).to_array_of_size(1).unwrap(), + ], + ) + .unwrap(); + // Evaluate the expression on both batches + let result_1 = dynamic_filter_1.evaluate(&batch_1).unwrap(); + let result_2 = dynamic_filter_2.evaluate(&batch_2).unwrap(); + // Check that the results are the same + let ColumnarValue::Array(arr_1) = result_1 else { + panic!("Expected ColumnarValue::Array"); + }; + let ColumnarValue::Array(arr_2) = result_2 else { + panic!("Expected ColumnarValue::Array"); + }; + assert!(arr_1.eq(&arr_2)); + let expected = ScalarValue::Boolean(Some(true)) + .to_array_of_size(1) + .unwrap(); + assert!(arr_1.eq(&expected)); + // Now lets update the expression + // Note that we update the *original* expression and that should be reflected in both the derived expressions + let new_expr = Arc::new(BinaryExpr::new( + col("a", &table_schema).unwrap(), + datafusion_expr::Operator::Gt, + lit(43) as Arc, + )); + dynamic_filter + .update(Arc::clone(&new_expr) as Arc) + .expect("Failed to update expression"); + // Now we should be able to evaluate the new expression on both batches + let result_1 = dynamic_filter_1.evaluate(&batch_1).unwrap(); + let result_2 = dynamic_filter_2.evaluate(&batch_2).unwrap(); + // Check that the results are the same + let ColumnarValue::Array(arr_1) = result_1 else { + panic!("Expected ColumnarValue::Array"); + }; + let ColumnarValue::Array(arr_2) = result_2 else { + panic!("Expected ColumnarValue::Array"); + }; + assert!(arr_1.eq(&arr_2)); + let expected = ScalarValue::Boolean(Some(false)) + .to_array_of_size(1) + .unwrap(); + assert!(arr_1.eq(&expected)); + } + + #[test] + fn test_snapshot() { + let expr = lit(42) as Arc; + let dynamic_filter = DynamicFilterPhysicalExpr::new(vec![], Arc::clone(&expr)); + + // Take a snapshot of the current expression + let snapshot = dynamic_filter.snapshot().unwrap(); + assert_eq!(snapshot, Some(expr)); + + // Update the current expression + let new_expr = lit(100) as Arc; + dynamic_filter.update(Arc::clone(&new_expr)).unwrap(); + // Take another snapshot + let snapshot = dynamic_filter.snapshot().unwrap(); + assert_eq!(snapshot, Some(new_expr)); + } + + #[test] + fn test_dynamic_filter_physical_expr_misbehaves_data_type_nullable() { + let dynamic_filter = + DynamicFilterPhysicalExpr::new(vec![], lit(42) as Arc); + + // First call to data_type and nullable should set the initial values. + let initial_data_type = dynamic_filter.data_type(&Schema::empty()).unwrap(); + let initial_nullable = dynamic_filter.nullable(&Schema::empty()).unwrap(); + + // Call again and expect no change. + let second_data_type = dynamic_filter.data_type(&Schema::empty()).unwrap(); + let second_nullable = dynamic_filter.nullable(&Schema::empty()).unwrap(); + assert_eq!( + initial_data_type, second_data_type, + "Data type should not change on second call." + ); + assert_eq!( + initial_nullable, second_nullable, + "Nullability should not change on second call." + ); + + // Now change the current expression to something else. + dynamic_filter + .update(lit(ScalarValue::Utf8(None)) as Arc) + .expect("Failed to update expression"); + // Check that we error if we call data_type, nullable or evaluate after changing the expression. + assert!( + dynamic_filter.data_type(&Schema::empty()).is_err(), + "Expected err when data_type is called after changing the expression." + ); + assert!( + dynamic_filter.nullable(&Schema::empty()).is_err(), + "Expected err when nullable is called after changing the expression." + ); + let batch = RecordBatch::new_empty(Arc::new(Schema::empty())); + assert!( + dynamic_filter.evaluate(&batch).is_err(), + "Expected err when evaluate is called after changing the expression." + ); + } + + #[tokio::test] + async fn test_wait_complete_already_complete() { + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![], + lit(42) as Arc, + )); + + // Mark as complete immediately + dynamic_filter.mark_complete(); + + // wait_complete should return immediately + dynamic_filter.wait_complete().await; + } + + #[test] + fn test_with_new_children_independence() { + // Create a schema with columns a, b, c, d + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + Field::new("d", DataType::Int32, false), + ])); + + // Create expression col(a) + col(b) + let col_a = col("a", &schema).unwrap(); + let col_b = col("b", &schema).unwrap(); + let col_c = col("c", &schema).unwrap(); + let col_d = col("d", &schema).unwrap(); + + let expr = Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + datafusion_expr::Operator::Plus, + Arc::clone(&col_b), + )); + + // Create DynamicFilterPhysicalExpr with children [col_a, col_b] + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&col_a), Arc::clone(&col_b)], + expr as Arc, + )); + + // Clone the Arc (two references to the same DynamicFilterPhysicalExpr) + let clone_1 = Arc::clone(&dynamic_filter); + let clone_2 = Arc::clone(&dynamic_filter); + + // Call with_new_children with different children on each clone + // clone_1: replace [a, b] with [b, c] -> expression becomes b + c + let remapped_1 = clone_1 + .with_new_children(vec![Arc::clone(&col_b), Arc::clone(&col_c)]) + .unwrap(); + + // clone_2: replace [a, b] with [b, d] -> expression becomes b + d + let remapped_2 = clone_2 + .with_new_children(vec![Arc::clone(&col_b), Arc::clone(&col_d)]) + .unwrap(); + + // Create a RecordBatch with columns a=1,2,3 b=10,20,30 c=100,200,300 d=1000,2000,3000 + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3])), // a + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), // b + Arc::new(arrow::array::Int32Array::from(vec![100, 200, 300])), // c + Arc::new(arrow::array::Int32Array::from(vec![1000, 2000, 3000])), // d + ], + ) + .unwrap(); + + // Evaluate both remapped expressions + let result_1 = remapped_1.evaluate(&batch).unwrap(); + let result_2 = remapped_2.evaluate(&batch).unwrap(); + + // Extract arrays from results + let ColumnarValue::Array(arr_1) = result_1 else { + panic!("Expected ColumnarValue::Array for result_1"); + }; + let ColumnarValue::Array(arr_2) = result_2 else { + panic!("Expected ColumnarValue::Array for result_2"); + }; + + // Verify result_1 = b + c = [110, 220, 330] + let expected_1: Arc = + Arc::new(arrow::array::Int32Array::from(vec![110, 220, 330])); + assert!( + arr_1.eq(&expected_1), + "Expected b + c = [110, 220, 330], got {arr_1:?}", + ); + + // Verify result_2 = b + d = [1010, 2020, 3030] + let expected_2: Arc = + Arc::new(arrow::array::Int32Array::from(vec![1010, 2020, 3030])); + assert!( + arr_2.eq(&expected_2), + "Expected b + d = [1010, 2020, 3030], got {arr_2:?}", + ); + } + + #[test] + fn test_is_used() { + let filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![], + lit(true) as Arc, + )); + + // Initially, only one reference to the inner Arc exists + assert!( + !filter.is_used(), + "Filter should not be used with only one inner reference" + ); + + // Simulate a consumer created via transformation (what happens during filter pushdown). + // When filters are pushed down and transformed via reassign_expr_columns/transform_down, + // with_new_children() is called which creates a new outer Arc but clones the inner Arc. + let consumer1_expr = Arc::clone(&filter).with_new_children(vec![]).unwrap(); + let _consumer1 = consumer1_expr + .downcast_ref::() + .expect("Should be DynamicFilterPhysicalExpr"); + + // Now the inner Arc is shared (inner_count = 2) + assert!( + filter.is_used(), + "Filter should be used when inner Arc is shared with transformed consumer" + ); + + // Create another transformed consumer + let consumer2_expr = Arc::clone(&filter).with_new_children(vec![]).unwrap(); + let _consumer2 = consumer2_expr + .downcast_ref::() + .expect("Should be DynamicFilterPhysicalExpr"); + + assert!( + filter.is_used(), + "Filter should still be used with multiple consumers" + ); + } + + /// Test that verifies the Hash/Eq contract is now satisfied (issue #19641 fix). + /// + /// After the fix, Hash uses Arc::as_ptr(&self.inner) which is stable across + /// update() calls, fixing the HashMap key instability issue. + #[test] + fn test_hash_stable_after_update() { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + // Create filter with initial value + let filter = + DynamicFilterPhysicalExpr::new(vec![], lit(true) as Arc); + + // Compute hash BEFORE update + let mut hasher_before = DefaultHasher::new(); + filter.hash(&mut hasher_before); + let hash_before = hasher_before.finish(); + + // Update changes the underlying expression + filter + .update(lit(false) as Arc) + .expect("Update should succeed"); + + // Compute hash AFTER update + let mut hasher_after = DefaultHasher::new(); + filter.hash(&mut hasher_after); + let hash_after = hasher_after.finish(); + + // FIXED: Hash should now be STABLE after update() because we use + // Arc::as_ptr for identity-based hashing instead of expression content. + assert_eq!( + hash_before, hash_after, + "Hash should be stable after update() - fix for issue #19641" + ); + + // Self-equality should still hold + assert!(filter.eq(&filter), "Self-equality should hold"); + } + + /// Test that verifies separate DynamicFilterPhysicalExpr instances + /// with the same expression are NOT equal (identity-based comparison). + #[test] + fn test_identity_based_equality() { + // Create two separate filters with identical initial expressions + let filter1 = + DynamicFilterPhysicalExpr::new(vec![], lit(true) as Arc); + let filter2 = + DynamicFilterPhysicalExpr::new(vec![], lit(true) as Arc); + + // Different instances should NOT be equal even with same expression + // because they have independent inner Arcs (different update lifecycles) + assert!( + !filter1.eq(&filter2), + "Different instances should not be equal (identity-based)" + ); + + // Self-equality should hold + assert!(filter1.eq(&filter1), "Self-equality should hold"); + } + + /// Test that hash is stable for the same filter instance. + /// After the fix, hash uses Arc::as_ptr which is pointer-based. + #[test] + fn test_hash_stable_for_same_instance() { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let filter = + DynamicFilterPhysicalExpr::new(vec![], lit(true) as Arc); + + // Compute hash twice for the same instance + let hash1 = { + let mut h = DefaultHasher::new(); + filter.hash(&mut h); + h.finish() + }; + let hash2 = { + let mut h = DefaultHasher::new(); + filter.hash(&mut h); + h.finish() + }; + + assert_eq!(hash1, hash2, "Same instance should have stable hash"); + + // Update the expression + filter + .update(lit(false) as Arc) + .expect("Update should succeed"); + + // Hash should STILL be the same (identity-based) + let hash3 = { + let mut h = DefaultHasher::new(); + filter.hash(&mut h); + h.finish() + }; + + assert_eq!( + hash1, hash3, + "Hash should be stable after update (identity-based)" + ); + } + + /// Verifies that `from_parts` rebuilds a `DynamicFilterPhysicalExpr` + /// whose observable state (original children, remapped children, + /// expression id, inner generation/expr/is_complete) matches the source + /// filter. + #[test] + fn test_from_parts_preserves_state() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let col_a = col("a", &schema).unwrap(); + + // Create a dynamic filter with children + let expr = Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + datafusion_expr::Operator::Gt, + lit(10) as Arc, + )); + let filter = DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&col_a)], + expr as Arc, + ); + + // Add remapped children. + let reassigned_schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + let reassigned = reassign_expr_columns( + Arc::new(filter) as Arc, + &reassigned_schema, + ) + .expect("reassign_expr_columns should succeed"); + let reassigned = reassigned + .downcast_ref::() + .expect("Expected dynamic filter after reassignment"); + + reassigned + .update(lit(42) as Arc) + .expect("Update should succeed"); + reassigned.mark_complete(); + + // Capture the parts and reconstruct. `expression_id` rides in `inner`. + let reconstructed = DynamicFilterPhysicalExpr::from_parts( + reassigned.original_children().to_vec(), + reassigned.remapped_children().map(|r| r.to_vec()), + reassigned.inner(), + ); + + assert_eq!( + reassigned.original_children(), + reconstructed.original_children(), + ); + assert_eq!( + reassigned.remapped_children(), + reconstructed.remapped_children(), + ); + assert_eq!(reassigned.expression_id(), reconstructed.expression_id()); + let r = reassigned.inner(); + let c = reconstructed.inner(); + assert_eq!(r.generation, c.generation); + assert_eq!(r.is_complete, c.is_complete); + assert_eq!(format!("{:?}", r.expr), format!("{:?}", c.expr)); + } + + #[tokio::test] + async fn test_expression_id() { + let source_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let col_a = col("a", &source_schema).unwrap(); + + // Create a source filter + let source = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&col_a)], + lit(true) as Arc, + )); + let source_clone = Arc::clone(&source); + + // Create a derived filter by reassigning the source filter to a different schema. + let derived_schema = Arc::new(Schema::new(vec![ + Field::new("x", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + let derived = reassign_expr_columns( + Arc::clone(&source) as Arc, + &derived_schema, + ) + .expect("reassign_expr_columns should succeed"); + + let derived_expression_id = derived + .expression_id() + .expect("derived filter should have an expression id"); + let source_expression_id = source + .expression_id() + .expect("source filter should have an expression id"); + let source_clone_expression_id = source_clone + .expression_id() + .expect("source clone should have an expression id"); + + assert_eq!( + source_clone_expression_id, source_expression_id, + "cloned filter should preserve its expression id", + ); + + assert_eq!( + derived_expression_id, source_expression_id, + "derived filters should carry forward the source expression id", + ); + + // `update()` rewrites the entire `Inner` struct in place; pin down + // that the rewrite preserves `expression_id`. + source + .update(lit(99) as Arc) + .expect("update should succeed"); + assert_eq!( + source.expression_id().unwrap(), + source_expression_id, + "update() must not change expression_id", + ); + + // `mark_complete()` also touches `Inner`; same invariant. + source.mark_complete(); + assert_eq!( + source.expression_id().unwrap(), + source_expression_id, + "mark_complete() must not change expression_id", + ); + } +} diff --git a/datafusion/physical-expr/src/expressions/dynamic_filters/tracker.rs b/datafusion/physical-expr/src/expressions/dynamic_filters/tracker.rs new file mode 100644 index 0000000000000..fd4c18b07e2cd --- /dev/null +++ b/datafusion/physical-expr/src/expressions/dynamic_filters/tracker.rs @@ -0,0 +1,331 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tracking changes to the dynamic filters inside a predicate. +//! +//! Several operators (Parquet file/row-group pruning, remote execution, ...) +//! hold a predicate that *may* contain one or more +//! [`DynamicFilterPhysicalExpr`] nodes which are updated during execution +//! (e.g. a `TopK` tightening its threshold, or a `HashJoinExec` publishing the +//! build-side bounds). These consumers repeatedly ask two questions: +//! +//! 1. *"Does this predicate contain anything that can still change?"* — to +//! decide whether it is worth setting up runtime re-pruning at all. +//! 2. *"Has it changed since I last looked?"* — to decide whether to rebuild an +//! expensive derived artifact (e.g. a `PruningPredicate`). +//! +//! Historically each call site answered these by recursively folding +//! [`PhysicalExpr::snapshot_generation`] over the whole tree on *every* check +//! and diffing the resulting `u64`. [`DynamicFilterTracker`] replaces that with +//! a single up-front walk that subscribes to each still-incomplete dynamic +//! filter; subsequent checks only poll the (shrinking) set of subscriptions, +//! each of which is a cheap atomic load in the common "nothing changed" case. +//! +//! [`PhysicalExpr::snapshot_generation`]: crate::PhysicalExpr::snapshot_generation + +use std::sync::Arc; + +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; + +use crate::PhysicalExpr; + +use super::{DynamicFilterPhysicalExpr, DynamicFilterSubscription}; + +/// Classification of a predicate according to the dynamic filters it contains. +/// +/// Produced by [`DynamicFilterTracking::classify`] with a single tree walk so +/// callers can answer both "is it worth pruning at all?" and "do I need to keep +/// watching?" without traversing the predicate twice. +#[derive(Debug)] +pub enum DynamicFilterTracking { + /// The predicate contains no [`DynamicFilterPhysicalExpr`] at all. It is + /// fully static and will never change. + Static, + /// The predicate contains one or more dynamic filters, but all of them have + /// already been marked complete. Their *current* values may differ from + /// what was known at planning time (so a one-shot prune is still + /// worthwhile), but they will not change again — there is nothing to watch. + AllComplete, + /// The predicate contains at least one dynamic filter that can still change. + /// The embedded [`DynamicFilterTracker`] should be polled to detect updates. + Watching(DynamicFilterTracker), +} + +impl DynamicFilterTracking { + /// Walk `predicate` once and classify its dynamic-filter content, + /// subscribing to every filter that is not yet complete. + pub fn classify(predicate: &Arc) -> Self { + let mut subscriptions = Vec::new(); + let mut found_any = false; + predicate + .apply(|expr| { + if let Some(filter) = expr.downcast_ref::() { + found_any = true; + // Already-complete filters can never change again, so there + // is no point subscribing to them. + if !filter.is_complete() { + subscriptions.push(filter.subscribe()); + } + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("traversal closure is infallible"); + + if !found_any { + DynamicFilterTracking::Static + } else if subscriptions.is_empty() { + DynamicFilterTracking::AllComplete + } else { + DynamicFilterTracking::Watching(DynamicFilterTracker { subscriptions }) + } + } + + /// `true` if the predicate contains any dynamic filter (complete or not), + /// i.e. its value may differ from what was known at planning time and is + /// therefore worth re-evaluating at least once. + pub fn contains_dynamic_filter(&self) -> bool { + !matches!(self, DynamicFilterTracking::Static) + } + + /// Mutable access to the underlying tracker when there is still something to + /// watch. + pub fn watcher(&mut self) -> Option<&mut DynamicFilterTracker> { + match self { + DynamicFilterTracking::Watching(tracker) => Some(tracker), + _ => None, + } + } +} + +/// Watches every still-incomplete [`DynamicFilterPhysicalExpr`] reachable from a +/// predicate and reports, cheaply, whether any of them has been updated since +/// the last check. +/// +/// Obtain one from [`DynamicFilterTracking::classify`] via +/// [`DynamicFilterTracking::watcher`]; the `Watching` variant carries it only +/// when there is at least one dynamic filter that can still change. +#[derive(Debug)] +pub struct DynamicFilterTracker { + /// Subscriptions to the not-yet-complete dynamic filters. Entries are + /// dropped as their filters complete, so the set only shrinks. + subscriptions: Vec, +} + +impl DynamicFilterTracker { + /// Returns `true` if any watched filter's expression has advanced since the + /// previous call. + /// + /// Filters that have completed are dropped from the watch set as they are + /// observed; once every filter has completed this is a no-op that always + /// returns `false`. + pub fn changed(&mut self) -> bool { + let mut changed = false; + self.subscriptions.retain_mut(|subscription| { + let change = subscription.observe(); + changed |= change.changed; + // Keep the subscription only while the filter can still change. + !change.complete + }); + changed + } +} + +#[cfg(test)] +impl DynamicFilterTracker { + /// Build a tracker directly, or `None` if `predicate` has no dynamic filter + /// that can still change. Test-only; production builds a tracker via + /// [`DynamicFilterTracking::classify`] + [`DynamicFilterTracking::watcher`]. + fn try_new(predicate: &Arc) -> Option { + match DynamicFilterTracking::classify(predicate) { + DynamicFilterTracking::Watching(tracker) => Some(tracker), + DynamicFilterTracking::Static | DynamicFilterTracking::AllComplete => None, + } + } + + /// `true` once every watched filter has completed and been dropped. + fn is_exhausted(&self) -> bool { + self.subscriptions.is_empty() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::expressions::{BinaryExpr, col, lit}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_expr::Operator; + + /// `col > ` where the dynamic filter starts as `lit(true)`. + fn dynamic_predicate() -> (Arc, Arc) { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let column = col("a", &schema).unwrap(); + let filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&column)], + lit(true), + )); + let predicate = Arc::new(BinaryExpr::new( + column, + Operator::Gt, + Arc::clone(&filter) as Arc, + )) as Arc; + (predicate, filter) + } + + #[test] + fn static_predicate_is_not_watched() { + let predicate = lit(true); + assert!(matches!( + DynamicFilterTracking::classify(&predicate), + DynamicFilterTracking::Static + )); + assert!(DynamicFilterTracker::try_new(&predicate).is_none()); + } + + #[test] + fn already_complete_filter_is_not_watched() { + let (predicate, filter) = dynamic_predicate(); + filter.mark_complete(); + + match DynamicFilterTracking::classify(&predicate) { + DynamicFilterTracking::AllComplete => {} + other => panic!("expected AllComplete, got {other:?}"), + } + // Still reported as dynamic (worth a one-shot prune)... + assert!(DynamicFilterTracking::classify(&predicate).contains_dynamic_filter()); + // ...but there is nothing to watch. + assert!(DynamicFilterTracker::try_new(&predicate).is_none()); + } + + #[test] + fn detects_update_exactly_once() { + let (predicate, filter) = dynamic_predicate(); + let mut tracker = DynamicFilterTracker::try_new(&predicate) + .expect("predicate has an incomplete dynamic filter"); + + // No update yet. + assert!(!tracker.changed()); + + filter.update(lit(false)).unwrap(); + // The update is reported once... + assert!(tracker.changed()); + // ...and not repeatedly. + assert!(!tracker.changed()); + } + + #[test] + fn update_before_subscribe_is_not_reported() { + let (predicate, filter) = dynamic_predicate(); + + // An update that happens *before* the tracker subscribes must not be + // reported on the first poll: `subscribe()` snapshots the current + // generation via `borrow_and_update()`, so only post-subscription + // updates count. + filter.update(lit(false)).unwrap(); + + let mut tracker = DynamicFilterTracker::try_new(&predicate) + .expect("predicate has an incomplete dynamic filter"); + assert!(!tracker.changed()); + + // A subsequent update is still reported. + filter.update(lit(true)).unwrap(); + assert!(tracker.changed()); + } + + #[test] + fn mark_complete_does_not_count_as_a_change() { + let (predicate, filter) = dynamic_predicate(); + let mut tracker = DynamicFilterTracker::try_new(&predicate).unwrap(); + + filter.update(lit(false)).unwrap(); + assert!(tracker.changed()); + + // `mark_complete()` re-broadcasts the current generation without + // changing the expression: it must not trigger a spurious rebuild. + filter.mark_complete(); + assert!(!tracker.changed()); + // The filter has completed, so the tracker drains itself. + assert!(tracker.is_exhausted()); + } + + #[test] + fn coalesced_update_then_complete_is_one_change() { + let (predicate, filter) = dynamic_predicate(); + let mut tracker = DynamicFilterTracker::try_new(&predicate).unwrap(); + + // Update and complete before the tracker gets a chance to observe. + // The watch channel only retains the latest value, so the tracker sees + // `Complete` directly; it must still report the (final) change once. + filter.update(lit(false)).unwrap(); + filter.mark_complete(); + + assert!(tracker.changed()); + assert!(tracker.is_exhausted()); + assert!(!tracker.changed()); + } + + #[test] + fn watches_multiple_filters_independently() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let col_a = col("a", &schema).unwrap(); + let col_b = col("b", &schema).unwrap(); + let filter_a = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&col_a)], + lit(true), + )); + let filter_b = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&col_b)], + lit(true), + )); + let predicate = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + col_a, + Operator::Gt, + Arc::clone(&filter_a) as Arc, + )), + Operator::And, + Arc::new(BinaryExpr::new( + col_b, + Operator::Lt, + Arc::clone(&filter_b) as Arc, + )), + )) as Arc; + + let mut tracker = DynamicFilterTracker::try_new(&predicate).unwrap(); + assert!(!tracker.changed()); + + filter_a.update(lit(false)).unwrap(); + assert!(tracker.changed()); + assert!(!tracker.changed()); + + filter_b.update(lit(false)).unwrap(); + assert!(tracker.changed()); + assert!(!tracker.changed()); + + // Completing one filter leaves the other still watched. + filter_a.mark_complete(); + assert!(!tracker.changed()); + assert!(!tracker.is_exhausted()); + + filter_b.mark_complete(); + assert!(!tracker.changed()); + assert!(tracker.is_exhausted()); + } +} diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index bb033aac03ed6..1d3e244d73971 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -17,38 +17,32 @@ //! Implementation of `InList` expressions: [`InListExpr`] -use std::any::Any; use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::physical_expr::physical_exprs_bag_equal; use crate::PhysicalExpr; +use crate::physical_expr::physical_exprs_bag_equal; use arrow::array::*; -use arrow::buffer::BooleanBuffer; +use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::compute::SortOptions; use arrow::compute::kernels::boolean::{not, or_kleene}; -use arrow::compute::{take, SortOptions}; +use arrow::compute::kernels::cmp::eq as arrow_eq; use arrow::datatypes::*; -use arrow::util::bit_iterator::BitIndexIterator; -use datafusion_common::hash_utils::with_hashes; + use datafusion_common::{ - assert_or_internal_err, exec_datafusion_err, exec_err, DFSchema, HashSet, Result, - ScalarValue, + DFSchema, Result, ScalarValue, assert_or_internal_err, exec_err, }; -use datafusion_expr::{expr_vec_fmt, ColumnarValue}; - -use ahash::RandomState; -use datafusion_common::HashMap; -use hashbrown::hash_map::RawEntryMut; +use datafusion_expr::{ColumnarValue, expr_vec_fmt}; -/// Trait for InList static filters -trait StaticFilter { - fn null_count(&self) -> usize; +mod array_static_filter; +mod primitive_filter; +mod static_filter; +mod strategy; - /// Checks if values in `v` are contained in the filter - fn contains(&self, v: &dyn Array, negated: bool) -> Result; -} +use static_filter::StaticFilter; +use strategy::instantiate_static_filter; /// InList pub struct InListExpr { @@ -68,195 +62,18 @@ impl Debug for InListExpr { } } -/// Static filter for InList that stores the array and hash set for O(1) lookups -#[derive(Debug, Clone)] -struct ArrayStaticFilter { - in_array: ArrayRef, - state: RandomState, - /// Used to provide a lookup from value to in list index - /// - /// Note: usize::hash is not used, instead the raw entry - /// API is used to store entries w.r.t their value - map: HashMap, -} - -impl StaticFilter for ArrayStaticFilter { - fn null_count(&self) -> usize { - self.in_array.null_count() - } - - /// Checks if values in `v` are contained in the `in_array` using this hash set for lookup. - fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Null type comparisons always return null (SQL three-valued logic) - if v.data_type() == &DataType::Null - || self.in_array.data_type() == &DataType::Null - { - return Ok(BooleanArray::from(vec![None; v.len()])); - } - - downcast_dictionary_array! { - v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) - } - _ => {} - } - - let needle_nulls = v.logical_nulls(); - let needle_nulls = needle_nulls.as_ref(); - let haystack_has_nulls = self.in_array.null_count() != 0; - - with_hashes([v], &self.state, |hashes| { - let cmp = make_comparator(v, &self.in_array, SortOptions::default())?; - Ok((0..v.len()) - .map(|i| { - // SQL three-valued logic: null IN (...) is always null - if needle_nulls.is_some_and(|nulls| nulls.is_null(i)) { - return None; - } - - let hash = hashes[i]; - let contains = self - .map - .raw_entry() - .from_hash(hash, |idx| cmp(i, *idx).is_eq()) - .is_some(); - - match contains { - true => Some(!negated), - false if haystack_has_nulls => None, - false => Some(negated), - } - }) - .collect()) - }) - } -} - -fn instantiate_static_filter( - in_array: ArrayRef, -) -> Result> { - match in_array.data_type() { - DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), - _ => { - /* fall through to generic implementation */ - Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) - } - } -} - -impl ArrayStaticFilter { - /// Computes a [`StaticFilter`] for the provided [`Array`] if there - /// are nulls present or there are more than the configured number of - /// elements. - /// - /// Note: This is split into a separate function as higher-rank trait bounds currently - /// cause type inference to misbehave - fn try_new(in_array: ArrayRef) -> Result { - // Null type has no natural order - return empty hash set - if in_array.data_type() == &DataType::Null { - return Ok(ArrayStaticFilter { - in_array, - state: RandomState::new(), - map: HashMap::with_hasher(()), - }); - } - - let state = RandomState::new(); - let mut map: HashMap = HashMap::with_hasher(()); - - with_hashes([&in_array], &state, |hashes| -> Result<()> { - let cmp = make_comparator(&in_array, &in_array, SortOptions::default())?; - - let insert_value = |idx| { - let hash = hashes[idx]; - if let RawEntryMut::Vacant(v) = map - .raw_entry_mut() - .from_hash(hash, |x| cmp(*x, idx).is_eq()) - { - v.insert_with_hasher(hash, idx, (), |x| hashes[*x]); - } - }; - - match in_array.nulls() { - Some(nulls) => { - BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) - .for_each(insert_value) - } - None => (0..in_array.len()).for_each(insert_value), - } - - Ok(()) - })?; - - Ok(Self { - in_array, - state, - map, - }) - } -} - -struct Int32StaticFilter { - null_count: usize, - values: HashSet, -} - -impl Int32StaticFilter { - fn try_new(in_array: &ArrayRef) -> Result { - let in_array = in_array - .as_primitive_opt::() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; - - let mut values = HashSet::with_capacity(in_array.len()); - let null_count = in_array.null_count(); - - for v in in_array.iter().flatten() { - values.insert(v); - } - - Ok(Self { null_count, values }) - } -} - -impl StaticFilter for Int32StaticFilter { - fn null_count(&self) -> usize { - self.null_count - } - - fn contains(&self, v: &dyn Array, negated: bool) -> Result { - let v = v - .as_primitive_opt::() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; - - let result = match (v.null_count() > 0, negated) { - (true, false) => { - // has nulls, not negated" - BooleanArray::from_iter( - v.iter().map(|value| Some(self.values.contains(&value?))), - ) - } - (true, true) => { - // has nulls, negated - BooleanArray::from_iter( - v.iter().map(|value| Some(!self.values.contains(&value?))), - ) - } - (false, false) => { - //no null, not negated - BooleanArray::from_iter( - v.values().iter().map(|value| self.values.contains(value)), - ) - } - (false, true) => { - // no null, negated - BooleanArray::from_iter( - v.values().iter().map(|value| !self.values.contains(value)), - ) - } - }; - Ok(result) +/// Returns true if Arrow's vectorized `eq` kernel supports this data type. +/// +/// Supported: primitives, boolean, strings (Utf8/LargeUtf8/Utf8View), +/// binary (Binary/LargeBinary/BinaryView/FixedSizeBinary), Null, and +/// Dictionary-encoded variants of the above. +/// Unsupported: nested types (Struct, List, Map, Union) and RunEndEncoded. +fn supports_arrow_eq(dt: &DataType) -> bool { + use DataType::*; + match dt { + Boolean | Binary | LargeBinary | BinaryView | FixedSizeBinary(_) => true, + Dictionary(_, v) => supports_arrow_eq(v.as_ref()), + _ => dt.is_primitive() || dt.is_null() || dt.is_string(), } } @@ -284,15 +101,43 @@ fn evaluate_list( /// Try to evaluate a list of expressions as constants. /// -/// Returns an ArrayRef if all expressions are constants (can be evaluated on an -/// empty RecordBatch), otherwise returns an error. This is used to detect when -/// a list contains only literals, casts of literals, or other constant expressions. +/// Returns: +/// - `Ok(Some(ArrayRef))` if all expressions are constants (can be evaluated on an empty RecordBatch) +/// - `Ok(None)` if the list contains non-constant expressions +/// - `Err(...)` only for actual errors (not for non-constant expressions) +/// +/// This is used to detect when a list contains only literals, casts of literals, +/// or other constant expressions. fn try_evaluate_constant_list( list: &[Arc], schema: &Schema, -) -> Result { +) -> Result> { let batch = RecordBatch::new_empty(Arc::new(schema.clone())); - evaluate_list(list, &batch) + match evaluate_list(list, &batch) { + Ok(array) => Ok(Some(array)), + Err(_) => { + // Non-constant expressions can't be evaluated on an empty batch + // This is not an error, just means we can't use a static filter + Ok(None) + } + } +} + +/// Asserts that the InList expression's data type matches a list element's +/// data type. `DataType::Null` list elements are accepted unconditionally so +/// that null literals and `NullArray` haystacks remain compatible with any +/// expression type. +fn assert_inlist_data_types_match( + expr_data_type: &DataType, + list_data_type: &DataType, +) -> Result<()> { + if *list_data_type != DataType::Null { + assert_or_internal_err!( + DFSchema::datatype_is_logically_equal(expr_data_type, list_data_type), + "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_data_type}" + ); + } + Ok(()) } impl InListExpr { @@ -321,6 +166,14 @@ impl InListExpr { &self.list } + pub fn is_empty(&self) -> bool { + self.list.is_empty() + } + + pub fn len(&self) -> usize { + self.list.len() + } + /// Is this negated e.g. NOT IN LIST pub fn negated(&self) -> bool { self.negated @@ -328,20 +181,28 @@ impl InListExpr { /// Create a new InList expression directly from an array, bypassing expression evaluation. /// - /// This is more efficient than `in_list()` when you already have the list as an array, - /// as it avoids the conversion: `ArrayRef -> Vec -> ArrayRef -> StaticFilter`. - /// Instead it goes directly: `ArrayRef -> StaticFilter`. + /// This is more efficient than [`InListExpr::try_new`] when you already have the list + /// as an array, as it builds the static filter directly from the array instead of + /// reconstructing an intermediate array from literal expressions. + /// + /// The `list` field is populated with literal expressions extracted from + /// the array, and the array is used to build a static filter for + /// efficient set membership evaluation. /// - /// The `list` field will be empty when using this constructor, as the array is stored - /// directly in the static filter. + /// The `array` may be dictionary-encoded — it will be flattened to its + /// value type such that specialized filters are used. /// - /// This does not make the expression any more performant at runtime, but it does make it slightly - /// cheaper to build. + /// Returns an error if the expression's data type and the array's data type + /// are not logically equal. Null arrays are always accepted. pub fn try_new_from_array( expr: Arc, array: ArrayRef, negated: bool, + schema: &Schema, ) -> Result { + let expr_data_type = expr.data_type(schema)?; + assert_inlist_data_types_match(&expr_data_type, array.data_type())?; + let list = (0..array.len()) .map(|i| { let scalar = ScalarValue::try_from_array(array.as_ref(), i)?; @@ -355,6 +216,62 @@ impl InListExpr { Some(instantiate_static_filter(array)?), )) } + + /// Create a new InList expression, using a static filter when possible. + /// + /// This validates data types and attempts to create a static filter for constant + /// list expressions. Uses specialized StaticFilter implementations for better + /// performance (e.g., Int32StaticFilter for Int32). + /// + /// Returns an error if data types don't match. If the list contains non-constant + /// expressions, falls back to dynamic evaluation at runtime. + pub fn try_new( + expr: Arc, + list: Vec>, + negated: bool, + schema: &Schema, + ) -> Result { + // Check the data types match + let expr_data_type = expr.data_type(schema)?; + for list_expr in list.iter() { + let list_expr_data_type = list_expr.data_type(schema)?; + assert_inlist_data_types_match(&expr_data_type, &list_expr_data_type)?; + } + + // Try to create a static filter if all list expressions are constants + let static_filter = match try_evaluate_constant_list(&list, schema)? { + Some(in_array) => Some(instantiate_static_filter(in_array)?), + None => None, // Non-constant expressions, fall back to dynamic evaluation + }; + + Ok(Self::new(expr, list, negated, static_filter)) + } + + #[cfg(feature = "proto")] + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_physical_expr_common::expect_expr_variant; + use datafusion_proto_models::protobuf; + + let node = expect_expr_variant!( + node, + protobuf::physical_expr_node::ExprType::InList, + "InList", + ); + + let expr = + ctx.decode_required_expression(node.expr.as_deref(), "InListExpr", "expr")?; + let list = ctx.decode_children_expressions(&node.list)?; + + Ok(Arc::new(InListExpr::try_new( + expr, + list, + node.negated, + ctx.schema(), + )?)) + } } impl std::fmt::Display for InListExpr { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { @@ -375,11 +292,6 @@ impl std::fmt::Display for InListExpr { } impl PhysicalExpr for InListExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - fn data_type(&self, _input_schema: &Schema) -> Result { Ok(DataType::Boolean) } @@ -414,8 +326,12 @@ impl PhysicalExpr for InListExpr { if scalar.is_null() { // SQL three-valued logic: null IN (...) is always null // The code below would handle this correctly but this is a faster path + let nulls = NullBuffer::new_null(num_rows); return Ok(ColumnarValue::Array(Arc::new( - BooleanArray::from(vec![None; num_rows]), + BooleanArray::new( + BooleanBuffer::new_unset(num_rows), + Some(nulls), + ), ))); } // Use a 1 row array to avoid code duplication/branching @@ -426,73 +342,93 @@ impl PhysicalExpr for InListExpr { // Broadcast the single result to all rows // Must check is_null() to preserve NULL values (SQL three-valued logic) if result_array.is_null(0) { - BooleanArray::from(vec![None; num_rows]) + let nulls = NullBuffer::new_null(num_rows); + BooleanArray::new( + BooleanBuffer::new_unset(num_rows), + Some(nulls), + ) + } else if result_array.value(0) { + BooleanArray::new(BooleanBuffer::new_set(num_rows), None) } else { - BooleanArray::from_iter(std::iter::repeat_n( - result_array.value(0), - num_rows, - )) + BooleanArray::new(BooleanBuffer::new_unset(num_rows), None) } } } } None => { - // No static filter: iterate through each expression, compare, and OR results + // No static filter: iterate through each expression, compare, and OR results. + // Use Arrow's vectorized eq kernel for types it supports (primitive, + // boolean, string, binary, dictionary), falling back to row-by-row + // comparator for unsupported types (nested, RunEndEncoded, etc.). let value = value.into_array(num_rows)?; - let found = self.list.iter().map(|expr| expr.evaluate(batch)).try_fold( - BooleanArray::new(BooleanBuffer::new_unset(num_rows), None), - |result, expr| -> Result { - let rhs = match expr? { - ColumnarValue::Array(array) => { + let lhs_supports_arrow_eq = supports_arrow_eq(value.data_type()); + + // Helper: compare value against a single list expression + let compare_one = |expr: &Arc| -> Result { + match expr.evaluate(batch)? { + ColumnarValue::Array(array) => { + if lhs_supports_arrow_eq + && supports_arrow_eq(array.data_type()) + { + Ok(arrow_eq(&value, &array)?) + } else { let cmp = make_comparator( value.as_ref(), array.as_ref(), SortOptions::default(), )?; - (0..num_rows) - .map(|i| { - if value.is_null(i) || array.is_null(i) { - return None; - } - Some(cmp(i, i).is_eq()) - }) - .collect::() + let buffer = BooleanBuffer::collect_bool(num_rows, |i| { + cmp(i, i).is_eq() + }); + let nulls = + NullBuffer::union(value.nulls(), array.nulls()); + Ok(BooleanArray::new(buffer, nulls)) } - ColumnarValue::Scalar(scalar) => { - // Check if scalar is null once, before the loop - if scalar.is_null() { - // If scalar is null, all comparisons return null - BooleanArray::from(vec![None; num_rows]) - } else { - // Convert scalar to 1-element array - let array = scalar.to_array()?; - let cmp = make_comparator( - value.as_ref(), - array.as_ref(), - SortOptions::default(), - )?; - // Compare each row of value with the single scalar element - (0..num_rows) - .map(|i| { - if value.is_null(i) { - None - } else { - Some(cmp(i, 0).is_eq()) - } - }) - .collect::() - } + } + ColumnarValue::Scalar(scalar) => { + // Check if scalar is null once, before the loop + if scalar.is_null() { + // If scalar is null, all comparisons return null + Ok(BooleanArray::from(vec![None; num_rows])) + } else if lhs_supports_arrow_eq { + let scalar_datum = scalar.to_scalar()?; + Ok(arrow_eq(&value, &scalar_datum)?) + } else { + // Convert scalar to 1-element array + let array = scalar.to_array()?; + let cmp = make_comparator( + value.as_ref(), + array.as_ref(), + SortOptions::default(), + )?; + // Compare each row of value with the single scalar element + let buffer = BooleanBuffer::collect_bool(num_rows, |i| { + cmp(i, 0).is_eq() + }); + Ok(BooleanArray::new(buffer, value.nulls().cloned())) } - }; - Ok(or_kleene(&result, &rhs)?) - }, - )?; + } + } + }; - if self.negated { - not(&found)? + // Evaluate first expression directly to avoid a redundant + // or_kleene with an all-false accumulator. + let mut found = if let Some(first) = self.list.first() { + compare_one(first)? } else { - found + BooleanArray::new(BooleanBuffer::new_unset(num_rows), None) + }; + + for expr in self.list.iter().skip(1) { + // Short-circuit: if every non-null row is already true, + // no further list items can change the result. + if found.null_count() == 0 && !found.has_false() { + break; + } + found = or_kleene(&found, &compare_one(expr)?)?; } + + if self.negated { not(&found)? } else { found } } }; Ok(ColumnarValue::Array(Arc::new(r))) @@ -532,6 +468,25 @@ impl PhysicalExpr for InListExpr { } write!(f, ")") } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::InList(Box::new( + protobuf::PhysicalInListNode { + expr: Some(Box::new(ctx.encode_child(&self.expr)?)), + list: ctx.encode_children_expressions(&self.list)?, + negated: self.negated, + }, + ))), + })) + } } impl PartialEq for InListExpr { @@ -560,38 +515,14 @@ pub fn in_list( negated: &bool, schema: &Schema, ) -> Result> { - // check the data type - let expr_data_type = expr.data_type(schema)?; - for list_expr in list.iter() { - let list_expr_data_type = list_expr.data_type(schema)?; - assert_or_internal_err!( - DFSchema::datatype_is_logically_equal(&expr_data_type, &list_expr_data_type), - "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}" - ); - } - - // Try to create a static filter for constant expressions - let static_filter = try_evaluate_constant_list(&list, schema) - .and_then(ArrayStaticFilter::try_new) - .ok() - .map(|static_filter| { - Arc::new(static_filter) as Arc - }); - - Ok(Arc::new(InListExpr::new( - expr, - list, - *negated, - static_filter, - ))) + Ok(Arc::new(InListExpr::try_new(expr, list, *negated, schema)?)) } #[cfg(test)] mod tests { use super::*; - use crate::expressions; use crate::expressions::{col, lit, try_cast}; - use arrow::buffer::NullBuffer; + use arrow::datatypes::{IntervalDayTime, IntervalMonthDayNano, i256}; use datafusion_common::plan_err; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_physical_expr_common::physical_expr::fmt_sql; @@ -632,14 +563,6 @@ mod tests { } } - fn try_cast_static_filter_to_set( - list: &[Arc], - schema: &Schema, - ) -> Result { - let array = try_evaluate_constant_list(list, schema)?; - ArrayStaticFilter::try_new(array) - } - // Attempts to coerce the types of `list_type` to be comparable with the // `expr_type` fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) -> Option { @@ -691,124 +614,504 @@ mod tests { /// and list expressions are already the correct types and don't require casting. macro_rules! in_list_raw { ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{ - let expr = in_list($COL, $LIST, $NEGATED, $SCHEMA).unwrap(); + let col_expr = $COL; + let expr = in_list(Arc::clone(&col_expr), $LIST, $NEGATED, $SCHEMA).unwrap(); let result = expr .evaluate(&$BATCH)? .into_array($BATCH.num_rows()) .expect("Failed to convert to array"); let result = as_boolean_array(&result); let expected = &BooleanArray::from($EXPECTED); - assert_eq!(expected, result); + assert_eq!( + expected, + result, + "Failed for: {}\n{}: {:?}", + fmt_sql(expr.as_ref()), + fmt_sql(col_expr.as_ref()), + col_expr + .evaluate(&$BATCH)? + .into_array($BATCH.num_rows()) + .unwrap() + ); }}; } + /// Test case for primitive types following the standard IN LIST pattern. + /// + /// Each test case represents a data type with: + /// - `value_in`: A value that appears in both the test array and the IN list (matches → true) + /// - `value_not_in`: A value that appears in the test array but NOT in the IN list (doesn't match → false) + /// - `other_list_values`: Additional values in the IN list besides `value_in` + /// - `null_value`: Optional null scalar value for NULL handling tests. When None, tests + /// without nulls are run, exercising the `(false, false)` and `(false, true)` branches. + struct InListPrimitiveTestCase { + name: &'static str, + value_in: ScalarValue, + value_not_in: ScalarValue, + other_list_values: Vec, + null_value: Option, + } + + /// Generic test data struct for primitive types. + /// + /// Holds test values needed for IN LIST tests, allowing the data + /// to be declared explicitly and reused across multiple types. + #[derive(Clone)] + struct PrimitiveTestCaseData { + value_in: T, + value_not_in: T, + other_list_values: Vec, + } + + /// Helper to create test cases for any primitive type using generic data. + /// + /// Uses TryInto for flexible type conversion, allowing test data to be + /// declared in any convertible type (e.g., i32 for all integer types). + /// Creates a test case WITH null support (for null handling tests). + fn primitive_test_case( + name: &'static str, + constructor: F, + data: PrimitiveTestCaseData, + ) -> InListPrimitiveTestCase + where + D: TryInto + Clone, + >::Error: Debug, + F: Fn(Option) -> ScalarValue, + T: Clone, + { + InListPrimitiveTestCase { + name, + value_in: constructor(Some(data.value_in.try_into().unwrap())), + value_not_in: constructor(Some(data.value_not_in.try_into().unwrap())), + other_list_values: data + .other_list_values + .into_iter() + .map(|v| constructor(Some(v.try_into().unwrap()))) + .collect(), + null_value: Some(constructor(None)), + } + } + + /// Helper to create test cases WITHOUT null support. + /// These test cases exercise the `(false, true)` branch (no nulls, negated). + fn primitive_test_case_no_nulls( + name: &'static str, + constructor: F, + data: PrimitiveTestCaseData, + ) -> InListPrimitiveTestCase + where + D: TryInto + Clone, + >::Error: Debug, + F: Fn(Option) -> ScalarValue, + T: Clone, + { + InListPrimitiveTestCase { + name, + value_in: constructor(Some(data.value_in.try_into().unwrap())), + value_not_in: constructor(Some(data.value_not_in.try_into().unwrap())), + other_list_values: data + .other_list_values + .into_iter() + .map(|v| constructor(Some(v.try_into().unwrap()))) + .collect(), + null_value: None, + } + } + + /// Runs test cases for multiple types, providing detailed SQL error messages on failure. + /// + /// For each test case, runs IN LIST scenarios based on whether null_value is Some or None: + /// - With null_value (Some): 4 tests including null handling + /// - Without null_value (None): 2 tests exercising the no-nulls paths + fn run_test_cases(test_cases: Vec) -> Result<()> { + for test_case in test_cases { + let test_name = test_case.name; + + // Get the data type from the scalar value + let data_type = test_case.value_in.data_type(); + + // Build the base list: [value_in, ...other_list_values] + let build_base_list = || -> Vec> { + let mut list = vec![lit(test_case.value_in.clone())]; + list.extend(test_case.other_list_values.iter().map(|v| lit(v.clone()))); + list + }; + + match &test_case.null_value { + Some(null_val) => { + // Tests WITH nulls in the needle array + let schema = + Schema::new(vec![Field::new("a", data_type.clone(), true)]); + + // Create array from scalar values: [value_in, value_not_in, None] + let array = ScalarValue::iter_to_array(vec![ + test_case.value_in.clone(), + test_case.value_not_in.clone(), + null_val.clone(), + ])?; + + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::clone(&array)], + )?; + + // Test 1: a IN (list) → [true, false, null] + let list = build_base_list(); + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // Test 2: a NOT IN (list) → [false, true, null] + let list = build_base_list(); + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // Test 3: a IN (list, NULL) → [true, null, null] + let mut list = build_base_list(); + list.push(lit(null_val.clone())); + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // Test 4: a NOT IN (list, NULL) → [false, null, null] + let mut list = build_base_list(); + list.push(lit(null_val.clone())); + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + } + None => { + // Tests WITHOUT nulls - exercises the (false, false) and (false, true) branches + let schema = + Schema::new(vec![Field::new("a", data_type.clone(), false)]); + + // Create array from scalar values: [value_in, value_not_in] (no NULL) + let array = ScalarValue::iter_to_array(vec![ + test_case.value_in.clone(), + test_case.value_not_in.clone(), + ])?; + + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::clone(&array)], + )?; + + // Test 1: a IN (list) → [true, false] - exercises (false, false) branch + let list = build_base_list(); + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false)], + Arc::clone(&col_a), + &schema + ); + + // Test 2: a NOT IN (list) → [false, true] - exercises (false, true) branch + let list = build_base_list(); + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true)], + Arc::clone(&col_a), + &schema + ); + + eprintln!( + "Test '{test_name}': exercised (false, true) branch (no nulls, negated)", + ); + } + } + } + + Ok(()) + } + + /// Test IN LIST for all integer types (Int8/16/32/64, UInt8/16/32/64). + /// + /// Test data: 0 (in list), 2 (not in list), [1, 3, 5] (other list values) #[test] - fn in_list_utf8() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); - let a = StringArray::from(vec![Some("a"), Some("d"), None]); + fn in_list_int_types() -> Result<()> { + let int_data = PrimitiveTestCaseData { + value_in: 0, + value_not_in: 2, + other_list_values: vec![1, 3, 5], + }; + + run_test_cases(vec![ + // Tests WITH nulls + primitive_test_case("int8", ScalarValue::Int8, int_data.clone()), + primitive_test_case("int16", ScalarValue::Int16, int_data.clone()), + primitive_test_case("int32", ScalarValue::Int32, int_data.clone()), + primitive_test_case("int64", ScalarValue::Int64, int_data.clone()), + primitive_test_case("uint8", ScalarValue::UInt8, int_data.clone()), + primitive_test_case("uint16", ScalarValue::UInt16, int_data.clone()), + primitive_test_case("uint32", ScalarValue::UInt32, int_data.clone()), + primitive_test_case("uint64", ScalarValue::UInt64, int_data.clone()), + // Tests WITHOUT nulls - exercises (false, true) branch + primitive_test_case_no_nulls("int32_no_nulls", ScalarValue::Int32, int_data), + ]) + } + + /// Test IN LIST for all string types (Utf8, LargeUtf8, Utf8View). + /// + /// Test data: "a" (in list), "d" (not in list), ["b", "c"] (other list values) + #[test] + fn in_list_string_types() -> Result<()> { + let string_data = PrimitiveTestCaseData { + value_in: "a", + value_not_in: "d", + other_list_values: vec!["b", "c"], + }; + + run_test_cases(vec![ + primitive_test_case("utf8", ScalarValue::Utf8, string_data.clone()), + primitive_test_case( + "large_utf8", + ScalarValue::LargeUtf8, + string_data.clone(), + ), + primitive_test_case("utf8_view", ScalarValue::Utf8View, string_data), + ]) + } + + /// Test IN LIST for all binary types (Binary, LargeBinary, BinaryView). + /// + /// Test data: [1,2,3] (in list), [1,2,2] (not in list), [[4,5,6], [7,8,9]] (other list values) + #[test] + fn in_list_binary_types() -> Result<()> { + let binary_data = PrimitiveTestCaseData { + value_in: vec![1_u8, 2, 3], + value_not_in: vec![1_u8, 2, 2], + other_list_values: vec![vec![4_u8, 5, 6], vec![7_u8, 8, 9]], + }; + + run_test_cases(vec![ + primitive_test_case("binary", ScalarValue::Binary, binary_data.clone()), + primitive_test_case( + "large_binary", + ScalarValue::LargeBinary, + binary_data.clone(), + ), + primitive_test_case("binary_view", ScalarValue::BinaryView, binary_data), + ]) + } + + /// Test IN LIST for date types (Date32, Date64). + /// + /// Test data: 0 (in list), 2 (not in list), [1, 3] (other list values) + #[test] + fn in_list_date_types() -> Result<()> { + let date_data = PrimitiveTestCaseData { + value_in: 0, + value_not_in: 2, + other_list_values: vec![1, 3], + }; + + run_test_cases(vec![ + primitive_test_case("date32", ScalarValue::Date32, date_data.clone()), + primitive_test_case("date64", ScalarValue::Date64, date_data), + ]) + } + + /// Test IN LIST for Decimal128 type. + /// + /// Test data: 0 (in list), 200 (not in list), [100, 300] (other list values) with precision=10, scale=2 + #[test] + fn in_list_decimal() -> Result<()> { + run_test_cases(vec![InListPrimitiveTestCase { + name: "decimal128", + value_in: ScalarValue::Decimal128(Some(0), 10, 2), + value_not_in: ScalarValue::Decimal128(Some(200), 10, 2), + other_list_values: vec![ + ScalarValue::Decimal128(Some(100), 10, 2), + ScalarValue::Decimal128(Some(300), 10, 2), + ], + null_value: Some(ScalarValue::Decimal128(None, 10, 2)), + }]) + } + + /// Test IN LIST for timestamp types. + /// + /// Test data: 0 (in list), 2000 (not in list), [1000, 3000] (other list values) + #[test] + fn in_list_timestamp_types() -> Result<()> { + run_test_cases(vec![ + InListPrimitiveTestCase { + name: "timestamp_nanosecond", + value_in: ScalarValue::TimestampNanosecond(Some(0), None), + value_not_in: ScalarValue::TimestampNanosecond(Some(2000), None), + other_list_values: vec![ + ScalarValue::TimestampNanosecond(Some(1000), None), + ScalarValue::TimestampNanosecond(Some(3000), None), + ], + null_value: Some(ScalarValue::TimestampNanosecond(None, None)), + }, + InListPrimitiveTestCase { + name: "timestamp_millisecond_with_tz", + value_in: ScalarValue::TimestampMillisecond( + Some(1500000), + Some("+05:00".into()), + ), + value_not_in: ScalarValue::TimestampMillisecond( + Some(2500000), + Some("+05:00".into()), + ), + other_list_values: vec![ScalarValue::TimestampMillisecond( + Some(3500000), + Some("+05:00".into()), + )], + null_value: Some(ScalarValue::TimestampMillisecond( + None, + Some("+05:00".into()), + )), + }, + InListPrimitiveTestCase { + name: "timestamp_millisecond_mixed_tz", + value_in: ScalarValue::TimestampMillisecond( + Some(1500000), + Some("+05:00".into()), + ), + value_not_in: ScalarValue::TimestampMillisecond( + Some(2500000), + Some("+05:00".into()), + ), + other_list_values: vec![ + ScalarValue::TimestampMillisecond( + Some(3500000), + Some("+01:00".into()), + ), + ScalarValue::TimestampMillisecond(Some(4500000), Some("UTC".into())), + ], + null_value: Some(ScalarValue::TimestampMillisecond( + None, + Some("+05:00".into()), + )), + }, + ]) + } + + #[test] + fn in_list_float64() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + let a = Float64Array::from(vec![ + Some(0.0), + Some(0.2), + None, + Some(f64::NAN), + Some(-f64::NAN), + ]); let col_a = col("a", &schema)?; let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - // expression: "a in ("a", "b")" - let list = vec![lit("a"), lit("b")]; + // expression: "a in (0.0, 0.1)" + let list = vec![lit(0.0f64), lit(0.1f64)]; in_list!( batch, list, &false, - vec![Some(true), Some(false), None], + vec![Some(true), Some(false), None, Some(false), Some(false)], Arc::clone(&col_a), &schema ); - // expression: "a not in ("a", "b")" - let list = vec![lit("a"), lit("b")]; + // expression: "a not in (0.0, 0.1)" + let list = vec![lit(0.0f64), lit(0.1f64)]; in_list!( batch, list, &true, - vec![Some(false), Some(true), None], + vec![Some(false), Some(true), None, Some(true), Some(true)], Arc::clone(&col_a), &schema ); - // expression: "a in ("a", "b", null)" - let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; + // expression: "a in (0.0, 0.1, NULL)" + let list = vec![lit(0.0f64), lit(0.1f64), lit(ScalarValue::Null)]; in_list!( batch, list, &false, - vec![Some(true), None, None], + vec![Some(true), None, None, None, None], Arc::clone(&col_a), &schema ); - // expression: "a not in ("a", "b", null)" - let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; + // expression: "a not in (0.0, 0.1, NULL)" + let list = vec![lit(0.0f64), lit(0.1f64), lit(ScalarValue::Null)]; in_list!( batch, list, &true, - vec![Some(false), None, None], + vec![Some(false), None, None, None, None], Arc::clone(&col_a), &schema ); - Ok(()) - } - - #[test] - fn in_list_binary() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Binary, true)]); - let a = BinaryArray::from(vec![ - Some([1, 2, 3].as_slice()), - Some([1, 2, 2].as_slice()), - None, - ]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in ([1, 2, 3], [4, 5, 6])" - let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())]; + // expression: "a in (0.0, 0.1, NaN)" + let list = vec![lit(0.0f64), lit(0.1f64), lit(f64::NAN)]; in_list!( batch, - list.clone(), + list, &false, - vec![Some(true), Some(false), None], + vec![Some(true), Some(false), None, Some(true), Some(false)], Arc::clone(&col_a), &schema ); - // expression: "a not in ([1, 2, 3], [4, 5, 6])" + // expression: "a not in (0.0, 0.1, NaN)" + let list = vec![lit(0.0f64), lit(0.1f64), lit(f64::NAN)]; in_list!( batch, list, &true, - vec![Some(false), Some(true), None], + vec![Some(false), Some(true), None, Some(false), Some(true)], Arc::clone(&col_a), &schema ); - // expression: "a in ([1, 2, 3], [4, 5, 6], null)" - let list = vec![ - lit([1, 2, 3].as_slice()), - lit([4, 5, 6].as_slice()), - lit(ScalarValue::Binary(None)), - ]; + // expression: "a in (0.0, 0.1, -NaN)" + let list = vec![lit(0.0f64), lit(0.1f64), lit(-f64::NAN)]; in_list!( batch, - list.clone(), + list, &false, - vec![Some(true), None, None], + vec![Some(true), Some(false), None, Some(false), Some(true)], Arc::clone(&col_a), &schema ); - // expression: "a in ([1, 2, 3], [4, 5, 6], null)" + // expression: "a not in (0.0, 0.1, -NaN)" + let list = vec![lit(0.0f64), lit(0.1f64), lit(-f64::NAN)]; in_list!( batch, list, &true, - vec![Some(false), None, None], + vec![Some(false), Some(true), None, Some(true), Some(false)], Arc::clone(&col_a), &schema ); @@ -817,52 +1120,52 @@ mod tests { } #[test] - fn in_list_int64() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); - let a = Int64Array::from(vec![Some(0), Some(2), None]); + fn in_list_bool() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); + let a = BooleanArray::from(vec![Some(true), None]); let col_a = col("a", &schema)?; let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - // expression: "a in (0, 1)" - let list = vec![lit(0i64), lit(1i64)]; + // expression: "a in (true)" + let list = vec![lit(true)]; in_list!( batch, list, &false, - vec![Some(true), Some(false), None], + vec![Some(true), None], Arc::clone(&col_a), &schema ); - // expression: "a not in (0, 1)" - let list = vec![lit(0i64), lit(1i64)]; + // expression: "a not in (true)" + let list = vec![lit(true)]; in_list!( batch, list, &true, - vec![Some(false), Some(true), None], + vec![Some(false), None], Arc::clone(&col_a), &schema ); - // expression: "a in (0, 1, NULL)" - let list = vec![lit(0i64), lit(1i64), lit(ScalarValue::Null)]; + // expression: "a in (true, NULL)" + let list = vec![lit(true), lit(ScalarValue::Null)]; in_list!( batch, list, &false, - vec![Some(true), None, None], + vec![Some(true), None], Arc::clone(&col_a), &schema ); - // expression: "a not in (0, 1, NULL)" - let list = vec![lit(0i64), lit(1i64), lit(ScalarValue::Null)]; + // expression: "a not in (true, NULL)" + let list = vec![lit(true), lit(ScalarValue::Null)]; in_list!( batch, list, &true, - vec![Some(false), None, None], + vec![Some(false), None], Arc::clone(&col_a), &schema ); @@ -870,158 +1173,89 @@ mod tests { Ok(()) } + macro_rules! test_nullable { + ($COL:expr, $LIST:expr, $SCHEMA:expr, $EXPECTED:expr) => {{ + let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?; + let expr = in_list(cast_expr, cast_list_exprs, &false, $SCHEMA).unwrap(); + let result = expr.nullable($SCHEMA)?; + assert_eq!($EXPECTED, result); + }}; + } + #[test] - fn in_list_float64() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); - let a = Float64Array::from(vec![ - Some(0.0), - Some(0.2), - None, - Some(f64::NAN), - Some(-f64::NAN), + fn in_list_nullable() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("c1_nullable", DataType::Int64, true), + Field::new("c2_non_nullable", DataType::Int64, false), ]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in (0.0, 0.1)" - let list = vec![lit(0.0f64), lit(0.1f64)]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None, Some(false), Some(false)], - Arc::clone(&col_a), - &schema - ); - - // expression: "a not in (0.0, 0.1)" - let list = vec![lit(0.0f64), lit(0.1f64)]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None, Some(true), Some(true)], - Arc::clone(&col_a), - &schema - ); - // expression: "a in (0.0, 0.1, NULL)" - let list = vec![lit(0.0f64), lit(0.1f64), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &false, - vec![Some(true), None, None, None, None], - Arc::clone(&col_a), - &schema - ); + let c1_nullable = col("c1_nullable", &schema)?; + let c2_non_nullable = col("c2_non_nullable", &schema)?; - // expression: "a not in (0.0, 0.1, NULL)" - let list = vec![lit(0.0f64), lit(0.1f64), lit(ScalarValue::Null)]; - in_list!( - batch, - list, - &true, - vec![Some(false), None, None, None, None], - Arc::clone(&col_a), - &schema - ); + // static_filter has no nulls + let list = vec![lit(1_i64), lit(2_i64)]; + test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); + test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false); - // expression: "a in (0.0, 0.1, NaN)" - let list = vec![lit(0.0f64), lit(0.1f64), lit(f64::NAN)]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None, Some(true), Some(false)], - Arc::clone(&col_a), - &schema - ); + // static_filter has nulls + let list = vec![lit(1_i64), lit(2_i64), lit(ScalarValue::Null)]; + test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); + test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true); - // expression: "a not in (0.0, 0.1, NaN)" - let list = vec![lit(0.0f64), lit(0.1f64), lit(f64::NAN)]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None, Some(false), Some(true)], - Arc::clone(&col_a), - &schema - ); + let list = vec![Arc::clone(&c1_nullable)]; + test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true); - // expression: "a in (0.0, 0.1, -NaN)" - let list = vec![lit(0.0f64), lit(0.1f64), lit(-f64::NAN)]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None, Some(false), Some(true)], - Arc::clone(&col_a), - &schema - ); + let list = vec![Arc::clone(&c2_non_nullable)]; + test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); - // expression: "a not in (0.0, 0.1, -NaN)" - let list = vec![lit(0.0f64), lit(0.1f64), lit(-f64::NAN)]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None, Some(true), Some(false)], - Arc::clone(&col_a), - &schema - ); + let list = vec![Arc::clone(&c2_non_nullable), Arc::clone(&c2_non_nullable)]; + test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false); Ok(()) } #[test] - fn in_list_bool() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); - let a = BooleanArray::from(vec![Some(true), None]); - let col_a = col("a", &schema)?; + fn in_list_no_cols() -> Result<()> { + // test logic when the in_list expression doesn't have any columns + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![Some(1), Some(2), None]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - // expression: "a in (true)" - let list = vec![lit(true)]; - in_list!( - batch, - list, - &false, - vec![Some(true), None], - Arc::clone(&col_a), - &schema - ); + let list = vec![lit(ScalarValue::from(1i32)), lit(ScalarValue::from(6i32))]; - // expression: "a not in (true)" - let list = vec![lit(true)]; + // 1 IN (1, 6) + let expr = lit(ScalarValue::Int32(Some(1))); in_list!( batch, - list, - &true, - vec![Some(false), None], - Arc::clone(&col_a), + list.clone(), + &false, + // should have three outputs, as the input batch has three rows + vec![Some(true), Some(true), Some(true)], + expr, &schema ); - // expression: "a in (true, NULL)" - let list = vec![lit(true), lit(ScalarValue::Null)]; + // 2 IN (1, 6) + let expr = lit(ScalarValue::Int32(Some(2))); in_list!( batch, - list, + list.clone(), &false, - vec![Some(true), None], - Arc::clone(&col_a), + // should have three outputs, as the input batch has three rows + vec![Some(false), Some(false), Some(false)], + expr, &schema ); - // expression: "a not in (true, NULL)" - let list = vec![lit(true), lit(ScalarValue::Null)]; + // NULL IN (1, 6) + let expr = lit(ScalarValue::Int32(None)); in_list!( batch, - list, - &true, - vec![Some(false), None], - Arc::clone(&col_a), + list.clone(), + &false, + // should have three outputs, as the input batch has three rows + vec![None, None, None], + expr, &schema ); @@ -1029,66 +1263,218 @@ mod tests { } #[test] - fn in_list_date64() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Date64, true)]); - let a = Date64Array::from(vec![Some(0), Some(2), None]); + fn in_list_utf8_with_dict_types() -> Result<()> { + fn dict_lit(key_type: DataType, value: &str) -> Arc { + lit(ScalarValue::Dictionary( + Box::new(key_type), + Box::new(ScalarValue::new_utf8(value.to_string())), + )) + } + + fn null_dict_lit(key_type: DataType) -> Arc { + lit(ScalarValue::Dictionary( + Box::new(key_type), + Box::new(ScalarValue::Utf8(None)), + )) + } + + let schema = Schema::new(vec![Field::new( + "a", + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + true, + )]); + let a: UInt16DictionaryArray = + vec![Some("a"), Some("d"), None].into_iter().collect(); let col_a = col("a", &schema)?; let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - // expression: "a in (0, 1)" - let list = vec![ - lit(ScalarValue::Date64(Some(0))), - lit(ScalarValue::Date64(Some(1))), + // expression: "a in ("a", "b")" + let lists = [ + vec![lit("a"), lit("b")], + vec![ + dict_lit(DataType::Int8, "a"), + dict_lit(DataType::UInt16, "b"), + ], ]; - in_list!( - batch, - list, - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); + for list in lists.iter() { + in_list_raw!( + batch, + list.clone(), + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + } - // expression: "a not in (0, 1)" - let list = vec![ - lit(ScalarValue::Date64(Some(0))), - lit(ScalarValue::Date64(Some(1))), - ]; - in_list!( - batch, - list, - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); + // expression: "a not in ("a", "b")" + for list in lists.iter() { + in_list_raw!( + batch, + list.clone(), + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + } - // expression: "a in (0, 1, NULL)" - let list = vec![ - lit(ScalarValue::Date64(Some(0))), - lit(ScalarValue::Date64(Some(1))), - lit(ScalarValue::Null), + // expression: "a in ("a", "b", null)" + let lists = [ + vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))], + vec![ + dict_lit(DataType::Int8, "a"), + dict_lit(DataType::UInt16, "b"), + null_dict_lit(DataType::UInt16), + ], ]; - in_list!( + for list in lists.iter() { + in_list_raw!( + batch, + list.clone(), + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + } + + // expression: "a not in ("a", "b", null)" + for list in lists.iter() { + in_list_raw!( + batch, + list.clone(), + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + } + + Ok(()) + } + + #[test] + fn test_fmt_sql_1() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let col_a = col("a", &schema)?; + + // Test: a IN ('a', 'b') + let list = vec![lit("a"), lit("b")]; + let expr = in_list(Arc::clone(&col_a), list, &false, &schema)?; + let sql_string = fmt_sql(expr.as_ref()).to_string(); + let display_string = expr.to_string(); + assert_snapshot!(sql_string, @"a IN (a, b)"); + assert_snapshot!(display_string, @"a@0 IN (SET) ([a, b])"); + Ok(()) + } + + #[test] + fn test_fmt_sql_2() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let col_a = col("a", &schema)?; + + // Test: a NOT IN ('a', 'b') + let list = vec![lit("a"), lit("b")]; + let expr = in_list(Arc::clone(&col_a), list, &true, &schema)?; + let sql_string = fmt_sql(expr.as_ref()).to_string(); + let display_string = expr.to_string(); + + assert_snapshot!(sql_string, @"a NOT IN (a, b)"); + assert_snapshot!(display_string, @"a@0 NOT IN (SET) ([a, b])"); + Ok(()) + } + + #[test] + fn test_fmt_sql_3() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let col_a = col("a", &schema)?; + // Test: a IN ('a', 'b', NULL) + let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; + let expr = in_list(Arc::clone(&col_a), list, &false, &schema)?; + let sql_string = fmt_sql(expr.as_ref()).to_string(); + let display_string = expr.to_string(); + + assert_snapshot!(sql_string, @"a IN (a, b, NULL)"); + assert_snapshot!(display_string, @"a@0 IN (SET) ([a, b, NULL])"); + Ok(()) + } + + #[test] + fn test_fmt_sql_4() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let col_a = col("a", &schema)?; + // Test: a NOT IN ('a', 'b', NULL) + let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; + let expr = in_list(Arc::clone(&col_a), list, &true, &schema)?; + let sql_string = fmt_sql(expr.as_ref()).to_string(); + let display_string = expr.to_string(); + assert_snapshot!(sql_string, @"a NOT IN (a, b, NULL)"); + assert_snapshot!(display_string, @"a@0 NOT IN (SET) ([a, b, NULL])"); + Ok(()) + } + + #[test] + fn in_list_struct() -> Result<()> { + // Create schema with a struct column + let struct_fields = Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, false), + ]); + let schema = Schema::new(vec![Field::new( + "a", + DataType::Struct(struct_fields.clone()), + true, + )]); + + // Create test data: array of structs + let x_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let y_array = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let struct_array = + StructArray::new(struct_fields.clone(), vec![x_array, y_array], None); + + let col_a = col("a", &schema)?; + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(struct_array)])?; + + // Create literal structs for the IN list + // Struct {x: 1, y: "a"} + let struct1 = ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(StringArray::from(vec!["a"])), + ], + None, + ))); + + // Struct {x: 3, y: "c"} + let struct3 = ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![3])), + Arc::new(StringArray::from(vec!["c"])), + ], + None, + ))); + + // Test: a IN ({1, "a"}, {3, "c"}) + let list = vec![lit(struct1.clone()), lit(struct3.clone())]; + in_list_raw!( batch, - list, + list.clone(), &false, - vec![Some(true), None, None], + vec![Some(true), Some(false), Some(true)], Arc::clone(&col_a), &schema ); - // expression: "a not in (0, 1, NULL)" - let list = vec![ - lit(ScalarValue::Date64(Some(0))), - lit(ScalarValue::Date64(Some(1))), - lit(ScalarValue::Null), - ]; - in_list!( + // Test: a NOT IN ({1, "a"}, {3, "c"}) + in_list_raw!( batch, list, &true, - vec![Some(false), None, None], + vec![Some(false), Some(true), Some(false)], Arc::clone(&col_a), &schema ); @@ -1097,62 +1483,116 @@ mod tests { } #[test] - fn in_list_date32() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Date32, true)]); - let a = Date32Array::from(vec![Some(0), Some(2), None]); + fn in_list_struct_with_nulls() -> Result<()> { + // Create schema with a struct column + let struct_fields = Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, false), + ]); + let schema = Schema::new(vec![Field::new( + "a", + DataType::Struct(struct_fields.clone()), + true, + )]); + + // Create test data with a null struct + let x_array = Arc::new(Int32Array::from(vec![1, 2])); + let y_array = Arc::new(StringArray::from(vec!["a", "b"])); + let struct_array = StructArray::new( + struct_fields.clone(), + vec![x_array, y_array], + Some(NullBuffer::from(vec![true, false])), + ); + let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(struct_array)])?; - // expression: "a in (0, 1)" - let list = vec![ - lit(ScalarValue::Date32(Some(0))), - lit(ScalarValue::Date32(Some(1))), - ]; - in_list!( + // Create literal struct for the IN list + let struct1 = ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(StringArray::from(vec!["a"])), + ], + None, + ))); + + // Test: a IN ({1, "a"}) + let list = vec![lit(struct1.clone())]; + in_list_raw!( batch, - list, + list.clone(), &false, - vec![Some(true), Some(false), None], + vec![Some(true), None], Arc::clone(&col_a), &schema ); - // expression: "a not in (0, 1)" - let list = vec![ - lit(ScalarValue::Date32(Some(0))), - lit(ScalarValue::Date32(Some(1))), - ]; - in_list!( + // Test: a NOT IN ({1, "a"}) + in_list_raw!( batch, list, &true, - vec![Some(false), Some(true), None], + vec![Some(false), None], Arc::clone(&col_a), &schema ); - // expression: "a in (0, 1, NULL)" - let list = vec![ - lit(ScalarValue::Date32(Some(0))), - lit(ScalarValue::Date32(Some(1))), - lit(ScalarValue::Null), - ]; - in_list!( + Ok(()) + } + + #[test] + fn in_list_struct_with_null_in_list() -> Result<()> { + // Create schema with a struct column + let struct_fields = Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, false), + ]); + let schema = Schema::new(vec![Field::new( + "a", + DataType::Struct(struct_fields.clone()), + true, + )]); + + // Create test data + let x_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let y_array = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let struct_array = + StructArray::new(struct_fields.clone(), vec![x_array, y_array], None); + + let col_a = col("a", &schema)?; + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(struct_array)])?; + + // Create literal structs including a NULL + let struct1 = ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(StringArray::from(vec!["a"])), + ], + None, + ))); + + let null_struct = ScalarValue::Struct(Arc::new(StructArray::new_null( + struct_fields.clone(), + 1, + ))); + + // Test: a IN ({1, "a"}, NULL) + let list = vec![lit(struct1), lit(null_struct.clone())]; + in_list_raw!( batch, - list, + list.clone(), &false, vec![Some(true), None, None], Arc::clone(&col_a), &schema ); - // expression: "a not in (0, 1, NULL)" - let list = vec![ - lit(ScalarValue::Date32(Some(0))), - lit(ScalarValue::Date32(Some(1))), - lit(ScalarValue::Null), - ]; - in_list!( + // Test: a NOT IN ({1, "a"}, NULL) + in_list_raw!( batch, list, &true, @@ -1165,99 +1605,271 @@ mod tests { } #[test] - fn in_list_decimal() -> Result<()> { - // Now, we can check the NULL type - let schema = - Schema::new(vec![Field::new("a", DataType::Decimal128(13, 4), true)]); - let array = vec![Some(100_0000_i128), None, Some(200_5000_i128)] - .into_iter() - .collect::(); - let array = array.with_precision_and_scale(13, 4).unwrap(); - let col_a = col("a", &schema)?; + fn in_list_nested_struct() -> Result<()> { + // Create nested struct schema + let inner_struct_fields = Fields::from(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ]); + let outer_struct_fields = Fields::from(vec![ + Field::new( + "inner", + DataType::Struct(inner_struct_fields.clone()), + false, + ), + Field::new("c", DataType::Int32, false), + ]); + let schema = Schema::new(vec![Field::new( + "x", + DataType::Struct(outer_struct_fields.clone()), + true, + )]); + + // Create test data with nested structs + let inner1 = Arc::new(StructArray::new( + inner_struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["x", "y"])), + ], + None, + )); + let c_array = Arc::new(Int32Array::from(vec![10, 20])); + let outer_array = + StructArray::new(outer_struct_fields.clone(), vec![inner1, c_array], None); + + let col_x = col("x", &schema)?; let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)])?; + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(outer_array)])?; - // expression: "a in (100,200), the data type of list is INT32 - let list = vec![lit(100i32), lit(200i32)]; - in_list!( + // Create a nested struct literal matching the first row + let inner_match = Arc::new(StructArray::new( + inner_struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(StringArray::from(vec!["x"])), + ], + None, + )); + let outer_match = ScalarValue::Struct(Arc::new(StructArray::new( + outer_struct_fields.clone(), + vec![inner_match, Arc::new(Int32Array::from(vec![10]))], + None, + ))); + + // Test: x IN ({{1, "x"}, 10}) + let list = vec![lit(outer_match)]; + in_list_raw!( batch, - list, + list.clone(), &false, - vec![Some(true), None, Some(false)], - Arc::clone(&col_a), + vec![Some(true), Some(false)], + Arc::clone(&col_x), &schema ); - // expression: "a not in (100,200) - let list = vec![lit(100i32), lit(200i32)]; - in_list!( + + // Test: x NOT IN ({{1, "x"}, 10}) + in_list_raw!( batch, list, &true, - vec![Some(false), None, Some(true)], - Arc::clone(&col_a), + vec![Some(false), Some(true)], + Arc::clone(&col_x), &schema ); - // expression: "a in (200,NULL), the data type of list is INT32 AND NULL - let list = vec![lit(ScalarValue::Int32(Some(100))), lit(ScalarValue::Null)]; - in_list!( - batch, - list.clone(), - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema + Ok(()) + } + + #[test] + fn in_list_struct_with_exprs_not_array() -> Result<()> { + // Test InList using expressions (not the array constructor) with structs + // By using InListExpr::new directly, we bypass the array optimization + // and use the Exprs variant, testing the expression evaluation path + + // Create schema with a struct column {x: Int32, y: Utf8} + let struct_fields = Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, false), + ]); + let schema = Schema::new(vec![Field::new( + "a", + DataType::Struct(struct_fields.clone()), + true, + )]); + + // Create test data: array of structs [{1, "a"}, {2, "b"}, {3, "c"}] + let x_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + let y_array = Arc::new(StringArray::from(vec!["a", "b", "c"])); + let struct_array = + StructArray::new(struct_fields.clone(), vec![x_array, y_array], None); + + let col_a = col("a", &schema)?; + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(struct_array)])?; + + // Create struct literals with the SAME shape (so types are compatible) + // Struct {x: 1, y: "a"} + let struct1 = ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(StringArray::from(vec!["a"])), + ], + None, + ))); + + // Struct {x: 3, y: "c"} + let struct3 = ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![3])), + Arc::new(StringArray::from(vec!["c"])), + ], + None, + ))); + + // Create list of struct expressions + let list = vec![lit(struct1), lit(struct3)]; + + // Use InListExpr::new directly (not in_list()) to bypass array optimization + // This creates an InList without a static filter + let expr = Arc::new(InListExpr::new(Arc::clone(&col_a), list, false, None)); + + // Verify that the expression doesn't have a static filter + // by checking the display string does NOT contain "(SET)" + let display_string = expr.to_string(); + assert!( + !display_string.contains("(SET)"), + "Expected display string to NOT contain '(SET)' (should use Exprs variant), but got: {display_string}", ); - // expression: "a not in (200,NULL), the data type of list is INT32 AND NULL - in_list!( - batch, - list, - &true, - vec![Some(false), None, None], + + // Evaluate the expression + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + + // Expected: first row {1, "a"} matches struct1, + // second row {2, "b"} doesn't match, + // third row {3, "c"} matches struct3 + let expected = BooleanArray::from(vec![Some(true), Some(false), Some(true)]); + assert_eq!(result, &expected); + + // Test NOT IN as well + let expr_not = Arc::new(InListExpr::new( Arc::clone(&col_a), - &schema - ); + vec![ + lit(ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(StringArray::from(vec!["a"])), + ], + None, + )))), + lit(ScalarValue::Struct(Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![3])), + Arc::new(StringArray::from(vec!["c"])), + ], + None, + )))), + ], + true, + None, + )); + + let result_not = expr_not.evaluate(&batch)?.into_array(batch.num_rows())?; + let result_not = as_boolean_array(&result_not); + + let expected_not = BooleanArray::from(vec![Some(false), Some(true), Some(false)]); + assert_eq!(result_not, &expected_not); + + Ok(()) + } + + #[test] + fn test_in_list_null_handling_comprehensive() -> Result<()> { + // Comprehensive test demonstrating SQL three-valued logic for IN expressions + // This test explicitly shows all possible outcomes: true, false, and null + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + + // Test data: [1, 2, 3, null] + // - 1 will match in both lists + // - 2 will not match in either list + // - 3 will not match in either list + // - null is always null + let a = Int64Array::from(vec![Some(1), Some(2), Some(3), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - // expression: "a in (200.5, 100), the data type of list is FLOAT32 and INT32 - let list = vec![lit(200.50f32), lit(100i32)]; + // Case 1: List WITHOUT null - demonstrates true/false/null outcomes + // "a IN (1, 4)" - 1 matches, 2 and 3 don't match, null is null + let list = vec![lit(1i64), lit(4i64)]; in_list!( batch, list, &false, - vec![Some(true), None, Some(true)], + vec![ + Some(true), // 1 is in the list → true + Some(false), // 2 is not in the list → false + Some(false), // 3 is not in the list → false + None, // null IN (...) → null (SQL three-valued logic) + ], Arc::clone(&col_a), &schema ); - // expression: "a not in (200.5, 100), the data type of list is FLOAT32 and INT32 - let list = vec![lit(200.50f32), lit(101i32)]; + // Case 2: List WITH null - demonstrates null propagation for non-matches + // "a IN (1, NULL)" - 1 matches (true), 2/3 don't match but list has null (null), null is null + let list = vec![lit(1i64), lit(ScalarValue::Int64(None))]; in_list!( batch, list, - &true, - vec![Some(true), None, Some(false)], + &false, + vec![ + Some(true), // 1 is in the list → true (found match) + None, // 2 is not in list, but list has NULL → null (might match NULL) + None, // 3 is not in list, but list has NULL → null (might match NULL) + None, // null IN (...) → null (SQL three-valued logic) + ], Arc::clone(&col_a), &schema ); - // test the optimization: set - // expression: "a in (99..300), the data type of list is INT32 - let list = (99i32..300).map(lit).collect::>(); + Ok(()) + } + + #[test] + fn test_in_list_with_only_nulls() -> Result<()> { + // Edge case: IN list contains ONLY null values + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + let a = Int64Array::from(vec![Some(1), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // "a IN (NULL, NULL)" - list has only nulls + let list = vec![lit(ScalarValue::Int64(None)), lit(ScalarValue::Int64(None))]; + // All results should be NULL because: + // - Non-null values (1, 2) can't match anything concrete, but list might contain matching value + // - NULL value is always NULL in IN expressions in_list!( batch, list.clone(), &false, - vec![Some(true), None, Some(false)], + vec![None, None, None], Arc::clone(&col_a), &schema ); + // "a NOT IN (NULL, NULL)" - list has only nulls + // All results should still be NULL due to three-valued logic in_list!( batch, list, &true, - vec![Some(false), None, Some(true)], + vec![None, None, None], Arc::clone(&col_a), &schema ); @@ -1266,150 +1878,134 @@ mod tests { } #[test] - fn test_cast_static_filter_to_set() -> Result<()> { - // random schema - let schema = - Schema::new(vec![Field::new("a", DataType::Decimal128(13, 4), true)]); - - // list of phy expr - let mut phy_exprs = vec![ - lit(1i64), - expressions::cast(lit(2i32), &schema, DataType::Int64)?, - try_cast(lit(3.13f32), &schema, DataType::Int64)?, - ]; - let static_filter = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); - - let array = Int64Array::from(vec![1, 2, 3, 4]); - let r = static_filter.contains(&array, false).unwrap(); - assert_eq!(r, BooleanArray::from(vec![true, true, true, false])); + fn test_in_list_multiple_nulls_deduplication() -> Result<()> { + // Test that multiple NULLs in the list are handled correctly + // This verifies deduplication doesn't break null handling + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + let col_a = col("a", &schema)?; - try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); - // cast(cast(lit())), but the cast to the same data type, one case will be ignored - phy_exprs.push(expressions::cast( - expressions::cast(lit(2i32), &schema, DataType::Int64)?, - &schema, - DataType::Int64, - )?); - try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); - - phy_exprs.clear(); + // Create array with multiple nulls: [1, 2, NULL, NULL, 3, NULL] + let array = Arc::new(Int64Array::from(vec![ + Some(1), + Some(2), + None, + None, + Some(3), + None, + ])) as ArrayRef; - // case(cast(lit())), the cast to the diff data type - phy_exprs.push(expressions::cast( - expressions::cast(lit(2i32), &schema, DataType::Int64)?, + // Create InListExpr from array + let expr = Arc::new(InListExpr::try_new_from_array( + Arc::clone(&col_a), + array, + false, &schema, - DataType::Int32, - )?); - try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); + )?) as Arc; + + // Create test data: [1, 2, 3, 4, null] + let a = Int64Array::from(vec![Some(1), Some(2), Some(3), Some(4), None]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // Evaluate the expression + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); - // column - phy_exprs.push(col("a", &schema)?); - assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_err()); + // Expected behavior with multiple NULLs in list: + // - Values in the list (1,2,3) → true + // - Values not in the list (4) → NULL (because list contains NULL) + // - NULL input → NULL + let expected = BooleanArray::from(vec![ + Some(true), // 1 is in list + Some(true), // 2 is in list + Some(true), // 3 is in list + None, // 4 not in list, but list has NULLs + None, // NULL input + ]); + assert_eq!(result, &expected); Ok(()) } #[test] - fn in_list_timestamp() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Timestamp(TimeUnit::Microsecond, None), - true, - )]); - let a = TimestampMicrosecondArray::from(vec![ - Some(1388588401000000000), - Some(1288588501000000000), - None, - ]); + fn test_not_in_null_handling_comprehensive() -> Result<()> { + // Comprehensive test demonstrating SQL three-valued logic for NOT IN expressions + // This test explicitly shows all possible outcomes for NOT IN: true, false, and null + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + + // Test data: [1, 2, 3, null] + let a = Int64Array::from(vec![Some(1), Some(2), Some(3), None]); let col_a = col("a", &schema)?; let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - let list = vec![ - lit(ScalarValue::TimestampMicrosecond( - Some(1388588401000000000), - None, - )), - lit(ScalarValue::TimestampMicrosecond( - Some(1388588401000000001), - None, - )), - lit(ScalarValue::TimestampMicrosecond( - Some(1388588401000000002), - None, - )), - ]; - + // Case 1: List WITHOUT null - demonstrates true/false/null outcomes for NOT IN + // "a NOT IN (1, 4)" - 1 matches (false), 2 and 3 don't match (true), null is null + let list = vec![lit(1i64), lit(4i64)]; in_list!( batch, - list.clone(), - &false, - vec![Some(true), Some(false), None], + list, + &true, + vec![ + Some(false), // 1 is in the list → NOT IN returns false + Some(true), // 2 is not in the list → NOT IN returns true + Some(true), // 3 is not in the list → NOT IN returns true + None, // null NOT IN (...) → null (SQL three-valued logic) + ], Arc::clone(&col_a), &schema ); + // Case 2: List WITH null - demonstrates null propagation for NOT IN + // "a NOT IN (1, NULL)" - 1 matches (false), 2/3 don't match but list has null (null), null is null + let list = vec![lit(1i64), lit(ScalarValue::Int64(None))]; in_list!( batch, - list.clone(), + list, &true, - vec![Some(false), Some(true), None], + vec![ + Some(false), // 1 is in the list → NOT IN returns false + None, // 2 is not in known values, but list has NULL → null (can't prove it's not in list) + None, // 3 is not in known values, but list has NULL → null (can't prove it's not in list) + None, // null NOT IN (...) → null (SQL three-valued logic) + ], Arc::clone(&col_a), &schema ); + Ok(()) } #[test] - fn in_expr_with_multiple_element_in_list() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("a", DataType::Float64, true), - Field::new("b", DataType::Float64, true), - Field::new("c", DataType::Float64, true), - ]); - let a = Float64Array::from(vec![ - Some(0.0), - Some(1.0), - Some(2.0), - Some(f64::NAN), - Some(-f64::NAN), - ]); - let b = Float64Array::from(vec![ - Some(8.0), - Some(1.0), - Some(5.0), - Some(f64::NAN), - Some(3.0), - ]); - let c = Float64Array::from(vec![ - Some(6.0), - Some(7.0), - None, - Some(5.0), - Some(-f64::NAN), - ]); + fn test_in_list_null_type_column() -> Result<()> { + // Test with a column that has DataType::Null (not just nullable values) + // All values in a NullArray are null by definition + let schema = Schema::new(vec![Field::new("a", DataType::Null, true)]); + let a = NullArray::new(3); let col_a = col("a", &schema)?; - let col_b = col("b", &schema)?; - let col_c = col("c", &schema)?; - let batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(a), Arc::new(b), Arc::new(c)], - )?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - let list = vec![Arc::clone(&col_b), Arc::clone(&col_c)]; + // "null_column IN (1, 2)" - comparing Null type against Int64 list + // Note: This tests type coercion behavior between Null and Int64 + let list = vec![lit(1i64), lit(2i64)]; + + // All results should be NULL because: + // - Every value in the column is null (DataType::Null) + // - null IN (anything) always returns null per SQL three-valued logic in_list!( batch, list.clone(), &false, - vec![Some(false), Some(true), None, Some(true), Some(true)], + vec![None, None, None], Arc::clone(&col_a), &schema ); + // "null_column NOT IN (1, 2)" + // Same behavior for NOT IN - null NOT IN (anything) is still null in_list!( batch, list, &true, - vec![Some(true), Some(false), None, Some(false), Some(false)], + vec![None, None, None], Arc::clone(&col_a), &schema ); @@ -1417,667 +2013,884 @@ mod tests { Ok(()) } - macro_rules! test_nullable { - ($COL:expr, $LIST:expr, $SCHEMA:expr, $EXPECTED:expr) => {{ - let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST, $SCHEMA)?; - let expr = in_list(cast_expr, cast_list_exprs, &false, $SCHEMA).unwrap(); - let result = expr.nullable($SCHEMA)?; - assert_eq!($EXPECTED, result); - }}; - } - #[test] - fn in_list_nullable() -> Result<()> { - let schema = Schema::new(vec![ - Field::new("c1_nullable", DataType::Int64, true), - Field::new("c2_non_nullable", DataType::Int64, false), - ]); - - let c1_nullable = col("c1_nullable", &schema)?; - let c2_non_nullable = col("c2_non_nullable", &schema)?; - - // static_filter has no nulls - let list = vec![lit(1_i64), lit(2_i64)]; - test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); - test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false); - - // static_filter has nulls - let list = vec![lit(1_i64), lit(2_i64), lit(ScalarValue::Null)]; - test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); - test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true); + fn test_in_list_null_type_list() -> Result<()> { + // Test with a list that has DataType::Null + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + let a = Int64Array::from(vec![Some(1), Some(2), None]); + let col_a = col("a", &schema)?; - let list = vec![Arc::clone(&c1_nullable)]; - test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true); + // Create a NullArray as the list + let null_array = Arc::new(NullArray::new(2)) as ArrayRef; - let list = vec![Arc::clone(&c2_non_nullable)]; - test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); + // Try to create InListExpr with a NullArray list + // This tests whether try_new_from_array can handle Null type arrays + let expr = Arc::new(InListExpr::try_new_from_array( + Arc::clone(&col_a), + null_array, + false, + &schema, + )?) as Arc; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); - let list = vec![Arc::clone(&c2_non_nullable), Arc::clone(&c2_non_nullable)]; - test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false); + // If it succeeds, all results should be NULL + // because the list contains only null type values + let expected = BooleanArray::from(vec![None, None, None]); + assert_eq!(result, &expected); Ok(()) } #[test] - fn in_list_no_cols() -> Result<()> { - // test logic when the in_list expression doesn't have any columns - let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); - let a = Int32Array::from(vec![Some(1), Some(2), None]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + fn test_in_list_null_type_both() -> Result<()> { + // Test when both column and list are DataType::Null + let schema = Schema::new(vec![Field::new("a", DataType::Null, true)]); + let a = NullArray::new(3); + let col_a = col("a", &schema)?; - let list = vec![lit(ScalarValue::from(1i32)), lit(ScalarValue::from(6i32))]; + // Create a NullArray as the list + let null_array = Arc::new(NullArray::new(2)) as ArrayRef; - // 1 IN (1, 6) - let expr = lit(ScalarValue::Int32(Some(1))); - in_list!( - batch, - list.clone(), - &false, - // should have three outputs, as the input batch has three rows - vec![Some(true), Some(true), Some(true)], - expr, - &schema - ); + // Try to create InListExpr with both Null types + let expr = Arc::new(InListExpr::try_new_from_array( + Arc::clone(&col_a), + null_array, + false, + &schema, + )?) as Arc; - // 2 IN (1, 6) - let expr = lit(ScalarValue::Int32(Some(2))); - in_list!( - batch, - list.clone(), - &false, - // should have three outputs, as the input batch has three rows - vec![Some(false), Some(false), Some(false)], - expr, - &schema - ); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); - // NULL IN (1, 6) - let expr = lit(ScalarValue::Int32(None)); - in_list!( - batch, - list.clone(), - &false, - // should have three outputs, as the input batch has three rows - vec![None, None, None], - expr, - &schema - ); + // If successful, all results should be NULL + // null IN [null, null] -> null + let expected = BooleanArray::from(vec![None, None, None]); + assert_eq!(result, &expected); Ok(()) } #[test] - fn in_list_utf8_with_dict_types() -> Result<()> { - fn dict_lit(key_type: DataType, value: &str) -> Arc { - lit(ScalarValue::Dictionary( - Box::new(key_type), - Box::new(ScalarValue::new_utf8(value.to_string())), - )) - } + fn test_in_list_comprehensive_null_handling() -> Result<()> { + // Comprehensive test for IN LIST operations with various NULL handling scenarios. + // This test covers the key cases validated against DuckDB as the source of truth. + // + // Note: Some scalar literal tests (like NULL IN (1, 2)) are omitted as they + // appear to expose an issue with static filter optimization. These are covered + // by existing tests like in_list_no_cols(). - fn null_dict_lit(key_type: DataType) -> Arc { - lit(ScalarValue::Dictionary( - Box::new(key_type), - Box::new(ScalarValue::Utf8(None)), - )) - } + let schema = Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])); + let col_b = col("b", &schema)?; + let null_i32 = ScalarValue::Int32(None); - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), - true, - )]); - let a: UInt16DictionaryArray = - vec![Some("a"), Some("d"), None].into_iter().collect(); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + // Helper to create a batch + let make_batch = |values: Vec>| -> Result { + let array = Arc::new(Int32Array::from(values)); + Ok(RecordBatch::try_new(Arc::clone(&schema), vec![array])?) + }; - // expression: "a in ("a", "b")" - let lists = [ - vec![lit("a"), lit("b")], - vec![ - dict_lit(DataType::Int8, "a"), - dict_lit(DataType::UInt16, "b"), - ], - ]; - for list in lists.iter() { - in_list_raw!( - batch, - list.clone(), - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - } + // Helper to run a test + let run_test = |batch: &RecordBatch, + expr: Arc, + list: Vec>, + expected: Vec>| + -> Result<()> { + let in_expr = in_list(expr, list, &false, schema.as_ref())?; + let result = in_expr.evaluate(batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!(result, &BooleanArray::from(expected)); + Ok(()) + }; - // expression: "a not in ("a", "b")" - for list in lists.iter() { - in_list_raw!( - batch, - list.clone(), - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - } + // ======================================================================== + // COLUMN TESTS - col(b) IN [1, 2] + // ======================================================================== - // expression: "a in ("a", "b", null)" - let lists = [ - vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))], - vec![ - dict_lit(DataType::Int8, "a"), - dict_lit(DataType::UInt16, "b"), - null_dict_lit(DataType::UInt16), - ], - ]; - for list in lists.iter() { - in_list_raw!( - batch, - list.clone(), - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - } + // [1] IN (1, 2) => [TRUE] + let batch = make_batch(vec![Some(1)])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(1i32), lit(2i32)], + vec![Some(true)], + )?; - // expression: "a not in ("a", "b", null)" - for list in lists.iter() { - in_list_raw!( - batch, - list.clone(), - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - } + // [1, 2] IN (1, 2) => [TRUE, TRUE] + let batch = make_batch(vec![Some(1), Some(2)])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(1i32), lit(2i32)], + vec![Some(true), Some(true)], + )?; - Ok(()) - } + // [3, 4] IN (1, 2) => [FALSE, FALSE] + let batch = make_batch(vec![Some(3), Some(4)])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(1i32), lit(2i32)], + vec![Some(false), Some(false)], + )?; - #[test] - fn test_fmt_sql_1() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); - let col_a = col("a", &schema)?; + // [1, NULL] IN (1, 2) => [TRUE, NULL] + let batch = make_batch(vec![Some(1), None])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(1i32), lit(2i32)], + vec![Some(true), None], + )?; - // Test: a IN ('a', 'b') - let list = vec![lit("a"), lit("b")]; - let expr = in_list(Arc::clone(&col_a), list, &false, &schema)?; - let sql_string = fmt_sql(expr.as_ref()).to_string(); - let display_string = expr.to_string(); - assert_snapshot!(sql_string, @"a IN (a, b)"); - assert_snapshot!(display_string, @"a@0 IN (SET) ([a, b])"); - Ok(()) - } + // [3, NULL] IN (1, 2) => [FALSE, NULL] (no match, NULL is NULL) + let batch = make_batch(vec![Some(3), None])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(1i32), lit(2i32)], + vec![Some(false), None], + )?; - #[test] - fn test_fmt_sql_2() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); - let col_a = col("a", &schema)?; + // ======================================================================== + // COLUMN WITH NULL IN LIST - col(b) IN [NULL, 1] + // ======================================================================== - // Test: a NOT IN ('a', 'b') - let list = vec![lit("a"), lit("b")]; - let expr = in_list(Arc::clone(&col_a), list, &true, &schema)?; - let sql_string = fmt_sql(expr.as_ref()).to_string(); - let display_string = expr.to_string(); + // [1] IN (NULL, 1) => [TRUE] (found match) + let batch = make_batch(vec![Some(1)])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(null_i32.clone()), lit(1i32)], + vec![Some(true)], + )?; - assert_snapshot!(sql_string, @"a NOT IN (a, b)"); - assert_snapshot!(display_string, @"a@0 NOT IN (SET) ([a, b])"); - Ok(()) - } + // [2] IN (NULL, 1) => [NULL] (no match, but list has NULL) + let batch = make_batch(vec![Some(2)])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(null_i32.clone()), lit(1i32)], + vec![None], + )?; - #[test] - fn test_fmt_sql_3() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); - let col_a = col("a", &schema)?; - // Test: a IN ('a', 'b', NULL) - let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; - let expr = in_list(Arc::clone(&col_a), list, &false, &schema)?; - let sql_string = fmt_sql(expr.as_ref()).to_string(); - let display_string = expr.to_string(); + // [NULL] IN (NULL, 1) => [NULL] + let batch = make_batch(vec![None])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(null_i32.clone()), lit(1i32)], + vec![None], + )?; + + // ======================================================================== + // COLUMN WITH ALL NULLS IN LIST - col(b) IN [NULL, NULL] + // ======================================================================== + + // [1] IN (NULL, NULL) => [NULL] + let batch = make_batch(vec![Some(1)])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(null_i32.clone()), lit(null_i32.clone())], + vec![None], + )?; + + // [NULL] IN (NULL, NULL) => [NULL] + let batch = make_batch(vec![None])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(null_i32.clone()), lit(null_i32.clone())], + vec![None], + )?; + + // ======================================================================== + // LITERAL IN LIST WITH COLUMN - lit(1) IN [2, col(b)] + // ======================================================================== + + // 1 IN (2, [1]) => [TRUE] (matches column value) + let batch = make_batch(vec![Some(1)])?; + run_test( + &batch, + lit(1i32), + vec![lit(2i32), Arc::clone(&col_b)], + vec![Some(true)], + )?; + + // 1 IN (2, [3]) => [FALSE] (no match) + let batch = make_batch(vec![Some(3)])?; + run_test( + &batch, + lit(1i32), + vec![lit(2i32), Arc::clone(&col_b)], + vec![Some(false)], + )?; + + // 1 IN (2, [NULL]) => [NULL] (no match, column is NULL) + let batch = make_batch(vec![None])?; + run_test( + &batch, + lit(1i32), + vec![lit(2i32), Arc::clone(&col_b)], + vec![None], + )?; + + // ======================================================================== + // COLUMN IN LIST CONTAINING ITSELF - col(b) IN [1, col(b)] + // ======================================================================== + + // [1] IN (1, [1]) => [TRUE] (always matches - either list literal or itself) + let batch = make_batch(vec![Some(1)])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(1i32), Arc::clone(&col_b)], + vec![Some(true)], + )?; + + // [2] IN (1, [2]) => [TRUE] (matches itself) + let batch = make_batch(vec![Some(2)])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(1i32), Arc::clone(&col_b)], + vec![Some(true)], + )?; + + // [NULL] IN (1, [NULL]) => [NULL] (NULL is never equal to anything) + let batch = make_batch(vec![None])?; + run_test( + &batch, + Arc::clone(&col_b), + vec![lit(1i32), Arc::clone(&col_b)], + vec![None], + )?; - assert_snapshot!(sql_string, @"a IN (a, b, NULL)"); - assert_snapshot!(display_string, @"a@0 IN (SET) ([a, b, NULL])"); Ok(()) } #[test] - fn test_fmt_sql_4() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); - let col_a = col("a", &schema)?; - // Test: a NOT IN ('a', 'b', NULL) - let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))]; - let expr = in_list(Arc::clone(&col_a), list, &true, &schema)?; - let sql_string = fmt_sql(expr.as_ref()).to_string(); - let display_string = expr.to_string(); - assert_snapshot!(sql_string, @"a NOT IN (a, b, NULL)"); - assert_snapshot!(display_string, @"a@0 NOT IN (SET) ([a, b, NULL])"); - Ok(()) - } + fn test_in_list_scalar_literal_cases() -> Result<()> { + // Test scalar literal cases (both NULL and non-NULL) to ensure SQL three-valued + // logic is correctly implemented. This covers the important case where a scalar + // value is tested against a list containing NULL. + + let schema = Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])); + let null_i32 = ScalarValue::Int32(None); + + // Helper to create a batch + let make_batch = |values: Vec>| -> Result { + let array = Arc::new(Int32Array::from(values)); + Ok(RecordBatch::try_new(Arc::clone(&schema), vec![array])?) + }; + + // Helper to run a test + let run_test = |batch: &RecordBatch, + expr: Arc, + list: Vec>, + negated: bool, + expected: Vec>| + -> Result<()> { + let in_expr = in_list(expr, list, &negated, schema.as_ref())?; + let result = in_expr.evaluate(batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + let expected_array = BooleanArray::from(expected); + assert_eq!( + result, + &expected_array, + "Expected {:?}, got {:?}", + expected_array, + result.iter().collect::>() + ); + Ok(()) + }; + + let batch = make_batch(vec![Some(1)])?; + + // ======================================================================== + // NULL LITERAL TESTS + // According to SQL semantics, NULL IN (any_list) should always return NULL + // ======================================================================== + + // NULL IN (1, 1) => NULL + run_test( + &batch, + lit(null_i32.clone()), + vec![lit(1i32), lit(1i32)], + false, + vec![None], + )?; + + // NULL IN (NULL, 1) => NULL + run_test( + &batch, + lit(null_i32.clone()), + vec![lit(null_i32.clone()), lit(1i32)], + false, + vec![None], + )?; + + // NULL IN (NULL, NULL) => NULL + run_test( + &batch, + lit(null_i32.clone()), + vec![lit(null_i32.clone()), lit(null_i32.clone())], + false, + vec![None], + )?; + + // ======================================================================== + // NON-NULL SCALAR LITERALS WITH NULL IN LIST - Int32 + // When a scalar value is NOT in a list containing NULL, the result is NULL + // When a scalar value IS in the list, the result is TRUE (NULL doesn't matter) + // ======================================================================== + + // 3 IN (0, 1, 2, NULL) => NULL (not in list, but list has NULL) + run_test( + &batch, + lit(3i32), + vec![lit(0i32), lit(1i32), lit(2i32), lit(null_i32.clone())], + false, + vec![None], + )?; + + // 3 NOT IN (0, 1, 2, NULL) => NULL (not in list, but list has NULL) + run_test( + &batch, + lit(3i32), + vec![lit(0i32), lit(1i32), lit(2i32), lit(null_i32.clone())], + true, + vec![None], + )?; - #[test] - fn in_list_struct() -> Result<()> { - // Create schema with a struct column - let struct_fields = Fields::from(vec![ - Field::new("x", DataType::Int32, false), - Field::new("y", DataType::Utf8, false), - ]); - let schema = Schema::new(vec![Field::new( - "a", - DataType::Struct(struct_fields.clone()), + // 1 IN (0, 1, 2, NULL) => TRUE (found match, NULL doesn't matter) + run_test( + &batch, + lit(1i32), + vec![lit(0i32), lit(1i32), lit(2i32), lit(null_i32.clone())], + false, + vec![Some(true)], + )?; + + // 1 NOT IN (0, 1, 2, NULL) => FALSE (found match, NULL doesn't matter) + run_test( + &batch, + lit(1i32), + vec![lit(0i32), lit(1i32), lit(2i32), lit(null_i32.clone())], true, - )]); + vec![Some(false)], + )?; - // Create test data: array of structs - let x_array = Arc::new(Int32Array::from(vec![1, 2, 3])); - let y_array = Arc::new(StringArray::from(vec!["a", "b", "c"])); - let struct_array = - StructArray::new(struct_fields.clone(), vec![x_array, y_array], None); + // ======================================================================== + // NON-NULL SCALAR LITERALS WITH NULL IN LIST - String + // Same semantics as Int32 but with string type + // ======================================================================== - let col_a = col("a", &schema)?; - let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(struct_array)])?; + let schema_str = + Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)])); + let batch_str = RecordBatch::try_new( + Arc::clone(&schema_str), + vec![Arc::new(StringArray::from(vec![Some("dummy")]))], + )?; + let null_str = ScalarValue::Utf8(None); - // Create literal structs for the IN list - // Struct {x: 1, y: "a"} - let struct1 = ScalarValue::Struct(Arc::new(StructArray::new( - struct_fields.clone(), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(StringArray::from(vec!["a"])), - ], - None, - ))); + let run_test_str = |expr: Arc, + list: Vec>, + negated: bool, + expected: Vec>| + -> Result<()> { + let in_expr = in_list(expr, list, &negated, schema_str.as_ref())?; + let result = in_expr + .evaluate(&batch_str)? + .into_array(batch_str.num_rows())?; + let result = as_boolean_array(&result); + let expected_array = BooleanArray::from(expected); + assert_eq!( + result, + &expected_array, + "Expected {:?}, got {:?}", + expected_array, + result.iter().collect::>() + ); + Ok(()) + }; - // Struct {x: 3, y: "c"} - let struct3 = ScalarValue::Struct(Arc::new(StructArray::new( - struct_fields.clone(), - vec![ - Arc::new(Int32Array::from(vec![3])), - Arc::new(StringArray::from(vec!["c"])), - ], - None, - ))); + // 'c' IN ('a', 'b', NULL) => NULL (not in list, but list has NULL) + run_test_str( + lit("c"), + vec![lit("a"), lit("b"), lit(null_str.clone())], + false, + vec![None], + )?; - // Test: a IN ({1, "a"}, {3, "c"}) - let list = vec![lit(struct1.clone()), lit(struct3.clone())]; - in_list_raw!( - batch, - list.clone(), - &false, - vec![Some(true), Some(false), Some(true)], - Arc::clone(&col_a), - &schema - ); + // 'c' NOT IN ('a', 'b', NULL) => NULL (not in list, but list has NULL) + run_test_str( + lit("c"), + vec![lit("a"), lit("b"), lit(null_str.clone())], + true, + vec![None], + )?; - // Test: a NOT IN ({1, "a"}, {3, "c"}) - in_list_raw!( - batch, - list, - &true, - vec![Some(false), Some(true), Some(false)], - Arc::clone(&col_a), - &schema - ); + // 'a' IN ('a', 'b', NULL) => TRUE (found match, NULL doesn't matter) + run_test_str( + lit("a"), + vec![lit("a"), lit("b"), lit(null_str.clone())], + false, + vec![Some(true)], + )?; + + // 'a' NOT IN ('a', 'b', NULL) => FALSE (found match, NULL doesn't matter) + run_test_str( + lit("a"), + vec![lit("a"), lit("b"), lit(null_str.clone())], + true, + vec![Some(false)], + )?; Ok(()) } #[test] - fn in_list_struct_with_nulls() -> Result<()> { - // Create schema with a struct column - let struct_fields = Fields::from(vec![ - Field::new("x", DataType::Int32, false), - Field::new("y", DataType::Utf8, false), - ]); - let schema = Schema::new(vec![Field::new( - "a", - DataType::Struct(struct_fields.clone()), - true, - )]); + fn test_in_list_tuple_cases() -> Result<()> { + // Test tuple/struct cases from the original request: (lit, lit) IN (lit, lit) + // These test row-wise comparisons like (1, 2) IN ((1, 2), (3, 4)) - // Create test data with a null struct - let x_array = Arc::new(Int32Array::from(vec![1, 2])); - let y_array = Arc::new(StringArray::from(vec!["a", "b"])); - let struct_array = StructArray::new( - struct_fields.clone(), - vec![x_array, y_array], - Some(NullBuffer::from(vec![true, false])), - ); + let schema = Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])); - let col_a = col("a", &schema)?; - let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(struct_array)])?; + // Helper to create struct scalars for tuple comparisons + let make_struct = |v1: Option, v2: Option| -> ScalarValue { + let fields = Fields::from(vec![ + Field::new("field_0", DataType::Int32, true), + Field::new("field_1", DataType::Int32, true), + ]); + ScalarValue::Struct(Arc::new(StructArray::new( + fields, + vec![ + Arc::new(Int32Array::from(vec![v1])), + Arc::new(Int32Array::from(vec![v2])), + ], + None, + ))) + }; - // Create literal struct for the IN list - let struct1 = ScalarValue::Struct(Arc::new(StructArray::new( - struct_fields.clone(), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(StringArray::from(vec!["a"])), - ], - None, - ))); + // Need a single row batch for scalar tests + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![Some(1)]))], + )?; - // Test: a IN ({1, "a"}) - let list = vec![lit(struct1.clone())]; - in_list_raw!( - batch, - list.clone(), - &false, - vec![Some(true), None], - Arc::clone(&col_a), - &schema - ); + // Helper to run tuple tests + let run_tuple_test = |lhs: ScalarValue, + list: Vec, + expected: Vec>| + -> Result<()> { + let expr = in_list( + lit(lhs), + list.into_iter().map(lit).collect(), + &false, + schema.as_ref(), + )?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!(result, &BooleanArray::from(expected)); + Ok(()) + }; - // Test: a NOT IN ({1, "a"}) - in_list_raw!( - batch, - list, - &true, - vec![Some(false), None], - Arc::clone(&col_a), - &schema - ); + // (NULL, NULL) IN ((1, 2)) => FALSE (tuples don't match) + run_tuple_test( + make_struct(None, None), + vec![make_struct(Some(1), Some(2))], + vec![Some(false)], + )?; - Ok(()) - } + // (NULL, NULL) IN ((NULL, 1)) => FALSE + run_tuple_test( + make_struct(None, None), + vec![make_struct(None, Some(1))], + vec![Some(false)], + )?; - #[test] - fn in_list_struct_with_null_in_list() -> Result<()> { - // Create schema with a struct column - let struct_fields = Fields::from(vec![ - Field::new("x", DataType::Int32, false), - Field::new("y", DataType::Utf8, false), - ]); - let schema = Schema::new(vec![Field::new( - "a", - DataType::Struct(struct_fields.clone()), - true, - )]); + // (NULL, NULL) IN ((NULL, NULL)) => TRUE (exact match including nulls) + run_tuple_test( + make_struct(None, None), + vec![make_struct(None, None)], + vec![Some(true)], + )?; - // Create test data - let x_array = Arc::new(Int32Array::from(vec![1, 2, 3])); - let y_array = Arc::new(StringArray::from(vec!["a", "b", "c"])); - let struct_array = - StructArray::new(struct_fields.clone(), vec![x_array, y_array], None); + // (NULL, 1) IN ((1, 2)) => FALSE + run_tuple_test( + make_struct(None, Some(1)), + vec![make_struct(Some(1), Some(2))], + vec![Some(false)], + )?; - let col_a = col("a", &schema)?; - let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(struct_array)])?; + // (NULL, 1) IN ((NULL, 1)) => TRUE (exact match) + run_tuple_test( + make_struct(None, Some(1)), + vec![make_struct(None, Some(1))], + vec![Some(true)], + )?; - // Create literal structs including a NULL - let struct1 = ScalarValue::Struct(Arc::new(StructArray::new( - struct_fields.clone(), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(StringArray::from(vec!["a"])), - ], - None, - ))); + // (NULL, 1) IN ((NULL, NULL)) => FALSE + run_tuple_test( + make_struct(None, Some(1)), + vec![make_struct(None, None)], + vec![Some(false)], + )?; - let null_struct = ScalarValue::Struct(Arc::new(StructArray::new_null( - struct_fields.clone(), - 1, - ))); + // (1, 2) IN ((1, 2)) => TRUE + run_tuple_test( + make_struct(Some(1), Some(2)), + vec![make_struct(Some(1), Some(2))], + vec![Some(true)], + )?; - // Test: a IN ({1, "a"}, NULL) - let list = vec![lit(struct1), lit(null_struct.clone())]; - in_list_raw!( - batch, - list.clone(), - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); + // (1, 3) IN ((1, 2)) => FALSE + run_tuple_test( + make_struct(Some(1), Some(3)), + vec![make_struct(Some(1), Some(2))], + vec![Some(false)], + )?; - // Test: a NOT IN ({1, "a"}, NULL) - in_list_raw!( - batch, - list, - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); + // (4, 4) IN ((1, 2)) => FALSE + run_tuple_test( + make_struct(Some(4), Some(4)), + vec![make_struct(Some(1), Some(2))], + vec![Some(false)], + )?; + + // (1, 1) IN ((NULL, 1)) => FALSE + run_tuple_test( + make_struct(Some(1), Some(1)), + vec![make_struct(None, Some(1))], + vec![Some(false)], + )?; + + // (1, 1) IN ((NULL, NULL)) => FALSE + run_tuple_test( + make_struct(Some(1), Some(1)), + vec![make_struct(None, None)], + vec![Some(false)], + )?; Ok(()) } #[test] - fn in_list_nested_struct() -> Result<()> { - // Create nested struct schema - let inner_struct_fields = Fields::from(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, false), - ]); - let outer_struct_fields = Fields::from(vec![ - Field::new( - "inner", - DataType::Struct(inner_struct_fields.clone()), - false, - ), - Field::new("c", DataType::Int32, false), - ]); - let schema = Schema::new(vec![Field::new( - "x", - DataType::Struct(outer_struct_fields.clone()), - true, - )]); - - // Create test data with nested structs - let inner1 = Arc::new(StructArray::new( - inner_struct_fields.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 2])), - Arc::new(StringArray::from(vec!["x", "y"])), - ], - None, - )); - let c_array = Arc::new(Int32Array::from(vec![10, 20])); - let outer_array = - StructArray::new(outer_struct_fields.clone(), vec![inner1, c_array], None); - - let col_x = col("x", &schema)?; - let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(outer_array)])?; + fn test_in_list_dictionary_int32() -> Result<()> { + // Create schema with dictionary-encoded Int32 column + let dict_type = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)); + let schema = Schema::new(vec![Field::new("a", dict_type.clone(), false)]); + let col_a = col("a", &schema)?; - // Create a nested struct literal matching the first row - let inner_match = Arc::new(StructArray::new( - inner_struct_fields.clone(), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(StringArray::from(vec!["x"])), - ], - None, - )); - let outer_match = ScalarValue::Struct(Arc::new(StructArray::new( - outer_struct_fields.clone(), - vec![inner_match, Arc::new(Int32Array::from(vec![10]))], - None, - ))); + // Create IN list with Int32 literals: (100, 200, 300) + let list = vec![lit(100i32), lit(200i32), lit(300i32)]; - // Test: x IN ({{1, "x"}, 10}) - let list = vec![lit(outer_match)]; - in_list_raw!( - batch, - list.clone(), - &false, - vec![Some(true), Some(false)], - Arc::clone(&col_x), - &schema - ); + // Create InListExpr via in_list() - this uses Int32StaticFilter for Int32 lists + let expr = in_list(col_a, list, &false, &schema)?; - // Test: x NOT IN ({{1, "x"}, 10}) - in_list_raw!( - batch, - list, - &true, - vec![Some(false), Some(true)], - Arc::clone(&col_x), - &schema - ); + // Create dictionary-encoded batch with values [100, 200, 500] + // Dictionary: keys [0, 1, 2] -> values [100, 200, 500] + // Using values clearly distinct from keys to avoid confusion + let keys = Int8Array::from(vec![0, 1, 2]); + let values = Int32Array::from(vec![100, 200, 500]); + let dict_array: ArrayRef = + Arc::new(DictionaryArray::try_new(keys, Arc::new(values))?); + let batch = RecordBatch::try_new(Arc::new(schema), vec![dict_array])?; + // Expected: [100 IN (100,200,300), 200 IN (100,200,300), 500 IN (100,200,300)] = [true, true, false] + let result = expr.evaluate(&batch)?.into_array(3)?; + let result = as_boolean_array(&result); + assert_eq!(result, &BooleanArray::from(vec![true, true, false])); Ok(()) } #[test] - fn in_list_struct_with_exprs_not_array() -> Result<()> { - // Test InList using expressions (not the array constructor) with structs - // By using InListExpr::new directly, we bypass the array optimization - // and use the Exprs variant, testing the expression evaluation path + fn test_in_list_dictionary_types() -> Result<()> { + // Helper functions for creating dictionary literals + fn dict_lit_int64(key_type: DataType, value: i64) -> Arc { + lit(ScalarValue::Dictionary( + Box::new(key_type), + Box::new(ScalarValue::Int64(Some(value))), + )) + } - // Create schema with a struct column {x: Int32, y: Utf8} - let struct_fields = Fields::from(vec![ - Field::new("x", DataType::Int32, false), - Field::new("y", DataType::Utf8, false), - ]); - let schema = Schema::new(vec![Field::new( - "a", - DataType::Struct(struct_fields.clone()), - true, - )]); + fn dict_lit_float64(key_type: DataType, value: f64) -> Arc { + lit(ScalarValue::Dictionary( + Box::new(key_type), + Box::new(ScalarValue::Float64(Some(value))), + )) + } - // Create test data: array of structs [{1, "a"}, {2, "b"}, {3, "c"}] - let x_array = Arc::new(Int32Array::from(vec![1, 2, 3])); - let y_array = Arc::new(StringArray::from(vec!["a", "b", "c"])); - let struct_array = - StructArray::new(struct_fields.clone(), vec![x_array, y_array], None); + // Test case structures + struct DictNeedleTest { + list_values: Vec>, + expected: Vec>, + } - let col_a = col("a", &schema)?; - let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(struct_array)])?; + struct DictionaryInListTestCase { + name: &'static str, + dict_type: DataType, + dict_keys: Vec>, + dict_values: ArrayRef, + list_values_no_null: Vec>, + list_values_with_null: Vec>, + expected_1: Vec>, + expected_2: Vec>, + expected_3: Vec>, + expected_4: Vec>, + dict_needle_test: Option, + } - // Create struct literals with the SAME shape (so types are compatible) - // Struct {x: 1, y: "a"} - let struct1 = ScalarValue::Struct(Arc::new(StructArray::new( - struct_fields.clone(), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(StringArray::from(vec!["a"])), - ], - None, - ))); + // Test harness function + fn run_dictionary_in_list_test( + test_case: DictionaryInListTestCase, + ) -> Result<()> { + // Create schema with dictionary type + let schema = + Schema::new(vec![Field::new("a", test_case.dict_type.clone(), true)]); + let col_a = col("a", &schema)?; + + // Create dictionary array from keys and values + let keys = Int8Array::from(test_case.dict_keys.clone()); + let dict_array: ArrayRef = + Arc::new(DictionaryArray::try_new(keys, test_case.dict_values)?); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![dict_array])?; + + let exp1 = test_case.expected_1.clone(); + let exp2 = test_case.expected_2.clone(); + let exp3 = test_case.expected_3.clone(); + let exp4 = test_case.expected_4; + + // Test 1: a IN (values_no_null) + in_list!( + batch, + test_case.list_values_no_null.clone(), + &false, + exp1, + Arc::clone(&col_a), + &schema + ); - // Struct {x: 3, y: "c"} - let struct3 = ScalarValue::Struct(Arc::new(StructArray::new( - struct_fields.clone(), - vec![ - Arc::new(Int32Array::from(vec![3])), - Arc::new(StringArray::from(vec!["c"])), - ], - None, - ))); + // Test 2: a NOT IN (values_no_null) + in_list!( + batch, + test_case.list_values_no_null.clone(), + &true, + exp2, + Arc::clone(&col_a), + &schema + ); - // Create list of struct expressions - let list = vec![lit(struct1), lit(struct3)]; + // Test 3: a IN (values_with_null) + in_list!( + batch, + test_case.list_values_with_null.clone(), + &false, + exp3, + Arc::clone(&col_a), + &schema + ); - // Use InListExpr::new directly (not in_list()) to bypass array optimization - // This creates an InList without a static filter - let expr = Arc::new(InListExpr::new(Arc::clone(&col_a), list, false, None)); + // Test 4: a NOT IN (values_with_null) + in_list!( + batch, + test_case.list_values_with_null, + &true, + exp4, + Arc::clone(&col_a), + &schema + ); - // Verify that the expression doesn't have a static filter - // by checking the display string does NOT contain "(SET)" - let display_string = expr.to_string(); - assert!( - !display_string.contains("(SET)"), - "Expected display string to NOT contain '(SET)' (should use Exprs variant), but got: {display_string}", - ); + // Optional: Dictionary needle test (if provided) + if let Some(needle_test) = test_case.dict_needle_test { + in_list_raw!( + batch, + needle_test.list_values, + &false, + needle_test.expected, + Arc::clone(&col_a), + &schema + ); + } - // Evaluate the expression - let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; - let result = as_boolean_array(&result); + Ok(()) + } - // Expected: first row {1, "a"} matches struct1, - // second row {2, "b"} doesn't match, - // third row {3, "c"} matches struct3 - let expected = BooleanArray::from(vec![Some(true), Some(false), Some(true)]); - assert_eq!(result, &expected); + // Test case 1: UTF8 + // Dictionary: keys [0, 1, null] → values ["a", "d", -] + // Rows: ["a", "d", null] + let utf8_case = DictionaryInListTestCase { + name: "dictionary_utf8", + dict_type: DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Utf8), + ), + dict_keys: vec![Some(0), Some(1), None], + dict_values: Arc::new(StringArray::from(vec![Some("a"), Some("d")])), + list_values_no_null: vec![lit("a"), lit("b")], + list_values_with_null: vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))], + expected_1: vec![Some(true), Some(false), None], + expected_2: vec![Some(false), Some(true), None], + expected_3: vec![Some(true), None, None], + expected_4: vec![Some(false), None, None], + dict_needle_test: None, + }; - // Test NOT IN as well - let expr_not = Arc::new(InListExpr::new( - Arc::clone(&col_a), - vec![ - lit(ScalarValue::Struct(Arc::new(StructArray::new( - struct_fields.clone(), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(StringArray::from(vec!["a"])), - ], - None, - )))), - lit(ScalarValue::Struct(Arc::new(StructArray::new( - struct_fields.clone(), - vec![ - Arc::new(Int32Array::from(vec![3])), - Arc::new(StringArray::from(vec!["c"])), - ], - None, - )))), + // Test case 2: Int64 with dictionary needles + // Dictionary: keys [0, 1, null] → values [10, 20, -] + // Rows: [10, 20, null] + let int64_case = DictionaryInListTestCase { + name: "dictionary_int64", + dict_type: DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Int64), + ), + dict_keys: vec![Some(0), Some(1), None], + dict_values: Arc::new(Int64Array::from(vec![Some(10), Some(20)])), + list_values_no_null: vec![lit(10i64), lit(15i64)], + list_values_with_null: vec![ + lit(10i64), + lit(15i64), + lit(ScalarValue::Int64(None)), ], - true, - None, - )); + expected_1: vec![Some(true), Some(false), None], + expected_2: vec![Some(false), Some(true), None], + expected_3: vec![Some(true), None, None], + expected_4: vec![Some(false), None, None], + dict_needle_test: Some(DictNeedleTest { + list_values: vec![ + dict_lit_int64(DataType::Int16, 10), + dict_lit_int64(DataType::Int16, 15), + ], + expected: vec![Some(true), Some(false), None], + }), + }; - let result_not = expr_not.evaluate(&batch)?.into_array(batch.num_rows())?; - let result_not = as_boolean_array(&result_not); + // Test case 3: Float64 with NaN and dictionary needles + // Dictionary: keys [0, 1, null, 2] → values [1.5, 3.7, NaN, -] + // Rows: [1.5, 3.7, null, NaN] + // Note: NaN is a value (not null), so it goes in the values array + let float64_case = DictionaryInListTestCase { + name: "dictionary_float64", + dict_type: DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Float64), + ), + dict_keys: vec![Some(0), Some(1), None, Some(2)], + dict_values: Arc::new(Float64Array::from(vec![ + Some(1.5), // index 0 + Some(3.7), // index 1 + Some(f64::NAN), // index 2 + ])), + list_values_no_null: vec![lit(1.5f64), lit(2.0f64)], + list_values_with_null: vec![ + lit(1.5f64), + lit(2.0f64), + lit(ScalarValue::Float64(None)), + ], + // Test 1: a IN (1.5, 2.0) → [true, false, null, false] + // NaN is false because NaN not in list and no NULL in list + expected_1: vec![Some(true), Some(false), None, Some(false)], + // Test 2: a NOT IN (1.5, 2.0) → [false, true, null, true] + // NaN is true because NaN not in list + expected_2: vec![Some(false), Some(true), None, Some(true)], + // Test 3: a IN (1.5, 2.0, NULL) → [true, null, null, null] + // 3.7 and NaN become null due to NULL in list (three-valued logic) + expected_3: vec![Some(true), None, None, None], + // Test 4: a NOT IN (1.5, 2.0, NULL) → [false, null, null, null] + // 3.7 and NaN become null due to NULL in list + expected_4: vec![Some(false), None, None, None], + dict_needle_test: Some(DictNeedleTest { + list_values: vec![ + dict_lit_float64(DataType::UInt16, 1.5), + dict_lit_float64(DataType::UInt16, 2.0), + ], + expected: vec![Some(true), Some(false), None, Some(false)], + }), + }; - let expected_not = BooleanArray::from(vec![Some(false), Some(true), Some(false)]); - assert_eq!(result_not, &expected_not); + // Execute all test cases + let test_name = utf8_case.name; + run_dictionary_in_list_test(utf8_case).map_err(|e| { + datafusion_common::DataFusionError::Execution(format!( + "Dictionary test '{test_name}' failed: {e}" + )) + })?; - Ok(()) - } + let test_name = int64_case.name; + run_dictionary_in_list_test(int64_case).map_err(|e| { + datafusion_common::DataFusionError::Execution(format!( + "Dictionary test '{test_name}' failed: {e}" + )) + })?; - #[test] - fn test_in_list_null_handling_comprehensive() -> Result<()> { - // Comprehensive test demonstrating SQL three-valued logic for IN expressions - // This test explicitly shows all possible outcomes: true, false, and null - let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + let test_name = float64_case.name; + run_dictionary_in_list_test(float64_case).map_err(|e| { + datafusion_common::DataFusionError::Execution(format!( + "Dictionary test '{test_name}' failed: {e}" + )) + })?; - // Test data: [1, 2, 3, null] - // - 1 will match in both lists - // - 2 will not match in either list - // - 3 will not match in either list - // - null is always null - let a = Int64Array::from(vec![Some(1), Some(2), Some(3), None]); + // Additional test: Dictionary deduplication with repeated keys + // This tests that multiple rows with the same key (pointing to the same value) + // are evaluated correctly + let dedup_case = DictionaryInListTestCase { + name: "dictionary_deduplication", + dict_type: DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Utf8), + ), + // Keys: [0, 1, 0, 1, null] - keys 0 and 1 are repeated + // This creates data: ["a", "d", "a", "d", null] + dict_keys: vec![Some(0), Some(1), Some(0), Some(1), None], + dict_values: Arc::new(StringArray::from(vec![Some("a"), Some("d")])), + list_values_no_null: vec![lit("a"), lit("b")], + list_values_with_null: vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))], + // Test 1: a IN ("a", "b") → [true, false, true, false, null] + // Rows 0 and 2 both have key 0 → "a", so both are true + expected_1: vec![Some(true), Some(false), Some(true), Some(false), None], + // Test 2: a NOT IN ("a", "b") → [false, true, false, true, null] + expected_2: vec![Some(false), Some(true), Some(false), Some(true), None], + // Test 3: a IN ("a", "b", NULL) → [true, null, true, null, null] + // "d" becomes null due to NULL in list + expected_3: vec![Some(true), None, Some(true), None, None], + // Test 4: a NOT IN ("a", "b", NULL) → [false, null, false, null, null] + expected_4: vec![Some(false), None, Some(false), None, None], + dict_needle_test: None, + }; + + let test_name = dedup_case.name; + run_dictionary_in_list_test(dedup_case).map_err(|e| { + datafusion_common::DataFusionError::Execution(format!( + "Dictionary test '{test_name}' failed: {e}" + )) + })?; + + // Additional test for Float64 NaN in IN list + let dict_type = + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Float64)); + let schema = Schema::new(vec![Field::new("a", dict_type.clone(), true)]); let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - // Case 1: List WITHOUT null - demonstrates true/false/null outcomes - // "a IN (1, 4)" - 1 matches, 2 and 3 don't match, null is null - let list = vec![lit(1i64), lit(4i64)]; - in_list!( - batch, - list, - &false, - vec![ - Some(true), // 1 is in the list → true - Some(false), // 2 is not in the list → false - Some(false), // 3 is not in the list → false - None, // null IN (...) → null (SQL three-valued logic) - ], - Arc::clone(&col_a), - &schema - ); + let keys = Int8Array::from(vec![Some(0), Some(1), None, Some(2)]); + let values = Float64Array::from(vec![Some(1.5), Some(3.7), Some(f64::NAN)]); + let dict_array: ArrayRef = + Arc::new(DictionaryArray::try_new(keys, Arc::new(values))?); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![dict_array])?; - // Case 2: List WITH null - demonstrates null propagation for non-matches - // "a IN (1, NULL)" - 1 matches (true), 2/3 don't match but list has null (null), null is null - let list = vec![lit(1i64), lit(ScalarValue::Int64(None))]; + // Test: a IN (1.5, 2.0, NaN) + let list_with_nan = vec![lit(1.5f64), lit(2.0f64), lit(f64::NAN)]; in_list!( batch, - list, + list_with_nan, &false, - vec![ - Some(true), // 1 is in the list → true (found match) - None, // 2 is not in list, but list has NULL → null (might match NULL) - None, // 3 is not in list, but list has NULL → null (might match NULL) - None, // null IN (...) → null (SQL three-valued logic) - ], - Arc::clone(&col_a), + vec![Some(true), Some(false), None, Some(true)], + col_a, &schema ); @@ -2085,738 +2898,1131 @@ mod tests { } #[test] - fn test_in_list_with_only_nulls() -> Result<()> { - // Edge case: IN list contains ONLY null values - let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); - let a = Int64Array::from(vec![Some(1), Some(2), None]); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + fn test_in_list_esoteric_types() -> Result<()> { + // Test esoteric/less common types to validate the transform and mapping flow. + // These types are reinterpreted to base primitive types (e.g., Timestamp -> UInt64, + // Interval -> Decimal128, Float16 -> UInt16). We just need to verify basic + // functionality works - no need for comprehensive null handling tests. + + // Helper: simple IN test that expects [Some(true), Some(false)] + let test_type = |data_type: DataType, + in_array: ArrayRef, + list_values: Vec| + -> Result<()> { + let schema = Schema::new(vec![Field::new("a", data_type.clone(), false)]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![in_array])?; - // "a IN (NULL, NULL)" - list has only nulls - let list = vec![lit(ScalarValue::Int64(None)), lit(ScalarValue::Int64(None))]; + let list = list_values.into_iter().map(lit).collect(); + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false)], + col_a, + &schema + ); + Ok(()) + }; - // All results should be NULL because: - // - Non-null values (1, 2) can't match anything concrete, but list might contain matching value - // - NULL value is always NULL in IN expressions - in_list!( - batch, - list.clone(), - &false, - vec![None, None, None], - Arc::clone(&col_a), - &schema - ); + // Timestamp types (all units map to Int64 -> UInt64) + test_type( + DataType::Timestamp(TimeUnit::Second, None), + Arc::new(TimestampSecondArray::from(vec![Some(1000), Some(2000)])), + vec![ + ScalarValue::TimestampSecond(Some(1000), None), + ScalarValue::TimestampSecond(Some(1500), None), + ], + )?; - // "a NOT IN (NULL, NULL)" - list has only nulls - // All results should still be NULL due to three-valued logic - in_list!( - batch, - list, - &true, - vec![None, None, None], - Arc::clone(&col_a), - &schema - ); + test_type( + DataType::Timestamp(TimeUnit::Millisecond, None), + Arc::new(TimestampMillisecondArray::from(vec![ + Some(1000000), + Some(2000000), + ])), + vec![ + ScalarValue::TimestampMillisecond(Some(1000000), None), + ScalarValue::TimestampMillisecond(Some(1500000), None), + ], + )?; - Ok(()) - } + test_type( + DataType::Timestamp(TimeUnit::Microsecond, None), + Arc::new(TimestampMicrosecondArray::from(vec![ + Some(1000000000), + Some(2000000000), + ])), + vec![ + ScalarValue::TimestampMicrosecond(Some(1000000000), None), + ScalarValue::TimestampMicrosecond(Some(1500000000), None), + ], + )?; - #[test] - fn test_in_list_multiple_nulls_deduplication() -> Result<()> { - // Test that multiple NULLs in the list are handled correctly - // This verifies deduplication doesn't break null handling - let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); - let col_a = col("a", &schema)?; + // Time32 and Time64 (map to Int32 -> UInt32 and Int64 -> UInt64 respectively) + test_type( + DataType::Time32(TimeUnit::Second), + Arc::new(Time32SecondArray::from(vec![Some(3600), Some(7200)])), + vec![ + ScalarValue::Time32Second(Some(3600)), + ScalarValue::Time32Second(Some(5400)), + ], + )?; - // Create array with multiple nulls: [1, 2, NULL, NULL, 3, NULL] - let array = Arc::new(Int64Array::from(vec![ - Some(1), - Some(2), - None, - None, - Some(3), - None, - ])) as ArrayRef; + test_type( + DataType::Time32(TimeUnit::Millisecond), + Arc::new(Time32MillisecondArray::from(vec![ + Some(3600000), + Some(7200000), + ])), + vec![ + ScalarValue::Time32Millisecond(Some(3600000)), + ScalarValue::Time32Millisecond(Some(5400000)), + ], + )?; - // Create InListExpr from array - let expr = Arc::new(InListExpr::try_new_from_array( - Arc::clone(&col_a), - array, - false, - )?) as Arc; + test_type( + DataType::Time64(TimeUnit::Microsecond), + Arc::new(Time64MicrosecondArray::from(vec![ + Some(3600000000), + Some(7200000000), + ])), + vec![ + ScalarValue::Time64Microsecond(Some(3600000000)), + ScalarValue::Time64Microsecond(Some(5400000000)), + ], + )?; - // Create test data: [1, 2, 3, 4, null] - let a = Int64Array::from(vec![Some(1), Some(2), Some(3), Some(4), None]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + test_type( + DataType::Time64(TimeUnit::Nanosecond), + Arc::new(Time64NanosecondArray::from(vec![ + Some(3600000000000), + Some(7200000000000), + ])), + vec![ + ScalarValue::Time64Nanosecond(Some(3600000000000)), + ScalarValue::Time64Nanosecond(Some(5400000000000)), + ], + )?; - // Evaluate the expression - let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; - let result = as_boolean_array(&result); + // Duration types (map to Int64 -> UInt64) + test_type( + DataType::Duration(TimeUnit::Second), + Arc::new(DurationSecondArray::from(vec![Some(86400), Some(172800)])), + vec![ + ScalarValue::DurationSecond(Some(86400)), + ScalarValue::DurationSecond(Some(129600)), + ], + )?; - // Expected behavior with multiple NULLs in list: - // - Values in the list (1,2,3) → true - // - Values not in the list (4) → NULL (because list contains NULL) - // - NULL input → NULL - let expected = BooleanArray::from(vec![ - Some(true), // 1 is in list - Some(true), // 2 is in list - Some(true), // 3 is in list - None, // 4 not in list, but list has NULLs - None, // NULL input - ]); - assert_eq!(result, &expected); + test_type( + DataType::Duration(TimeUnit::Millisecond), + Arc::new(DurationMillisecondArray::from(vec![ + Some(86400000), + Some(172800000), + ])), + vec![ + ScalarValue::DurationMillisecond(Some(86400000)), + ScalarValue::DurationMillisecond(Some(129600000)), + ], + )?; + + test_type( + DataType::Duration(TimeUnit::Microsecond), + Arc::new(DurationMicrosecondArray::from(vec![ + Some(86400000000), + Some(172800000000), + ])), + vec![ + ScalarValue::DurationMicrosecond(Some(86400000000)), + ScalarValue::DurationMicrosecond(Some(129600000000)), + ], + )?; + + test_type( + DataType::Duration(TimeUnit::Nanosecond), + Arc::new(DurationNanosecondArray::from(vec![ + Some(86400000000000), + Some(172800000000000), + ])), + vec![ + ScalarValue::DurationNanosecond(Some(86400000000000)), + ScalarValue::DurationNanosecond(Some(129600000000000)), + ], + )?; + + // Interval types (map to 16-byte Decimal128Type) + test_type( + DataType::Interval(IntervalUnit::YearMonth), + Arc::new(IntervalYearMonthArray::from(vec![Some(12), Some(24)])), + vec![ + ScalarValue::IntervalYearMonth(Some(12)), + ScalarValue::IntervalYearMonth(Some(18)), + ], + )?; + + test_type( + DataType::Interval(IntervalUnit::DayTime), + Arc::new(IntervalDayTimeArray::from(vec![ + Some(IntervalDayTime { + days: 1, + milliseconds: 0, + }), + Some(IntervalDayTime { + days: 2, + milliseconds: 0, + }), + ])), + vec![ + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 1, + milliseconds: 0, + })), + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 1, + milliseconds: 500, + })), + ], + )?; + + test_type( + DataType::Interval(IntervalUnit::MonthDayNano), + Arc::new(IntervalMonthDayNanoArray::from(vec![ + Some(IntervalMonthDayNano { + months: 1, + days: 0, + nanoseconds: 0, + }), + Some(IntervalMonthDayNano { + months: 2, + days: 0, + nanoseconds: 0, + }), + ])), + vec![ + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { + months: 1, + days: 0, + nanoseconds: 0, + })), + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { + months: 1, + days: 15, + nanoseconds: 0, + })), + ], + )?; + + // Decimal256 (maps to Decimal128Type for 16-byte width) + // Need to use with_precision_and_scale() to set the metadata + let precision = 38; + let scale = 10; + test_type( + DataType::Decimal256(precision, scale), + Arc::new( + Decimal256Array::from(vec![ + Some(i256::from(12345)), + Some(i256::from(67890)), + ]) + .with_precision_and_scale(precision, scale)?, + ), + vec![ + ScalarValue::Decimal256(Some(i256::from(12345)), precision, scale), + ScalarValue::Decimal256(Some(i256::from(54321)), precision, scale), + ], + )?; Ok(()) } - #[test] - fn test_not_in_null_handling_comprehensive() -> Result<()> { - // Comprehensive test demonstrating SQL three-valued logic for NOT IN expressions - // This test explicitly shows all possible outcomes for NOT IN: true, false, and null - let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + /// Helper: creates an InListExpr with `static_filter = None` + /// to force the column-reference evaluation path. + fn make_in_list_with_columns( + expr: Arc, + list: Vec>, + negated: bool, + ) -> Arc { + Arc::new(InListExpr::new(expr, list, negated, None)) + } - // Test data: [1, 2, 3, null] - let a = Int64Array::from(vec![Some(1), Some(2), Some(3), None]); + #[test] + fn test_in_list_with_columns_int32_scalars() -> Result<()> { + // Column-reference path with scalar literals (bypassing static filter) + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + None, + ]))], + )?; - // Case 1: List WITHOUT null - demonstrates true/false/null outcomes for NOT IN - // "a NOT IN (1, 4)" - 1 matches (false), 2 and 3 don't match (true), null is null - let list = vec![lit(1i64), lit(4i64)]; - in_list!( - batch, - list, - &true, - vec![ - Some(false), // 1 is in the list → NOT IN returns false - Some(true), // 2 is not in the list → NOT IN returns true - Some(true), // 3 is not in the list → NOT IN returns true - None, // null NOT IN (...) → null (SQL three-valued logic) - ], - Arc::clone(&col_a), - &schema + let list = vec![ + lit(ScalarValue::Int32(Some(1))), + lit(ScalarValue::Int32(Some(3))), + ]; + let expr = make_in_list_with_columns(col_a, list, false); + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), Some(false), Some(true), None,]) ); + Ok(()) + } - // Case 2: List WITH null - demonstrates null propagation for NOT IN - // "a NOT IN (1, NULL)" - 1 matches (false), 2/3 don't match but list has null (null), null is null - let list = vec![lit(1i64), lit(ScalarValue::Int64(None))]; - in_list!( - batch, - list, - &true, + #[test] + fn test_in_list_with_columns_int32_column_refs() -> Result<()> { + // IN list with column references + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), vec![ - Some(false), // 1 is in the list → NOT IN returns false - None, // 2 is not in known values, but list has NULL → null (can't prove it's not in list) - None, // 3 is not in known values, but list has NULL → null (can't prove it's not in list) - None, // null NOT IN (...) → null (SQL three-valued logic) + Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3), None])), + Arc::new(Int32Array::from(vec![ + Some(1), + Some(99), + Some(99), + Some(99), + ])), + Arc::new(Int32Array::from(vec![Some(99), Some(99), Some(3), None])), ], - Arc::clone(&col_a), - &schema - ); + )?; + + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?, col("c", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, false); + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: 1 IN (1, 99) → true + // row 1: 2 IN (99, 99) → false + // row 2: 3 IN (99, 3) → true + // row 3: NULL IN (99, NULL) → NULL + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), Some(false), Some(true), None,]) + ); Ok(()) } #[test] - fn test_in_list_null_type_column() -> Result<()> { - // Test with a column that has DataType::Null (not just nullable values) - // All values in a NullArray are null by definition - let schema = Schema::new(vec![Field::new("a", DataType::Null, true)]); - let a = NullArray::new(3); + fn test_in_list_with_columns_utf8_column_refs() -> Result<()> { + // IN list with Utf8 column references + let schema = Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(StringArray::from(vec!["x", "y", "z"])), + Arc::new(StringArray::from(vec!["x", "x", "z"])), + ], + )?; + let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + let list = vec![col("b", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, false); - // "null_column IN (1, 2)" - comparing Null type against Int64 list - // Note: This tests type coercion behavior between Null and Int64 - let list = vec![lit(1i64), lit(2i64)]; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: "x" IN ("x") → true + // row 1: "y" IN ("x") → false + // row 2: "z" IN ("z") → true + assert_eq!(result, &BooleanArray::from(vec![true, false, true])); + Ok(()) + } - // All results should be NULL because: - // - Every value in the column is null (DataType::Null) - // - null IN (anything) always returns null per SQL three-valued logic - in_list!( - batch, - list.clone(), - &false, - vec![None, None, None], - Arc::clone(&col_a), - &schema - ); + #[test] + fn test_in_list_with_columns_negated() -> Result<()> { + // NOT IN with column references + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![1, 99, 3])), + ], + )?; - // "null_column NOT IN (1, 2)" - // Same behavior for NOT IN - null NOT IN (anything) is still null - in_list!( - batch, - list, - &true, - vec![None, None, None], - Arc::clone(&col_a), - &schema - ); + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, true); + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: 1 NOT IN (1) → false + // row 1: 2 NOT IN (99) → true + // row 2: 3 NOT IN (3) → false + assert_eq!(result, &BooleanArray::from(vec![false, true, false])); Ok(()) } #[test] - fn test_in_list_null_type_list() -> Result<()> { - // Test with a list that has DataType::Null - let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); - let a = Int64Array::from(vec![Some(1), Some(2), None]); + fn test_in_list_with_columns_null_in_list() -> Result<()> { + // IN list with NULL scalar (column-reference path) + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(Int32Array::from(vec![1, 2]))], + )?; - // Create a NullArray as the list - let null_array = Arc::new(NullArray::new(2)) as ArrayRef; + let list = vec![ + lit(ScalarValue::Int32(None)), + lit(ScalarValue::Int32(Some(1))), + ]; + let expr = make_in_list_with_columns(col_a, list, false); - // Try to create InListExpr with a NullArray list - // This tests whether try_new_from_array can handle Null type arrays - let expr = Arc::new(InListExpr::try_new_from_array( - Arc::clone(&col_a), - null_array, - false, - )?) as Arc; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; let result = as_boolean_array(&result); - - // If it succeeds, all results should be NULL - // because the list contains only null type values - let expected = BooleanArray::from(vec![None, None, None]); - assert_eq!(result, &expected); - + // row 0: 1 IN (NULL, 1) → true (true OR null = true) + // row 1: 2 IN (NULL, 1) → NULL (false OR null = null) + assert_eq!(result, &BooleanArray::from(vec![Some(true), None])); Ok(()) } #[test] - fn test_in_list_null_type_both() -> Result<()> { - // Test when both column and list are DataType::Null - let schema = Schema::new(vec![Field::new("a", DataType::Null, true)]); - let a = NullArray::new(3); - let col_a = col("a", &schema)?; - - // Create a NullArray as the list - let null_array = Arc::new(NullArray::new(2)) as ArrayRef; + fn test_in_list_with_columns_float_nan() -> Result<()> { + // Verify NaN == NaN is true in the column-reference path + // (consistent with Arrow's totalOrder semantics) + let schema = Schema::new(vec![ + Field::new("a", DataType::Float64, false), + Field::new("b", DataType::Float64, false), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Float64Array::from(vec![f64::NAN, 1.0, f64::NAN])), + Arc::new(Float64Array::from(vec![f64::NAN, 2.0, 0.0])), + ], + )?; - // Try to create InListExpr with both Null types - let expr = Arc::new(InListExpr::try_new_from_array( - Arc::clone(&col_a), - null_array, - false, - )?) as Arc; + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, false); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; let result = as_boolean_array(&result); + // row 0: NaN IN (NaN) → true + // row 1: 1.0 IN (2.0) → false + // row 2: NaN IN (0.0) → false + assert_eq!(result, &BooleanArray::from(vec![true, false, false])); + Ok(()) + } - // If successful, all results should be NULL - // null IN [null, null] -> null - let expected = BooleanArray::from(vec![None, None, None]); - assert_eq!(result, &expected); + /// Tests that short-circuit evaluation produces correct results. + /// When all rows match after the first list item, remaining items + /// should be skipped without affecting correctness. + #[test] + fn test_in_list_with_columns_short_circuit() -> Result<()> { + // a IN (b, c) where b already matches every row of a + // The short-circuit should skip evaluating c + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![1, 2, 3])), // b == a for all rows + Arc::new(Int32Array::from(vec![99, 99, 99])), + ], + )?; + + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?, col("c", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, false); + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!(result, &BooleanArray::from(vec![true, true, true])); Ok(()) } + /// Short-circuit must NOT skip when nulls are present (three-valued logic). + /// Even if all non-null values are true, null rows keep the result as null. #[test] - fn test_in_list_comprehensive_null_handling() -> Result<()> { - // Comprehensive test for IN LIST operations with various NULL handling scenarios. - // This test covers the key cases validated against DuckDB as the source of truth. - // - // Note: Some scalar literal tests (like NULL IN (1, 2)) are omitted as they - // appear to expose an issue with static filter optimization. These are covered - // by existing tests like in_list_no_cols(). - - let schema = Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])); - let col_b = col("b", &schema)?; - let null_i32 = ScalarValue::Int32(None); + fn test_in_list_with_columns_short_circuit_with_nulls() -> Result<()> { + // a IN (b, c) where a has nulls + // Even if b matches all non-null rows, result should preserve nulls + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])), + Arc::new(Int32Array::from(vec![1, 2, 3])), // matches non-null rows + Arc::new(Int32Array::from(vec![99, 99, 99])), + ], + )?; - // Helper to create a batch - let make_batch = |values: Vec>| -> Result { - let array = Arc::new(Int32Array::from(values)); - Ok(RecordBatch::try_new(Arc::clone(&schema), vec![array])?) - }; + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?, col("c", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, false); - // Helper to run a test - let run_test = |batch: &RecordBatch, - expr: Arc, - list: Vec>, - expected: Vec>| - -> Result<()> { - let in_expr = in_list(expr, list, &false, schema.as_ref())?; - let result = in_expr.evaluate(batch)?.into_array(batch.num_rows())?; - let result = as_boolean_array(&result); - assert_eq!(result, &BooleanArray::from(expected)); - Ok(()) - }; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: 1 IN (1, 99) → true + // row 1: NULL IN (2, 99) → NULL + // row 2: 3 IN (3, 99) → true + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), None, Some(true)]) + ); + Ok(()) + } - // ======================================================================== - // COLUMN TESTS - col(b) IN [1, 2] - // ======================================================================== + /// Tests the make_comparator + collect_bool fallback path using + /// struct column references (nested types don't support arrow_eq). + #[test] + fn test_in_list_with_columns_struct() -> Result<()> { + let struct_fields = Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, false), + ]); + let struct_dt = DataType::Struct(struct_fields.clone()); - // [1] IN (1, 2) => [TRUE] - let batch = make_batch(vec![Some(1)])?; - run_test( - &batch, - Arc::clone(&col_b), - vec![lit(1i32), lit(2i32)], - vec![Some(true)], - )?; + let schema = Schema::new(vec![ + Field::new("a", struct_dt.clone(), true), + Field::new("b", struct_dt.clone(), false), + Field::new("c", struct_dt.clone(), false), + ]); - // [1, 2] IN (1, 2) => [TRUE, TRUE] - let batch = make_batch(vec![Some(1), Some(2)])?; - run_test( - &batch, - Arc::clone(&col_b), - vec![lit(1i32), lit(2i32)], - vec![Some(true), Some(true)], - )?; + // a: [{1,"a"}, {2,"b"}, NULL, {4,"d"}] + // b: [{1,"a"}, {9,"z"}, {3,"c"}, {4,"d"}] + // c: [{9,"z"}, {2,"b"}, {9,"z"}, {9,"z"}] + let a = Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4])), + Arc::new(StringArray::from(vec!["a", "b", "c", "d"])), + ], + Some(vec![true, true, false, true].into()), + )); + let b = Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 9, 3, 4])), + Arc::new(StringArray::from(vec!["a", "z", "c", "d"])), + ], + None, + )); + let c = Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![9, 2, 9, 9])), + Arc::new(StringArray::from(vec!["z", "b", "z", "z"])), + ], + None, + )); - // [3, 4] IN (1, 2) => [FALSE, FALSE] - let batch = make_batch(vec![Some(3), Some(4)])?; - run_test( - &batch, - Arc::clone(&col_b), - vec![lit(1i32), lit(2i32)], - vec![Some(false), Some(false)], - )?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b, c])?; - // [1, NULL] IN (1, 2) => [TRUE, NULL] - let batch = make_batch(vec![Some(1), None])?; - run_test( - &batch, - Arc::clone(&col_b), - vec![lit(1i32), lit(2i32)], - vec![Some(true), None], - )?; + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?, col("c", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, false); - // [3, NULL] IN (1, 2) => [FALSE, NULL] (no match, NULL is NULL) - let batch = make_batch(vec![Some(3), None])?; - run_test( - &batch, - Arc::clone(&col_b), - vec![lit(1i32), lit(2i32)], - vec![Some(false), None], - )?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: {1,"a"} IN ({1,"a"}, {9,"z"}) → true (matches b) + // row 1: {2,"b"} IN ({9,"z"}, {2,"b"}) → true (matches c) + // row 2: NULL IN ({3,"c"}, {9,"z"}) → NULL + // row 3: {4,"d"} IN ({4,"d"}, {9,"z"}) → true (matches b) + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), Some(true), None, Some(true)]) + ); - // ======================================================================== - // COLUMN WITH NULL IN LIST - col(b) IN [NULL, 1] - // ======================================================================== + // Also test NOT IN + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?, col("c", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, true); - // [1] IN (NULL, 1) => [TRUE] (found match) - let batch = make_batch(vec![Some(1)])?; - run_test( - &batch, - Arc::clone(&col_b), - vec![lit(null_i32.clone()), lit(1i32)], - vec![Some(true)], - )?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: {1,"a"} NOT IN ({1,"a"}, {9,"z"}) → false + // row 1: {2,"b"} NOT IN ({9,"z"}, {2,"b"}) → false + // row 2: NULL NOT IN ({3,"c"}, {9,"z"}) → NULL + // row 3: {4,"d"} NOT IN ({4,"d"}, {9,"z"}) → false + assert_eq!( + result, + &BooleanArray::from(vec![Some(false), Some(false), None, Some(false)]) + ); + Ok(()) + } - // [2] IN (NULL, 1) => [NULL] (no match, but list has NULL) - let batch = make_batch(vec![Some(2)])?; - run_test( - &batch, - Arc::clone(&col_b), - vec![lit(null_i32.clone()), lit(1i32)], - vec![None], - )?; + // ----------------------------------------------------------------------- + // Tests for try_new_from_array: evaluates `needle IN in_array`. + // + // This exercises the code path used by HashJoin dynamic filter pushdown, + // where in_array is built directly from the join's build-side arrays. + // Unlike try_new (used by SQL IN expressions), which always produces a + // non-Dictionary in_array because evaluate_list() flattens Dictionary + // scalars, try_new_from_array passes the array directly and can produce + // a Dictionary in_array. + // ----------------------------------------------------------------------- + + fn wrap_in_dict(array: ArrayRef) -> ArrayRef { + let keys = Int32Array::from((0..array.len() as i32).collect::>()); + Arc::new(DictionaryArray::new(keys, array)) + } - // [NULL] IN (NULL, 1) => [NULL] - let batch = make_batch(vec![None])?; - run_test( - &batch, - Arc::clone(&col_b), - vec![lit(null_i32.clone()), lit(1i32)], - vec![None], - )?; + /// Evaluates `needle IN in_array` via try_new_from_array, the same + /// path used by HashJoin dynamic filter pushdown (not the SQL literal + /// IN path which goes through try_new). + fn eval_in_list_from_array( + needle: ArrayRef, + in_array: ArrayRef, + ) -> Result { + let schema = + Schema::new(vec![Field::new("a", needle.data_type().clone(), false)]); + let col_a = col("a", &schema)?; + let expr = Arc::new(InListExpr::try_new_from_array( + col_a, in_array, false, &schema, + )?) as Arc; + let batch = RecordBatch::try_new(Arc::new(schema), vec![needle])?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + Ok(as_boolean_array(&result).clone()) + } - // ======================================================================== - // COLUMN WITH ALL NULLS IN LIST - col(b) IN [NULL, NULL] - // ======================================================================== + #[test] + fn test_in_list_from_array_type_combinations() -> Result<()> { + use arrow::compute::cast; - // [1] IN (NULL, NULL) => [NULL] - let batch = make_batch(vec![Some(1)])?; - run_test( - &batch, - Arc::clone(&col_b), - vec![lit(null_i32.clone()), lit(null_i32.clone())], - vec![None], - )?; + // All cases: needle[0] and needle[2] match, needle[1] does not. + let expected = BooleanArray::from(vec![Some(true), Some(false), Some(true)]); - // [NULL] IN (NULL, NULL) => [NULL] - let batch = make_batch(vec![None])?; - run_test( - &batch, - Arc::clone(&col_b), - vec![lit(null_i32.clone()), lit(null_i32.clone())], - vec![None], - )?; + // Base arrays cast to each target type + let base_in = Arc::new(Int64Array::from(vec![1i64, 2, 3])) as ArrayRef; + let base_needle = Arc::new(Int64Array::from(vec![1i64, 4, 2])) as ArrayRef; - // ======================================================================== - // LITERAL IN LIST WITH COLUMN - lit(1) IN [2, col(b)] - // ======================================================================== + // Test all specializations in instantiate_static_filter + let primitive_types = vec![ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float32, + DataType::Float64, + ]; - // 1 IN (2, [1]) => [TRUE] (matches column value) - let batch = make_batch(vec![Some(1)])?; - run_test( - &batch, - lit(1i32), - vec![lit(2i32), Arc::clone(&col_b)], - vec![Some(true)], - )?; + for dt in &primitive_types { + let in_array = cast(&base_in, dt)?; + let needle = cast(&base_needle, dt)?; - // 1 IN (2, [3]) => [FALSE] (no match) - let batch = make_batch(vec![Some(3)])?; - run_test( - &batch, - lit(1i32), - vec![lit(2i32), Arc::clone(&col_b)], - vec![Some(false)], - )?; + // T in_array, T needle + assert_eq!( + expected, + eval_in_list_from_array(Arc::clone(&needle), Arc::clone(&in_array))?, + "same-type failed for {dt:?}" + ); - // 1 IN (2, [NULL]) => [NULL] (no match, column is NULL) - let batch = make_batch(vec![None])?; - run_test( - &batch, - lit(1i32), - vec![lit(2i32), Arc::clone(&col_b)], - vec![None], - )?; + // T in_array, Dict(Int32, T) needle + assert_eq!( + expected, + eval_in_list_from_array(wrap_in_dict(needle), in_array)?, + "dict-needle failed for {dt:?}" + ); + } - // ======================================================================== - // COLUMN IN LIST CONTAINING ITSELF - col(b) IN [1, col(b)] - // ======================================================================== + // Utf8 (falls through to ArrayStaticFilter) + let utf8_in = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; + let utf8_needle = Arc::new(StringArray::from(vec!["a", "d", "b"])) as ArrayRef; - // [1] IN (1, [1]) => [TRUE] (always matches - either list literal or itself) - let batch = make_batch(vec![Some(1)])?; - run_test( - &batch, - Arc::clone(&col_b), - vec![lit(1i32), Arc::clone(&col_b)], - vec![Some(true)], - )?; + // Utf8 in_array, Utf8 needle + assert_eq!( + expected, + eval_in_list_from_array(Arc::clone(&utf8_needle), Arc::clone(&utf8_in),)? + ); - // [2] IN (1, [2]) => [TRUE] (matches itself) - let batch = make_batch(vec![Some(2)])?; - run_test( - &batch, - Arc::clone(&col_b), - vec![lit(1i32), Arc::clone(&col_b)], - vec![Some(true)], - )?; + // Utf8 in_array, Dict(Utf8) needle + assert_eq!( + expected, + eval_in_list_from_array( + wrap_in_dict(Arc::clone(&utf8_needle)), + Arc::clone(&utf8_in), + )? + ); - // [NULL] IN (1, [NULL]) => [NULL] (NULL is never equal to anything) - let batch = make_batch(vec![None])?; - run_test( - &batch, - Arc::clone(&col_b), - vec![lit(1i32), Arc::clone(&col_b)], - vec![None], - )?; + // Dict(Utf8) in_array, Dict(Utf8) needle: the #20937 bug + assert_eq!( + expected, + eval_in_list_from_array( + wrap_in_dict(Arc::clone(&utf8_needle)), + wrap_in_dict(Arc::clone(&utf8_in)), + )? + ); + + // Struct in_array, Struct needle: multi-column join + let struct_fields = Fields::from(vec![ + Field::new("c0", DataType::Utf8, true), + Field::new("c1", DataType::Int64, true), + ]); + let make_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef { + let pairs: Vec<(FieldRef, ArrayRef)> = + struct_fields.iter().cloned().zip([c0, c1]).collect(); + Arc::new(StructArray::from(pairs)) + }; + assert_eq!( + expected, + eval_in_list_from_array( + make_struct( + Arc::clone(&utf8_needle), + Arc::new(Int64Array::from(vec![1, 4, 2])), + ), + make_struct( + Arc::clone(&utf8_in), + Arc::new(Int64Array::from(vec![1, 2, 3])), + ), + )? + ); + + // Struct with Dict fields: multi-column Dict join + let dict_struct_fields = Fields::from(vec![ + Field::new( + "c0", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + Field::new("c1", DataType::Int64, true), + ]); + let make_dict_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef { + let pairs: Vec<(FieldRef, ArrayRef)> = + dict_struct_fields.iter().cloned().zip([c0, c1]).collect(); + Arc::new(StructArray::from(pairs)) + }; + assert_eq!( + expected, + eval_in_list_from_array( + make_dict_struct( + wrap_in_dict(Arc::clone(&utf8_needle)), + Arc::new(Int64Array::from(vec![1, 4, 2])), + ), + make_dict_struct( + wrap_in_dict(Arc::clone(&utf8_in)), + Arc::new(Int64Array::from(vec![1, 2, 3])), + ), + )? + ); Ok(()) } + fn make_int32_dict_array(values: Vec>) -> ArrayRef { + let mut builder = PrimitiveDictionaryBuilder::::new(); + for v in values { + match v { + Some(val) => builder.append_value(val), + None => builder.append_null(), + } + } + Arc::new(builder.finish()) + } + + fn make_f64_dict_array(values: Vec>) -> ArrayRef { + let mut builder = PrimitiveDictionaryBuilder::::new(); + for v in values { + match v { + Some(val) => builder.append_value(val), + None => builder.append_null(), + } + } + Arc::new(builder.finish()) + } + #[test] - fn test_in_list_scalar_literal_cases() -> Result<()> { - // Test scalar literal cases (both NULL and non-NULL) to ensure SQL three-valued - // logic is correctly implemented. This covers the important case where a scalar - // value is tested against a list containing NULL. + fn test_try_new_from_array_dict_haystack_int32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let needle = Int32Array::from(vec![1, 2, 3, 4]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(needle)])?; - let schema = Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])); - let null_i32 = ScalarValue::Int32(None); + let haystack = make_int32_dict_array(vec![Some(1), None, Some(3)]); - // Helper to create a batch - let make_batch = |values: Vec>| -> Result { - let array = Arc::new(Int32Array::from(values)); - Ok(RecordBatch::try_new(Arc::clone(&schema), vec![array])?) - }; + let col_a = col("a", &schema)?; + let expr = InListExpr::try_new_from_array(col_a, haystack, false, &schema)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), None, Some(true), None]) + ); - // Helper to run a test - let run_test = |batch: &RecordBatch, - expr: Arc, - list: Vec>, - negated: bool, - expected: Vec>| - -> Result<()> { - let in_expr = in_list(expr, list, &negated, schema.as_ref())?; - let result = in_expr.evaluate(batch)?.into_array(batch.num_rows())?; - let result = as_boolean_array(&result); - let expected_array = BooleanArray::from(expected); - assert_eq!( - result, - &expected_array, - "Expected {:?}, got {:?}", - expected_array, - result.iter().collect::>() - ); - Ok(()) - }; + Ok(()) + } - let batch = make_batch(vec![Some(1)])?; + #[test] + fn test_in_list_from_array_type_mismatch_errors() -> Result<()> { + // Utf8 needle, Dict(Utf8) in_array: now works with dict haystack support + assert_eq!( + BooleanArray::from(vec![Some(true), Some(false), Some(true)]), + eval_in_list_from_array( + Arc::new(StringArray::from(vec!["a", "d", "b"])), + wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))), + )? + ); + + // Dict(Utf8) needle, Int64 in_array: type validation rejects at construction + let err = eval_in_list_from_array( + wrap_in_dict(Arc::new(StringArray::from(vec!["a", "d", "b"]))), + Arc::new(Int64Array::from(vec![1, 2, 3])), + ) + .unwrap_err() + .to_string(); + assert!(err.contains("The data type inlist should be same"), "{err}"); + + // Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different + // value types, type validation rejects at construction + let err = eval_in_list_from_array( + wrap_in_dict(Arc::new(Int64Array::from(vec![1, 4, 2]))), + wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))), + ) + .unwrap_err() + .to_string(); + assert!(err.contains("The data type inlist should be same"), "{err}"); - // ======================================================================== - // NULL LITERAL TESTS - // According to SQL semantics, NULL IN (any_list) should always return NULL - // ======================================================================== + Ok(()) + } - // NULL IN (1, 1) => NULL - run_test( - &batch, - lit(null_i32.clone()), - vec![lit(1i32), lit(1i32)], - false, - vec![None], - )?; + #[test] + fn test_try_new_from_array_dict_haystack_negated() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let needle = Int32Array::from(vec![1, 2, 3, 4]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(needle)])?; - // NULL IN (NULL, 1) => NULL - run_test( - &batch, - lit(null_i32.clone()), - vec![lit(null_i32.clone()), lit(1i32)], - false, - vec![None], - )?; + let haystack = make_int32_dict_array(vec![Some(1), None, Some(3)]); - // NULL IN (NULL, NULL) => NULL - run_test( - &batch, - lit(null_i32.clone()), - vec![lit(null_i32.clone()), lit(null_i32.clone())], - false, - vec![None], - )?; + let col_a = col("a", &schema)?; + let expr = InListExpr::try_new_from_array(col_a, haystack, true, &schema)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!( + result, + &BooleanArray::from(vec![Some(false), None, Some(false), None]) + ); - // ======================================================================== - // NON-NULL SCALAR LITERALS WITH NULL IN LIST - Int32 - // When a scalar value is NOT in a list containing NULL, the result is NULL - // When a scalar value IS in the list, the result is TRUE (NULL doesn't matter) - // ======================================================================== + Ok(()) + } - // 3 IN (0, 1, 2, NULL) => NULL (not in list, but list has NULL) - run_test( - &batch, - lit(3i32), - vec![lit(0i32), lit(1i32), lit(2i32), lit(null_i32.clone())], - false, - vec![None], - )?; + #[test] + fn test_try_new_from_array_dict_haystack_utf8() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + let needle = StringArray::from(vec!["a", "b", "c"]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(needle)])?; - // 3 NOT IN (0, 1, 2, NULL) => NULL (not in list, but list has NULL) - run_test( - &batch, - lit(3i32), - vec![lit(0i32), lit(1i32), lit(2i32), lit(null_i32.clone())], - true, - vec![None], - )?; + let dict_builder = StringDictionaryBuilder::::new(); + let mut builder = dict_builder; + builder.append_value("a"); + builder.append_value("c"); + let haystack: ArrayRef = Arc::new(builder.finish()); - // 1 IN (0, 1, 2, NULL) => TRUE (found match, NULL doesn't matter) - run_test( - &batch, - lit(1i32), - vec![lit(0i32), lit(1i32), lit(2i32), lit(null_i32.clone())], + let col_a = col("a", &schema)?; + let expr = InListExpr::try_new_from_array(col_a, haystack, false, &schema)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), Some(false), Some(true)]) + ); + + Ok(()) + } + + #[test] + fn test_try_new_from_array_dict_needle_and_plain_haystack() -> Result<()> { + let schema = Schema::new(vec![Field::new( + "a", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), false, - vec![Some(true)], - )?; + )]); - // 1 NOT IN (0, 1, 2, NULL) => FALSE (found match, NULL doesn't matter) - run_test( - &batch, - lit(1i32), - vec![lit(0i32), lit(1i32), lit(2i32), lit(null_i32.clone())], - true, - vec![Some(false)], - )?; + let needle = make_int32_dict_array(vec![Some(1), Some(2), Some(3), Some(4)]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::clone(&needle)])?; - // ======================================================================== - // NON-NULL SCALAR LITERALS WITH NULL IN LIST - String - // Same semantics as Int32 but with string type - // ======================================================================== + let haystack: ArrayRef = Arc::new(Int32Array::from(vec![1, 3])); + let col_a = col("a", &schema)?; + let expr = InListExpr::try_new_from_array(col_a, haystack, false, &schema)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), Some(false), Some(true), Some(false)]) + ); - let schema_str = - Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)])); - let batch_str = RecordBatch::try_new( - Arc::clone(&schema_str), - vec![Arc::new(StringArray::from(vec![Some("dummy")]))], - )?; - let null_str = ScalarValue::Utf8(None); + Ok(()) + } - let run_test_str = |expr: Arc, - list: Vec>, - negated: bool, - expected: Vec>| - -> Result<()> { - let in_expr = in_list(expr, list, &negated, schema_str.as_ref())?; - let result = in_expr - .evaluate(&batch_str)? - .into_array(batch_str.num_rows())?; - let result = as_boolean_array(&result); - let expected_array = BooleanArray::from(expected); - assert_eq!( - result, - &expected_array, - "Expected {:?}, got {:?}", - expected_array, - result.iter().collect::>() - ); - Ok(()) - }; + #[test] + fn test_try_new_from_array_dict_haystack_float64() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + let needle = Float64Array::from(vec![1.0, 2.0, 3.0]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(needle)])?; - // 'c' IN ('a', 'b', NULL) => NULL (not in list, but list has NULL) - run_test_str( - lit("c"), - vec![lit("a"), lit("b"), lit(null_str.clone())], - false, - vec![None], - )?; + let haystack = make_f64_dict_array(vec![Some(1.0), Some(3.0)]); - // 'c' NOT IN ('a', 'b', NULL) => NULL (not in list, but list has NULL) - run_test_str( - lit("c"), - vec![lit("a"), lit("b"), lit(null_str.clone())], - true, - vec![None], - )?; + let col_a = col("a", &schema)?; + let expr = InListExpr::try_new_from_array(col_a, haystack, false, &schema)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), Some(false), Some(true)]) + ); - // 'a' IN ('a', 'b', NULL) => TRUE (found match, NULL doesn't matter) - run_test_str( - lit("a"), - vec![lit("a"), lit("b"), lit(null_str.clone())], - false, - vec![Some(true)], - )?; + Ok(()) + } - // 'a' NOT IN ('a', 'b', NULL) => FALSE (found match, NULL doesn't matter) - run_test_str( - lit("a"), - vec![lit("a"), lit("b"), lit(null_str.clone())], - true, - vec![Some(false)], - )?; + #[test] + fn test_try_new_from_array_type_mismatch_rejects() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let col_a = col("a", &schema)?; + let haystack: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0])); + let result = InListExpr::try_new_from_array(col_a, haystack, false, &schema); + assert!(result.is_err()); Ok(()) } #[test] - fn test_in_list_tuple_cases() -> Result<()> { - // Test tuple/struct cases from the original request: (lit, lit) IN (lit, lit) - // These test row-wise comparisons like (1, 2) IN ((1, 2), (3, 4)) + fn test_try_new_from_array_struct_haystack() -> Result<()> { + let struct_fields = Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, false), + ]); + let struct_dt = DataType::Struct(struct_fields.clone()); + let schema = Schema::new(vec![Field::new("a", struct_dt, true)]); - let schema = Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])); + // Needle: [{1,"a"}, {2,"b"}, NULL, {4,"d"}] + let needle = Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4])), + Arc::new(StringArray::from(vec!["a", "b", "c", "d"])), + ], + Some(vec![true, true, false, true].into()), + )); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![needle])?; - // Helper to create struct scalars for tuple comparisons - let make_struct = |v1: Option, v2: Option| -> ScalarValue { - let fields = Fields::from(vec![ - Field::new("field_0", DataType::Int32, true), - Field::new("field_1", DataType::Int32, true), - ]); - ScalarValue::Struct(Arc::new(StructArray::new( - fields, - vec![ - Arc::new(Int32Array::from(vec![v1])), - Arc::new(Int32Array::from(vec![v2])), - ], - None, - ))) - }; + // Haystack: [{1,"a"}, {4,"d"}] + let haystack: ArrayRef = Arc::new(StructArray::new( + struct_fields, + vec![ + Arc::new(Int32Array::from(vec![1, 4])), + Arc::new(StringArray::from(vec!["a", "d"])), + ], + None, + )); - // Need a single row batch for scalar tests - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![Arc::new(Int32Array::from(vec![Some(1)]))], + let col_a = col("a", &schema)?; + let expr = InListExpr::try_new_from_array( + Arc::clone(&col_a), + Arc::clone(&haystack), + false, + &schema, )?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // {1,"a"} -> true, {2,"b"} -> false, NULL -> NULL, {4,"d"} -> true + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), Some(false), None, Some(true)]) + ); - // Helper to run tuple tests - let run_tuple_test = |lhs: ScalarValue, - list: Vec, - expected: Vec>| - -> Result<()> { - let expr = in_list( - lit(lhs), - list.into_iter().map(lit).collect(), - &false, - schema.as_ref(), - )?; - let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; - let result = as_boolean_array(&result); - assert_eq!(result, &BooleanArray::from(expected)); - Ok(()) - }; + // Negated path + let expr = InListExpr::try_new_from_array(col_a, haystack, true, &schema)?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!( + result, + &BooleanArray::from(vec![Some(false), Some(true), None, Some(false)]) + ); - // (NULL, NULL) IN ((1, 2)) => FALSE (tuples don't match) - run_tuple_test( - make_struct(None, None), - vec![make_struct(Some(1), Some(2))], - vec![Some(false)], - )?; + Ok(()) + } +} - // (NULL, NULL) IN ((NULL, 1)) => FALSE - run_tuple_test( - make_struct(None, None), - vec![make_struct(None, Some(1))], - vec![Some(false)], - )?; +#[cfg(all(test, feature = "proto"))] +mod proto_tests { + use super::*; + use crate::expressions::{Column, col, lit}; + use crate::proto_test_util::{ + StubDecoder, StubEncoder, UnreachableDecoder, column_node, + }; + use arrow::datatypes::Field; + use datafusion_common::DataFusionError; + use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx; + use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx; + use datafusion_proto_models::protobuf::{ + PhysicalExprNode, PhysicalInListNode, physical_expr_node, + }; + + /// Build an `InListExpr` proto node with the given children. + fn in_list_node( + expr: Option>, + list: Vec, + negated: bool, + ) -> PhysicalExprNode { + PhysicalExprNode { + expr_id: None, + expr_type: Some(physical_expr_node::ExprType::InList(Box::new( + PhysicalInListNode { + expr, + list, + negated, + }, + ))), + } + } - // (NULL, NULL) IN ((NULL, NULL)) => TRUE (exact match including nulls) - run_tuple_test( - make_struct(None, None), - vec![make_struct(None, None)], - vec![Some(true)], - )?; + /// An `InListExpr` over a column with one literal value. + fn in_list_fixture() -> InListExpr { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + InListExpr::try_new(col("a", &schema).unwrap(), vec![lit(1)], false, &schema) + .unwrap() + } - // (NULL, 1) IN ((1, 2)) => FALSE - run_tuple_test( - make_struct(None, Some(1)), - vec![make_struct(Some(1), Some(2))], - vec![Some(false)], - )?; + #[test] + fn try_to_proto_encodes_in_list() { + let in_list = in_list_fixture(); + let encoder = StubEncoder::ok(); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let node = in_list + .try_to_proto(&ctx) + .unwrap() + .expect("InListExpr should encode to Some(node)"); + + // Built-in exprs never set expr_id; only dynamic filters do. + assert!(node.expr_id.is_none()); + let in_list_node = match node.expr_type { + Some(physical_expr_node::ExprType::InList(boxed)) => *boxed, + other => panic!("expected an InList node, got {other:?}"), + }; + assert!(!in_list_node.negated); + assert!(in_list_node.expr.is_some()); + assert_eq!(in_list_node.list.len(), 1); + } - // (NULL, 1) IN ((NULL, 1)) => TRUE (exact match) - run_tuple_test( - make_struct(None, Some(1)), - vec![make_struct(None, Some(1))], - vec![Some(true)], - )?; + #[test] + fn try_to_proto_propagates_expr_encode_error() { + let in_list = in_list_fixture(); + let encoder = StubEncoder::failing_on(1); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + let err = in_list.try_to_proto(&ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1"))); + } - // (NULL, 1) IN ((NULL, NULL)) => FALSE - run_tuple_test( - make_struct(None, Some(1)), - vec![make_struct(None, None)], - vec![Some(false)], - )?; + #[test] + fn try_to_proto_propagates_list_encode_error() { + let in_list = in_list_fixture(); + // Call 1 is for `expr`, Call 2 is for the first element of `list` + let encoder = StubEncoder::failing_on(2); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + let err = in_list.try_to_proto(&ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 2"))); + } - // (1, 2) IN ((1, 2)) => TRUE - run_tuple_test( - make_struct(Some(1), Some(2)), - vec![make_struct(Some(1), Some(2))], - vec![Some(true)], - )?; + #[test] + fn try_from_proto_decodes_in_list() { + let node = in_list_node( + Some(Box::new(column_node("a"))), + vec![column_node("b")], + true, + ); + let schema = Schema::new(vec![Field::new("decoded", DataType::Int32, true)]); + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); - // (1, 3) IN ((1, 2)) => FALSE - run_tuple_test( - make_struct(Some(1), Some(3)), - vec![make_struct(Some(1), Some(2))], - vec![Some(false)], - )?; + let decoded = InListExpr::try_from_proto(&node, &ctx).unwrap(); + let in_list = decoded + .downcast_ref::() + .expect("decoded expr should be an InListExpr"); - // (4, 4) IN ((1, 2)) => FALSE - run_tuple_test( - make_struct(Some(4), Some(4)), - vec![make_struct(Some(1), Some(2))], - vec![Some(false)], - )?; + assert!(in_list.negated()); + assert!(in_list.expr().downcast_ref::().is_some()); + assert_eq!(in_list.list().len(), 1); + } - // (1, 1) IN ((NULL, 1)) => FALSE - run_tuple_test( - make_struct(Some(1), Some(1)), - vec![make_struct(None, Some(1))], - vec![Some(false)], - )?; + #[test] + fn try_from_proto_rejects_non_in_list_node() { + let node = column_node("a"); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = InListExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!( + err, + DataFusionError::Internal(msg) if msg.contains("PhysicalExprNode is not a InList") + )); + } - // (1, 1) IN ((NULL, NULL)) => FALSE - run_tuple_test( - make_struct(Some(1), Some(1)), - vec![make_struct(None, None)], - vec![Some(false)], - )?; + #[test] + fn try_from_proto_rejects_missing_expr() { + let node = in_list_node(None, vec![column_node("b")], false); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = InListExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!( + err, + DataFusionError::Internal(msg) if msg.contains("InListExpr is missing required field 'expr'") + )); + } - Ok(()) + #[test] + fn try_from_proto_propagates_expr_decode_error() { + let node = in_list_node( + Some(Box::new(column_node("a"))), + vec![column_node("b")], + false, + ); + let schema = Schema::empty(); + let decoder = StubDecoder::failing_on(1); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = InListExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1"))); + } + + #[test] + fn try_from_proto_propagates_list_decode_error() { + let node = in_list_node( + Some(Box::new(column_node("a"))), + vec![column_node("b")], + false, + ); + let schema = Schema::empty(); + // Call 1 is `expr`, Call 2 is the first element of `list` + let decoder = StubDecoder::failing_on(2); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = InListExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 2"))); } } diff --git a/datafusion/physical-expr/src/expressions/in_list/array_static_filter.rs b/datafusion/physical-expr/src/expressions/in_list/array_static_filter.rs new file mode 100644 index 0000000000000..93bfcd49600d0 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/array_static_filter.rs @@ -0,0 +1,160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, BooleanArray, downcast_array, downcast_dictionary_array, + make_comparator, +}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::compute::{SortOptions, take}; +use arrow::datatypes::DataType; +use arrow::util::bit_iterator::BitIndexIterator; +use datafusion_common::HashMap; +use datafusion_common::Result; +use datafusion_common::hash_utils::{RandomState, with_hashes}; +use hashbrown::hash_map::RawEntryMut; + +use super::static_filter::StaticFilter; + +/// Static filter for InList that stores the array and hash set for O(1) lookups +#[derive(Debug, Clone)] +pub(super) struct ArrayStaticFilter { + in_array: ArrayRef, + state: RandomState, + /// Used to provide a lookup from value to in list index + /// + /// Note: usize::hash is not used, instead the raw entry + /// API is used to store entries w.r.t their value + map: HashMap, +} + +impl StaticFilter for ArrayStaticFilter { + fn null_count(&self) -> usize { + self.in_array.null_count() + } + + /// Checks if values in `v` are contained in the `in_array` using this hash set for lookup. + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + // Null type comparisons always return null (SQL three-valued logic) + if v.data_type() == &DataType::Null + || self.in_array.data_type() == &DataType::Null + { + let nulls = NullBuffer::new_null(v.len()); + return Ok(BooleanArray::new( + BooleanBuffer::new_unset(v.len()), + Some(nulls), + )); + } + + // Unwrap dictionary-encoded needles when the value type matches + // in_array, evaluating against the dictionary values and mapping + // back via keys. + downcast_dictionary_array! { + v => { + // Only unwrap when the haystack (in_array) type matches + // the dictionary value type + if v.values().data_type() == self.in_array.data_type() { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())); + } + } + _ => {} + } + + let needle_nulls = v.logical_nulls(); + let needle_nulls = needle_nulls.as_ref(); + let haystack_has_nulls = self.in_array.null_count() != 0; + + with_hashes([v], &self.state, |hashes| { + let cmp = make_comparator(v, &self.in_array, SortOptions::default())?; + Ok((0..v.len()) + .map(|i| { + // SQL three-valued logic: null IN (...) is always null + if needle_nulls.is_some_and(|nulls| nulls.is_null(i)) { + return None; + } + + let hash = hashes[i]; + let contains = self + .map + .raw_entry() + .from_hash(hash, |idx| cmp(i, *idx).is_eq()) + .is_some(); + + match contains { + true => Some(!negated), + false if haystack_has_nulls => None, + false => Some(negated), + } + }) + .collect()) + }) + } +} + +impl ArrayStaticFilter { + /// Computes a [`StaticFilter`] for the provided [`Array`] if there + /// are nulls present or there are more than the configured number of + /// elements. + /// + /// Note: This is split into a separate function as higher-rank trait bounds currently + /// cause type inference to misbehave + pub(super) fn try_new(in_array: ArrayRef) -> Result { + // Null type has no natural order - return empty hash set + if in_array.data_type() == &DataType::Null { + return Ok(ArrayStaticFilter { + in_array, + state: RandomState::default(), + map: HashMap::with_hasher(()), + }); + } + + let state = RandomState::default(); + let mut map: HashMap = HashMap::with_hasher(()); + + with_hashes([&in_array], &state, |hashes| -> Result<()> { + let cmp = make_comparator(&in_array, &in_array, SortOptions::default())?; + + let insert_value = |idx| { + let hash = hashes[idx]; + if let RawEntryMut::Vacant(v) = map + .raw_entry_mut() + .from_hash(hash, |x| cmp(*x, idx).is_eq()) + { + v.insert_with_hasher(hash, idx, (), |x| hashes[*x]); + } + }; + + match in_array.nulls() { + Some(nulls) => { + BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) + .for_each(insert_value) + } + None => (0..in_array.len()).for_each(insert_value), + } + + Ok(()) + })?; + + Ok(Self { + in_array, + state, + map, + }) + } +} diff --git a/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs new file mode 100644 index 0000000000000..2c084a1cb247b --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs @@ -0,0 +1,233 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, AsArray, BooleanArray, downcast_array, downcast_dictionary_array, +}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::compute::take; +use arrow::datatypes::*; +use datafusion_common::{HashSet, Result, exec_datafusion_err}; +use std::hash::{Hash, Hasher}; + +use super::static_filter::StaticFilter; + +/// Wrapper for f32 that implements Hash and Eq using bit comparison. +/// This treats NaN values as equal to each other when they have the same bit pattern. +#[derive(Clone, Copy)] +struct OrderedFloat32(f32); + +impl Hash for OrderedFloat32 { + fn hash(&self, state: &mut H) { + self.0.to_ne_bytes().hash(state); + } +} + +impl PartialEq for OrderedFloat32 { + fn eq(&self, other: &Self) -> bool { + self.0.to_bits() == other.0.to_bits() + } +} + +impl Eq for OrderedFloat32 {} + +impl From for OrderedFloat32 { + fn from(v: f32) -> Self { + Self(v) + } +} + +/// Wrapper for f64 that implements Hash and Eq using bit comparison. +/// This treats NaN values as equal to each other when they have the same bit pattern. +#[derive(Clone, Copy)] +struct OrderedFloat64(f64); + +impl Hash for OrderedFloat64 { + fn hash(&self, state: &mut H) { + self.0.to_ne_bytes().hash(state); + } +} + +impl PartialEq for OrderedFloat64 { + fn eq(&self, other: &Self) -> bool { + self.0.to_bits() == other.0.to_bits() + } +} + +impl Eq for OrderedFloat64 {} + +impl From for OrderedFloat64 { + fn from(v: f64) -> Self { + Self(v) + } +} + +// Macro to generate specialized StaticFilter implementations for primitive types +macro_rules! primitive_static_filter { + ($Name:ident, $ArrowType:ty) => { + primitive_static_filter!( + $Name, + $ArrowType, + <$ArrowType as ArrowPrimitiveType>::Native, + |v| v + ); + }; + ($Name:ident, $ArrowType:ty, $SetValueType:ty, $to_set_value:expr) => { + pub(super) struct $Name { + null_count: usize, + values: HashSet<$SetValueType>, + } + + impl $Name { + pub(super) fn try_new(in_array: &ArrayRef) -> Result { + let in_array = in_array + .as_primitive_opt::<$ArrowType>() + .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; + + let mut values = HashSet::with_capacity(in_array.len()); + let null_count = in_array.null_count(); + + for v in in_array.iter().flatten() { + values.insert(($to_set_value)(v)); + } + + Ok(Self { null_count, values }) + } + } + + impl StaticFilter for $Name { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + // Handle dictionary arrays by recursing on the values + downcast_dictionary_array! { + v => { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())) + } + _ => {} + } + + let v = v + .as_primitive_opt::<$ArrowType>() + .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; + + let haystack_has_nulls = self.null_count > 0; + let needle_values = v.values(); + let needle_nulls = v.nulls(); + let needle_has_nulls = v.null_count() > 0; + + // Truth table for `value [NOT] IN (set)` with SQL three-valued logic: + // ("-" means the value doesn't affect the result) + // + // | needle_null | haystack_null | negated | in set? | result | + // |-------------|---------------|---------|---------|--------| + // | true | - | false | - | null | + // | true | - | true | - | null | + // | false | true | false | yes | true | + // | false | true | false | no | null | + // | false | true | true | yes | false | + // | false | true | true | no | null | + // | false | false | false | yes | true | + // | false | false | false | no | false | + // | false | false | true | yes | false | + // | false | false | true | no | true | + + // Compute the "contains" result using collect_bool (fast batched approach) + // This ignores nulls - we handle them separately + let contains_buffer = if negated { + BooleanBuffer::collect_bool(needle_values.len(), |i| { + !self.values.contains(&($to_set_value)(needle_values[i])) + }) + } else { + BooleanBuffer::collect_bool(needle_values.len(), |i| { + self.values.contains(&($to_set_value)(needle_values[i])) + }) + }; + + // Compute the null mask + // Output is null when: + // 1. needle value is null, OR + // 2. needle value is not in set AND haystack has nulls + let result_nulls = match (needle_has_nulls, haystack_has_nulls) { + (false, false) => { + // No nulls anywhere + None + } + (true, false) => { + // Only needle has nulls - just use needle's null mask + needle_nulls.cloned() + } + (false, true) => { + // Only haystack has nulls - result is null when value not in set + // Valid (not null) when original "in set" is true + // For NOT IN: contains_buffer = !original, so validity = !contains_buffer + let validity = if negated { + !&contains_buffer + } else { + contains_buffer.clone() + }; + Some(NullBuffer::new(validity)) + } + (true, true) => { + // Both have nulls - combine needle nulls with haystack-induced nulls + let needle_validity = needle_nulls.map(|n| n.inner().clone()) + .unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len())); + + // Valid when original "in set" is true (see above) + let haystack_validity = if negated { + !&contains_buffer + } else { + contains_buffer.clone() + }; + + // Combined validity: valid only where both are valid + let combined_validity = &needle_validity & &haystack_validity; + Some(NullBuffer::new(combined_validity)) + } + }; + + Ok(BooleanArray::new(contains_buffer, result_nulls)) + } + } + }; +} + +// Generate specialized filters for all integer primitive types +primitive_static_filter!(Int8StaticFilter, Int8Type); +primitive_static_filter!(Int16StaticFilter, Int16Type); +primitive_static_filter!(Int32StaticFilter, Int32Type); +primitive_static_filter!(Int64StaticFilter, Int64Type); +primitive_static_filter!(UInt8StaticFilter, UInt8Type); +primitive_static_filter!(UInt16StaticFilter, UInt16Type); +primitive_static_filter!(UInt32StaticFilter, UInt32Type); +primitive_static_filter!(UInt64StaticFilter, UInt64Type); + +// Macro to generate specialized StaticFilter implementations for float types +// Floats require a wrapper type (OrderedFloat*) to implement Hash/Eq due to NaN semantics +macro_rules! float_static_filter { + ($Name:ident, $ArrowType:ty, $OrderedType:ty) => { + primitive_static_filter!($Name, $ArrowType, $OrderedType, <$OrderedType>::from); + }; +} + +// Generate specialized filters for float types using ordered wrappers +float_static_filter!(Float32StaticFilter, Float32Type, OrderedFloat32); +float_static_filter!(Float64StaticFilter, Float64Type, OrderedFloat64); diff --git a/datafusion/physical-expr/src/expressions/in_list/static_filter.rs b/datafusion/physical-expr/src/expressions/in_list/static_filter.rs new file mode 100644 index 0000000000000..218bd27950266 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/static_filter.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, BooleanArray}; +use datafusion_common::Result; + +/// Trait for InList static filters. +/// +/// Static filters store a pre-computed set of values (the haystack) and check +/// whether needle values are contained in that set. The haystack is always +/// represented in its non-dictionary (value) type. Dictionary haystacks are +/// flattened via `cast()` before construction. +/// +/// Dictionary-encoded needles are unwrapped inside `contains()` and +/// evaluated against the dictionary's values. +pub(super) trait StaticFilter { + fn null_count(&self) -> usize; + + /// Checks if values in `v` (needle) are contained in this filter's + /// haystack. `v` may be dictionary-encoded, in which case the + /// implementation unwraps the dictionary and operates on its values. + fn contains(&self, v: &dyn Array, negated: bool) -> Result; +} diff --git a/datafusion/physical-expr/src/expressions/in_list/strategy.rs b/datafusion/physical-expr/src/expressions/in_list/strategy.rs new file mode 100644 index 0000000000000..b7ee3dd1a3b9d --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/strategy.rs @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::compute::cast; +use arrow::datatypes::DataType; +use datafusion_common::Result; + +use super::array_static_filter::ArrayStaticFilter; +use super::primitive_filter::*; +use super::static_filter::StaticFilter; + +pub(super) fn instantiate_static_filter( + in_array: ArrayRef, +) -> Result> { + // Flatten dictionary-encoded haystacks to their value type so that + // specialized filters (e.g. Int32StaticFilter) are used instead of + // falling through to the generic ArrayStaticFilter. + let in_array = match in_array.data_type() { + DataType::Dictionary(_, value_type) => cast(&in_array, value_type.as_ref())?, + _ => in_array, + }; + match in_array.data_type() { + // Integer primitive types + DataType::Int8 => Ok(Arc::new(Int8StaticFilter::try_new(&in_array)?)), + DataType::Int16 => Ok(Arc::new(Int16StaticFilter::try_new(&in_array)?)), + DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), + DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)), + DataType::UInt8 => Ok(Arc::new(UInt8StaticFilter::try_new(&in_array)?)), + DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)), + DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), + DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), + // Float primitive types (use ordered wrappers for Hash/Eq) + DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)), + DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)), + _ => { + /* fall through to generic implementation for unsupported types (Struct, etc.) */ + Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) + } + } +} diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 62be8ebbc13e3..3f3b7d16e543a 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -22,11 +22,10 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::ColumnarValue; use std::hash::Hash; -use std::{any::Any, sync::Arc}; +use std::sync::Arc; /// IS NOT NULL expression #[derive(Debug, Eq)] @@ -67,11 +66,6 @@ impl std::fmt::Display for IsNotNullExpr { } impl PhysicalExpr for IsNotNullExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - fn data_type(&self, _input_schema: &Schema) -> Result { Ok(DataType::Boolean) } @@ -108,6 +102,48 @@ impl PhysicalExpr for IsNotNullExpr { self.arg.fmt_sql(f)?; write!(f, " IS NOT NULL") } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( + Box::new(protobuf::PhysicalIsNotNull { + expr: Some(Box::new(ctx.encode_child(&self.arg)?)), + }), + )), + })) + } +} + +#[cfg(feature = "proto")] +impl IsNotNullExpr { + /// Reconstruct an [`IsNotNullExpr`] from its protobuf representation. + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_physical_expr_common::expect_expr_variant; + use datafusion_proto_models::protobuf; + + let node = expect_expr_variant!( + node, + protobuf::physical_expr_node::ExprType::IsNotNullExpr, + "IsNotNullExpr", + ); + let expr = ctx.decode_required_expression( + node.expr.as_deref(), + "IsNotNullExpr", + "expr", + )?; + + Ok(Arc::new(IsNotNullExpr::new(expr))) + } } /// Create an IS NOT NULL expression @@ -218,3 +254,109 @@ mod tests { Ok(()) } } + +#[cfg(all(test, feature = "proto"))] +mod proto_tests { + use super::*; + use crate::expressions::{Column, col}; + use crate::proto_test_util::{ + StubDecoder, StubEncoder, UnreachableDecoder, column_node, + }; + use arrow::datatypes::Field; + use datafusion_common::DataFusionError; + use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx; + use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx; + use datafusion_proto_models::protobuf::{ + PhysicalExprNode, PhysicalIsNotNull, physical_expr_node, + }; + + fn is_not_null_node(expr: Option>) -> PhysicalExprNode { + PhysicalExprNode { + expr_id: None, + expr_type: Some(physical_expr_node::ExprType::IsNotNullExpr(Box::new( + PhysicalIsNotNull { expr }, + ))), + } + } + + fn is_not_null_fixture() -> IsNotNullExpr { + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + IsNotNullExpr::new(col("a", &schema).unwrap()) + } + + #[test] + fn try_to_proto_encodes_is_not_null_expr() { + let is_not_null = is_not_null_fixture(); + let encoder = StubEncoder::ok(); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let node = is_not_null + .try_to_proto(&ctx) + .unwrap() + .expect("IsNotNullExpr should encode to Some(node)"); + + assert!(node.expr_id.is_none()); + let is_not_null_node = match node.expr_type { + Some(physical_expr_node::ExprType::IsNotNullExpr(boxed)) => *boxed, + other => panic!("expected an IsNotNullExpr node, got {other:?}"), + }; + assert!(is_not_null_node.expr.is_some()); + } + + #[test] + fn try_to_proto_propagates_expr_encode_error() { + let is_not_null = is_not_null_fixture(); + let encoder = StubEncoder::failing_on(1); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + let err = is_not_null.try_to_proto(&ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1"))); + } + + #[test] + fn try_from_proto_decodes_is_not_null_expr() { + let node = is_not_null_node(Some(Box::new(column_node("a")))); + let schema = Schema::empty(); + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let decoded = IsNotNullExpr::try_from_proto(&node, &ctx).unwrap(); + let is_not_null = decoded + .downcast_ref::() + .expect("decoded expr should be an IsNotNullExpr"); + assert!(is_not_null.arg().downcast_ref::().is_some()); + } + + #[test] + fn try_from_proto_rejects_non_is_not_null_node() { + let node = column_node("a"); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = IsNotNullExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("PhysicalExprNode is not a IsNotNullExpr")) + ); + } + + #[test] + fn try_from_proto_rejects_missing_expr() { + let node = is_not_null_node(None); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = IsNotNullExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("IsNotNullExpr is missing required field 'expr'")) + ); + } + + #[test] + fn try_from_proto_propagates_expr_decode_error() { + let node = is_not_null_node(Some(Box::new(column_node("a")))); + let schema = Schema::empty(); + let decoder = StubDecoder::failing_on(1); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = IsNotNullExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1"))); + } +} diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 356fe2a866672..da008a1cfb821 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -22,11 +22,10 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::ColumnarValue; use std::hash::Hash; -use std::{any::Any, sync::Arc}; +use std::sync::Arc; /// IS NULL expression #[derive(Debug, Eq)] @@ -67,11 +66,6 @@ impl std::fmt::Display for IsNullExpr { } impl PhysicalExpr for IsNullExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - fn data_type(&self, _input_schema: &Schema) -> Result { Ok(DataType::Boolean) } @@ -107,6 +101,45 @@ impl PhysicalExpr for IsNullExpr { self.arg.fmt_sql(f)?; write!(f, " IS NULL") } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( + Box::new(protobuf::PhysicalIsNull { + expr: Some(Box::new(ctx.encode_child(&self.arg)?)), + }), + )), + })) + } +} + +#[cfg(feature = "proto")] +impl IsNullExpr { + /// Reconstruct an [`IsNullExpr`] from its protobuf representation. + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_physical_expr_common::expect_expr_variant; + use datafusion_proto_models::protobuf; + + let node = expect_expr_variant!( + node, + protobuf::physical_expr_node::ExprType::IsNullExpr, + "IsNullExpr", + ); + let expr = + ctx.decode_required_expression(node.expr.as_deref(), "IsNullExpr", "expr")?; + + Ok(Arc::new(IsNullExpr::new(expr))) + } } /// Create an IS NULL expression @@ -229,3 +262,109 @@ mod tests { Ok(()) } } + +#[cfg(all(test, feature = "proto"))] +mod proto_tests { + use super::*; + use crate::expressions::{Column, col}; + use crate::proto_test_util::{ + StubDecoder, StubEncoder, UnreachableDecoder, column_node, + }; + use arrow::datatypes::Field; + use datafusion_common::DataFusionError; + use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx; + use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx; + use datafusion_proto_models::protobuf::{ + PhysicalExprNode, PhysicalIsNull, physical_expr_node, + }; + + fn is_null_node(expr: Option>) -> PhysicalExprNode { + PhysicalExprNode { + expr_id: None, + expr_type: Some(physical_expr_node::ExprType::IsNullExpr(Box::new( + PhysicalIsNull { expr }, + ))), + } + } + + fn is_null_fixture() -> IsNullExpr { + let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); + IsNullExpr::new(col("a", &schema).unwrap()) + } + + #[test] + fn try_to_proto_encodes_is_null_expr() { + let is_null = is_null_fixture(); + let encoder = StubEncoder::ok(); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let node = is_null + .try_to_proto(&ctx) + .unwrap() + .expect("IsNullExpr should encode to Some(node)"); + + assert!(node.expr_id.is_none()); + let is_null_node = match node.expr_type { + Some(physical_expr_node::ExprType::IsNullExpr(boxed)) => *boxed, + other => panic!("expected an IsNullExpr node, got {other:?}"), + }; + assert!(is_null_node.expr.is_some()); + } + + #[test] + fn try_to_proto_propagates_expr_encode_error() { + let is_null = is_null_fixture(); + let encoder = StubEncoder::failing_on(1); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + let err = is_null.try_to_proto(&ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1"))); + } + + #[test] + fn try_from_proto_decodes_is_null_expr() { + let node = is_null_node(Some(Box::new(column_node("a")))); + let schema = Schema::empty(); + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let decoded = IsNullExpr::try_from_proto(&node, &ctx).unwrap(); + let is_null = decoded + .downcast_ref::() + .expect("decoded expr should be an IsNullExpr"); + assert!(is_null.arg().downcast_ref::().is_some()); + } + + #[test] + fn try_from_proto_rejects_non_is_null_node() { + let node = column_node("a"); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = IsNullExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("PhysicalExprNode is not a IsNullExpr")) + ); + } + + #[test] + fn try_from_proto_rejects_missing_expr() { + let node = is_null_node(None); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = IsNullExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("IsNullExpr is missing required field 'expr'")) + ); + } + + #[test] + fn try_from_proto_propagates_expr_decode_error() { + let node = is_null_node(Some(Box::new(column_node("a")))); + let schema = Schema::empty(); + let decoder = StubDecoder::failing_on(1); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = IsNullExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1"))); + } +} diff --git a/datafusion/physical-expr/src/expressions/lambda.rs b/datafusion/physical-expr/src/expressions/lambda.rs new file mode 100644 index 0000000000000..9275821ae9150 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/lambda.rs @@ -0,0 +1,252 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical lambda expression: [`LambdaExpr`] + +use std::hash::Hash; +use std::sync::Arc; + +use crate::{ + ScalarFunctionExpr, + expressions::{Column, LambdaVariable}, + physical_expr::PhysicalExpr, +}; +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion_common::{ + HashMap, plan_err, + tree_node::{Transformed, TreeNode, TreeNodeRecursion}, +}; +use datafusion_common::{HashSet, Result, internal_err}; +use datafusion_expr::ColumnarValue; + +/// Represents a lambda with the given parameters names and body +#[derive(Debug, Eq, Clone)] +pub struct LambdaExpr { + params: Vec, + body: Arc, + projected_body: Arc, + projection: Vec, +} + +// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 [https://github.com/apache/datafusion/issues/13196] +impl PartialEq for LambdaExpr { + fn eq(&self, other: &Self) -> bool { + self.params.eq(&other.params) && self.body.eq(&other.body) + } +} + +impl Hash for LambdaExpr { + fn hash(&self, state: &mut H) { + self.params.hash(state); + self.body.hash(state); + } +} + +impl LambdaExpr { + /// Create a new lambda expression with the given parameters and body + pub fn try_new(params: Vec, body: Arc) -> Result { + if !all_unique(¶ms) { + return plan_err!( + "lambda params must be unique, got ({})", + params.join(", ") + ); + } + + check_async_udf(&body)?; + + Ok(Self::new(params, body)) + } + + fn new(params: Vec, body: Arc) -> Self { + let mut used_column_indices = HashSet::new(); + + body.apply(|node| { + if let Some(col) = node.downcast_ref::() { + used_column_indices.insert(col.index()); + } else if let Some(var) = node.downcast_ref::() { + used_column_indices.insert(var.index()); + } + + Ok(TreeNodeRecursion::Continue) + }) + .expect("closure should be infallible"); + + let mut projection = used_column_indices.into_iter().collect::>(); + + projection.sort(); + + let column_index_map = projection + .iter() + .enumerate() + .map(|(projected, original)| (*original, projected)) + .collect::>(); + + let projected_body = Arc::clone(&body) + .transform_down(|e| { + if let Some(column) = e.downcast_ref::() { + let original = column.index(); + let projected = *column_index_map.get(&original).unwrap(); + if projected != original { + return Ok(Transformed::yes(Arc::new(Column::new( + column.name(), + projected, + )))); + } + } else if let Some(lambda_variable) = e.downcast_ref::() { + let original = lambda_variable.index(); + let projected = *column_index_map.get(&original).unwrap(); + if projected != original { + return Ok(Transformed::yes(Arc::new(LambdaVariable::new( + projected, + Arc::clone(lambda_variable.field()), + )))); + } + } + Ok(Transformed::no(e)) + }) + .expect("closure should be infallible") + .data; + + Self { + params, + body, + projected_body, + projection, + } + } + + /// Get the lambda's params names + pub fn params(&self) -> &[String] { + &self.params + } + + /// Get the lambda's body + pub fn body(&self) -> &Arc { + &self.body + } + + pub(crate) fn projection(&self) -> &[usize] { + &self.projection + } + + pub(crate) fn projected_body(&self) -> &Arc { + &self.projected_body + } +} + +impl std::fmt::Display for LambdaExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "({}) -> {}", self.params.join(", "), self.body) + } +} + +impl PhysicalExpr for LambdaExpr { + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Null) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + internal_err!("LambdaExpr::evaluate() should not be called") + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.body] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + let [body] = children.as_slice() else { + return internal_err!( + "LambdaExpr expects exactly 1 child, got {}", + children.len() + ); + }; + + check_async_udf(body)?; + + Ok(Arc::new(Self::new(self.params.clone(), Arc::clone(body)))) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({}) -> {}", self.params.join(", "), self.body) + } +} + +/// Create a lambda expression +pub fn lambda( + params: impl IntoIterator>, + body: Arc, +) -> Result> { + Ok(Arc::new(LambdaExpr::try_new( + params.into_iter().map(Into::into).collect(), + body, + )?)) +} + +fn all_unique(params: &[String]) -> bool { + match params.len() { + 0 | 1 => true, + 2 => params[0] != params[1], + _ => { + let mut set = HashSet::with_capacity(params.len()); + + params.iter().all(|p| set.insert(p.as_str())) + } + } +} + +fn check_async_udf(body: &Arc) -> Result<()> { + if body.exists(|expr| { + Ok(expr + .downcast_ref::() + .is_some_and(|udf| udf.fun().as_async().is_some())) + })? { + return plan_err!( + "Async functions in lambdas aren't supported, see https://github.com/apache/datafusion/issues/22091" + ); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use crate::expressions::{NoOp, lambda::lambda}; + use arrow::{array::RecordBatch, datatypes::Schema}; + use std::sync::Arc; + + #[test] + fn test_lambda_evaluate() { + let lambda = lambda(["a"], Arc::new(NoOp::new())).unwrap(); + let batch = RecordBatch::new_empty(Arc::new(Schema::empty())); + assert!(lambda.evaluate(&batch).is_err()); + } + + #[test] + fn test_lambda_duplicate_name() { + assert!(lambda(["a", "a"], Arc::new(NoOp::new())).is_err()); + } +} diff --git a/datafusion/physical-expr/src/expressions/lambda_variable.rs b/datafusion/physical-expr/src/expressions/lambda_variable.rs new file mode 100644 index 0000000000000..1c130ab12e9bb --- /dev/null +++ b/datafusion/physical-expr/src/expressions/lambda_variable.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical lambda variable reference: [`LambdaVariable`] + +use std::hash::Hash; +use std::sync::Arc; + +use crate::physical_expr::PhysicalExpr; +use arrow::datatypes::FieldRef; +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; + +use datafusion_common::{Result, exec_err, internal_err}; +use datafusion_expr::ColumnarValue; + +/// Represents the lambda variable with a given index and field +#[derive(Debug, Clone)] +pub struct LambdaVariable { + index: usize, + field: FieldRef, +} + +impl Eq for LambdaVariable {} + +impl PartialEq for LambdaVariable { + fn eq(&self, other: &Self) -> bool { + self.index == other.index && self.field == other.field + } +} + +impl Hash for LambdaVariable { + fn hash(&self, state: &mut H) { + self.index.hash(state); + self.field.hash(state); + } +} + +impl LambdaVariable { + /// Create a new lambda variable expression + pub fn new(index: usize, field: FieldRef) -> Self { + Self { index, field } + } + + /// Get the variable's name + pub fn name(&self) -> &str { + self.field.name() + } + + /// Get the variable's index + pub fn index(&self) -> usize { + self.index + } + + /// Get the variable's field + pub fn field(&self) -> &FieldRef { + &self.field + } +} + +impl std::fmt::Display for LambdaVariable { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}@{}", self.name(), self.index) + } +} + +impl PhysicalExpr for LambdaVariable { + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.field.data_type().clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(self.field.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + if self.index >= batch.num_columns() { + return internal_err!( + "PhysicalExpr LambdaVariable references column '{}' at index {} (zero-based) but batch only has {} columns: {:?}", + self.name(), + self.index, + batch.num_columns(), + batch + .schema_ref() + .fields() + .iter() + .map(|f| f.name()) + .collect::>() + ); + } + + if self.field.as_ref() != batch.schema_ref().field(self.index) { + return exec_err!( + "Field of physical LambdaVariable with index {} doesn't match batch field during evaluation {} != {}", + self.index, + self.field, + batch.schema_ref().field(self.index) + ); + } + + Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) + } + + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.field)) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}@{}", self.name(), self.index) + } +} + +/// Create a lambda variable expression +pub fn lambda_variable(name: &str, schema: &Schema) -> Result> { + let index = schema.index_of(name)?; + let field = Arc::clone(&schema.fields()[index]); + + Ok(Arc::new(LambdaVariable::new(index, field))) +} diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index 5502def5820f6..7535f109a0a92 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -18,11 +18,11 @@ use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{assert_or_internal_err, Result}; +use datafusion_common::{Result, assert_or_internal_err}; use datafusion_expr::{ColumnarValue, Operator}; use datafusion_physical_expr_common::datum::apply_cmp; use std::hash::Hash; -use std::{any::Any, sync::Arc}; +use std::sync::Arc; // Like expression #[derive(Debug, Eq)] @@ -105,10 +105,6 @@ impl std::fmt::Display for LikeExpr { } impl PhysicalExpr for LikeExpr { - fn as_any(&self) -> &dyn Any { - self - } - fn data_type(&self, _input_schema: &Schema) -> Result { Ok(DataType::Boolean) } @@ -149,6 +145,65 @@ impl PhysicalExpr for LikeExpr { write!(f, " {} ", self.op_name())?; self.pattern.fmt_sql(f) } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr(Box::new( + protobuf::PhysicalLikeExprNode { + negated: self.negated, + case_insensitive: self.case_insensitive, + expr: Some(Box::new(ctx.encode_child(&self.expr)?)), + pattern: Some(Box::new(ctx.encode_child(&self.pattern)?)), + }, + ))), + })) + } +} + +#[cfg(feature = "proto")] +impl LikeExpr { + /// Reconstruct a [`LikeExpr`] from its protobuf representation. + /// + /// Takes the whole [`PhysicalExprNode`] so the decode signature matches + /// other migrated expressions and can inspect outer-node metadata if + /// needed in the future. + /// + /// [`PhysicalExprNode`]: datafusion_proto_models::protobuf::PhysicalExprNode + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_physical_expr_common::expect_expr_variant; + use datafusion_proto_models::protobuf; + + let like_expr = expect_expr_variant!( + node, + protobuf::physical_expr_node::ExprType::LikeExpr, + "LikeExpr", + ); + + Ok(Arc::new(LikeExpr::new( + like_expr.negated, + like_expr.case_insensitive, + ctx.decode_required_expression( + like_expr.expr.as_deref(), + "LikeExpr", + "expr", + )?, + ctx.decode_required_expression( + like_expr.pattern.as_deref(), + "LikeExpr", + "pattern", + )?, + ))) + } } /// used for optimize Dictionary like @@ -287,3 +342,189 @@ mod test { Ok(()) } } + +/// Tests for the `try_to_proto` / `try_from_proto` hooks. +#[cfg(all(test, feature = "proto"))] +mod proto_tests { + use super::*; + use crate::expressions::{Column, col}; + use crate::proto_test_util::{ + StubDecoder, StubEncoder, UnreachableDecoder, column_node, + }; + use arrow::datatypes::Field; + use datafusion_common::DataFusionError; + use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx; + use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx; + use datafusion_proto_models::protobuf::{ + PhysicalExprNode, PhysicalLikeExprNode, physical_expr_node, + }; + + /// Build a `LikeExpr` proto node with the given children. + fn like_node( + negated: bool, + case_insensitive: bool, + expr: Option>, + pattern: Option>, + ) -> PhysicalExprNode { + PhysicalExprNode { + expr_id: None, + expr_type: Some(physical_expr_node::ExprType::LikeExpr(Box::new( + PhysicalLikeExprNode { + negated, + case_insensitive, + expr, + pattern, + }, + ))), + } + } + + /// A `LikeExpr` over two `Utf8` columns with both flags set, so the + /// `negated` / `case_insensitive` wiring is actually exercised. + fn like_fixture() -> LikeExpr { + let schema = Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ]); + LikeExpr::new( + true, + true, + col("a", &schema).unwrap(), + col("b", &schema).unwrap(), + ) + } + + #[test] + fn try_to_proto_encodes_like_expr() { + let like = like_fixture(); + let encoder = StubEncoder::ok(); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let node = like + .try_to_proto(&ctx) + .unwrap() + .expect("LikeExpr should encode to Some(node)"); + + // Built-in exprs never set expr_id; only dynamic filters do. + assert!(node.expr_id.is_none()); + let like_node = match node.expr_type { + Some(physical_expr_node::ExprType::LikeExpr(boxed)) => *boxed, + other => panic!("expected a LikeExpr node, got {other:?}"), + }; + assert!(like_node.negated); + assert!(like_node.case_insensitive); + assert!(like_node.expr.is_some()); + assert!(like_node.pattern.is_some()); + } + + #[test] + fn try_to_proto_propagates_expr_encode_error() { + let like = like_fixture(); + let encoder = StubEncoder::failing_on(1); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + let err = like.try_to_proto(&ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1"))); + } + + #[test] + fn try_to_proto_propagates_pattern_encode_error() { + let like = like_fixture(); + let encoder = StubEncoder::failing_on(2); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + let err = like.try_to_proto(&ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 2"))); + } + + #[test] + fn try_from_proto_decodes_like_expr() { + let node = like_node( + true, + true, + Some(Box::new(column_node("a"))), + Some(Box::new(column_node("b"))), + ); + let schema = Schema::empty(); + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let decoded = LikeExpr::try_from_proto(&node, &ctx).unwrap(); + let like = decoded + .downcast_ref::() + .expect("decoded expr should be a LikeExpr"); + assert!(like.negated()); + assert!(like.case_insensitive()); + assert!(like.expr().downcast_ref::().is_some()); + assert!(like.pattern().downcast_ref::().is_some()); + } + + #[test] + fn try_from_proto_rejects_non_like_node() { + let node = column_node("a"); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = LikeExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!( + err, + DataFusionError::Internal(msg) if msg.contains("PhysicalExprNode is not a LikeExpr") + )); + } + + #[test] + fn try_from_proto_rejects_missing_expr() { + let node = like_node(false, false, None, Some(Box::new(column_node("b")))); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = LikeExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!( + err, + DataFusionError::Internal(msg) if msg.contains("LikeExpr is missing required field 'expr'") + )); + } + + #[test] + fn try_from_proto_rejects_missing_pattern() { + let node = like_node(false, false, Some(Box::new(column_node("a"))), None); + let schema = Schema::empty(); + // `expr` is present, so it is decoded before the missing-`pattern` + // check fires; use a decoder that succeeds for that first child. + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = LikeExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!( + err, + DataFusionError::Internal(msg) if msg.contains("LikeExpr is missing required field 'pattern'") + )); + } + + #[test] + fn try_from_proto_propagates_expr_decode_error() { + let node = like_node( + false, + false, + Some(Box::new(column_node("a"))), + Some(Box::new(column_node("b"))), + ); + let schema = Schema::empty(); + let decoder = StubDecoder::failing_on(1); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = LikeExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1"))); + } + + #[test] + fn try_from_proto_propagates_pattern_decode_error() { + let node = like_node( + false, + false, + Some(Box::new(column_node("a"))), + Some(Box::new(column_node("b"))), + ); + let schema = Schema::empty(); + let decoder = StubDecoder::failing_on(2); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = LikeExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 2"))); + } +} diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 1f3fefc60b7ad..5fb9a3b2cd29b 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -17,7 +17,6 @@ //! Literal expressions for physical operations -use std::any::Any; use std::hash::Hash; use std::sync::Arc; @@ -33,6 +32,7 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Expr; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::placement::ExpressionPlacement; use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; /// Represents a literal value @@ -91,11 +91,6 @@ impl std::fmt::Display for Literal { } impl PhysicalExpr for Literal { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - fn data_type(&self, _input_schema: &Schema) -> Result { Ok(self.value.data_type()) } @@ -134,6 +129,45 @@ impl PhysicalExpr for Literal { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) } + + fn placement(&self) -> ExpressionPlacement { + ExpressionPlacement::Literal + } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + _ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( + (&self.value).try_into()?, + )), + })) + } +} + +#[cfg(feature = "proto")] +impl Literal { + /// Reconstruct a [`Literal`] from its protobuf representation. + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + _ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_physical_expr_common::expect_expr_variant; + use datafusion_proto_models::protobuf; + + let scalar_proto = expect_expr_variant!( + node, + protobuf::physical_expr_node::ExprType::Literal, + "Literal", + ); + let value = ScalarValue::try_from(scalar_proto)?; + Ok(Arc::new(Literal::new(value))) + } } /// Create a literal expression @@ -150,7 +184,6 @@ mod tests { use super::*; use arrow::array::Int32Array; - use arrow::datatypes::Field; use datafusion_common::cast::as_int32_array; use datafusion_physical_expr_common::physical_expr::fmt_sql; @@ -192,3 +225,103 @@ mod tests { Ok(()) } } + +/// Tests for the `try_to_proto` / `try_from_proto` hooks. +#[cfg(all(test, feature = "proto"))] +mod proto_tests { + use super::*; + use crate::proto_test_util::{StubEncoder, UnreachableDecoder, column_node}; + use datafusion_common::DataFusionError; + use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx; + use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx; + use datafusion_proto_models::protobuf::physical_expr_node; + + fn i32_literal() -> Literal { + Literal::new(ScalarValue::Int32(Some(42))) + } + + // ── try_to_proto ───────────────────────────────────────────────────────── + + #[test] + fn try_to_proto_encodes_literal() { + let literal = i32_literal(); + let encoder = StubEncoder::ok(); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let node = literal + .try_to_proto(&ctx) + .unwrap() + .expect("Literal should encode to Some(node)"); + + // Literal nodes never set expr_id. + assert!(node.expr_id.is_none()); + // Variant must be Literal, not any other expr type. + assert!(matches!( + node.expr_type, + Some(physical_expr_node::ExprType::Literal(_)) + )); + } + + #[test] + fn try_to_proto_null_literal() { + let literal = Literal::new(ScalarValue::Int32(None)); + let encoder = StubEncoder::ok(); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let node = literal + .try_to_proto(&ctx) + .unwrap() + .expect("null Literal should encode to Some(node)"); + + assert!(matches!( + node.expr_type, + Some(physical_expr_node::ExprType::Literal(_)) + )); + + // Decode and verify the null payload round-trips correctly. + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let dec_ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let decoded = Literal::try_from_proto(&node, &dec_ctx).unwrap(); + let lit = decoded + .downcast_ref::() + .expect("decoded expr should be a Literal"); + assert_eq!(lit.value(), &ScalarValue::Int32(None)); + } + + // ── try_from_proto ─────────────────────────────────────────────────────── + + #[test] + fn try_from_proto_roundtrip() { + let original = i32_literal(); + let encoder = StubEncoder::ok(); + let enc_ctx = PhysicalExprEncodeCtx::new(&encoder); + + let node = original + .try_to_proto(&enc_ctx) + .unwrap() + .expect("should encode"); + + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let dec_ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let decoded = Literal::try_from_proto(&node, &dec_ctx).unwrap(); + let lit = decoded + .downcast_ref::() + .expect("decoded expr should be a Literal"); + assert_eq!(lit.value(), &ScalarValue::Int32(Some(42))); + } + + #[test] + fn try_from_proto_rejects_non_literal_node() { + let node = column_node("a"); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = Literal::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(ref msg) if msg.contains("PhysicalExprNode is not a Literal")) + ); + } +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 59d675753d985..05a04f88dcadf 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -21,12 +21,13 @@ mod binary; mod case; mod cast; -mod cast_column; mod column; mod dynamic_filters; mod in_list; mod is_not_null; mod is_null; +mod lambda; +mod lambda_variable; mod like; mod literal; mod negative; @@ -35,24 +36,30 @@ mod not; mod try_cast; mod unknown_column; +pub use crate::PhysicalSortExpr; /// Module with some convenient methods used in expression building pub use crate::aggregate::stats::StatsType; -pub use crate::PhysicalSortExpr; -pub use binary::{binary, similar_to, BinaryExpr}; -pub use case::{case, CaseExpr}; -pub use cast::{cast, CastExpr}; -pub use cast_column::CastColumnExpr; -pub use column::{col, with_new_schema, Column}; +pub use binary::{BinaryExpr, binary, similar_to}; +pub use case::{CaseExpr, case}; +pub use cast::{CastExpr, cast}; +pub use column::{Column, col, with_new_schema}; pub use datafusion_expr::utils::format_state_name; -pub use dynamic_filters::DynamicFilterPhysicalExpr; -pub use in_list::{in_list, InListExpr}; -pub use is_not_null::{is_not_null, IsNotNullExpr}; -pub use is_null::{is_null, IsNullExpr}; -pub use like::{like, LikeExpr}; -pub use literal::{lit, Literal}; -pub use negative::{negative, NegativeExpr}; +pub use dynamic_filters::{ + DynamicFilterPhysicalExpr, DynamicFilterTracker, DynamicFilterTracking, + Inner as DynamicFilterInner, +}; +pub use in_list::{InListExpr, in_list}; +pub use is_not_null::{IsNotNullExpr, is_not_null}; +pub use is_null::{IsNullExpr, is_null}; +pub use lambda::{LambdaExpr, lambda}; +pub use lambda_variable::{LambdaVariable, lambda_variable}; +pub use like::{LikeExpr, like}; +pub use literal::{Literal, lit}; +pub use negative::{NegativeExpr, negative}; pub use no_op::NoOp; -pub use not::{not, NotExpr}; -pub use try_cast::{try_cast, TryCastExpr}; +pub use not::{NotExpr, not}; +pub use try_cast::{TryCastExpr, try_cast}; pub use unknown_column::UnKnownColumn; + +pub(crate) use cast::cast_with_target_field; diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index fa7224768a777..9fbf38361c89c 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -17,7 +17,6 @@ //! Negation (-) expression -use std::any::Any; use std::hash::Hash; use std::sync::Arc; @@ -29,15 +28,16 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_common::{Result, internal_err, plan_err}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; +#[expect(deprecated)] use datafusion_expr::statistics::Distribution::{ self, Bernoulli, Exponential, Gaussian, Generic, Uniform, }; use datafusion_expr::{ - type_coercion::{is_interval, is_null, is_signed_numeric, is_timestamp}, ColumnarValue, + type_coercion::{is_interval, is_signed_numeric, is_timestamp}, }; /// Negative expression @@ -79,11 +79,6 @@ impl std::fmt::Display for NegativeExpr { } impl PhysicalExpr for NegativeExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - fn data_type(&self, input_schema: &Schema) -> Result { self.arg.data_type(input_schema) } @@ -140,6 +135,7 @@ impl PhysicalExpr for NegativeExpr { .map(|result| vec![result])) } + #[expect(deprecated)] fn evaluate_statistics(&self, children: &[&Distribution]) -> Result { match children[0] { Uniform(u) => Distribution::new_uniform(u.range().arithmetic_negate()?), @@ -178,6 +174,45 @@ impl PhysicalExpr for NegativeExpr { self.arg.fmt_sql(f)?; write!(f, ")") } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::Negative(Box::new( + protobuf::PhysicalNegativeNode { + expr: Some(Box::new(ctx.encode_child(&self.arg)?)), + }, + ))), + })) + } +} + +#[cfg(feature = "proto")] +impl NegativeExpr { + /// Reconstruct a [`NegativeExpr`] from its protobuf representation. + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_physical_expr_common::expect_expr_variant; + use datafusion_proto_models::protobuf; + + let n = expect_expr_variant!( + node, + protobuf::physical_expr_node::ExprType::Negative, + "Negative", + ); + let expr = + ctx.decode_required_expression(n.expr.as_deref(), "NegativeExpr", "expr")?; + + Ok(Arc::new(NegativeExpr::new(expr))) + } } /// Creates a unary expression NEGATIVE @@ -190,7 +225,7 @@ pub fn negative( input_schema: &Schema, ) -> Result> { let data_type = arg.data_type(input_schema)?; - if is_null(&data_type) { + if data_type.is_null() { Ok(arg) } else if !is_signed_numeric(&data_type) && !is_interval(&data_type) @@ -205,19 +240,18 @@ pub fn negative( #[cfg(test)] mod tests { use super::*; - use crate::expressions::{col, Column}; + use crate::expressions::{Column, col}; use arrow::array::*; - use arrow::datatypes::DataType::{Float32, Float64, Int16, Int32, Int64, Int8}; + use arrow::datatypes::DataType::{Float32, Float64, Int8, Int16, Int32, Int64}; use arrow::datatypes::*; use datafusion_common::cast::as_primitive_array; use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_physical_expr_common::physical_expr::fmt_sql; - use paste::paste; macro_rules! test_array_negative_op { - ($DATA_TY:tt, $($VALUE:expr),* ) => { + ($DATA_TY:tt, $ARRAY_TY:ty, $($VALUE:expr),* ) => { let schema = Schema::new(vec![Field::new("a", DataType::$DATA_TY, true)]); let expr = negative(col("a", &schema)?, &schema)?; assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TY); @@ -230,8 +264,8 @@ mod tests { )+ arr.push(None); arr_expected.push(None); - let input = paste!{[<$DATA_TY Array>]::from(arr)}; - let expected = &paste!{[<$DATA_TY Array>]::from(arr_expected)}; + let input = <$ARRAY_TY>::from(arr); + let expected = &<$ARRAY_TY>::from(arr_expected); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; let result = expr.evaluate(&batch)?.into_array(batch.num_rows()).expect("Failed to convert to array"); @@ -243,12 +277,12 @@ mod tests { #[test] fn array_negative_op() -> Result<()> { - test_array_negative_op!(Int8, 2i8, 1i8); - test_array_negative_op!(Int16, 234i16, 123i16); - test_array_negative_op!(Int32, 2345i32, 1234i32); - test_array_negative_op!(Int64, 23456i64, 12345i64); - test_array_negative_op!(Float32, 2345.0f32, 1234.0f32); - test_array_negative_op!(Float64, 23456.0f64, 12345.0f64); + test_array_negative_op!(Int8, Int8Array, 2i8, 1i8); + test_array_negative_op!(Int16, Int16Array, 234i16, 123i16); + test_array_negative_op!(Int32, Int32Array, 2345i32, 1234i32); + test_array_negative_op!(Int64, Int64Array, 23456i64, 12345i64); + test_array_negative_op!(Float32, Float32Array, 2345.0f32, 1234.0f32); + test_array_negative_op!(Float64, Float64Array, 23456.0f64, 12345.0f64); Ok(()) } @@ -265,6 +299,7 @@ mod tests { } #[test] + #[expect(deprecated)] fn test_evaluate_statistics() -> Result<()> { let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0))); @@ -277,11 +312,13 @@ mod tests { ); // Bernoulli - assert!(negative_expr - .evaluate_statistics(&[&Distribution::new_bernoulli(ScalarValue::from( - 0.75 - ))?]) - .is_err()); + assert!( + negative_expr + .evaluate_statistics(&[&Distribution::new_bernoulli(ScalarValue::from( + 0.75 + ))?]) + .is_err() + ); // Exponential assert_eq!( @@ -342,6 +379,7 @@ mod tests { } #[test] + #[expect(deprecated)] fn test_propagate_statistics_range_holders() -> Result<()> { let negative_expr = NegativeExpr::new(Arc::new(Column::new("a", 0))); let original_child_interval = Interval::make(Some(-2), Some(3))?; @@ -403,3 +441,111 @@ mod tests { Ok(()) } } + +#[cfg(all(test, feature = "proto"))] +mod proto_tests { + use super::*; + use crate::expressions::{Column, col}; + use crate::proto_test_util::{ + StubDecoder, StubEncoder, UnreachableDecoder, column_node, + }; + use arrow::datatypes::Field; + use datafusion_common::DataFusionError; + use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx; + use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx; + use datafusion_proto_models::protobuf::{ + PhysicalExprNode, PhysicalNegativeNode, physical_expr_node, + }; + + /// Build a `NegativeExpr` proto node with the given children. + fn negative_node(expr: Option>) -> PhysicalExprNode { + PhysicalExprNode { + expr_id: None, + expr_type: Some(physical_expr_node::ExprType::Negative(Box::new( + PhysicalNegativeNode { expr }, + ))), + } + } + + /// A `NegativeExpr` over a column of type Int32. + fn negative_fixture() -> NegativeExpr { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + NegativeExpr::new(col("a", &schema).unwrap()) + } + + #[test] + fn try_to_proto_encodes_negative_expr() { + let negative = negative_fixture(); + let encoder = StubEncoder::ok(); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let node = negative + .try_to_proto(&ctx) + .unwrap() + .expect("NegativeExpr should encode to Some(node)"); + + assert!(node.expr_id.is_none()); + let negative_node = match node.expr_type { + Some(physical_expr_node::ExprType::Negative(boxed)) => *boxed, + other => panic!("expected a NegativeExpr node, got {other:?}"), + }; + assert!(negative_node.expr.is_some()); + } + + #[test] + fn try_to_proto_propagates_expr_encode_error() { + let negative = negative_fixture(); + let encoder = StubEncoder::failing_on(1); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + let err = negative.try_to_proto(&ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1"))); + } + + #[test] + fn try_from_proto_decodes_negative_expr() { + let node = negative_node(Some(Box::new(column_node("a")))); + let schema = Schema::empty(); + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let decoded = NegativeExpr::try_from_proto(&node, &ctx).unwrap(); + let negative = decoded + .downcast_ref::() + .expect("decoded expr should be a NegativeExpr"); + assert!(negative.arg().downcast_ref::().is_some()); + } + + #[test] + fn try_from_proto_rejects_non_negative_node() { + let node = column_node("a"); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = NegativeExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("PhysicalExprNode is not a Negative")) + ); + } + + #[test] + fn try_from_proto_rejects_missing_expr() { + let node = negative_node(None); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = NegativeExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("NegativeExpr is missing required field 'expr'")) + ); + } + + #[test] + fn try_from_proto_propagates_expr_decode_error() { + let node = negative_node(Some(Box::new(column_node("a")))); + let schema = Schema::empty(); + let decoder = StubDecoder::failing_on(1); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = NegativeExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1"))); + } +} diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index 94610996c6b00..c866c6ab07113 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -17,7 +17,6 @@ //! NoOp placeholder for physical operations -use std::any::Any; use std::hash::Hash; use std::sync::Arc; @@ -26,7 +25,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{Result, internal_err}; use datafusion_expr::ColumnarValue; /// A place holder expression, can not be evaluated. @@ -49,11 +48,6 @@ impl std::fmt::Display for NoOp { } impl PhysicalExpr for NoOp { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - fn data_type(&self, _input_schema: &Schema) -> Result { Ok(DataType::Null) } diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 655fc7f92e965..f856dd568a8da 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -17,7 +17,6 @@ //! Not expression -use std::any::Any; use std::fmt; use std::hash::Hash; use std::sync::Arc; @@ -26,10 +25,11 @@ use crate::PhysicalExpr; use arrow::datatypes::{DataType, FieldRef, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{cast::as_boolean_array, internal_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, cast::as_boolean_array, internal_err}; +use datafusion_expr::ColumnarValue; use datafusion_expr::interval_arithmetic::Interval; +#[expect(deprecated)] use datafusion_expr::statistics::Distribution::{self, Bernoulli}; -use datafusion_expr::ColumnarValue; /// Not expression #[derive(Debug, Eq)] @@ -70,11 +70,6 @@ impl fmt::Display for NotExpr { } impl PhysicalExpr for NotExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - fn data_type(&self, _input_schema: &Schema) -> Result { Ok(DataType::Boolean) } @@ -132,6 +127,7 @@ impl PhysicalExpr for NotExpr { .map(|result| vec![result])) } + #[expect(deprecated)] fn evaluate_statistics(&self, children: &[&Distribution]) -> Result { match children[0] { Bernoulli(b) => { @@ -147,6 +143,7 @@ impl PhysicalExpr for NotExpr { } } + #[expect(deprecated)] fn propagate_statistics( &self, parent: &Distribution, @@ -184,6 +181,45 @@ impl PhysicalExpr for NotExpr { write!(f, "NOT ")?; self.arg.fmt_sql(f) } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr(Box::new( + protobuf::PhysicalNot { + expr: Some(Box::new(ctx.encode_child(&self.arg)?)), + }, + ))), + })) + } +} + +#[cfg(feature = "proto")] +impl NotExpr { + /// Reconstruct a [`NotExpr`] from its protobuf representation. + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_physical_expr_common::expect_expr_variant; + use datafusion_proto_models::protobuf; + + let not_expr = expect_expr_variant!( + node, + protobuf::physical_expr_node::ExprType::NotExpr, + "NotExpr", + ); + let expr = + ctx.decode_required_expression(not_expr.expr.as_deref(), "NotExpr", "expr")?; + + Ok(Arc::new(NotExpr::new(expr))) + } } /// Creates a unary expression NOT @@ -196,7 +232,7 @@ mod tests { use std::sync::LazyLock; use super::*; - use crate::expressions::{col, Column}; + use crate::expressions::{Column, col}; use arrow::{array::BooleanArray, datatypes::*}; use datafusion_physical_expr_common::physical_expr::fmt_sql; @@ -259,34 +295,38 @@ mod tests { } #[test] + #[expect(deprecated)] fn test_evaluate_statistics() -> Result<()> { let _schema = &Schema::new(vec![Field::new("a", DataType::Boolean, false)]); let a = Arc::new(Column::new("a", 0)) as _; let expr = not(a)?; // Uniform with non-boolean bounds - assert!(expr - .evaluate_statistics(&[&Distribution::new_uniform( + assert!( + expr.evaluate_statistics(&[&Distribution::new_uniform( Interval::make_unbounded(&DataType::Float64)? )?]) - .is_err()); + .is_err() + ); // Exponential - assert!(expr - .evaluate_statistics(&[&Distribution::new_exponential( + assert!( + expr.evaluate_statistics(&[&Distribution::new_exponential( ScalarValue::from(1.0), ScalarValue::from(1.0), true )?]) - .is_err()); + .is_err() + ); // Gaussian - assert!(expr - .evaluate_statistics(&[&Distribution::new_gaussian( + assert!( + expr.evaluate_statistics(&[&Distribution::new_gaussian( ScalarValue::from(1.0), ScalarValue::from(1.0), )?]) - .is_err()); + .is_err() + ); // Bernoulli assert_eq!( @@ -310,24 +350,26 @@ mod tests { Distribution::new_bernoulli(ScalarValue::from(0.75))? ); - assert!(expr - .evaluate_statistics(&[&Distribution::new_generic( + assert!( + expr.evaluate_statistics(&[&Distribution::new_generic( ScalarValue::Null, ScalarValue::Null, ScalarValue::Null, Interval::make_unbounded(&DataType::UInt8)? )?]) - .is_err()); + .is_err() + ); // Unknown with non-boolean interval as range - assert!(expr - .evaluate_statistics(&[&Distribution::new_generic( + assert!( + expr.evaluate_statistics(&[&Distribution::new_generic( ScalarValue::Null, ScalarValue::Null, ScalarValue::Null, Interval::make_unbounded(&DataType::Float64)? )?]) - .is_err()); + .is_err() + ); Ok(()) } @@ -354,3 +396,112 @@ mod tests { Arc::clone(&SCHEMA) } } + +/// Tests for the `try_to_proto` / `try_from_proto` hooks. +#[cfg(all(test, feature = "proto"))] +mod proto_tests { + use super::*; + use crate::expressions::{Column, col}; + use crate::proto_test_util::{ + StubDecoder, StubEncoder, UnreachableDecoder, column_node, + }; + use arrow::datatypes::Field; + use datafusion_common::DataFusionError; + use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx; + use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx; + use datafusion_proto_models::protobuf::{ + PhysicalExprNode, PhysicalNot, physical_expr_node, + }; + + /// Build a `NotExpr` proto node with the given child. + fn not_node(expr: Option>) -> PhysicalExprNode { + PhysicalExprNode { + expr_id: None, + expr_type: Some(physical_expr_node::ExprType::NotExpr(Box::new( + PhysicalNot { expr }, + ))), + } + } + + /// A `NotExpr` over a boolean column. + fn not_fixture() -> NotExpr { + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); + NotExpr::new(col("a", &schema).unwrap()) + } + + #[test] + fn try_to_proto_encodes_not_expr() { + let not = not_fixture(); + let encoder = StubEncoder::ok(); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let node = not + .try_to_proto(&ctx) + .unwrap() + .expect("NotExpr should encode to Some(node)"); + + assert!(node.expr_id.is_none()); + let not_node = match node.expr_type { + Some(physical_expr_node::ExprType::NotExpr(boxed)) => *boxed, + other => panic!("expected a NotExpr node, got {other:?}"), + }; + assert!(not_node.expr.is_some()); + } + + #[test] + fn try_to_proto_propagates_expr_encode_error() { + let not = not_fixture(); + let encoder = StubEncoder::failing_on(1); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + let err = not.try_to_proto(&ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1"))); + } + + #[test] + fn try_from_proto_decodes_not_expr() { + let node = not_node(Some(Box::new(column_node("a")))); + let schema = Schema::empty(); + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let decoded = NotExpr::try_from_proto(&node, &ctx).unwrap(); + let not = decoded + .downcast_ref::() + .expect("decoded expr should be a NotExpr"); + assert!(not.arg().downcast_ref::().is_some()); + } + + #[test] + fn try_from_proto_rejects_non_not_node() { + let node = column_node("a"); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = NotExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("PhysicalExprNode is not a NotExpr")) + ); + } + + #[test] + fn try_from_proto_rejects_missing_expr() { + let node = not_node(None); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = NotExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("NotExpr is missing required field 'expr'")) + ); + } + + #[test] + fn try_from_proto_propagates_expr_decode_error() { + let node = not_node(Some(Box::new(column_node("a")))); + let schema = Schema::empty(); + let decoder = StubDecoder::failing_on(1); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = NotExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1"))); + } +} diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index b32aabbe5b006..65b953fd181b7 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::fmt; use std::hash::Hash; use std::sync::Arc; @@ -27,7 +26,7 @@ use arrow::datatypes::{DataType, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use compute::can_cast_types; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::{Result, not_impl_err}; use datafusion_expr::ColumnarValue; /// TRY_CAST expression casts an expression to a specific data type and returns NULL on invalid cast @@ -72,16 +71,11 @@ impl TryCastExpr { impl fmt::Display for TryCastExpr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "TRY_CAST({} AS {:?})", self.expr, self.cast_type) + write!(f, "TRY_CAST({} AS {})", self.expr, self.cast_type) } } impl PhysicalExpr for TryCastExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - fn data_type(&self, _input_schema: &Schema) -> Result { Ok(self.cast_type.clone()) } @@ -125,6 +119,56 @@ impl PhysicalExpr for TryCastExpr { self.expr.fmt_sql(f)?; write!(f, " AS {:?})", self.cast_type) } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( + protobuf::PhysicalTryCastNode { + expr: Some(Box::new(ctx.encode_child(&self.expr)?)), + arrow_type: Some(self.cast_type().try_into()?), + }, + ))), + })) + } +} + +#[cfg(feature = "proto")] +impl TryCastExpr { + /// Reconstruct a [`TryCastExpr`] from its protobuf representation. + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_physical_expr_common::expect_expr_variant; + use datafusion_physical_expr_common::physical_expr::proto_decode::require_proto_field; + use datafusion_proto_models::protobuf; + + let try_cast = expect_expr_variant!( + node, + protobuf::physical_expr_node::ExprType::TryCast, + "TryCastExpr", + ); + let expr = ctx.decode_required_expression( + try_cast.expr.as_deref(), + "TryCastExpr", + "expr", + )?; + let arrow_type = require_proto_field( + try_cast.arrow_type.as_ref(), + "TryCastExpr", + "arrow_type", + )?; + let cast_type: DataType = arrow_type.try_into()?; + + Ok(Arc::new(TryCastExpr::new(expr, cast_type))) + } } /// Return a PhysicalExpression representing `expr` casted to @@ -155,8 +199,8 @@ mod tests { }; use arrow::{ array::{ - Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, TimestampNanosecondArray, UInt32Array, + Array, Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, + Int64Array, TimestampNanosecondArray, UInt32Array, }, datatypes::*, }; @@ -180,7 +224,7 @@ mod tests { // verify that its display is correct assert_eq!( - format!("TRY_CAST(a@0 AS {:?})", $TYPE), + format!("TRY_CAST(a@0 AS {})", $TYPE), format!("{}", expression) ); @@ -206,7 +250,7 @@ mod tests { for (i, x) in $VEC.iter().enumerate() { match x { Some(x) => assert_eq!(result.value(i), *x), - None => assert!(!result.is_valid(i)), + None => assert!(result.is_null(i)), } } }}; @@ -231,7 +275,7 @@ mod tests { // verify that its display is correct assert_eq!( - format!("TRY_CAST(a@0 AS {:?})", $TYPE), + format!("TRY_CAST(a@0 AS {})", $TYPE), format!("{}", expression) ); @@ -260,7 +304,7 @@ mod tests { for (i, x) in $VEC.iter().enumerate() { match x { Some(x) => assert_eq!(result.value(i), *x), - None => assert!(!result.is_valid(i)), + None => assert!(result.is_null(i)), } } }}; @@ -599,3 +643,143 @@ mod tests { Ok(()) } } + +#[cfg(all(test, feature = "proto"))] +mod proto_tests { + use super::*; + use crate::expressions::{Column, col}; + use crate::proto_test_util::{ + StubDecoder, StubEncoder, UnreachableDecoder, column_node, + }; + use arrow::datatypes::Field; + use datafusion_common::DataFusionError; + use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx; + use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx; + use datafusion_proto_models::datafusion_common::ArrowType; + use datafusion_proto_models::protobuf::{ + PhysicalExprNode, PhysicalTryCastNode, physical_expr_node, + }; + + fn try_cast_fixture() -> TryCastExpr { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + TryCastExpr::new(col("a", &schema).unwrap(), DataType::Int32) + } + + fn int32_arrow_type() -> ArrowType { + (&DataType::Int32).try_into().unwrap() + } + + fn try_cast_node( + expr: Option>, + arrow_type: Option, + ) -> PhysicalExprNode { + PhysicalExprNode { + expr_id: None, + expr_type: Some(physical_expr_node::ExprType::TryCast(Box::new( + PhysicalTryCastNode { expr, arrow_type }, + ))), + } + } + + #[test] + fn try_to_proto_encodes_try_cast_expr() { + let try_cast = try_cast_fixture(); + let encoder = StubEncoder::ok(); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let node = try_cast + .try_to_proto(&ctx) + .unwrap() + .expect("TryCastExpr should encode to Some(node)"); + + assert!(node.expr_id.is_none()); + let try_cast_node = match node.expr_type { + Some(physical_expr_node::ExprType::TryCast(boxed)) => *boxed, + other => panic!("expected a TryCastExpr node, got {other:?}"), + }; + assert!(try_cast_node.expr.is_some()); + + let arrow_type = try_cast_node + .arrow_type + .as_ref() + .expect("try cast type should be encoded"); + let data_type: DataType = arrow_type.try_into().unwrap(); + assert_eq!(data_type, DataType::Int32); + } + + #[test] + fn try_to_proto_propagates_child_encode_error() { + let try_cast = try_cast_fixture(); + let encoder = StubEncoder::failing_on(1); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + let err = try_cast.try_to_proto(&ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1"))); + } + + #[test] + fn try_from_proto_decodes_try_cast_expr() { + let node = + try_cast_node(Some(Box::new(column_node("a"))), Some(int32_arrow_type())); + let schema = Schema::empty(); + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let decoded = TryCastExpr::try_from_proto(&node, &ctx).unwrap(); + let try_cast = decoded + .downcast_ref::() + .expect("decoded expr should be a TryCastExpr"); + + assert_eq!(try_cast.cast_type(), &DataType::Int32); + assert!(try_cast.expr().downcast_ref::().is_some()); + } + + #[test] + fn try_from_proto_rejects_non_try_cast_node() { + let node = column_node("a"); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("PhysicalExprNode is not a TryCastExpr")) + ); + } + + #[test] + fn try_from_proto_rejects_missing_expr() { + let node = try_cast_node(None, Some(int32_arrow_type())); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("TryCastExpr is missing required field 'expr'")) + ); + } + + #[test] + fn try_from_proto_rejects_missing_arrow_type() { + let node = try_cast_node(Some(Box::new(column_node("a"))), None); + let schema = Schema::empty(); + let decoder = StubDecoder::ok(); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!( + matches!(err, DataFusionError::Internal(msg) if msg.contains("TryCastExpr is missing required field 'arrow_type'")) + ); + } + + #[test] + fn try_from_proto_propagates_child_decode_error() { + let node = + try_cast_node(Some(Box::new(column_node("a"))), Some(int32_arrow_type())); + let schema = Schema::empty(); + let decoder = StubDecoder::failing_on(1); + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = TryCastExpr::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(msg) if msg.contains("call 1"))); + } +} diff --git a/datafusion/physical-expr/src/expressions/unknown_column.rs b/datafusion/physical-expr/src/expressions/unknown_column.rs index 2face4eb6bdb6..ed85f20dd274b 100644 --- a/datafusion/physical-expr/src/expressions/unknown_column.rs +++ b/datafusion/physical-expr/src/expressions/unknown_column.rs @@ -17,7 +17,6 @@ //! UnKnownColumn expression -use std::any::Any; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -27,7 +26,8 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{Result, internal_err}; + use datafusion_expr::ColumnarValue; #[derive(Debug, Clone, Eq)] @@ -56,11 +56,6 @@ impl std::fmt::Display for UnKnownColumn { } impl PhysicalExpr for UnKnownColumn { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - /// Get the data type of this expression, given the schema of the input fn data_type(&self, _input_schema: &Schema) -> Result { Ok(DataType::Null) @@ -90,6 +85,42 @@ impl PhysicalExpr for UnKnownColumn { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + _ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::UnknownColumn( + protobuf::UnknownColumn { + name: self.name.clone(), + }, + )), + })) + } +} + +#[cfg(feature = "proto")] +impl UnKnownColumn { + /// Reconstruct an [`UnKnownColumn`] from its protobuf representation. + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + _ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_physical_expr_common::expect_expr_variant; + use datafusion_proto_models::protobuf; + + let unknown_col = expect_expr_variant!( + node, + protobuf::physical_expr_node::ExprType::UnknownColumn, + "UnKnownColumn", + ); + Ok(Arc::new(UnKnownColumn::new(&unknown_col.name))) + } } impl Hash for UnKnownColumn { @@ -105,3 +136,103 @@ impl PartialEq for UnKnownColumn { false } } + +/// Tests for the `try_to_proto` / `try_from_proto` hooks. +#[cfg(all(test, feature = "proto"))] +mod proto_tests { + use super::*; + use crate::proto_test_util::{StubEncoder, UnreachableDecoder, column_node}; + use arrow::datatypes::Schema; + use datafusion_common::DataFusionError; + use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx; + use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx; + use datafusion_proto_models::protobuf::{self, physical_expr_node}; + + // ── try_to_proto ───────────────────────────────────────────────────────── + + #[test] + fn try_to_proto_encodes_unknown_column() { + let expr = UnKnownColumn::new("my_col"); + let encoder = StubEncoder::ok(); + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let node = expr + .try_to_proto(&ctx) + .unwrap() + .expect("UnKnownColumn should encode to Some(node)"); + + // Built-in exprs never set expr_id; only dynamic filters do. + assert!(node.expr_id.is_none()); + + // Verify the encoded name matches the original. + let protobuf::UnknownColumn { name } = match node.expr_type { + Some(physical_expr_node::ExprType::UnknownColumn(c)) => c, + other => panic!("expected UnknownColumn proto node, got {other:?}"), + }; + assert_eq!(name, "my_col"); + } + + // ── try_from_proto ─────────────────────────────────────────────────────── + + #[test] + fn try_from_proto_decodes_name() { + let node = protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(physical_expr_node::ExprType::UnknownColumn( + protobuf::UnknownColumn { + name: "my_col".to_string(), + }, + )), + }; + let schema = Schema::empty(); + // UnKnownColumn has no child exprs so the decoder is never called. + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let decoded = UnKnownColumn::try_from_proto(&node, &ctx).unwrap(); + let col = decoded + .downcast_ref::() + .expect("decoded expr should be an UnKnownColumn"); + assert_eq!(col.name(), "my_col"); + } + + #[test] + fn try_from_proto_rejects_non_unknown_column_node() { + // column_node produces an ExprType::Column node, not UnknownColumn. + let node = column_node("a"); + let schema = Schema::empty(); + let decoder = UnreachableDecoder; + let ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + let err = UnKnownColumn::try_from_proto(&node, &ctx).unwrap_err(); + assert!(matches!( + err, + DataFusionError::Internal(ref msg) + if msg.contains("PhysicalExprNode is not a UnKnownColumn") + )); + } + + // ── roundtrip ──────────────────────────────────────────────────────────── + + #[test] + fn unknown_column_proto_roundtrip() { + let expr = UnKnownColumn::new("col_b"); + let encoder = StubEncoder::ok(); + let enc_ctx = PhysicalExprEncodeCtx::new(&encoder); + + let node = expr + .try_to_proto(&enc_ctx) + .unwrap() + .expect("UnKnownColumn should encode to Some(node)"); + + let schema = Schema::empty(); + // UnKnownColumn has no child exprs so the decoder is never called. + let decoder = UnreachableDecoder; + let dec_ctx = PhysicalExprDecodeCtx::new(&schema, &decoder); + + let decoded = UnKnownColumn::try_from_proto(&node, &dec_ctx).unwrap(); + let col = decoded + .downcast_ref::() + .expect("decoded expr should be an UnKnownColumn"); + assert_eq!(col.name(), "col_b"); + } +} diff --git a/datafusion/physical-expr/src/higher_order_function.rs b/datafusion/physical-expr/src/higher_order_function.rs new file mode 100644 index 0000000000000..7390eb33a0922 --- /dev/null +++ b/datafusion/physical-expr/src/higher_order_function.rs @@ -0,0 +1,718 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Declaration of built-in (higher order) functions. +//! This module contains built-in functions' enumeration and metadata. +//! +//! Generally, a function has: +//! * a signature +//! * a return type, that is a function of the incoming argument's types +//! * the computation, that must accept each valid signature +//! +//! * Signature: see `Signature` +//! * Return type: a function `(arg_types) -> return_type`. E.g. for array_transform, ([[f32]], v -> v*2) -> [f32], ([[f32]], v -> v > 3.0) -> [bool]. +//! +//! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed +//! to a function that supports f64, it is coerced to f64. + +use std::fmt::{self, Debug, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use crate::PhysicalExpr; +use crate::expressions::{LambdaExpr, Literal}; + +use arrow::array::{Array, RecordBatch}; +use arrow::datatypes::{DataType, FieldRef, Schema}; +use datafusion_common::config::{ConfigEntry, ConfigOptions}; +use datafusion_common::datatype::FieldExt; +use datafusion_common::utils::remove_list_null_values; +use datafusion_common::{ + Result, ScalarValue, exec_err, internal_datafusion_err, internal_err, + plan_datafusion_err, plan_err, +}; +use datafusion_expr::type_coercion::functions::value_fields_with_higher_order_udf; +use datafusion_expr::{ + ColumnarValue, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, HigherOrderUDF, + LambdaArgument, LambdaParametersProgress, ValueOrLambda, Volatility, expr_vec_fmt, +}; + +/// Per-argument classification cached at construction time. +/// +/// Walking the wrapped lambda tree and scanning a `Vec` of lambda +/// positions used to be done on every `evaluate` call. Both costs collapse +/// to a single up-front pass by storing the classification (and the resolved +/// inner [`LambdaExpr`]) here. +enum ArgSlot { + /// A regular value-producing expression at this position. + Value, + /// A lambda position. Stores the inner [`LambdaExpr`] pre-extracted from + /// any wrapper expressions that may have been introduced via + /// [`PhysicalExpr::with_new_children`] tree rewrites. + Lambda(Arc), +} + +/// Physical expression of a higher order function +pub struct HigherOrderFunctionExpr { + /// A shared instance of the higher-order function + fun: Arc, + /// The name of the higher-order function + name: String, + /// List of expressions to feed to the function as arguments + /// + /// For example, for `array_transform([2, 3], v -> v != 2)`, this will be: + /// + /// ```text + /// ListExpression [2,3] + /// LambdaExpression + /// parameters: ["v"] + /// body: + /// BinaryExpression (!=) + /// left: + /// LambdaVariableExpression("v", Field::new("", Int32, false)) + /// right: + /// LiteralExpression(2) + /// ``` + args: Vec>, + /// Per-arg classification, parallel to `args`. Length always equals + /// `args.len()`. Lambda variants carry the resolved inner [`LambdaExpr`] + /// so `evaluate` doesn't walk through wrapper nodes. + slots: Vec, + /// The output field associated this expression + /// + /// For example, for `array_transform([2, 3], v -> v != 2)`, this will be + /// `Field::new("", DataType::new_list(DataType::Boolean, true), true)` + return_field: FieldRef, + /// The config options at execution time + config_options: Arc, +} + +impl Debug for HigherOrderFunctionExpr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let lambda_positions: Vec<_> = self + .slots + .iter() + .enumerate() + .filter_map(|(i, slot)| matches!(slot, ArgSlot::Lambda(_)).then_some(i)) + .collect(); + f.debug_struct("HigherOrderFunctionExpr") + .field("fun", &"") + .field("name", &self.name) + .field("args", &self.args) + .field("lambda_positions", &lambda_positions) + .field("return_field", &self.return_field) + .finish() + } +} + +impl HigherOrderFunctionExpr { + /// Create a new Higher Order function + /// + /// Note that lambda arguments must be present directly in args as [LambdaExpr], + /// and not as a wrapped child of any arg + pub fn try_new_with_schema( + fun: Arc, + args: Vec>, + schema: &Schema, + config_options: Arc, + ) -> Result { + let name = fun.name().to_string(); + let mut slots = Vec::with_capacity(args.len()); + let arg_fields = args + .iter() + .map(|e| match e.downcast_ref::() { + Some(lambda) => { + slots.push(ArgSlot::Lambda(Arc::new(lambda.clone()))); + Ok(ValueOrLambda::Lambda(lambda.body().return_field(schema)?)) + } + None => { + slots.push(ArgSlot::Value); + Ok(ValueOrLambda::Value(e.return_field(schema)?)) + } + }) + .collect::>>()?; + + // verify that input data types is consistent with function's `HigherOrderTypeSignature` + value_fields_with_higher_order_udf(&arg_fields, fun.as_ref())?; + + let arguments = args + .iter() + .map(|e| e.downcast_ref::().map(|literal| literal.value())) + .collect::>(); + + let ret_args = HigherOrderReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &arguments, + }; + + let return_field = fun.return_field_from_args(ret_args)?; + + Ok(Self { + fun, + name, + args, + slots, + return_field, + config_options, + }) + } + + /// Get the higher order function implementation + pub fn fun(&self) -> &HigherOrderUDF { + self.fun.as_ref() + } + + /// The name for this expression + pub fn name(&self) -> &str { + &self.name + } + + /// Input arguments + pub fn args(&self) -> &[Arc] { + &self.args + } + + /// Data type produced by this expression + pub fn return_type(&self) -> &DataType { + self.return_field.data_type() + } + + pub fn nullable(&self) -> bool { + self.return_field.is_nullable() + } + + pub fn config_options(&self) -> &ConfigOptions { + &self.config_options + } + + /// Resolve every lambda's parameter list. Returns an empty `Vec` when + /// there are no lambdas, avoiding the [`datafusion_expr::HigherOrderUDFImpl::lambda_parameters`] + /// virtual call entirely. + fn resolve_lambda_parameters( + &self, + fields: &[ValueOrLambda>], + ) -> Result>> { + let num_lambdas = self + .slots + .iter() + .filter(|s| matches!(s, ArgSlot::Lambda(_))) + .count(); + if num_lambdas == 0 { + return Ok(Vec::new()); + } + match self.fun().lambda_parameters(0, fields)? { + LambdaParametersProgress::Partial(_) => plan_err!( + "{} lambda_parameters returned a partial result when the return type of all it's lambdas were provided", + self.name() + ), + LambdaParametersProgress::Complete(items) => { + // functions can support multiple lambdas where some trailing ones are optional, + // but to simplify the implementor, lambda_parameters returns the parameters of all of them, + // so we can't do equality check. one example is spark reduce: + // https://spark.apache.org/docs/latest/api/sql/index.html#reduce + if items.len() < num_lambdas { + return exec_err!( + "{} invocation defined {num_lambdas} but lambda_parameters returned only {}", + self.name(), + items.len() + ); + } + Ok(items) + } + } + } +} + +impl fmt::Display for HigherOrderFunctionExpr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}({})", self.name, expr_vec_fmt!(self.args)) + } +} + +impl PartialEq for HigherOrderFunctionExpr { + fn eq(&self, o: &Self) -> bool { + if std::ptr::eq(self, o) { + // The equality implementation is somewhat expensive, so let's short-circuit when possible. + return true; + } + // `slots` is a deterministic function of `fun` and `args`, so it's + // not part of the comparison. + let Self { + fun, + name, + args, + slots: _, + return_field, + config_options, + } = self; + fun.eq(&o.fun) + && name.eq(&o.name) + && args.eq(&o.args) + && return_field.eq(&o.return_field) + && (Arc::ptr_eq(config_options, &o.config_options) + || sorted_config_entries(config_options) + == sorted_config_entries(&o.config_options)) + } +} +impl Eq for HigherOrderFunctionExpr {} +impl Hash for HigherOrderFunctionExpr { + fn hash(&self, state: &mut H) { + let Self { + fun, + name, + args, + slots: _, + return_field, + config_options: _, // expensive to hash, and often equal + } = self; + fun.hash(state); + name.hash(state); + args.hash(state); + return_field.hash(state); + } +} + +fn sorted_config_entries(config_options: &ConfigOptions) -> Vec { + let mut entries = config_options.entries(); + entries.sort_by(|l, r| l.key.cmp(&r.key)); + entries +} + +impl PhysicalExpr for HigherOrderFunctionExpr { + fn evaluate(&self, batch: &RecordBatch) -> Result { + let mut arg_fields = Vec::with_capacity(self.args.len()); + let mut fields = Vec::with_capacity(self.args.len()); + for (arg, slot) in self.args.iter().zip(&self.slots) { + match slot { + ArgSlot::Lambda(lambda) => { + let field = lambda.body().return_field(batch.schema_ref())?; + arg_fields.push(ValueOrLambda::Lambda(Arc::clone(&field))); + fields.push(ValueOrLambda::Lambda(Some(field))); + } + ArgSlot::Value => { + let field = arg.return_field(batch.schema_ref())?; + arg_fields.push(ValueOrLambda::Value(Arc::clone(&field))); + fields.push(ValueOrLambda::Value(field)); + } + } + } + + let mut lambda_parameters = self.resolve_lambda_parameters(&fields)?.into_iter(); + + let args = self + .args + .iter() + .zip(&self.slots) + .map(|(arg, slot)| match slot { + ArgSlot::Lambda(lambda) => { + let lambda_params = lambda_parameters.next().ok_or_else(|| { + internal_datafusion_err!( + "params len should have been checked above" + ) + })?; + + if lambda.params().len() > lambda_params.len() { + return exec_err!( + "lambda defined {} params but higher-order function support only {}", + lambda.params().len(), + lambda_params.len() + ); + } + + let params = std::iter::zip(lambda.params(), lambda_params) + .map(|(name, param)| param.renamed(name.as_str())) + .collect(); + + // lambda.projection may include indexes of nested lambda variables not present on this batch + let projection = lambda + .projection() + .iter() + .copied() + .filter(|i| *i < batch.num_columns()) + .collect::>(); + + Ok(ValueOrLambda::Lambda(LambdaArgument::new( + params, + Arc::clone(lambda.projected_body()), + if projection.is_empty() { + None + } else { + Some(batch.project(&projection)?) + }, + ))) + } + ArgSlot::Value => { + let value = arg.evaluate(batch)?; + + let value = if self.fun.clear_null_values() + && matches!( + value.data_type(), + DataType::List(_) | DataType::LargeList(_) + ) + { + let arr = value.into_array(batch.num_rows())?; + if arr.null_count() == 0 { + ColumnarValue::Array(arr) + } else { + ColumnarValue::Array(remove_list_null_values(&arr)?) + } + } else { + value + }; + + Ok(ValueOrLambda::Value(value)) + } + }) + .collect::>>()?; + + let input_empty = args.is_empty(); + let input_all_scalar = args + .iter() + .all(|arg| matches!(arg, ValueOrLambda::Value(ColumnarValue::Scalar(_)))); + + // evaluate the function + let output = self.fun.invoke_with_args(HigherOrderFunctionArgs { + args, + arg_fields, + number_rows: batch.num_rows(), + return_field: Arc::clone(&self.return_field), + config_options: Arc::clone(&self.config_options), + })?; + + if let ColumnarValue::Array(array) = &output + && array.len() != batch.num_rows() + { + // If the arguments are a non-empty slice of scalar values, we can assume that + // returning a one-element array is equivalent to returning a scalar. + let preserve_scalar = array.len() == 1 && !input_empty && input_all_scalar; + return if preserve_scalar { + ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar) + } else { + internal_err!( + "higher-order function {} returned a different number of rows than expected. Expected: {}, Got: {}", + self.name, + batch.num_rows(), + array.len() + ) + }; + } + Ok(output) + } + + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::clone(&self.return_field)) + } + + fn children(&self) -> Vec<&Arc> { + self.args.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() != self.args.len() { + return internal_err!( + "HigherOrderFunctionExpr expects exactly {} child, got {}", + self.args.len(), + children.len() + ); + } + + // Re-derive `slots` for the new children using the original slot kinds + // as the source of truth for which positions must (still) be lambdas. + let mut new_slots = Vec::with_capacity(children.len()); + for (i, child) in children.iter().enumerate() { + match &self.slots[i] { + ArgSlot::Lambda(_) => { + let lambda = wrapped_lambda(child).ok_or_else(|| { + plan_datafusion_err!( + "{} unable to unwrap lambda from {} at position {i}", + &children[i], + self.name() + ) + })?; + new_slots.push(ArgSlot::Lambda(Arc::new(lambda.clone()))); + } + ArgSlot::Value => { + if child.is::() { + return plan_err!( + "{} received a lambda via with_new_children at position {i} that wasn't a lambda before", + self.name() + ); + } + new_slots.push(ArgSlot::Value); + } + } + } + + Ok(Arc::new(HigherOrderFunctionExpr { + name: self.name.clone(), + fun: Arc::clone(&self.fun), + args: children, + slots: new_slots, + return_field: Arc::clone(&self.return_field), + config_options: Arc::clone(&self.config_options), + })) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}(", self.name)?; + for (i, expr) in self.args.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + expr.fmt_sql(f)?; + } + write!(f, ")") + } + + fn is_volatile_node(&self) -> bool { + self.fun.signature().volatility == Volatility::Volatile + } +} + +fn wrapped_lambda(expr: &Arc) -> Option<&LambdaExpr> { + let mut current = expr; + + loop { + if let Some(lambda) = current.downcast_ref::() { + return Some(lambda); + } else if current.is::() { + return None; + } + + match current.children().as_slice() { + [single_child] => current = *single_child, + _ => return None, + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use crate::HigherOrderFunctionExpr; + use crate::expressions::Column; + use crate::expressions::NoOp; + use crate::expressions::lambda; + use crate::expressions::not; + use arrow::array::NullArray; + use arrow::array::RecordBatchOptions; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_common::assert_contains; + use datafusion_expr::{ + HigherOrderFunctionArgs, HigherOrderSignature, HigherOrderUDF, HigherOrderUDFImpl, + }; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::physical_expr::is_volatile; + + /// Test helper to create a mock UDF with a specific volatility + #[derive(Debug, PartialEq, Eq, Hash)] + struct MockHigherOrderUDF { + signature: HigherOrderSignature, + } + + impl HigherOrderUDFImpl for MockHigherOrderUDF { + fn name(&self) -> &str { + "mock_function" + } + + fn signature(&self) -> &HigherOrderSignature { + &self.signature + } + + fn lambda_parameters( + &self, + _step: usize, + _fields: &[ValueOrLambda>], + ) -> Result { + Ok(LambdaParametersProgress::Complete(vec![vec![Arc::new( + Field::new("", DataType::Null, true), + )]])) + } + + fn return_field_from_args( + &self, + args: HigherOrderReturnFieldArgs, + ) -> Result { + match &args.arg_fields[0] { + ValueOrLambda::Lambda(field) | ValueOrLambda::Value(field) => { + Ok(Arc::clone(field)) + } + } + } + + fn invoke_with_args( + &self, + args: HigherOrderFunctionArgs, + ) -> Result { + match &args.args[0] { + ValueOrLambda::Lambda(lambda) => lambda.evaluate( + &[&|| Ok(Arc::new(NullArray::new(args.number_rows)))], + |arrays| Ok(arrays.to_vec()), + ), + ValueOrLambda::Value(value) => Ok(value.clone()), + } + } + } + + #[test] + fn test_higher_order_function_volatile_node() { + // Create a volatile UDF + let volatile_udf = Arc::new(HigherOrderUDF::new_from_impl(MockHigherOrderUDF { + signature: HigherOrderSignature::variadic_any(Volatility::Volatile), + })); + + // Create a non-volatile UDF + let stable_udf = Arc::new(HigherOrderUDF::new_from_impl(MockHigherOrderUDF { + signature: HigherOrderSignature::variadic_any(Volatility::Stable), + })); + + let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); + let args = vec![Arc::new(Column::new("a", 0)) as Arc]; + let config_options = Arc::new(ConfigOptions::new()); + + // Test volatile function + let volatile_expr = HigherOrderFunctionExpr::try_new_with_schema( + volatile_udf, + args.clone(), + &schema, + Arc::clone(&config_options), + ) + .unwrap(); + + assert!(volatile_expr.is_volatile_node()); + let volatile_arc: Arc = Arc::new(volatile_expr); + assert!(is_volatile(&volatile_arc)); + + // Test non-volatile function + let stable_expr = HigherOrderFunctionExpr::try_new_with_schema( + stable_udf, + args, + &schema, + config_options, + ) + .unwrap(); + + assert!(!stable_expr.is_volatile_node()); + let stable_arc: Arc = Arc::new(stable_expr); + assert!(!is_volatile(&stable_arc)); + } + + #[test] + fn test_higher_order_function_wrapped_lambda() { + let fun = Arc::new(HigherOrderUDF::new_from_impl(MockHigherOrderUDF { + signature: HigherOrderSignature::variadic_any(Volatility::Stable), + })); + + let expected = ScalarValue::Int32(Some(42)); + + let hof = HigherOrderFunctionExpr::try_new_with_schema( + fun, + vec![lambda(["a"], Arc::new(Literal::new(expected.clone()))).unwrap()], + &Schema::empty(), + Arc::new(ConfigOptions::new()), + ) + .unwrap(); + + let new_children = vec![not(Arc::clone(&hof.args[0])).unwrap()]; + let wrapped = Arc::new(hof).with_new_children(new_children).unwrap(); + + let result = wrapped + .evaluate( + &RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &RecordBatchOptions::new().with_row_count(Some(0)), + ) + .unwrap(), + ) + .unwrap(); + + let ColumnarValue::Scalar(result) = result else { + unreachable!() + }; + + assert_eq!(result, expected); + } + + #[test] + fn test_higher_order_function_badly_wrapped_lambda() { + let fun = Arc::new(HigherOrderUDF::new_from_impl(MockHigherOrderUDF { + signature: HigherOrderSignature::variadic_any(Volatility::Stable), + })); + + let hof = HigherOrderFunctionExpr::try_new_with_schema( + fun, + vec![ + not( + lambda(["a"], Arc::new(Literal::new(ScalarValue::Int32(Some(42))))) + .unwrap(), + ) + .unwrap(), + ], + &Schema::empty(), + Arc::new(ConfigOptions::new()), + ) + .unwrap(); + + let result = hof + .evaluate( + &RecordBatch::try_new_with_options( + Arc::new(Schema::empty()), + vec![], + &RecordBatchOptions::new().with_row_count(Some(0)), + ) + .unwrap(), + ) + .unwrap_err(); + + assert_contains!( + result.to_string(), + "LambdaExpr::evaluate() should not be called" + ); + } + + #[test] + fn test_higher_order_function_unexpected_lambda() { + let fun = Arc::new(HigherOrderUDF::new_from_impl(MockHigherOrderUDF { + signature: HigherOrderSignature::variadic_any(Volatility::Stable), + })); + + let hof = HigherOrderFunctionExpr::try_new_with_schema( + fun, + vec![Arc::new(NoOp::new())], + &Schema::empty(), + Arc::new(ConfigOptions::new()), + ) + .unwrap(); + + let result = Arc::new(hof) + .with_new_children(vec![lambda(["a"], Arc::new(NoOp::new())).unwrap()]) + .unwrap_err(); + + assert_contains!( + result.to_string(), + "mock_function received a lambda via with_new_children at position 0 that wasn't a lambda before" + ); + } +} diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index b1fe75a223010..aee65f35dc49c 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -148,19 +148,19 @@ use std::sync::Arc; use super::utils::{ convert_duration_type_to_interval, convert_interval_type_to_duration, get_inverse_op, }; -use crate::expressions::{BinaryExpr, Literal}; -use crate::utils::{build_dag, ExprTreeNode}; use crate::PhysicalExpr; +use crate::expressions::{BinaryExpr, Literal}; +use crate::utils::{ExprTreeNode, build_dag}; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{internal_err, not_impl_err, Result}; -use datafusion_expr::interval_arithmetic::{apply_operator, satisfy_greater, Interval}; +use datafusion_common::{Result, internal_err, not_impl_err}; use datafusion_expr::Operator; +use datafusion_expr::interval_arithmetic::{Interval, apply_operator, satisfy_greater}; +use petgraph::Outgoing; use petgraph::graph::NodeIndex; use petgraph::stable_graph::{DefaultIx, StableGraph}; use petgraph::visit::{Bfs, Dfs, DfsPostOrder, EdgeRef}; -use petgraph::Outgoing; /// This object implements a directed acyclic expression graph (DAEG) that /// is used to compute ranges for expressions through interval arithmetic. @@ -220,7 +220,7 @@ impl ExprIntervalGraphNode { /// any other expression starts with an indefinite interval (`[-∞, ∞]`). pub fn make_node(node: &ExprTreeNode, schema: &Schema) -> Result { let expr = Arc::clone(&node.expr); - if let Some(literal) = expr.as_any().downcast_ref::() { + if let Some(literal) = expr.downcast_ref::() { let value = literal.value(); Interval::try_new(value.clone(), value.clone()) .map(|interval| Self::new_with_interval(expr, interval)) @@ -646,7 +646,6 @@ impl ExprIntervalGraph { if node_interval == &Interval::TRUE && self.graph[node] .expr - .as_any() .downcast_ref::() .is_some_and(|expr| expr.op() == &Operator::Or) { @@ -768,7 +767,7 @@ fn reverse_tuple((first, second): (T, U)) -> (U, T) { #[cfg(test)] mod tests { use super::*; - use crate::expressions::{BinaryExpr, Column}; + use crate::expressions::Column; use crate::intervals::test_utils::gen_conjunctive_numerical_expr; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; @@ -780,7 +779,7 @@ mod tests { use rand::{Rng, SeedableRng}; use rstest::*; - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] fn experiment( expr: Arc, exprs_with_interval: (Arc, Arc), @@ -892,12 +891,12 @@ mod tests { PropagationResult::Success, &Schema::new(vec![ Field::new( - left_col.as_any().downcast_ref::().unwrap().name(), + left_col.downcast_ref::().unwrap().name(), DataType::$SCALAR, true, ), Field::new( - right_col.as_any().downcast_ref::().unwrap().name(), + right_col.downcast_ref::().unwrap().name(), DataType::$SCALAR, true, ), @@ -939,16 +938,8 @@ mod tests { Interval::make(Some(100), None)?, PropagationResult::Infeasible, &Schema::new(vec![ - Field::new( - left_col.as_any().downcast_ref::().unwrap().name(), - DataType::Int32, - true, - ), - Field::new( - right_col.as_any().downcast_ref::().unwrap().name(), - DataType::Int32, - true, - ), + Field::new(left_col.name(), DataType::Int32, true), + Field::new(right_col.name(), DataType::Int32, true), ]), ) } diff --git a/datafusion/physical-expr/src/intervals/test_utils.rs b/datafusion/physical-expr/src/intervals/test_utils.rs index c3d38a974ab02..805ffd27613ee 100644 --- a/datafusion/physical-expr/src/intervals/test_utils.rs +++ b/datafusion/physical-expr/src/intervals/test_utils.rs @@ -19,13 +19,13 @@ use std::sync::Arc; -use crate::expressions::{binary, BinaryExpr, Literal}; use crate::PhysicalExpr; +use crate::expressions::{BinaryExpr, Literal, binary}; use arrow::datatypes::Schema; use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::Operator; -#[allow(clippy::too_many_arguments)] +#[expect(clippy::too_many_arguments)] /// This test function generates a conjunctive statement with two numeric /// terms with the following form: /// left_col (op_1) a >/>= right_col (op_2) b AND left_col (op_3) c right_col (op_2) b AND left_col (op_3) c < right_col (op_4) d diff --git a/datafusion/physical-expr/src/intervals/utils.rs b/datafusion/physical-expr/src/intervals/utils.rs index 22752a00e9259..1090660a6b5e6 100644 --- a/datafusion/physical-expr/src/intervals/utils.rs +++ b/datafusion/physical-expr/src/intervals/utils.rs @@ -20,15 +20,15 @@ use std::sync::Arc; use crate::{ - expressions::{BinaryExpr, CastExpr, Column, Literal, NegativeExpr}, PhysicalExpr, + expressions::{BinaryExpr, CastExpr, Column, Literal, NegativeExpr}, }; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; use arrow::datatypes::{DataType, SchemaRef}; -use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_expr::interval_arithmetic::Interval; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::Operator; +use datafusion_expr::interval_arithmetic::Interval; /// Indicates whether interval arithmetic is supported for the given expression. /// Currently, we do not support all [`PhysicalExpr`]s for interval calculations. @@ -36,26 +36,25 @@ use datafusion_expr::Operator; /// will relax as more types of `PhysicalExpr`s and `Operator`s are supported. /// Currently, [`CastExpr`], [`NegativeExpr`], [`BinaryExpr`], [`Column`] and [`Literal`] are supported. pub fn check_support(expr: &Arc, schema: &SchemaRef) -> bool { - let expr_any = expr.as_any(); - if let Some(binary_expr) = expr_any.downcast_ref::() { + if let Some(binary_expr) = expr.downcast_ref::() { is_operator_supported(binary_expr.op()) && check_support(binary_expr.left(), schema) && check_support(binary_expr.right(), schema) - } else if let Some(column) = expr_any.downcast_ref::() { + } else if let Some(column) = expr.downcast_ref::() { if let Ok(field) = schema.field_with_name(column.name()) { is_datatype_supported(field.data_type()) } else { false } - } else if let Some(literal) = expr_any.downcast_ref::() { + } else if let Some(literal) = expr.downcast_ref::() { if let Ok(dt) = literal.data_type(schema) { is_datatype_supported(&dt) } else { false } - } else if let Some(cast) = expr_any.downcast_ref::() { + } else if let Some(cast) = expr.downcast_ref::() { check_support(cast.expr(), schema) - } else if let Some(negative) = expr_any.downcast_ref::() { + } else if let Some(negative) = expr.downcast_ref::() { check_support(negative.arg(), schema) } else { false @@ -104,6 +103,9 @@ pub fn is_datatype_supported(data_type: &DataType) -> bool { | &DataType::UInt8 | &DataType::Float64 | &DataType::Float32 + | &DataType::Date32 + | &DataType::Date64 + | &DataType::Timestamp(_, _) ) } diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index f59582f405064..b55bd70bdf185 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -23,8 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] // Backward compatibility @@ -36,12 +34,17 @@ pub mod binary_map { pub mod async_scalar_function; pub mod equivalence; pub mod expressions; +pub mod higher_order_function; pub mod intervals; mod partitioning; mod physical_expr; pub mod planner; pub mod projection; +/// Shared test helpers for the `try_to_proto` / `try_from_proto` unit tests +#[cfg(all(test, feature = "proto"))] +pub(crate) mod proto_test_util; mod scalar_function; +pub mod scalar_subquery; pub mod simplifier; pub mod statistics; pub mod utils; @@ -54,11 +57,13 @@ pub mod execution_props { } pub use aggregate::groups_accumulator::{GroupsAccumulatorAdapter, NullState}; -pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; +pub use analysis::{AnalysisContext, ExprBoundaries, analyze}; +pub use datafusion_common::SplitPoint; pub use equivalence::{ - calculate_union, AcrossPartitions, ConstExpr, EquivalenceProperties, + AcrossPartitions, ConstExpr, EquivalenceProperties, calculate_union, }; -pub use partitioning::{Distribution, Partitioning}; +pub use expressions::{DynamicFilterTracker, DynamicFilterTracking}; +pub use partitioning::{Distribution, Partitioning, RangePartitioning}; pub use physical_expr::{ add_offset_to_expr, add_offset_to_physical_sort_exprs, create_lex_ordering, create_ordering, create_physical_sort_expr, create_physical_sort_exprs, @@ -71,6 +76,7 @@ pub use datafusion_physical_expr_common::sort_expr::{ PhysicalSortRequirement, }; +pub use higher_order_function::HigherOrderFunctionExpr; pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; pub use simplifier::PhysicalExprSimplifier; diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs index d6b2b1b046f75..2e0aaaf3fb4b7 100644 --- a/datafusion/physical-expr/src/partitioning.rs +++ b/datafusion/physical-expr/src/partitioning.rs @@ -18,10 +18,13 @@ //! [`Partitioning`] and [`Distribution`] for `ExecutionPlans` use crate::{ - equivalence::ProjectionMapping, expressions::UnKnownColumn, physical_exprs_equal, - EquivalenceProperties, PhysicalExpr, + EquivalenceProperties, PhysicalExpr, equivalence::ProjectionMapping, + expressions::UnKnownColumn, physical_exprs_equal, }; +pub use datafusion_common::SplitPoint; +use datafusion_common::{Result, validate_range_split_points}; use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use std::fmt; use std::fmt::Display; use std::sync::Arc; @@ -117,6 +120,8 @@ pub enum Partitioning { /// Allocate rows based on a hash of one of more expressions and the specified number of /// partitions Hash(Vec>, usize), + /// Partition rows by source-declared ranges + Range(RangePartitioning), /// Unknown partitioning scheme with a known number of partitions UnknownPartitioning(usize), } @@ -133,66 +138,399 @@ impl Display for Partitioning { .join(", "); write!(f, "Hash([{phy_exprs_str}], {size})") } + Partitioning::Range(range) => write!(f, "{range}"), Partitioning::UnknownPartitioning(size) => { write!(f, "UnknownPartitioning({size})") } } } } + +/// Physical range partitioning. +/// +/// [`RangePartitioning`] describes an ordered key space with split points. +/// +/// - `ordering` defines the partitioning key and ordering. +/// - `split_points` define the boundaries between adjacent partitions. +/// +/// Comparisons use the lexicographic order defined by `ordering`, including +/// `ASC`/`DESC` and null ordering. Split points must be strictly ordered +/// according to that ordering, and each split point must have one value per +/// ordering expression. See [`SplitPoint`] for the shared boundary convention. +/// +/// Like other user-specified data properties such as sortedness, if a source +/// declares range partitioning, it is responsible for placing each row in the +/// partition described by the split points. DataFusion will not validate this is +/// upheld. +/// +/// For a single range key: +/// +/// ```text +/// ordering = [date ASC NULLS LAST] +/// split_points = [ +/// (2022-01-01), +/// (2023-01-01), +/// ] +/// +/// partition 0: date before 2022-01-01 +/// partition 1: date between 2022-01-01 (inclusive) and 2023-01-01 (exclusive) +/// partition 2: date at/after 2023-01-01 +/// ``` +/// +/// The same model extends to compound keys. +/// For `ordering = [time ASC, city ASC]`, split points are ordered +/// lexicographically by `(time, city)`: +/// +/// ```text +/// ordering = [time ASC NULLS LAST, city ASC NULLS LAST] +/// split_points = [ +/// (2022, Allston), +/// (2023, Allston), +/// ] +/// +/// partition 0: keys before (2022, Allston) +/// partition 1: keys between (2022, Allston) and (2023, Allston) +/// partition 2: keys at/after (2023, Allston) +/// ``` +/// +/// NOTE: Optimizer and execution behavior for this partitioning is intentionally +/// not implemented and will be introduced incrementally. See +/// . +#[derive(Debug, Clone, PartialEq)] +pub struct RangePartitioning { + /// Ordered partitioning key. + ordering: LexOrdering, + /// Boundaries between adjacent partitions. + split_points: Vec, +} + +impl RangePartitioning { + /// Creates range partitioning metadata without validating split points. + /// + /// Use [`Self::try_new`] to validate the contract documented on + /// [`RangePartitioning`]. + pub fn new(ordering: LexOrdering, split_points: Vec) -> Self { + Self { + ordering, + split_points, + } + } + + /// Creates range partitioning metadata and validates split point shape and + /// ordering. + pub fn try_new(ordering: LexOrdering, split_points: Vec) -> Result { + validate_range_split_points( + &split_points, + &ordering + .iter() + .map(|sort_expr| sort_expr.options) + .collect::>(), + )?; + Ok(Self::new(ordering, split_points)) + } + + /// Returns the ordering that defines the range key. + pub fn ordering(&self) -> &LexOrdering { + &self.ordering + } + + /// Returns the ordered split points between partitions. + pub fn split_points(&self) -> &[SplitPoint] { + &self.split_points + } + + /// Returns the number of partitions. + pub fn partition_count(&self) -> usize { + self.split_points.len() + 1 + } + + /// Returns true when `self` and `other` describe the same range partition + /// map. + /// + /// Single-partition range partitionings are always compatible. Otherwise, + /// the two partitionings must have identical split points and equivalent + /// ordering expressions with the same sort options. + pub fn compatible_with( + &self, + other: &Self, + eq_properties: &EquivalenceProperties, + ) -> bool { + if self.partition_count() == 1 && other.partition_count() == 1 { + return true; + } + + if self.split_points != other.split_points + || self.ordering.len() != other.ordering.len() + { + return false; + } + + if !self + .ordering + .iter() + .zip(other.ordering.iter()) + .all(|(left, right)| left.options == right.options) + { + return false; + } + + let left_exprs = self + .ordering + .iter() + .map(|sort_expr| Arc::clone(&sort_expr.expr)) + .collect::>(); + let right_exprs = other + .ordering + .iter() + .map(|sort_expr| Arc::clone(&sort_expr.expr)) + .collect::>(); + + equivalent_exprs(&left_exprs, &right_exprs, eq_properties) + } + + /// Calculates the range partitioning after applying the given projection. + /// + /// Returns `None` if any range key cannot be projected or if projection + /// collapses distinct range keys into duplicate output expressions. + fn project( + &self, + mapping: &ProjectionMapping, + input_eq_properties: &EquivalenceProperties, + ) -> Option { + let exprs = self + .ordering + .iter() + .map(|sort_expr| Arc::clone(&sort_expr.expr)) + .collect::>(); + let projected_exprs = input_eq_properties + .project_expressions(&exprs, mapping) + .collect::>>()?; + let sort_exprs = self + .ordering + .iter() + .zip(projected_exprs) + .map(|(sort_expr, expr)| PhysicalSortExpr::new(expr, sort_expr.options)) + .collect::>(); + let ordering = LexOrdering::new(sort_exprs)?; + if ordering.len() != self.ordering.len() { + return None; + } + + Some(Self { + ordering, + split_points: self.split_points.clone(), + }) + } +} + +impl Display for RangePartitioning { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let split_points = format_range_split_points(&self.split_points); + write!( + f, + "Range([{}], [{}], {})", + self.ordering, + split_points, + self.partition_count() + ) + } +} + +fn format_range_split_points(split_points: &[SplitPoint]) -> String { + split_points + .iter() + .map(ToString::to_string) + .collect::>() + .join(", ") +} + +fn equivalent_exprs( + left: &[Arc], + right: &[Arc], + eq_properties: &EquivalenceProperties, +) -> bool { + if physical_exprs_equal(left, right) { + return true; + } + + let eq_groups = eq_properties.eq_group(); + if eq_groups.is_empty() { + return false; + } + + let normalized_left = normalize_exprs(left, eq_properties); + let normalized_right = normalize_exprs(right, eq_properties); + + physical_exprs_equal(&normalized_left, &normalized_right) +} + +fn normalize_exprs( + exprs: &[Arc], + eq_properties: &EquivalenceProperties, +) -> Vec> { + let eq_groups = eq_properties.eq_group(); + exprs + .iter() + .map(|expr| eq_groups.normalize_expr(Arc::clone(expr))) + .collect() +} + +/// Represents how a [`Partitioning`] satisfies a [`Distribution`] requirement. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PartitioningSatisfaction { + /// The partitioning does not satisfy the distribution requirement + NotSatisfied, + /// The partitioning exactly matches the distribution requirement + Exact, + /// The partitioning satisfies the distribution requirement via subset logic + Subset, +} + +impl PartitioningSatisfaction { + pub fn is_satisfied(&self) -> bool { + matches!(self, Self::Exact | Self::Subset) + } + + pub fn is_subset(&self) -> bool { + *self == Self::Subset + } +} + impl Partitioning { /// Returns the number of partitions in this partitioning scheme pub fn partition_count(&self) -> usize { use Partitioning::*; match self { RoundRobinBatch(n) | Hash(_, n) | UnknownPartitioning(n) => *n, + Range(range) => range.partition_count(), } } - /// Returns true when the guarantees made by this [`Partitioning`] are sufficient to - /// satisfy the partitioning scheme mandated by the `required` [`Distribution`]. + /// Returns true when `self` and `other` describe compatible partition maps. + /// + /// Compatible partition maps can be used for partition-local behavior: if + /// this returns true, partition `i` from both partitionings can be treated + /// as covering the same partition domain. This is stricter than + /// [`Self::satisfaction`], which only answers whether this partitioning can + /// satisfy a required distribution. + pub fn compatible_with( + &self, + other: &Self, + eq_properties: &EquivalenceProperties, + ) -> bool { + if self.partition_count() == 1 && other.partition_count() == 1 { + return true; + } + + match (self, other) { + ( + Partitioning::Hash(left_exprs, left_count), + Partitioning::Hash(right_exprs, right_count), + ) => { + if left_count != right_count { + return false; + } + if left_exprs.is_empty() || right_exprs.is_empty() { + return false; + } + equivalent_exprs(left_exprs, right_exprs, eq_properties) + } + (Partitioning::Range(left), Partitioning::Range(right)) => { + left.compatible_with(right, eq_properties) + } + _ => false, + } + } + + /// Returns true if `subset_exprs` is a subset of `exprs`. + /// For example: Hash(a, b) is subset of Hash(a) since a partition with all occurrences of + /// a distinct (a) must also contain all occurrences of a distinct (a, b) with the same (a). + fn is_subset_partitioning( + subset_exprs: &[Arc], + superset_exprs: &[Arc], + ) -> bool { + // Require strict subset: fewer expressions, not equal + if subset_exprs.is_empty() || subset_exprs.len() >= superset_exprs.len() { + return false; + } + + subset_exprs.iter().all(|subset_expr| { + superset_exprs + .iter() + .any(|superset_expr| subset_expr.eq(superset_expr)) + }) + } + + #[deprecated(since = "52.0.0", note = "Use satisfaction instead")] pub fn satisfy( &self, required: &Distribution, eq_properties: &EquivalenceProperties, ) -> bool { + self.satisfaction(required, eq_properties, false) + == PartitioningSatisfaction::Exact + } + + /// Returns how this [`Partitioning`] satisfies the partitioning scheme mandated + /// by the `required` [`Distribution`]. + pub fn satisfaction( + &self, + required: &Distribution, + eq_properties: &EquivalenceProperties, + allow_subset: bool, + ) -> PartitioningSatisfaction { match required { - Distribution::UnspecifiedDistribution => true, - Distribution::SinglePartition if self.partition_count() == 1 => true, + Distribution::UnspecifiedDistribution => PartitioningSatisfaction::Exact, + Distribution::SinglePartition if self.partition_count() == 1 => { + PartitioningSatisfaction::Exact + } // When partition count is 1, hash requirement is satisfied. - Distribution::HashPartitioned(_) if self.partition_count() == 1 => true, - Distribution::HashPartitioned(required_exprs) => { - match self { - // Here we do not check the partition count for hash partitioning and assumes the partition count - // and hash functions in the system are the same. In future if we plan to support storage partition-wise joins, - // then we need to have the partition count and hash functions validation. - Partitioning::Hash(partition_exprs, _) => { - let fast_match = - physical_exprs_equal(required_exprs, partition_exprs); - // If the required exprs do not match, need to leverage the eq_properties provided by the child - // and normalize both exprs based on the equivalent groups. - if !fast_match { - let eq_groups = eq_properties.eq_group(); - if !eq_groups.is_empty() { - let normalized_required_exprs = required_exprs - .iter() - .map(|e| eq_groups.normalize_expr(Arc::clone(e))) - .collect::>(); - let normalized_partition_exprs = partition_exprs - .iter() - .map(|e| eq_groups.normalize_expr(Arc::clone(e))) - .collect::>(); - return physical_exprs_equal( - &normalized_required_exprs, - &normalized_partition_exprs, - ); + Distribution::HashPartitioned(_) if self.partition_count() == 1 => { + PartitioningSatisfaction::Exact + } + Distribution::HashPartitioned(required_exprs) => match self { + // Here we do not check the partition count for hash partitioning and assumes the partition count + // and hash functions in the system are the same. In future if we plan to support storage partition-wise joins, + // then we need to have the partition count and hash functions validation. + Partitioning::Hash(partition_exprs, _) => { + // Empty hash partitioning is invalid + if partition_exprs.is_empty() || required_exprs.is_empty() { + return PartitioningSatisfaction::NotSatisfied; + } + + if equivalent_exprs(required_exprs, partition_exprs, eq_properties) { + return PartitioningSatisfaction::Exact; + } + + let eq_groups = eq_properties.eq_group(); + if !eq_groups.is_empty() { + if allow_subset { + let normalized_partition_exprs = + normalize_exprs(partition_exprs, eq_properties); + let normalized_required_exprs = + normalize_exprs(required_exprs, eq_properties); + if Self::is_subset_partitioning( + &normalized_partition_exprs, + &normalized_required_exprs, + ) { + return PartitioningSatisfaction::Subset; } } - fast_match + } else if allow_subset + && Self::is_subset_partitioning(partition_exprs, required_exprs) + { + return PartitioningSatisfaction::Subset; } - _ => false, + + PartitioningSatisfaction::NotSatisfied } - } - _ => false, + Partitioning::RoundRobinBatch(_) + | Partitioning::Range(_) + | Partitioning::UnknownPartitioning(_) => { + PartitioningSatisfaction::NotSatisfied + } + }, + Distribution::SinglePartition => PartitioningSatisfaction::NotSatisfied, } } @@ -202,19 +540,29 @@ impl Partitioning { mapping: &ProjectionMapping, input_eq_properties: &EquivalenceProperties, ) -> Self { - if let Partitioning::Hash(exprs, part) = self { - let normalized_exprs = input_eq_properties - .project_expressions(exprs, mapping) - .zip(exprs) - .map(|(proj_expr, expr)| { - proj_expr.unwrap_or_else(|| { - Arc::new(UnKnownColumn::new(&expr.to_string())) + match self { + Partitioning::Hash(exprs, part) => { + let normalized_exprs = input_eq_properties + .project_expressions(exprs, mapping) + .zip(exprs) + .map(|(proj_expr, expr)| { + proj_expr.unwrap_or_else(|| { + Arc::new(UnKnownColumn::new(&expr.to_string())) + }) }) - }) - .collect(); - Partitioning::Hash(normalized_exprs, *part) - } else { - self.clone() + .collect(); + Partitioning::Hash(normalized_exprs, *part) + } + Partitioning::Range(range) => { + if let Some(projected) = range.project(mapping, input_eq_properties) { + Partitioning::Range(projected) + } else { + Partitioning::UnknownPartitioning(range.partition_count()) + } + } + Partitioning::RoundRobinBatch(_) | Partitioning::UnknownPartitioning(_) => { + self.clone() + } } } } @@ -231,6 +579,7 @@ impl PartialEq for Partitioning { { true } + (Partitioning::Range(left), Partitioning::Range(right)) => left == right, _ => false, } } @@ -281,47 +630,157 @@ mod tests { use super::*; use crate::expressions::Column; + use crate::projection::ProjectionTargets; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::Result; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::{Result, ScalarValue}; + + struct PartitioningTestFixture { + schema: SchemaRef, + cols: Vec>, + eq_properties: EquivalenceProperties, + } + + impl PartitioningTestFixture { + fn new(fields: Vec<(&str, DataType)>) -> Result { + let schema = Arc::new(Schema::new( + fields + .iter() + .map(|(name, data_type)| Field::new(*name, data_type.clone(), false)) + .collect::>(), + )); + let cols = fields + .iter() + .map(|(name, _)| { + Ok(Arc::new(Column::new_with_schema(name, &schema)?) + as Arc) + }) + .collect::>()?; + let eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + + Ok(Self { + schema, + cols, + eq_properties, + }) + } + + fn int64(names: &[&str]) -> Result { + Self::new(names.iter().map(|name| (*name, DataType::Int64)).collect()) + } + + fn col(&self, index: usize) -> Arc { + Arc::clone(&self.cols[index]) + } + + fn cols( + &self, + indices: impl IntoIterator, + ) -> Vec> { + indices.into_iter().map(|index| self.col(index)).collect() + } + + fn hash_partitioning( + &self, + indices: impl IntoIterator, + partition_count: usize, + ) -> Partitioning { + Partitioning::Hash(self.cols(indices), partition_count) + } + + fn hash_distribution( + &self, + indices: impl IntoIterator, + ) -> Distribution { + Distribution::HashPartitioned(self.cols(indices)) + } + + fn range_sort_expr( + &self, + index: usize, + options: SortOptions, + ) -> PhysicalSortExpr { + PhysicalSortExpr::new(self.col(index), options) + } + + fn range_ordering( + &self, + indices: impl IntoIterator, + ) -> LexOrdering { + LexOrdering::new( + indices + .into_iter() + .map(|index| PhysicalSortExpr::new_default(self.col(index))), + ) + .expect("ordering must not be empty") + } + + fn range( + &self, + indices: impl IntoIterator, + split_points: Vec, + ) -> RangePartitioning { + RangePartitioning::try_new(self.range_ordering(indices), split_points) + .expect("test range partitioning should be valid") + } + + fn range_partitioning( + &self, + indices: impl IntoIterator, + split_points: Vec, + ) -> Partitioning { + Partitioning::Range(self.range(indices, split_points)) + } + + fn range_partitioning_with_ordering( + &self, + ordering: LexOrdering, + split_points: Vec, + ) -> Partitioning { + Partitioning::Range( + RangePartitioning::try_new(ordering, split_points) + .expect("test range partitioning should be valid"), + ) + } + } #[test] fn partitioning_satisfy_distribution() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("column_1", DataType::Int64, false), - Field::new("column_2", DataType::Utf8, false), - ])); - - let partition_exprs1: Vec> = vec![ - Arc::new(Column::new_with_schema("column_1", &schema).unwrap()), - Arc::new(Column::new_with_schema("column_2", &schema).unwrap()), - ]; - - let partition_exprs2: Vec> = vec![ - Arc::new(Column::new_with_schema("column_2", &schema).unwrap()), - Arc::new(Column::new_with_schema("column_1", &schema).unwrap()), - ]; + let fixture = PartitioningTestFixture::new(vec![ + ("column_1", DataType::Int64), + ("column_2", DataType::Utf8), + ])?; let distribution_types = vec![ Distribution::UnspecifiedDistribution, Distribution::SinglePartition, - Distribution::HashPartitioned(partition_exprs1.clone()), + fixture.hash_distribution([0, 1]), ]; let single_partition = Partitioning::UnknownPartitioning(1); let unspecified_partition = Partitioning::UnknownPartitioning(10); let round_robin_partition = Partitioning::RoundRobinBatch(10); - let hash_partition1 = Partitioning::Hash(partition_exprs1, 10); - let hash_partition2 = Partitioning::Hash(partition_exprs2, 10); - let eq_properties = EquivalenceProperties::new(schema); + let hash_partition1 = fixture.hash_partitioning([0, 1], 10); + let hash_partition2 = fixture.hash_partitioning([1, 0], 10); for distribution in distribution_types { let result = ( - single_partition.satisfy(&distribution, &eq_properties), - unspecified_partition.satisfy(&distribution, &eq_properties), - round_robin_partition.satisfy(&distribution, &eq_properties), - hash_partition1.satisfy(&distribution, &eq_properties), - hash_partition2.satisfy(&distribution, &eq_properties), + single_partition + .satisfaction(&distribution, &fixture.eq_properties, true) + .is_satisfied(), + unspecified_partition + .satisfaction(&distribution, &fixture.eq_properties, true) + .is_satisfied(), + round_robin_partition + .satisfaction(&distribution, &fixture.eq_properties, true) + .is_satisfied(), + hash_partition1 + .satisfaction(&distribution, &fixture.eq_properties, true) + .is_satisfied(), + hash_partition2 + .satisfaction(&distribution, &fixture.eq_properties, true) + .is_satisfied(), ); match distribution { @@ -339,4 +798,625 @@ mod tests { Ok(()) } + + #[test] + fn test_partitioning_satisfy_by_subset() -> Result<()> { + let fixture = PartitioningTestFixture::int64(&["a", "b", "c"])?; + + let test_cases = vec![ + ( + "Hash([a]) vs Hash([a, b])", + fixture.hash_partitioning([0], 4), + fixture.hash_distribution([0, 1]), + PartitioningSatisfaction::Subset, + PartitioningSatisfaction::NotSatisfied, + ), + ( + "Hash([a]) vs Hash([a, b, c])", + fixture.hash_partitioning([0], 4), + fixture.hash_distribution([0, 1, 2]), + PartitioningSatisfaction::Subset, + PartitioningSatisfaction::NotSatisfied, + ), + ( + "Hash([a, b]) vs Hash([a, b, c])", + fixture.hash_partitioning([0, 1], 4), + fixture.hash_distribution([0, 1, 2]), + PartitioningSatisfaction::Subset, + PartitioningSatisfaction::NotSatisfied, + ), + ( + "Hash([b]) vs Hash([a, b, c])", + fixture.hash_partitioning([1], 4), + fixture.hash_distribution([0, 1, 2]), + PartitioningSatisfaction::Subset, + PartitioningSatisfaction::NotSatisfied, + ), + ( + "Hash([b, a]) vs Hash([a, b, c])", + fixture.hash_partitioning([1, 0], 4), + fixture.hash_distribution([0, 1, 2]), + PartitioningSatisfaction::Subset, + PartitioningSatisfaction::NotSatisfied, + ), + ]; + + for (desc, partition, required, expected_with_subset, expected_without_subset) in + test_cases + { + let result = partition.satisfaction(&required, &fixture.eq_properties, true); + assert_eq!( + result, expected_with_subset, + "Failed for {desc} with subset enabled" + ); + + let result = partition.satisfaction(&required, &fixture.eq_properties, false); + assert_eq!( + result, expected_without_subset, + "Failed for {desc} with subset disabled" + ); + } + + Ok(()) + } + + #[test] + fn test_partitioning_current_superset() -> Result<()> { + let fixture = PartitioningTestFixture::int64(&["a", "b", "c"])?; + + let test_cases = vec![ + ( + "Hash([a, b]) vs Hash([a])", + fixture.hash_partitioning([0, 1], 4), + fixture.hash_distribution([0]), + PartitioningSatisfaction::NotSatisfied, + PartitioningSatisfaction::NotSatisfied, + ), + ( + "Hash([a, b, c]) vs Hash([a])", + fixture.hash_partitioning([0, 1, 2], 4), + fixture.hash_distribution([0]), + PartitioningSatisfaction::NotSatisfied, + PartitioningSatisfaction::NotSatisfied, + ), + ( + "Hash([a, b, c]) vs Hash([a, b])", + fixture.hash_partitioning([0, 1, 2], 4), + fixture.hash_distribution([0, 1]), + PartitioningSatisfaction::NotSatisfied, + PartitioningSatisfaction::NotSatisfied, + ), + ]; + + for (desc, partition, required, expected_with_subset, expected_without_subset) in + test_cases + { + let result = partition.satisfaction(&required, &fixture.eq_properties, true); + assert_eq!( + result, expected_with_subset, + "Failed for {desc} with subset enabled" + ); + + let result = partition.satisfaction(&required, &fixture.eq_properties, false); + assert_eq!( + result, expected_without_subset, + "Failed for {desc} with subset disabled" + ); + } + + Ok(()) + } + + #[test] + fn test_partitioning_partial_overlap() -> Result<()> { + let fixture = PartitioningTestFixture::int64(&["a", "b", "c"])?; + + let test_cases = vec![( + "Partial overlap: Hash([a, c]) vs Hash([a, b])", + fixture.hash_partitioning([0, 2], 4), + fixture.hash_distribution([0, 1]), + PartitioningSatisfaction::NotSatisfied, + PartitioningSatisfaction::NotSatisfied, + )]; + + for (desc, partition, required, expected_with_subset, expected_without_subset) in + test_cases + { + let result = partition.satisfaction(&required, &fixture.eq_properties, true); + assert_eq!( + result, expected_with_subset, + "Failed for {desc} with subset enabled" + ); + + let result = partition.satisfaction(&required, &fixture.eq_properties, false); + assert_eq!( + result, expected_without_subset, + "Failed for {desc} with subset disabled" + ); + } + + Ok(()) + } + + #[test] + fn test_partitioning_no_overlap() -> Result<()> { + let fixture = PartitioningTestFixture::int64(&["a", "b", "c"])?; + + let test_cases = vec![ + ( + "Hash([a]) vs Hash([b, c])", + fixture.hash_partitioning([0], 4), + fixture.hash_distribution([1, 2]), + PartitioningSatisfaction::NotSatisfied, + PartitioningSatisfaction::NotSatisfied, + ), + ( + "Hash([a, b]) vs Hash([c])", + fixture.hash_partitioning([0, 1], 4), + fixture.hash_distribution([2]), + PartitioningSatisfaction::NotSatisfied, + PartitioningSatisfaction::NotSatisfied, + ), + ]; + + for (desc, partition, required, expected_with_subset, expected_without_subset) in + test_cases + { + let result = partition.satisfaction(&required, &fixture.eq_properties, true); + assert_eq!( + result, expected_with_subset, + "Failed for {desc} with subset enabled" + ); + + let result = partition.satisfaction(&required, &fixture.eq_properties, false); + assert_eq!( + result, expected_without_subset, + "Failed for {desc} with subset disabled" + ); + } + + Ok(()) + } + + #[test] + fn test_partitioning_exact_match() -> Result<()> { + let fixture = PartitioningTestFixture::int64(&["a", "b"])?; + + let test_cases = vec![ + ( + "Hash([a, b]) vs Hash([a, b])", + fixture.hash_partitioning([0, 1], 4), + fixture.hash_distribution([0, 1]), + PartitioningSatisfaction::Exact, + PartitioningSatisfaction::Exact, + ), + ( + "Hash([a]) vs Hash([a])", + fixture.hash_partitioning([0], 4), + fixture.hash_distribution([0]), + PartitioningSatisfaction::Exact, + PartitioningSatisfaction::Exact, + ), + ]; + + for (desc, partition, required, expected_with_subset, expected_without_subset) in + test_cases + { + let result = partition.satisfaction(&required, &fixture.eq_properties, true); + assert_eq!( + result, expected_with_subset, + "Failed for {desc} with subset enabled" + ); + + let result = partition.satisfaction(&required, &fixture.eq_properties, false); + assert_eq!( + result, expected_without_subset, + "Failed for {desc} with subset disabled" + ); + } + + Ok(()) + } + + #[test] + fn test_partitioning_unknown() -> Result<()> { + let fixture = PartitioningTestFixture::int64(&["a", "b"])?; + let unknown: Arc = Arc::new(UnKnownColumn::new("dropped")); + + let test_cases = vec![ + ( + "Hash([unknown]) vs Hash([a, b])", + Partitioning::Hash(vec![Arc::clone(&unknown)], 4), + fixture.hash_distribution([0, 1]), + PartitioningSatisfaction::NotSatisfied, + PartitioningSatisfaction::NotSatisfied, + ), + ( + "Hash([a, b]) vs Hash([unknown])", + fixture.hash_partitioning([0, 1], 4), + Distribution::HashPartitioned(vec![Arc::clone(&unknown)]), + PartitioningSatisfaction::NotSatisfied, + PartitioningSatisfaction::NotSatisfied, + ), + ( + "Hash([unknown]) vs Hash([unknown])", + Partitioning::Hash(vec![Arc::clone(&unknown)], 4), + Distribution::HashPartitioned(vec![Arc::clone(&unknown)]), + PartitioningSatisfaction::NotSatisfied, + PartitioningSatisfaction::NotSatisfied, + ), + ]; + + for (desc, partition, required, expected_with_subset, expected_without_subset) in + test_cases + { + let result = partition.satisfaction(&required, &fixture.eq_properties, true); + assert_eq!( + result, expected_with_subset, + "Failed for {desc} with subset enabled" + ); + + let result = partition.satisfaction(&required, &fixture.eq_properties, false); + assert_eq!( + result, expected_without_subset, + "Failed for {desc} with subset disabled" + ); + } + + Ok(()) + } + + #[test] + fn test_partitioning_empty_hash() -> Result<()> { + let fixture = PartitioningTestFixture::int64(&["a"])?; + + let test_cases = vec![ + ( + "Hash([]) vs Hash([a])", + Partitioning::Hash(vec![], 4), + fixture.hash_distribution([0]), + PartitioningSatisfaction::NotSatisfied, + PartitioningSatisfaction::NotSatisfied, + ), + ( + "Hash([a]) vs Hash([])", + fixture.hash_partitioning([0], 4), + Distribution::HashPartitioned(vec![]), + PartitioningSatisfaction::NotSatisfied, + PartitioningSatisfaction::NotSatisfied, + ), + ( + "Hash([]) vs Hash([])", + Partitioning::Hash(vec![], 4), + Distribution::HashPartitioned(vec![]), + PartitioningSatisfaction::NotSatisfied, + PartitioningSatisfaction::NotSatisfied, + ), + ]; + + for (desc, partition, required, expected_with_subset, expected_without_subset) in + test_cases + { + let result = partition.satisfaction(&required, &fixture.eq_properties, true); + assert_eq!( + result, expected_with_subset, + "Failed for {desc} with subset enabled" + ); + + let result = partition.satisfaction(&required, &fixture.eq_properties, false); + assert_eq!( + result, expected_without_subset, + "Failed for {desc} with subset disabled" + ); + } + + Ok(()) + } + + fn int_split_point(values: impl IntoIterator) -> SplitPoint { + SplitPoint::new( + values + .into_iter() + .map(|value| ScalarValue::Int64(Some(value))) + .collect(), + ) + } + + fn assert_range_try_new_error( + ordering: LexOrdering, + split_points: Vec, + expected: &str, + ) { + let error = RangePartitioning::try_new(ordering, split_points) + .unwrap_err() + .to_string(); + assert!(error.contains(expected), "{error}"); + } + + #[test] + fn test_range_partitioning_metadata() -> Result<()> { + let fixture = PartitioningTestFixture::int64(&["a", "b"])?; + + let range_partitioning = + fixture.range([0], vec![int_split_point([10]), int_split_point([20])]); + assert_eq!(range_partitioning.ordering()[0].to_string(), "a@0 ASC"); + assert_eq!( + range_partitioning.split_points(), + &[int_split_point([10]), int_split_point([20])] + ); + let partitioning = Partitioning::Range(range_partitioning); + + assert_eq!(partitioning.partition_count(), 3); + assert_eq!( + partitioning.to_string(), + "Range([a@0 ASC], [(10), (20)], 3)" + ); + + Ok(()) + } + + #[test] + fn test_range_partitioning_try_new_validates_split_points() -> Result<()> { + let fixture = PartitioningTestFixture::int64(&["a", "b"])?; + let asc_a = fixture.range_ordering([0]); + let ordering_ab = fixture.range_ordering([0, 1]); + + assert_range_try_new_error( + ordering_ab.clone(), + vec![int_split_point([10])], + "split point 0 has width 1, but ordering has width 2", + ); + + RangePartitioning::try_new( + [fixture.range_sort_expr(0, SortOptions::new(true, false))].into(), + vec![int_split_point([20]), int_split_point([10])], + )?; + + assert_range_try_new_error( + asc_a, + vec![int_split_point([20]), int_split_point([10])], + "split points must be strictly ordered", + ); + + assert_range_try_new_error( + [fixture.range_sort_expr(0, SortOptions::new(false, false))].into(), + vec![ + SplitPoint::new(vec![ScalarValue::Int64(None)]), + int_split_point([10]), + ], + "split points must be strictly ordered", + ); + + RangePartitioning::try_new( + ordering_ab.clone(), + vec![int_split_point([10, 20]), int_split_point([10, 30])], + )?; + + assert_range_try_new_error( + ordering_ab, + vec![int_split_point([10, 30]), int_split_point([10, 20])], + "split points must be strictly ordered", + ); + + Ok(()) + } + + #[test] + fn test_range_partitioning_project_preserves_or_degrades() -> Result<()> { + let fixture = PartitioningTestFixture::int64(&["a", "b"])?; + let range_partitioning = fixture.range_partitioning_with_ordering( + [fixture.range_sort_expr(1, SortOptions::new(true, false))].into(), + vec![int_split_point([10])], + ); + + let keep_b_mapping = ProjectionMapping::from_indices(&[1], &fixture.schema)?; + let projected = + range_partitioning.project(&keep_b_mapping, &fixture.eq_properties); + assert_eq!( + projected.to_string(), + "Range([b@0 DESC NULLS LAST], [(10)], 2)" + ); + + let drop_b_mapping = ProjectionMapping::from_indices(&[0], &fixture.schema)?; + let projected = + range_partitioning.project(&drop_b_mapping, &fixture.eq_properties); + let Partitioning::UnknownPartitioning(partition_count) = projected else { + panic!("expected UnknownPartitioning, got {projected:?}"); + }; + assert_eq!(partition_count, 2); + + Ok(()) + } + + #[test] + fn test_range_partitioning_project_degrades_if_ordering_collapses() -> Result<()> { + let fixture = PartitioningTestFixture::int64(&["a", "b"])?; + let target: Arc = Arc::new(Column::new("x", 0)); + let range_partitioning = + fixture.range_partitioning([0, 1], vec![int_split_point([10, 100])]); + let mapping = ProjectionMapping::from_iter([ + ( + fixture.col(0), + ProjectionTargets::from(vec![(Arc::clone(&target), 0)]), + ), + ( + fixture.col(1), + ProjectionTargets::from(vec![(Arc::clone(&target), 0)]), + ), + ]); + + let projected = range_partitioning.project(&mapping, &fixture.eq_properties); + let Partitioning::UnknownPartitioning(partition_count) = projected else { + panic!("expected UnknownPartitioning, got {projected:?}"); + }; + assert_eq!(partition_count, 2); + + Ok(()) + } + + #[test] + fn test_range_partitioning_compatible_with() -> Result<()> { + let fixture = PartitioningTestFixture::int64(&["a", "b"])?; + let mut eq_properties = fixture.eq_properties.clone(); + eq_properties.add_equal_conditions(fixture.col(0), fixture.col(1))?; + + let split_points = vec![int_split_point([10]), int_split_point([20])]; + let range_a = fixture.range([0], split_points.clone()); + let range_a_same = fixture.range([0], split_points.clone()); + let range_b_equivalent = fixture.range([1], split_points.clone()); + let range_b_different_split = fixture.range([1], vec![int_split_point([30])]); + let range_a_desc = RangePartitioning::try_new( + [fixture.range_sort_expr(0, SortOptions::new(true, false))].into(), + vec![int_split_point([10])], + )?; + let single_partition_range_a = fixture.range([0], vec![]); + let single_partition_range_b = fixture.range([1], vec![]); + + assert!(range_a.compatible_with(&range_a_same, &fixture.eq_properties)); + assert!(range_a.compatible_with(&range_b_equivalent, &eq_properties)); + assert!(!range_a.compatible_with(&range_b_equivalent, &fixture.eq_properties)); + assert!(!range_a.compatible_with(&range_b_different_split, &eq_properties)); + assert!(!range_a.compatible_with(&range_a_desc, &eq_properties)); + assert!( + single_partition_range_a + .compatible_with(&single_partition_range_b, &fixture.eq_properties) + ); + + assert!( + fixture + .range_partitioning([0], vec![int_split_point([10])]) + .compatible_with( + &fixture.range_partitioning([1], vec![int_split_point([10])]), + &eq_properties + ) + ); + assert!( + !fixture + .range_partitioning([0], vec![int_split_point([10])]) + .compatible_with( + &fixture.range_partitioning([0], vec![int_split_point([20])]), + &fixture.eq_properties + ) + ); + assert!( + !fixture + .range_partitioning([0], vec![int_split_point([10])]) + .compatible_with( + &fixture.hash_partitioning([0], 2), + &fixture.eq_properties + ) + ); + + Ok(()) + } + + #[test] + fn test_hash_partitioning_compatible_with() -> Result<()> { + let fixture = PartitioningTestFixture::int64(&["a", "b"])?; + let mut eq_properties = fixture.eq_properties.clone(); + eq_properties.add_equal_conditions(fixture.col(0), fixture.col(1))?; + + assert!( + fixture.hash_partitioning([0], 2).compatible_with( + &fixture.hash_partitioning([0], 2), + &fixture.eq_properties + ) + ); + assert!( + fixture + .hash_partitioning([0], 2) + .compatible_with(&fixture.hash_partitioning([1], 2), &eq_properties) + ); + assert!( + !fixture.hash_partitioning([0], 2).compatible_with( + &fixture.hash_partitioning([1], 2), + &fixture.eq_properties + ) + ); + assert!( + !fixture.hash_partitioning([0], 2).compatible_with( + &fixture.hash_partitioning([0], 3), + &fixture.eq_properties + ) + ); + assert!(!fixture.hash_partitioning([0], 2).compatible_with( + &fixture.hash_partitioning([0, 1], 2), + &fixture.eq_properties + )); + assert!( + !Partitioning::Hash(vec![], 2) + .compatible_with(&Partitioning::Hash(vec![], 2), &fixture.eq_properties) + ); + assert!(!fixture.hash_partitioning([0], 2).compatible_with( + &fixture.range_partitioning([0], vec![int_split_point([10])]), + &fixture.eq_properties + )); + assert!( + fixture.hash_partitioning([0], 1).compatible_with( + &Partitioning::RoundRobinBatch(1), + &fixture.eq_properties + ) + ); + + Ok(()) + } + + #[test] + fn test_round_robin_partitioning_compatible_with() { + let eq_properties = EquivalenceProperties::new(Arc::new(Schema::empty())); + + assert!( + Partitioning::RoundRobinBatch(1) + .compatible_with(&Partitioning::RoundRobinBatch(1), &eq_properties) + ); + assert!( + !Partitioning::RoundRobinBatch(2) + .compatible_with(&Partitioning::RoundRobinBatch(2), &eq_properties) + ); + assert!( + Partitioning::RoundRobinBatch(1) + .compatible_with(&Partitioning::UnknownPartitioning(1), &eq_properties) + ); + assert!( + !Partitioning::RoundRobinBatch(2) + .compatible_with(&Partitioning::UnknownPartitioning(2), &eq_properties) + ); + } + + #[test] + fn test_unknown_partitioning_compatible_with() { + let eq_properties = EquivalenceProperties::new(Arc::new(Schema::empty())); + + assert!( + Partitioning::UnknownPartitioning(1) + .compatible_with(&Partitioning::UnknownPartitioning(1), &eq_properties) + ); + assert!( + !Partitioning::UnknownPartitioning(2) + .compatible_with(&Partitioning::UnknownPartitioning(2), &eq_properties) + ); + assert!( + Partitioning::UnknownPartitioning(1) + .compatible_with(&Partitioning::RoundRobinBatch(1), &eq_properties) + ); + assert!( + !Partitioning::UnknownPartitioning(2) + .compatible_with(&Partitioning::RoundRobinBatch(2), &eq_properties) + ); + } + + #[test] + fn test_multi_partition_range_does_not_satisfy_hash_distribution() -> Result<()> { + let fixture = PartitioningTestFixture::int64(&["a", "b"])?; + let range_partitioning = + fixture.range_partitioning([0, 1], vec![int_split_point([10, 100])]); + let required = fixture.hash_distribution([0, 1]); + + assert_eq!( + range_partitioning.satisfaction(&required, &fixture.eq_properties, false), + PartitioningSatisfaction::NotSatisfied + ); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index c658a8eddc233..77ede76e1daa8 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -18,13 +18,13 @@ use std::sync::Arc; use crate::expressions::{self, Column}; -use crate::{create_physical_expr, LexOrdering, PhysicalSortExpr}; +use crate::{LexOrdering, PhysicalSortExpr, create_physical_expr}; use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{plan_err, Result}; use datafusion_common::{DFSchema, HashMap}; +use datafusion_common::{Result, plan_err}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{Expr, SortExpr}; @@ -38,7 +38,7 @@ pub fn add_offset_to_expr( expr: Arc, offset: isize, ) -> Result> { - expr.transform_down(|e| match e.as_any().downcast_ref::() { + expr.transform_down(|e| match e.downcast_ref::() { Some(col) => { let Some(idx) = col.index().checked_add_signed(offset) else { return plan_err!("Column index overflow"); @@ -233,18 +233,17 @@ pub fn add_offset_to_physical_sort_exprs( mod tests { use super::*; - use crate::expressions::{BinaryExpr, Column, Literal}; + use crate::expressions::{BinaryExpr, Literal}; use crate::physical_expr::{ physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal, }; use datafusion_physical_expr_common::physical_expr::is_volatile; - use arrow::datatypes::{DataType, Schema}; + use arrow::datatypes::DataType; use arrow::record_batch::RecordBatch; - use datafusion_common::{Result, ScalarValue}; + use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; use datafusion_expr::Operator; - use std::any::Any; use std::fmt; #[test] @@ -394,10 +393,6 @@ mod tests { } impl PhysicalExpr for MockVolatileExpr { - fn as_any(&self) -> &dyn Any { - self - } - fn data_type(&self, _input_schema: &Schema) -> Result { Ok(DataType::Boolean) } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 7790380dffd56..d0d0508a106a5 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,24 +17,30 @@ use std::sync::Arc; -use crate::ScalarFunctionExpr; +use crate::scalar_subquery::ScalarSubqueryExpr; +use crate::{HigherOrderFunctionExpr, ScalarFunctionExpr}; use crate::{ - expressions::{self, binary, like, similar_to, Column, Literal}, PhysicalExpr, + expressions::{self, Column, Literal, binary, like, similar_to}, }; use arrow::datatypes::Schema; use datafusion_common::config::ConfigOptions; -use datafusion_common::metadata::FieldMetadata; +use datafusion_common::datatype::FieldExt; +use datafusion_common::metadata::{FieldMetadata, format_type_and_metadata}; use datafusion_common::{ - exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, + DFSchema, Result, ScalarValue, TableReference, ToDFSchema, exec_err, + internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, }; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; -use datafusion_expr::var_provider::is_system_variables; +use datafusion_expr::expr::{ + Alias, Cast, HigherOrderFunction, InList, Lambda, LambdaVariable, Placeholder, + ScalarFunction, +}; use datafusion_expr::var_provider::VarType; +use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::{ - binary_expr, lit, Between, BinaryExpr, Expr, Like, Operator, TryCast, + Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, TryCast, binary_expr, lit, }; /// [PhysicalExpr] evaluate DataFusion expressions such as `A + 1`, or `CAST(c1 @@ -105,6 +111,7 @@ use datafusion_expr::{ /// * `e` - The logical expression /// * `input_dfschema` - The DataFusion schema for the input, used to resolve `Column` references /// to qualified or unqualified fields by name. +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] pub fn create_physical_expr( e: &Expr, input_dfschema: &DFSchema, @@ -287,16 +294,31 @@ pub fn create_physical_expr( }; Ok(expressions::case(expr, when_then_expr, else_expr)?) } - Expr::Cast(Cast { expr, data_type }) => expressions::cast( - create_physical_expr(expr, input_dfschema, execution_props)?, - input_schema, - data_type.clone(), - ), - Expr::TryCast(TryCast { expr, data_type }) => expressions::try_cast( + Expr::Cast(Cast { expr, field }) => expressions::cast_with_target_field( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, - data_type.clone(), + Arc::clone(field), + None, ), + Expr::TryCast(TryCast { expr, field }) => { + if !field.metadata().is_empty() { + let (_, src_field) = expr.to_field(input_dfschema)?; + return plan_err!( + "TryCast from {} to {} is not supported", + format_type_and_metadata( + src_field.data_type(), + Some(src_field.metadata()), + ), + format_type_and_metadata(field.data_type(), Some(field.metadata())) + ); + } + + expressions::try_cast( + create_physical_expr(expr, input_dfschema, execution_props)?, + input_schema, + field.data_type().clone(), + ) + } Expr::Not(expr) => { expressions::not(create_physical_expr(expr, input_dfschema, execution_props)?) } @@ -380,9 +402,169 @@ pub fn create_physical_expr( expressions::in_list(value_expr, list_exprs, negated, input_schema) } }, + Expr::ScalarSubquery(sq) => { + match execution_props.subquery_indexes.get(sq) { + Some(&index) => { + let schema = sq.subquery.schema(); + if schema.fields().len() != 1 { + return plan_err!( + "Scalar subquery must return exactly one column, got {}", + schema.fields().len() + ); + } + let dt = schema.field(0).data_type().clone(); + let nullable = schema.field(0).is_nullable(); + Ok(Arc::new(ScalarSubqueryExpr::new( + dt, + nullable, + index, + execution_props.subquery_results.clone(), + ))) + } + None => { + // Not found: either a correlated subquery that wasn't + // rewritten to a join, or an uncorrelated one that wasn't + // registered by the physical planner. + not_impl_err!( + "Physical plan does not support logical expression {e:?}" + ) + } + } + } Expr::Placeholder(Placeholder { id, .. }) => { exec_err!("Placeholder '{id}' was not provided a value for execution.") } + Expr::HigherOrderFunction(invocation @ HigherOrderFunction { func, args }) => { + let num_lambdas = args + .iter() + .filter(|arg| matches!(arg, Expr::Lambda(_))) + .count(); + + let mut lambda_parameters = + invocation.lambda_parameters(input_dfschema)?.into_iter(); + + if num_lambdas > lambda_parameters.len() { + return plan_err!( + "{} lambda_parameters returned only {} values for {num_lambdas} lambdas", + func.name(), + lambda_parameters.len() + ); + } + + let lambda_qualifier = 1 + input_dfschema + .iter() + .filter_map(|(qualifier, _field)| { + qualifier.and_then(|tbl| { + tbl.table().strip_prefix("lambda_")?.parse::().ok() + }) + }) + .max() + .unwrap_or_default(); + + let qualifier = TableReference::bare(format!("lambda_{lambda_qualifier}")); + + let physical_args = args + .iter() + .map(|arg| match arg { + Expr::Lambda(lambda) => { + let lambda_parameters = lambda_parameters + .next() + .ok_or_else(|| { + internal_datafusion_err!( + "lambda_parameters len should have been checked above" + ) + })? + .into_iter() + .zip(&lambda.params) + .map(|(field, name)| { + (Some(qualifier.clone()), field.renamed(name.as_str())) + }); + + let new_fields = input_dfschema + .iter() + .map(|(tbl, field)| (tbl.cloned(), Arc::clone(field))) + .chain(lambda_parameters) + .collect(); + + let lambda_schema = DFSchema::new_with_metadata( + new_fields, + input_dfschema.metadata().clone(), + )?; + + let execution_props = execution_props + .clone() + .with_qualified_lambda_variables(&qualifier, &lambda.params); + + create_physical_expr(arg, &lambda_schema, &execution_props) + } + _ => create_physical_expr(arg, input_dfschema, execution_props), + }) + .collect::>()?; + + let config_options = match execution_props.config_options.as_ref() { + Some(config_options) => Arc::clone(config_options), + None => Arc::new(ConfigOptions::default()), + }; + + Ok(Arc::new(HigherOrderFunctionExpr::try_new_with_schema( + Arc::clone(func), + physical_args, + input_schema, + config_options, + )?)) + } + Expr::Lambda(Lambda { params, body }) => expressions::lambda( + params, + create_physical_expr(body, input_dfschema, execution_props)?, + ), + Expr::LambdaVariable(LambdaVariable { + name, + field, + spans: _, + }) => { + let field = field.as_ref().ok_or_else(|| { + plan_datafusion_err!("unresolved LambdaVariable {name}") + })?; + + let qualifier = execution_props + .lambda_variable_qualifier + .get(name) + .ok_or_else(|| { + plan_datafusion_err!("qualifier for lambda variable {name} not found") + })?; + + let index = input_dfschema + .index_of_column_by_name(Some(qualifier), name) + .ok_or_else(|| { + plan_datafusion_err!( + "lambda variable {qualifier}.{name} not found in planning schema" + ) + })?; + + let schema_field = input_dfschema.field(index); + + // LambdaVariable.field will be made optional as in Expr::Placeholder + // and only LambdaVariable.name used, and field.name ignored, + // so they're not enforced to match for logical expressions + // Rename the field to match the schema one and use it's PartialEq impl instead + // of checking property by property and fail if new properties get's added to it. + // While not necessary, the sql planner does create lambda vars with matching names, + // so this shouldn't allocate with a lambda var from it + let renamed_field = Arc::clone(field).renamed(name); + + if &renamed_field != schema_field { + return plan_err!( + "LambdaVariable field and schema field mismatch {} != {}", + renamed_field, + schema_field + ); + } + + Ok(Arc::new(expressions::LambdaVariable::new( + index, + Arc::clone(schema_field), + ))) + } other => { not_impl_err!("Physical plan does not support logical expression {other:?}") } @@ -416,11 +598,25 @@ pub fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { mod tests { use arrow::array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field}; - - use datafusion_expr::{col, lit}; + use datafusion_expr::col; use super::*; + fn test_cast_schema() -> Schema { + Schema::new(vec![Field::new("a", DataType::Int32, false)]) + } + + fn lower_cast_expr(expr: &Expr, schema: &Schema) -> Result> { + let df_schema = DFSchema::try_from(schema.clone())?; + create_physical_expr(expr, &df_schema, &ExecutionProps::new()) + } + + fn as_planner_cast(physical: &Arc) -> &expressions::CastExpr { + physical + .downcast_ref::() + .expect("planner should lower logical CAST to CastExpr") + } + #[test] fn test_create_physical_expr_scalar_input_output() -> Result<()> { let expr = col("letter").eq(lit("A")); @@ -445,4 +641,96 @@ mod tests { Ok(()) } + + #[test] + fn test_cast_lowering_preserves_target_field_metadata() -> Result<()> { + let schema = test_cast_schema(); + let target_field = Arc::new( + Field::new("cast_target", DataType::Int64, true) + .with_metadata([("target_meta".to_string(), "1".to_string())].into()), + ); + let cast_expr = Expr::Cast(Cast::new_from_field( + Box::new(col("a")), + Arc::clone(&target_field), + )); + + let physical = lower_cast_expr(&cast_expr, &schema)?; + let cast = as_planner_cast(&physical); + + assert_eq!(cast.target_field(), &target_field); + assert_eq!(physical.return_field(&schema)?, target_field); + assert!(physical.nullable(&schema)?); + + Ok(()) + } + + #[test] + fn test_cast_lowering_preserves_standard_cast_semantics() -> Result<()> { + let schema = test_cast_schema(); + let cast_expr = Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int64)); + + let physical = lower_cast_expr(&cast_expr, &schema)?; + let cast = as_planner_cast(&physical); + let returned_field = physical.return_field(&schema)?; + + assert_eq!(cast.cast_type(), &DataType::Int64); + assert_eq!(returned_field.name(), "a"); + assert_eq!(returned_field.data_type(), &DataType::Int64); + assert!(!physical.nullable(&schema)?); + + Ok(()) + } + + #[test] + fn test_cast_lowering_preserves_same_type_field_semantics() -> Result<()> { + let schema = test_cast_schema(); + let target_field = Arc::new( + Field::new("same_type_cast", DataType::Int32, true).with_metadata( + [("target_meta".to_string(), "same-type".to_string())].into(), + ), + ); + let cast_expr = Expr::Cast(Cast::new_from_field( + Box::new(col("a")), + Arc::clone(&target_field), + )); + + let physical = lower_cast_expr(&cast_expr, &schema)?; + let cast = as_planner_cast(&physical); + + assert_eq!(cast.target_field(), &target_field); + assert_eq!(physical.return_field(&schema)?, target_field); + assert!(physical.nullable(&schema)?); + + Ok(()) + } + + /// Test that deeply nested expressions do not cause a stack overflow. + /// + /// This test only runs when the `recursive_protection` feature is enabled, + /// as it would overflow the stack otherwise. + #[test] + #[cfg_attr(not(feature = "recursive_protection"), ignore)] + fn test_deeply_nested_binary_expr() -> Result<()> { + // Create a deeply nested binary expression tree: ((((a + a) + a) + a) + ... ) + // With 1000 levels of nesting, this would overflow the stack without recursion protection. + let depth = 1000; + + let mut expr = col("a"); + for _ in 0..depth { + expr = Expr::BinaryExpr(BinaryExpr { + left: Box::new(expr), + op: Operator::Plus, + right: Box::new(col("a")), + }); + } + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let df_schema = DFSchema::try_from(schema)?; + + // This should not stack overflow + let _physical_expr = + create_physical_expr(&expr, &df_schema, &ExecutionProps::new())?; + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index 3d6740510bec6..cee95685e8440 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -15,27 +15,34 @@ // specific language governing permissions and limitations // under the License. +//! [`ProjectionExpr`] and [`ProjectionExprs`] for representing projections. + use std::ops::Deref; use std::sync::Arc; -use crate::expressions::Column; -use crate::utils::collect_columns; use crate::PhysicalExpr; +use crate::expressions::{CastExpr, Column, Literal}; +use crate::scalar_function::ScalarFunctionExpr; +use crate::utils::collect_columns; use arrow::array::{RecordBatch, RecordBatchOptions}; -use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::stats::{ColumnStatistics, Precision}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ - assert_or_internal_err, internal_datafusion_err, plan_err, Result, + Result, ScalarValue, Statistics, assert_or_internal_err, internal_datafusion_err, + plan_err, }; +use datafusion_physical_expr_common::metrics::ExecutionPlanMetricsSet; +use datafusion_physical_expr_common::metrics::ExpressionEvaluatorMetrics; +use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays; +use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays_with_metrics; use indexmap::IndexMap; use itertools::Itertools; -/// A projection expression as used by projection operations. +/// An expression used by projection operations. /// /// The expression is evaluated and the result is stored in a column /// with the name specified by `alias`. @@ -43,6 +50,8 @@ use itertools::Itertools; /// For example, the SQL expression `a + b AS sum_ab` would be represented /// as a `ProjectionExpr` where `expr` is the expression `a + b` /// and `alias` is the string `sum_ab`. +/// +/// See [`ProjectionExprs`] for a collection of projection expressions. #[derive(Debug, Clone)] pub struct ProjectionExpr { /// The expression that will be evaluated. @@ -72,7 +81,8 @@ impl std::fmt::Display for ProjectionExpr { impl ProjectionExpr { /// Create a new projection expression - pub fn new(expr: Arc, alias: String) -> Self { + pub fn new(expr: Arc, alias: impl Into) -> Self { + let alias = alias.into(); Self { expr, alias } } @@ -107,14 +117,18 @@ impl From for (Arc, String) { } } -/// A collection of projection expressions. +/// A collection of [`ProjectionExpr`] instances, representing a complete +/// projection operation. +/// +/// Projection operations are used in query plans to select specific columns or +/// compute new columns based on existing ones. /// -/// This struct encapsulates multiple `ProjectionExpr` instances, -/// representing a complete projection operation and provides -/// methods to manipulate and analyze the projection as a whole. +/// See [`ProjectionExprs::from_indices`] to select a subset of columns by +/// indices. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ProjectionExprs { - exprs: Vec, + /// [`Arc`] used for a cheap clone, which improves physical plan optimization performance. + exprs: Arc<[ProjectionExpr]>, } impl std::fmt::Display for ProjectionExprs { @@ -126,14 +140,16 @@ impl std::fmt::Display for ProjectionExprs { impl From> for ProjectionExprs { fn from(value: Vec) -> Self { - Self { exprs: value } + Self { + exprs: value.into(), + } } } impl From<&[ProjectionExpr]> for ProjectionExprs { fn from(value: &[ProjectionExpr]) -> Self { Self { - exprs: value.to_vec(), + exprs: value.iter().cloned().collect(), } } } @@ -141,7 +157,7 @@ impl From<&[ProjectionExpr]> for ProjectionExprs { impl FromIterator for ProjectionExprs { fn from_iter>(exprs: T) -> Self { Self { - exprs: exprs.into_iter().collect::>(), + exprs: exprs.into_iter().collect(), } } } @@ -153,12 +169,17 @@ impl AsRef<[ProjectionExpr]> for ProjectionExprs { } impl ProjectionExprs { - pub fn new(exprs: I) -> Self - where - I: IntoIterator, - { + /// Make a new [`ProjectionExprs`] from expressions iterator. + pub fn new(exprs: impl IntoIterator) -> Self { + Self { + exprs: exprs.into_iter().collect(), + } + } + + /// Make a new [`ProjectionExprs`] from expressions. + pub fn from_expressions(exprs: impl Into>) -> Self { Self { - exprs: exprs.into_iter().collect::>(), + exprs: exprs.into(), } } @@ -240,6 +261,50 @@ impl ProjectionExprs { self.exprs.iter().map(|e| Arc::clone(&e.expr)) } + /// Apply a fallible transformation to the [`PhysicalExpr`] of each projection. + /// + /// This method transforms the expression in each [`ProjectionExpr`] while preserving + /// the alias. This is useful for rewriting expressions, such as when adapting + /// expressions to a different schema. + /// + /// # Example + /// + /// ```rust + /// use std::sync::Arc; + /// use arrow::datatypes::{DataType, Field, Schema}; + /// use datafusion_common::Result; + /// use datafusion_physical_expr::expressions::Column; + /// use datafusion_physical_expr::projection::ProjectionExprs; + /// use datafusion_physical_expr::PhysicalExpr; + /// + /// // Create a schema and projection + /// let schema = Arc::new(Schema::new(vec![ + /// Field::new("a", DataType::Int32, false), + /// Field::new("b", DataType::Int32, false), + /// ])); + /// let projection = ProjectionExprs::from_indices(&[0, 1], &schema); + /// + /// // Transform each expression (this example just clones them) + /// let transformed = projection.try_map_exprs(|expr| Ok(expr))?; + /// assert_eq!(transformed.as_ref().len(), 2); + /// # Ok::<(), datafusion_common::DataFusionError>(()) + /// ``` + pub fn try_map_exprs(self, mut f: F) -> Result + where + F: FnMut(Arc) -> Result>, + { + let exprs = self + .exprs + .iter() + .cloned() + .map(|mut proj| { + proj.expr = f(proj.expr)?; + Ok(proj) + }) + .collect::>>()?; + Ok(Self::from_expressions(exprs)) + } + /// Apply another projection on top of this projection, returning the combined projection. /// For example, if this projection is `SELECT c@2 AS x, b@1 AS y, a@0 as z` and the other projection is `SELECT x@0 + 1 AS c1, y@1 + z@2 as c2`, /// we return a projection equivalent to `SELECT c@2 + 1 AS c1, b@1 + a@0 as c2`. @@ -307,17 +372,9 @@ impl ProjectionExprs { /// applied on top of this projection. pub fn try_merge(&self, other: &ProjectionExprs) -> Result { let mut new_exprs = Vec::with_capacity(other.exprs.len()); - for proj_expr in &other.exprs { - let new_expr = update_expr(&proj_expr.expr, &self.exprs, true)? - .ok_or_else(|| { - internal_datafusion_err!( - "Failed to combine projections: expression {} could not be applied on top of existing projections {}", - proj_expr.expr, - self.exprs.iter().map(|e| format!("{e}")).join(", ") - ) - })?; + for proj_expr in other.exprs.iter() { new_exprs.push(ProjectionExpr { - expr: new_expr, + expr: self.unproject_expr(&proj_expr.expr)?, alias: proj_expr.alias.clone(), }); } @@ -364,12 +421,19 @@ impl ProjectionExprs { /// /// Use [`column_indices()`](Self::column_indices) instead if the projection may contain /// non-column expressions or if you need a deduplicated sorted list. + /// + /// # Panics + /// + /// Panics if any expression in the projection is not a simple column reference. + #[deprecated( + since = "52.0.0", + note = "Use column_indices() instead. This method will be removed in 58.0.0 or 6 months after 52.0.0 is released, whichever comes first." + )] pub fn ordered_column_indices(&self) -> Vec { self.exprs .iter() .map(|e| { e.expr - .as_any() .downcast_ref::() .expect("Expected column reference in projection") .index() @@ -378,9 +442,16 @@ impl ProjectionExprs { } /// Project a schema according to this projection. - /// For example, for a projection `SELECT a AS x, b + 1 AS y`, where `a` is at index 0 and `b` is at index 1, - /// if the input schema is `[a: Int32, b: Int32, c: Int32]`, the output schema would be `[x: Int32, y: Int32]`. - /// Fields' metadata are preserved from the input schema. + /// + /// For example, given a projection: + /// * `SELECT a AS x, b + 1 AS y` + /// * where `a` is at index 0 + /// * `b` is at index 1 + /// + /// If the input schema is `[a: Int32, b: Int32, c: Int32]`, the output + /// schema would be `[x: Int32, y: Int32]`. + /// + /// Note that [`Field`] metadata are preserved from the input schema. pub fn project_schema(&self, input_schema: &Schema) -> Result { let fields: Result> = self .exprs @@ -409,6 +480,48 @@ impl ProjectionExprs { )) } + /// "unproject" an expression by applying this projection in reverse, + /// returning a new set of expressions that reference the original input + /// columns. + /// + /// For example, consider + /// * an expression `c1_c2 > 5`, and a schema `[c1, c2]` + /// * a projection `c1 + c2 as c1_c2` + /// + /// This method would rewrite the expression to `c1 + c2 > 5` + pub fn unproject_expr( + &self, + expr: &Arc, + ) -> Result> { + update_expr(expr, &self.exprs, true)?.ok_or_else(|| { + internal_datafusion_err!( + "Failed to unproject an expression {} with ProjectionExprs {}", + expr, + self.exprs.iter().map(|e| format!("{e}")).join(", ") + ) + }) + } + + /// "project" an expression using these projection's expressions + /// + /// For example, consider + /// * an expression `c1 + c2 > 5`, and a schema `[c1, c2]` + /// * a projection `c1 + c2 as c1_c2` + /// + /// * This method would rewrite the expression to `c1_c2 > 5` + pub fn project_expr( + &self, + expr: &Arc, + ) -> Result> { + update_expr(expr, &self.exprs, false)?.ok_or_else(|| { + internal_datafusion_err!( + "Failed to project an expression {} with ProjectionExprs {}", + expr, + self.exprs.iter().map(|e| format!("{e}")).join(", ") + ) + }) + } + /// Create a new [`Projector`] from this projection and an input schema. /// /// A [`Projector`] can be used to apply this projection to record batches. @@ -422,48 +535,231 @@ impl ProjectionExprs { Ok(Projector { projection: self.clone(), output_schema, + expression_metrics: None, }) } + pub fn create_expression_metrics( + &self, + metrics: &ExecutionPlanMetricsSet, + partition: usize, + ) -> ExpressionEvaluatorMetrics { + let labels: Vec = self + .exprs + .iter() + .map(|proj_expr| { + let expr_sql = fmt_sql(proj_expr.expr.as_ref()).to_string(); + if proj_expr.expr.to_string() == proj_expr.alias { + expr_sql + } else { + format!("{expr_sql} AS {}", proj_expr.alias) + } + }) + .collect(); + ExpressionEvaluatorMetrics::new(metrics, partition, labels) + } + /// Project statistics according to this projection. /// For example, for a projection `SELECT a AS x, b + 1 AS y`, where `a` is at index 0 and `b` is at index 1, /// if the input statistics has column statistics for columns `a`, `b`, and `c`, the output statistics would have column statistics for columns `x` and `y`. + /// + /// # Example + /// + /// ```rust + /// use arrow::datatypes::{DataType, Field, Schema}; + /// use datafusion_common::stats::{ColumnStatistics, Precision, Statistics}; + /// use datafusion_physical_expr::projection::ProjectionExprs; + /// use datafusion_common::Result; + /// use datafusion_common::ScalarValue; + /// use std::sync::Arc; + /// + /// fn main() -> Result<()> { + /// // Input schema: a: Int32, b: Int32, c: Int32 + /// let input_schema = Arc::new(Schema::new(vec![ + /// Field::new("a", DataType::Int32, false), + /// Field::new("b", DataType::Int32, false), + /// Field::new("c", DataType::Int32, false), + /// ])); + /// + /// // Input statistics with column stats for a, b, c + /// let input_stats = Statistics { + /// num_rows: Precision::Exact(100), + /// total_byte_size: Precision::Exact(1200), + /// column_statistics: vec![ + /// // Column a stats + /// ColumnStatistics::new_unknown() + /// .with_null_count(Precision::Exact(0)) + /// .with_min_value(Precision::Exact(ScalarValue::Int32(Some(0)))) + /// .with_max_value(Precision::Exact(ScalarValue::Int32(Some(100)))) + /// .with_distinct_count(Precision::Exact(100)), + /// // Column b stats + /// ColumnStatistics::new_unknown() + /// .with_null_count(Precision::Exact(0)) + /// .with_min_value(Precision::Exact(ScalarValue::Int32(Some(10)))) + /// .with_max_value(Precision::Exact(ScalarValue::Int32(Some(60)))) + /// .with_distinct_count(Precision::Exact(50)), + /// // Column c stats + /// ColumnStatistics::new_unknown() + /// .with_null_count(Precision::Exact(5)) + /// .with_min_value(Precision::Exact(ScalarValue::Int32(Some(-10)))) + /// .with_max_value(Precision::Exact(ScalarValue::Int32(Some(200)))) + /// .with_distinct_count(Precision::Exact(25)), + /// ], + /// }; + /// + /// // Create a projection that selects columns c and a (indices 2 and 0) + /// let projection = ProjectionExprs::from_indices(&[2, 0], &input_schema); + /// + /// // Compute output schema + /// let output_schema = projection.project_schema(&input_schema)?; + /// + /// // Project the statistics + /// let output_stats = projection.project_statistics(input_stats, &output_schema)?; + /// + /// // The output should have 2 column statistics (for c and a, in that order) + /// assert_eq!(output_stats.column_statistics.len(), 2); + /// + /// // First column in output is c (was at index 2) + /// assert_eq!( + /// output_stats.column_statistics[0].min_value, + /// Precision::Exact(ScalarValue::Int32(Some(-10))) + /// ); + /// assert_eq!( + /// output_stats.column_statistics[0].null_count, + /// Precision::Exact(5) + /// ); + /// + /// // Second column in output is a (was at index 0) + /// assert_eq!( + /// output_stats.column_statistics[1].min_value, + /// Precision::Exact(ScalarValue::Int32(Some(0))) + /// ); + /// assert_eq!( + /// output_stats.column_statistics[1].distinct_count, + /// Precision::Exact(100) + /// ); + /// + /// // Total byte size is recalculated based on projected columns + /// assert_eq!( + /// output_stats.total_byte_size, + /// Precision::Exact(800), // each Int32 column is 4 bytes * 100 rows * 2 columns + /// ); + /// + /// // Number of rows remains the same + /// assert_eq!(output_stats.num_rows, Precision::Exact(100)); + /// + /// Ok(()) + /// } + /// ``` pub fn project_statistics( &self, - mut stats: datafusion_common::Statistics, - input_schema: &Schema, - ) -> Result { - let mut primitive_row_size = 0; - let mut primitive_row_size_possible = true; - let mut column_statistics = vec![]; - - for proj_expr in &self.exprs { + mut stats: Statistics, + output_schema: &Schema, + ) -> Result { + let mut column_statistics = Vec::with_capacity(self.exprs.len()); + + for proj_expr in self.exprs.iter() { let expr = &proj_expr.expr; - let col_stats = if let Some(col) = expr.as_any().downcast_ref::() { + let col_stats = if let Some(col) = expr.downcast_ref::() { stats.column_statistics[col.index()].clone() + } else if let Some(literal) = expr.downcast_ref::() { + // Handle literal expressions (constants) by calculating proper statistics + let data_type = expr.data_type(output_schema)?; + + if literal.value().is_null() { + let null_count = match stats.num_rows { + Precision::Exact(num_rows) => Precision::Exact(num_rows), + _ => Precision::Absent, + }; + + ColumnStatistics { + min_value: Precision::Exact(literal.value().clone()), + max_value: Precision::Exact(literal.value().clone()), + distinct_count: Precision::Exact(1), + null_count, + sum_value: Precision::Exact(literal.value().clone()), + byte_size: Precision::Exact(0), + } + } else { + let value = literal.value(); + let distinct_count = Precision::Exact(1); + let null_count = Precision::Exact(0); + + let byte_size = if let Some(byte_width) = data_type.primitive_width() + { + stats.num_rows.multiply(&Precision::Exact(byte_width)) + } else { + // Complex types depend on array encoding, so set to Absent + Precision::Absent + }; + + let widened_sum = Precision::Exact(value.clone()).cast_to_sum_type(); + let sum_value = widened_sum + .get_value() + .and_then(|sum| { + Precision::::from(stats.num_rows) + .cast_to(&sum.data_type()) + .ok() + }) + .map(|row_count| widened_sum.multiply(&row_count)) + .unwrap_or(Precision::Absent); + + ColumnStatistics { + min_value: Precision::Exact(value.clone()), + max_value: Precision::Exact(value.clone()), + distinct_count, + null_count, + sum_value, + byte_size, + } + } } else { - // TODO stats: estimate more statistics from expressions - // (expressions should compute their statistics themselves) - ColumnStatistics::new_unknown() + project_column_statistics_through_expr( + expr.as_ref(), + &stats.column_statistics, + ) }; column_statistics.push(col_stats); - let data_type = expr.data_type(input_schema)?; - if let Some(value) = data_type.primitive_width() { - primitive_row_size += value; - continue; - } - primitive_row_size_possible = false; - } - - if primitive_row_size_possible { - stats.total_byte_size = - Precision::Exact(primitive_row_size).multiply(&stats.num_rows); } + stats.calculate_total_byte_size(output_schema); stats.column_statistics = column_statistics; Ok(stats) } } +/// Propagate column statistics through CAST projections. Other expressions +/// return unknown — generalizing via [`PhysicalExpr::evaluate_bounds`] is +/// unsafe for aggregate folding since many impls (e.g. `sin`) return a fixed +/// envelope rather than tight bounds on the actual inputs. +fn project_column_statistics_through_expr( + expr: &dyn PhysicalExpr, + column_stats: &[ColumnStatistics], +) -> ColumnStatistics { + if let Some(col) = expr.downcast_ref::() { + return column_stats[col.index()].clone(); + } + let Some(cast_expr) = expr.downcast_ref::() else { + return ColumnStatistics::new_unknown(); + }; + let inner_stats = + project_column_statistics_through_expr(cast_expr.expr.as_ref(), column_stats); + let target_type = cast_expr.cast_type(); + ColumnStatistics { + min_value: inner_stats + .min_value + .cast_to(target_type) + .unwrap_or(Precision::Absent), + max_value: inner_stats + .max_value + .cast_to(target_type) + .unwrap_or(Precision::Absent), + null_count: inner_stats.null_count, + distinct_count: inner_stats.distinct_count, + sum_value: Precision::Absent, + byte_size: Precision::Absent, + } +} + impl<'a> IntoIterator for &'a ProjectionExprs { type Item = &'a ProjectionExpr; type IntoIter = std::slice::Iter<'a, ProjectionExpr>; @@ -485,9 +781,30 @@ impl<'a> IntoIterator for &'a ProjectionExprs { pub struct Projector { projection: ProjectionExprs, output_schema: SchemaRef, + /// If `Some`, metrics will be tracked for projection evaluation. + expression_metrics: Option, } impl Projector { + /// Construct the projector with metrics. After execution, related metrics will + /// be tracked inside `ExecutionPlanMetricsSet` + /// + /// See [`ExpressionEvaluatorMetrics`] for details. + pub fn with_metrics( + &self, + metrics: &ExecutionPlanMetricsSet, + partition: usize, + ) -> Self { + let expr_metrics = self + .projection + .create_expression_metrics(metrics, partition); + Self { + expression_metrics: Some(expr_metrics), + projection: self.projection.clone(), + output_schema: Arc::clone(&self.output_schema), + } + } + /// Project a record batch according to this projector's expressions. /// /// # Errors @@ -495,9 +812,10 @@ impl Projector { /// or if the output schema of the resulting record batch does not match /// the pre-computed output schema of the projector. pub fn project_batch(&self, batch: &RecordBatch) -> Result { - let arrays = evaluate_expressions_to_arrays( + let arrays = evaluate_expressions_to_arrays_with_metrics( self.projection.exprs.iter().map(|p| &p.expr), batch, + self.expression_metrics.as_ref(), )?; if arrays.is_empty() { @@ -524,35 +842,92 @@ impl Projector { } } -impl IntoIterator for ProjectionExprs { - type Item = ProjectionExpr; - type IntoIter = std::vec::IntoIter; +/// Describes an immutable reference counted projection. +/// +/// This structure represents projecting a set of columns by index. +/// [`Arc`] is used to make it cheap to clone. +pub type ProjectionRef = Arc<[usize]>; - fn into_iter(self) -> Self::IntoIter { - self.exprs.into_iter() - } +/// Combine two projections. +/// +/// If `p1` is [`None`] then there are no changes. +/// Otherwise, if passed `p2` is not [`None`] then it is remapped +/// according to the `p1`. Otherwise, there are no changes. +/// +/// # Example +/// +/// If stored projection is [0, 2] and we call `apply_projection([0, 2, 3])`, +/// then the resulting projection will be [0, 3]. +/// +/// # Error +/// +/// Returns an internal error if `p1` contains index that is greater than `p2` len. +/// +pub fn combine_projections( + p1: Option<&ProjectionRef>, + p2: Option<&ProjectionRef>, +) -> Result> { + let Some(p1) = p1 else { + return Ok(None); + }; + let Some(p2) = p2 else { + return Ok(Some(Arc::clone(p1))); + }; + + Ok(Some( + p1.iter() + .map(|i| { + let idx = *i; + assert_or_internal_err!( + idx < p2.len(), + "unable to apply projection: index {} is greater than new projection len {}", + idx, + p2.len(), + ); + Ok(p2[*i]) + }) + .collect::>>()?, + )) } -/// The function operates in two modes: +/// The function projects / unprojects an expression with respect to set of +/// projection expressions. +/// +/// See also [`ProjectionExprs::unproject_expr`] and [`ProjectionExprs::project_expr`] +/// +/// 1) When `unproject` is `true`: /// -/// 1) When `sync_with_child` is `true`: +/// Rewrites an expression with respect to the projection expressions, +/// effectively "unprojecting" it to reference the original input columns. /// -/// The function updates the indices of `expr` if the expression resides -/// in the input plan. For instance, given the expressions `a@1 + b@2` -/// and `c@0` with the input schema `c@2, a@0, b@1`, the expressions are -/// updated to `a@0 + b@1` and `c@2`. +/// For example, given +/// * the expressions `a@1 + b@2` and `c@0` +/// * and projection expressions `c@2, a@0, b@1` /// -/// 2) When `sync_with_child` is `false`: +/// Then +/// * `a@1 + b@2` becomes `a@0 + b@1` +/// * `c@0` becomes `c@2` /// -/// The function determines how the expression would be updated if a projection -/// was placed before the plan associated with the expression. If the expression -/// cannot be rewritten after the projection, it returns `None`. For example, -/// given the expressions `c@0`, `a@1` and `b@2`, and the projection with -/// an output schema of `a, c_new`, then `c@0` becomes `c_new@1`, `a@1` becomes -/// `a@0`, but `b@2` results in `None` since the projection does not include `b`. +/// 2) When `unproject` is `false`: +/// +/// Rewrites the expression to reference the projected expressions, +/// effectively "projecting" it. The resulting expression will reference the +/// indices as they appear in the projection. +/// +/// If the expression cannot be rewritten after the projection, it returns +/// `None`. +/// +/// For example, given +/// * the expressions `c@0`, `a@1` and `b@2` +/// * the projection `a@1 as a, c@0 as c_new`, +/// +/// Then +/// * `c@0` becomes `c_new@1` +/// * `a@1` becomes `a@0` +/// * `b@2` results in `None` since the projection does not include `b`. /// /// # Errors -/// This function returns an error if `sync_with_child` is `true` and if any expression references +/// This function returns an error if `unproject` is `true` and if any expression references /// an index that is out of bounds for `projected_exprs`. /// For example: /// @@ -563,7 +938,7 @@ impl IntoIterator for ProjectionExprs { pub fn update_expr( expr: &Arc, projected_exprs: &[ProjectionExpr], - sync_with_child: bool, + unproject: bool, ) -> Result>> { #[derive(Debug, PartialEq)] enum RewriteState { @@ -584,12 +959,10 @@ pub fn update_expr( return Ok(Transformed::no(expr)); } - let Some(column) = expr.as_any().downcast_ref::() else { + let Some(column) = expr.downcast_ref::() else { return Ok(Transformed::no(expr)); }; - if sync_with_child { - state = RewriteState::RewrittenValid; - // Update the index of `column`: + if unproject { let projected_expr = projected_exprs.get(column.index()).ok_or_else(|| { internal_datafusion_err!( "Column index {} out of bounds for projected expressions of length {}", @@ -597,6 +970,17 @@ pub fn update_expr( projected_exprs.len() ) })?; + // Skip rebuilding the parent if substituting with an equal + // Column (e.g. pass-through `c0@0` -> `c0@0` during chained + // projection collapse). Without this, every CASE/BinaryExpr + // containing such a Column is reconstructed unnecessarily. + if let Some(projected_col) = + projected_expr.expr.downcast_ref::() + && projected_col == column + { + return Ok(Transformed::no(expr)); + } + state = RewriteState::RewrittenValid; Ok(Transformed::yes(Arc::clone(&projected_expr.expr))) } else { // default to invalid, in case we can't find the relevant column @@ -606,7 +990,7 @@ pub fn update_expr( .iter() .enumerate() .find_map(|(index, proj_expr)| { - proj_expr.expr.as_any().downcast_ref::().and_then( + proj_expr.expr.downcast_ref::().and_then( |projected_column| { (column.name().eq(projected_column.name()) && column.index() == projected_column.index()) @@ -702,7 +1086,7 @@ impl ProjectionMapping { let mut map = IndexMap::<_, ProjectionTargets>::new(); for (expr_idx, (expr, name)) in expr.into_iter().enumerate() { let target_expr = Arc::new(Column::new(&name, expr_idx)) as _; - let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::() { + let source_expr = expr.transform_down(|e| match e.downcast_ref::() { Some(col) => { // Sometimes, an expression and its name in the input_schema // doesn't match. This can cause problems, so we make sure @@ -722,9 +1106,66 @@ impl ProjectionMapping { None => Ok(Transformed::no(e)), }) .data()?; - map.entry(source_expr) + map.entry(Arc::clone(&source_expr)) .or_default() - .push((target_expr, expr_idx)); + .push((Arc::clone(&target_expr), expr_idx)); + + // For struct-producing functions (e.g. named_struct), decompose + // into field-level mapping entries so that orderings propagate + // through struct projections. For example, if the projection has + // `named_struct('ticker', p.ticker, ...) AS details`, this adds: + // p.ticker → get_field(col("details"), "ticker") + // enabling the optimizer to know that sorting by + // `details.ticker` is equivalent to sorting by `p.ticker`. + if let Some(func_expr) = source_expr.downcast_ref::() { + let literal_args: Vec> = func_expr + .args() + .iter() + .map(|arg| arg.downcast_ref::().map(|l| l.value().clone())) + .collect(); + + if let Some(field_mapping) = + func_expr.fun().struct_field_mapping(&literal_args) + && let DataType::Struct(struct_fields) = func_expr.return_type() + { + for (accessor_args, source_arg_idx) in &field_mapping.fields { + let value_expr = Arc::clone(&func_expr.args()[*source_arg_idx]); + + // Build accessor args: [target_col, ...field_name_literals] + let mut accessor_fn_args: Vec> = + vec![Arc::clone(&target_expr)]; + accessor_fn_args.extend(accessor_args.iter().map(|sv| { + Arc::new(Literal::new(sv.clone())) as Arc + })); + + // Look up the field's return type from the struct schema + let return_field = accessor_args + .first() + .and_then(|sv| sv.try_as_str().flatten()) + .and_then(|field_name| { + struct_fields + .iter() + .find(|f| f.name() == field_name) + .cloned() + }); + + if let Some(return_field) = return_field { + let field_access_expr = Arc::new(ScalarFunctionExpr::new( + field_mapping.field_accessor.name(), + Arc::clone(&field_mapping.field_accessor), + accessor_fn_args, + return_field, + Arc::new(func_expr.config_options().clone()), + )) + as Arc; + + map.entry(value_expr) + .or_default() + .push((field_access_expr, expr_idx)); + } + } + } + } } Ok(Self { map }) } @@ -822,7 +1263,7 @@ pub fn project_ordering( let mut projected_exprs = vec![]; for PhysicalSortExpr { expr, options } in ordering.iter() { let transformed = Arc::clone(expr).transform_up(|expr| { - let Some(col) = expr.as_any().downcast_ref::() else { + let Some(col) = expr.downcast_ref::() else { return Ok(Transformed::no(expr)); }; @@ -857,15 +1298,14 @@ pub(crate) mod tests { use std::collections::HashMap; use super::*; - use crate::equivalence::{convert_to_orderings, EquivalenceProperties}; - use crate::expressions::{col, BinaryExpr, Literal}; + use crate::equivalence::{EquivalenceProperties, convert_to_orderings}; + use crate::expressions::{BinaryExpr, CastExpr, col}; use crate::utils::tests::TestScalarUDF; use crate::{PhysicalExprRef, ScalarFunctionExpr}; use arrow::compute::SortOptions; - use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; + use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::config::ConfigOptions; - use datafusion_common::{ScalarValue, Statistics}; use datafusion_expr::{Operator, ScalarUDF}; use insta::assert_snapshot; @@ -879,8 +1319,10 @@ pub(crate) mod tests { let data_type = source.data_type(input_schema)?; let nullable = source.nullable(input_schema)?; for (target, _) in targets.iter() { - let Some(column) = target.as_any().downcast_ref::() else { - return plan_err!("Expects to have column"); + // Skip non-Column targets (e.g. struct field decomposition + // entries which are ScalarFunctionExpr targets). + let Some(column) = target.downcast_ref::() else { + continue; }; fields.push(Field::new(column.name(), data_type.clone(), nullable)); } @@ -1731,6 +2173,7 @@ pub(crate) mod tests { min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Exact(0), + byte_size: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Exact(1), @@ -1738,6 +2181,7 @@ pub(crate) mod tests { min_value: Precision::Exact(ScalarValue::from("a")), sum_value: Precision::Absent, null_count: Precision::Exact(3), + byte_size: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Absent, @@ -1745,6 +2189,7 @@ pub(crate) mod tests { min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))), sum_value: Precision::Exact(ScalarValue::Float32(Some(5.5))), null_count: Precision::Absent, + byte_size: Precision::Absent, }, ], } @@ -1773,11 +2218,15 @@ pub(crate) mod tests { }, ]); - let result = projection.project_statistics(source, &schema).unwrap(); + let result = projection + .project_statistics(source, &projection.project_schema(&schema).unwrap()) + .unwrap(); let expected = Statistics { num_rows: Precision::Exact(5), - total_byte_size: Precision::Exact(23), + // Because there is a variable length Utf8 column we cannot calculate exact byte size after projection + // Thus we set it to Inexact (originally it was Exact(23)) + total_byte_size: Precision::Inexact(23), column_statistics: vec![ ColumnStatistics { distinct_count: Precision::Exact(1), @@ -1785,6 +2234,7 @@ pub(crate) mod tests { min_value: Precision::Exact(ScalarValue::from("a")), sum_value: Precision::Absent, null_count: Precision::Exact(3), + byte_size: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Exact(5), @@ -1792,6 +2242,7 @@ pub(crate) mod tests { min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Exact(0), + byte_size: Precision::Absent, }, ], }; @@ -1815,7 +2266,9 @@ pub(crate) mod tests { }, ]); - let result = projection.project_statistics(source, &schema).unwrap(); + let result = projection + .project_statistics(source, &projection.project_schema(&schema).unwrap()) + .unwrap(); let expected = Statistics { num_rows: Precision::Exact(5), @@ -1827,6 +2280,7 @@ pub(crate) mod tests { min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))), sum_value: Precision::Exact(ScalarValue::Float32(Some(5.5))), null_count: Precision::Absent, + byte_size: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Exact(5), @@ -1834,6 +2288,7 @@ pub(crate) mod tests { min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Exact(0), + byte_size: Precision::Absent, }, ], }; @@ -2130,11 +2585,7 @@ pub(crate) mod tests { let result_expr = result.unwrap(); assert_eq!( - result_expr - .as_any() - .downcast_ref::() - .unwrap() - .value(), + result_expr.downcast_ref::().unwrap().value(), &ScalarValue::Int64(Some(42)) ); @@ -2163,17 +2614,15 @@ pub(crate) mod tests { let result_expr = result.unwrap(); let binary = result_expr - .as_any() .downcast_ref::() .expect("Should be a BinaryExpr"); // Left side should still be the literal - assert!(binary.left().as_any().downcast_ref::().is_some()); + assert!(binary.left().downcast_ref::().is_some()); // Right side should be updated to reference column at index 5 let right_col = binary .right() - .as_any() .downcast_ref::() .expect("Right should be a Column"); assert_eq!(right_col.index(), 5); @@ -2300,7 +2749,10 @@ pub(crate) mod tests { }, ]); - let output_stats = projection.project_statistics(input_stats, &input_schema)?; + let output_stats = projection.project_statistics( + input_stats, + &projection.project_schema(&input_schema)?, + )?; // Row count should be preserved assert_eq!(output_stats.num_rows, Precision::Exact(5)); @@ -2352,7 +2804,10 @@ pub(crate) mod tests { }, ]); - let output_stats = projection.project_statistics(input_stats, &input_schema)?; + let output_stats = projection.project_statistics( + input_stats, + &projection.project_schema(&input_schema)?, + )?; // Row count should be preserved assert_eq!(output_stats.num_rows, Precision::Exact(5)); @@ -2379,6 +2834,88 @@ pub(crate) mod tests { Ok(()) } + #[test] + fn test_project_statistics_with_cast() -> Result<()> { + let input_stats = get_stats(); + let input_schema = get_schema(); + + // SELECT CAST(col0 AS Int32) AS casted + let projection = ProjectionExprs::new(vec![ProjectionExpr { + expr: Arc::new(CastExpr::new( + Arc::new(Column::new("col0", 0)), + DataType::Int32, + None, + )), + alias: "casted".to_string(), + }]); + + let output_stats = projection.project_statistics( + input_stats, + &projection.project_schema(&input_schema)?, + )?; + + assert_eq!( + output_stats.column_statistics[0].min_value, + Precision::Exact(ScalarValue::Int32(Some(-4))) + ); + assert_eq!( + output_stats.column_statistics[0].max_value, + Precision::Exact(ScalarValue::Int32(Some(21))) + ); + + Ok(()) + } + + #[test] + fn test_project_statistics_duplicate_column() -> Result<()> { + let input_stats = get_stats(); + let col0 = input_stats.column_statistics[0].clone(); + let projection = ProjectionExprs::new([ + ProjectionExpr::new(Arc::new(Column::new("col0", 0)), "a"), + ProjectionExpr::new(Arc::new(Column::new("col0", 0)), "b"), + ]); + + let output_schema = projection.project_schema(&get_schema())?; + let output_stats = projection.project_statistics(input_stats, &output_schema)?; + + assert_eq!(output_stats.column_statistics, vec![col0.clone(), col0]); + Ok(()) + } + + #[test] + fn test_project_statistics_column_and_cast() -> Result<()> { + let input_stats = get_stats(); + let col0 = input_stats.column_statistics[0].clone(); + let projection = ProjectionExprs::new([ + ProjectionExpr::new(Arc::new(Column::new("col0", 0)), "num"), + ProjectionExpr::new( + Arc::new(CastExpr::new( + Arc::new(Column::new("col0", 0)), + DataType::Int32, + None, + )), + "casted", + ), + ]); + + let output_schema = projection.project_schema(&get_schema())?; + let output_stats = projection.project_statistics(input_stats, &output_schema)?; + + assert_eq!(output_stats.column_statistics[0], col0); + assert_eq!( + output_stats.column_statistics[1], + ColumnStatistics { + min_value: Precision::Exact(ScalarValue::Int32(Some(-4))), + max_value: Precision::Exact(ScalarValue::Int32(Some(21))), + distinct_count: Precision::Exact(5), + null_count: Precision::Exact(0), + sum_value: Precision::Absent, + byte_size: Precision::Absent, + } + ); + Ok(()) + } + #[test] fn test_project_statistics_primitive_width_only() -> Result<()> { let input_stats = get_stats(); @@ -2396,7 +2933,10 @@ pub(crate) mod tests { }, ]); - let output_stats = projection.project_statistics(input_stats, &input_schema)?; + let output_stats = projection.project_statistics( + input_stats, + &projection.project_schema(&input_schema)?, + )?; // Row count should be preserved assert_eq!(output_stats.num_rows, Precision::Exact(5)); @@ -2418,7 +2958,10 @@ pub(crate) mod tests { let projection = ProjectionExprs::new(vec![]); - let output_stats = projection.project_statistics(input_stats, &input_schema)?; + let output_stats = projection.project_statistics( + input_stats, + &projection.project_schema(&input_schema)?, + )?; // Row count should be preserved assert_eq!(output_stats.num_rows, Precision::Exact(5)); @@ -2431,4 +2974,246 @@ pub(crate) mod tests { Ok(()) } + + // Test statistics calculation for non-null literal (numeric constant) + #[test] + fn test_project_statistics_with_literal() -> Result<()> { + let input_stats = get_stats(); + let input_schema = get_schema(); + + // Projection with literal: SELECT 42 AS constant, col0 AS num + let projection = ProjectionExprs::new(vec![ + ProjectionExpr { + expr: Arc::new(Literal::new(ScalarValue::Int64(Some(42)))), + alias: "constant".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("col0", 0)), + alias: "num".to_string(), + }, + ]); + + let output_stats = projection.project_statistics( + input_stats, + &projection.project_schema(&input_schema)?, + )?; + + // Row count should be preserved + assert_eq!(output_stats.num_rows, Precision::Exact(5)); + + // Should have 2 column statistics + assert_eq!(output_stats.column_statistics.len(), 2); + + // First column (literal 42) should have proper constant statistics + assert_eq!( + output_stats.column_statistics[0].min_value, + Precision::Exact(ScalarValue::Int64(Some(42))) + ); + assert_eq!( + output_stats.column_statistics[0].max_value, + Precision::Exact(ScalarValue::Int64(Some(42))) + ); + assert_eq!( + output_stats.column_statistics[0].distinct_count, + Precision::Exact(1) + ); + assert_eq!( + output_stats.column_statistics[0].null_count, + Precision::Exact(0) + ); + // Int64 is 8 bytes, 5 rows = 40 bytes + assert_eq!( + output_stats.column_statistics[0].byte_size, + Precision::Exact(40) + ); + // For a constant column, sum_value = value * num_rows = 42 * 5 = 210 + assert_eq!( + output_stats.column_statistics[0].sum_value, + Precision::Exact(ScalarValue::Int64(Some(210))) + ); + + // Second column (col0) should preserve statistics + assert_eq!( + output_stats.column_statistics[1].distinct_count, + Precision::Exact(5) + ); + assert_eq!( + output_stats.column_statistics[1].max_value, + Precision::Exact(ScalarValue::Int64(Some(21))) + ); + + Ok(()) + } + + #[test] + fn test_project_statistics_with_i32_literal_sum_widens_to_i64() -> Result<()> { + let input_stats = get_stats(); + let input_schema = get_schema(); + + let projection = ProjectionExprs::new(vec![ + ProjectionExpr { + expr: Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + alias: "constant".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("col0", 0)), + alias: "num".to_string(), + }, + ]); + + let output_stats = projection.project_statistics( + input_stats, + &projection.project_schema(&input_schema)?, + )?; + + assert_eq!( + output_stats.column_statistics[0].sum_value, + Precision::Exact(ScalarValue::Int64(Some(50))) + ); + + Ok(()) + } + + // Test statistics calculation for NULL literal (constant NULL column) + #[test] + fn test_project_statistics_with_null_literal() -> Result<()> { + let input_stats = get_stats(); + let input_schema = get_schema(); + + // Projection with NULL literal: SELECT NULL AS null_col, col0 AS num + let projection = ProjectionExprs::new(vec![ + ProjectionExpr { + expr: Arc::new(Literal::new(ScalarValue::Int64(None))), + alias: "null_col".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("col0", 0)), + alias: "num".to_string(), + }, + ]); + + let output_stats = projection.project_statistics( + input_stats, + &projection.project_schema(&input_schema)?, + )?; + + // Row count should be preserved + assert_eq!(output_stats.num_rows, Precision::Exact(5)); + + // Should have 2 column statistics + assert_eq!(output_stats.column_statistics.len(), 2); + + // First column (NULL literal) should have proper constant NULL statistics + assert_eq!( + output_stats.column_statistics[0].min_value, + Precision::Exact(ScalarValue::Int64(None)) + ); + assert_eq!( + output_stats.column_statistics[0].max_value, + Precision::Exact(ScalarValue::Int64(None)) + ); + assert_eq!( + output_stats.column_statistics[0].distinct_count, + Precision::Exact(1) // All NULLs are considered the same + ); + assert_eq!( + output_stats.column_statistics[0].null_count, + Precision::Exact(5) // All rows are NULL + ); + assert_eq!( + output_stats.column_statistics[0].byte_size, + Precision::Exact(0) + ); + assert_eq!( + output_stats.column_statistics[0].sum_value, + Precision::Exact(ScalarValue::Int64(None)) + ); + + // Second column (col0) should preserve statistics + assert_eq!( + output_stats.column_statistics[1].distinct_count, + Precision::Exact(5) + ); + assert_eq!( + output_stats.column_statistics[1].max_value, + Precision::Exact(ScalarValue::Int64(Some(21))) + ); + + Ok(()) + } + + // Test statistics calculation for complex type literal (e.g., Utf8 string) + #[test] + fn test_project_statistics_with_complex_type_literal() -> Result<()> { + let input_stats = get_stats(); + let input_schema = get_schema(); + + // Projection with Utf8 literal (complex type): SELECT 'hello' AS text, col0 AS num + let projection = ProjectionExprs::new(vec![ + ProjectionExpr { + expr: Arc::new(Literal::new(ScalarValue::Utf8(Some( + "hello".to_string(), + )))), + alias: "text".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("col0", 0)), + alias: "num".to_string(), + }, + ]); + + let output_stats = projection.project_statistics( + input_stats, + &projection.project_schema(&input_schema)?, + )?; + + // Row count should be preserved + assert_eq!(output_stats.num_rows, Precision::Exact(5)); + + // Should have 2 column statistics + assert_eq!(output_stats.column_statistics.len(), 2); + + // First column (Utf8 literal 'hello') should have proper constant statistics + // but byte_size should be Absent for complex types + assert_eq!( + output_stats.column_statistics[0].min_value, + Precision::Exact(ScalarValue::Utf8(Some("hello".to_string()))) + ); + assert_eq!( + output_stats.column_statistics[0].max_value, + Precision::Exact(ScalarValue::Utf8(Some("hello".to_string()))) + ); + assert_eq!( + output_stats.column_statistics[0].distinct_count, + Precision::Exact(1) + ); + assert_eq!( + output_stats.column_statistics[0].null_count, + Precision::Exact(0) + ); + // Complex types (Utf8, List, etc.) should have byte_size = Absent + // because we can't calculate exact size without knowing the actual data + assert_eq!( + output_stats.column_statistics[0].byte_size, + Precision::Absent + ); + // Non-numeric types (Utf8) should have sum_value = Absent + // because sum is only meaningful for numeric types + assert_eq!( + output_stats.column_statistics[0].sum_value, + Precision::Absent + ); + + // Second column (col0) should preserve statistics + assert_eq!( + output_stats.column_statistics[1].distinct_count, + Precision::Exact(5) + ); + assert_eq!( + output_stats.column_statistics[1].max_value, + Precision::Exact(ScalarValue::Int64(Some(21))) + ); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/proto_test_util.rs b/datafusion/physical-expr/src/proto_test_util.rs new file mode 100644 index 0000000000000..ab280335800b9 --- /dev/null +++ b/datafusion/physical-expr/src/proto_test_util.rs @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shared test helpers for proto serialization / deserialization in expression unit tests +//! without depending on `datafusion-proto` (which would create circular deps). + +use std::cell::Cell; +use std::sync::Arc; + +use arrow::datatypes::Schema; +use datafusion_common::{DataFusionError, Result}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecode; +use datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncode; +use datafusion_proto_models::protobuf::{self, PhysicalExprNode, physical_expr_node}; + +use crate::expressions::Column; + +/// A proto node for a `Column`, useful as a stand-in child node when building +/// an expression's proto representation in tests. +pub(crate) fn column_node(name: &str) -> PhysicalExprNode { + PhysicalExprNode { + expr_id: None, + expr_type: Some(physical_expr_node::ExprType::Column( + protobuf::PhysicalColumn { + name: name.to_string(), + index: 0, + }, + )), + } +} + +/// Decoder stub for driving `try_from_proto`: returns a fixed `Column` for each +/// child node, optionally failing on the Nth `decode` call so the +/// `ctx.decode(..)?` error arms can be exercised. +pub(crate) struct StubDecoder { + fail_on_call: Option, + calls: Cell, +} + +impl StubDecoder { + /// Always succeeds, returning a placeholder `Column` per child. + pub(crate) fn ok() -> Self { + Self { + fail_on_call: None, + calls: Cell::new(0), + } + } + + /// Fails on the `call`-th invocation (1-based), succeeding otherwise. + pub(crate) fn failing_on(call: usize) -> Self { + Self { + fail_on_call: Some(call), + calls: Cell::new(0), + } + } +} + +impl PhysicalExprDecode for StubDecoder { + fn decode( + &self, + _node: &PhysicalExprNode, + _schema: &Schema, + ) -> Result> { + let call = self.calls.get() + 1; + self.calls.set(call); + if Some(call) == self.fail_on_call { + return Err(DataFusionError::Internal(format!( + "stub decode failure on call {call}" + ))); + } + Ok(Arc::new(Column::new("decoded", 0))) + } +} + +/// Decoder that must never run: used to assert that the reject paths of a +/// `try_from_proto` (wrong node, missing child) bail out before decoding. +pub(crate) struct UnreachableDecoder; + +impl PhysicalExprDecode for UnreachableDecoder { + fn decode( + &self, + _node: &PhysicalExprNode, + _schema: &Schema, + ) -> Result> { + unreachable!("decode must not be reached when the node is rejected") + } +} + +/// Encoder stub for driving `try_to_proto`: emits a placeholder `Column` node +/// for each child, optionally failing on the Nth `encode` call so the +/// `ctx.encode_child(..)?` error arms can be exercised. +pub(crate) struct StubEncoder { + fail_on_call: Option, + calls: Cell, +} + +impl StubEncoder { + /// Always succeeds, emitting a placeholder `Column` node per child. + pub(crate) fn ok() -> Self { + Self { + fail_on_call: None, + calls: Cell::new(0), + } + } + + /// Fails on the `call`-th invocation (1-based), succeeding otherwise. + pub(crate) fn failing_on(call: usize) -> Self { + Self { + fail_on_call: Some(call), + calls: Cell::new(0), + } + } +} + +impl PhysicalExprEncode for StubEncoder { + fn encode(&self, _expr: &Arc) -> Result { + let call = self.calls.get() + 1; + self.calls.set(call); + if Some(call) == self.fail_on_call { + return Err(DataFusionError::Internal(format!( + "stub encode failure on call {call}" + ))); + } + Ok(column_node("child")) + } +} diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 743d5b99cde95..418d005c971ea 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -29,24 +29,23 @@ //! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed //! to a function that supports f64, it is coerced to f64. -use std::any::Any; use std::fmt::{self, Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::expressions::Literal; use crate::PhysicalExpr; +use crate::expressions::Literal; use arrow::array::{Array, RecordBatch}; use arrow::datatypes::{DataType, FieldRef, Schema}; use datafusion_common::config::{ConfigEntry, ConfigOptions}; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; -use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; +use datafusion_expr::type_coercion::functions::fields_with_udf; use datafusion_expr::{ - expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, - Volatility, + ColumnarValue, ExpressionPlacement, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, Volatility, expr_vec_fmt, }; /// Physical expression of a scalar function @@ -101,19 +100,11 @@ impl ScalarFunctionExpr { .collect::>>()?; // verify that input data types is consistent with function's `TypeSignature` - let arg_types = arg_fields - .iter() - .map(|f| f.data_type().clone()) - .collect::>(); - data_types_with_scalar_udf(&arg_types, &fun)?; + fields_with_udf(&arg_fields, fun.as_ref())?; let arguments = args .iter() - .map(|e| { - e.as_any() - .downcast_ref::() - .map(|literal| literal.value()) - }) + .map(|e| e.downcast_ref::().map(|literal| literal.value())) .collect::>(); let ret_args = ReturnFieldArgs { arg_fields: &arg_fields, @@ -173,19 +164,10 @@ impl ScalarFunctionExpr { /// Otherwise returns `Some(ScalarFunctionExpr)`. pub fn try_downcast_func(expr: &dyn PhysicalExpr) -> Option<&ScalarFunctionExpr> where - T: 'static, + T: ScalarUDFImpl, { - match expr.as_any().downcast_ref::() { - Some(scalar_expr) - if scalar_expr - .fun() - .inner() - .as_any() - .downcast_ref::() - .is_some() => - { - Some(scalar_expr) - } + match expr.downcast_ref::() { + Some(scalar_expr) if scalar_expr.fun().inner().is::() => Some(scalar_expr), _ => None, } } @@ -243,11 +225,6 @@ fn sorted_config_entries(config_options: &ConfigOptions) -> Vec { } impl PhysicalExpr for ScalarFunctionExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - fn data_type(&self, _input_schema: &Schema) -> Result { Ok(self.return_field.data_type().clone()) } @@ -283,19 +260,22 @@ impl PhysicalExpr for ScalarFunctionExpr { config_options: Arc::clone(&self.config_options), })?; - if let ColumnarValue::Array(array) = &output { - if array.len() != batch.num_rows() { - // If the arguments are a non-empty slice of scalar values, we can assume that - // returning a one-element array is equivalent to returning a scalar. - let preserve_scalar = - array.len() == 1 && !input_empty && input_all_scalar; - return if preserve_scalar { - ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar) - } else { - internal_err!("UDF {} returned a different number of rows than expected. Expected: {}, Got: {}", - self.name, batch.num_rows(), array.len()) - }; - } + if let ColumnarValue::Array(array) = &output + && array.len() != batch.num_rows() + { + // If the arguments are a non-empty slice of scalar values, we can assume that + // returning a one-element array is equivalent to returning a scalar. + let preserve_scalar = array.len() == 1 && !input_empty && input_all_scalar; + return if preserve_scalar { + ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar) + } else { + internal_err!( + "UDF {} returned a different number of rows than expected. Expected: {}, Got: {}", + self.name, + batch.num_rows(), + array.len() + ) + }; } Ok(output) } @@ -363,16 +343,21 @@ impl PhysicalExpr for ScalarFunctionExpr { fn is_volatile_node(&self) -> bool { self.fun.signature().volatility == Volatility::Volatile } + + fn placement(&self) -> ExpressionPlacement { + let arg_placements: Vec<_> = + self.args.iter().map(|arg| arg.placement()).collect(); + self.fun.placement(&arg_placements) + } } #[cfg(test)] mod tests { use super::*; use crate::expressions::Column; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature}; + use arrow::datatypes::Field; + use datafusion_expr::{ScalarUDFImpl, Signature}; use datafusion_physical_expr_common::physical_expr::is_volatile; - use std::any::Any; /// Test helper to create a mock UDF with a specific volatility #[derive(Debug, PartialEq, Eq, Hash)] @@ -381,10 +366,6 @@ mod tests { } impl ScalarUDFImpl for MockScalarUDF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "mock_function" } diff --git a/datafusion/physical-expr/src/scalar_subquery.rs b/datafusion/physical-expr/src/scalar_subquery.rs new file mode 100644 index 0000000000000..ea00847151e66 --- /dev/null +++ b/datafusion/physical-expr/src/scalar_subquery.rs @@ -0,0 +1,240 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical expression for uncorrelated scalar subqueries. + +use std::fmt; +use std::hash::Hash; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion_common::{Result, internal_datafusion_err}; +use datafusion_expr::execution_props::{ScalarSubqueryResults, SubqueryIndex}; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +/// A physical expression whose value is provided by a scalar subquery. +/// +/// Subquery execution is handled by `ScalarSubqueryExec`, which stores the +/// result in a shared [`ScalarSubqueryResults`] container. This expression +/// simply reads from that container at the appropriate index. +#[derive(Debug)] +pub struct ScalarSubqueryExpr { + data_type: DataType, + nullable: bool, + /// Index of this subquery in the shared results container. + index: SubqueryIndex, + /// Shared results container populated by `ScalarSubqueryExec`. + results: ScalarSubqueryResults, +} + +impl ScalarSubqueryExpr { + pub fn new( + data_type: DataType, + nullable: bool, + index: SubqueryIndex, + results: ScalarSubqueryResults, + ) -> Self { + Self { + data_type, + nullable, + index, + results, + } + } + + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + pub fn nullable(&self) -> bool { + self.nullable + } + + /// Returns the index of this subquery in the shared results container. + pub fn index(&self) -> SubqueryIndex { + self.index + } + + pub fn results(&self) -> &ScalarSubqueryResults { + &self.results + } +} + +impl fmt::Display for ScalarSubqueryExpr { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.results.get(self.index) { + Some(v) => write!(f, "scalar_subquery({v})"), + None => write!(f, "scalar_subquery()"), + } + } +} + +// Two ScalarSubqueryExprs are considered the "same" if they refer to the +// same underlying shared results container and the same index within it. +impl Hash for ScalarSubqueryExpr { + fn hash(&self, state: &mut H) { + self.results.hash(state); + self.index.hash(state); + } +} + +impl PartialEq for ScalarSubqueryExpr { + fn eq(&self, other: &Self) -> bool { + self.results == other.results && self.index == other.index + } +} + +impl Eq for ScalarSubqueryExpr {} + +impl PhysicalExpr for ScalarSubqueryExpr { + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(Arc::new(Field::new( + "scalar_subquery", + self.data_type.clone(), + self.nullable, + ))) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + let value = self.results.get(self.index).ok_or_else(|| { + internal_datafusion_err!( + "ScalarSubqueryExpr evaluated before the subquery was executed" + ) + })?; + Ok(ColumnarValue::Scalar(value)) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn get_properties(&self, _children: &[ExprProperties]) -> Result { + Ok(ExprProperties::new_unknown().with_order(SortProperties::Singleton)) + } + + fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "(scalar subquery)") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use arrow::array::Int32Array; + use arrow::datatypes::Field; + use datafusion_common::ScalarValue; + + fn make_results(values: Vec>) -> ScalarSubqueryResults { + let results = ScalarSubqueryResults::new(values.len()); + for (index, value) in values.into_iter().enumerate() { + if let Some(value) = value { + results.set(SubqueryIndex::new(index), value).unwrap(); + } + } + results + } + + #[test] + fn test_evaluate_with_value() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![1, 2, 3]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + + let results = make_results(vec![Some(ScalarValue::Int32(Some(42)))]); + let expr = ScalarSubqueryExpr::new( + DataType::Int32, + false, + SubqueryIndex::new(0), + results, + ); + + let result = expr.evaluate(&batch)?; + match result { + ColumnarValue::Scalar(ScalarValue::Int32(Some(42))) => {} + other => panic!("Expected Scalar(Int32(42)), got {other:?}"), + } + Ok(()) + } + + #[test] + fn test_evaluate_before_populated() { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![1]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap(); + + let results = ScalarSubqueryResults::new(1); + let expr = ScalarSubqueryExpr::new( + DataType::Int32, + false, + SubqueryIndex::new(0), + results, + ); + + let result = expr.evaluate(&batch); + assert!(result.is_err()); + } + + #[test] + fn test_identity_equality() { + let results = make_results(vec![None, None]); + + let e1a = ScalarSubqueryExpr::new( + DataType::Int32, + false, + SubqueryIndex::new(0), + results.clone(), + ); + let e1b = ScalarSubqueryExpr::new( + DataType::Int32, + false, + SubqueryIndex::new(0), + results.clone(), + ); + let e2 = ScalarSubqueryExpr::new( + DataType::Int32, + false, + SubqueryIndex::new(1), + results.clone(), + ); + + // Same container + same index → equal + assert_eq!(e1a, e1b); + // Same container, different index → not equal + assert_ne!(e1a, e2); + + // Different container, same index → not equal + let other_results = make_results(vec![None]); + let e3 = ScalarSubqueryExpr::new( + DataType::Int32, + false, + SubqueryIndex::new(0), + other_results, + ); + assert_ne!(e1a, e3); + } +} diff --git a/datafusion/physical-expr/src/simplifier/const_evaluator.rs b/datafusion/physical-expr/src/simplifier/const_evaluator.rs new file mode 100644 index 0000000000000..ba62c49803350 --- /dev/null +++ b/datafusion/physical-expr/src/simplifier/const_evaluator.rs @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Constant expression evaluation for the physical expression simplifier + +use std::sync::Arc; + +use arrow::array::new_null_array; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::{Result, ScalarValue, internal_datafusion_err}; +use datafusion_expr_common::columnar_value::ColumnarValue; + +use crate::PhysicalExpr; +use crate::expressions::{Column, Literal}; + +/// Simplify expressions that consist only of literals by evaluating them. +/// +/// This function checks if all children of the given expression are literals. +/// If so, it evaluates the expression against a dummy RecordBatch and returns +/// the result as a new Literal. +/// +/// # Example transformations +/// - `1 + 2` -> `3` +/// - `(1 + 2) * 3` -> `9` (with bottom-up traversal) +/// - `'hello' || ' world'` -> `'hello world'` +#[deprecated( + since = "53.0.0", + note = "This function will be removed in a future release in favor of a private implementation that depends on other implementation details. Please open an issue if you have a use case for keeping it." +)] +pub fn simplify_const_expr( + expr: Arc, +) -> Result>> { + let batch = create_dummy_batch()?; + // If expr is already a const literal or can't be evaluated into one. + if expr.is::() || (!can_evaluate_as_constant(&expr)) { + return Ok(Transformed::no(expr)); + } + + // Evaluate the expression + match expr.evaluate(batch) { + Ok(ColumnarValue::Scalar(scalar)) => { + Ok(Transformed::yes(Arc::new(Literal::new(scalar)))) + } + Ok(ColumnarValue::Array(arr)) if arr.len() == 1 => { + // Some operations return an array even for scalar inputs + let scalar = ScalarValue::try_from_array(&arr, 0)?; + Ok(Transformed::yes(Arc::new(Literal::new(scalar)))) + } + Ok(_) => { + // Unexpected result - keep original expression + Ok(Transformed::no(expr)) + } + Err(_) => { + // On error, keep original expression + // The expression might succeed at runtime due to short-circuit evaluation + // or other runtime conditions + Ok(Transformed::no(expr)) + } + } +} + +/// Simplify expressions whose immediate children are all literals. +/// +/// This function only checks the direct children of the expression, +/// not the entire subtree. It is designed to be used with bottom-up tree +/// traversal, where children are simplified before parents. +/// +/// # Example transformations +/// - `1 + 2` -> `3` +/// - `(1 + 2) * 3` -> `9` (with bottom-up traversal, inner expr simplified first) +/// - `'hello' || ' world'` -> `'hello world'` +pub(crate) fn simplify_const_expr_immediate( + expr: Arc, + batch: &RecordBatch, +) -> Result>> { + // Already a literal - nothing to do + if expr.is::() { + return Ok(Transformed::no(expr)); + } + + // Column references cannot be evaluated at plan time + if expr.is::() { + return Ok(Transformed::no(expr)); + } + + // Volatile nodes cannot be evaluated at plan time + if expr.is_volatile_node() { + return Ok(Transformed::no(expr)); + } + + // Since transform visits bottom-up, children have already been simplified. + // If all children are now Literals, this node can be const-evaluated. + // This is O(k) where k = number of children, instead of O(subtree). + // + // Leaf nodes (zero children) are rejected here. Const-folding is only + // sound for a node whose value is fully determined by its child literals; + // a leaf has no children, so there is nothing to derive constness from. + // The known leaves that are constant (`Literal`) or known-non-constant + // (`Column`, volatile) are handled by the dedicated checks above. Any + // other leaf is opaque to the simplifier and must be preserved as-is, + // otherwise `all` over an empty child list would vacuously hold and the + // node would be evaluated against the dummy batch, producing a value + // unrelated to its real runtime semantics. + let children = expr.children(); + if children.is_empty() || !children.iter().all(|c| c.is::()) { + return Ok(Transformed::no(expr)); + } + + // Evaluate the expression + match expr.evaluate(batch) { + Ok(ColumnarValue::Scalar(scalar)) => { + Ok(Transformed::yes(Arc::new(Literal::new(scalar)))) + } + Ok(ColumnarValue::Array(arr)) if arr.len() == 1 => { + // Some operations return an array even for scalar inputs + let scalar = ScalarValue::try_from_array(&arr, 0)?; + Ok(Transformed::yes(Arc::new(Literal::new(scalar)))) + } + Ok(_) => { + // Unexpected result - keep original expression + Ok(Transformed::no(expr)) + } + Err(_) => { + // On error, keep original expression + // The expression might succeed at runtime due to short-circuit evaluation + // or other runtime conditions + Ok(Transformed::no(expr)) + } + } +} + +/// Create a 1-row dummy RecordBatch for evaluating constant expressions. +/// +/// The batch is never actually accessed for data - it's just needed because +/// the PhysicalExpr::evaluate API requires a RecordBatch. For expressions +/// that only contain literals, the batch content is irrelevant. +/// +/// This is the same approach used in the logical expression `ConstEvaluator`. +pub(crate) fn create_dummy_batch() -> Result<&'static RecordBatch> { + static DUMMY_BATCH: std::sync::OnceLock> = + std::sync::OnceLock::new(); + DUMMY_BATCH + .get_or_init(|| { + // RecordBatch requires at least one column + let dummy_schema = + Arc::new(Schema::new(vec![Field::new("_", DataType::Null, true)])); + let col = new_null_array(&DataType::Null, 1); + Ok(RecordBatch::try_new(dummy_schema, vec![col])?) + }) + .as_ref() + .map_err(|e| { + internal_datafusion_err!( + "Failed to create dummy batch for constant expression evaluation: {e}" + ) + }) +} + +fn can_evaluate_as_constant(expr: &Arc) -> bool { + let mut can_evaluate = true; + + expr.apply(|e| { + if e.is::() || e.is_volatile_node() { + can_evaluate = false; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + .expect("apply should not fail"); + + can_evaluate +} + +/// Check if this expression has any column references. +#[deprecated( + since = "53.0.0", + note = "This function isn't used internally and is trivial to implement, therefore it will be removed in a future release." +)] +pub fn has_column_references(expr: &Arc) -> bool { + let mut has_columns = false; + expr.apply(|expr| { + if expr.downcast_ref::().is_some() { + has_columns = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + .expect("apply should not fail"); + has_columns +} diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs index 80d6ee0a7b914..3f791569d766e 100644 --- a/datafusion/physical-expr/src/simplifier/mod.rs +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -18,16 +18,22 @@ //! Simplifier for Physical Expressions use arrow::datatypes::Schema; -use datafusion_common::{ - tree_node::{Transformed, TreeNode, TreeNodeRewriter}, - Result, -}; +use datafusion_common::{Result, tree_node::TreeNode}; use std::sync::Arc; -use crate::PhysicalExpr; +use crate::{ + PhysicalExpr, + simplifier::{ + const_evaluator::create_dummy_batch, unwrap_cast::unwrap_cast_in_comparison, + }, +}; +pub mod const_evaluator; +pub mod not; pub mod unwrap_cast; +const MAX_LOOP_COUNT: usize = 5; + /// Simplifies physical expressions by applying various optimizations /// /// This can be useful after adapting expressions from a table schema @@ -44,37 +50,54 @@ impl<'a> PhysicalExprSimplifier<'a> { } /// Simplify a physical expression - pub fn simplify( - &mut self, - expr: Arc, - ) -> Result> { - Ok(expr.rewrite(self)?.data) - } -} + pub fn simplify(&self, expr: Arc) -> Result> { + let mut current_expr = expr; + let mut count = 0; + let schema = self.schema; -impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> { - type Node = Arc; + let batch = create_dummy_batch()?; - fn f_up(&mut self, node: Self::Node) -> Result> { - // Apply unwrap cast optimization - #[cfg(test)] - let original_type = node.data_type(self.schema).unwrap(); - let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, self.schema)?; - #[cfg(test)] - assert_eq!( - unwrapped.data.data_type(self.schema).unwrap(), - original_type, - "Simplified expression should have the same data type as the original" - ); - Ok(unwrapped) + while count < MAX_LOOP_COUNT { + count += 1; + let result = current_expr.transform(|node| { + #[cfg(debug_assertions)] + let original_type = node.data_type(schema).unwrap(); + + // Apply NOT expression simplification first, then unwrap cast optimization, + // then constant expression evaluation + #[expect(deprecated, reason = "`simplify_not_expr` is marked as deprecated until it's made private.")] + let rewritten = not::simplify_not_expr(node, schema)? + .transform_data(|node| unwrap_cast_in_comparison(node, schema))? + .transform_data(|node| { + const_evaluator::simplify_const_expr_immediate(node, batch) + })?; + + #[cfg(debug_assertions)] + assert_eq!( + rewritten.data.data_type(schema).unwrap(), + original_type, + "Simplified expression should have the same data type as the original" + ); + + Ok(rewritten) + })?; + + if !result.transformed { + return Ok(result.data); + } + current_expr = result.data; + } + Ok(current_expr) } } #[cfg(test)] mod tests { use super::*; - use crate::expressions::{col, lit, BinaryExpr, CastExpr, Literal, TryCastExpr}; - use arrow::datatypes::{DataType, Field, Schema}; + use crate::expressions::{ + BinaryExpr, CastExpr, Literal, NotExpr, TryCastExpr, col, in_list, lit, + }; + use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_expr::Operator; @@ -86,10 +109,43 @@ mod tests { ]) } + fn not_test_schema() -> Schema { + Schema::new(vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Boolean, false), + Field::new("c", DataType::Int32, false), + ]) + } + + /// Helper function to extract a Literal from a PhysicalExpr + fn as_literal(expr: &Arc) -> &Literal { + expr.downcast_ref::() + .unwrap_or_else(|| panic!("Expected Literal, got: {expr}")) + } + + /// Helper function to extract a BinaryExpr from a PhysicalExpr + fn as_binary(expr: &Arc) -> &BinaryExpr { + expr.downcast_ref::() + .unwrap_or_else(|| panic!("Expected BinaryExpr, got: {expr}")) + } + + /// Assert that simplifying `input` produces `expected` + fn assert_not_simplify( + simplifier: &PhysicalExprSimplifier, + input: Arc, + expected: Arc, + ) { + let result = simplifier.simplify(Arc::clone(&input)).unwrap(); + assert_eq!( + &result, &expected, + "Simplification should transform:\n input: {input}\n to: {expected}\n got: {result}" + ); + } + #[test] fn test_simplify() { let schema = test_schema(); - let mut simplifier = PhysicalExprSimplifier::new(&schema); + let simplifier = PhysicalExprSimplifier::new(&schema); // Create: cast(c2 as INT32) != INT32(99) let column_expr = col("c2", &schema).unwrap(); @@ -101,26 +157,22 @@ mod tests { // Apply full simplification (uses TreeNodeRewriter) let optimized = simplifier.simplify(binary_expr).unwrap(); - let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + let optimized_binary = as_binary(&optimized); // Should be optimized to: c2 != INT64(99) (c2 is INT64, literal cast to match) let left_expr = optimized_binary.left(); assert!( - left_expr.as_any().downcast_ref::().is_none() - && left_expr.as_any().downcast_ref::().is_none() + left_expr.downcast_ref::().is_none() + && left_expr.downcast_ref::().is_none() ); - let right_literal = optimized_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let right_literal = as_literal(optimized_binary.right()); assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(99))); } #[test] fn test_nested_expression_simplification() { let schema = test_schema(); - let mut simplifier = PhysicalExprSimplifier::new(&schema); + let simplifier = PhysicalExprSimplifier::new(&schema); // Create nested expression: (cast(c1 as INT64) > INT64(5)) OR (cast(c2 as INT32) <= INT32(10)) let c1_expr = col("c1", &schema).unwrap(); @@ -138,51 +190,491 @@ mod tests { // Apply simplification let optimized = simplifier.simplify(or_expr).unwrap(); - let or_binary = optimized.as_any().downcast_ref::().unwrap(); + let or_binary = as_binary(&optimized); // Verify left side: c1 > INT32(5) - let left_binary = or_binary - .left() - .as_any() - .downcast_ref::() - .unwrap(); + let left_binary = as_binary(or_binary.left()); let left_left_expr = left_binary.left(); assert!( - left_left_expr.as_any().downcast_ref::().is_none() - && left_left_expr - .as_any() - .downcast_ref::() - .is_none() + left_left_expr.downcast_ref::().is_none() + && left_left_expr.downcast_ref::().is_none() ); - let left_literal = left_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let left_literal = as_literal(left_binary.right()); assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(5))); // Verify right side: c2 <= INT64(10) - let right_binary = or_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let right_binary = as_binary(or_binary.right()); let right_left_expr = right_binary.left(); assert!( - right_left_expr - .as_any() - .downcast_ref::() - .is_none() - && right_left_expr - .as_any() - .downcast_ref::() - .is_none() + right_left_expr.downcast_ref::().is_none() + && right_left_expr.downcast_ref::().is_none() ); - let right_literal = right_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let right_literal = as_literal(right_binary.right()); assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(10))); } + + #[test] + fn test_double_negation_elimination() -> Result<()> { + let schema = not_test_schema(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(NOT(c > 5)) -> c > 5 + let inner_expr: Arc = Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::Gt, + lit(ScalarValue::Int32(Some(5))), + )); + let inner_not = Arc::new(NotExpr::new(Arc::clone(&inner_expr))); + let double_not: Arc = Arc::new(NotExpr::new(inner_not)); + + let expected = inner_expr; + assert_not_simplify(&simplifier, double_not, expected); + Ok(()) + } + + #[test] + fn test_not_literal() -> Result<()> { + let schema = not_test_schema(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(TRUE) -> FALSE + let not_true = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(true))))); + let expected = lit(ScalarValue::Boolean(Some(false))); + assert_not_simplify(&simplifier, not_true, expected); + + // NOT(FALSE) -> TRUE + let not_false = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(false))))); + let expected = lit(ScalarValue::Boolean(Some(true))); + assert_not_simplify(&simplifier, not_false, expected); + + Ok(()) + } + + #[test] + fn test_negate_comparison() -> Result<()> { + let schema = not_test_schema(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(c = 5) -> c != 5 + let not_eq = Arc::new(NotExpr::new(Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::Eq, + lit(ScalarValue::Int32(Some(5))), + )))); + let expected = Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::NotEq, + lit(ScalarValue::Int32(Some(5))), + )); + assert_not_simplify(&simplifier, not_eq, expected); + + Ok(()) + } + + #[test] + fn test_demorgans_law_and() -> Result<()> { + let schema = not_test_schema(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(a AND b) -> NOT a OR NOT b + let and_expr = Arc::new(BinaryExpr::new( + col("a", &schema)?, + Operator::And, + col("b", &schema)?, + )); + let not_and: Arc = Arc::new(NotExpr::new(and_expr)); + + let expected: Arc = Arc::new(BinaryExpr::new( + Arc::new(NotExpr::new(col("a", &schema)?)), + Operator::Or, + Arc::new(NotExpr::new(col("b", &schema)?)), + )); + assert_not_simplify(&simplifier, not_and, expected); + + Ok(()) + } + + #[test] + fn test_demorgans_law_or() -> Result<()> { + let schema = not_test_schema(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(a OR b) -> NOT a AND NOT b + let or_expr = Arc::new(BinaryExpr::new( + col("a", &schema)?, + Operator::Or, + col("b", &schema)?, + )); + let not_or: Arc = Arc::new(NotExpr::new(or_expr)); + + let expected: Arc = Arc::new(BinaryExpr::new( + Arc::new(NotExpr::new(col("a", &schema)?)), + Operator::And, + Arc::new(NotExpr::new(col("b", &schema)?)), + )); + assert_not_simplify(&simplifier, not_or, expected); + + Ok(()) + } + + #[test] + fn test_demorgans_with_comparison_simplification() -> Result<()> { + let schema = not_test_schema(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(c = 1 AND c = 2) -> c != 1 OR c != 2 + let eq1 = Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::Eq, + lit(ScalarValue::Int32(Some(1))), + )); + let eq2 = Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::Eq, + lit(ScalarValue::Int32(Some(2))), + )); + let and_expr = Arc::new(BinaryExpr::new(eq1, Operator::And, eq2)); + let not_and: Arc = Arc::new(NotExpr::new(and_expr)); + + let expected: Arc = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::NotEq, + lit(ScalarValue::Int32(Some(1))), + )), + Operator::Or, + Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::NotEq, + lit(ScalarValue::Int32(Some(2))), + )), + )); + assert_not_simplify(&simplifier, not_and, expected); + + Ok(()) + } + + #[test] + fn test_not_of_not_and_not() -> Result<()> { + let schema = not_test_schema(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(NOT(a) AND NOT(b)) -> a OR b + let not_a = Arc::new(NotExpr::new(col("a", &schema)?)); + let not_b = Arc::new(NotExpr::new(col("b", &schema)?)); + let and_expr = Arc::new(BinaryExpr::new(not_a, Operator::And, not_b)); + let not_and: Arc = Arc::new(NotExpr::new(and_expr)); + + let expected: Arc = Arc::new(BinaryExpr::new( + col("a", &schema)?, + Operator::Or, + col("b", &schema)?, + )); + assert_not_simplify(&simplifier, not_and, expected); + + Ok(()) + } + + #[test] + fn test_not_in_list() -> Result<()> { + let schema = not_test_schema(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(c IN (1, 2, 3)) -> c NOT IN (1, 2, 3) + let list = vec![ + lit(ScalarValue::Int32(Some(1))), + lit(ScalarValue::Int32(Some(2))), + lit(ScalarValue::Int32(Some(3))), + ]; + let in_list_expr = in_list(col("c", &schema)?, list.clone(), &false, &schema)?; + let not_in: Arc = Arc::new(NotExpr::new(in_list_expr)); + + let expected = in_list(col("c", &schema)?, list, &true, &schema)?; + assert_not_simplify(&simplifier, not_in, expected); + + Ok(()) + } + + #[test] + fn test_not_not_in_list() -> Result<()> { + let schema = not_test_schema(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(c NOT IN (1, 2, 3)) -> c IN (1, 2, 3) + let list = vec![ + lit(ScalarValue::Int32(Some(1))), + lit(ScalarValue::Int32(Some(2))), + lit(ScalarValue::Int32(Some(3))), + ]; + let not_in_list_expr = in_list(col("c", &schema)?, list.clone(), &true, &schema)?; + let not_not_in: Arc = Arc::new(NotExpr::new(not_in_list_expr)); + + let expected = in_list(col("c", &schema)?, list, &false, &schema)?; + assert_not_simplify(&simplifier, not_not_in, expected); + + Ok(()) + } + + #[test] + fn test_double_not_in_list() -> Result<()> { + let schema = not_test_schema(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // NOT(NOT(c IN (1, 2, 3))) -> c IN (1, 2, 3) + let list = vec![ + lit(ScalarValue::Int32(Some(1))), + lit(ScalarValue::Int32(Some(2))), + lit(ScalarValue::Int32(Some(3))), + ]; + let in_list_expr = in_list(col("c", &schema)?, list.clone(), &false, &schema)?; + let not_in = Arc::new(NotExpr::new(in_list_expr)); + let double_not: Arc = Arc::new(NotExpr::new(not_in)); + + let expected = in_list(col("c", &schema)?, list, &false, &schema)?; + assert_not_simplify(&simplifier, double_not, expected); + + Ok(()) + } + + #[test] + fn test_deeply_nested_not() -> Result<()> { + let schema = not_test_schema(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // Create a deeply nested NOT expression: NOT(NOT(NOT(...NOT(c > 5)...))) + // This tests that we don't get stack overflow with many nested NOTs. + // With recursive_protection enabled (default), this should work by + // automatically growing the stack as needed. + let inner_expr: Arc = Arc::new(BinaryExpr::new( + col("c", &schema)?, + Operator::Gt, + lit(ScalarValue::Int32(Some(5))), + )); + + let mut expr = Arc::clone(&inner_expr); + // Create 200 layers of NOT to test deep recursion handling + for _ in 0..200 { + expr = Arc::new(NotExpr::new(expr)); + } + + // With 200 NOTs (even number), should simplify back to the original expression + let expected = inner_expr; + assert_not_simplify(&simplifier, Arc::clone(&expr), expected); + + // Manually dismantle the deep input expression to avoid Stack Overflow on Drop + // If we just let `expr` go out of scope, Rust's recursive Drop will blow the stack + // even with recursive_protection, because Drop doesn't use the #[recursive] attribute. + // We peel off layers one by one to avoid deep recursion in Drop. + while let Some(not_expr) = expr.downcast_ref::() { + // Clone the child (Arc increment). + // Now child has 2 refs: one in parent, one in `child`. + let child = Arc::clone(not_expr.arg()); + + // Reassign `expr` to `child`. + // This drops the old `expr` (Parent). + // Parent refcount -> 0, Parent is dropped. + // Parent drops its reference to Child. + // Child refcount decrements 2 -> 1. + // Child is NOT dropped recursively because we still hold it in `expr` + expr = child; + } + + Ok(()) + } + + #[test] + fn test_simplify_literal_binary_expr() { + let schema = Schema::empty(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // 1 + 2 -> 3 + let expr: Arc = + Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32))); + let result = simplifier.simplify(expr).unwrap(); + let literal = as_literal(&result); + assert_eq!(literal.value(), &ScalarValue::Int32(Some(3))); + } + + #[test] + fn test_simplify_literal_comparison() { + let schema = Schema::empty(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // 5 > 3 -> true + let expr: Arc = + Arc::new(BinaryExpr::new(lit(5i32), Operator::Gt, lit(3i32))); + let result = simplifier.simplify(expr).unwrap(); + let literal = as_literal(&result); + assert_eq!(literal.value(), &ScalarValue::Boolean(Some(true))); + + // 2 > 3 -> false + let expr: Arc = + Arc::new(BinaryExpr::new(lit(2i32), Operator::Gt, lit(3i32))); + let result = simplifier.simplify(expr).unwrap(); + let literal = as_literal(&result); + assert_eq!(literal.value(), &ScalarValue::Boolean(Some(false))); + } + + #[test] + fn test_simplify_nested_literal_expr() { + let schema = Schema::empty(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // (1 + 2) * 3 -> 9 + let inner: Arc = + Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32))); + let expr: Arc = + Arc::new(BinaryExpr::new(inner, Operator::Multiply, lit(3i32))); + let result = simplifier.simplify(expr).unwrap(); + let literal = as_literal(&result); + assert_eq!(literal.value(), &ScalarValue::Int32(Some(9))); + } + + #[test] + fn test_simplify_deeply_nested_literals() { + let schema = Schema::empty(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // ((1 + 2) * 3) + ((4 - 1) * 2) -> 9 + 6 -> 15 + let left: Arc = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32))), + Operator::Multiply, + lit(3i32), + )); + let right: Arc = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new(lit(4i32), Operator::Minus, lit(1i32))), + Operator::Multiply, + lit(2i32), + )); + let expr: Arc = + Arc::new(BinaryExpr::new(left, Operator::Plus, right)); + let result = simplifier.simplify(expr).unwrap(); + let literal = as_literal(&result); + assert_eq!(literal.value(), &ScalarValue::Int32(Some(15))); + } + + #[test] + fn test_no_simplify_with_column() { + let schema = test_schema(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // c1 + 2 should NOT be simplified (has column reference) + let expr: Arc = Arc::new(BinaryExpr::new( + col("c1", &schema).unwrap(), + Operator::Plus, + lit(2i32), + )); + let result = simplifier.simplify(expr).unwrap(); + // Should remain a BinaryExpr, not become a Literal + assert!(result.downcast_ref::().is_some()); + } + + #[test] + fn test_partial_simplify_with_column() { + let schema = test_schema(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // (1 + 2) + c1 should simplify the literal part: 3 + c1 + let literal_part: Arc = + Arc::new(BinaryExpr::new(lit(1i32), Operator::Plus, lit(2i32))); + let expr: Arc = Arc::new(BinaryExpr::new( + literal_part, + Operator::Plus, + col("c1", &schema).unwrap(), + )); + let result = simplifier.simplify(expr).unwrap(); + + // Should be a BinaryExpr with a Literal(3) on the left + let binary = as_binary(&result); + let left_literal = as_literal(binary.left()); + assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(3))); + } + + /// Regression test for https://github.com/apache/datafusion/issues/22367. + /// + /// A leaf `PhysicalExpr` that is neither a `Literal` nor a `Column` + /// (nor volatile) must not be const-folded: it has no children to + /// derive constness from, and evaluating it against the dummy batch + /// produces a value unrelated to its real runtime semantics. Without + /// the zero-children guard, `all(empty)` would vacuously hold and the + /// node would be replaced with whatever scalar fell out of the dummy + /// evaluation. Verify the node is left untouched. + #[test] + fn test_no_simplify_opaque_leaf_expr() { + use arrow::array::ArrayRef; + use arrow::array::Int32Array; + use arrow::record_batch::RecordBatch; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr as PhysicalExprTrait; + use std::fmt; + + #[derive(Debug, Clone, PartialEq, Eq, Hash)] + struct OpaqueLeaf; + + impl fmt::Display for OpaqueLeaf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "OpaqueLeaf") + } + } + + impl PhysicalExprTrait for OpaqueLeaf { + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Int32) + } + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + fn evaluate(&self, batch: &RecordBatch) -> Result { + // Simulate the broken FFI Column path: when handed a dummy + // batch, return whatever scalar happens to materialize. If + // the simplifier ever reaches this branch for a leaf node, + // the predicate has already been silently corrupted. + let arr: ArrayRef = Arc::new(Int32Array::from(vec![0; batch.num_rows()])); + Ok(ColumnarValue::Array(arr)) + } + fn children(&self) -> Vec<&Arc> { + vec![] + } + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "OpaqueLeaf") + } + } + + let schema = Schema::empty(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + let opaque: Arc = Arc::new(OpaqueLeaf); + let result = simplifier.simplify(Arc::clone(&opaque)).unwrap(); + + assert!( + result.downcast_ref::().is_none(), + "opaque leaf must not be rewritten to a Literal, got: {result}" + ); + assert_eq!(&result, &opaque); + } + + #[test] + fn test_simplify_literal_string_concat() { + let schema = Schema::empty(); + let simplifier = PhysicalExprSimplifier::new(&schema); + + // 'hello' || ' world' -> 'hello world' + let expr: Arc = Arc::new(BinaryExpr::new( + lit("hello"), + Operator::StringConcat, + lit(" world"), + )); + let result = simplifier.simplify(expr).unwrap(); + let literal = as_literal(&result); + assert_eq!( + literal.value(), + &ScalarValue::Utf8(Some("hello world".to_string())) + ); + } } diff --git a/datafusion/physical-expr/src/simplifier/not.rs b/datafusion/physical-expr/src/simplifier/not.rs new file mode 100644 index 0000000000000..886cadd6a262d --- /dev/null +++ b/datafusion/physical-expr/src/simplifier/not.rs @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Simplify NOT expressions in physical expressions +//! +//! This module provides optimizations for NOT expressions such as: +//! - Double negation elimination: NOT(NOT(expr)) -> expr +//! - NOT with binary comparisons: NOT(a = b) -> a != b +//! - NOT with IN expressions: NOT(a IN (list)) -> a NOT IN (list) +//! - De Morgan's laws: NOT(A AND B) -> NOT A OR NOT B +//! - Constant folding: NOT(TRUE) -> FALSE, NOT(FALSE) -> TRUE +//! +//! This function is designed to work with TreeNodeRewriter's f_up traversal, +//! which means children are already simplified when this function is called. +//! The TreeNodeRewriter will automatically call this function repeatedly until +//! no more transformations are possible. + +use std::sync::Arc; + +use arrow::datatypes::Schema; +use datafusion_common::{Result, ScalarValue, tree_node::Transformed}; +use datafusion_expr::Operator; + +use crate::PhysicalExpr; +use crate::expressions::{BinaryExpr, InListExpr, Literal, NotExpr, in_list, lit}; + +/// Attempts to simplify NOT expressions by applying one level of transformation +/// +/// This function applies a single simplification rule and returns. When used with +/// TreeNodeRewriter, multiple passes will automatically be applied until no more +/// transformations are possible. +#[deprecated( + since = "53.0.0", + note = "This function will be made private in a future release, please file an issue if you have a reason for keeping it public." +)] +pub fn simplify_not_expr( + expr: Arc, + schema: &Schema, +) -> Result>> { + // Check if this is a NOT expression + let not_expr = match expr.downcast_ref::() { + Some(not_expr) => not_expr, + None => return Ok(Transformed::no(expr)), + }; + + let inner_expr = not_expr.arg(); + + // Handle NOT(NOT(expr)) -> expr (double negation elimination) + if let Some(inner_not) = inner_expr.downcast_ref::() { + return Ok(Transformed::yes(Arc::clone(inner_not.arg()))); + } + + // Handle NOT(literal) -> !literal + if let Some(literal) = inner_expr.downcast_ref::() { + if let ScalarValue::Boolean(Some(val)) = literal.value() { + return Ok(Transformed::yes(lit(ScalarValue::Boolean(Some(!val))))); + } + if let ScalarValue::Boolean(None) = literal.value() { + return Ok(Transformed::yes(lit(ScalarValue::Boolean(None)))); + } + } + + // Handle NOT(IN list) -> NOT IN list + if let Some(in_list_expr) = inner_expr.downcast_ref::() { + let negated = !in_list_expr.negated(); + let new_in_list = in_list( + Arc::clone(in_list_expr.expr()), + in_list_expr.list().to_vec(), + &negated, + schema, + )?; + return Ok(Transformed::yes(new_in_list)); + } + + // Handle NOT(binary_expr) + if let Some(binary_expr) = inner_expr.downcast_ref::() { + if let Some(negated_op) = binary_expr.op().negate() { + let new_binary = Arc::new(BinaryExpr::new( + Arc::clone(binary_expr.left()), + negated_op, + Arc::clone(binary_expr.right()), + )); + return Ok(Transformed::yes(new_binary)); + } + + // Handle De Morgan's laws for AND/OR + match binary_expr.op() { + Operator::And => { + // NOT(A AND B) -> NOT A OR NOT B + let not_left: Arc = + Arc::new(NotExpr::new(Arc::clone(binary_expr.left()))); + let not_right: Arc = + Arc::new(NotExpr::new(Arc::clone(binary_expr.right()))); + let new_binary = + Arc::new(BinaryExpr::new(not_left, Operator::Or, not_right)); + return Ok(Transformed::yes(new_binary)); + } + Operator::Or => { + // NOT(A OR B) -> NOT A AND NOT B + let not_left: Arc = + Arc::new(NotExpr::new(Arc::clone(binary_expr.left()))); + let not_right: Arc = + Arc::new(NotExpr::new(Arc::clone(binary_expr.right()))); + let new_binary = + Arc::new(BinaryExpr::new(not_left, Operator::And, not_right)); + return Ok(Transformed::yes(new_binary)); + } + _ => {} + } + } + + // If no simplification possible, return the original expression + Ok(Transformed::no(expr)) +} diff --git a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs index d409ce9cb5bf2..4f4dfb2c20a81 100644 --- a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs +++ b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs @@ -34,29 +34,24 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{ - tree_node::{Transformed, TreeNode}, - Result, ScalarValue, -}; +use datafusion_common::{Result, ScalarValue, tree_node::Transformed}; use datafusion_expr::Operator; use datafusion_expr_common::casts::try_cast_literal_to_type; -use crate::expressions::{lit, BinaryExpr, CastExpr, Literal, TryCastExpr}; use crate::PhysicalExpr; +use crate::expressions::{BinaryExpr, CastExpr, Literal, TryCastExpr, lit}; /// Attempts to unwrap casts in comparison expressions. pub(crate) fn unwrap_cast_in_comparison( expr: Arc, schema: &Schema, ) -> Result>> { - expr.transform_down(|e| { - if let Some(binary) = e.as_any().downcast_ref::() { - if let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? { - return Ok(Transformed::yes(unwrapped)); - } - } - Ok(Transformed::no(e)) - }) + if let Some(binary) = expr.downcast_ref::() + && let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? + { + return Ok(Transformed::yes(unwrapped)); + } + Ok(Transformed::no(expr)) } /// Try to unwrap casts in binary expressions @@ -67,37 +62,34 @@ fn try_unwrap_cast_binary( // Case 1: cast(left_expr) op literal if let (Some((inner_expr, _cast_type)), Some(literal)) = ( extract_cast_info(binary.left()), - binary.right().as_any().downcast_ref::(), - ) { - if binary.op().supports_propagation() { - if let Some(unwrapped) = try_unwrap_cast_comparison( - Arc::clone(inner_expr), - literal.value(), - *binary.op(), - schema, - )? { - return Ok(Some(unwrapped)); - } - } + binary.right().downcast_ref::(), + ) && binary.op().supports_propagation() + && let Some(unwrapped) = try_unwrap_cast_comparison( + Arc::clone(inner_expr), + literal.value(), + *binary.op(), + schema, + )? + { + return Ok(Some(unwrapped)); } // Case 2: literal op cast(right_expr) if let (Some(literal), Some((inner_expr, _cast_type))) = ( - binary.left().as_any().downcast_ref::(), + binary.left().downcast_ref::(), extract_cast_info(binary.right()), ) { // For literal op cast(expr), we need to swap the operator - if let Some(swapped_op) = binary.op().swap() { - if binary.op().supports_propagation() { - if let Some(unwrapped) = try_unwrap_cast_comparison( - Arc::clone(inner_expr), - literal.value(), - swapped_op, - schema, - )? { - return Ok(Some(unwrapped)); - } - } + if let Some(swapped_op) = binary.op().swap() + && binary.op().supports_propagation() + && let Some(unwrapped) = try_unwrap_cast_comparison( + Arc::clone(inner_expr), + literal.value(), + swapped_op, + schema, + )? + { + return Ok(Some(unwrapped)); } // If the operator cannot be swapped, we skip this optimization case // but don't prevent other optimizations @@ -113,9 +105,9 @@ fn try_unwrap_cast_binary( fn extract_cast_info( expr: &Arc, ) -> Option<(&Arc, &DataType)> { - if let Some(cast) = expr.as_any().downcast_ref::() { + if let Some(cast) = expr.downcast_ref::() { Some((cast.expr(), cast.cast_type())) - } else if let Some(try_cast) = expr.as_any().downcast_ref::() { + } else if let Some(try_cast) = expr.downcast_ref::() { Some((try_cast.expr(), try_cast.cast_type())) } else { None @@ -145,27 +137,25 @@ fn try_unwrap_cast_comparison( #[cfg(test)] mod tests { use super::*; - use crate::expressions::{col, lit}; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::ScalarValue; - use datafusion_expr::Operator; + use crate::expressions::col; + use arrow::datatypes::Field; + use datafusion_common::tree_node::TreeNode; /// Check if an expression is a cast expression fn is_cast_expr(expr: &Arc) -> bool { - expr.as_any().downcast_ref::().is_some() - || expr.as_any().downcast_ref::().is_some() + expr.downcast_ref::().is_some() + || expr.downcast_ref::().is_some() } /// Check if a binary expression is suitable for cast unwrapping fn is_binary_expr_with_cast_and_literal(binary: &BinaryExpr) -> bool { // Check if left is cast and right is literal let left_cast_right_literal = is_cast_expr(binary.left()) - && binary.right().as_any().downcast_ref::().is_some(); + && binary.right().downcast_ref::().is_some(); // Check if left is literal and right is cast - let left_literal_right_cast = - binary.left().as_any().downcast_ref::().is_some() - && is_cast_expr(binary.right()); + let left_literal_right_cast = binary.left().downcast_ref::().is_some() + && is_cast_expr(binary.right()); left_cast_right_literal || left_literal_right_cast } @@ -197,17 +187,13 @@ mod tests { // The result should be: c1 > INT32(10) let optimized = result.data; - let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + let optimized_binary = optimized.downcast_ref::().unwrap(); // Check that left side is no longer a cast assert!(!is_cast_expr(optimized_binary.left())); // Check that right side is a literal with the correct type and value - let right_literal = optimized_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let right_literal = optimized_binary.right().downcast_ref::().unwrap(); assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(10))); } @@ -230,7 +216,7 @@ mod tests { // The result should be equivalent to: c1 > INT32(10) let optimized = result.data; - let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + let optimized_binary = optimized.downcast_ref::().unwrap(); // Check the operator was swapped assert_eq!(*optimized_binary.op(), Operator::Gt); @@ -263,9 +249,7 @@ mod tests { let literal_expr = lit(10i64); let binary_expr = Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr)); - let binary_ref = binary_expr.as_any().downcast_ref::().unwrap(); - - assert!(is_binary_expr_with_cast_and_literal(binary_ref)); + assert!(is_binary_expr_with_cast_and_literal(&binary_expr)); } #[test] @@ -297,7 +281,7 @@ mod tests { // The result should be: decimal_col >= Decimal128(400, 9, 2) let optimized = result.data; - let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + let optimized_binary = optimized.downcast_ref::().unwrap(); // Check operator was swapped correctly assert_eq!(*optimized_binary.op(), Operator::GtEq); @@ -306,11 +290,7 @@ mod tests { assert!(!is_cast_expr(optimized_binary.left())); // Check that right side is a literal with the correct type - let right_literal = optimized_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let right_literal = optimized_binary.right().downcast_ref::().unwrap(); assert_eq!( right_literal.value().data_type(), DataType::Decimal128(9, 2) @@ -346,8 +326,7 @@ mod tests { assert!(result.transformed); let optimized = result.data; - let optimized_binary = - optimized.as_any().downcast_ref::().unwrap(); + let optimized_binary = optimized.downcast_ref::().unwrap(); // Check the operator was swapped correctly assert_eq!( @@ -360,11 +339,8 @@ mod tests { assert!(!is_cast_expr(optimized_binary.left())); // Check that the literal was cast to the column type - let right_literal = optimized_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let right_literal = + optimized_binary.right().downcast_ref::().unwrap(); assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(100))); } } @@ -437,12 +413,8 @@ mod tests { // Verify the NULL was cast to the column type let optimized = result.data; - let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); - let right_literal = optimized_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let optimized_binary = optimized.downcast_ref::().unwrap(); + let right_literal = optimized_binary.right().downcast_ref::().unwrap(); assert_eq!(right_literal.value(), &ScalarValue::Int32(None)); } @@ -487,28 +459,22 @@ mod tests { let and_expr = Arc::new(BinaryExpr::new(compare1, Operator::And, compare2)); - // Apply unwrap cast optimization - let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap(); + // Apply unwrap cast optimization recursively + let result = (and_expr as Arc) + .transform_down(|node| unwrap_cast_in_comparison(node, &schema)) + .unwrap(); // Should be transformed assert!(result.transformed); // Verify the AND operator is preserved let optimized = result.data; - let and_binary = optimized.as_any().downcast_ref::().unwrap(); + let and_binary = optimized.downcast_ref::().unwrap(); assert_eq!(*and_binary.op(), Operator::And); // Both sides should have their casts unwrapped - let left_binary = and_binary - .left() - .as_any() - .downcast_ref::() - .unwrap(); - let right_binary = and_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let left_binary = and_binary.left().downcast_ref::().unwrap(); + let right_binary = and_binary.right().downcast_ref::().unwrap(); assert!(!is_cast_expr(left_binary.left())); assert!(!is_cast_expr(right_binary.left())); @@ -532,17 +498,13 @@ mod tests { assert!(result.transformed); let optimized = result.data; - let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + let optimized_binary = optimized.downcast_ref::().unwrap(); // Verify the try_cast was removed assert!(!is_cast_expr(optimized_binary.left())); // Verify the literal was converted - let right_literal = optimized_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let right_literal = optimized_binary.right().downcast_ref::().unwrap(); assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(100))); } @@ -605,42 +567,28 @@ mod tests { // Create AND expression let and_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::And, c2_binary)); - // Apply unwrap cast optimization - let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap(); + // Apply unwrap cast optimization recursively + let result = (and_expr as Arc) + .transform_down(|node| unwrap_cast_in_comparison(node, &schema)) + .unwrap(); // Should be transformed assert!(result.transformed); // Verify both sides of the AND were optimized let optimized = result.data; - let and_binary = optimized.as_any().downcast_ref::().unwrap(); + let and_binary = optimized.downcast_ref::().unwrap(); // Left side should be: c1 > INT32(10) - let left_binary = and_binary - .left() - .as_any() - .downcast_ref::() - .unwrap(); + let left_binary = and_binary.left().downcast_ref::().unwrap(); assert!(!is_cast_expr(left_binary.left())); - let left_literal = left_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let left_literal = left_binary.right().downcast_ref::().unwrap(); assert_eq!(left_literal.value(), &ScalarValue::Int32(Some(10))); // Right side should be: c2 = INT64(20) (c2 is already INT64, literal cast to match) - let right_binary = and_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let right_binary = and_binary.right().downcast_ref::().unwrap(); assert!(!is_cast_expr(right_binary.left())); - let right_literal = right_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); + let right_literal = right_binary.right().downcast_ref::().unwrap(); assert_eq!(right_literal.value(), &ScalarValue::Int64(Some(20))); } } diff --git a/datafusion/physical-expr/src/statistics/mod.rs b/datafusion/physical-expr/src/statistics/mod.rs index 02897e0594578..115e1b66ebfb5 100644 --- a/datafusion/physical-expr/src/statistics/mod.rs +++ b/datafusion/physical-expr/src/statistics/mod.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! Statistics and constraint propagation library +//! Statistics and constraint propagation library. +//! +//! All items exported from this module are **deprecated**; +//! see for details. pub mod stats_solver; diff --git a/datafusion/physical-expr/src/statistics/stats_solver.rs b/datafusion/physical-expr/src/statistics/stats_solver.rs index fae6996646b27..862ff4a032871 100644 --- a/datafusion/physical-expr/src/statistics/stats_solver.rs +++ b/datafusion/physical-expr/src/statistics/stats_solver.rs @@ -15,26 +15,37 @@ // specific language governing permissions and limitations // under the License. +//! DAG-based statistics propagation for the Statistics V2 framework. +//! +//! All public items in this module are **deprecated** as of `54.0.0`. +//! See for details. + +#![allow(deprecated)] + use std::sync::Arc; use crate::expressions::Literal; use crate::intervals::cp_solver::PropagationResult; use crate::physical_expr::PhysicalExpr; -use crate::utils::{build_dag, ExprTreeNode}; +use crate::utils::{ExprTreeNode, build_dag}; use arrow::datatypes::{DataType, Schema}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::statistics::Distribution; use datafusion_expr_common::interval_arithmetic::Interval; +use petgraph::Outgoing; use petgraph::adj::DefaultIx; use petgraph::prelude::Bfs; use petgraph::stable_graph::{NodeIndex, StableGraph}; use petgraph::visit::DfsPostOrder; -use petgraph::Outgoing; /// This object implements a directed acyclic expression graph (DAEG) that /// is used to compute statistics/distributions for expressions hierarchically. +#[deprecated( + since = "54.0.0", + note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071" +)] #[derive(Clone, Debug)] pub struct ExprStatisticsGraph { graph: StableGraph, @@ -43,6 +54,10 @@ pub struct ExprStatisticsGraph { /// This is a node in the DAEG; it encapsulates a reference to the actual /// [`PhysicalExpr`] as well as its statistics/distribution. +#[deprecated( + since = "54.0.0", + note = "Part of the unused Statistics V2 framework; see https://github.com/apache/datafusion/pull/22071" +)] #[derive(Clone, Debug)] pub struct ExprStatisticsGraphNode { expr: Arc, @@ -86,7 +101,7 @@ impl ExprStatisticsGraphNode { /// indefinite range (i.e. `[-∞, ∞]`). pub fn make_node(node: &ExprTreeNode, schema: &Schema) -> Result { let expr = Arc::clone(&node.expr); - if let Some(literal) = expr.as_any().downcast_ref::() { + if let Some(literal) = expr.downcast_ref::() { let value = literal.value(); Interval::try_new(value.clone(), value.clone()) .and_then(|interval| Self::new_uniform(expr, interval)) @@ -205,7 +220,7 @@ impl ExprStatisticsGraph { mod tests { use std::sync::Arc; - use crate::expressions::{binary, try_cast, Column}; + use crate::expressions::{Column, binary, try_cast}; use crate::intervals::cp_solver::PropagationResult; use crate::statistics::stats_solver::ExprStatisticsGraph; diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index d63a9590c3f66..c36e69603681e 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -19,7 +19,7 @@ //! constant. use crate::utils::split_disjunction; -use crate::{split_conjunction, PhysicalExpr}; +use crate::{PhysicalExpr, split_conjunction}; use datafusion_common::{Column, HashMap, ScalarValue}; use datafusion_expr::Operator; use std::collections::HashSet; @@ -93,6 +93,7 @@ impl LiteralGuarantee { /// Create a new instance of the guarantee if the provided operator is /// supported. Returns None otherwise. See [`LiteralGuarantee::analyze`] to /// create these structures from an predicate (boolean expression). + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key fn new<'a>( column_name: impl Into, guarantee: Guarantee, @@ -125,9 +126,8 @@ impl LiteralGuarantee { .fold(GuaranteeBuilder::new(), |builder, expr| { if let Some(cel) = ColOpLit::try_new(expr) { builder.aggregate_conjunct(&cel) - } else if let Some(inlist) = expr - .as_any() - .downcast_ref::() + } else if let Some(inlist) = + expr.downcast_ref::() { if let Some(inlist) = ColInList::try_new(inlist) { builder.aggregate_multi_conjunct( @@ -233,7 +233,7 @@ impl LiteralGuarantee { builder = builder.aggregate_multi_conjunct( col, Guarantee::In, - literals.into_iter(), + literals, ); } @@ -309,6 +309,7 @@ impl<'a> GuaranteeBuilder<'a> { /// * `AND (a IN (1,2,3))`: a is in (1, 2, or 3) /// * `AND (a != 1 OR a != 2 OR a != 3)`: a is not in (1, 2, or 3) /// * `AND (a NOT IN (1,2,3))`: a is not in (1, 2, or 3) + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key fn aggregate_multi_conjunct( mut self, col: &'a crate::expressions::Column, @@ -391,15 +392,10 @@ impl<'a> ColOpLit<'a> { /// /// Returns None otherwise fn try_new(expr: &'a Arc) -> Option { - let binary_expr = expr - .as_any() - .downcast_ref::()?; + let binary_expr = expr.downcast_ref::()?; - let (left, op, right) = ( - binary_expr.left().as_any(), - binary_expr.op(), - binary_expr.right().as_any(), - ); + let (left, op, right) = + (binary_expr.left(), binary_expr.op(), binary_expr.right()); let guarantee = match op { Operator::Eq => Guarantee::In, Operator::NotEq => Guarantee::NotIn, @@ -447,15 +443,12 @@ impl<'a> ColInList<'a> { /// Returns None otherwise fn try_new(inlist: &'a crate::expressions::InListExpr) -> Option { // Only support single-column inlist currently, multi-column inlist is not supported - let col = inlist - .expr() - .as_any() - .downcast_ref::()?; + let col = inlist.expr().downcast_ref::()?; let literals = inlist .list() .iter() - .map(|e| e.as_any().downcast_ref::()) + .map(|e| e.downcast_ref::()) .collect::>>()?; let guarantee = if inlist.negated() { @@ -480,10 +473,7 @@ enum ColOpLitOrInList<'a> { impl<'a> ColOpLitOrInList<'a> { fn try_new(expr: &'a Arc) -> Option { - match expr - .as_any() - .downcast_ref::() - { + match expr.downcast_ref::() { Some(inlist) => Some(Self::ColInList(ColInList::try_new(inlist)?)), None => ColOpLit::try_new(expr).map(Self::ColOpLit), } @@ -550,7 +540,7 @@ mod test { use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_expr::expr_fn::*; - use datafusion_expr::{lit, Expr}; + use datafusion_expr::{Expr, lit}; use itertools::Itertools; diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 745ae855efee2..1be57c9192626 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -21,10 +21,11 @@ pub use guarantee::{Guarantee, LiteralGuarantee}; use std::borrow::Borrow; use std::sync::Arc; -use crate::expressions::{BinaryExpr, Column}; +use crate::expressions::{BinaryExpr, Column, Literal}; use crate::tree_node::ExprContext; -use crate::PhysicalExpr; -use crate::PhysicalSortExpr; +use crate::{ + AcrossPartitions, ConstExpr, EquivalenceProperties, PhysicalExpr, PhysicalSortExpr, +}; use arrow::datatypes::Schema; use datafusion_common::tree_node::{ @@ -45,6 +46,65 @@ pub fn split_conjunction( split_impl(Operator::And, predicate, vec![]) } +impl ConstExpr { + /// Collects predicate-derived constants from equality conjunctions. + /// + /// For each equality predicate of the form `lhs = rhs`, if either side is + /// already known constant according to `input_eqs`, or is a literal, then + /// the other side is also constant and will be returned as a [`ConstExpr`]. + /// + /// Literals are treated as uniform constants across partitions, so + /// `col = literal` produces a constant for `col` with the literal value. + /// + /// For example, given predicate `a = 5 AND b = c` where `c` is already + /// known constant, this returns constants for both `a` (Uniform with value + /// 5) and `b` (propagating `c`'s across-partitions value). + pub fn collect_predicate_constants( + input_eqs: &EquivalenceProperties, + predicate: &Arc, + ) -> Vec { + /// Returns the `AcrossPartitions` value for `expr` if it is constant: + /// either already known constant in `input_eqs`, or a `Literal` + /// (which is inherently constant across all partitions). + fn expr_constant_or_literal( + expr: &Arc, + input_eqs: &EquivalenceProperties, + ) -> Option { + input_eqs.is_expr_constant(expr).or_else(|| { + expr.downcast_ref::() + .map(|l| AcrossPartitions::Uniform(Some(l.value().clone()))) + }) + } + + let mut constants = Vec::new(); + for conjunction in split_conjunction(predicate) { + if let Some(binary) = conjunction.downcast_ref::() + && binary.op() == &Operator::Eq + { + // Check if either side is constant — either already known + // constant from the input equivalence properties, or a literal + // value (which is inherently constant across all partitions). + let left_const = expr_constant_or_literal(binary.left(), input_eqs); + let right_const = expr_constant_or_literal(binary.right(), input_eqs); + + if let Some(left_across) = left_const { + // LEFT is constant, so RIGHT must also be constant. + // Use RIGHT's known across value if available, otherwise + // propagate LEFT's (e.g. Uniform from a literal). + let across = right_const.unwrap_or(left_across); + constants.push(ConstExpr::new(Arc::clone(binary.right()), across)); + } else if let Some(right_across) = right_const { + // RIGHT is constant, so LEFT must also be constant. + constants + .push(ConstExpr::new(Arc::clone(binary.left()), right_across)); + } + } + } + + constants + } +} + /// Create a conjunction of the given predicates. /// If the input is empty, return a literal true. /// If the input contains a single predicate, return the predicate. @@ -84,7 +144,7 @@ fn split_impl<'a>( predicate: &'a Arc, mut exprs: Vec<&'a Arc>, ) -> Vec<&'a Arc> { - match predicate.as_any().downcast_ref::() { + match predicate.downcast_ref::() { Some(binary) if binary.op() == &operator => { let exprs = split_impl(operator, binary.left(), exprs); split_impl(operator, binary.right(), exprs) @@ -115,16 +175,14 @@ pub fn map_columns_before_projection( let column_mapping = proj_exprs .iter() .filter_map(|(expr, name)| { - expr.as_any() - .downcast_ref::() + expr.downcast_ref::() .map(|column| (name.clone(), column.clone())) }) .collect::>(); parent_required .iter() .filter_map(|r| { - r.as_any() - .downcast_ref::() + r.downcast_ref::() .and_then(|c| column_mapping.get(c.name())) }) .map(|e| Arc::new(e.clone()) as _) @@ -228,8 +286,8 @@ where pub fn collect_columns(expr: &Arc) -> HashSet { let mut columns = HashSet::::new(); expr.apply(|expr| { - if let Some(column) = expr.as_any().downcast_ref::() { - columns.get_or_insert_owned(column); + if let Some(column) = expr.downcast_ref::() { + columns.get_or_insert_with(column, |c| c.clone()); } Ok(TreeNodeRecursion::Continue) }) @@ -252,7 +310,7 @@ pub fn reassign_expr_columns( schema: &Schema, ) -> Result> { expr.transform_down(|expr| { - if let Some(column) = expr.as_any().downcast_ref::() { + if let Some(column) = expr.downcast_ref::() { let index = schema.index_of(column.name())?; return Ok(Transformed::yes(Arc::new(Column::new( @@ -267,15 +325,15 @@ pub fn reassign_expr_columns( #[cfg(test)] pub(crate) mod tests { - use std::any::Any; + use std::fmt::{Display, Formatter}; use super::*; - use crate::expressions::{binary, cast, col, in_list, lit, Literal}; + use crate::expressions::{Literal, binary, cast, col, in_list, lit}; use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{exec_err, internal_datafusion_err, ScalarValue}; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::{ScalarValue, exec_err, internal_datafusion_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -302,9 +360,6 @@ pub(crate) mod tests { } impl ScalarUDFImpl for TestScalarUDF { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "test-scalar-udf" } @@ -394,11 +449,11 @@ pub(crate) mod tests { fn make_dummy_node(node: &ExprTreeNode) -> Result { let expr = Arc::clone(&node.expr); - let dummy_property = if expr.as_any().is::() { + let dummy_property = if expr.is::() { "Binary" - } else if expr.as_any().is::() { + } else if expr.is::() { "Column" - } else if expr.as_any().is::() { + } else if expr.is::() { "Literal" } else { "Other" @@ -562,4 +617,31 @@ pub(crate) mod tests { assert_eq!(collect_columns(&expr3), expected); Ok(()) } + + #[test] + fn test_collect_predicate_constants_propagates_uniform_literal_value() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "ticker", + DataType::Utf8, + false, + )])); + let predicate = binary( + col("ticker", schema.as_ref())?, + Operator::Eq, + lit(ScalarValue::Utf8(Some("NGJ26".to_string()))), + schema.as_ref(), + )?; + let eq_properties = EquivalenceProperties::new(schema); + + let constants = + ConstExpr::collect_predicate_constants(&eq_properties, &predicate); + + assert_eq!(constants.len(), 1); + assert_eq!( + constants[0].across_partitions, + AcrossPartitions::Uniform(Some(ScalarValue::Utf8(Some("NGJ26".to_string())))) + ); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 2ed9770902d58..1ff13d107c036 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::aggregate::AggregateFunctionExpr; use crate::window::standard::add_new_ordering_expr_with_partition_by; -use crate::window::window_expr::{filter_array, AggregateWindowExpr, WindowFn}; +use crate::window::window_expr::{AggregateWindowExpr, WindowFn, filter_array}; use crate::window::{ PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr, }; @@ -33,7 +33,7 @@ use arrow::array::ArrayRef; use arrow::array::BooleanArray; use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; -use datafusion_common::{exec_datafusion_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, exec_datafusion_err}; use datafusion_expr::{Accumulator, WindowFrame, WindowFrameBound, WindowFrameUnits}; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index f93b13fef4dff..a71df3ec88472 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -22,11 +22,11 @@ use std::ops::Range; use std::sync::Arc; use crate::aggregate::AggregateFunctionExpr; -use crate::window::window_expr::{filter_array, AggregateWindowExpr, WindowFn}; +use crate::window::window_expr::{AggregateWindowExpr, WindowFn, filter_array}; use crate::window::{ PartitionBatches, PartitionWindowAggStates, PlainAggregateWindowExpr, WindowExpr, }; -use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; +use crate::{PhysicalExpr, expressions::PhysicalSortExpr}; use arrow::array::{ArrayRef, BooleanArray}; use arrow::datatypes::FieldRef; diff --git a/datafusion/physical-expr/src/window/standard.rs b/datafusion/physical-expr/src/window/standard.rs index e9e7f6abf6368..46f3cabbadd48 100644 --- a/datafusion/physical-expr/src/window/standard.rs +++ b/datafusion/physical-expr/src/window/standard.rs @@ -22,17 +22,17 @@ use std::ops::Range; use std::sync::Arc; use super::{StandardWindowFunctionExpr, WindowExpr}; -use crate::window::window_expr::{get_orderby_values, WindowFn}; +use crate::window::window_expr::{WindowFn, get_orderby_values}; use crate::window::{PartitionBatches, PartitionWindowAggStates, WindowState}; use crate::{EquivalenceProperties, PhysicalExpr}; -use arrow::array::{new_empty_array, ArrayRef}; +use arrow::array::{ArrayRef, new_empty_array}; use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; use datafusion_common::utils::evaluate_partition_ranges; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::window_state::{WindowAggState, WindowFrameContext}; use datafusion_expr::WindowFrame; +use datafusion_expr::window_state::{WindowAggState, WindowFrameContext}; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; /// A window expr that takes the form of a [`StandardWindowFunctionExpr`]. @@ -242,7 +242,7 @@ impl WindowExpr for StandardWindowExpr { // fast path when the result only has a single row row_wise_results[0].to_array()? } else { - ScalarValue::iter_to_array(row_wise_results.into_iter())? + ScalarValue::iter_to_array(row_wise_results)? }; state.update(&out_col, partition_batch_state)?; diff --git a/datafusion/physical-expr/src/window/standard_window_function_expr.rs b/datafusion/physical-expr/src/window/standard_window_function_expr.rs index 9b1213450c2fb..a6ea5e44a4997 100644 --- a/datafusion/physical-expr/src/window/standard_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/standard_window_function_expr.rs @@ -28,14 +28,13 @@ use std::any::Any; use std::sync::Arc; /// Evaluates a window function by instantiating a -/// `[PartitionEvaluator]` for calculating the function's output in +/// [`PartitionEvaluator`] for calculating the function's output in /// that partition. /// /// Note that unlike aggregation based window functions, some window /// functions such as `rank` ignore the values in the window frame, /// but others such as `first_value`, `last_value`, and /// `nth_value` need the value. -#[allow(rustdoc::private_intra_doc_links)] pub trait StandardWindowFunctionExpr: Send + Sync + std::fmt::Debug { /// Returns the aggregate expression as [`Any`] so that it can be /// downcast to a specific implementation. diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 47f970d276e00..0f0ec647a50ae 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -23,17 +23,16 @@ use std::sync::Arc; use crate::PhysicalExpr; use arrow::array::BooleanArray; -use arrow::array::{new_empty_array, Array, ArrayRef}; +use arrow::array::{Array, ArrayRef, new_empty_array}; +use arrow::compute::SortOptions; use arrow::compute::filter as arrow_filter; use arrow::compute::kernels::sort::SortColumn; -use arrow::compute::SortOptions; use arrow::datatypes::FieldRef; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; use datafusion_common::utils::compare_rows; use datafusion_common::{ - arrow_datafusion_err, exec_datafusion_err, internal_err, DataFusionError, Result, - ScalarValue, + Result, ScalarValue, arrow_datafusion_err, exec_datafusion_err, internal_err, }; use datafusion_expr::window_state::{ PartitionBatchState, WindowAggState, WindowFrameContext, WindowFrameStateGroups, @@ -282,7 +281,7 @@ pub trait AggregateWindowExpr: WindowExpr { /// * `window_frame_ctx`: Details about the window frame (see [`WindowFrameContext`]). /// * `idx`: The index of the current row in the record batch. /// * `not_end`: is the current row not the end of the partition (see [`PartitionBatchState`]). - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] fn get_result_column( &self, accumulator: &mut Box, diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index 395da10d629ba..38c8a7c37211f 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -56,5 +56,6 @@ recursive = { workspace = true, optional = true } [dev-dependencies] datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } +datafusion-functions-window = { workspace = true } insta = { workspace = true } tokio = { workspace = true } diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs index 672317060d902..d0be53d59b3cf 100644 --- a/datafusion/physical-optimizer/src/aggregate_statistics.rs +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -16,15 +16,17 @@ // under the License. //! Utilizing exact statistics from sources to avoid scanning data +use datafusion_common::Result; use datafusion_common::config::ConfigOptions; use datafusion_common::scalar::ScalarValue; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::Result; -use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateInputMode, AggregateMode, +}; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs}; -use datafusion_physical_plan::{expressions, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, expressions}; use std::sync::Arc; use crate::PhysicalOptimizerRule; @@ -34,7 +36,7 @@ use crate::PhysicalOptimizerRule; pub struct AggregateStatistics {} impl AggregateStatistics { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -42,15 +44,15 @@ impl AggregateStatistics { impl PhysicalOptimizerRule for AggregateStatistics { #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + #[expect(clippy::allow_attributes)] // See https://github.com/apache/datafusion/issues/18881#issuecomment-3621545670 #[allow(clippy::only_used_in_recursion)] // See https://github.com/rust-lang/rust-clippy/issues/14566 fn optimize( &self, plan: Arc, config: &ConfigOptions, ) -> Result> { - if let Some(partial_agg_exec) = take_optimizable(&*plan) { + if let Some(partial_agg_exec) = take_optimizable(&plan) { let partial_agg_exec = partial_agg_exec - .as_any() .downcast_ref::() .expect("take_optimizable() ensures that this is a AggregateExec"); let stats = partial_agg_exec.input().partition_statistics(None)?; @@ -106,35 +108,38 @@ impl PhysicalOptimizerRule for AggregateStatistics { } } -/// assert if the node passed as argument is a final `AggregateExec` node that can be optimized: -/// - its child (with possible intermediate layers) is a partial `AggregateExec` node -/// - they both have no grouping expression -/// -/// If this is the case, return a ref to the partial `AggregateExec`, else `None`. -/// We would have preferred to return a casted ref to AggregateExec but the recursion requires -/// the `ExecutionPlan.children()` method that returns an owned reference. -fn take_optimizable(node: &dyn ExecutionPlan) -> Option> { - if let Some(final_agg_exec) = node.as_any().downcast_ref::() { - if !final_agg_exec.mode().is_first_stage() - && final_agg_exec.group_expr().is_empty() - { - let mut child = Arc::clone(final_agg_exec.input()); - loop { - if let Some(partial_agg_exec) = - child.as_any().downcast_ref::() - { - if partial_agg_exec.mode().is_first_stage() - && partial_agg_exec.group_expr().is_empty() - && partial_agg_exec.filter_expr().iter().all(|e| e.is_none()) - { - return Some(child); - } - } - if let [childrens_child] = child.children().as_slice() { - child = Arc::clone(childrens_child); - } else { - break; - } +/// Returns an `AggregateExec` whose statistics can replace the aggregate with +/// literal values: either a `Single`/`SinglePartitioned` aggregate, or a +/// `Final` aggregate wrapping a `Partial`. Must have no GROUP BY and no +/// filters. +fn take_optimizable(plan: &Arc) -> Option> { + let agg_exec = plan.downcast_ref::()?; + + if matches!( + agg_exec.mode(), + AggregateMode::Single | AggregateMode::SinglePartitioned + ) && agg_exec.group_expr().is_empty() + && agg_exec.filter_expr().iter().all(|e| e.is_none()) + { + return Some(Arc::clone(plan)); + } + + if agg_exec.mode().input_mode() == AggregateInputMode::Partial + && agg_exec.group_expr().is_empty() + { + let mut child = Arc::clone(agg_exec.input()); + loop { + if let Some(partial_agg_exec) = child.downcast_ref::() + && partial_agg_exec.mode().input_mode() == AggregateInputMode::Raw + && partial_agg_exec.group_expr().is_empty() + && partial_agg_exec.filter_expr().iter().all(|e| e.is_none()) + { + return Some(child); + } + if let [childrens_child] = child.children().as_slice() { + child = Arc::clone(childrens_child); + } else { + break; } } } diff --git a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs index bffb2c9df98ec..297a92c45a16d 100644 --- a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs +++ b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs @@ -21,26 +21,27 @@ use std::sync::Arc; use datafusion_common::error::Result; +use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; -use datafusion_physical_plan::ExecutionPlan; use crate::PhysicalOptimizerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; -use datafusion_physical_expr::{physical_exprs_equal, PhysicalExpr}; +use datafusion_physical_expr::{PhysicalExpr, physical_exprs_equal}; /// CombinePartialFinalAggregate optimizer rule combines the adjacent Partial and Final AggregateExecs /// into a Single AggregateExec if their grouping exprs and aggregate exprs equal. /// -/// This rule should be applied after the EnforceDistribution and EnforceSorting rules +/// This rule should be applied after the `EnsureRequirements` rule (which +/// handles both distribution and sorting enforcement). #[derive(Default, Debug)] pub struct CombinePartialFinalAggregate {} impl CombinePartialFinalAggregate { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -54,7 +55,7 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { ) -> Result> { plan.transform_down(|plan| { // Check if the plan is AggregateExec - let Some(agg_exec) = plan.as_any().downcast_ref::() else { + let Some(agg_exec) = plan.downcast_ref::() else { return Ok(Transformed::no(plan)); }; @@ -66,13 +67,12 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { } // Check if the input is AggregateExec - let Some(input_agg_exec) = - agg_exec.input().as_any().downcast_ref::() + let Some(input_agg_exec) = agg_exec.input().downcast_ref::() else { return Ok(Transformed::no(plan)); }; - let transformed = if matches!(input_agg_exec.mode(), AggregateMode::Partial) + let transformed = if *input_agg_exec.mode() == AggregateMode::Partial && can_combine( ( agg_exec.group_expr(), @@ -98,7 +98,9 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { Arc::clone(input_agg_exec.input()), input_agg_exec.input_schema(), ) - .map(|combined_agg| combined_agg.with_limit(agg_exec.limit())) + .map(|combined_agg| { + combined_agg.with_limit_options(agg_exec.limit_options()) + }) .ok() .map(Arc::new) } else { diff --git a/datafusion/physical-optimizer/src/ensure_coop.rs b/datafusion/physical-optimizer/src/ensure_coop.rs index 0c0b63c0b3e79..e7aacb2321b67 100644 --- a/datafusion/physical-optimizer/src/ensure_coop.rs +++ b/datafusion/physical-optimizer/src/ensure_coop.rs @@ -25,12 +25,12 @@ use std::sync::Arc; use crate::PhysicalOptimizerRule; -use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::Result; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::coop::CooperativeExec; use datafusion_physical_plan::execution_plan::{EvaluationType, SchedulingType}; -use datafusion_physical_plan::ExecutionPlan; /// `EnsureCooperative` is a [`PhysicalOptimizerRule`] that inspects the physical plan for /// sub plans that do not participate in cooperative scheduling. The plan is subdivided into sub @@ -67,23 +67,57 @@ impl PhysicalOptimizerRule for EnsureCooperative { plan: Arc, _config: &ConfigOptions, ) -> Result> { - plan.transform_up(|plan| { - let is_leaf = plan.children().is_empty(); - let is_exchange = plan.properties().evaluation_type == EvaluationType::Eager; - if (is_leaf || is_exchange) - && plan.properties().scheduling_type != SchedulingType::Cooperative - { - // Wrap non-cooperative leaves or eager evaluation roots in a cooperative exec to - // ensure the plans they participate in are properly cooperative. - Ok(Transformed::new( - Arc::new(CooperativeExec::new(Arc::clone(&plan))), - true, - TreeNodeRecursion::Continue, - )) - } else { + use std::cell::RefCell; + + let ancestry_stack = RefCell::new(Vec::<(SchedulingType, EvaluationType)>::new()); + + plan.transform_down_up( + // Down phase: Push parent properties into the stack + |plan| { + let props = plan.properties(); + ancestry_stack + .borrow_mut() + .push((props.scheduling_type, props.evaluation_type)); Ok(Transformed::no(plan)) - } - }) + }, + // Up phase: Wrap nodes with CooperativeExec if needed + |plan| { + ancestry_stack.borrow_mut().pop(); + + let props = plan.properties(); + let is_cooperative = props.scheduling_type == SchedulingType::Cooperative; + let is_leaf = plan.children().is_empty(); + let is_exchange = props.evaluation_type == EvaluationType::Eager; + + let mut is_under_cooperative_context = false; + for (scheduling_type, evaluation_type) in + ancestry_stack.borrow().iter().rev() + { + // If nearest ancestor is cooperative, we are under a cooperative context + if *scheduling_type == SchedulingType::Cooperative { + is_under_cooperative_context = true; + break; + // If nearest ancestor is eager, the cooperative context will be reset + } else if *evaluation_type == EvaluationType::Eager { + is_under_cooperative_context = false; + break; + } + } + + // Wrap if: + // 1. Node is a leaf or exchange point + // 2. Node is not already cooperative + // 3. Not under any Cooperative context + if (is_leaf || is_exchange) + && !is_cooperative + && !is_under_cooperative_context + { + return Ok(Transformed::yes(Arc::new(CooperativeExec::new(plan)))); + } + + Ok(Transformed::no(plan)) + }, + ) .map(|t| t.data) } @@ -96,7 +130,6 @@ impl PhysicalOptimizerRule for EnsureCooperative { #[cfg(test)] mod tests { use super::*; - use datafusion_common::config::ConfigOptions; use datafusion_physical_plan::{displayable, test::scan_partitioned}; use insta::assert_snapshot; @@ -110,9 +143,264 @@ mod tests { let display = displayable(optimized.as_ref()).indent(true).to_string(); // Use insta snapshot to ensure full plan structure - assert_snapshot!(display, @r###" - CooperativeExec - DataSourceExec: partitions=1, partition_sizes=[1] - "###); + assert_snapshot!(display, @r" + CooperativeExec + DataSourceExec: partitions=1, partition_sizes=[1] + "); + } + + #[tokio::test] + async fn test_optimizer_is_idempotent() { + // Comprehensive idempotency test: verify f(f(...f(x))) = f(x) + // This test covers: + // 1. Multiple runs on unwrapped plan + // 2. Multiple runs on already-wrapped plan + // 3. No accumulation of CooperativeExec nodes + + let config = ConfigOptions::new(); + let rule = EnsureCooperative::new(); + + // Test 1: Start with unwrapped plan, run multiple times + let unwrapped_plan = scan_partitioned(1); + let mut current = unwrapped_plan; + let mut stable_result = String::new(); + + for run in 1..=5 { + current = rule.optimize(current, &config).unwrap(); + let display = displayable(current.as_ref()).indent(true).to_string(); + + if run == 1 { + stable_result = display.clone(); + assert_eq!(display.matches("CooperativeExec").count(), 1); + } else { + assert_eq!( + display, stable_result, + "Run {run} should match run 1 (idempotent)" + ); + assert_eq!( + display.matches("CooperativeExec").count(), + 1, + "Should always have exactly 1 CooperativeExec, not accumulate" + ); + } + } + + // Test 2: Start with already-wrapped plan, verify no double wrapping + let pre_wrapped = Arc::new(CooperativeExec::new(scan_partitioned(1))); + let result = rule.optimize(pre_wrapped, &config).unwrap(); + let display = displayable(result.as_ref()).indent(true).to_string(); + + assert_eq!( + display.matches("CooperativeExec").count(), + 1, + "Should not double-wrap already cooperative plans" + ); + assert_eq!( + display, stable_result, + "Pre-wrapped plan should produce same result as unwrapped after optimization" + ); + } + + #[tokio::test] + async fn test_selective_wrapping() { + // Test that wrapping is selective: only leaf/eager nodes, not intermediate nodes + // Also verify depth tracking prevents double wrapping in subtrees + use datafusion_physical_expr::expressions::lit; + use datafusion_physical_plan::filter::FilterExec; + + let config = ConfigOptions::new(); + let rule = EnsureCooperative::new(); + + // Case 1: Filter -> Scan (middle node should not be wrapped) + let scan = scan_partitioned(1); + let filter = Arc::new(FilterExec::try_new(lit(true), scan).unwrap()); + let optimized = rule.optimize(filter, &config).unwrap(); + let display = displayable(optimized.as_ref()).indent(true).to_string(); + + assert_eq!(display.matches("CooperativeExec").count(), 1); + assert!(display.contains("FilterExec")); + + // Case 2: Filter -> CoopExec -> Scan (depth tracking prevents double wrap) + let scan2 = scan_partitioned(1); + let wrapped_scan = Arc::new(CooperativeExec::new(scan2)); + let filter2 = Arc::new(FilterExec::try_new(lit(true), wrapped_scan).unwrap()); + let optimized2 = rule.optimize(filter2, &config).unwrap(); + let display2 = displayable(optimized2.as_ref()).indent(true).to_string(); + + assert_eq!(display2.matches("CooperativeExec").count(), 1); + } + + #[tokio::test] + async fn test_multiple_leaf_nodes() { + // When there are multiple leaf nodes, each should be wrapped separately + use datafusion_physical_plan::union::UnionExec; + + let scan1 = scan_partitioned(1); + let scan2 = scan_partitioned(1); + let union = UnionExec::try_new(vec![scan1, scan2]).unwrap(); + + let config = ConfigOptions::new(); + let optimized = EnsureCooperative::new() + .optimize(union as Arc, &config) + .unwrap(); + + let display = displayable(optimized.as_ref()).indent(true).to_string(); + + // Each leaf should have its own CooperativeExec + assert_eq!( + display.matches("CooperativeExec").count(), + 2, + "Each leaf node should be wrapped separately" + ); + assert_eq!( + display.matches("DataSourceExec").count(), + 2, + "Both data sources should be present" + ); + } + + #[tokio::test] + async fn test_eager_evaluation_resets_cooperative_context() { + // Test that cooperative context is reset when encountering an eager evaluation boundary. + use arrow::datatypes::Schema; + use datafusion_common::internal_err; + use datafusion_execution::TaskContext; + use datafusion_physical_expr::EquivalenceProperties; + use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, Partitioning, PlanProperties, + SendableRecordBatchStream, + execution_plan::{Boundedness, EmissionType}, + }; + + #[derive(Debug)] + struct DummyExec { + name: String, + input: Arc, + scheduling_type: SchedulingType, + evaluation_type: EvaluationType, + properties: Arc, + } + + impl DummyExec { + fn new( + name: &str, + input: Arc, + scheduling_type: SchedulingType, + evaluation_type: EvaluationType, + ) -> Self { + let properties = PlanProperties::new( + EquivalenceProperties::new(Arc::new(Schema::empty())), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ) + .with_scheduling_type(scheduling_type) + .with_evaluation_type(evaluation_type); + + Self { + name: name.to_string(), + input, + scheduling_type, + evaluation_type, + properties: Arc::new(properties), + } + } + } + + impl DisplayAs for DummyExec { + fn fmt_as( + &self, + _: DisplayFormatType, + f: &mut Formatter, + ) -> std::fmt::Result { + write!(f, "{}", self.name) + } + } + + impl ExecutionPlan for DummyExec { + fn name(&self) -> &str { + &self.name + } + fn properties(&self) -> &Arc { + &self.properties + } + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(DummyExec::new( + &self.name, + Arc::clone(&children[0]), + self.scheduling_type, + self.evaluation_type, + ))) + } + fn execute( + &self, + _: usize, + _: Arc, + ) -> Result { + internal_err!("DummyExec does not support execution") + } + } + + // Build a plan similar to the original test: + // scan -> exch1(NonCoop,Eager) -> CoopExec -> filter -> exch2(Coop,Eager) -> filter + let scan = scan_partitioned(1); + let exch1 = Arc::new(DummyExec::new( + "exch1", + scan, + SchedulingType::NonCooperative, + EvaluationType::Eager, + )); + let coop = Arc::new(CooperativeExec::new(exch1)); + let filter1 = Arc::new(DummyExec::new( + "filter1", + coop, + SchedulingType::NonCooperative, + EvaluationType::Lazy, + )); + let exch2 = Arc::new(DummyExec::new( + "exch2", + filter1, + SchedulingType::Cooperative, + EvaluationType::Eager, + )); + let filter2 = Arc::new(DummyExec::new( + "filter2", + exch2, + SchedulingType::NonCooperative, + EvaluationType::Lazy, + )); + + let config = ConfigOptions::new(); + let optimized = EnsureCooperative::new().optimize(filter2, &config).unwrap(); + + let display = displayable(optimized.as_ref()).indent(true).to_string(); + + // Expected wrapping: + // - Scan (leaf) gets wrapped + // - exch1 (eager+noncoop) keeps its manual CooperativeExec wrapper + // - filter1 is protected by exch2's cooperative context, no extra wrap + // - exch2 (already Cooperative) does NOT get wrapped + // - filter2 (not leaf or eager) does NOT get wrapped + assert_eq!( + display.matches("CooperativeExec").count(), + 2, + "Should have 2 CooperativeExec: one wrapping scan, one wrapping exch1" + ); + + assert_snapshot!(display, @r" + filter2 + exch2 + filter1 + CooperativeExec + exch1 + CooperativeExec + DataSourceExec: partitions=1, partition_sizes=[1] + "); } } diff --git a/datafusion/physical-optimizer/src/ensure_requirements/enforce_distribution.rs b/datafusion/physical-optimizer/src/ensure_requirements/enforce_distribution.rs new file mode 100644 index 0000000000000..ada7b6d741cf2 --- /dev/null +++ b/datafusion/physical-optimizer/src/ensure_requirements/enforce_distribution.rs @@ -0,0 +1,1423 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Distribution enforcement helpers. The standalone `EnforceDistribution` +//! rule that previously lived here has been retired in favour of +//! `EnsureRequirements` (which composes distribution and sorting +//! enforcement into a single idempotent pass). The helpers in this +//! module — `adjust_input_keys_ordering`, `reorder_join_keys_to_inputs`, +//! `DistributionContext`, `ensure_distribution`, … — are used directly +//! by `EnsureRequirements`. +//! +//! These helpers inspect the physical plan with respect to distribution +//! requirements and add [`RepartitionExec`]s to satisfy them when necessary. +//! If increasing parallelism is beneficial (and also desirable according to +//! configuration), they increase partition counts in the physical plan. + +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +use crate::output_requirements::OutputRequirementExec; +use crate::utils::{ + add_sort_above_with_check, is_coalesce_partitions, is_repartition, + is_sort_preserving_merge, +}; + +use arrow::compute::SortOptions; +use datafusion_common::config::ConfigOptions; +use datafusion_common::error::Result; +use datafusion_common::stats::Precision; +use datafusion_common::tree_node::Transformed; +use datafusion_expr::logical_plan::{Aggregate, JoinType}; +use datafusion_physical_expr::expressions::{Column, NoOp}; +use datafusion_physical_expr::utils::map_columns_before_projection; +use datafusion_physical_expr::{ + EquivalenceProperties, PhysicalExpr, PhysicalExprRef, physical_exprs_equal, +}; +use datafusion_physical_plan::ExecutionPlanProperties; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_plan::execution_plan::EmissionType; +use datafusion_physical_plan::joins::{ + CrossJoinExec, HashJoinExec, PartitionMode, SortMergeJoinExec, +}; +use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion_physical_plan::tree_node::PlanContext; +use datafusion_physical_plan::union::{InterleaveExec, UnionExec, can_interleave}; +use datafusion_physical_plan::windows::WindowAggExec; +use datafusion_physical_plan::windows::{BoundedWindowAggExec, get_best_fitting_window}; +use datafusion_physical_plan::{ + Distribution, ExecutionPlan, Partitioning, with_new_children_if_necessary, +}; + +use itertools::izip; + +// The `EnforceDistribution` rule was retired in favour of `EnsureRequirements`, +// which composes distribution and sorting enforcement into a single idempotent +// pass. The helper functions below (`adjust_input_keys_ordering`, +// `reorder_join_keys_to_inputs`, `DistributionContext`, `ensure_distribution`, +// etc.) remain — `EnsureRequirements` calls into them directly. + +#[derive(Debug, Clone)] +struct JoinKeyPairs { + left_keys: Vec>, + right_keys: Vec>, +} + +/// Keeps track of parent required key orderings. +pub type PlanWithKeyRequirements = PlanContext>>; + +/// When the physical planner creates the Joins, the ordering of join keys is from the original query. +/// That might not match with the output partitioning of the join node's children +/// A Top-Down process will use this method to adjust children's output partitioning based on the parent key reordering requirements: +/// +/// Example: +/// TopJoin on (a, b, c) +/// bottom left join on(b, a, c) +/// bottom right join on(c, b, a) +/// +/// Will be adjusted to: +/// TopJoin on (a, b, c) +/// bottom left join on(a, b, c) +/// bottom right join on(a, b, c) +/// +/// Example: +/// TopJoin on (a, b, c) +/// Agg1 group by (b, a, c) +/// Agg2 group by (c, b, a) +/// +/// Will be adjusted to: +/// TopJoin on (a, b, c) +/// Projection(b, a, c) +/// Agg1 group by (a, b, c) +/// Projection(c, b, a) +/// Agg2 group by (a, b, c) +/// +/// Following is the explanation of the reordering process: +/// +/// 1) If the current plan is Partitioned HashJoin, SortMergeJoin, check whether the requirements can be satisfied by adjusting join keys ordering: +/// Requirements can not be satisfied, clear the current requirements, generate new requirements(to pushdown) based on the current join keys, return the unchanged plan. +/// Requirements is already satisfied, clear the current requirements, generate new requirements(to pushdown) based on the current join keys, return the unchanged plan. +/// Requirements can be satisfied by adjusting keys ordering, clear the current requirements, generate new requirements(to pushdown) based on the adjusted join keys, return the changed plan. +/// +/// 2) If the current plan is Aggregation, check whether the requirements can be satisfied by adjusting group by keys ordering: +/// Requirements can not be satisfied, clear all the requirements, return the unchanged plan. +/// Requirements is already satisfied, clear all the requirements, return the unchanged plan. +/// Requirements can be satisfied by adjusting keys ordering, clear all the requirements, return the changed plan. +/// +/// 3) If the current plan is RepartitionExec, CoalescePartitionsExec or WindowAggExec, clear all the requirements, return the unchanged plan +/// 4) If the current plan is Projection, transform the requirements to the columns before the Projection and push down requirements +/// 5) For other types of operators, by default, pushdown the parent requirements to children. +pub fn adjust_input_keys_ordering( + mut requirements: PlanWithKeyRequirements, +) -> Result> { + let plan = Arc::clone(&requirements.plan); + + if let Some( + exec @ HashJoinExec { + left, + on, + join_type, + mode, + .. + }, + ) = plan.downcast_ref::() + { + match mode { + PartitionMode::Partitioned => { + let join_constructor = |new_conditions: ( + Vec<(PhysicalExprRef, PhysicalExprRef)>, + Vec, + )| { + exec.builder() + .with_partition_mode(PartitionMode::Partitioned) + .with_on(new_conditions.0) + .build_exec() + }; + return reorder_partitioned_join_keys( + requirements, + on, + &[], + &join_constructor, + ) + .map(Transformed::yes); + } + PartitionMode::CollectLeft => { + // Push down requirements to the right side + requirements.children[1].data = match join_type { + JoinType::Inner | JoinType::Right => shift_right_required( + &requirements.data, + left.schema().fields().len(), + ) + .unwrap_or_default(), + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { + requirements.data.clone() + } + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::Full + | JoinType::LeftMark => vec![], + }; + } + PartitionMode::Auto => { + // Can not satisfy, clear the current requirements and generate new empty requirements + requirements.data.clear(); + } + } + } else if let Some(CrossJoinExec { left, .. }) = plan.downcast_ref::() + { + let left_columns_len = left.schema().fields().len(); + // Push down requirements to the right side + requirements.children[1].data = + shift_right_required(&requirements.data, left_columns_len) + .unwrap_or_default(); + } else if let Some(SortMergeJoinExec { + left, + right, + on, + filter, + join_type, + sort_options, + null_equality, + .. + }) = plan.downcast_ref::() + { + let join_constructor = |new_conditions: ( + Vec<(PhysicalExprRef, PhysicalExprRef)>, + Vec, + )| { + SortMergeJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + new_conditions.0, + filter.clone(), + *join_type, + new_conditions.1, + *null_equality, + ) + .map(|e| Arc::new(e) as _) + }; + return reorder_partitioned_join_keys( + requirements, + on, + sort_options, + &join_constructor, + ) + .map(Transformed::yes); + } else if let Some(aggregate_exec) = plan.downcast_ref::() { + if !requirements.data.is_empty() { + if aggregate_exec.mode() == &AggregateMode::FinalPartitioned { + return reorder_aggregate_keys(requirements, aggregate_exec) + .map(Transformed::yes); + } else { + requirements.data.clear(); + } + } else { + // Keep everything unchanged + return Ok(Transformed::no(requirements)); + } + } else if let Some(proj) = plan.downcast_ref::() { + let expr = proj.expr(); + // For Projection, we need to transform the requirements to the columns before the Projection + // And then to push down the requirements + // Construct a mapping from new name to the original Column + let proj_exprs: Vec<_> = expr + .iter() + .map(|p| (Arc::clone(&p.expr), p.alias.clone())) + .collect(); + let new_required = map_columns_before_projection(&requirements.data, &proj_exprs); + if new_required.len() == requirements.data.len() { + requirements.children[0].data = new_required; + } else { + // Can not satisfy, clear the current requirements and generate new empty requirements + requirements.data.clear(); + } + } else if plan.is::() + || plan.is::() + || plan.is::() + { + requirements.data.clear(); + } else if requirements.data.is_empty() { + // No requirements to push down and no plan changes — skip rebuild. + return Ok(Transformed::no(requirements)); + } else { + // By default, push down the parent requirements to children + for child in requirements.children.iter_mut() { + child.data.clone_from(&requirements.data); + } + } + Ok(Transformed::yes(requirements)) +} + +pub fn reorder_partitioned_join_keys( + mut join_plan: PlanWithKeyRequirements, + on: &[(PhysicalExprRef, PhysicalExprRef)], + sort_options: &[SortOptions], + join_constructor: &F, +) -> Result +where + F: Fn( + (Vec<(PhysicalExprRef, PhysicalExprRef)>, Vec), + ) -> Result>, +{ + let parent_required = &join_plan.data; + let join_key_pairs = extract_join_keys(on); + let eq_properties = join_plan.plan.equivalence_properties(); + + let ( + JoinKeyPairs { + left_keys, + right_keys, + }, + positions, + ) = try_reorder(join_key_pairs, parent_required, eq_properties); + + if let Some(positions) = positions + && !positions.is_empty() + { + let new_join_on = new_join_conditions(&left_keys, &right_keys); + let new_sort_options = (0..sort_options.len()) + .map(|idx| sort_options[positions[idx]]) + .collect(); + join_plan.plan = join_constructor((new_join_on, new_sort_options))?; + } + + join_plan.children[0].data = left_keys; + join_plan.children[1].data = right_keys; + Ok(join_plan) +} + +pub fn reorder_aggregate_keys( + mut agg_node: PlanWithKeyRequirements, + agg_exec: &AggregateExec, +) -> Result { + let parent_required = &agg_node.data; + let output_columns = agg_exec + .group_expr() + .expr() + .iter() + .enumerate() + .map(|(index, (_, name))| Column::new(name, index)) + .collect::>(); + + let output_exprs = output_columns + .iter() + .map(|c| Arc::new(c.clone()) as _) + .collect::>(); + + if parent_required.len() == output_exprs.len() + && agg_exec.group_expr().null_expr().is_empty() + && !physical_exprs_equal(&output_exprs, parent_required) + && let Some(positions) = expected_expr_positions(&output_exprs, parent_required) + && let Some(agg_exec) = agg_exec.input().downcast_ref::() + && *agg_exec.mode() == AggregateMode::Partial + { + let group_exprs = agg_exec.group_expr().expr(); + let new_group_exprs = positions + .into_iter() + .map(|idx| group_exprs[idx].clone()) + .collect(); + let partial_agg = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(new_group_exprs), + agg_exec.aggr_expr().to_vec(), + agg_exec.filter_expr().to_vec(), + Arc::clone(agg_exec.input()), + Arc::clone(&agg_exec.input_schema), + )?); + // Build new group expressions that correspond to the output + // of the "reordered" aggregator: + let group_exprs = partial_agg.group_expr().expr(); + let new_group_by = PhysicalGroupBy::new_single( + partial_agg + .output_group_expr() + .into_iter() + .enumerate() + .map(|(idx, expr)| (expr, group_exprs[idx].1.clone())) + .collect(), + ); + let new_final_agg = Arc::new(AggregateExec::try_new( + AggregateMode::FinalPartitioned, + new_group_by, + agg_exec.aggr_expr().to_vec(), + agg_exec.filter_expr().to_vec(), + Arc::clone(&partial_agg) as _, + agg_exec.input_schema(), + )?); + + agg_node.plan = Arc::clone(&new_final_agg) as _; + agg_node.data.clear(); + agg_node.children = vec![PlanWithKeyRequirements::new( + partial_agg as _, + vec![], + agg_node.children.swap_remove(0).children, + )]; + + // Need to create a new projection to change the expr ordering back + let agg_schema = new_final_agg.schema(); + let mut proj_exprs = output_columns + .iter() + .map(|col| { + let name = col.name(); + let index = agg_schema.index_of(name)?; + Ok(ProjectionExpr { + expr: Arc::new(Column::new(name, index)) as _, + alias: name.to_owned(), + }) + }) + .collect::>>()?; + let agg_fields = agg_schema.fields(); + for (idx, field) in agg_fields.iter().enumerate().skip(output_columns.len()) { + let name = field.name(); + let plan = Arc::new(Column::new(name, idx)) as _; + proj_exprs.push(ProjectionExpr { + expr: plan, + alias: name.clone(), + }) + } + return ProjectionExec::try_new(proj_exprs, new_final_agg) + .map(|p| PlanWithKeyRequirements::new(Arc::new(p), vec![], vec![agg_node])); + } + Ok(agg_node) +} + +fn shift_right_required( + parent_required: &[Arc], + left_columns_len: usize, +) -> Option>> { + let new_right_required = parent_required + .iter() + .filter_map(|r| { + (r.as_ref() as &dyn Any) + .downcast_ref::() + .and_then(|col| { + col.index() + .checked_sub(left_columns_len) + .map(|index| Arc::new(Column::new(col.name(), index)) as _) + }) + }) + .collect::>(); + + // if the parent required are all coming from the right side, the requirements can be pushdown + (new_right_required.len() == parent_required.len()).then_some(new_right_required) +} + +/// When the physical planner creates the Joins, the ordering of join keys is from the original query. +/// That might not match with the output partitioning of the join node's children +/// This method will try to change the ordering of the join keys to match with the +/// partitioning of the join nodes' children. If it can not match with both sides, it will try to +/// match with one, either the left side or the right side. +/// +/// Example: +/// TopJoin on (a, b, c) +/// bottom left join on(b, a, c) +/// bottom right join on(c, b, a) +/// +/// Will be adjusted to: +/// TopJoin on (b, a, c) +/// bottom left join on(b, a, c) +/// bottom right join on(c, b, a) +/// +/// Compared to the Top-Down reordering process, this Bottom-Up approach is much simpler, but might not reach a best result. +/// The Bottom-Up approach will be useful in future if we plan to support storage partition-wised Joins. +/// In that case, the datasources/tables might be pre-partitioned and we can't adjust the key ordering of the datasources +/// and then can't apply the Top-Down reordering process. +pub fn reorder_join_keys_to_inputs( + plan: Arc, +) -> Result> { + if let Some( + exec @ HashJoinExec { + left, + right, + on, + mode, + .. + }, + ) = plan.downcast_ref::() + { + if *mode == PartitionMode::Partitioned { + let (join_keys, positions) = reorder_current_join_keys( + extract_join_keys(on), + Some(left.output_partitioning()), + Some(right.output_partitioning()), + left.equivalence_properties(), + right.equivalence_properties(), + ); + if positions.is_some_and(|idxs| !idxs.is_empty()) { + let JoinKeyPairs { + left_keys, + right_keys, + } = join_keys; + let new_join_on = new_join_conditions(&left_keys, &right_keys); + return exec + .builder() + .with_partition_mode(PartitionMode::Partitioned) + .with_on(new_join_on) + .build_exec(); + } + } + } else if let Some(SortMergeJoinExec { + left, + right, + on, + filter, + join_type, + sort_options, + null_equality, + .. + }) = plan.downcast_ref::() + { + let (join_keys, positions) = reorder_current_join_keys( + extract_join_keys(on), + Some(left.output_partitioning()), + Some(right.output_partitioning()), + left.equivalence_properties(), + right.equivalence_properties(), + ); + if let Some(positions) = positions + && !positions.is_empty() + { + let JoinKeyPairs { + left_keys, + right_keys, + } = join_keys; + let new_join_on = new_join_conditions(&left_keys, &right_keys); + let new_sort_options = (0..sort_options.len()) + .map(|idx| sort_options[positions[idx]]) + .collect(); + return SortMergeJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + new_join_on, + filter.clone(), + *join_type, + new_sort_options, + *null_equality, + ) + .map(|smj| Arc::new(smj) as _); + } + } + Ok(plan) +} + +/// Reorder the current join keys ordering based on either left partition or right partition +fn reorder_current_join_keys( + join_keys: JoinKeyPairs, + left_partition: Option<&Partitioning>, + right_partition: Option<&Partitioning>, + left_equivalence_properties: &EquivalenceProperties, + right_equivalence_properties: &EquivalenceProperties, +) -> (JoinKeyPairs, Option>) { + match (left_partition, right_partition) { + (Some(Partitioning::Hash(left_exprs, _)), _) => { + match try_reorder(join_keys, left_exprs, left_equivalence_properties) { + (join_keys, None) => reorder_current_join_keys( + join_keys, + None, + right_partition, + left_equivalence_properties, + right_equivalence_properties, + ), + result => result, + } + } + (_, Some(Partitioning::Hash(right_exprs, _))) => { + try_reorder(join_keys, right_exprs, right_equivalence_properties) + } + _ => (join_keys, None), + } +} + +fn try_reorder( + join_keys: JoinKeyPairs, + expected: &[Arc], + equivalence_properties: &EquivalenceProperties, +) -> (JoinKeyPairs, Option>) { + let eq_groups = equivalence_properties.eq_group(); + let mut normalized_expected = vec![]; + let mut normalized_left_keys = vec![]; + let mut normalized_right_keys = vec![]; + if join_keys.left_keys.len() != expected.len() { + return (join_keys, None); + } + if physical_exprs_equal(expected, &join_keys.left_keys) + || physical_exprs_equal(expected, &join_keys.right_keys) + { + return (join_keys, Some(vec![])); + } else if !equivalence_properties.eq_group().is_empty() { + normalized_expected = expected + .iter() + .map(|e| eq_groups.normalize_expr(Arc::clone(e))) + .collect(); + + normalized_left_keys = join_keys + .left_keys + .iter() + .map(|e| eq_groups.normalize_expr(Arc::clone(e))) + .collect(); + + normalized_right_keys = join_keys + .right_keys + .iter() + .map(|e| eq_groups.normalize_expr(Arc::clone(e))) + .collect(); + + if physical_exprs_equal(&normalized_expected, &normalized_left_keys) + || physical_exprs_equal(&normalized_expected, &normalized_right_keys) + { + return (join_keys, Some(vec![])); + } + } + + let Some(positions) = expected_expr_positions(&join_keys.left_keys, expected) + .or_else(|| expected_expr_positions(&join_keys.right_keys, expected)) + .or_else(|| expected_expr_positions(&normalized_left_keys, &normalized_expected)) + .or_else(|| { + expected_expr_positions(&normalized_right_keys, &normalized_expected) + }) + else { + return (join_keys, None); + }; + + let mut new_left_keys = vec![]; + let mut new_right_keys = vec![]; + for pos in positions.iter() { + new_left_keys.push(Arc::clone(&join_keys.left_keys[*pos])); + new_right_keys.push(Arc::clone(&join_keys.right_keys[*pos])); + } + let pairs = JoinKeyPairs { + left_keys: new_left_keys, + right_keys: new_right_keys, + }; + + (pairs, Some(positions)) +} + +/// Return the expected expressions positions. +/// For example, the current expressions are ['c', 'a', 'a', b'], the expected expressions are ['b', 'c', 'a', 'a'], +/// +/// This method will return a Vec [3, 0, 1, 2] +fn expected_expr_positions( + current: &[Arc], + expected: &[Arc], +) -> Option> { + if current.is_empty() || expected.is_empty() { + return None; + } + let mut indexes: Vec = vec![]; + let mut current = current.to_vec(); + for expr in expected.iter() { + // Find the position of the expected expr in the current expressions + if let Some(expected_position) = current.iter().position(|e| e.eq(expr)) { + current[expected_position] = Arc::new(NoOp::new()); + indexes.push(expected_position); + } else { + return None; + } + } + Some(indexes) +} + +fn extract_join_keys(on: &[(PhysicalExprRef, PhysicalExprRef)]) -> JoinKeyPairs { + let (left_keys, right_keys) = on + .iter() + .map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _)) + .unzip(); + JoinKeyPairs { + left_keys, + right_keys, + } +} + +fn new_join_conditions( + new_left_keys: &[Arc], + new_right_keys: &[Arc], +) -> Vec<(PhysicalExprRef, PhysicalExprRef)> { + new_left_keys + .iter() + .zip(new_right_keys.iter()) + .map(|(l_key, r_key)| (Arc::clone(l_key), Arc::clone(r_key))) + .collect() +} + +/// Adds RoundRobin repartition operator to the plan increase parallelism. +/// +/// # Arguments +/// +/// * `input`: Current node. +/// * `n_target`: desired target partition number, if partition number of the +/// current executor is less than this value. Partition number will be increased. +/// +/// # Returns +/// +/// A [`Result`] object that contains new execution plan where the desired +/// partition number is achieved by adding a RoundRobin repartition. +fn add_roundrobin_on_top( + input: DistributionContext, + n_target: usize, +) -> Result { + // Adding repartition is helpful: + if input.plan.output_partitioning().partition_count() < n_target { + // When there is an existing ordering, we preserve ordering + // during repartition. This will be un-done in the future + // If any of the following conditions is true + // - Preserving ordering is not helpful in terms of satisfying ordering requirements + // - Usage of order preserving variants is not desirable + // (determined by flag `config.optimizer.prefer_existing_sort`) + let partitioning = Partitioning::RoundRobinBatch(n_target); + let repartition = + RepartitionExec::try_new(Arc::clone(&input.plan), partitioning)? + .with_preserve_order(); + + let new_plan = Arc::new(repartition) as _; + + Ok(DistributionContext::new(new_plan, true, vec![input])) + } else { + // Partition is not helpful, we already have desired number of partitions. + Ok(input) + } +} + +/// Adds a hash repartition operator: +/// - to increase parallelism, and/or +/// - to satisfy requirements of the subsequent operators. +/// +/// Repartition(Hash) is added on top of operator `input`. +/// +/// # Arguments +/// +/// * `input`: Current node. +/// * `hash_exprs`: Stores Physical Exprs that are used during hashing. +/// * `n_target`: desired target partition number, if partition number of the +/// current executor is less than this value. Partition number will be increased. +/// * `allow_subset_satisfy_partitioning`: Whether to allow subset partitioning logic in satisfaction checks. +/// Set to `false` for partitioned hash joins to ensure exact hash matching. +/// +/// # Returns +/// +/// A [`Result`] object that contains new execution plan where the desired +/// distribution is satisfied by adding a Hash repartition. +fn add_hash_on_top( + input: DistributionContext, + hash_exprs: Vec>, + n_target: usize, + allow_subset_satisfy_partitioning: bool, +) -> Result { + // Early return if hash repartition is unnecessary + // `RepartitionExec: partitioning=Hash([...], 1), input_partitions=1` is unnecessary. + if n_target == 1 && input.plan.output_partitioning().partition_count() == 1 { + return Ok(input); + } + + let dist = Distribution::HashPartitioned(hash_exprs); + let satisfaction = input.plan.output_partitioning().satisfaction( + &dist, + input.plan.equivalence_properties(), + allow_subset_satisfy_partitioning, + ); + + // Add hash repartitioning when: + // - When subset satisfaction is enabled (current >= threshold): only repartition if not satisfied + // - When below threshold (current < threshold): repartition if expressions don't match OR to increase parallelism + let needs_repartition = if allow_subset_satisfy_partitioning { + !satisfaction.is_satisfied() + } else { + !satisfaction.is_satisfied() + || n_target > input.plan.output_partitioning().partition_count() + }; + + if needs_repartition { + // When there is an existing ordering, we preserve ordering during + // repartition. This will be rolled back in the future if any of the + // following conditions is true: + // - Preserving ordering is not helpful in terms of satisfying ordering + // requirements. + // - Usage of order preserving variants is not desirable (per the flag + // `config.optimizer.prefer_existing_sort`). + let partitioning = dist.create_partitioning(n_target); + let repartition = + RepartitionExec::try_new(Arc::clone(&input.plan), partitioning)? + .with_preserve_order(); + let plan = Arc::new(repartition) as _; + + return Ok(DistributionContext::new(plan, true, vec![input])); + } + + Ok(input) +} + +/// Adds a [`SortPreservingMergeExec`] or a [`CoalescePartitionsExec`] operator +/// on top of the given plan node to satisfy a single partition requirement +/// while preserving ordering constraints. +/// +/// # Parameters +/// +/// * `input`: Current node. +/// +/// Checks whether preserving the child's ordering enables the parent to +/// run in streaming mode. Compares the parent's pipeline behavior with +/// the ordered child vs. an unordered (coalesced) child. If removing the +/// ordering would cause the parent to switch from streaming to blocking, +/// keeping the order-preserving variant is beneficial. +/// +/// Only applicable to single-child operators; returns `Ok(false)` for +/// multi-child operators (e.g. joins) where child substitution semantics are +/// ambiguous. +fn preserving_order_enables_streaming( + parent: &Arc, + ordered_child: &Arc, +) -> Result { + // Only applicable to single-child operators that maintain input order + // (e.g. AggregateExec in PartiallySorted mode). Operators that don't + // maintain input order (e.g. SortExec) handle ordering themselves — + // preserving SPM for them is unnecessary. + if parent.children().len() != 1 { + return Ok(false); + } + if !parent.maintains_input_order()[0] { + return Ok(false); + } + // Build parent with the ordered child + let with_ordered = + Arc::clone(parent).with_new_children(vec![Arc::clone(ordered_child)])?; + if with_ordered.pipeline_behavior() == EmissionType::Final { + // Parent is blocking even with ordering — no benefit + return Ok(false); + } + // Build parent with an unordered child via CoalescePartitionsExec. + let unordered_child: Arc = + Arc::new(CoalescePartitionsExec::new(Arc::clone(ordered_child))); + let without_ordered = Arc::clone(parent).with_new_children(vec![unordered_child])?; + Ok(without_ordered.pipeline_behavior() == EmissionType::Final) +} + +/// # Returns +/// +/// Updated node with an execution plan, where the desired single distribution +/// requirement is satisfied. +fn add_merge_on_top( + input: DistributionContext, + fetch: Option, +) -> DistributionContext { + // Apply only when the partition count is larger than one. + if input.plan.output_partitioning().partition_count() > 1 { + // When there is an existing ordering, we preserve ordering + // when decreasing partitions. This will be un-done in the future + // if any of the following conditions is true + // - Preserving ordering is not helpful in terms of satisfying ordering requirements + // - Usage of order preserving variants is not desirable + // (determined by flag `config.optimizer.prefer_existing_sort`) + let new_plan: Arc = if let Some(req) = + input.plan.output_ordering() + { + let mut spm = + SortPreservingMergeExec::new(req.clone(), Arc::clone(&input.plan)); + if let Some(f) = fetch { + spm = spm.with_fetch(Some(f)); + } + Arc::new(spm) + } else { + // If there is no input order, we can simply coalesce partitions: + Arc::new( + CoalescePartitionsExec::new(Arc::clone(&input.plan)).with_fetch(fetch), + ) + }; + + DistributionContext::new(new_plan, true, vec![input]) + } else { + input + } +} + +/// Updates the physical plan inside [`DistributionContext`] so that distribution +/// changing operators are removed from the top. If they are necessary, they will +/// be added in subsequent stages. +/// +/// Assume that following plan is given: +/// ```text +/// "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", +/// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", +/// " DataSourceExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC], file_type=parquet", +/// ``` +/// +/// Since `RepartitionExec`s change the distribution, this function removes +/// them and returns following plan: +/// +/// ```text +/// "DataSourceExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC], file_type=parquet", +/// ``` +/// Returned by [`remove_dist_changing_operators`] to carry the fetch value +/// that may have been on a removed `SortPreservingMergeExec` or `CoalescePartitionsExec`. +struct RemovedDistOps { + context: DistributionContext, + /// The fetch value from the removed SPM/Coalesce, if any. + /// Must be re-applied when distribution operators are re-inserted. + removed_fetch: Option, +} + +fn remove_dist_changing_operators( + mut distribution_context: DistributionContext, +) -> Result { + let mut removed_fetch = None; + while is_repartition(&distribution_context.plan) + || is_coalesce_partitions(&distribution_context.plan) + || is_sort_preserving_merge(&distribution_context.plan) + { + // Preserve fetch from SPM or CoalescePartitions before removing (#14150). + if let Some(fetch) = distribution_context.plan.fetch() { + removed_fetch = Some( + removed_fetch + .map(|existing: usize| existing.min(fetch)) + .unwrap_or(fetch), + ); + } + // All of above operators have a single child. First child is only child. + // Remove any distribution changing operators at the beginning: + distribution_context = distribution_context.children.swap_remove(0); + // Note that they will be re-inserted later on if necessary or helpful. + } + + Ok(RemovedDistOps { + context: distribution_context, + removed_fetch, + }) +} + +/// Updates the [`DistributionContext`] if preserving ordering while changing partitioning is not helpful or desirable. +/// +/// Assume that following plan is given: +/// ```text +/// "SortPreservingMergeExec: \[a@0 ASC]" +/// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true", +/// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2, preserve_order=true", +/// " DataSourceExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC], file_type=parquet", +/// ``` +/// +/// This function converts plan above to the following: +/// +/// ```text +/// "CoalescePartitionsExec" +/// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", +/// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", +/// " DataSourceExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC], file_type=parquet", +/// ``` +pub fn replace_order_preserving_variants( + mut context: DistributionContext, +) -> Result { + context.children = context + .children + .into_iter() + .map(|child| { + if child.data { + replace_order_preserving_variants(child) + } else { + Ok(child) + } + }) + .collect::>>()?; + + if is_sort_preserving_merge(&context.plan) { + let child_plan = Arc::clone(&context.children[0].plan); + context.plan = Arc::new( + CoalescePartitionsExec::new(child_plan).with_fetch(context.plan.fetch()), + ); + return Ok(context); + } else if let Some(repartition) = context.plan.downcast_ref::() + && repartition.preserve_order() + { + context.plan = Arc::new(RepartitionExec::try_new( + Arc::clone(&context.children[0].plan), + repartition.partitioning().clone(), + )?); + return Ok(context); + } + + context.update_plan_from_children() +} + +/// A struct to keep track of repartition requirements for each child node. +struct RepartitionRequirementStatus { + /// The distribution requirement for the node. + requirement: Distribution, + /// Designates whether round robin partitioning is theoretically beneficial; + /// i.e. the operator can actually utilize parallelism. + roundrobin_beneficial: bool, + /// Designates whether round robin partitioning is beneficial according to + /// the statistical information we have on the number of rows. + roundrobin_beneficial_stats: bool, + /// Designates whether hash partitioning is necessary. + hash_necessary: bool, +} + +/// Calculates the `RepartitionRequirementStatus` for each children to generate +/// consistent and sensible (in terms of performance) distribution requirements. +/// As an example, a hash join's left (build) child might produce +/// +/// ```text +/// RepartitionRequirementStatus { +/// .., +/// hash_necessary: true +/// } +/// ``` +/// +/// while its right (probe) child might have very few rows and produce: +/// +/// ```text +/// RepartitionRequirementStatus { +/// .., +/// hash_necessary: false +/// } +/// ``` +/// +/// These statuses are not consistent as all children should agree on hash +/// partitioning. This function aligns the statuses to generate consistent +/// hash partitions for each children. After alignment, the right child's +/// status would turn into: +/// +/// ```text +/// RepartitionRequirementStatus { +/// .., +/// hash_necessary: true +/// } +/// ``` +fn get_repartition_requirement_status( + plan: &Arc, + batch_size: usize, + should_use_estimates: bool, +) -> Result> { + let mut needs_alignment = false; + let children = plan.children(); + let rr_beneficial = plan.benefits_from_input_partitioning(); + let requirements = plan.required_input_distribution(); + let mut repartition_status_flags = vec![]; + for (child, requirement, roundrobin_beneficial) in + izip!(children.into_iter(), requirements, rr_beneficial) + { + // Decide whether adding a round robin is beneficial depending on + // the statistical information we have on the number of rows: + let roundrobin_beneficial_stats = match child.partition_statistics(None)?.num_rows + { + Precision::Exact(n_rows) => n_rows > batch_size, + Precision::Inexact(n_rows) => !should_use_estimates || (n_rows > batch_size), + Precision::Absent => true, + }; + let is_hash = matches!(requirement, Distribution::HashPartitioned(_)); + // Hash re-partitioning is necessary when the input has more than one + // partitions: + let multi_partitions = child.output_partitioning().partition_count() > 1; + let roundrobin_sensible = roundrobin_beneficial && roundrobin_beneficial_stats; + needs_alignment |= is_hash && (multi_partitions || roundrobin_sensible); + repartition_status_flags.push(( + is_hash, + RepartitionRequirementStatus { + requirement, + roundrobin_beneficial, + roundrobin_beneficial_stats, + hash_necessary: is_hash && multi_partitions, + }, + )); + } + // Align hash necessary flags for hash partitions to generate consistent + // hash partitions at each children: + if needs_alignment { + // When there is at least one hash requirement that is necessary or + // beneficial according to statistics, make all children require hash + // repartitioning: + for (is_hash, status) in &mut repartition_status_flags { + if *is_hash { + status.hash_necessary = true; + } + } + } + Ok(repartition_status_flags + .into_iter() + .map(|(_, status)| status) + .collect()) +} + +/// This function checks whether we need to add additional data exchange +/// operators to satisfy distribution requirements. Since this function +/// takes care of such requirements, we should avoid manually adding data +/// exchange operators in other places. +/// +/// This function is intended to be used in a bottom up traversal, as it +/// can first repartition (or newly partition) at the datasources -- these +/// source partitions may be later repartitioned with additional data exchange operators. +pub fn ensure_distribution( + dist_context: DistributionContext, + config: &ConfigOptions, +) -> Result> { + let dist_context = update_children(dist_context)?; + + if dist_context.plan.children().is_empty() { + return Ok(Transformed::no(dist_context)); + } + + let target_partitions = config.execution.target_partitions; + // When `false`, round robin repartition will not be added to increase parallelism + let enable_round_robin = config.optimizer.enable_round_robin_repartition; + let repartition_file_scans = config.optimizer.repartition_file_scans; + let batch_size = config.execution.batch_size; + let should_use_estimates = config + .execution + .use_row_number_estimates_to_optimize_partitioning; + let subset_satisfaction_threshold = config.optimizer.subset_repartition_threshold; + let unbounded_and_pipeline_friendly = dist_context.plan.boundedness().is_unbounded() + && matches!( + dist_context.plan.pipeline_behavior(), + EmissionType::Incremental | EmissionType::Both + ); + // Use order preserving variants either of the conditions true + // - it is desired according to config + // - when plan is unbounded + // - when it is pipeline friendly (can incrementally produce results) + let order_preserving_variants_desirable = + unbounded_and_pipeline_friendly || config.optimizer.prefer_existing_sort; + + // Remove unnecessary repartition from the physical plan if any. + // Preserve fetch from removed SPM/Coalesce (#14150). + let RemovedDistOps { + context: + DistributionContext { + mut plan, + data, + children, + }, + removed_fetch, + } = remove_dist_changing_operators(dist_context)?; + + if let Some(exec) = plan.downcast_ref::() { + if let Some(updated_window) = get_best_fitting_window( + exec.window_expr(), + exec.input(), + &exec.partition_keys(), + )? { + plan = updated_window; + } + } else if let Some(exec) = plan.downcast_ref::() + && let Some(updated_window) = get_best_fitting_window( + exec.window_expr(), + exec.input(), + &exec.partition_keys(), + )? + { + plan = updated_window; + }; + + // For joins in partitioned mode, we need exact hash matching between + // both sides, so subset partitioning logic must be disabled. + // + // Why: Different hash expressions produce different hash values, causing + // rows with the same join key to land in different partitions. Since + // partitioned joins match partition N left with partition N right, rows + // that should match may be in different partitions and miss each other. + // + // Example JOIN ON left.a = right.a: + // + // Left: Hash([a]) + // Partition 1: a=1 + // Partition 2: a=2 + // + // Right: Hash([a, b]) + // Partition 1: (a=1, b=1) -> Same a=1 + // Partition 2: (a=2, b=2) + // Partition 3: (a=1, b=2) -> Same a=1 + // + // Partitioned join execution: + // P1 left (a=1) joins P1 right (a=1, b=1) -> Match + // P2 left (a=2) joins P2 right (a=2, b=2) -> Match + // P3 left (empty) joins P3 right (a=1, b=2) -> Missing, errors + // + // The row (a=1, b=2) should match left.a=1 but they're in different + // partitions, causing panics. + // + // CollectLeft/CollectRight modes are safe because one side is collected + // to a single partition which eliminates partition-to-partition mapping. + let is_partitioned_join = plan + .downcast_ref::() + .is_some_and(|join| join.mode == PartitionMode::Partitioned) + || plan.is::(); + + let repartition_status_flags = + get_repartition_requirement_status(&plan, batch_size, should_use_estimates)?; + // This loop iterates over all the children to: + // - Increase parallelism for every child if it is beneficial. + // - Satisfy the distribution requirements of every child, if it is not + // already satisfied. + // We store the updated children in `new_children`. + let children = izip!( + children.into_iter(), + plan.required_input_ordering(), + plan.maintains_input_order(), + repartition_status_flags.into_iter() + ) + .map( + |( + mut child, + required_input_ordering, + maintains, + RepartitionRequirementStatus { + requirement, + roundrobin_beneficial, + roundrobin_beneficial_stats, + hash_necessary, + }, + )| { + let increases_partition_count = + child.plan.output_partitioning().partition_count() < target_partitions; + + let add_roundrobin = enable_round_robin + // Operator benefits from partitioning (e.g. filter): + && roundrobin_beneficial + && roundrobin_beneficial_stats + // Unless partitioning increases the partition count, it is not beneficial: + && increases_partition_count; + + // Allow subset satisfaction when: + // 1. Current partition count >= threshold + // 2. Not a partitioned join since must use exact hash matching for joins + // 3. Not a grouping set aggregate (requires exact hash including __grouping_id) + let current_partitions = child.plan.output_partitioning().partition_count(); + + // Check if the hash partitioning requirement includes __grouping_id column. + // Grouping set aggregates (ROLLUP, CUBE, GROUPING SETS) require exact hash + // partitioning on all group columns including __grouping_id to ensure partial + // aggregates from different partitions are correctly combined. + let requires_grouping_id = matches!(&requirement, Distribution::HashPartitioned(exprs) + if exprs.iter().any(|expr| { + (expr.as_ref() as &dyn Any) + .downcast_ref::() + .is_some_and(|col| col.name() == Aggregate::INTERNAL_GROUPING_ID) + }) + ); + + let allow_subset_satisfy_partitioning = (current_partitions + >= subset_satisfaction_threshold + // `preserve_file_partitions` exposes existing file-group + // partitioning to the optimizer. Respect it when the only + // reason to repartition would be to increase partition count + // beyond the preserved file-group count. + || (config.optimizer.preserve_file_partitions > 0 + && current_partitions < target_partitions)) + && !is_partitioned_join + && !requires_grouping_id; + + // When `repartition_file_scans` is set, attempt to increase + // parallelism at the source. + // + // If repartitioning is not possible (a.k.a. None is returned from `ExecutionPlan::repartitioned`) + // then no repartitioning will have occurred. As the default implementation returns None, it is only + // specific physical plan nodes, such as certain datasources, which are repartitioned. + if repartition_file_scans + && roundrobin_beneficial_stats + && let Some(new_child) = + child.plan.repartitioned(target_partitions, config)? + { + child.plan = new_child; + } + + // Satisfy the distribution requirement if it is unmet. + match &requirement { + Distribution::SinglePartition => { + child = add_merge_on_top(child, removed_fetch); + } + Distribution::HashPartitioned(exprs) => { + // See https://github.com/apache/datafusion/issues/18341#issuecomment-3503238325 for background + // When inserting hash is necessary to satisfy hash requirement, insert hash repartition. + if hash_necessary { + child = add_hash_on_top( + child, + exprs.to_vec(), + target_partitions, + allow_subset_satisfy_partitioning, + )?; + } + } + Distribution::UnspecifiedDistribution => { + if add_roundrobin { + // Add round-robin repartitioning on top of the operator + // to increase parallelism. + child = add_roundrobin_on_top(child, target_partitions)?; + } + } + }; + + let streaming_benefit = if child.data { + preserving_order_enables_streaming(&plan, &child.plan)? + } else { + false + }; + + // There is an ordering requirement of the operator: + if let Some(required_input_ordering) = required_input_ordering { + // Either: + // - Ordering requirement cannot be satisfied by preserving ordering through repartitions, or + // - using order preserving variant is not desirable. + let sort_req = required_input_ordering.into_single(); + let ordering_satisfied = child + .plan + .equivalence_properties() + .ordering_satisfy_requirement(sort_req.clone())?; + + if (!ordering_satisfied || !order_preserving_variants_desirable) + && !streaming_benefit + && child.data + { + child = replace_order_preserving_variants(child)?; + // If ordering requirements were satisfied before repartitioning, + // make sure ordering requirements are still satisfied after. + if ordering_satisfied { + // Make sure to satisfy ordering requirement: + child = add_sort_above_with_check( + child, + sort_req, + plan.downcast_ref::() + .map(|output| output.fetch()) + .unwrap_or(None), + )?; + } + } + // Stop tracking distribution changing operators + child.data = false; + } else { + let streaming_benefit = if child.data { + preserving_order_enables_streaming(&plan, &child.plan)? + } else { + false + }; + // no ordering requirement + match requirement { + // Operator requires specific distribution. + Distribution::SinglePartition | Distribution::HashPartitioned(_) => { + // If the parent doesn't maintain input order, preserving + // ordering is pointless. However, if it does maintain + // input order, we keep order-preserving variants so + // ordering can flow through to ancestors that need it. + if !maintains && !streaming_benefit { + child = replace_order_preserving_variants(child)?; + } + } + Distribution::UnspecifiedDistribution => { + // Since ordering is lost, trying to preserve ordering is pointless + if !maintains || plan.is::() { + child = replace_order_preserving_variants(child)?; + } + } + } + } + Ok(child) + }, + ) + .collect::>>()?; + + let children_plans = children + .iter() + .map(|c| Arc::clone(&c.plan)) + .collect::>(); + + plan = if plan.is::() + && !config.optimizer.prefer_existing_union + && can_interleave(children_plans.iter()) + { + // Add a special case for [`UnionExec`] since we want to "bubble up" + // hash-partitioned data. So instead of + // + // Agg: + // Repartition (hash): + // Union: + // - Agg: + // Repartition (hash): + // Data + // - Agg: + // Repartition (hash): + // Data + // + // we can use: + // + // Agg: + // Interleave: + // - Agg: + // Repartition (hash): + // Data + // - Agg: + // Repartition (hash): + // Data + Arc::new(InterleaveExec::try_new(children_plans)?) + } else { + // Route through `with_new_children_if_necessary` so the common + // case where no child was replaced above skips the expensive + // `with_new_children` rebuild. For nodes like `ProjectionExec`, + // `with_new_children` recomputes schema / equivalence properties / + // output ordering via `try_new` even when the input Arcs are + // identical, which dominates `ensure_distribution` time on deep + // projection stacks over plans where no distribution change + // applies (point queries with no join / aggregate / unmet + // ordering). + with_new_children_if_necessary(plan, children_plans)? + }; + + Ok(Transformed::yes(DistributionContext::new( + plan, data, children, + ))) +} + +/// Keeps track of distribution changing operators (like `RepartitionExec`, +/// `SortPreservingMergeExec`, `CoalescePartitionsExec`) and their ancestors. +/// Using this information, we can optimize distribution of the plan if/when +/// necessary. +pub type DistributionContext = PlanContext; + +fn update_children(mut dist_context: DistributionContext) -> Result { + for child_context in dist_context.children.iter_mut() { + child_context.data = if let Some(repartition) = + child_context.plan.downcast_ref::() + { + !matches!( + repartition.partitioning(), + Partitioning::UnknownPartitioning(_) + ) + } else { + child_context.plan.is::() + || child_context.plan.is::() + || child_context.plan.children().is_empty() + || child_context.children[0].data + || child_context + .plan + .required_input_distribution() + .iter() + .zip(child_context.children.iter()) + .any(|(required_dist, child_context)| { + child_context.data + && matches!( + required_dist, + Distribution::UnspecifiedDistribution + ) + }) + } + } + + dist_context.data = false; + Ok(dist_context) +} + +// See tests in datafusion/core/tests/physical_optimizer diff --git a/datafusion/physical-optimizer/src/ensure_requirements/enforce_sorting/mod.rs b/datafusion/physical-optimizer/src/ensure_requirements/enforce_sorting/mod.rs new file mode 100644 index 0000000000000..53917a51085ec --- /dev/null +++ b/datafusion/physical-optimizer/src/ensure_requirements/enforce_sorting/mod.rs @@ -0,0 +1,746 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sort enforcement helpers. The standalone `EnforceSorting` rule that +//! previously lived here has been retired in favour of `EnsureRequirements` +//! (which composes distribution and sorting enforcement into a single +//! idempotent pass). The helpers in this module — `ensure_sorting`, +//! `parallelize_sorts`, `PlanWithCorrespondingSort`, and the submodules +//! `replace_with_order_preserving_variants` and `sort_pushdown` — are +//! used directly by `EnsureRequirements`. +//! +//! Sort enforcement inspects the physical plan with respect to local +//! sorting requirements and does the following: +//! - Adds a [`SortExec`] when a requirement is not met, +//! - Removes an already-existing [`SortExec`] if it is possible to prove +//! that this sort is unnecessary +//! +//! The helpers can work on valid *and* invalid physical plans with respect +//! to sorting requirements, but always produce a valid plan in this sense. +//! +//! A non-realistic but easy to follow example for sort removals: assume the +//! fragment +//! +//! ```text +//! SortExec: expr=[nullable_col@0 ASC] +//! SortExec: expr=[non_nullable_col@1 ASC] +//! ``` +//! +//! reaches this stage. The first sort is unnecessary since its result is +//! overwritten by another [`SortExec`], so it is removed. + +pub mod replace_with_order_preserving_variants; +pub mod sort_pushdown; + +use std::sync::Arc; + +use crate::output_requirements::OutputRequirementExec; +use crate::utils::{ + add_sort_above, add_sort_above_with_check, is_coalesce_partitions, is_limit, + is_repartition, is_sort, is_sort_preserving_merge, is_window, +}; + +use datafusion_common::Result; +use datafusion_common::plan_err; +use datafusion_common::tree_node::Transformed; +use datafusion_physical_expr::{Distribution, Partitioning}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::sorts::partial_sort::PartialSortExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion_physical_plan::tree_node::PlanContext; +use datafusion_physical_plan::windows::{ + BoundedWindowAggExec, WindowAggExec, get_best_fitting_window, +}; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties, InputOrderMode}; + +use itertools::izip; + +// The `EnforceSorting` rule was retired in favour of `EnsureRequirements`, +// which composes distribution and sorting enforcement into a single idempotent +// pass. The helper functions and contexts below (`ensure_sorting`, +// `parallelize_sorts`, `PlanWithCorrespondingSort`, etc.) remain — +// `EnsureRequirements` calls into them directly. + +/// Context object used by sort enforcement to track the closest +/// [`SortExec`] descendant(s) for every child of a plan. The data attribute +/// stores whether the plan is a `SortExec` or is connected to a `SortExec` +/// via its children. +pub type PlanWithCorrespondingSort = PlanContext; + +/// For a given node, update the `PlanContext.data` attribute. +/// +/// If the node is a `SortExec`, or any of the node's children are a `SortExec`, +/// then set the attribute to true. +/// +/// This requires a bottom-up traversal was previously performed, updating the +/// children previously. +fn update_sort_ctx_children_data( + mut node_and_ctx: PlanWithCorrespondingSort, + data: bool, +) -> Result { + // Update `child.data` for all children. + for child_node in node_and_ctx.children.iter_mut() { + let child_plan = &child_node.plan; + child_node.data = if is_sort(child_plan) { + // child is sort + true + } else if is_limit(child_plan) { + // There is no sort linkage for this path, it starts at a limit. + false + } else { + // If a descendent is a sort, and the child maintains the sort. + let is_spm = is_sort_preserving_merge(child_plan); + let required_orderings = child_plan.required_input_ordering(); + let flags = child_plan.maintains_input_order(); + // Add parent node to the tree if there is at least one child with + // a sort connection: + izip!(flags, required_orderings).any(|(maintains, required_ordering)| { + let propagates_ordering = + (maintains && required_ordering.is_none()) || is_spm; + // `connected_to_sort` only returns the correct answer with bottom-up traversal + let connected_to_sort = + child_node.children.iter().any(|child| child.data); + propagates_ordering && connected_to_sort + }) + } + } + + // set data attribute on current node + node_and_ctx.data = data; + + Ok(node_and_ctx) +} + +/// Tracks the closest +/// [`CoalescePartitionsExec`] descendant(s) for every child of a plan. The data +/// attribute stores whether the plan is a `CoalescePartitionsExec` or is +/// connected to a `CoalescePartitionsExec` via its children. +/// +/// The tracker halts at each [`SortExec`] (where the SPM will act to replace the coalesce). +/// +/// This requires a bottom-up traversal was previously performed, updating the +/// children previously. +pub type PlanWithCorrespondingCoalescePartitions = PlanContext; + +/// Discovers the linked Coalesce->Sort cascades. +/// +/// This linkage is used in [`remove_bottleneck_in_subplan`] to selectively +/// remove the linked coalesces in the subplan. Then afterwards, an SPM is added +/// at the root of the subplan (just after the sort) in order to parallelize sorts. +/// Refer to the [`parallelize_sorts`] for more details on sort parallelization. +/// +/// Example of linked Coalesce->Sort: +/// ```text +/// SortExec ctx.data=false, to halt remove_bottleneck_in_subplan) +/// ...nodes... ctx.data=true (e.g. are linked in cascade) +/// Coalesce ctx.data=true (e.g. is a coalesce) +/// ``` +/// +/// The link should not be continued (and the coalesce not removed) if the distribution +/// is changed between the Coalesce->Sort cascade. Example: +/// ```text +/// SortExec ctx.data=false, to halt remove_bottleneck_in_subplan) +/// AggregateExec ctx.data=false, to stop the link +/// ...nodes... ctx.data=true (e.g. are linked in cascade) +/// Coalesce ctx.data=true (e.g. is a coalesce) +/// ``` +fn update_coalesce_ctx_children( + coalesce_context: &mut PlanWithCorrespondingCoalescePartitions, +) { + let children = &coalesce_context.children; + coalesce_context.data = if children.is_empty() { + // Plan has no children, it cannot be a `CoalescePartitionsExec`. + false + } else if is_coalesce_partitions(&coalesce_context.plan) { + // Initiate a connection: + true + } else { + children.iter().enumerate().any(|(idx, node)| { + // Only consider operators that don't require a single partition, + // and connected to some `CoalescePartitionsExec`: + node.data + && !matches!( + coalesce_context.plan.required_input_distribution()[idx], + Distribution::SinglePartition + ) + }) + }; +} + +/// Only interested with [`SortExec`]s and their unbounded children. +/// If the plan is not a [`SortExec`] or its child is not unbounded, returns the original plan. +/// Otherwise, by checking the requirement satisfaction searches for a replacement chance. +/// If there's one replaces the [`SortExec`] plan with a [`PartialSortExec`] +pub fn replace_with_partial_sort( + plan: Arc, +) -> Result> { + let Some(sort_plan) = plan.downcast_ref::() else { + return Ok(plan); + }; + + // It's safe to get first child of the SortExec + let child = Arc::clone(sort_plan.children()[0]); + if !child.boundedness().is_unbounded() { + return Ok(plan); + } + + // Here we're trying to find the common prefix for sorted columns that is required for the + // sort and already satisfied by the given ordering + let child_eq_properties = child.equivalence_properties(); + let sort_exprs = sort_plan.expr().clone(); + + let mut common_prefix_length = 0; + while child_eq_properties + .ordering_satisfy(sort_exprs[0..common_prefix_length + 1].to_vec())? + { + common_prefix_length += 1; + } + if common_prefix_length > 0 { + return Ok(Arc::new( + PartialSortExec::new( + sort_exprs, + Arc::clone(sort_plan.input()), + common_prefix_length, + ) + .with_preserve_partitioning(sort_plan.preserve_partitioning()) + .with_fetch(sort_plan.fetch()), + )); + } + Ok(plan) +} + +/// Transform [`CoalescePartitionsExec`] + [`SortExec`] cascades into [`SortExec`] +/// + [`SortPreservingMergeExec`] cascades, as illustrated below. +/// +/// A [`CoalescePartitionsExec`] + [`SortExec`] cascade combines partitions +/// first, and then sorts: +/// ```text +/// ┌ ─ ─ ─ ─ ─ ┐ +/// ┌─┬─┬─┐ +/// ││B│A│D│... ├──┐ +/// └─┴─┴─┘ │ +/// └ ─ ─ ─ ─ ─ ┘ │ ┌────────────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ┐ +/// Partition 1 │ │ Coalesce │ ┌─┬─┬─┬─┬─┐ │ │ ┌─┬─┬─┬─┬─┐ +/// ├──▶(no ordering guarantees)│──▶││B│E│A│D│C│...───▶ Sort ├───▶││A│B│C│D│E│... │ +/// │ │ │ └─┴─┴─┴─┴─┘ │ │ └─┴─┴─┴─┴─┘ +/// ┌ ─ ─ ─ ─ ─ ┐ │ └────────────────────────┘ └ ─ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ─ ─ ┘ +/// ┌─┬─┐ │ Partition Partition +/// ││E│C│ ... ├──┘ +/// └─┴─┘ +/// └ ─ ─ ─ ─ ─ ┘ +/// Partition 2 +/// ``` +/// +/// +/// A [`SortExec`] + [`SortPreservingMergeExec`] cascade sorts each partition +/// first, then merges partitions while preserving the sort: +/// ```text +/// ┌ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ┐ +/// ┌─┬─┬─┐ │ │ ┌─┬─┬─┐ +/// ││B│A│D│... │──▶│ Sort │──▶││A│B│D│... │──┐ +/// └─┴─┴─┘ │ │ └─┴─┴─┘ │ +/// └ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ┘ │ ┌─────────────────────┐ ┌ ─ ─ ─ ─ ─ ─ ─ ┐ +/// Partition 1 Partition 1 │ │ │ ┌─┬─┬─┬─┬─┐ +/// ├──▶ SortPreservingMerge ├───▶││A│B│C│D│E│... │ +/// │ │ │ └─┴─┴─┴─┴─┘ +/// ┌ ─ ─ ─ ─ ─ ┐ ┌────────┐ ┌ ─ ─ ─ ─ ─ ┐ │ └─────────────────────┘ └ ─ ─ ─ ─ ─ ─ ─ ┘ +/// ┌─┬─┐ │ │ ┌─┬─┐ │ Partition +/// ││E│C│ ... │──▶│ Sort ├──▶││C│E│ ... │──┘ +/// └─┴─┘ │ │ └─┴─┘ +/// └ ─ ─ ─ ─ ─ ┘ └────────┘ └ ─ ─ ─ ─ ─ ┘ +/// Partition 2 Partition 2 +/// ``` +/// +/// The latter [`SortExec`] + [`SortPreservingMergeExec`] cascade performs +/// sorting first on a per-partition basis, thereby parallelizing the sort. +/// +/// The outcome is that plans of the form +/// ```text +/// "SortExec: expr=\[a@0 ASC\]", +/// " ...nodes..." +/// " CoalescePartitionsExec", +/// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", +/// ``` +/// are transformed into +/// ```text +/// "SortPreservingMergeExec: \[a@0 ASC\]", +/// " SortExec: expr=\[a@0 ASC\]", +/// " ...nodes..." +/// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", +/// ``` +/// by following connections from [`CoalescePartitionsExec`]s to [`SortExec`]s. +/// By performing sorting in parallel, we can increase performance in some +/// scenarios. +/// +/// This optimization requires that there are no nodes between the [`SortExec`] +/// and the [`CoalescePartitionsExec`], which requires single partitioning. Do +/// not parallelize when the following scenario occurs: +/// ```text +/// "SortExec: expr=\[a@0 ASC\]", +/// " ...nodes requiring single partitioning..." +/// " CoalescePartitionsExec", +/// " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", +/// ``` +/// +/// **Steps** +/// 1. Checks if the plan is either a [`SortExec`], a [`SortPreservingMergeExec`], +/// or a [`CoalescePartitionsExec`]. Otherwise, does nothing. +/// 2. If the plan is a [`SortExec`] or a final [`SortPreservingMergeExec`] +/// (i.e. output partitioning is 1): +/// - Check for [`CoalescePartitionsExec`] in children. If found, check if +/// it can be removed (with possible [`RepartitionExec`]s). If so, remove +/// (see `remove_bottleneck_in_subplan`). +/// - If the plan is satisfying the ordering requirements, add a `SortExec`. +/// - Add an SPM above the plan and return. +/// 3. If the plan is a [`CoalescePartitionsExec`]: +/// - Check if it can be removed (with possible [`RepartitionExec`]s). +/// If so, remove (see `remove_bottleneck_in_subplan`). +pub fn parallelize_sorts( + mut requirements: PlanWithCorrespondingCoalescePartitions, +) -> Result> { + update_coalesce_ctx_children(&mut requirements); + + if requirements.children.is_empty() || !requirements.children[0].data { + // We only take an action when the plan is either a `SortExec`, a + // `SortPreservingMergeExec` or a `CoalescePartitionsExec`, and they + // all have a single child. Therefore, if the first child has no + // connection, we can return immediately. + Ok(Transformed::no(requirements)) + } else if (is_sort(&requirements.plan) + || is_sort_preserving_merge(&requirements.plan)) + && requirements.plan.output_partitioning().partition_count() <= 1 + { + // Take the initial sort expressions and requirements + let (sort_exprs, fetch) = get_sort_exprs(&requirements.plan)?; + let sort_reqs = LexRequirement::from(sort_exprs.clone()); + let sort_exprs = sort_exprs.clone(); + + // If there is a connection between a `CoalescePartitionsExec` and a + // global sort that satisfy the requirements (i.e. intermediate + // executors don't require single partition), then we can replace + // the `CoalescePartitionsExec` + `SortExec` cascade with a `SortExec` + // + `SortPreservingMergeExec` cascade to parallelize sorting. + requirements = remove_bottleneck_in_subplan(requirements)?; + // We also need to remove the self node since `remove_corresponding_coalesce_in_sub_plan` + // deals with the children and their children and so on. + requirements = requirements.children.swap_remove(0); + + requirements = add_sort_above_with_check(requirements, sort_reqs, fetch)?; + + let spm = + SortPreservingMergeExec::new(sort_exprs, Arc::clone(&requirements.plan)); + Ok(Transformed::yes( + PlanWithCorrespondingCoalescePartitions::new( + Arc::new(spm.with_fetch(fetch)), + false, + vec![requirements], + ), + )) + } else if is_coalesce_partitions(&requirements.plan) { + let fetch = requirements.plan.fetch(); + // There is an unnecessary `CoalescePartitionsExec` in the plan. + // This will handle the recursive `CoalescePartitionsExec` plans. + requirements = remove_bottleneck_in_subplan(requirements)?; + // For the removal of self node which is also a `CoalescePartitionsExec`. + requirements = requirements.children.swap_remove(0); + + Ok(Transformed::yes( + PlanWithCorrespondingCoalescePartitions::new( + Arc::new( + CoalescePartitionsExec::new(Arc::clone(&requirements.plan)) + .with_fetch(fetch), + ), + false, + vec![requirements], + ), + )) + } else { + Ok(Transformed::yes(requirements)) + } +} + +/// This function enforces sorting requirements and makes optimizations without +/// violating these requirements whenever possible. Requires a bottom-up traversal. +/// +/// **Steps** +/// 1. Analyze if there are any immediate removals of [`SortExec`]s. If so, +/// removes them (see `analyze_immediate_sort_removal`). +/// 2. For each child of the plan, if the plan requires an input ordering: +/// - Checks if ordering is satisfied with the child. If not: +/// - If the child has an output ordering, removes the unnecessary +/// `SortExec`. +/// - Adds sort above the child plan. +/// - (Plan not requires input ordering) +/// - Checks if the `SortExec` is neutralized in the plan. If so, +/// removes it. +/// 3. Check and modify window operator: +/// - Checks if the plan is a window operator, and connected with a sort. +/// If so, either tries to update the window definition or removes +/// unnecessary [`SortExec`]s (see `adjust_window_sort_removal`). +/// 4. Check and remove possibly unnecessary SPM: +/// - Checks if the plan is SPM and child 1 output partitions, if so +/// decides this SPM is unnecessary and removes it from the plan. +pub fn ensure_sorting( + mut requirements: PlanWithCorrespondingSort, +) -> Result> { + requirements = update_sort_ctx_children_data(requirements, false)?; + + // Perform naive analysis at the beginning -- remove already-satisfied sorts: + if requirements.children.is_empty() { + return Ok(Transformed::no(requirements)); + } + let maybe_requirements = analyze_immediate_sort_removal(requirements)?; + requirements = if !maybe_requirements.transformed { + maybe_requirements.data + } else { + return Ok(maybe_requirements); + }; + + let plan = &requirements.plan; + let mut updated_children = vec![]; + for (idx, (required_ordering, mut child)) in plan + .required_input_ordering() + .into_iter() + .zip(requirements.children) + .enumerate() + { + let physical_ordering = child.plan.output_ordering(); + + if let Some(required) = required_ordering { + let eq_properties = child.plan.equivalence_properties(); + let req = required.into_single(); + if !eq_properties.ordering_satisfy_requirement(req.clone())? { + // Make sure we preserve the ordering requirements: + if physical_ordering.is_some() { + child = update_child_to_remove_unnecessary_sort(idx, child, plan)?; + } + child = add_sort_above( + child, + req, + plan.downcast_ref::() + .map(|output| output.fetch()) + .unwrap_or(None), + ); + child = update_sort_ctx_children_data(child, true)?; + } + } else if physical_ordering.is_none() || !plan.maintains_input_order()[idx] { + // We have a `SortExec` whose effect may be neutralized by another + // order-imposing operator, remove this sort: + child = update_child_to_remove_unnecessary_sort(idx, child, plan)?; + } + updated_children.push(child); + } + requirements.children = updated_children; + requirements = requirements.update_plan_from_children()?; + // For window expressions, we can remove some sorts when we can + // calculate the result in reverse: + let child_node = &requirements.children[0]; + if is_window(&requirements.plan) && child_node.data { + return adjust_window_sort_removal(requirements).map(Transformed::yes); + } else if is_sort_preserving_merge(&requirements.plan) + && child_node.plan.output_partitioning().partition_count() <= 1 + { + // This `SortPreservingMergeExec` is unnecessary, input already has a + // single partition and no fetch is required. + let mut child_node = requirements.children.swap_remove(0); + if let Some(fetch) = requirements.plan.fetch() { + // Add the limit exec if the original SPM had a fetch: + child_node.plan = + Arc::new(LocalLimitExec::new(Arc::clone(&child_node.plan), fetch)); + } + return Ok(Transformed::yes(child_node)); + } + update_sort_ctx_children_data(requirements, false).map(Transformed::yes) +} + +/// Analyzes if there are any immediate sort removals by checking the `SortExec`s +/// and their ordering requirement satisfactions with children +/// If the sort is unnecessary, either replaces it with +/// [`SortPreservingMergeExec`] and/or a limit node, or removes the +/// [`SortExec`]. +/// Otherwise, returns the original plan +fn analyze_immediate_sort_removal( + mut node: PlanWithCorrespondingSort, +) -> Result> { + let Some(sort_exec) = node.plan.downcast_ref::() else { + return Ok(Transformed::no(node)); + }; + let sort_input = sort_exec.input(); + // Check if the sort is unnecessary: + let properties = sort_exec.properties(); + if let Some(ordering) = properties.output_ordering().cloned() { + let eqp = sort_input.equivalence_properties(); + if !eqp.ordering_satisfy(ordering)? { + return Ok(Transformed::no(node)); + } + } + node.plan = if !sort_exec.preserve_partitioning() + && sort_input.output_partitioning().partition_count() > 1 + { + // Replace the sort with a sort-preserving merge: + Arc::new( + SortPreservingMergeExec::new( + sort_exec.expr().clone(), + Arc::clone(sort_input), + ) + .with_fetch(sort_exec.fetch()), + ) as _ + } else { + // Remove the sort: + node.children = node.children.swap_remove(0).children; + if let Some(fetch) = sort_exec.fetch() { + let required_ordering = sort_exec.properties().output_ordering().cloned(); + // If the sort has a fetch, we need to add a limit: + if properties.output_partitioning().partition_count() == 1 { + let mut global_limit = + GlobalLimitExec::new(Arc::clone(sort_input), 0, Some(fetch)); + global_limit.set_required_ordering(required_ordering); + Arc::new(global_limit) + } else { + let mut local_limit = LocalLimitExec::new(Arc::clone(sort_input), fetch); + local_limit.set_required_ordering(required_ordering); + Arc::new(local_limit) + } + } else { + Arc::clone(sort_input) + } + }; + for child in node.children.iter_mut() { + child.data = false; + } + node.data = false; + Ok(Transformed::yes(node)) +} + +/// Adjusts a [`WindowAggExec`] or a [`BoundedWindowAggExec`] to determine +/// whether it may allow removing a sort. +fn adjust_window_sort_removal( + mut window_tree: PlanWithCorrespondingSort, +) -> Result { + // Window operators have a single child we need to adjust: + let child_node = remove_corresponding_sort_from_sub_plan( + window_tree.children.swap_remove(0), + matches!( + window_tree.plan.required_input_distribution()[0], + Distribution::SinglePartition + ), + )?; + window_tree.children.push(child_node); + + let child_plan = &window_tree.children[0].plan; + let (window_expr, new_window) = if let Some(exec) = + window_tree.plan.downcast_ref::() + { + let window_expr = exec.window_expr(); + let new_window = + get_best_fitting_window(window_expr, child_plan, &exec.partition_keys())?; + (window_expr, new_window) + } else if let Some(exec) = window_tree.plan.downcast_ref::() { + let window_expr = exec.window_expr(); + let new_window = + get_best_fitting_window(window_expr, child_plan, &exec.partition_keys())?; + (window_expr, new_window) + } else { + return plan_err!("Expected WindowAggExec or BoundedWindowAggExec"); + }; + + window_tree.plan = if let Some(new_window) = new_window { + // We were able to change the window to accommodate the input, use it: + new_window + } else { + // We were unable to change the window to accommodate the input, so we + // will insert a sort. + let reqs = window_tree.plan.required_input_ordering().swap_remove(0); + + // Satisfy the ordering requirement so that the window can run: + let mut child_node = window_tree.children.swap_remove(0); + if let Some(reqs) = reqs { + child_node = add_sort_above(child_node, reqs.into_single(), None); + } + let child_plan = Arc::clone(&child_node.plan); + window_tree.children.push(child_node); + + if window_expr.iter().all(|e| e.uses_bounded_memory()) { + Arc::new(BoundedWindowAggExec::try_new( + window_expr.to_vec(), + child_plan, + InputOrderMode::Sorted, + !window_expr[0].partition_by().is_empty(), + )?) as _ + } else { + Arc::new(WindowAggExec::try_new( + window_expr.to_vec(), + child_plan, + !window_expr[0].partition_by().is_empty(), + )?) as _ + } + }; + + window_tree.data = false; + Ok(window_tree) +} + +/// Removes parallelization-reducing, avoidable [`CoalescePartitionsExec`]s from +/// the plan in `node`. After the removal of such `CoalescePartitionsExec`s from +/// the plan, some of the remaining `RepartitionExec`s might become unnecessary. +/// Removes such `RepartitionExec`s from the plan as well. +fn remove_bottleneck_in_subplan( + mut requirements: PlanWithCorrespondingCoalescePartitions, +) -> Result { + let plan = &requirements.plan; + let children = &mut requirements.children; + if is_coalesce_partitions(&children[0].plan) { + // We can safely use the 0th index since we have a `CoalescePartitionsExec`. + let mut new_child_node = children[0].children.swap_remove(0); + while new_child_node.plan.output_partitioning() == plan.output_partitioning() + && is_repartition(&new_child_node.plan) + && is_repartition(plan) + { + new_child_node = new_child_node.children.swap_remove(0) + } + children[0] = new_child_node; + } else { + requirements.children = requirements + .children + .into_iter() + .map(|node| { + if node.data { + remove_bottleneck_in_subplan(node) + } else { + Ok(node) + } + }) + .collect::>()?; + } + let mut new_reqs = requirements.update_plan_from_children()?; + if let Some(repartition) = new_reqs.plan.downcast_ref::() { + let input_partitioning = repartition.input().output_partitioning(); + // We can remove this repartitioning operator if it is now a no-op: + let mut can_remove = input_partitioning.eq(repartition.partitioning()); + // We can also remove it if we ended up with an ineffective RR: + if let Partitioning::RoundRobinBatch(n_out) = repartition.partitioning() { + can_remove |= *n_out == input_partitioning.partition_count(); + } + if can_remove { + new_reqs = new_reqs.children.swap_remove(0) + } + } + Ok(new_reqs) +} + +/// Updates child to remove the unnecessary sort below it. +fn update_child_to_remove_unnecessary_sort( + child_idx: usize, + mut node: PlanWithCorrespondingSort, + parent: &Arc, +) -> Result { + if node.data { + let requires_single_partition = matches!( + parent.required_input_distribution()[child_idx], + Distribution::SinglePartition + ); + node = remove_corresponding_sort_from_sub_plan(node, requires_single_partition)?; + } + node.data = false; + Ok(node) +} + +/// Removes the sort from the plan in `node`. +fn remove_corresponding_sort_from_sub_plan( + mut node: PlanWithCorrespondingSort, + requires_single_partition: bool, +) -> Result { + // A `SortExec` is always at the bottom of the tree. + if let Some(sort_exec) = node.plan.downcast_ref::() { + // Do not remove sorts with fetch: + if sort_exec.fetch().is_none() { + node = node.children.swap_remove(0); + } + } else { + let mut any_connection = false; + let required_dist = node.plan.required_input_distribution(); + node.children = node + .children + .into_iter() + .enumerate() + .map(|(idx, child)| { + if child.data { + any_connection = true; + remove_corresponding_sort_from_sub_plan( + child, + matches!(required_dist[idx], Distribution::SinglePartition), + ) + } else { + Ok(child) + } + }) + .collect::>()?; + node = node.update_plan_from_children()?; + if any_connection || node.children.is_empty() { + node = update_sort_ctx_children_data(node, false)?; + } + + // Replace with variants that do not preserve order. + if is_sort_preserving_merge(&node.plan) { + node.children = node.children.swap_remove(0).children; + node.plan = Arc::clone(node.plan.children().swap_remove(0)); + } else if let Some(repartition) = node.plan.downcast_ref::() { + node.plan = Arc::new(RepartitionExec::try_new( + Arc::clone(&node.children[0].plan), + repartition.properties().output_partitioning().clone(), + )?) as _; + } + }; + // Deleting a merging sort may invalidate distribution requirements. + // Ensure that we stay compliant with such requirements: + if requires_single_partition && node.plan.output_partitioning().partition_count() > 1 + { + // If there is existing ordering, to preserve ordering use + // `SortPreservingMergeExec` instead of a `CoalescePartitionsExec`. + let plan = Arc::clone(&node.plan); + let fetch = plan.fetch(); + let plan = if let Some(ordering) = plan.output_ordering() { + Arc::new( + SortPreservingMergeExec::new(ordering.clone(), plan).with_fetch(fetch), + ) as _ + } else { + Arc::new(CoalescePartitionsExec::new(plan)) as _ + }; + node = PlanWithCorrespondingSort::new(plan, false, vec![node]); + node = update_sort_ctx_children_data(node, false)?; + } + Ok(node) +} + +/// Converts an [ExecutionPlan] trait object to a [LexOrdering] reference when possible. +fn get_sort_exprs( + sort_any: &Arc, +) -> Result<(&LexOrdering, Option)> { + if let Some(sort_exec) = sort_any.downcast_ref::() { + Ok((sort_exec.expr(), sort_exec.fetch())) + } else if let Some(spm) = sort_any.downcast_ref::() { + Ok((spm.expr(), spm.fetch())) + } else { + plan_err!("Given ExecutionPlan is not a SortExec or a SortPreservingMergeExec") + } +} + +// Tests are in tests/cases/enforce_sorting.rs diff --git a/datafusion/physical-optimizer/src/ensure_requirements/enforce_sorting/replace_with_order_preserving_variants.rs b/datafusion/physical-optimizer/src/ensure_requirements/enforce_sorting/replace_with_order_preserving_variants.rs new file mode 100644 index 0000000000000..6ab84dc95eab9 --- /dev/null +++ b/datafusion/physical-optimizer/src/ensure_requirements/enforce_sorting/replace_with_order_preserving_variants.rs @@ -0,0 +1,299 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule that replaces executors that lose ordering with their +//! order-preserving variants when it is helpful; either in terms of +//! performance or to accommodate unbounded streams by fixing the pipeline. + +use std::sync::Arc; + +use crate::utils::{ + is_coalesce_partitions, is_repartition, is_sort, is_sort_preserving_merge, +}; + +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::Transformed; +use datafusion_common::{Result, assert_or_internal_err}; +use datafusion_physical_plan::ExecutionPlanProperties; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_plan::execution_plan::EmissionType; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion_physical_plan::tree_node::PlanContext; + +use itertools::izip; + +/// For a given `plan`, this object carries the information one needs from its +/// descendants to decide whether it is beneficial to replace order-losing (but +/// somewhat faster) variants of certain operators with their order-preserving +/// (but somewhat slower) cousins. +pub type OrderPreservationContext = PlanContext; + +/// Updates order-preservation data for all children of the given node. +pub fn update_order_preservation_ctx_children_data(opc: &mut OrderPreservationContext) { + for PlanContext { + plan, + children, + data, + } in opc.children.iter_mut() + { + let maintains_input_order = plan.maintains_input_order(); + let inspect_child = |idx| { + maintains_input_order[idx] + || is_coalesce_partitions(plan) + || is_repartition(plan) + }; + + // We cut the path towards nodes that do not maintain ordering. + for (idx, c) in children.iter_mut().enumerate() { + c.data &= inspect_child(idx); + } + + let plan_children = plan.children(); + *data = if plan_children.is_empty() { + false + } else if !children[0].data + && ((is_repartition(plan) && !maintains_input_order[0]) + || (is_coalesce_partitions(plan) + && plan_children[0].output_ordering().is_some())) + { + // We either have a RepartitionExec or a CoalescePartitionsExec + // and they lose their input ordering, so initiate connection: + true + } else { + // Maintain connection if there is a child with a connection, + // and operator can possibly maintain that connection (either + // in its current form or when we replace it with the corresponding + // order preserving operator). + children + .iter() + .enumerate() + .any(|(idx, c)| c.data && inspect_child(idx)) + } + } + opc.data = false; +} + +/// Calculates the updated plan by replacing operators that lose ordering +/// inside `sort_input` with their order-preserving variants. This will +/// generate an alternative plan, which will be accepted or rejected later on +/// depending on whether it helps us remove a `SortExec`. +pub fn plan_with_order_preserving_variants( + mut sort_input: OrderPreservationContext, + // Flag indicating that it is desirable to replace `RepartitionExec`s with + // `SortPreservingRepartitionExec`s: + is_spr_better: bool, + // Flag indicating that it is desirable to replace `CoalescePartitionsExec`s + // with `SortPreservingMergeExec`s: + is_spm_better: bool, + fetch: Option, +) -> Result { + sort_input.children = sort_input + .children + .into_iter() + .map(|node| { + // Update descendants in the given tree if there is a connection: + if node.data { + plan_with_order_preserving_variants( + node, + is_spr_better, + is_spm_better, + fetch, + ) + } else { + Ok(node) + } + }) + .collect::>()?; + sort_input.data = false; + + if is_repartition(&sort_input.plan) + && !sort_input.plan.maintains_input_order()[0] + && is_spr_better + { + // When a `RepartitionExec` doesn't preserve ordering, replace it with + // a sort-preserving variant if appropriate: + let child = Arc::clone(&sort_input.children[0].plan); + let partitioning = sort_input.plan.output_partitioning().clone(); + sort_input.plan = Arc::new( + RepartitionExec::try_new(child, partitioning)?.with_preserve_order(), + ) as _; + sort_input.children[0].data = true; + return Ok(sort_input); + } else if is_coalesce_partitions(&sort_input.plan) && is_spm_better { + let child = &sort_input.children[0].plan; + if let Some(ordering) = child.output_ordering() { + let mut fetch = fetch; + if let Some(coalesce_fetch) = sort_input.plan.fetch() { + fetch = match fetch { + Some(sort_fetch) => { + assert_or_internal_err!( + coalesce_fetch >= sort_fetch, + "CoalescePartitionsExec fetch [{:?}] should be greater than or equal to SortExec fetch [{:?}]", + coalesce_fetch, + sort_fetch + ); + Some(sort_fetch) + } + None => { + // If the sort node does not have a fetch, we need to keep the coalesce node's fetch. + Some(coalesce_fetch) + } + }; + }; + // When the input of a `CoalescePartitionsExec` has an ordering, + // replace it with a `SortPreservingMergeExec` if appropriate: + let spm = SortPreservingMergeExec::new(ordering.clone(), Arc::clone(child)) + .with_fetch(fetch); + sort_input.plan = Arc::new(spm) as _; + sort_input.children[0].data = true; + return Ok(sort_input); + } + } + + sort_input.update_plan_from_children() +} + +/// Calculates the updated plan by replacing operators that preserve ordering +/// inside `sort_input` with their order-breaking variants. This will restore +/// the original plan modified by [`plan_with_order_preserving_variants`]. +pub fn plan_with_order_breaking_variants( + mut sort_input: OrderPreservationContext, +) -> Result { + let plan = &sort_input.plan; + sort_input.children = izip!( + sort_input.children, + plan.maintains_input_order(), + plan.required_input_ordering() + ) + .map(|(node, maintains, required_ordering)| { + // Replace with non-order preserving variants as long as ordering is + // not required by intermediate operators: + if !maintains { + return Ok(node); + } else if is_sort_preserving_merge(plan) { + return plan_with_order_breaking_variants(node); + } else if let Some(required_ordering) = required_ordering { + let eqp = node.plan.equivalence_properties(); + if eqp.ordering_satisfy_requirement(required_ordering.into_single())? { + return Ok(node); + } + } + plan_with_order_breaking_variants(node) + }) + .collect::>()?; + sort_input.data = false; + + if is_repartition(plan) && plan.maintains_input_order()[0] { + // When a `RepartitionExec` preserves ordering, replace it with a + // non-sort-preserving variant: + let child = Arc::clone(&sort_input.children[0].plan); + let partitioning = plan.output_partitioning().clone(); + sort_input.plan = Arc::new(RepartitionExec::try_new(child, partitioning)?) as _; + } else if is_sort_preserving_merge(plan) { + // Replace `SortPreservingMergeExec` with a `CoalescePartitionsExec` + // SPM may have `fetch`, so pass it to the `CoalescePartitionsExec` + let child = Arc::clone(&sort_input.children[0].plan); + let coalesce = + Arc::new(CoalescePartitionsExec::new(child).with_fetch(plan.fetch())); + sort_input.plan = coalesce; + } else { + return sort_input.update_plan_from_children(); + } + + sort_input.children[0].data = false; + Ok(sort_input) +} + +/// The `replace_with_order_preserving_variants` optimizer sub-rule tries to +/// remove `SortExec`s from the physical plan by replacing operators that do +/// not preserve ordering with their order-preserving variants; i.e. by replacing +/// ordinary `RepartitionExec`s with their sort-preserving variants or by replacing +/// `CoalescePartitionsExec`s with `SortPreservingMergeExec`s. +/// +/// If this replacement is helpful for removing a `SortExec`, it updates the plan. +/// Otherwise, it leaves the plan unchanged. +/// +/// NOTE: This optimizer sub-rule will only produce sort-preserving `RepartitionExec`s +/// if the query is bounded or if the config option `prefer_existing_sort` is +/// set to `true`. +/// +/// The algorithm flow is simply like this: +/// 1. Visit nodes of the physical plan bottom-up and look for `SortExec` nodes. +/// During the traversal, keep track of operators that maintain ordering (or +/// can maintain ordering when replaced by an order-preserving variant) until +/// a `SortExec` is found. +/// 2. When a `SortExec` is found, update the child of the `SortExec` by replacing +/// operators that do not preserve ordering in the tree with their order +/// preserving variants. +/// 3. Check if the `SortExec` is still necessary in the updated plan by comparing +/// its input ordering with the output ordering it imposes. We do this because +/// replacing operators that lose ordering with their order-preserving variants +/// enables us to preserve the previously lost ordering at the input of `SortExec`. +/// 4. If the `SortExec` in question turns out to be unnecessary, remove it and +/// use updated plan. Otherwise, use the original plan. +/// 5. Continue the bottom-up traversal until another `SortExec` is seen, or the +/// traversal is complete. +pub fn replace_with_order_preserving_variants( + mut requirements: OrderPreservationContext, + // A flag indicating that replacing `RepartitionExec`s with sort-preserving + // variants is desirable when it helps to remove a `SortExec` from the plan. + // If this flag is `false`, this replacement should only be made to fix the + // pipeline (streaming). + is_spr_better: bool, + // A flag indicating that replacing `CoalescePartitionsExec`s with + // `SortPreservingMergeExec`s is desirable when it helps to remove a + // `SortExec` from the plan. If this flag is `false`, this replacement + // should only be made to fix the pipeline (streaming). + is_spm_better: bool, + config: &ConfigOptions, +) -> Result> { + update_order_preservation_ctx_children_data(&mut requirements); + if !(is_sort(&requirements.plan) && requirements.children[0].data) { + return Ok(Transformed::no(requirements)); + } + + // For unbounded cases, we replace with the order-preserving variant in any + // case, as doing so helps fix the pipeline. Also replace if config allows. + let use_order_preserving_variant = config.optimizer.prefer_existing_sort + || (requirements.plan.boundedness().is_unbounded() + && requirements.plan.pipeline_behavior() == EmissionType::Final); + + // Create an alternate plan with order-preserving variants: + let mut alternate_plan = plan_with_order_preserving_variants( + requirements.children.swap_remove(0), + is_spr_better || use_order_preserving_variant, + is_spm_better || use_order_preserving_variant, + requirements.plan.fetch(), + )?; + + // If the alternate plan makes this sort unnecessary, accept the alternate: + if let Some(ordering) = requirements.plan.output_ordering() { + let eqp = alternate_plan.plan.equivalence_properties(); + if !eqp.ordering_satisfy(ordering.clone())? { + // The alternate plan does not help, use faster order-breaking variants: + alternate_plan = plan_with_order_breaking_variants(alternate_plan)?; + alternate_plan.data = false; + requirements.children = vec![alternate_plan]; + return Ok(Transformed::yes(requirements)); + } + } + for child in alternate_plan.children.iter_mut() { + child.data = false; + } + Ok(Transformed::yes(alternate_plan)) +} diff --git a/datafusion/physical-optimizer/src/ensure_requirements/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/ensure_requirements/enforce_sorting/sort_pushdown.rs new file mode 100644 index 0000000000000..261cf701c870f --- /dev/null +++ b/datafusion/physical-optimizer/src/ensure_requirements/enforce_sorting/sort_pushdown.rs @@ -0,0 +1,966 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::Debug; +use std::sync::Arc; + +use crate::utils::{ + add_sort_above_with_distribution, is_sort, is_sort_preserving_merge, is_union, + is_window, +}; + +use arrow::datatypes::SchemaRef; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{HashSet, JoinSide, Result, internal_err}; +use datafusion_expr::JoinType; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::utils::collect_columns; +use datafusion_physical_expr::{ + Distribution, EquivalenceProperties, add_offset_to_physical_sort_exprs, +}; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr, + PhysicalSortRequirement, +}; +use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::execution_plan::CardinalityEffect; +use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::joins::utils::{ + ColumnIndex, calculate_join_output_ordering, +}; +use datafusion_physical_plan::joins::{HashJoinExec, SortMergeJoinExec}; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::tree_node::PlanContext; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; + +/// "Data class" used by sort pushdown (now driven from `EnsureRequirements`) +/// to push down [`SortExec`] in the plan. In some cases the total +/// computational cost is reduced by pushing down `SortExec`s through certain +/// executors. The object carries the parent required ordering, the (optional) +/// `fetch` value of the parent node, and the parent's distribution requirement +/// (used by the distribution-aware pushdown path) as its data. +#[derive(Clone, Debug)] +pub struct ParentRequirements { + ordering_requirement: Option, + fetch: Option, + /// The distribution required by the consumer above any SortExec we insert. + /// When this is `SinglePartition` and the input has multiple partitions, + /// `add_sort_above_with_distribution` wraps the sort in `SortPreservingMergeExec`. + distribution_requirement: Distribution, +} + +impl Default for ParentRequirements { + fn default() -> Self { + Self { + ordering_requirement: None, + fetch: None, + distribution_requirement: Distribution::UnspecifiedDistribution, + } + } +} + +pub type SortPushDown = PlanContext; + +/// Assigns the ordering requirement of the root node to the its children. +pub fn assign_initial_requirements(sort_push_down: &mut SortPushDown) { + let reqs = sort_push_down.plan.required_input_ordering(); + let dists = sort_push_down.plan.required_input_distribution(); + for (idx, (child, requirement)) in + sort_push_down.children.iter_mut().zip(reqs).enumerate() + { + child.data = ParentRequirements { + ordering_requirement: requirement, + fetch: child.plan.fetch(), + distribution_requirement: dists + .get(idx) + .cloned() + .unwrap_or(Distribution::UnspecifiedDistribution), + }; + } +} + +/// Tries to push down the sort requirements as far as possible, if decides a `SortExec` is unnecessary removes it. +pub fn pushdown_sorts(sort_push_down: SortPushDown) -> Result { + sort_push_down + .transform_down(pushdown_sorts_helper) + .map(|transformed| transformed.data) +} + +fn min_fetch(f1: Option, f2: Option) -> Option { + match (f1, f2) { + (Some(f1), Some(f2)) => Some(f1.min(f2)), + (Some(_), _) => f1, + (_, Some(_)) => f2, + _ => None, + } +} + +/// Returns the stricter of two distribution requirements. +/// `SinglePartition` is the strictest. +fn stronger_distribution(a: &Distribution, b: &Distribution) -> Distribution { + match (a, b) { + (Distribution::SinglePartition, _) | (_, Distribution::SinglePartition) => { + Distribution::SinglePartition + } + (Distribution::HashPartitioned(_), _) => a.clone(), + (_, Distribution::HashPartitioned(_)) => b.clone(), + _ => Distribution::UnspecifiedDistribution, + } +} + +fn pushdown_sorts_helper( + mut sort_push_down: SortPushDown, +) -> Result> { + let plan = sort_push_down.plan; + let parent_fetch = sort_push_down.data.fetch; + let parent_distribution = sort_push_down.data.distribution_requirement.clone(); + + let Some(parent_requirement) = sort_push_down.data.ordering_requirement.clone() + else { + // If there are no ordering requirements from the parent, nothing to do + // unless we have a sort. + if is_sort(&plan) { + let Some(sort_ordering) = plan.output_ordering().cloned() else { + return internal_err!("SortExec should have output ordering"); + }; + // The sort is unnecessary, just propagate the stricter fetch and + // ordering requirements. + let fetch = min_fetch(plan.fetch(), parent_fetch); + sort_push_down = sort_push_down + .children + .swap_remove(0) + .update_plan_from_children()?; + sort_push_down.data.fetch = fetch; + sort_push_down.data.ordering_requirement = + Some(OrderingRequirements::from(sort_ordering)); + // Recursive call to helper, so it doesn't transform_down and miss + // the new node (previous child of sort): + return pushdown_sorts_helper(sort_push_down); + } + sort_push_down.plan = plan; + // No ordering is being pushed; use each child's own distribution requirement + let dists = sort_push_down.plan.required_input_distribution(); + for (idx, child) in sort_push_down.children.iter_mut().enumerate() { + child.data.distribution_requirement = dists + .get(idx) + .cloned() + .unwrap_or(Distribution::UnspecifiedDistribution); + } + return Ok(Transformed::no(sort_push_down)); + }; + + let eqp = plan.equivalence_properties(); + let satisfy_parent = + eqp.ordering_satisfy_requirement(parent_requirement.first().clone())?; + + if is_sort(&plan) { + let Some(sort_ordering) = plan.output_ordering().cloned() else { + return internal_err!("SortExec should have output ordering"); + }; + + let sort_fetch = plan.fetch(); + let parent_is_stricter = eqp.requirements_compatible( + parent_requirement.first().clone(), + sort_ordering.clone().into(), + ); + + // Remove the current sort as we are either going to prove that it is + // unnecessary, or replace it with a stricter sort. + sort_push_down = sort_push_down + .children + .swap_remove(0) + .update_plan_from_children()?; + if !satisfy_parent && !parent_is_stricter { + // The sort was imposing a different ordering than the one being + // pushed down. Replace it with a sort that matches the pushed-down + // ordering, and continue the pushdown. + // Add back the sort (distribution-aware): + sort_push_down = add_sort_above_with_distribution( + sort_push_down, + parent_requirement.into_single(), + parent_fetch, + &parent_distribution, + ); + // Update pushdown requirements: + sort_push_down.children[0].data = ParentRequirements { + ordering_requirement: Some(OrderingRequirements::from(sort_ordering)), + fetch: sort_fetch, + distribution_requirement: Distribution::UnspecifiedDistribution, + }; + return Ok(Transformed::yes(sort_push_down)); + } else { + // Sort was unnecessary, just propagate the stricter fetch and + // ordering requirements. Reset distribution to Unspecified + // because the sort we're removing may have been below a + // partition-merging node (like SortPreservingMergeExec) that + // already satisfies SinglePartition. + sort_push_down.data.fetch = min_fetch(sort_fetch, parent_fetch); + sort_push_down.data.distribution_requirement = + Distribution::UnspecifiedDistribution; + let current_is_stricter = eqp.requirements_compatible( + sort_ordering.clone().into(), + parent_requirement.first().clone(), + ); + sort_push_down.data.ordering_requirement = if current_is_stricter { + Some(OrderingRequirements::from(sort_ordering)) + } else { + Some(parent_requirement) + }; + // Recursive call to helper, so it doesn't transform_down and miss + // the new node (previous child of sort): + return pushdown_sorts_helper(sort_push_down); + } + } + + sort_push_down.plan = plan; + if satisfy_parent { + // For non-sort operators which satisfy ordering: + let reqs = sort_push_down.plan.required_input_ordering(); + let dists = sort_push_down.plan.required_input_distribution(); + + // If this node already outputs single partition, don't push SinglePartition + // requirement to children (they're below the merge point). + let effective_parent_dist = + if sort_push_down.plan.output_partitioning().partition_count() == 1 { + Distribution::UnspecifiedDistribution + } else { + parent_distribution.clone() + }; + + for (idx, (child, order)) in + sort_push_down.children.iter_mut().zip(reqs).enumerate() + { + child.data.ordering_requirement = order; + child.data.fetch = min_fetch(parent_fetch, child.data.fetch); + child.data.distribution_requirement = stronger_distribution( + &effective_parent_dist, + dists + .get(idx) + .unwrap_or(&Distribution::UnspecifiedDistribution), + ); + } + } else if let Some(adjusted) = pushdown_requirement_to_children( + &sort_push_down.plan, + parent_requirement.clone(), + parent_fetch, + )? { + // For operators that can take a sort pushdown, continue with updated + // requirements. If this node already outputs single partition (e.g. SPM), + // don't push SinglePartition to children. + let current_fetch = sort_push_down.plan.fetch(); + let dists = sort_push_down.plan.required_input_distribution(); + let effective_dist = + if sort_push_down.plan.output_partitioning().partition_count() == 1 { + Distribution::UnspecifiedDistribution + } else { + parent_distribution.clone() + }; + for (idx, (child, order)) in + sort_push_down.children.iter_mut().zip(adjusted).enumerate() + { + child.data.ordering_requirement = order; + child.data.fetch = min_fetch(current_fetch, parent_fetch); + child.data.distribution_requirement = stronger_distribution( + &effective_dist, + dists + .get(idx) + .unwrap_or(&Distribution::UnspecifiedDistribution), + ); + } + sort_push_down.data.ordering_requirement = None; + } else { + // Can not push down requirements, add new `SortExec` (distribution-aware): + sort_push_down = add_sort_above_with_distribution( + sort_push_down, + parent_requirement.into_single(), + parent_fetch, + &parent_distribution, + ); + assign_initial_requirements(&mut sort_push_down); + } + Ok(Transformed::yes(sort_push_down)) +} + +/// Calculate the pushdown ordering requirements for children. +/// If sort cannot be pushed down, return None. +fn pushdown_requirement_to_children( + plan: &Arc, + parent_required: OrderingRequirements, + parent_fetch: Option, +) -> Result>>> { + // If there is a limit on the parent plan we cannot push it down through operators that change the cardinality. + // E.g. consider if LIMIT 2 is applied below a FilteExec that filters out 1/2 of the rows we'll end up with 1 row instead of 2. + // If the LIMIT is applied after the FilterExec and the FilterExec returns > 2 rows we'll end up with 2 rows (correct). + if parent_fetch.is_some() && !plan.supports_limit_pushdown() { + return Ok(None); + } + // Note: we still need to check the cardinality effect of the plan here, because the + // limit pushdown is not always safe, even if the plan supports it. Here's an example: + // + // UnionExec advertises `supports_limit_pushdown() == true` because it can + // forward a LIMIT k to each of its children—i.e. apply “LIMIT k” separately + // on each branch before merging them together. + // + // However, UnionExec’s `cardinality_effect() == GreaterEqual` (it sums up + // all child row counts), so pushing a global TopK/LIMIT through it would + // break the semantics of “take the first k rows of the combined result.” + // + // For example, with two branches A and B and k = 3: + // — Global LIMIT: take the first 3 rows from (A ∪ B) after merging. + // — Pushed down: take 3 from A, 3 from B, then merge → up to 6 rows! + // + // That’s why we still block on cardinality: even though UnionExec can + // push a LIMIT to its children, its GreaterEqual effect means it cannot + // preserve the global TopK semantics. + if parent_fetch.is_some() { + match plan.cardinality_effect() { + CardinalityEffect::Equal => { + // safe: only true sources (e.g. CoalesceBatchesExec, ProjectionExec) pass + } + _ => return Ok(None), + } + } + + let maintains_input_order = plan.maintains_input_order(); + if is_window(plan) { + let mut required_input_ordering = plan.required_input_ordering(); + let maybe_child_requirement = required_input_ordering.swap_remove(0); + let child_plan = plan.children().swap_remove(0); + let Some(child_req) = maybe_child_requirement else { + return Ok(None); + }; + match determine_children_requirement(&parent_required, &child_req, child_plan) { + RequirementsCompatibility::Satisfy => Ok(Some(vec![Some(child_req)])), + RequirementsCompatibility::Compatible(adjusted) => { + // If parent requirements are more specific than output ordering + // of the window plan, then we can deduce that the parent expects + // an ordering from the columns created by window functions. If + // that's the case, we block the pushdown of sort operation. + if !plan + .equivalence_properties() + .ordering_satisfy_requirement(parent_required.into_single())? + { + return Ok(None); + } + + Ok(Some(vec![adjusted])) + } + RequirementsCompatibility::NonCompatible => Ok(None), + } + } else if let Some(sort_exec) = plan.downcast_ref::() { + let Some(sort_ordering) = sort_exec.properties().output_ordering().cloned() + else { + return internal_err!("SortExec should have output ordering"); + }; + sort_exec + .properties() + .eq_properties + .requirements_compatible( + parent_required.first().clone(), + sort_ordering.into(), + ) + .then(|| Ok(vec![Some(parent_required)])) + .transpose() + } else if plan.fetch().is_some() + && plan.supports_limit_pushdown() + && plan + .maintains_input_order() + .into_iter() + .all(|maintain| maintain) + { + // Push down through operator with fetch when: + // - requirement is aligned with output ordering + // - it preserves ordering during execution + let Some(ordering) = plan.properties().output_ordering() else { + return Ok(Some(vec![Some(parent_required)])); + }; + if plan.properties().eq_properties.requirements_compatible( + parent_required.first().clone(), + ordering.clone().into(), + ) { + Ok(Some(vec![Some(parent_required)])) + } else { + Ok(None) + } + } else if is_union(plan) { + // `UnionExec` does not have real sort requirements for its input, we + // just propagate the sort requirements down: + Ok(Some(vec![Some(parent_required); plan.children().len()])) + } else if let Some(smj) = plan.downcast_ref::() { + let left_columns_len = smj.left().schema().fields().len(); + let parent_ordering: Vec = parent_required + .first() + .iter() + .cloned() + .map(Into::into) + .collect(); + let eqp = smj.properties().equivalence_properties(); + match expr_source_side(eqp, parent_ordering, smj.join_type(), left_columns_len) { + Some((JoinSide::Left, ordering)) => try_pushdown_requirements_to_join( + smj, + parent_required.into_single(), + ordering, + JoinSide::Left, + ), + Some((JoinSide::Right, ordering)) => { + let right_offset = + smj.schema().fields.len() - smj.right().schema().fields.len(); + let ordering = add_offset_to_physical_sort_exprs( + ordering, + -(right_offset as isize), + )?; + try_pushdown_requirements_to_join( + smj, + parent_required.into_single(), + ordering, + JoinSide::Right, + ) + } + _ => { + // Can not decide the expr side for SortMergeJoinExec, can not push down + Ok(None) + } + } + } else if let Some(aggregate_exec) = plan.downcast_ref::() { + handle_aggregate_pushdown(aggregate_exec, parent_required) + } else if maintains_input_order.is_empty() + || !maintains_input_order.iter().any(|o| *o) + || plan.is::() + || plan.is::() + // TODO: Add support for Projection push down + || plan.is::() + || pushdown_would_violate_requirements(&parent_required, plan.as_ref()) + { + // If the current plan is a leaf node or can not maintain any of the input ordering, can not pushed down requirements. + // For RepartitionExec, we always choose to not push down the sort requirements even the RepartitionExec(input_partition=1) could maintain input ordering. + // Pushing down is not beneficial + Ok(None) + } else if is_sort_preserving_merge(plan) { + let new_ordering = LexOrdering::from(parent_required.first().clone()); + let mut spm_eqs = plan.equivalence_properties().clone(); + let old_ordering = spm_eqs.output_ordering().unwrap(); + // Sort preserving merge will have new ordering, one requirement above is pushed down to its below. + let change = spm_eqs.reorder(new_ordering)?; + if !change || spm_eqs.ordering_satisfy(old_ordering)? { + // Can push-down through SortPreservingMergeExec, because parent requirement is finer + // than SortPreservingMergeExec output ordering. + Ok(Some(vec![Some(parent_required)])) + } else { + // Do not push-down through SortPreservingMergeExec when + // ordering requirement invalidates requirement of sort preserving merge exec. + Ok(None) + } + } else if let Some(hash_join) = plan.downcast_ref::() { + handle_hash_join(hash_join, parent_required) + } else { + handle_custom_pushdown(plan, parent_required, &maintains_input_order) + } + // TODO: Add support for Projection push down +} + +/// Try to push sorting through [`AggregateExec`] +/// +/// `AggregateExec` only preserves the input order of its group by columns +/// (not aggregates in general, which are formed from arbitrary expressions over +/// input) +/// +/// Thus function rewrites the parent required ordering in terms of the +/// aggregate input if possible. This rewritten requirement represents the +/// ordering of the `AggregateExec`'s **input** that would also satisfy the +/// **parent** ordering. +/// +/// If no such mapping is possible (e.g. because the sort references aggregate +/// columns), returns None. +fn handle_aggregate_pushdown( + aggregate_exec: &AggregateExec, + parent_required: OrderingRequirements, +) -> Result>>> { + if !aggregate_exec + .maintains_input_order() + .into_iter() + .any(|o| o) + { + return Ok(None); + } + + let group_expr = aggregate_exec.group_expr(); + // GROUPING SETS introduce additional output columns and NULL substitutions; + // skip pushdown until we can map those cases safely. + if group_expr.has_grouping_set() { + return Ok(None); + } + + let group_input_exprs = group_expr.input_exprs(); + let parent_requirement = parent_required.into_single(); + let mut child_requirement = Vec::with_capacity(parent_requirement.len()); + + for req in parent_requirement { + // Sort above AggregateExec should reference its output columns. Map each + // output group-by column to its original input expression. + let Some(column) = req.expr.downcast_ref::() else { + return Ok(None); + }; + if column.index() >= group_input_exprs.len() { + // AggregateExec does not produce output that is sorted on aggregate + // columns so those can not be pushed through. + return Ok(None); + } + child_requirement.push(PhysicalSortRequirement::new( + Arc::clone(&group_input_exprs[column.index()]), + req.options, + )); + } + + let Some(child_requirement) = LexRequirement::new(child_requirement) else { + return Ok(None); + }; + + // Keep sort above aggregate unless input ordering already satisfies the + // mapped requirement. + if aggregate_exec + .input() + .equivalence_properties() + .ordering_satisfy_requirement(child_requirement.iter().cloned())? + { + let child_requirements = OrderingRequirements::new(child_requirement); + Ok(Some(vec![Some(child_requirements)])) + } else { + Ok(None) + } +} + +/// Return true if pushing the sort requirements through a node would violate +/// the input sorting requirements for the plan +fn pushdown_would_violate_requirements( + parent_required: &OrderingRequirements, + child: &dyn ExecutionPlan, +) -> bool { + child + .required_input_ordering() + .into_iter() + // If there is no requirement, pushing down would not violate anything. + .flatten() + .any(|child_required| { + // Check if the plan's requirements would still be satisfied if we + // pushed down the parent requirements: + child_required + .into_single() + .iter() + .zip(parent_required.first().iter()) + .all(|(c, p)| !c.compatible(p)) + }) +} + +/// Determine children requirements: +/// - If children requirements are more specific, do not push down parent +/// requirements. +/// - If parent requirements are more specific, push down parent requirements. +/// - If they are not compatible, need to add a sort. +fn determine_children_requirement( + parent_required: &OrderingRequirements, + child_requirement: &OrderingRequirements, + child_plan: &Arc, +) -> RequirementsCompatibility { + let eqp = child_plan.equivalence_properties(); + if eqp.requirements_compatible( + child_requirement.first().clone(), + parent_required.first().clone(), + ) { + // Child requirements are more specific, no need to push down. + RequirementsCompatibility::Satisfy + } else if eqp.requirements_compatible( + parent_required.first().clone(), + child_requirement.first().clone(), + ) { + // Parent requirements are more specific, adjust child's requirements + // and push down the new requirements: + RequirementsCompatibility::Compatible(Some(parent_required.clone())) + } else { + RequirementsCompatibility::NonCompatible + } +} + +fn try_pushdown_requirements_to_join( + smj: &SortMergeJoinExec, + parent_required: LexRequirement, + sort_exprs: Vec, + push_side: JoinSide, +) -> Result>>> { + let mut smj_required_orderings = smj.required_input_ordering(); + + let ordering = LexOrdering::new(sort_exprs.clone()); + let (new_left_ordering, new_right_ordering) = match push_side { + JoinSide::Left => { + let mut left_eq_properties = smj.left().equivalence_properties().clone(); + left_eq_properties.reorder(sort_exprs)?; + let Some(left_requirement) = smj_required_orderings.swap_remove(0) else { + return Ok(None); + }; + if !left_eq_properties + .ordering_satisfy_requirement(left_requirement.into_single())? + { + return Ok(None); + } + // After re-ordering, requirement is still satisfied: + (ordering.as_ref(), smj.right().output_ordering()) + } + JoinSide::Right => { + let mut right_eq_properties = smj.right().equivalence_properties().clone(); + right_eq_properties.reorder(sort_exprs)?; + let Some(right_requirement) = smj_required_orderings.swap_remove(1) else { + return Ok(None); + }; + if !right_eq_properties + .ordering_satisfy_requirement(right_requirement.into_single())? + { + return Ok(None); + } + // After re-ordering, requirement is still satisfied: + (smj.left().output_ordering(), ordering.as_ref()) + } + JoinSide::None => return Ok(None), + }; + let join_type = smj.join_type(); + let probe_side = SortMergeJoinExec::probe_side(&join_type); + let new_output_ordering = calculate_join_output_ordering( + new_left_ordering, + new_right_ordering, + join_type, + smj.left().schema().fields.len(), + &smj.maintains_input_order(), + Some(probe_side), + )?; + let mut smj_eqs = smj.properties().equivalence_properties().clone(); + if let Some(new_output_ordering) = new_output_ordering { + // smj will have this ordering when its input changes. + smj_eqs.reorder(new_output_ordering)?; + } + let should_pushdown = smj_eqs.ordering_satisfy_requirement(parent_required)?; + Ok(should_pushdown.then(|| { + let mut required_input_ordering = smj.required_input_ordering(); + let new_req = ordering.map(Into::into); + match push_side { + JoinSide::Left => { + required_input_ordering[0] = new_req; + } + JoinSide::Right => { + required_input_ordering[1] = new_req; + } + JoinSide::None => unreachable!(), + } + required_input_ordering + })) +} + +fn expr_source_side( + eqp: &EquivalenceProperties, + mut ordering: Vec, + join_type: JoinType, + left_columns_len: usize, +) -> Option<(JoinSide, Vec)> { + // TODO: Handle the case where a prefix of the ordering comes from the left + // and a suffix from the right. + match join_type { + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftMark + | JoinType::RightMark => { + let eq_group = eqp.eq_group(); + let mut right_ordering = ordering.clone(); + let (mut valid_left, mut valid_right) = (true, true); + for (left, right) in ordering.iter_mut().zip(right_ordering.iter_mut()) { + let col = left.expr.downcast_ref::()?; + let eq_class = eq_group.get_equivalence_class(&left.expr); + if col.index() < left_columns_len { + if valid_right { + valid_right = eq_class.is_some_and(|cls| { + for expr in cls.iter() { + if expr + .downcast_ref::() + .is_some_and(|c| c.index() >= left_columns_len) + { + right.expr = Arc::clone(expr); + return true; + } + } + false + }); + } + } else if valid_left { + valid_left = eq_class.is_some_and(|cls| { + for expr in cls.iter() { + if expr + .downcast_ref::() + .is_some_and(|c| c.index() < left_columns_len) + { + left.expr = Arc::clone(expr); + return true; + } + } + false + }); + }; + if !(valid_left || valid_right) { + return None; + } + } + if valid_left { + Some((JoinSide::Left, ordering)) + } else if valid_right { + Some((JoinSide::Right, right_ordering)) + } else { + // TODO: Handle the case where we can push down to both sides. + None + } + } + JoinType::LeftSemi | JoinType::LeftAnti => ordering + .iter() + .all(|e| e.expr.is::()) + .then_some((JoinSide::Left, ordering)), + JoinType::RightSemi | JoinType::RightAnti => ordering + .iter() + .all(|e| e.expr.is::()) + .then_some((JoinSide::Right, ordering)), + } +} + +/// Handles the custom pushdown of parent-required sorting requirements down to +/// the child execution plans, considering whether the input order is maintained. +/// +/// # Arguments +/// +/// * `plan` - A reference to an `ExecutionPlan` for which the pushdown will be applied. +/// * `parent_required` - The sorting requirements expected by the parent node. +/// * `maintains_input_order` - A vector of booleans indicating whether each child +/// maintains the input order. +/// +/// # Returns +/// +/// Returns `Ok(Some(Vec>))` if the sorting requirements can be +/// pushed down, `Ok(None)` if not. On error, returns a `Result::Err`. +fn handle_custom_pushdown( + plan: &Arc, + parent_required: OrderingRequirements, + maintains_input_order: &[bool], +) -> Result>>> { + let plan_children = plan.children(); + + // If the plan has no children, return early: + if plan_children.is_empty() { + return Ok(None); + } + + // Collect all unique column indices used in the parent-required sorting + // expression: + let requirement = parent_required.into_single(); + let all_indices: HashSet = requirement + .iter() + .flat_map(|order| { + collect_columns(&order.expr) + .iter() + .map(|col| col.index()) + .collect::>() + }) + .collect(); + + // Get the number of fields in each child's schema: + let children_schema_lengths: Vec = plan_children + .iter() + .map(|c| c.schema().fields().len()) + .collect(); + + // Find the index of the order-maintaining child: + let Some(maintained_child_idx) = maintains_input_order + .iter() + .enumerate() + .find(|(_, m)| **m) + .map(|pair| pair.0) + else { + return Ok(None); + }; + + // Check if all required columns come from the order-maintaining child: + let start_idx = children_schema_lengths[..maintained_child_idx] + .iter() + .sum::(); + let end_idx = start_idx + children_schema_lengths[maintained_child_idx]; + let all_from_maintained_child = + all_indices.iter().all(|i| i >= &start_idx && i < &end_idx); + + // If all columns are from the maintained child, update the parent requirements: + if all_from_maintained_child { + let sub_offset = children_schema_lengths + .iter() + .take(maintained_child_idx) + .sum::(); + // Transform the parent-required expression for the child schema by + // adjusting columns: + let updated_parent_req = requirement + .into_iter() + .map(|req| { + let child_schema = plan_children[maintained_child_idx].schema(); + let updated_columns = req + .expr + .transform_up(|expr| { + if let Some(col) = expr.downcast_ref::() { + let new_index = col.index() - sub_offset; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(new_index).name(), + new_index, + )))) + } else { + Ok(Transformed::no(expr)) + } + })? + .data; + Ok(PhysicalSortRequirement::new(updated_columns, req.options)) + }) + .collect::>>()?; + + // Prepare the result, populating with the updated requirements for children that maintain order + let result = maintains_input_order + .iter() + .map(|&maintains_order| { + if maintains_order { + LexRequirement::new(updated_parent_req.clone()) + .map(OrderingRequirements::new) + } else { + None + } + }) + .collect(); + + Ok(Some(result)) + } else { + Ok(None) + } +} + +// For hash join we only maintain the input order for the right child +// for join type: Inner, Right, RightSemi, RightAnti +fn handle_hash_join( + plan: &HashJoinExec, + parent_required: OrderingRequirements, +) -> Result>>> { + // If the plan has no children or does not maintain the right side ordering, + // return early: + if !plan.maintains_input_order()[1] { + return Ok(None); + } + + // Collect all unique column indices used in the parent-required sorting expression + let requirement = parent_required.into_single(); + let all_indices: HashSet<_> = requirement + .iter() + .flat_map(|order| { + collect_columns(&order.expr) + .into_iter() + .map(|col| col.index()) + .collect::>() + }) + .collect(); + + let column_indices = build_join_column_index(plan); + let projected_indices: Vec<_> = if let Some(projection) = plan.projection.as_ref() { + projection.iter().map(|&i| &column_indices[i]).collect() + } else { + column_indices.iter().collect() + }; + let len_of_left_fields = projected_indices + .iter() + .filter(|ci| ci.side == JoinSide::Left) + .count(); + + let all_from_right_child = all_indices.iter().all(|i| *i >= len_of_left_fields); + + let plan_children = plan.children(); + + // If all columns are from the right child, update the parent requirements + if all_from_right_child { + // Transform the parent-required expression for the child schema by adjusting columns + let updated_parent_req = requirement + .into_iter() + .map(|req| { + let child_schema = plan_children[1].schema(); + let updated_columns = req + .expr + .transform_up(|expr| { + if let Some(col) = expr.downcast_ref::() { + let index = projected_indices[col.index()].index; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(index).name(), + index, + )))) + } else { + Ok(Transformed::no(expr)) + } + })? + .data; + Ok(PhysicalSortRequirement::new(updated_columns, req.options)) + }) + .collect::>>()?; + + // Populating with the updated requirements for children that maintain order + Ok(Some(vec![ + None, + LexRequirement::new(updated_parent_req).map(OrderingRequirements::new), + ])) + } else { + Ok(None) + } +} + +// this function is used to build the column index for the hash join +// push down sort requirements to the right child +fn build_join_column_index(plan: &HashJoinExec) -> Vec { + let map_fields = |schema: SchemaRef, side: JoinSide| { + schema + .fields() + .iter() + .enumerate() + .map(|(index, _)| ColumnIndex { index, side }) + .collect::>() + }; + + match plan.join_type() { + JoinType::Inner | JoinType::Right => { + map_fields(plan.left().schema(), JoinSide::Left) + .into_iter() + .chain(map_fields(plan.right().schema(), JoinSide::Right)) + .collect::>() + } + JoinType::RightSemi | JoinType::RightAnti => { + map_fields(plan.right().schema(), JoinSide::Right) + } + _ => unreachable!("unexpected join type: {}", plan.join_type()), + } +} + +/// Define the Requirements Compatibility +#[derive(Debug)] +enum RequirementsCompatibility { + /// Requirements satisfy + Satisfy, + /// Requirements compatible + Compatible(Option), + /// Requirements not compatible + NonCompatible, +} diff --git a/datafusion/physical-optimizer/src/ensure_requirements/mod.rs b/datafusion/physical-optimizer/src/ensure_requirements/mod.rs new file mode 100644 index 0000000000000..41a03bb031629 --- /dev/null +++ b/datafusion/physical-optimizer/src/ensure_requirements/mod.rs @@ -0,0 +1,259 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`EnsureRequirements`] optimizer rule that enforces distribution and +//! sorting requirements together so that the two never invalidate each other. +//! +//! This rule replaces the separate `EnforceDistribution` + `EnforceSorting` +//! rules with a unified approach inspired by Apache Spark's `EnsureRequirements` +//! and Presto/Trino's `AddExchanges`. +//! +//! # Motivation +//! +//! The previous two-rule design (`EnforceDistribution` then `EnforceSorting`) +//! suffers from non-idempotent composition: `EnforceSorting`'s `pushdown_sorts` +//! can break distribution invariants established by `EnforceDistribution`, +//! because `SortExec.preserve_partitioning` couples sorting and distribution +//! decisions. See for details. +//! +//! # Architecture +//! +//! `optimize` runs several tree traversals. The defining property of this +//! rule is **Phase 2**: a single combined bottom-up pass that resolves +//! distribution *and* sorting for each node together. The surrounding phases +//! are independent traversals (top-down join-key reorder, then several +//! follow-up sort/order rewrites). Some of those could be consolidated +//! further in a follow-up. +//! +//! ```text +//! EnsureRequirements::optimize(plan) +//! │ +//! ├─ Phase 1: top-down join-key reorder (adjust_input_keys_ordering) +//! │ +//! ├─ Phase 2: combined distribution + sorting (single bottom-up pass) +//! │ └─ For each node (bottom-up), for each child: +//! │ Step 1: ensure distribution requirement +//! │ └─ insert RepartitionExec / CoalescePartitionsExec / +//! │ SortPreservingMergeExec as needed +//! │ Step 2: ensure ordering requirement (distribution-aware) +//! │ └─ insert SortExec with the correct `preserve_partitioning`, +//! │ with SortPreservingMergeExec on top if needed +//! │ +//! └─ Phase 3: small follow-up passes (bottom-up unless noted) +//! ├─ parallelize_sorts +//! ├─ replace_with_order_preserving_variants +//! ├─ pushdown_sorts (recursive walk) +//! └─ replace_with_partial_sort +//! ``` +//! +//! # Key Properties +//! +//! - **Idempotent across the whole rule**: Running `EnsureRequirements` +//! twice produces the same plan. This is the property that fixes +//! , where the old +//! two-rule pipeline could regress a parallel sort plan into a serial one +//! on pass 2. +//! - **Distribution before sorting**: For each child, distribution is +//! resolved before ordering, so sorting decisions always have full +//! distribution context. +//! - **Sort pushdown is implicit**: Phase 2 only adds `SortExec` where the +//! child doesn't already satisfy the ordering requirement, so sorts land +//! at the deepest valid position without a separate destructive pass. +//! +//! # Behavior: parallelism via repartitioning +//! +//! Phase 2 Step 1 inserts `RepartitionExec` to satisfy distribution +//! requirements. When configuration allows, it also increases parallelism by +//! repartitioning over otherwise-serial inputs. For example, given two +//! 1-partition inputs feeding an operator that can run with more +//! parallelism: +//! +//! ```text +//! ┌─────────────────────────────────┐ +//! │ ExecutionPlan │ +//! └─────────────────────────────────┘ +//! ▲ ▲ +//! │ │ +//! ┌───────────┐ ┌───────────┐ +//! │ batch A │ │ batch B │ Input: 2 partitions +//! └───────────┘ └───────────┘ +//! ``` +//! +//! `EnsureRequirements` inserts a `RepartitionExec` so the operator runs +//! with three partitions: +//! +//! ```text +//! ┌─────────────────────────────────┐ +//! │ ExecutionPlan │ Input now has 3 partitions +//! └─────────────────────────────────┘ +//! ▲ ▲ ▲ +//! └──────┼───────┘ +//! │ +//! ┌─────────────────────────────────┐ +//! │ RepartitionExec(3) │ batches are repartitioned +//! │ RoundRobin │ +//! └─────────────────────────────────┘ +//! ▲ ▲ +//! ┌───────────┐ ┌───────────┐ +//! │ batch A │ │ batch B │ +//! └───────────┘ └───────────┘ +//! ``` +//! +//! # Behavior: joint distribution + sorting +//! +//! Resolving distribution and sorting together lets Phase 2 produce a +//! parallel sort plan in cases where the two-rule pipeline historically +//! risked a serial one. Given `Sort(DESC) ← Coalesce ← MultiPartitionSource`, +//! `EnsureRequirements` rewrites it into: +//! +//! ```text +//! SortPreservingMergeExec: [a DESC] (cheap k-way merge of sorted streams) +//! SortExec: [a DESC], preserve_partitioning=true (N sorts run in parallel) +//! MultiPartitionSource +//! ``` +//! +//! Each input partition is sorted in parallel, then a `SortPreservingMergeExec` +//! at the top performs a cheap merge of pre-sorted streams. For TopK queries +//! (`fetch=K`), each parallel sort only keeps K rows per partition, so total +//! memory is `N × K` rather than coalescing the entire stream first. +//! +//! # Behavior: strictest distribution match for joins +//! +//! Distribution requirements are met in the strictest way. For example, a +//! hash join with keys `(a, b, c)` requires `Distribution(a, b, c)`. This +//! can in principle be satisfied by partitioning on any superset of any +//! subset of `(a, b, c)`, but this rule always partitions on the exact key +//! tuple `(a, b, c)`. This is sometimes more aggressive than strictly +//! necessary, but the strictest match helps avoid data skew in joins. + +// Internal implementation modules. Re-exported from `crate` root for tests +// in `core/tests/physical_optimizer/{enforce_distribution,enforce_sorting}.rs`. +pub mod enforce_distribution; +pub mod enforce_sorting; + +use std::sync::Arc; + +use crate::PhysicalOptimizerRule; + +use datafusion_common::Result; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_physical_plan::ExecutionPlan; + +/// Optimizer rule that enforces both distribution and sorting requirements. +/// +/// This rule combines the functionality of `EnforceDistribution` and +/// `EnforceSorting` into a coordinated sequence where distribution is +/// always settled before sorting for each operator, preventing the +/// non-idempotent interactions between the two separate rules. +/// +/// See [module level documentation](self) for more details. +#[derive(Default, Debug)] +pub struct EnsureRequirements {} + +impl EnsureRequirements { + /// Create a new `EnsureRequirements` optimizer rule. + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for EnsureRequirements { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + // Phase 1: Join key reordering (top-down, from EnforceDistribution) + use super::enforce_distribution::{ + PlanWithKeyRequirements, adjust_input_keys_ordering, + }; + let top_down_join_key_reordering = config.optimizer.top_down_join_key_reordering; + let plan = if top_down_join_key_reordering { + let ctx = PlanWithKeyRequirements::new_default(plan); + ctx.transform_down(adjust_input_keys_ordering).data()?.plan + } else { + use super::enforce_distribution::reorder_join_keys_to_inputs; + plan.transform_up(|p| Ok(Transformed::yes(reorder_join_keys_to_inputs(p)?))) + .data()? + }; + + // Phase 2: Combined distribution + sorting enforcement (single bottom-up pass) + // For each node: distribution first, then sorting. + use super::enforce_distribution::{DistributionContext, ensure_distribution}; + use super::enforce_sorting::{PlanWithCorrespondingSort, ensure_sorting}; + + // Step 2a: Distribution enforcement (bottom-up) + let dist_ctx = DistributionContext::new_default(plan); + let dist_ctx = dist_ctx + .transform_up(|ctx| ensure_distribution(ctx, config)) + .data()?; + + // Step 2b: Sorting enforcement (bottom-up) — runs on distribution-fixed plan + let sort_ctx = PlanWithCorrespondingSort::new_default(dist_ctx.plan); + let sort_ctx = sort_ctx.transform_up(ensure_sorting)?.data; + + // Phase 3: Optimization passes + // 3a: Parallelize sorts (Coalesce+Sort → SPM+Sort) + use super::enforce_sorting::{ + PlanWithCorrespondingCoalescePartitions, parallelize_sorts, + replace_with_partial_sort, + }; + let plan = if config.optimizer.repartition_sorts { + let ctx = PlanWithCorrespondingCoalescePartitions::new_default(sort_ctx.plan); + ctx.transform_up(parallelize_sorts).data()?.plan + } else { + sort_ctx.plan + }; + + // 3b: Order-preserving variants + use super::enforce_sorting::replace_with_order_preserving_variants::{ + OrderPreservationContext, replace_with_order_preserving_variants, + }; + let ctx = OrderPreservationContext::new_default(plan); + let plan = ctx + .transform_up(|c| { + replace_with_order_preserving_variants(c, false, true, config) + }) + .data()? + .plan; + + // 3c: Sort pushdown (distribution-aware) + use super::enforce_sorting::sort_pushdown::{ + SortPushDown, assign_initial_requirements, pushdown_sorts, + }; + let mut sort_pushdown = SortPushDown::new_default(plan); + assign_initial_requirements(&mut sort_pushdown); + let adjusted = pushdown_sorts(sort_pushdown)?; + + // 3d: Partial sort + adjusted + .plan + .transform_up(|p| Ok(Transformed::yes(replace_with_partial_sort(p)?))) + .data() + } + + fn name(&self) -> &str { + "EnsureRequirements" + } + + fn schema_check(&self) -> bool { + true + } +} + +// See tests in datafusion/core/tests/physical_optimizer/ensure_requirements.rs diff --git a/datafusion/physical-optimizer/src/filter_pushdown.rs b/datafusion/physical-optimizer/src/filter_pushdown.rs index 22cb03fc3e876..28f8155002a50 100644 --- a/datafusion/physical-optimizer/src/filter_pushdown.rs +++ b/datafusion/physical-optimizer/src/filter_pushdown.rs @@ -36,16 +36,16 @@ use std::sync::Arc; use crate::PhysicalOptimizerRule; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{assert_eq_or_internal_err, config::ConfigOptions, Result}; +use datafusion_common::{Result, assert_eq_or_internal_err, config::ConfigOptions}; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr_common::physical_expr::is_volatile; use datafusion_physical_plan::filter_pushdown::{ ChildFilterPushdownResult, ChildPushdownResult, FilterPushdownPhase, FilterPushdownPropagation, PushedDown, }; -use datafusion_physical_plan::{with_new_children_if_necessary, ExecutionPlan}; +use datafusion_physical_plan::{ExecutionPlan, with_new_children_if_necessary}; -use itertools::{izip, Itertools}; +use itertools::{Itertools, izip}; /// Attempts to recursively push given filters from the top of the tree into leaves. /// diff --git a/datafusion/physical-optimizer/src/hash_join_buffering.rs b/datafusion/physical-optimizer/src/hash_join_buffering.rs new file mode 100644 index 0000000000000..7a198cac13fc9 --- /dev/null +++ b/datafusion/physical-optimizer/src/hash_join_buffering.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::PhysicalOptimizerRule; +use datafusion_common::JoinSide; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::buffer::BufferExec; +use datafusion_physical_plan::joins::HashJoinExec; +use std::sync::Arc; + +/// Looks for all the [HashJoinExec]s in the plan and places a [BufferExec] node with the +/// configured capacity in the probe side: +/// +/// ```text +/// ┌───────────────────┐ +/// │ HashJoinExec │ +/// └─────▲────────▲────┘ +/// ┌───────┘ └─────────┐ +/// │ │ +/// ┌────────────────┐ ┌─────────────────┐ +/// │ Build side │ + │ BufferExec │ +/// └────────────────┘ └────────▲────────┘ +/// │ +/// ┌────────┴────────┐ +/// │ Probe side │ +/// └─────────────────┘ +/// ``` +/// +/// Which allows eagerly pulling it even before the build side has completely finished. +#[derive(Debug, Default)] +pub struct HashJoinBuffering {} + +impl HashJoinBuffering { + pub fn new() -> Self { + Self::default() + } +} + +impl PhysicalOptimizerRule for HashJoinBuffering { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> datafusion_common::Result> { + let capacity = config.execution.hash_join_buffering_capacity; + if capacity == 0 { + return Ok(plan); + } + + plan.transform_down(|plan| { + let Some(node) = plan.downcast_ref::() else { + return Ok(Transformed::no(plan)); + }; + let plan = Arc::clone(&plan); + Ok(Transformed::yes( + if HashJoinExec::probe_side() == JoinSide::Left { + // Do not stack BufferExec nodes together. + if node.left.is::() { + return Ok(Transformed::no(plan)); + } + plan.with_new_children(vec![ + Arc::new(BufferExec::new(Arc::clone(&node.left), capacity)), + Arc::clone(&node.right), + ])? + } else { + // Do not stack BufferExec nodes together. + if node.right.is::() { + return Ok(Transformed::no(plan)); + } + plan.with_new_children(vec![ + Arc::clone(&node.left), + Arc::new(BufferExec::new(Arc::clone(&node.right), capacity)), + ])? + }, + )) + }) + .data() + } + + fn name(&self) -> &str { + "HashJoinBuffering" + } + + fn schema_check(&self) -> bool { + true + } +} diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index b55c01f62e992..74c6cbb19aea9 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -24,19 +24,22 @@ //! `PartitionMode` and the build side using the available statistics for hash joins. use crate::PhysicalOptimizerRule; +use crate::optimizer::{ConfigOnlyContext, PhysicalOptimizerContext}; +use datafusion_common::Statistics; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{internal_err, JoinSide, JoinType}; +use datafusion_common::{JoinSide, JoinType, internal_err}; use datafusion_expr_common::sort_properties::SortProperties; -use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::LexOrdering; +use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::joins::utils::ColumnIndex; use datafusion_physical_plan::joins::{ CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, }; +use datafusion_physical_plan::operator_statistics::StatisticsRegistry; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use std::sync::Arc; @@ -47,26 +50,55 @@ use std::sync::Arc; pub struct JoinSelection {} impl JoinSelection { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } } +/// Get statistics for a plan node, using the registry if available. +fn get_stats( + plan: &dyn ExecutionPlan, + registry: Option<&StatisticsRegistry>, +) -> Result> { + if let Some(reg) = registry { + reg.compute(plan) + .map(|s| Arc::::clone(s.base_arc())) + } else { + plan.partition_statistics(None) + } +} + // TODO: We need some performance test for Right Semi/Right Join swap to Left Semi/Left Join in case that the right side is smaller but not much smaller. // TODO: In PrestoSQL, the optimizer flips join sides only if one side is much smaller than the other by more than SIZE_DIFFERENCE_THRESHOLD times, by default is 8 times. -/// Checks statistics for join swap. +/// Checks whether join inputs should be swapped using available statistics. +/// +/// It follows these steps: +/// 1. If a [`StatisticsRegistry`] is provided, use it for cross-operator estimates +/// (e.g., intermediate join outputs that would otherwise have `Absent` statistics). +/// 2. Compare the in-memory sizes of both sides, and place the smaller side on +/// the left (build) side. +/// 3. If in-memory byte sizes are unavailable, fall back to row counts. +/// 4. Do not reorder the join if neither statistic is available, or if +/// `datafusion.optimizer.join_reordering` is disabled. +/// +/// Used configurations inside arg `config` +/// - `config.optimizer.join_reordering`: allows or forbids statistics-driven join swapping pub(crate) fn should_swap_join_order( left: &dyn ExecutionPlan, right: &dyn ExecutionPlan, + config: &ConfigOptions, + registry: Option<&StatisticsRegistry>, ) -> Result { - // Get the left and right table's total bytes - // If both the left and right tables contain total_byte_size statistics, - // use `total_byte_size` to determine `should_swap_join_order`, else use `num_rows` - let left_stats = left.partition_statistics(None)?; - let right_stats = right.partition_statistics(None)?; - // First compare `total_byte_size` of left and right side, - // if information in this field is insufficient fallback to the `num_rows` + if !config.optimizer.join_reordering { + return Ok(false); + } + + let left_stats = get_stats(left, registry)?; + let right_stats = get_stats(right, registry)?; + + // First compare total_byte_size, then fall back to num_rows if byte + // sizes are unavailable. match ( left_stats.total_byte_size.get_value(), right_stats.total_byte_size.get_value(), @@ -86,17 +118,20 @@ fn supports_collect_by_thresholds( plan: &dyn ExecutionPlan, threshold_byte_size: usize, threshold_num_rows: usize, + registry: Option<&StatisticsRegistry>, ) -> bool { - // Currently we do not trust the 0 value from stats, due to stats collection might have bug - // TODO check the logic in datasource::get_statistics_with_limit() - let Ok(stats) = plan.partition_statistics(None) else { + let Ok(stats) = get_stats(plan, registry) else { return false; }; + // Stats use `Precision` to represent stats, where `Absent` means unknown. + // `Exact(0)` and `Inexact(0)` are both valid stats, and we should not treat + // them as unknown, `Absent` will return None (this is in regards to why + // `!=0` is not checked) if let Some(byte_size) = stats.total_byte_size.get_value() { - *byte_size != 0 && *byte_size < threshold_byte_size + *byte_size < threshold_byte_size } else if let Some(num_rows) = stats.num_rows.get_value() { - *num_rows != 0 && *num_rows < threshold_num_rows + *num_rows < threshold_num_rows } else { false } @@ -108,11 +143,25 @@ impl PhysicalOptimizerRule for JoinSelection { plan: Arc, config: &ConfigOptions, ) -> Result> { - // First, we make pipeline-fixing modifications to joins so as to accommodate - // unbounded inputs. Each pipeline-fixing subrule, which is a function - // of type `PipelineFixerSubrule`, takes a single [`PipelineStatePropagator`] - // argument storing state variables that indicate the unboundedness status - // of the current [`ExecutionPlan`] as we traverse the plan tree. + self.optimize_with_context(plan, &ConfigOnlyContext::new(config)) + } + + fn optimize_with_context( + &self, + plan: Arc, + context: &dyn PhysicalOptimizerContext, + ) -> Result> { + let config = context.config_options(); + let mut default_registry = None; + let registry: Option<&StatisticsRegistry> = + if config.optimizer.use_statistics_registry { + Some(context.statistics_registry().unwrap_or_else(|| { + default_registry + .insert(StatisticsRegistry::default_with_builtin_providers()) + })) + } else { + None + }; let subrules: Vec> = vec![ Box::new(hash_join_convert_symmetric_subrule), Box::new(hash_join_swap_subrule), @@ -120,27 +169,9 @@ impl PhysicalOptimizerRule for JoinSelection { let new_plan = plan .transform_up(|p| apply_subrules(p, &subrules, config)) .data()?; - // Next, we apply another subrule that tries to optimize joins using any - // statistics their inputs might have. - // - For a hash join with partition mode [`PartitionMode::Auto`], we will - // make a cost-based decision to select which `PartitionMode` mode - // (`Partitioned`/`CollectLeft`) is optimal. If the statistics information - // is not available, we will fall back to [`PartitionMode::Partitioned`]. - // - We optimize/swap join sides so that the left (build) side of the join - // is the small side. If the statistics information is not available, we - // do not modify join sides. - // - We will also swap left and right sides for cross joins so that the left - // side is the small side. - let config = &config.optimizer; - let collect_threshold_byte_size = config.hash_join_single_partition_threshold; - let collect_threshold_num_rows = config.hash_join_single_partition_threshold_rows; new_plan .transform_up(|plan| { - statistical_join_selection_subrule( - plan, - collect_threshold_byte_size, - collect_threshold_num_rows, - ) + statistical_join_selection_subrule(plan, config, registry) }) .data() } @@ -160,59 +191,65 @@ impl PhysicalOptimizerRule for JoinSelection { /// `CollectLeft` mode is applicable. Otherwise, it will try to swap the join sides. /// When the `ignore_threshold` is false, this function will also check left /// and right sizes in bytes or rows. +/// +/// Used configurations inside arg `config` +/// - `config.optimizer.hash_join_single_partition_threshold`: byte threshold for `CollectLeft` +/// - `config.optimizer.hash_join_single_partition_threshold_rows`: row threshold for `CollectLeft` +/// - `config.optimizer.join_reordering`: allows or forbids input swapping pub(crate) fn try_collect_left( hash_join: &HashJoinExec, ignore_threshold: bool, - threshold_byte_size: usize, - threshold_num_rows: usize, + config: &ConfigOptions, + registry: Option<&StatisticsRegistry>, ) -> Result>> { let left = hash_join.left(); let right = hash_join.right(); + let optimizer_config = &config.optimizer; let left_can_collect = ignore_threshold || supports_collect_by_thresholds( &**left, - threshold_byte_size, - threshold_num_rows, + optimizer_config.hash_join_single_partition_threshold, + optimizer_config.hash_join_single_partition_threshold_rows, + registry, ); let right_can_collect = ignore_threshold || supports_collect_by_thresholds( &**right, - threshold_byte_size, - threshold_num_rows, + optimizer_config.hash_join_single_partition_threshold, + optimizer_config.hash_join_single_partition_threshold_rows, + registry, ); match (left_can_collect, right_can_collect) { (true, true) => { + // Don't swap null-aware anti joins as they have specific side requirements if hash_join.join_type().supports_swap() - && should_swap_join_order(&**left, &**right)? + && !hash_join.null_aware + && should_swap_join_order(&**left, &**right, config, registry)? { Ok(Some(hash_join.swap_inputs(PartitionMode::CollectLeft)?)) } else { - Ok(Some(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - hash_join.on().to_vec(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.projection.clone(), - PartitionMode::CollectLeft, - hash_join.null_equality(), - )?))) + Ok(Some(Arc::new( + hash_join + .builder() + .with_partition_mode(PartitionMode::CollectLeft) + .build()?, + ))) } } - (true, false) => Ok(Some(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - hash_join.on().to_vec(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.projection.clone(), - PartitionMode::CollectLeft, - hash_join.null_equality(), - )?))), + (true, false) => Ok(Some(Arc::new( + hash_join + .builder() + .with_partition_mode(PartitionMode::CollectLeft) + .build()?, + ))), (false, true) => { - if hash_join.join_type().supports_swap() { + // Don't swap null-aware anti joins as they have specific side requirements + if optimizer_config.join_reordering + && hash_join.join_type().supports_swap() + && !hash_join.null_aware + { hash_join.swap_inputs(PartitionMode::CollectLeft).map(Some) } else { Ok(None) @@ -227,88 +264,105 @@ pub(crate) fn try_collect_left( /// Checks if the join order should be swapped based on the join type and input statistics. /// If swapping is optimal and supported, creates a swapped partitioned hash join; otherwise, /// creates a standard partitioned hash join. +/// +/// Used configurations inside arg `config` +/// - `config.optimizer.join_reordering`: allows or forbids statistics-driven join swapping pub(crate) fn partitioned_hash_join( hash_join: &HashJoinExec, + config: &ConfigOptions, + registry: Option<&StatisticsRegistry>, ) -> Result> { let left = hash_join.left(); let right = hash_join.right(); - if hash_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)? + // Don't swap null-aware anti joins as they have specific side requirements + if hash_join.join_type().supports_swap() + && !hash_join.null_aware + && should_swap_join_order(&**left, &**right, config, registry)? { hash_join.swap_inputs(PartitionMode::Partitioned) } else { - Ok(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - hash_join.on().to_vec(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.projection.clone(), - PartitionMode::Partitioned, - hash_join.null_equality(), - )?)) + // Null-aware anti joins must use CollectLeft mode because they track probe-side state + // (probe_side_non_empty, probe_side_has_null) per-partition, but need global knowledge + // for correct null handling. With partitioning, a partition might not see probe rows + // even if the probe side is globally non-empty, leading to incorrect NULL row handling. + let partition_mode = if hash_join.null_aware { + PartitionMode::CollectLeft + } else { + PartitionMode::Partitioned + }; + + Ok(Arc::new( + hash_join + .builder() + .with_partition_mode(partition_mode) + .build()?, + )) } } /// This subrule tries to modify a given plan so that it can -/// optimize hash and cross joins in the plan according to available statistical information. +/// optimize hash and cross joins in the plan according to available statistical +/// information. +/// +/// Used configurations inside arg `config` +/// - `config.optimizer.hash_join_single_partition_threshold`: byte threshold for `CollectLeft` +/// - `config.optimizer.hash_join_single_partition_threshold_rows`: row threshold for `CollectLeft` +/// - `config.optimizer.join_reordering`: allows or forbids input swapping fn statistical_join_selection_subrule( plan: Arc, - collect_threshold_byte_size: usize, - collect_threshold_num_rows: usize, + config: &ConfigOptions, + registry: Option<&StatisticsRegistry>, ) -> Result>> { - let transformed = - if let Some(hash_join) = plan.as_any().downcast_ref::() { - match hash_join.partition_mode() { - PartitionMode::Auto => try_collect_left( - hash_join, - false, - collect_threshold_byte_size, - collect_threshold_num_rows, - )? + let transformed = if let Some(hash_join) = plan.downcast_ref::() { + match hash_join.partition_mode() { + PartitionMode::Auto => try_collect_left(hash_join, false, config, registry)? .map_or_else( - || partitioned_hash_join(hash_join).map(Some), + || partitioned_hash_join(hash_join, config, registry).map(Some), |v| Ok(Some(v)), )?, - PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)? - .map_or_else( - || partitioned_hash_join(hash_join).map(Some), - |v| Ok(Some(v)), - )?, - PartitionMode::Partitioned => { - let left = hash_join.left(); - let right = hash_join.right(); - if hash_join.join_type().supports_swap() - && should_swap_join_order(&**left, &**right)? - { - hash_join - .swap_inputs(PartitionMode::Partitioned) - .map(Some)? - } else { - None - } - } - } - } else if let Some(cross_join) = plan.as_any().downcast_ref::() { - let left = cross_join.left(); - let right = cross_join.right(); - if should_swap_join_order(&**left, &**right)? { - cross_join.swap_inputs().map(Some)? - } else { - None + PartitionMode::CollectLeft => { + try_collect_left(hash_join, true, config, registry)?.map_or_else( + || partitioned_hash_join(hash_join, config, registry).map(Some), + |v| Ok(Some(v)), + )? } - } else if let Some(nl_join) = plan.as_any().downcast_ref::() { - let left = nl_join.left(); - let right = nl_join.right(); - if nl_join.join_type().supports_swap() - && should_swap_join_order(&**left, &**right)? - { - nl_join.swap_inputs().map(Some)? - } else { - None + PartitionMode::Partitioned => { + let left = hash_join.left(); + let right = hash_join.right(); + // Don't swap null-aware anti joins as they have specific side requirements + if hash_join.join_type().supports_swap() + && !hash_join.null_aware + && should_swap_join_order(&**left, &**right, config, registry)? + { + hash_join + .swap_inputs(PartitionMode::Partitioned) + .map(Some)? + } else { + None + } } + } + } else if let Some(cross_join) = plan.downcast_ref::() { + let left = cross_join.left(); + let right = cross_join.right(); + if should_swap_join_order(&**left, &**right, config, registry)? { + cross_join.swap_inputs().map(Some)? } else { None - }; + } + } else if let Some(nl_join) = plan.downcast_ref::() { + let left = nl_join.left(); + let right = nl_join.right(); + if nl_join.join_type().supports_swap() + && should_swap_join_order(&**left, &**right, config, registry)? + { + nl_join.swap_inputs().map(Some)? + } else { + None + } + } else { + None + }; Ok(if let Some(transformed) = transformed { Transformed::yes(transformed) @@ -342,7 +396,7 @@ fn hash_join_convert_symmetric_subrule( config_options: &ConfigOptions, ) -> Result> { // Check if the current plan node is a HashJoinExec. - if let Some(hash_join) = input.as_any().downcast_ref::() { + if let Some(hash_join) = input.downcast_ref::() { let left_unbounded = hash_join.left.boundedness().is_unbounded(); let left_incremental = matches!( hash_join.left.pipeline_behavior(), @@ -481,19 +535,16 @@ pub fn hash_join_swap_subrule( mut input: Arc, _config_options: &ConfigOptions, ) -> Result> { - if let Some(hash_join) = input.as_any().downcast_ref::() { - if hash_join.left.boundedness().is_unbounded() - && !hash_join.right.boundedness().is_unbounded() - && matches!( - *hash_join.join_type(), - JoinType::Inner - | JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftAnti - ) - { - input = swap_join_according_to_unboundedness(hash_join)?; - } + if let Some(hash_join) = input.downcast_ref::() + && hash_join.left.boundedness().is_unbounded() + && !hash_join.right.boundedness().is_unbounded() + && !hash_join.null_aware // Don't swap null-aware anti joins + && matches!( + *hash_join.join_type(), + JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti + ) + { + input = swap_join_according_to_unboundedness(hash_join)?; } Ok(input) } @@ -539,10 +590,14 @@ fn apply_subrules( subrules: &Vec>, config_options: &ConfigOptions, ) -> Result>> { + let original = Arc::clone(&input); for subrule in subrules { input = subrule(input, config_options)?; } - Ok(Transformed::yes(input)) + + let transformed = !Arc::ptr_eq(&original, &input); + + Ok(Transformed::new_transformed(input, transformed)) } // See tests in datafusion/core/tests/physical_optimizer diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs index f4b82eed3c400..b9eb248f6e843 100644 --- a/datafusion/physical-optimizer/src/lib.rs +++ b/datafusion/physical-optimizer/src/lib.rs @@ -23,16 +23,16 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] pub mod aggregate_statistics; -pub mod coalesce_batches; pub mod combine_partial_final_agg; -pub mod enforce_distribution; -pub mod enforce_sorting; pub mod ensure_coop; +pub mod ensure_requirements; +// `enforce_distribution` and `enforce_sorting` are now internal implementation +// details of `ensure_requirements`. Re-export at the crate root so external test +// modules keep their public paths. +pub use ensure_requirements::{enforce_distribution, enforce_sorting}; pub mod filter_pushdown; pub mod join_selection; pub mod limit_pushdown; @@ -42,9 +42,13 @@ pub mod optimizer; pub mod output_requirements; pub mod projection_pushdown; pub use datafusion_pruning as pruning; +pub mod hash_join_buffering; +pub mod pushdown_sort; pub mod sanity_checker; pub mod topk_aggregation; +pub mod topk_repartition; pub mod update_aggr_exprs; pub mod utils; +pub mod window_topn; -pub use optimizer::PhysicalOptimizerRule; +pub use optimizer::{ConfigOnlyContext, PhysicalOptimizerContext, PhysicalOptimizerRule}; diff --git a/datafusion/physical-optimizer/src/limit_pushdown.rs b/datafusion/physical-optimizer/src/limit_pushdown.rs index 7469c3af9344c..63c4f21bd9d6d 100644 --- a/datafusion/physical-optimizer/src/limit_pushdown.rs +++ b/datafusion/physical-optimizer/src/limit_pushdown.rs @@ -17,6 +17,48 @@ //! [`LimitPushdown`] pushes `LIMIT` down through `ExecutionPlan`s to reduce //! data transfer as much as possible. +//! +//! # Plan Limit Absorption +//! In addition to pushing down `GlobalLimitExec` and `LocalLimitExec` nodes in +//! the plan, some operators can "absorb" a limit and stop early during +//! execution. +//! +//! ## Background: vectorized volcano execution model +//! DataFusion uses a batched volcano model. For most operators, output is +//! produced in batches of `datafusion.execution.batch_size` (default 8192), so +//! the batch sizes typically look like: +//! ```text +//! 8192, 8192, ..., 8192, 100 (the final batch may be partial) +//! ``` +//! +//! ## Example +//! For a join with an expensive, selective predicate: +//! ```text +//! GlobalLimitExec: skip=0, fetch=10 +//! -- NestedLoopJoinExec(on=expr_expensive_and_selective) +//! --- DataSourceExec() +//! --- DataSourceExec() +//! ``` +//! +//! Under this model, `NestedLoopJoinExec` would keep working until it can emit +//! a full batch (8192 rows), even though the query only needs 10. If the limit +//! cannot be pushed below the join, we can still embed it inside the join so it +//! stops once the limit is satisfied. The transformed plan looks like: +//! +//! ```text +//! NestedLoopJoinExec(on=expr_expensive_and_selective, fetch=10) +//! --- DataSourceExec() +//! --- DataSourceExec() +//! ``` +//! +//! ## Implementation +//! The current optimizer rule optionally pushes `fetch` requirements into +//! operators via [`ExecutionPlan::with_fetch`]. +//! +//! To support early termination in operators, [`LimitedBatchCoalescer`](https://docs.rs/datafusion/latest/datafusion/physical_plan/coalesce/struct.LimitedBatchCoalescer.html) +//! can help manage the output buffer. +//! +//! Reference implementation in Hash Join: use std::fmt::Debug; use std::sync::Arc; @@ -25,10 +67,14 @@ use crate::PhysicalOptimizerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; +use datafusion_common::stats::Precision; use datafusion_common::tree_node::{Transformed, TreeNodeRecursion}; use datafusion_common::utils::combine_limit; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; /// This rule inspects [`ExecutionPlan`]'s and pushes down the fetch limit from @@ -37,23 +83,23 @@ use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; pub struct LimitPushdown {} /// This is a "data class" we use within the [`LimitPushdown`] rule to push -/// down [`LimitExec`] in the plan. GlobalRequirements are hold as a rule-wide state +/// down limits in the plan. GlobalRequirements are hold as a rule-wide state /// and holds the fetch and skip information. The struct also has a field named /// satisfied which means if the "current" plan is valid in terms of limits or not. /// /// For example: If the plan is satisfied with current fetch info, we decide to not add a LocalLimit /// /// [`LimitPushdown`]: crate::limit_pushdown::LimitPushdown -/// [`LimitExec`]: crate::limit_pushdown::LimitExec #[derive(Default, Clone, Debug)] pub struct GlobalRequirements { fetch: Option, skip: usize, satisfied: bool, + preserve_order: bool, } impl LimitPushdown { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -69,6 +115,7 @@ impl PhysicalOptimizerRule for LimitPushdown { fetch: None, skip: 0, satisfied: false, + preserve_order: false, }; pushdown_limits(plan, global_state) } @@ -82,44 +129,11 @@ impl PhysicalOptimizerRule for LimitPushdown { } } -/// This enumeration makes `skip` and `fetch` calculations easier by providing -/// a single API for both local and global limit operators. -#[derive(Debug)] -pub enum LimitExec { - Global(GlobalLimitExec), - Local(LocalLimitExec), -} - -impl LimitExec { - fn input(&self) -> &Arc { - match self { - Self::Global(global) => global.input(), - Self::Local(local) => local.input(), - } - } - - fn fetch(&self) -> Option { - match self { - Self::Global(global) => global.fetch(), - Self::Local(local) => Some(local.fetch()), - } - } - - fn skip(&self) -> usize { - match self { - Self::Global(global) => global.skip(), - Self::Local(_) => 0, - } - } -} - -impl From for Arc { - fn from(limit_exec: LimitExec) -> Self { - match limit_exec { - LimitExec::Global(global) => Arc::new(global), - LimitExec::Local(local) => Arc::new(local), - } - } +struct LimitInfo { + input: Arc, + fetch: Option, + skip: usize, + preserve_order: bool, } /// This function is the main helper function of the `LimitPushDown` rule. @@ -134,24 +148,45 @@ pub fn pushdown_limit_helper( mut global_state: GlobalRequirements, ) -> Result<(Transformed>, GlobalRequirements)> { // Extract limit, if exist, and return child inputs. - if let Some(limit_exec) = extract_limit(&pushdown_plan) { + if let Some(limit_info) = extract_limit(&pushdown_plan) { // If we have fetch/skip info in the global state already, we need to // decide which one to continue with: let (skip, fetch) = combine_limit( global_state.skip, global_state.fetch, - limit_exec.skip(), - limit_exec.fetch(), + limit_info.skip, + limit_info.fetch, ); global_state.skip = skip; global_state.fetch = fetch; + global_state.preserve_order = limit_info.preserve_order; + global_state.satisfied = false; + + if let Some(fetch) = fetch + && limit_satisfied_by_input(&limit_info.input, skip, fetch)? + { + // The input already produces at most `fetch` rows, so no new limit + // node is needed. Mark satisfied so downstream won't re-add one, + // but preserve skip/fetch so any nested limit nodes (e.g. an inner + // GlobalLimitExec) can still be merged with the outer constraint. + global_state.satisfied = true; + + return Ok(( + Transformed { + data: limit_info.input, + transformed: true, + tnr: TreeNodeRecursion::Stop, + }, + global_state, + )); + } // Now the global state has the most recent information, we can remove - // the `LimitExec` plan. We will decide later if we should add it again - // or not. + // the limit node. We will decide later if we should add it again or + // not. return Ok(( Transformed { - data: Arc::clone(limit_exec.input()), + data: limit_info.input, transformed: true, tnr: TreeNodeRecursion::Stop, }, @@ -162,7 +197,7 @@ pub fn pushdown_limit_helper( // If we have a non-limit operator with fetch capability, update global // state as necessary: if pushdown_plan.fetch().is_some() { - if global_state.fetch.is_none() { + if global_state.skip == 0 { global_state.satisfied = true; } (global_state.skip, global_state.fetch) = combine_limit( @@ -201,7 +236,7 @@ pub fn pushdown_limit_helper( Ok((Transformed::no(pushdown_plan), global_state)) } else if let Some(plan_with_fetch) = pushdown_plan.with_fetch(skip_and_fetch) { // This plan is combining input partitions, so we need to add the - // fetch info to plan if possible. If not, we must add a `LimitExec` + // fetch info to plan if possible. If not, we must add a limit node // with the information from the global state. let mut new_plan = plan_with_fetch; // Execution plans can't (yet) handle skip, so if we have one, @@ -241,17 +276,28 @@ pub fn pushdown_limit_helper( let maybe_fetchable = pushdown_plan.with_fetch(skip_and_fetch); if global_state.satisfied { if let Some(plan_with_fetch) = maybe_fetchable { - Ok((Transformed::yes(plan_with_fetch), global_state)) + let plan_with_preserve_order = plan_with_fetch + .with_preserve_order(global_state.preserve_order) + .unwrap_or(plan_with_fetch); + Ok((Transformed::yes(plan_with_preserve_order), global_state)) } else { Ok((Transformed::no(pushdown_plan), global_state)) } } else { global_state.satisfied = true; pushdown_plan = if let Some(plan_with_fetch) = maybe_fetchable { + let plan_with_preserve_order = plan_with_fetch + .with_preserve_order(global_state.preserve_order) + .unwrap_or(plan_with_fetch); + if global_skip > 0 { - add_global_limit(plan_with_fetch, global_skip, Some(global_fetch)) + add_global_limit( + plan_with_preserve_order, + global_skip, + Some(global_fetch), + ) } else { - plan_with_fetch + plan_with_preserve_order } } else { add_limit(pushdown_plan, global_skip, global_fetch) @@ -261,6 +307,59 @@ pub fn pushdown_limit_helper( } } +/// Returns true if exact input statistics prove that applying the limit would +/// not remove any rows. +fn limit_satisfied_by_input( + plan: &Arc, + skip: usize, + fetch: usize, +) -> Result { + if skip > 0 { + return Ok(false); + } + + if plan.output_partitioning().partition_count() != 1 { + return Ok(false); + } + + let Some(num_rows) = limit_eliminable_exact_num_rows(plan)? else { + return Ok(false); + }; + + Ok(num_rows <= fetch) +} + +/// Returns exact row counts only from a conservative whitelist of operators +/// whose row-count guarantees are strong enough to remove a limit. +fn limit_eliminable_exact_num_rows( + plan: &Arc, +) -> Result> { + // Unwrap any wrapping ProjectionExec layers; projections preserve row count + // but may derive statistics in ways that are not trustworthy, so we peek + // through them to the underlying producer. + let mut current = plan; + while let Some(projection) = current.downcast_ref::() { + current = projection.input(); + } + + if current.is::() { + return Ok(Some(0)); + } + + if current.is::() { + return Ok(Some(1)); + } + + if matches!( + current.partition_statistics(None)?.num_rows, + Precision::Exact(0) + ) { + return Ok(Some(0)); + } + + Ok(None) +} + /// Pushes down the limit through the plan. pub(crate) fn pushdown_limits( pushdown_plan: Arc, @@ -276,41 +375,60 @@ pub(crate) fn pushdown_limits( (new_node, global_state) = pushdown_limit_helper(new_node.data, global_state)?; } + // Once a limit has been materialized above the current node, child + // subtrees should not inherit its `skip`. Keep `fetch`, but clear + // `skip` before recursing so child-local limits are not merged with + // an `OFFSET` that has already been applied. + if global_state.satisfied { + global_state.skip = 0; + } + // Apply pushdown limits in children let children = new_node.data.children(); + let mut changed = false; let new_children = children .into_iter() - .map(|child| { - pushdown_limits(Arc::::clone(child), global_state.clone()) + .map(|child: &Arc| { + let new_child = pushdown_limits( + Arc::::clone(child), + global_state.clone(), + )?; + // Tracking if any of the children changed + changed |= !Arc::ptr_eq(child, &new_child); + Ok(new_child) }) .collect::>()?; - new_node.data.with_new_children(new_children) + + if changed { + new_node.data.with_new_children(new_children) + } else { + Ok(new_node.data) + } } -/// Transforms the [`ExecutionPlan`] into a [`LimitExec`] if it is a +/// Extracts limit information from the [`ExecutionPlan`] if it is a /// [`GlobalLimitExec`] or a [`LocalLimitExec`]. -fn extract_limit(plan: &Arc) -> Option { - if let Some(global_limit) = plan.as_any().downcast_ref::() { - Some(LimitExec::Global(GlobalLimitExec::new( - Arc::clone(global_limit.input()), - global_limit.skip(), - global_limit.fetch(), - ))) +fn extract_limit(plan: &Arc) -> Option { + if let Some(global_limit) = plan.downcast_ref::() { + Some(LimitInfo { + input: Arc::clone(global_limit.input()), + fetch: global_limit.fetch(), + skip: global_limit.skip(), + preserve_order: global_limit.required_ordering().is_some(), + }) } else { - plan.as_any() - .downcast_ref::() - .map(|local_limit| { - LimitExec::Local(LocalLimitExec::new( - Arc::clone(local_limit.input()), - local_limit.fetch(), - )) + plan.downcast_ref::() + .map(|local_limit| LimitInfo { + input: Arc::clone(local_limit.input()), + fetch: Some(local_limit.fetch()), + skip: 0, + preserve_order: local_limit.required_ordering().is_some(), }) } } /// Checks if the given plan combines input partitions. fn combines_input_partitions(plan: &Arc) -> bool { - let plan = plan.as_any(); plan.is::() || plan.is::() } diff --git a/datafusion/physical-optimizer/src/limit_pushdown_past_window.rs b/datafusion/physical-optimizer/src/limit_pushdown_past_window.rs index 1c671cd074886..092570b051979 100644 --- a/datafusion/physical-optimizer/src/limit_pushdown_past_window.rs +++ b/datafusion/physical-optimizer/src/limit_pushdown_past_window.rs @@ -16,16 +16,16 @@ // under the License. use crate::PhysicalOptimizerRule; +use datafusion_common::ScalarValue; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::ScalarValue; use datafusion_expr::{LimitEffect, WindowFrameBound, WindowFrameUnits}; use datafusion_physical_expr::window::{ PlainAggregateWindowExpr, SlidingAggregateWindowExpr, StandardWindowExpr, StandardWindowFunctionExpr, WindowExpr, }; use datafusion_physical_plan::execution_plan::CardinalityEffect; -use datafusion_physical_plan::limit::GlobalLimitExec; +use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; @@ -104,7 +104,7 @@ impl PhysicalOptimizerRule for LimitPushPastWindows { } // grow the limit if we hit a window function - if let Some(window) = node.as_any().downcast_ref::() { + if let Some(window) = node.downcast_ref::() { phase = Phase::Apply; if !grow_limit(window, &mut ctx) { return reset(node, &mut ctx); @@ -113,17 +113,17 @@ impl PhysicalOptimizerRule for LimitPushPastWindows { } // Apply the limit if we hit a sortpreservingmerge node - if phase == Phase::Apply { - if let Some(out) = apply_limit(&node, &mut ctx) { - return Ok(out); - } + if phase == Phase::Apply + && let Some(out) = apply_limit(&node, &mut ctx) + { + return Ok(out); } // nodes along the way if !node.supports_limit_pushdown() { return reset(node, &mut ctx); } - if let Some(part) = node.as_any().downcast_ref::() { + if let Some(part) = node.downcast_ref::() { let output = part.partitioning().partition_count(); let input = part.input().output_partitioning().partition_count(); if output < input { @@ -185,7 +185,7 @@ fn apply_limit( node: &Arc, ctx: &mut TraverseState, ) -> Option>> { - if !node.as_any().is::() && !node.as_any().is::() { + if !node.is::() && !node.is::() { return None; } let latest = ctx.limit.take(); @@ -202,11 +202,17 @@ fn apply_limit( } fn get_limit(node: &Arc, ctx: &mut TraverseState) -> bool { - if let Some(limit) = node.as_any().downcast_ref::() { + if let Some(limit) = node.downcast_ref::() { ctx.reset_limit(limit.fetch().map(|fetch| fetch + limit.skip())); return true; } - if let Some(limit) = node.as_any().downcast_ref::() { + // In distributed execution, GlobalLimitExec becomes LocalLimitExec + // per partition. Handle it the same way (LocalLimitExec has no skip). + if let Some(limit) = node.downcast_ref::() { + ctx.reset_limit(Some(limit.fetch())); + return true; + } + if let Some(limit) = node.downcast_ref::() { ctx.reset_limit(limit.fetch()); return true; } @@ -254,3 +260,110 @@ fn bound_to_usize(bound: &WindowFrameBound) -> Option { _ => None, } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_expr::WindowFrame; + use datafusion_functions_window::row_number::row_number_udwf; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_plan::InputOrderMode; + use datafusion_physical_plan::displayable; + use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; + use datafusion_physical_plan::windows::{ + BoundedWindowAggExec, create_udwf_window_expr, + }; + use insta::assert_snapshot; + + fn plan_str(plan: &dyn ExecutionPlan) -> String { + displayable(plan).indent(true).to_string() + } + + fn schema() -> Arc { + Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])) + } + + /// Build: LocalLimitExec or GlobalLimitExec → BoundedWindowAggExec(row_number) → SortExec + fn build_window_plan( + use_local_limit: bool, + ) -> datafusion_common::Result> { + let s = schema(); + let input: Arc = + Arc::new(PlaceholderRowExec::new(Arc::clone(&s))); + + let ordering = + LexOrdering::new(vec![PhysicalSortExpr::new_default(col("a", &s)?).asc()]) + .unwrap(); + + let sort: Arc = Arc::new( + SortExec::new(ordering.clone(), input).with_preserve_partitioning(true), + ); + + let window_expr = Arc::new(StandardWindowExpr::new( + create_udwf_window_expr( + &row_number_udwf(), + &[], + &s, + "row_number".to_string(), + false, + )?, + &[], + ordering.as_ref(), + Arc::new(WindowFrame::new_bounds( + WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameBound::CurrentRow, + )), + )); + + let window: Arc = Arc::new(BoundedWindowAggExec::try_new( + vec![window_expr], + sort, + InputOrderMode::Sorted, + true, + )?); + + let limit: Arc = if use_local_limit { + Arc::new(LocalLimitExec::new(window, 100)) + } else { + Arc::new(GlobalLimitExec::new(window, 0, Some(100))) + }; + + Ok(limit) + } + + fn optimize(plan: Arc) -> Arc { + let mut config = ConfigOptions::new(); + config.optimizer.enable_window_limits = true; + LimitPushPastWindows::new().optimize(plan, &config).unwrap() + } + + /// GlobalLimitExec above a windowed sort should push fetch into the SortExec. + #[test] + fn global_limit_pushes_past_window() { + let plan = build_window_plan(false).unwrap(); + let optimized = optimize(plan); + assert_snapshot!(plan_str(optimized.as_ref()), @r#" + GlobalLimitExec: skip=0, fetch=100 + BoundedWindowAggExec: wdw=[row_number: Field { "row_number": UInt64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: TopK(fetch=100), expr=[a@0 ASC], preserve_partitioning=[true] + PlaceholderRowExec + "#); + } + + /// LocalLimitExec above a windowed sort should also push fetch into the SortExec. + /// This is the case in distributed execution where GlobalLimitExec becomes LocalLimitExec. + #[test] + fn local_limit_pushes_past_window() { + let plan = build_window_plan(true).unwrap(); + let optimized = optimize(plan); + assert_snapshot!(plan_str(optimized.as_ref()), @r#" + LocalLimitExec: fetch=100 + BoundedWindowAggExec: wdw=[row_number: Field { "row_number": UInt64 }, frame: ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], mode=[Sorted] + SortExec: TopK(fetch=100), expr=[a@0 ASC], preserve_partitioning=[true] + PlaceholderRowExec + "#); + } +} diff --git a/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs b/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs index 3666ff3798b67..852dc2a2a9434 100644 --- a/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs +++ b/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs @@ -20,13 +20,13 @@ use std::sync::Arc; -use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::aggregates::{AggregateExec, LimitOptions}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; +use datafusion_common::Result; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::Result; use crate::PhysicalOptimizerRule; use itertools::Itertools; @@ -54,16 +54,8 @@ impl LimitedDistinctAggregation { } // We found what we want: clone, copy the limit down, and return modified node - let new_aggr = AggregateExec::try_new( - *aggr.mode(), - aggr.group_expr().clone(), - aggr.aggr_expr().to_vec(), - aggr.filter_expr().to_vec(), - aggr.input().to_owned(), - aggr.input_schema(), - ) - .expect("Unable to copy Aggregate!") - .with_limit(Some(limit)); + let new_aggr = aggr.with_new_limit_options(Some(LimitOptions::new(limit))); + Some(Arc::new(new_aggr)) } @@ -77,11 +69,10 @@ impl LimitedDistinctAggregation { let mut global_skip: usize = 0; let children: Vec>; let mut is_global_limit = false; - if let Some(local_limit) = plan.as_any().downcast_ref::() { + if let Some(local_limit) = plan.downcast_ref::() { limit = local_limit.fetch(); children = local_limit.children().into_iter().cloned().collect(); - } else if let Some(global_limit) = plan.as_any().downcast_ref::() - { + } else if let Some(global_limit) = plan.downcast_ref::() { global_fetch = global_limit.fetch(); global_fetch?; global_skip = global_limit.skip(); @@ -112,18 +103,15 @@ impl LimitedDistinctAggregation { if !rewrite_applicable { return Ok(Transformed::no(plan)); } - if let Some(aggr) = plan.as_any().downcast_ref::() { - if found_match_aggr { - if let Some(parent_aggr) = - match_aggr.as_any().downcast_ref::() - { - if !parent_aggr.group_expr().eq(aggr.group_expr()) { - // a partial and final aggregation with different groupings disqualifies - // rewriting the child aggregation - rewrite_applicable = false; - return Ok(Transformed::no(plan)); - } - } + if let Some(aggr) = plan.downcast_ref::() { + if found_match_aggr + && let Some(parent_aggr) = match_aggr.downcast_ref::() + && !parent_aggr.group_expr().eq(aggr.group_expr()) + { + // a partial and final aggregation with different groupings disqualifies + // rewriting the child aggregation + rewrite_applicable = false; + return Ok(Transformed::no(plan)); } // either we run into an Aggregate and transform it, or disable the rewrite // for subsequent children diff --git a/datafusion/physical-optimizer/src/optimizer.rs b/datafusion/physical-optimizer/src/optimizer.rs index 03c83bb5a092a..0f81512b61c8e 100644 --- a/datafusion/physical-optimizer/src/optimizer.rs +++ b/datafusion/physical-optimizer/src/optimizer.rs @@ -21,11 +21,9 @@ use std::fmt::Debug; use std::sync::Arc; use crate::aggregate_statistics::AggregateStatistics; -use crate::coalesce_batches::CoalesceBatches; use crate::combine_partial_final_agg::CombinePartialFinalAggregate; -use crate::enforce_distribution::EnforceDistribution; -use crate::enforce_sorting::EnforceSorting; use crate::ensure_coop::EnsureCooperative; +use crate::ensure_requirements::EnsureRequirements; use crate::filter_pushdown::FilterPushdown; use crate::join_selection::JoinSelection; use crate::limit_pushdown::LimitPushdown; @@ -34,12 +32,58 @@ use crate::output_requirements::OutputRequirements; use crate::projection_pushdown::ProjectionPushdown; use crate::sanity_checker::SanityCheckPlan; use crate::topk_aggregation::TopKAggregation; +use crate::topk_repartition::TopKRepartition; use crate::update_aggr_exprs::OptimizeAggregateOrder; +use crate::hash_join_buffering::HashJoinBuffering; use crate::limit_pushdown_past_window::LimitPushPastWindows; -use datafusion_common::config::ConfigOptions; +use crate::pushdown_sort::PushdownSort; +use crate::window_topn::WindowTopN; use datafusion_common::Result; +use datafusion_common::config::ConfigOptions; use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::operator_statistics::StatisticsRegistry; + +/// Context available to physical optimizer rules. +/// +/// This trait provides access to configuration options and optional statistics +/// registry for enhanced statistics lookup. It allows optimizer rules to access +/// extended context without changing the core [`PhysicalOptimizerRule::optimize`] +/// signature. +pub trait PhysicalOptimizerContext: Send + Sync { + /// Returns the configuration options. + fn config_options(&self) -> &ConfigOptions; + + /// Returns the statistics registry for enhanced statistics lookup. + /// + /// Returns `None` if no registry is configured, in which case rules + /// should fall back to using `ExecutionPlan::partition_statistics()`. + fn statistics_registry(&self) -> Option<&StatisticsRegistry> { + None + } +} + +/// Simple context wrapping [`ConfigOptions`] for backward compatibility. +/// +/// This struct provides a minimal implementation of [`PhysicalOptimizerContext`] +/// that only supplies configuration options. Used when no statistics registry +/// is available or needed. +pub struct ConfigOnlyContext<'a> { + config: &'a ConfigOptions, +} + +impl<'a> ConfigOnlyContext<'a> { + /// Create a new context wrapping the given config options. + pub fn new(config: &'a ConfigOptions) -> Self { + Self { config } + } +} + +impl PhysicalOptimizerContext for ConfigOnlyContext<'_> { + fn config_options(&self) -> &ConfigOptions { + self.config + } +} /// `PhysicalOptimizerRule` transforms one ['ExecutionPlan'] into another which /// computes the same results, but in a potentially more efficient way. @@ -48,14 +92,30 @@ use datafusion_physical_plan::ExecutionPlan; /// `PhysicalOptimizerRule`s. /// /// [`SessionState::add_physical_optimizer_rule`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html#method.add_physical_optimizer_rule -pub trait PhysicalOptimizerRule: Debug { - /// Rewrite `plan` to an optimized form +pub trait PhysicalOptimizerRule: Debug + std::any::Any { + /// Rewrite `plan` to an optimized form. + /// + /// This is the primary optimization method. For rules that need access to + /// the statistics registry, override [`optimize_with_context`](Self::optimize_with_context) instead. fn optimize( &self, plan: Arc, config: &ConfigOptions, ) -> Result>; + /// Rewrite `plan` with access to extended context (statistics registry, etc.). + /// + /// Override this method if you need access to the statistics registry for + /// enhanced statistics lookup. The default implementation simply calls + /// [`optimize`](Self::optimize) with the config options from the context. + fn optimize_with_context( + &self, + plan: Arc, + context: &dyn PhysicalOptimizerContext, + ) -> Result> { + self.optimize(plan, context.config_options()) + } + /// A human readable name for this optimizer rule fn name(&self) -> &str; @@ -82,6 +142,12 @@ impl Default for PhysicalOptimizer { impl PhysicalOptimizer { /// Create a new optimizer using the recommended list of rules pub fn new() -> Self { + // NOTEs: + // - The order of rules in this list is important, as it determines the + // order in which they are applied. + // - Adding a new rule here is expensive as it will be applied to all + // queries, and will likely increase the optimization time. Please extend + // existing rules when possible, rather than adding a new rule. let rules: Vec> = vec![ // If there is a output requirement of the query, make sure that // this information is not lost across different rules during optimization. @@ -89,11 +155,11 @@ impl PhysicalOptimizer { Arc::new(AggregateStatistics::new()), // Statistics-based join selection will change the Auto mode to a real join implementation, // like collect left, or hash join, or future sort merge join, which will influence the - // EnforceDistribution and EnforceSorting rules as they decide whether to add additional - // repartitioning and local sorting steps to meet distribution and ordering requirements. - // Therefore, it should run before EnforceDistribution and EnforceSorting. + // EnsureRequirements rule as it decides whether to add additional repartitioning and + // local sorting steps to meet distribution and ordering requirements. Therefore, it + // should run before EnsureRequirements. Arc::new(JoinSelection::new()), - // The LimitedDistinctAggregation rule should be applied before the EnforceDistribution rule, + // The LimitedDistinctAggregation rule should be applied before EnsureRequirements, // as that rule may inject other operations in between the different AggregateExecs. // Applying the rule early means only directly-connected AggregateExecs must be examined. Arc::new(LimitedDistinctAggregation::new()), @@ -103,25 +169,36 @@ impl PhysicalOptimizer { // those are handled by the later `FilterPushdown` rule. // See `FilterPushdownPhase` for more details. Arc::new(FilterPushdown::new()), - // The EnforceDistribution rule is for adding essential repartitioning to satisfy distribution - // requirements. Please make sure that the whole plan tree is determined before this rule. - // This rule increases parallelism if doing so is beneficial to the physical plan; i.e. at - // least one of the operators in the plan benefits from increased parallelism. - Arc::new(EnforceDistribution::new()), - // The CombinePartialFinalAggregate rule should be applied after the EnforceDistribution rule + // Ensures each input plan satisfies the distribution and ordering + // requirements declared by `ExecutionPlan::required_input_distribution` + // and `ExecutionPlan::required_input_ordering`. + // + // If the requirements are already satisfied, this rule leaves the plan + // unchanged. For example, it does not add sorting when the input is a + // file scan whose existing order already satisfies the required ordering. + // Otherwise, this rule inserts the necessary repartitioning and sorting + // operators. + // + // This used to be implemented as two separate rules: `EnforceDistribution` + // and `EnforceSorting`. It is now a single idempotent rule that decides + // distribution and sorting together in one bottom-up pass, so the + // `pushdown_sorts` step no longer breaks distribution invariants set + // earlier in the pipeline. See the module-level doc on + // [`EnsureRequirements`](crate::ensure_requirements) for the per-phase + // breakdown, and + // for the original failure mode. + Arc::new(EnsureRequirements::new()), + // The CombinePartialFinalAggregate rule should be applied after distribution enforcement Arc::new(CombinePartialFinalAggregate::new()), - // The EnforceSorting rule is for adding essential local sorting to satisfy the required - // ordering. Please make sure that the whole plan tree is determined before this rule. - // Note that one should always run this rule after running the EnforceDistribution rule - // as the latter may break local sorting requirements. - Arc::new(EnforceSorting::new()), // Run once after the local sorting requirement is changed Arc::new(OptimizeAggregateOrder::new()), + // WindowTopN: replaces Filter(rn<=K) → Window(ROW_NUMBER) → Sort + // with Window(ROW_NUMBER) → PartitionedTopKExec(fetch=K). + // Must run after EnsureRequirements (which inserts SortExec) and before + // ProjectionPushdown (which embeds projections into FilterExec). + Arc::new(WindowTopN::new()), // TODO: `try_embed_to_hash_join` in the ProjectionPushdown rule would be block by the CoalesceBatches, so add it before CoalesceBatches. Maybe optimize it in the future. Arc::new(ProjectionPushdown::new()), - // The CoalesceBatches rule will not influence the distribution and ordering of the - // whole plan tree. Therefore, to avoid influencing other rules, it should run last. - Arc::new(CoalesceBatches::new()), // Remove the ancillary output requirement operator since we are done with the planning // phase. Arc::new(OutputRequirements::new_remove_mode()), @@ -132,12 +209,21 @@ impl PhysicalOptimizer { Arc::new(TopKAggregation::new()), // Tries to push limits down through window functions, growing as appropriate // This can possibly be combined with [LimitPushdown] - // It needs to come after [EnforceSorting] + // It needs to come after [EnsureRequirements] (which handles sort enforcement) Arc::new(LimitPushPastWindows::new()), + // The HashJoinBuffering rule adds a BufferExec node with the configured capacity + // in the prob side of hash joins. That way, the probe side gets eagerly polled before + // the build side is completely finished. + Arc::new(HashJoinBuffering::new()), // The LimitPushdown rule tries to push limits down as far as possible, // replacing operators with fetching variants, or adding limits // past operators that support limit pushdown. Arc::new(LimitPushdown::new()), + // TopKRepartition pushes TopK (Sort with fetch) below Hash + // repartition when the partition key is a prefix of the sort key. + // This reduces data volume before a hash shuffle. It must run + // after LimitPushdown so that the TopK already exists on the SortExec. + Arc::new(TopKRepartition::new()), // The ProjectionPushdown rule tries to push projections towards // the sources in the execution plan. As a result of this process, // a projection can disappear if it reaches the source providers, and @@ -145,9 +231,11 @@ impl PhysicalOptimizer { // are not present, the load of executors such as join or union will be // reduced by narrowing their input tables. Arc::new(ProjectionPushdown::new()), + // PushdownSort: Detect sorts that can be pushed down to data sources. + Arc::new(PushdownSort::new()), Arc::new(EnsureCooperative::new()), // This FilterPushdown handles dynamic filters that may have references to the source ExecutionPlan. - // Therefore it should be run at the end of the optimization process since any changes to the plan may break the dynamic filter's references. + // Therefore, it should be run at the end of the optimization process since any changes to the plan may break the dynamic filter's references. // See `FilterPushdownPhase` for more details. Arc::new(FilterPushdown::new_post_optimization()), // The SanityCheckPlan rule checks whether the order and diff --git a/datafusion/physical-optimizer/src/output_requirements.rs b/datafusion/physical-optimizer/src/output_requirements.rs index 9e5e980219767..899abcc88ba59 100644 --- a/datafusion/physical-optimizer/src/output_requirements.rs +++ b/datafusion/physical-optimizer/src/output_requirements.rs @@ -34,7 +34,7 @@ use datafusion_physical_expr::Distribution; use datafusion_physical_expr_common::sort_expr::OrderingRequirements; use datafusion_physical_plan::execution_plan::Boundedness; use datafusion_physical_plan::projection::{ - make_with_child, update_expr, update_ordering_requirement, ProjectionExec, + ProjectionExec, make_with_child, update_expr, update_ordering_requirement, }; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; @@ -61,7 +61,9 @@ impl OutputRequirements { /// Create a new rule which works in `Add` mode; i.e. it simply adds a /// top-level [`OutputRequirementExec`] into the physical plan to keep track /// of global ordering and distribution requirements if there are any. - /// Note that this rule should run at the beginning. + /// Note that this rule should run at the beginning. It is idempotent: when + /// invoked on a plan that is already topped by an `OutputRequirementExec`, + /// it returns the plan unchanged. pub fn new_add_mode() -> Self { Self { mode: RuleMode::Add, @@ -98,7 +100,7 @@ pub struct OutputRequirementExec { input: Arc, order_requirement: Option, dist_requirement: Distribution, - cache: PlanProperties, + cache: Arc, fetch: Option, } @@ -114,7 +116,7 @@ impl OutputRequirementExec { input, order_requirement: requirements, dist_requirement, - cache, + cache: Arc::new(cache), fetch, } } @@ -196,11 +198,7 @@ impl ExecutionPlan for OutputRequirementExec { "OutputRequirementExec" } - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -244,11 +242,7 @@ impl ExecutionPlan for OutputRequirementExec { unreachable!(); } - fn statistics(&self) -> Result { - self.input.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { self.input.partition_statistics(partition) } @@ -312,9 +306,7 @@ impl PhysicalOptimizerRule for OutputRequirements { RuleMode::Add => require_top_ordering(plan), RuleMode::Remove => plan .transform_up(|plan| { - if let Some(sort_req) = - plan.as_any().downcast_ref::() - { + if let Some(sort_req) = plan.downcast_ref::() { Ok(Transformed::yes(sort_req.input())) } else { Ok(Transformed::no(plan)) @@ -335,7 +327,15 @@ impl PhysicalOptimizerRule for OutputRequirements { /// This functions adds ancillary `OutputRequirementExec` to the physical plan, so that /// global requirements are not lost during optimization. +/// +/// Idempotent: if the plan is already topped by an `OutputRequirementExec`, it +/// is returned unchanged so that re-running this rule (as adaptive execution +/// in datafusion-ballista AQE does after every completed stage, see +/// datafusion-ballista#1359) does not stack wrappers. fn require_top_ordering(plan: Arc) -> Result> { + if plan.downcast_ref::().is_some() { + return Ok(plan); + } let (new_plan, is_changed) = require_top_ordering_helper(plan)?; if is_changed { Ok(new_plan) @@ -361,7 +361,7 @@ fn require_top_ordering_helper( // Global ordering defines desired ordering in the final result. if children.len() != 1 { Ok((plan, false)) - } else if let Some(sort_exec) = plan.as_any().downcast_ref::() { + } else if let Some(sort_exec) = plan.downcast_ref::() { // In case of constant columns, output ordering of the `SortExec` would // be an empty set. Therefore; we check the sort expression field to // assign the requirements. @@ -379,7 +379,7 @@ fn require_top_ordering_helper( )) as _, true, )) - } else if let Some(spm) = plan.as_any().downcast_ref::() { + } else if let Some(spm) = plan.downcast_ref::() { let reqs = OrderingRequirements::from(spm.expr().clone()); let fetch = spm.fetch(); Ok(( @@ -402,7 +402,14 @@ fn require_top_ordering_helper( // be responsible for (i.e. the originator of) the global ordering. let (new_child, is_changed) = require_top_ordering_helper(Arc::clone(children.swap_remove(0)))?; - Ok((plan.with_new_children(vec![new_child])?, is_changed)) + + let plan = if is_changed { + plan.with_new_children(vec![new_child])? + } else { + plan + }; + + Ok((plan, is_changed)) } else { // Stop searching, there is no global ordering desired for the query. Ok((plan, false)) diff --git a/datafusion/physical-optimizer/src/projection_pushdown.rs b/datafusion/physical-optimizer/src/projection_pushdown.rs index b5e002b51f921..fe71c211769c8 100644 --- a/datafusion/physical-optimizer/src/projection_pushdown.rs +++ b/datafusion/physical-optimizer/src/projection_pushdown.rs @@ -32,13 +32,13 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{JoinSide, JoinType, Result}; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; +use datafusion_physical_expr_common::physical_expr::{PhysicalExpr, is_volatile}; +use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::joins::NestedLoopJoinExec; +use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use datafusion_physical_plan::projection::{ - remove_unnecessary_projections, ProjectionExec, + ProjectionExec, remove_unnecessary_projections, }; -use datafusion_physical_plan::ExecutionPlan; /// This rule inspects `ProjectionExec`'s in the given physical plan and tries to /// remove or swap with its child. @@ -50,7 +50,7 @@ use datafusion_physical_plan::ExecutionPlan; pub struct ProjectionPushdown {} impl ProjectionPushdown { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -64,15 +64,13 @@ impl PhysicalOptimizerRule for ProjectionPushdown { ) -> Result> { let alias_generator = AliasGenerator::new(); let plan = plan - .transform_up(|plan| { - match plan.as_any().downcast_ref::() { - None => Ok(Transformed::no(plan)), - Some(hash_join) => try_push_down_join_filter( - Arc::clone(&plan), - hash_join, - &alias_generator, - ), - } + .transform_up(|plan| match plan.downcast_ref::() { + None => Ok(Transformed::no(plan)), + Some(hash_join) => try_push_down_join_filter( + Arc::clone(&plan), + hash_join, + &alias_generator, + ), }) .map(|t| t.data)?; @@ -135,7 +133,7 @@ fn try_push_down_join_filter( ); let new_lhs_length = lhs_rewrite.data.0.schema().fields.len(); - let projections = match projections { + let projections = match projections.as_ref() { None => match join.join_type() { JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { // Build projections that ignore the newly projected columns. @@ -244,7 +242,7 @@ fn minimize_join_filter( ) -> JoinFilter { let mut used_columns = HashSet::new(); expr.apply(|expr| { - if let Some(col) = expr.as_any().downcast_ref::() { + if let Some(col) = expr.downcast_ref::() { used_columns.insert(col.index()); } Ok(TreeNodeRecursion::Continue) @@ -267,7 +265,7 @@ fn minimize_join_filter( .collect::(); let final_expr = expr - .transform_up(|expr| match expr.as_any().downcast_ref::() { + .transform_up(|expr| match expr.downcast_ref::() { None => Ok(Transformed::no(expr)), Some(column) => { let new_idx = used_columns @@ -349,8 +347,7 @@ impl<'a> JoinFilterRewriter<'a> { // Recurse if there is a dependency to both sides or if the entire expression is volatile. let depends_on_other_side = self.depends_on_join_side(&expr, self.join_side.negate())?; - let is_volatile = is_volatile_expression_tree(expr.as_ref()); - if depends_on_other_side || is_volatile { + if depends_on_other_side || is_volatile(&expr) { return expr.map_children(|expr| self.rewrite(expr)); } @@ -381,7 +378,7 @@ impl<'a> JoinFilterRewriter<'a> { // executed against the filter schema. let new_idx = self.join_side_projections.len(); let rewritten_expr = expr.transform_up(|expr| { - Ok(match expr.as_any().downcast_ref::() { + Ok(match expr.downcast_ref::() { None => Transformed::no(expr), Some(column) => { let intermediate_column = @@ -415,7 +412,7 @@ impl<'a> JoinFilterRewriter<'a> { join_side: JoinSide, ) -> Result { let mut result = false; - expr.apply(|expr| match expr.as_any().downcast_ref::() { + expr.apply(|expr| match expr.downcast_ref::() { None => Ok(TreeNodeRecursion::Continue), Some(c) => { let column_index = &self.intermediate_column_indices[c.index()]; @@ -431,26 +428,14 @@ impl<'a> JoinFilterRewriter<'a> { } } -fn is_volatile_expression_tree(expr: &dyn PhysicalExpr) -> bool { - if expr.is_volatile_node() { - return true; - } - - expr.children() - .iter() - .map(|expr| is_volatile_expression_tree(expr.as_ref())) - .reduce(|lhs, rhs| lhs || rhs) - .unwrap_or(false) -} - #[cfg(test)] mod test { use super::*; use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use datafusion_expr_common::operator::Operator; use datafusion_functions::math::random; - use datafusion_physical_expr::expressions::{binary, lit}; use datafusion_physical_expr::ScalarFunctionExpr; + use datafusion_physical_expr::expressions::{binary, lit}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::displayable; use datafusion_physical_plan::empty::EmptyExec; diff --git a/datafusion/physical-optimizer/src/pushdown_sort.rs b/datafusion/physical-optimizer/src/pushdown_sort.rs new file mode 100644 index 0000000000000..40a6fe2c205c7 --- /dev/null +++ b/datafusion/physical-optimizer/src/pushdown_sort.rs @@ -0,0 +1,199 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sort Pushdown Optimization +//! +//! This optimizer attempts to push sort requirements down through the execution plan +//! tree to data sources that can natively handle them (e.g., by scanning files in +//! reverse order). +//! +//! ## How it works +//! +//! 1. Detects `SortExec` nodes in the plan +//! 2. Calls `try_pushdown_sort()` on the input to recursively push the sort requirement +//! 3. Each node type defines its own pushdown behavior: +//! - **Transparent nodes** (CoalesceBatchesExec, RepartitionExec, etc.) delegate to +//! their children and wrap the result +//! - **Data sources** (DataSourceExec) check if they can optimize for the ordering +//! - **Blocking nodes** return `Unsupported` to stop pushdown +//! 4. Based on the result: +//! - `Exact`: Remove the Sort operator (data source guarantees perfect ordering) +//! - `Inexact`: Keep Sort but use optimized input (enables early termination for TopK) +//! - `Unsupported`: No change +//! +//! ## Capabilities +//! +//! - **Sort elimination**: when a data source's natural ordering satisfies the +//! request, return `Exact` and remove the `SortExec` entirely. Preserves +//! `fetch` (LIMIT) from the eliminated `SortExec` for early termination. +//! - **Statistics-based file sorting**: sort files within each partition by +//! min/max statistics. When files are non-overlapping but listed in wrong +//! order (e.g., alphabetical order ≠ sort key order), this fixes the ordering +//! and enables sort elimination. Works for both single-partition and +//! multi-partition plans with multi-file groups. +//! - **Reverse scan optimization**: when required sort is the reverse of the data source's +//! natural ordering, enable reverse scanning (reading row groups in reverse order) +//! - **Prefix matching**: if data has ordering [A DESC, B ASC] and query needs +//! [A DESC], the existing ordering satisfies the requirement (`Exact`). +//! If the query needs [A ASC] (reverse of the prefix), a reverse scan is +//! used (`Inexact`, `SortExec` retained) +//! +//! Related issue: + +use crate::PhysicalOptimizerRule; +use datafusion_common::Result; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::SortOrderPushdownResult; +use datafusion_physical_plan::buffer::BufferExec; +use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use std::sync::Arc; + +/// A PhysicalOptimizerRule that attempts to push down sort requirements to data sources. +/// +/// See module-level documentation for details. +#[derive(Debug, Clone, Default)] +pub struct PushdownSort; + +impl PushdownSort { + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for PushdownSort { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + // Check if sort pushdown optimization is enabled + if !config.optimizer.enable_sort_pushdown { + return Ok(plan); + } + + let buffer_capacity = config.execution.sort_pushdown_buffer_capacity; + + // Use transform_down to find and optimize all SortExec nodes (including nested ones) + // Also handles SPM → SortExec pattern to insert BufferExec when sort is eliminated + plan.transform_down(|plan: Arc| { + // Pattern 1: SPM → SortExec(preserve_partitioning) + // When we eliminate the SortExec, SPM loses its memory buffer and reads + // directly from I/O-bound sources. Insert a BufferExec to compensate. + if let Some(spm) = plan.downcast_ref::() + && let Some(sort_child) = spm.input().downcast_ref::() + && sort_child.preserve_partitioning() + { + let sort_input = Arc::clone(sort_child.input()); + let required_ordering = sort_child.expr(); + match sort_input.try_pushdown_sort(required_ordering)? { + SortOrderPushdownResult::Exact { inner } => { + // Preserve fetch (LIMIT) from the eliminated SortExec. + // Use LocalLimitExec (not Global) since input is multi-partition. + let inner = if let Some(fetch) = sort_child.fetch() { + inner.with_fetch(Some(fetch)).unwrap_or_else(|| { + Arc::new(LocalLimitExec::new(inner, fetch)) + }) + } else { + inner + }; + // Insert BufferExec to replace SortExec's buffering role. + // SortExec buffered all data in memory; BufferExec provides + // bounded buffering so SPM doesn't stall on I/O. + let buffered: Arc = + Arc::new(BufferExec::new(inner, buffer_capacity)); + let new_spm = + SortPreservingMergeExec::new(spm.expr().clone(), buffered) + .with_fetch(spm.fetch()); + return Ok(Transformed::yes(Arc::new(new_spm))); + } + SortOrderPushdownResult::Inexact { inner } => { + let new_sort = SortExec::new(required_ordering.clone(), inner) + .with_fetch(sort_child.fetch()) + .with_preserve_partitioning(true); + let new_spm = SortPreservingMergeExec::new( + spm.expr().clone(), + Arc::new(new_sort), + ) + .with_fetch(spm.fetch()); + return Ok(Transformed::yes(Arc::new(new_spm))); + } + SortOrderPushdownResult::Unsupported => { + return Ok(Transformed::no(plan)); + } + } + } + + // Pattern 2: Standalone SortExec (no SPM parent) + let Some(sort_exec) = plan.downcast_ref::() else { + return Ok(Transformed::no(plan)); + }; + + let sort_input = Arc::clone(sort_exec.input()); + let required_ordering = sort_exec.expr(); + + // Try to push the sort requirement down through the plan tree + // Each node type defines its own pushdown behavior via try_pushdown_sort() + match sort_input.try_pushdown_sort(required_ordering)? { + SortOrderPushdownResult::Exact { inner } => { + // Data source guarantees perfect ordering - remove the Sort operator. + // + // If the SortExec carried a fetch (LIMIT), we must preserve it. + // First try pushing the limit into the source via `with_fetch()`. + // If the source doesn't support `with_fetch`, fall back to + // wrapping with GlobalLimitExec. + if let Some(fetch) = sort_exec.fetch() { + let inner = inner.with_fetch(Some(fetch)).unwrap_or_else(|| { + Arc::new(GlobalLimitExec::new(inner, 0, Some(fetch))) + }); + Ok(Transformed::yes(inner)) + } else { + Ok(Transformed::yes(inner)) + } + } + SortOrderPushdownResult::Inexact { inner } => { + // Data source is optimized for the ordering but not perfectly sorted + // Keep the Sort operator but use the optimized input + // Benefits: TopK queries can terminate early, better cache locality + Ok(Transformed::yes(Arc::new( + SortExec::new(required_ordering.clone(), inner) + .with_fetch(sort_exec.fetch()) + .with_preserve_partitioning( + sort_exec.preserve_partitioning(), + ), + ))) + } + SortOrderPushdownResult::Unsupported => { + // Cannot optimize for this ordering - no change + Ok(Transformed::no(plan)) + } + } + }) + .data() + } + + fn name(&self) -> &str { + "PushdownSort" + } + + fn schema_check(&self) -> bool { + true + } +} diff --git a/datafusion/physical-optimizer/src/sanity_checker.rs b/datafusion/physical-optimizer/src/sanity_checker.rs index acc70d39f057b..40c6245d894d4 100644 --- a/datafusion/physical-optimizer/src/sanity_checker.rs +++ b/datafusion/physical-optimizer/src/sanity_checker.rs @@ -28,11 +28,11 @@ use datafusion_physical_plan::ExecutionPlan; use datafusion_common::config::{ConfigOptions, OptimizerOptions}; use datafusion_common::plan_err; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion_physical_plan::joins::SymmetricHashJoinExec; -use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; +use datafusion_physical_plan::{ExecutionPlanProperties, get_plan_string}; use crate::PhysicalOptimizerRule; use datafusion_physical_expr_common::sort_expr::format_physical_sort_requirement_list; @@ -47,7 +47,7 @@ use itertools::izip; pub struct SanityCheckPlan {} impl SanityCheckPlan { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self {} } @@ -59,8 +59,8 @@ impl PhysicalOptimizerRule for SanityCheckPlan { plan: Arc, config: &ConfigOptions, ) -> Result> { - plan.transform_up(|p| check_plan_sanity(p, &config.optimizer)) - .data() + check_plan_sanity_recursive(&plan, &config.optimizer)?; + Ok(plan) } fn name(&self) -> &str { @@ -72,19 +72,31 @@ impl PhysicalOptimizerRule for SanityCheckPlan { } } +/// Bottom-up (post-order) read-only traversal that checks plan sanity. +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] +fn check_plan_sanity_recursive( + plan: &Arc, + optimizer_options: &OptimizerOptions, +) -> Result { + plan.apply_children(|child| check_plan_sanity_recursive(child, optimizer_options))?; + check_plan_sanity(plan, optimizer_options)?; + Ok(TreeNodeRecursion::Continue) +} + /// This function propagates finiteness information and rejects any plan with /// pipeline-breaking operators acting on infinite inputs. pub fn check_finiteness_requirements( - input: Arc, + input: &dyn ExecutionPlan, optimizer_options: &OptimizerOptions, -) -> Result>> { - if let Some(exec) = input.as_any().downcast_ref::() { - if !(optimizer_options.allow_symmetric_joins_without_pruning +) -> Result<()> { + if let Some(exec) = input.downcast_ref::() + && !(optimizer_options.allow_symmetric_joins_without_pruning || (exec.check_if_order_information_available()? && is_prunable(exec))) - { - return plan_err!("Join operation cannot operate on a non-prunable stream without enabling \ - the 'allow_symmetric_joins_without_pruning' configuration flag"); - } + { + return plan_err!( + "Join operation cannot operate on a non-prunable stream without enabling \ + the 'allow_symmetric_joins_without_pruning' configuration flag" + ); } if matches!( @@ -100,7 +112,7 @@ pub fn check_finiteness_requirements( input ) } else { - Ok(Transformed::no(input)) + Ok(()) } } @@ -125,10 +137,10 @@ fn is_prunable(join: &SymmetricHashJoinExec) -> bool { /// Ensures that the plan is pipeline friendly and the order and /// distribution requirements from its children are satisfied. pub fn check_plan_sanity( - plan: Arc, + plan: &Arc, optimizer_options: &OptimizerOptions, -) -> Result>> { - check_finiteness_requirements(Arc::clone(&plan), optimizer_options)?; +) -> Result<()> { + check_finiteness_requirements(plan.as_ref(), optimizer_options)?; for ((idx, child), sort_req, dist_req) in izip!( plan.children().into_iter().enumerate(), @@ -139,7 +151,7 @@ pub fn check_plan_sanity( if let Some(sort_req) = sort_req { let sort_req = sort_req.into_single(); if !child_eq_props.ordering_satisfy_requirement(sort_req.clone())? { - let plan_str = get_plan_string(&plan); + let plan_str = get_plan_string(plan); return plan_err!( "Plan: {:?} does not satisfy order requirements: {}. Child-{} order: {}", plan_str, @@ -152,9 +164,10 @@ pub fn check_plan_sanity( if !child .output_partitioning() - .satisfy(&dist_req, child_eq_props) + .satisfaction(&dist_req, child_eq_props, true) + .is_satisfied() { - let plan_str = get_plan_string(&plan); + let plan_str = get_plan_string(plan); return plan_err!( "Plan: {:?} does not satisfy distribution requirements: {}. Child-{} output partitioning: {}", plan_str, @@ -165,7 +178,7 @@ pub fn check_plan_sanity( } } - Ok(Transformed::no(plan)) + Ok(()) } // See tests in datafusion/core/tests/physical_optimizer diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs index b7505f0df4edb..e1779c04a6a92 100644 --- a/datafusion/physical-optimizer/src/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -20,16 +20,16 @@ use std::sync::Arc; use crate::PhysicalOptimizerRule; -use arrow::datatypes::DataType; +use datafusion_common::Result; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::Result; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::aggregates::LimitOptions; +use datafusion_physical_plan::aggregates::{AggregateExec, topk_types_supported}; use datafusion_physical_plan::execution_plan::CardinalityEffect; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::sorts::sort::SortExec; -use datafusion_physical_plan::ExecutionPlan; use itertools::Itertools; /// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed @@ -48,52 +48,59 @@ impl TopKAggregation { order_desc: bool, limit: usize, ) -> Option> { - // ensure the sort direction matches aggregate function - let (field, desc) = aggr.get_minmax_desc()?; - if desc != order_desc { - return None; - } - let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?; - let kt = group_key.0.data_type(&aggr.input().schema()).ok()?; - if !kt.is_primitive() - && kt != DataType::Utf8 - && kt != DataType::Utf8View - && kt != DataType::LargeUtf8 - { + // Current only support single group key + let (group_key, group_key_alias) = + aggr.group_expr().expr().iter().exactly_one().ok()?; + let kt = group_key.data_type(&aggr.input().schema()).ok()?; + let vt = if let Some((field, _)) = aggr.get_minmax_desc() { + field.data_type().clone() + } else { + kt.clone() + }; + if !topk_types_supported(&kt, &vt) { return None; } if aggr.filter_expr().iter().any(|e| e.is_some()) { return None; } - // ensure the sort is on the same field as the aggregate output - if order_by != field.name() { + // Check if this is ordering by an aggregate function (MIN/MAX) + if let Some((field, desc)) = aggr.get_minmax_desc() { + // ensure the sort direction matches aggregate function + if desc != order_desc { + return None; + } + // ensure the sort is on the same field as the aggregate output + if order_by != field.name() { + return None; + } + } else if aggr.aggr_expr().is_empty() { + // This is a GROUP BY without aggregates, check if ordering is on the group key itself + if order_by != group_key_alias { + return None; + } + } else { + // Has aggregates but not MIN/MAX, or doesn't DISTINCT return None; } // We found what we want: clone, copy the limit down, and return modified node - let new_aggr = AggregateExec::try_new( - *aggr.mode(), - aggr.group_expr().clone(), - aggr.aggr_expr().to_vec(), - aggr.filter_expr().to_vec(), - Arc::clone(aggr.input()), - aggr.input_schema(), - ) - .expect("Unable to copy Aggregate!") - .with_limit(Some(limit)); + let new_aggr = AggregateExec::with_new_limit_options( + aggr, + Some(LimitOptions::new_with_order(limit, order_desc)), + ); Some(Arc::new(new_aggr)) } fn transform_sort(plan: &Arc) -> Option> { - let sort = plan.as_any().downcast_ref::()?; + let sort = plan.downcast_ref::()?; let children = sort.children(); let child = children.into_iter().exactly_one().ok()?; let order = sort.properties().output_ordering()?; let order = order.iter().exactly_one().ok()?; let order_desc = order.options.descending; - let order = order.expr.as_any().downcast_ref::()?; + let order = order.expr.downcast_ref::()?; let mut cur_col_name = order.name().to_string(); let limit = sort.fetch()?; @@ -102,17 +109,16 @@ impl TopKAggregation { if !cardinality_preserved { return Ok(Transformed::no(plan)); } - if let Some(aggr) = plan.as_any().downcast_ref::() { + if let Some(aggr) = plan.downcast_ref::() { // either we run into an Aggregate and transform it match Self::transform_agg(aggr, &cur_col_name, order_desc, limit) { None => cardinality_preserved = false, Some(plan) => return Ok(Transformed::yes(plan)), } - } else if let Some(proj) = plan.as_any().downcast_ref::() { + } else if let Some(proj) = plan.downcast_ref::() { // track renames due to successive projections for proj_expr in proj.expr() { - let Some(src_col) = proj_expr.expr.as_any().downcast_ref::() - else { + let Some(src_col) = proj_expr.expr.downcast_ref::() else { continue; }; if proj_expr.alias == cur_col_name { diff --git a/datafusion/physical-optimizer/src/topk_repartition.rs b/datafusion/physical-optimizer/src/topk_repartition.rs new file mode 100644 index 0000000000000..115bdc3cb535f --- /dev/null +++ b/datafusion/physical-optimizer/src/topk_repartition.rs @@ -0,0 +1,367 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Push TopK (Sort with fetch) past Hash Repartition +//! +//! When a `SortExec` with a fetch limit (TopK) sits above a +//! `RepartitionExec(Hash)`, and the hash partition expressions are a prefix +//! of the sort expressions, this rule inserts a copy of the TopK below +//! the repartition to reduce the volume of data flowing through the shuffle. +//! +//! This is correct because the hash partition key being a prefix of the sort +//! key guarantees that all rows with the same partition key end up in the same +//! output partition. Therefore, rows that survive the final TopK after +//! repartitioning will always survive the pre-repartition TopK as well. +//! +//! ## Example +//! +//! Before: +//! ```text +//! SortExec: TopK(fetch=3), expr=[a ASC, b ASC] +//! RepartitionExec: Hash([a], 4) +//! DataSourceExec +//! ``` +//! +//! After: +//! ```text +//! SortExec: TopK(fetch=3), expr=[a ASC, b ASC] +//! RepartitionExec: Hash([a], 4) +//! SortExec: TopK(fetch=3), expr=[a ASC, b ASC] +//! DataSourceExec +//! ``` + +use crate::PhysicalOptimizerRule; +use datafusion_common::Result; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use std::sync::Arc; +// CoalesceBatchesExec is deprecated on main (replaced by arrow-rs BatchCoalescer), +// but older DataFusion versions may still insert it between SortExec and RepartitionExec. +#[expect(deprecated)] +use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::{ExecutionPlan, Partitioning}; + +/// A physical optimizer rule that pushes TopK (Sort with fetch) past +/// hash repartition when the partition key is a prefix of the sort key. +/// +/// See module-level documentation for details. +#[derive(Debug, Clone, Default)] +pub struct TopKRepartition; + +impl TopKRepartition { + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for TopKRepartition { + #[expect(deprecated)] // CoalesceBatchesExec: kept for older DataFusion versions + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + if !config.optimizer.enable_topk_repartition { + return Ok(plan); + } + plan.transform_down(|node| { + // Match SortExec with fetch (TopK) + let Some(sort_exec) = node.downcast_ref::() else { + return Ok(Transformed::no(node)); + }; + let Some(fetch) = sort_exec.fetch() else { + return Ok(Transformed::no(node)); + }; + + // The child might be a CoalesceBatchesExec; look through it + let sort_input = sort_exec.input(); + let (repart_parent, repart_exec) = if let Some(rp) = + sort_input.downcast_ref::() + { + // found a RepartitionExec, use it + (None, rp) + } else if let Some(cb_exec) = sort_input.downcast_ref::() + { + // There's a CoalesceBatchesExec between TopK & RepartitionExec + // in this case we will need to reconstruct both nodes + let cb_input = cb_exec.input(); + let Some(rp) = cb_input.downcast_ref::() else { + return Ok(Transformed::no(node)); + }; + (Some(Arc::clone(sort_input)), rp) + } else { + return Ok(Transformed::no(node)); + }; + + // Only handle Hash partitioning + let Partitioning::Hash(hash_exprs, num_partitions) = + repart_exec.partitioning() + else { + return Ok(Transformed::no(node)); + }; + + let sort_exprs = sort_exec.expr(); + + // Check that hash expressions are a prefix of the sort expressions. + // Each hash expression must match the corresponding sort expression + // (ignoring sort options like ASC/DESC since hash doesn't care about order). + if hash_exprs.len() > sort_exprs.len() { + return Ok(Transformed::no(node)); + } + for (hash_expr, sort_expr) in hash_exprs.iter().zip(sort_exprs.iter()) { + if !hash_expr.eq(&sort_expr.expr) { + return Ok(Transformed::no(node)); + } + } + + // Don't push if the input to the repartition is already bounded + // (e.g., another TopK), as it would be redundant. + let repart_input = repart_exec.input(); + if repart_input.is::() { + return Ok(Transformed::no(node)); + } + + // Insert a copy of the TopK below the repartition + let new_sort: Arc = Arc::new( + SortExec::new(sort_exprs.clone(), Arc::clone(repart_input)) + .with_fetch(Some(fetch)) + .with_preserve_partitioning(sort_exec.preserve_partitioning()), + ); + + let new_partitioning = + Partitioning::Hash(hash_exprs.clone(), *num_partitions); + let new_repartition: Arc = + Arc::new(RepartitionExec::try_new(new_sort, new_partitioning)?); + + // Rebuild the tree above the repartition + let new_sort_input = if let Some(parent) = repart_parent { + parent.with_new_children(vec![new_repartition])? + } else { + new_repartition + }; + + let new_top_sort: Arc = Arc::new( + SortExec::new(sort_exprs.clone(), new_sort_input) + .with_fetch(Some(fetch)) + .with_preserve_partitioning(sort_exec.preserve_partitioning()), + ); + + Ok(Transformed::yes(new_top_sort)) + }) + .data() + } + + fn name(&self) -> &str { + "TopKRepartition" + } + + fn schema_check(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion_physical_plan::displayable; + use datafusion_physical_plan::test::scan_partitioned; + use insta::assert_snapshot; + + fn schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int64, false), + ])) + } + + fn sort_exprs(schema: &Schema) -> LexOrdering { + LexOrdering::new(vec![ + PhysicalSortExpr::new_default(col("a", schema).unwrap()).asc(), + PhysicalSortExpr::new_default(col("b", schema).unwrap()).asc(), + ]) + .unwrap() + } + + /// TopK above Hash(a) repartition should get pushed below it, + /// because `a` is a prefix of the sort key `(a, b)`. + #[test] + fn topk_pushed_below_hash_repartition() { + let s = schema(); + let input = scan_partitioned(1); + let ordering = sort_exprs(&s); + + let repartition = Arc::new( + RepartitionExec::try_new( + input, + Partitioning::Hash(vec![col("a", &s).unwrap()], 4), + ) + .unwrap(), + ); + + let sort = Arc::new( + SortExec::new(ordering, repartition) + .with_fetch(Some(3)) + .with_preserve_partitioning(true), + ); + + let config = ConfigOptions::new(); + let optimized = TopKRepartition::new().optimize(sort, &config).unwrap(); + + let display = displayable(optimized.as_ref()).indent(true).to_string(); + assert_snapshot!(display, @r" + SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true], sort_prefix=[a@0 ASC] + RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1, maintains_sort_order=true + SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true] + DataSourceExec: partitions=1, partition_sizes=[1] + "); + } + + /// TopK with no fetch (unbounded sort) should NOT be pushed. + #[test] + fn unbounded_sort_not_pushed() { + let s = schema(); + let input = scan_partitioned(1); + let ordering = sort_exprs(&s); + + let repartition = Arc::new( + RepartitionExec::try_new( + input, + Partitioning::Hash(vec![col("a", &s).unwrap()], 4), + ) + .unwrap(), + ); + + let sort: Arc = Arc::new( + SortExec::new(ordering, repartition).with_preserve_partitioning(true), + ); + + let config = ConfigOptions::new(); + let optimized = TopKRepartition::new().optimize(sort, &config).unwrap(); + + let display = displayable(optimized.as_ref()).indent(true).to_string(); + assert_snapshot!(display, @r" + SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1] + "); + } + + /// Hash key NOT a prefix of sort key should NOT be pushed. + #[test] + fn non_prefix_hash_key_not_pushed() { + let s = schema(); + let input = scan_partitioned(1); + let ordering = sort_exprs(&s); + + // Hash by `b`, but sort by `(a, b)` - b is not a prefix + let repartition = Arc::new( + RepartitionExec::try_new( + input, + Partitioning::Hash(vec![col("b", &s).unwrap()], 4), + ) + .unwrap(), + ); + + let sort: Arc = Arc::new( + SortExec::new(ordering, repartition) + .with_fetch(Some(3)) + .with_preserve_partitioning(true), + ); + + let config = ConfigOptions::new(); + let optimized = TopKRepartition::new().optimize(sort, &config).unwrap(); + + let display = displayable(optimized.as_ref()).indent(true).to_string(); + assert_snapshot!(display, @r" + SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=Hash([b@1], 4), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1] + "); + } + + /// TopK above CoalesceBatchesExec above Hash(a) repartition should + /// push through both, inserting a new TopK below the repartition. + #[expect(deprecated)] + #[test] + fn topk_pushed_through_coalesce_batches() { + let s = schema(); + let input = scan_partitioned(1); + let ordering = sort_exprs(&s); + + let repartition = Arc::new( + RepartitionExec::try_new( + input, + Partitioning::Hash(vec![col("a", &s).unwrap()], 4), + ) + .unwrap(), + ); + + let coalesce: Arc = + Arc::new(CoalesceBatchesExec::new(repartition, 8192)); + + let sort = Arc::new( + SortExec::new(ordering, coalesce) + .with_fetch(Some(3)) + .with_preserve_partitioning(true), + ); + + let config = ConfigOptions::new(); + let optimized = TopKRepartition::new().optimize(sort, &config).unwrap(); + + let display = displayable(optimized.as_ref()).indent(true).to_string(); + assert_snapshot!(display, @r" + SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true], sort_prefix=[a@0 ASC] + CoalesceBatchesExec: target_batch_size=8192 + RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1, maintains_sort_order=true + SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true] + DataSourceExec: partitions=1, partition_sizes=[1] + "); + } + + /// RoundRobin repartition should NOT be pushed. + #[test] + fn round_robin_not_pushed() { + let s = schema(); + let input = scan_partitioned(1); + let ordering = sort_exprs(&s); + + let repartition = Arc::new( + RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(4)).unwrap(), + ); + + let sort: Arc = Arc::new( + SortExec::new(ordering, repartition) + .with_fetch(Some(3)) + .with_preserve_partitioning(true), + ); + + let config = ConfigOptions::new(); + let optimized = TopKRepartition::new().optimize(sort, &config).unwrap(); + + let display = displayable(optimized.as_ref()).indent(true).to_string(); + assert_snapshot!(display, @r" + SortExec: TopK(fetch=3), expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[true] + RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 + DataSourceExec: partitions=1, partition_sizes=[1] + "); + } +} diff --git a/datafusion/physical-optimizer/src/update_aggr_exprs.rs b/datafusion/physical-optimizer/src/update_aggr_exprs.rs index 61bc715592af6..2430918e2c2db 100644 --- a/datafusion/physical-optimizer/src/update_aggr_exprs.rs +++ b/datafusion/physical-optimizer/src/update_aggr_exprs.rs @@ -22,10 +22,12 @@ use std::sync::Arc; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{plan_datafusion_err, Result}; +use datafusion_common::{Result, plan_datafusion_err}; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement}; -use datafusion_physical_plan::aggregates::{concat_slices, AggregateExec}; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateInputMode, concat_slices, +}; use datafusion_physical_plan::windows::get_ordered_partition_by_indices; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; @@ -49,7 +51,7 @@ use crate::PhysicalOptimizerRule; pub struct OptimizeAggregateOrder {} impl OptimizeAggregateOrder { - #[allow(missing_docs)] + #[expect(missing_docs)] pub fn new() -> Self { Self::default() } @@ -76,12 +78,12 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder { _config: &ConfigOptions, ) -> Result> { plan.transform_up(|plan| { - if let Some(aggr_exec) = plan.as_any().downcast_ref::() { + if let Some(aggr_exec) = plan.downcast_ref::() { // Final stage implementations do not rely on ordering -- those // ordering fields may be pruned out by first stage aggregates. // Hence, necessary information for proper merge is added during // the first stage to the state field, which the final stage uses. - if !aggr_exec.mode().is_first_stage() { + if aggr_exec.mode().input_mode() == AggregateInputMode::Partial { return Ok(Transformed::no(plan)); } let input = aggr_exec.input(); diff --git a/datafusion/physical-optimizer/src/utils.rs b/datafusion/physical-optimizer/src/utils.rs index 13a1745216e83..04229e1cc2737 100644 --- a/datafusion/physical-optimizer/src/utils.rs +++ b/datafusion/physical-optimizer/src/utils.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use datafusion_common::Result; -use datafusion_physical_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr::{Distribution, LexOrdering, LexRequirement}; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; @@ -58,6 +58,56 @@ pub fn add_sort_above( PlanContext::new(Arc::new(new_sort), T::default(), vec![node]) } +/// Like [`add_sort_above`], but also inserts a [`SortPreservingMergeExec`] when +/// the parent distribution requires a single partition and the input has +/// multiple partitions. This prevents `SortExec(preserve_partitioning=true)` +/// from violating `SinglePartition` requirements. +pub fn add_sort_above_with_distribution( + node: PlanContext, + sort_requirements: LexRequirement, + fetch: Option, + required_distribution: &Distribution, +) -> PlanContext { + let mut sort_reqs: Vec<_> = sort_requirements.into(); + sort_reqs.retain(|sort_expr| { + node.plan + .equivalence_properties() + .is_expr_constant(&sort_expr.expr) + .is_none() + }); + let sort_exprs = sort_reqs.into_iter().map(Into::into).collect::>(); + let Some(ordering) = LexOrdering::new(sort_exprs) else { + return node; + }; + let input_has_multiple_partitions = + node.plan.output_partitioning().partition_count() > 1; + + let mut new_sort = + SortExec::new(ordering.clone(), Arc::clone(&node.plan)).with_fetch(fetch); + if input_has_multiple_partitions { + new_sort = new_sort.with_preserve_partitioning(true); + } + + let sort_node = PlanContext::new(Arc::new(new_sort), T::default(), vec![node]); + + // If the parent requires SinglePartition and the input has multiple partitions, + // wrap the partition-preserving sort in SortPreservingMergeExec. + if matches!(required_distribution, Distribution::SinglePartition) + && input_has_multiple_partitions + { + PlanContext::new( + Arc::new( + SortPreservingMergeExec::new(ordering, Arc::clone(&sort_node.plan)) + .with_fetch(fetch), + ), + T::default(), + vec![sort_node], + ) + } else { + sort_node + } +} + /// This utility function adds a `SortExec` above an operator according to the /// given ordering requirements while preserving the original partitioning. If /// requirement is already satisfied no `SortExec` is added. @@ -79,37 +129,37 @@ pub fn add_sort_above_with_check( /// Checks whether the given operator is a [`SortExec`]. pub fn is_sort(plan: &Arc) -> bool { - plan.as_any().is::() + plan.is::() } /// Checks whether the given operator is a window; /// i.e. either a [`WindowAggExec`] or a [`BoundedWindowAggExec`]. pub fn is_window(plan: &Arc) -> bool { - plan.as_any().is::() || plan.as_any().is::() + plan.is::() || plan.is::() } /// Checks whether the given operator is a [`UnionExec`]. pub fn is_union(plan: &Arc) -> bool { - plan.as_any().is::() + plan.is::() } /// Checks whether the given operator is a [`SortPreservingMergeExec`]. pub fn is_sort_preserving_merge(plan: &Arc) -> bool { - plan.as_any().is::() + plan.is::() } /// Checks whether the given operator is a [`CoalescePartitionsExec`]. pub fn is_coalesce_partitions(plan: &Arc) -> bool { - plan.as_any().is::() + plan.is::() } /// Checks whether the given operator is a [`RepartitionExec`]. pub fn is_repartition(plan: &Arc) -> bool { - plan.as_any().is::() + plan.is::() } /// Checks whether the given operator is a limit; /// i.e. either a [`LocalLimitExec`] or a [`GlobalLimitExec`]. pub fn is_limit(plan: &Arc) -> bool { - plan.as_any().is::() || plan.as_any().is::() + plan.is::() || plan.is::() } diff --git a/datafusion/physical-optimizer/src/window_topn.rs b/datafusion/physical-optimizer/src/window_topn.rs new file mode 100644 index 0000000000000..40dbddfbdf9fb --- /dev/null +++ b/datafusion/physical-optimizer/src/window_topn.rs @@ -0,0 +1,331 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`WindowTopN`] optimizer rule for per-partition top-K window queries. +//! +//! Detects queries of the form: +//! +//! ```sql +//! SELECT * FROM ( +//! SELECT *, ROW_NUMBER() OVER (PARTITION BY pk ORDER BY val) as rn +//! FROM t +//! ) WHERE rn <= K; +//! ``` +//! +//! And replaces the `FilterExec → BoundedWindowAggExec → SortExec` pipeline +//! with `BoundedWindowAggExec → PartitionedTopKExec(fetch=K)`, removing both +//! the `FilterExec` and `SortExec`. +//! +//! See [`PartitionedTopKExec`] +//! for details on the replacement operator. + +use std::sync::Arc; + +use crate::PhysicalOptimizerRule; +use arrow::datatypes::DataType; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; +use datafusion_physical_expr::window::StandardWindowExpr; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::sorts::partitioned_topk::PartitionedTopKExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::windows::{BoundedWindowAggExec, WindowUDFExpr}; + +/// Physical optimizer rule that converts per-partition `ROW_NUMBER` top-K +/// queries into a more efficient plan using [`PartitionedTopKExec`]. +/// +/// # Pattern Detected +/// +/// ```text +/// FilterExec(rn <= K) +/// [optional ProjectionExec] +/// BoundedWindowAggExec(ROW_NUMBER PARTITION BY ... ORDER BY ...) +/// SortExec(partition_keys, order_keys) +/// ``` +/// +/// # Replacement +/// +/// ```text +/// [optional ProjectionExec] +/// BoundedWindowAggExec(ROW_NUMBER PARTITION BY ... ORDER BY ...) +/// PartitionedTopKExec(partition_keys, order_keys, fetch=K) +/// ``` +/// +/// The `FilterExec` is removed entirely (all output rows have `rn ∈ {1..K}`). +/// The `SortExec` is replaced by `PartitionedTopKExec` which maintains a +/// per-partition top-K heap instead of sorting the entire dataset. +/// +/// # Supported Predicates +/// +/// - `rn <= K` → fetch = K +/// - `rn < K` → fetch = K - 1 +/// - `K >= rn` (flipped) → fetch = K +/// - `K > rn` (flipped) → fetch = K - 1 +/// +/// # When the Rule Fires +/// +/// All of the following must be true: +/// - Config flag `enable_window_topn` is `true` +/// - The plan matches `FilterExec → [ProjectionExec] → BoundedWindowAggExec → SortExec` +/// - The window function is `ROW_NUMBER` (not `RANK`, `DENSE_RANK`, etc.) +/// - `ROW_NUMBER` has a `PARTITION BY` clause (global top-K is already +/// handled by `SortExec` with `fetch`) +/// - The filter predicate compares the window output column to an integer +/// literal using `<=`, `<`, `>=`, or `>` +/// +/// [`PartitionedTopKExec`]: datafusion_physical_plan::sorts::partitioned_topk::PartitionedTopKExec +#[derive(Default, Clone, Debug)] +pub struct WindowTopN; + +impl WindowTopN { + pub fn new() -> Self { + Self + } + + /// Attempt to transform a single plan node. + /// + /// Returns `Some(new_plan)` if the node matches the + /// `FilterExec → [ProjectionExec] → BoundedWindowAggExec → SortExec` + /// pattern and can be rewritten, or `None` if the node should be + /// left unchanged. + fn try_transform(plan: &Arc) -> Option> { + // Step 1: Match FilterExec at the top + let filter = plan.downcast_ref::()?; + + // Don't handle filters with projections + if filter.projection().is_some() { + return None; + } + + // Step 2: Extract limit from predicate (rn <= K, rn < K, etc.) + let (col_idx, limit_n) = extract_window_limit(filter.predicate())?; + + // Step 3: Walk through optional ProjectionExec to find BoundedWindowAggExec + let child = filter.input(); + let (window_exec, proj_between) = find_window_below(child)?; + + // Step 4: Verify col_idx references a ROW_NUMBER window output column + let input_field_count = window_exec.input().schema().fields().len(); + if col_idx < input_field_count { + return None; // Filter is on an input column, not a window column + } + let window_expr_idx = col_idx - input_field_count; + let window_exprs = window_exec.window_expr(); + if window_expr_idx >= window_exprs.len() { + return None; + } + if !is_row_number(&window_exprs[window_expr_idx]) { + return None; + } + + // Step 5: Verify child of window is SortExec + let sort_exec = window_exec.input().downcast_ref::()?; + let sort_child = sort_exec.input(); + + // Step 6: Determine partition_prefix_len from the window expression + let partition_by = window_exprs[window_expr_idx].partition_by(); + let partition_prefix_len = partition_by.len(); + + // Without PARTITION BY, this is just a global top-K which + // SortExec with fetch already handles efficiently. + if partition_prefix_len == 0 { + return None; + } + + // Step 7: Build PartitionedTopKExec using SortExec's expressions + let partitioned_topk = PartitionedTopKExec::try_new( + Arc::clone(sort_child), + sort_exec.expr().clone(), + partition_prefix_len, + limit_n, + ) + .ok()?; + + // Step 8: Rebuild window with new child + let new_window = Arc::clone(&child_as_arc(window_exec)) + .with_new_children(vec![Arc::new(partitioned_topk)]) + .ok()?; + + // Step 9: If ProjectionExec was between Filter and Window, rebuild it + let result = match proj_between { + Some(proj) => Arc::clone(&child_as_arc(proj)) + .with_new_children(vec![new_window]) + .ok()?, + None => new_window, + }; + + Some(result) + } +} + +/// Helper to get an `Arc` from a reference. +/// We need this because `with_new_children` takes `Arc`. +fn child_as_arc(plan: &T) -> Arc { + Arc::new(plan.clone()) +} + +impl PhysicalOptimizerRule for WindowTopN { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + if !config.optimizer.enable_window_topn { + return Ok(plan); + } + + plan.transform_down(|node| { + Ok( + if let Some(transformed) = WindowTopN::try_transform(&node) { + Transformed::yes(transformed) + } else { + Transformed::no(node) + }, + ) + }) + .data() + } + + fn name(&self) -> &str { + "WindowTopN" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// Extract a window limit from a predicate expression. +/// +/// Returns `(column_index, fetch)` if the predicate constrains a column +/// to at most N rows. +/// +/// # Supported Patterns +/// +/// | Predicate | Returns | +/// |-----------|---------| +/// | `Column(idx) <= Literal(N)` | `(idx, N)` | +/// | `Column(idx) < Literal(N)` | `(idx, N-1)` | +/// | `Literal(N) >= Column(idx)` | `(idx, N)` | +/// | `Literal(N) > Column(idx)` | `(idx, N-1)` | +/// +/// # Examples +/// +/// - `rn <= 5` → `Some((2, 5))` (assuming rn is column index 2) +/// - `rn < 3` → `Some((2, 2))` +/// - `10 >= rn` → `Some((2, 10))` +/// - `rn = 1` → `None` (equality not supported) +/// - `val <= 5` → `Some((1, 5))` (caller must verify it's a window column) +fn extract_window_limit( + predicate: &Arc, +) -> Option<(usize, usize)> { + let binary = predicate.downcast_ref::()?; + let op = binary.op(); + let left = binary.left(); + let right = binary.right(); + + // Try Column op Literal + if let (Some(col), Some(lit_val)) = ( + left.downcast_ref::(), + right.downcast_ref::(), + ) { + let n = scalar_to_usize(lit_val.value())?; + return match *op { + Operator::LtEq => Some((col.index(), n)), + Operator::Lt => Some((col.index(), n - 1)), + _ => None, + }; + } + + // Try Literal op Column (flipped) + if let (Some(lit_val), Some(col)) = ( + left.downcast_ref::(), + right.downcast_ref::(), + ) { + let n = scalar_to_usize(lit_val.value())?; + return match *op { + Operator::GtEq => Some((col.index(), n)), + Operator::Gt => Some((col.index(), n - 1)), + _ => None, + }; + } + + None +} + +/// Convert a [`ScalarValue`] to `usize` if it's a positive integer. +/// +/// Returns `None` for null values, zero, negative integers, and +/// non-integer types (floats, strings, decimals, etc.). +fn scalar_to_usize(value: &ScalarValue) -> Option { + if !value.data_type().is_integer() { + return None; + } + let casted = value.cast_to(&DataType::UInt64).ok()?; + match casted { + ScalarValue::UInt64(Some(v)) if v > 0 => usize::try_from(v).ok(), + _ => None, + } +} + +/// Check if a window expression is `ROW_NUMBER`. +/// +/// Downcasts through `StandardWindowExpr` → `WindowUDFExpr` and checks +/// that the UDF name is `"row_number"`. Returns `false` for all other +/// window functions (e.g., `RANK`, `DENSE_RANK`, `SUM`). +fn is_row_number(expr: &Arc) -> bool { + let Some(swe) = expr.as_any().downcast_ref::() else { + return false; + }; + let swfe = swe.get_standard_func_expr(); + let Some(udf) = swfe.as_any().downcast_ref::() else { + return false; + }; + udf.fun().name() == "row_number" +} + +/// Walk below a plan node looking for a [`BoundedWindowAggExec`]. +/// +/// Handles two cases: +/// - Direct child: `FilterExec → BoundedWindowAggExec` +/// - With projection: `FilterExec → ProjectionExec → BoundedWindowAggExec` +/// +/// Returns the window exec and an optional `ProjectionExec` in between, +/// or `None` if no `BoundedWindowAggExec` is found within one or two levels. +fn find_window_below( + plan: &Arc, +) -> Option<(&BoundedWindowAggExec, Option<&ProjectionExec>)> { + // Direct child is BoundedWindowAggExec + if let Some(window) = plan.downcast_ref::() { + return Some((window, None)); + } + + // Child is ProjectionExec with BoundedWindowAggExec below + if let Some(proj) = plan.downcast_ref::() { + let proj_child = proj.input(); + if let Some(window) = proj_child.downcast_ref::() { + return Some((window, Some(proj))); + } + } + + None +} diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 5858deb83c83c..0fc75043bf333 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -39,38 +39,57 @@ workspace = true [features] force_hash_collisions = [] +test_utils = ["arrow/test_utils"] tokio_coop = [] tokio_coop_fallback = [] +# Enables `PhysicalExpr::try_to_proto` / `try_from_proto` hooks on the +# physical expressions defined in this crate (e.g. `HashExpr`). Off by +# default so consumers that never serialize plans pay nothing. +proto = [ + "dep:datafusion-proto-models", + "dep:datafusion-proto-common", + "datafusion-physical-expr/proto", + "datafusion-physical-expr-common/proto", +] [lib] name = "datafusion_physical_plan" [dependencies] -ahash = { workspace = true } arrow = { workspace = true } +arrow-data = { workspace = true } +# Spill IPC writes require lz4 and zstd codec support. Keep these features in +# sync with the SpillCompression variants in datafusion-common so codec +# availability is explicit in the crate that owns spill handling. +arrow-ipc = { workspace = true, features = ["lz4", "zstd"] } arrow-ord = { workspace = true } arrow-schema = { workspace = true } async-trait = { workspace = true } -chrono = { workspace = true } datafusion-common = { workspace = true } datafusion-common-runtime = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } datafusion-physical-expr-common = { workspace = true } +datafusion-proto-common = { workspace = true, optional = true } +datafusion-proto-models = { workspace = true, optional = true } futures = { workspace = true } half = { workspace = true } hashbrown = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } +num-traits = { workspace = true } parking_lot = { workspace = true } pin-project-lite = "^0.2.7" +serde_json = { workspace = true, features = ["preserve_order"] } tokio = { workspace = true } [dev-dependencies] +arrow-data = { workspace = true } criterion = { workspace = true, features = ["async_futures"] } datafusion-functions-aggregate = { workspace = true } datafusion-functions-window = { workspace = true } @@ -96,6 +115,25 @@ name = "spill_io" harness = false name = "sort_preserving_merge" +[[bench]] +harness = false +name = "sort_merge_join" +required-features = ["test_utils"] + [[bench]] harness = false name = "aggregate_vectorized" +required-features = ["test_utils"] + +[[bench]] +harness = false +name = "hash_join_semi_anti" +required-features = ["test_utils"] + +[[bench]] +harness = false +name = "dictionary_group_values" + +[[bench]] +harness = false +name = "multi_group_by" diff --git a/datafusion/physical-plan/benches/aggregate_vectorized.rs b/datafusion/physical-plan/benches/aggregate_vectorized.rs index 3c1899406c985..48ca76d80d2d3 100644 --- a/datafusion/physical-plan/benches/aggregate_vectorized.rs +++ b/datafusion/physical-plan/benches/aggregate_vectorized.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::ArrayRef; +use arrow::array::{ArrayRef, BooleanBufferBuilder}; use arrow::datatypes::{Int32Type, StringViewType}; use arrow::util::bench_util::{ create_primitive_array, create_string_view_array_with_len, @@ -25,11 +25,11 @@ use arrow::util::test_util::seedable_rng; use arrow_schema::DataType; use criterion::measurement::WallTime; use criterion::{ - criterion_group, criterion_main, BenchmarkGroup, BenchmarkId, Criterion, + BenchmarkGroup, BenchmarkId, Criterion, criterion_group, criterion_main, }; +use datafusion_physical_plan::aggregates::group_values::multi_group_by::GroupColumn; use datafusion_physical_plan::aggregates::group_values::multi_group_by::bytes_view::ByteViewGroupValueBuilder; use datafusion_physical_plan::aggregates::group_values::multi_group_by::primitive::PrimitiveGroupValueBuilder; -use datafusion_physical_plan::aggregates::group_values::multi_group_by::GroupColumn; use rand::distr::{Bernoulli, Distribution}; use std::hint::black_box; use std::sync::Arc; @@ -271,6 +271,7 @@ fn bench_single_primitive( } /// Test `vectorized_equal_to` with different number of true in the initial results +#[expect(clippy::needless_pass_by_value)] fn vectorized_equal_to( group: &mut BenchmarkGroup, mut builder: GroupColumnBuilder, @@ -288,13 +289,17 @@ fn vectorized_equal_to( builder.vectorized_append(input, rows).unwrap(); b.iter(|| { - // Cloning is a must as `vectorized_equal_to` will modify the input vec - // and without cloning all benchmarks after the first one won't be meaningful - let mut equal_to_results = equal_to_results.clone(); - builder.vectorized_equal_to(rows, input, rows, &mut equal_to_results); + // Rebuild the buffer each iteration as `vectorized_equal_to` mutates + // it, and without a fresh buffer all iterations after the first one + // would not be meaningful. + let mut equal_to_buffer = BooleanBufferBuilder::new(equal_to_results.len()); + for &v in &equal_to_results { + equal_to_buffer.append(v); + } + builder.vectorized_equal_to(rows, input, rows, &mut equal_to_buffer); // Make sure that the compiler does not optimize away the call - black_box(equal_to_results); + black_box(equal_to_buffer); }); }); } diff --git a/datafusion/physical-plan/benches/dictionary_group_values.rs b/datafusion/physical-plan/benches/dictionary_group_values.rs new file mode 100644 index 0000000000000..ded52aebd1100 --- /dev/null +++ b/datafusion/physical-plan/benches/dictionary_group_values.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for `GroupValues` over a single `Dictionary` +//! column. Each iteration measures `intern` (once or N times) followed by +//! `emit(EmitTo::All)`. The `Box` returned by +//! `new_group_values` is constructed in the setup closure of +//! `iter_batched_ref` and is not included in the timing. + +use arrow::array::{ArrayRef, DictionaryArray, PrimitiveArray, StringArray}; +use arrow::buffer::{Buffer, NullBuffer}; +use arrow::datatypes::{DataType, Field, Int32Type, Schema, SchemaRef}; +use criterion::{ + BatchSize, BenchmarkId, Criterion, Throughput, criterion_group, criterion_main, +}; +use datafusion_expr::EmitTo; +use datafusion_physical_plan::aggregates::group_values::new_group_values; +use datafusion_physical_plan::aggregates::order::GroupOrdering; +use rand::rngs::StdRng; +use rand::seq::SliceRandom; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +const SIZES: [usize; 2] = [8 * 1024, 64 * 1024]; +const CARDS_RELATIVE: [usize; 4] = [20, 75, 300, 1000]; +const N_BATCHES: usize = 4; +// Fixed for reproducibility. +const SEED: u64 = 0xD1C7; + +fn dict_schema() -> SchemaRef { + let dict_ty = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + Arc::new(Schema::new(vec![Field::new("g", dict_ty, true)])) +} + +/// Build a `Dictionary` column. +fn make_dict(size: usize, cardinality: usize, null_density: f32, seed: u64) -> ArrayRef { + let strings: Vec = (0..cardinality).map(|i| format!("v_{i:08}")).collect(); + let values = Arc::new(StringArray::from( + strings.iter().map(String::as_str).collect::>(), + )); + + let mut rng = StdRng::seed_from_u64(seed); + let keys: Vec = if cardinality == size { + let mut perm: Vec = (0..size as i32).collect(); + perm.shuffle(&mut rng); + perm + } else { + (0..size) + .map(|_| rng.random_range(0..cardinality) as i32) + .collect() + }; + let keys_buf = Buffer::from_slice_ref(&keys); + + let nulls: Option = (null_density > 0.0).then(|| { + (0..size) + .map(|_| !rng.random_bool(null_density as f64)) + .collect() + }); + + let key_array = PrimitiveArray::::new(keys_buf.into(), nulls); + Arc::new(DictionaryArray::::try_new(key_array, values).unwrap()) +} + +fn bench_id( + label: &str, + size: usize, + cardinality: usize, + null_density: f32, +) -> BenchmarkId { + BenchmarkId::new( + label, + format!("size_{size}_card_{cardinality}_null_{null_density:.2}"), + ) +} + +fn bench_intern_emit(c: &mut Criterion) { + let mut group = c.benchmark_group("dict_intern_emit"); + let schema = dict_schema(); + let null_density = 0.0; + + for &size in &SIZES { + let mut cards = CARDS_RELATIVE.to_vec(); + cards.push(size); // all-unique stress case + for cardinality in cards { + let array = make_dict(size, cardinality, null_density, SEED); + group.throughput(Throughput::Elements(size as u64)); + group.bench_function( + bench_id("intern_emit", size, cardinality, null_density), + |b| { + b.iter_batched_ref( + || { + ( + new_group_values(schema.clone(), &GroupOrdering::None) + .unwrap(), + Vec::::with_capacity(size), + ) + }, + |(gv, groups)| { + gv.intern(std::slice::from_ref(&array), groups).unwrap(); + black_box(&*groups); + black_box(gv.emit(EmitTo::All).unwrap()); + }, + BatchSize::SmallInput, + ); + }, + ); + } + } + group.finish(); +} + +fn bench_repeated_intern_emit(c: &mut Criterion) { + let mut group = c.benchmark_group("dict_repeated_intern_emit"); + let schema = dict_schema(); + let null_density = 0.10; + + for &size in &SIZES { + let mut cards = CARDS_RELATIVE.to_vec(); + cards.push(size); + for cardinality in cards { + let batches: Vec = (0..N_BATCHES) + .map(|i| { + make_dict( + size, + cardinality, + null_density, + SEED.wrapping_add(i as u64), + ) + }) + .collect(); + group.throughput(Throughput::Elements((size * N_BATCHES) as u64)); + group.bench_function( + bench_id("repeated_intern_emit", size, cardinality, null_density), + |b| { + b.iter_batched_ref( + || { + ( + new_group_values(schema.clone(), &GroupOrdering::None) + .unwrap(), + Vec::::with_capacity(size), + ) + }, + |(gv, groups)| { + for arr in &batches { + gv.intern(std::slice::from_ref(arr), groups).unwrap(); + black_box(&*groups); + } + black_box(gv.emit(EmitTo::All).unwrap()); + }, + BatchSize::SmallInput, + ); + }, + ); + } + } + group.finish(); +} + +criterion_group!(benches, bench_intern_emit, bench_repeated_intern_emit); +criterion_main!(benches); diff --git a/datafusion/physical-plan/benches/hash_join_semi_anti.rs b/datafusion/physical-plan/benches/hash_join_semi_anti.rs new file mode 100644 index 0000000000000..1e11da36be73c --- /dev/null +++ b/datafusion/physical-plan/benches/hash_join_semi_anti.rs @@ -0,0 +1,387 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Criterion benchmarks for Hash Join with RightSemi/RightAnti joins with Int32 keys. +//! +//! ## Key Benchmark Axes +//! +//! - **Density**: How tightly distinct keys pack into their numeric range. +//! `density = num_distinct_keys / (max_key - min_key + 1)`. +//! Examples for 5 distinct keys: +//! - `[0, 1, 2, 3, 4]` → 5/5 = 100% (fully packed) +//! - `[0, 2, 4, 6, 8]` → 5/9 ≈ 55% (every 2nd slot) +//! - `[0, 10, 20, 30, 40]` → 5/41 ≈ 12% (every 10th slot) +//! +//! Why it matters for this workload: future potential semi/anti-join +//! fast paths could exploit densely packed build keys to outperform the +//! general hash-table path, which is largely insensitive to density. +//! Varying density across benchmarks helps surface those potential gains +//! under different key distributions. Density describes only the +//! build-side key layout; the per-probe match count is tracked +//! separately as fanout. +//! +//! - **Hit Rate**: The percentage of probe rows that find a match in the build side. +//! This controls how often the join produces output rows. +//! +//! Semi/anti joins can short-circuit after finding the first match, so these +//! benchmarks help evaluate optimization strategies for existence checks. + +use std::sync::Arc; + +use arrow::array::{Int32Array, RecordBatch, StringArray}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::{JoinType, NullEquality}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_plan::collect; +use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode, utils::JoinOn}; +use datafusion_physical_plan::test::TestMemoryExec; +use tokio::runtime::Runtime; + +/// Build RecordBatches with Int32 keys. +/// +/// Schema: (key: Int32, data: Int32, payload: Utf8) +/// +/// `key_mod` controls distinct key count: key = row_index % key_mod. +/// `key_offset` shifts keys to control hit rate. +fn build_batches( + num_rows: usize, + key_mod: usize, + key_offset: i32, + schema: &SchemaRef, +) -> Vec { + let keys: Vec = (0..num_rows) + .map(|i| ((i % key_mod) as i32) + key_offset) + .collect(); + let data: Vec = (0..num_rows).map(|i| i as i32).collect(); + let payload: Vec = data.iter().map(|d| format!("val_{d}")).collect(); + + let batch = RecordBatch::try_new( + Arc::clone(schema), + vec![ + Arc::new(Int32Array::from(keys)), + Arc::new(Int32Array::from(data)), + Arc::new(StringArray::from(payload)), + ], + ) + .unwrap(); + + let batch_size = 8192; + let mut batches = Vec::new(); + let mut offset = 0; + while offset < batch.num_rows() { + let len = (batch.num_rows() - offset).min(batch_size); + batches.push(batch.slice(offset, len)); + offset += len; + } + batches +} + +fn make_exec( + batches: &[RecordBatch], + schema: &SchemaRef, +) -> Arc { + TestMemoryExec::try_new_exec(&[batches.to_vec()], Arc::clone(schema), None).unwrap() +} + +fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int32, false), + Field::new("data", DataType::Int32, false), + Field::new("payload", DataType::Utf8, false), + ])) +} + +fn do_hash_join( + left: Arc, + right: Arc, + join_type: JoinType, + rt: &Runtime, +) -> usize { + let on: JoinOn = vec![( + col("key", &left.schema()).unwrap(), + col("key", &right.schema()).unwrap(), + )]; + let join = HashJoinExec::try_new( + left, + right, + on, + None, + &join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + ) + .unwrap(); + + let task_ctx = Arc::new(TaskContext::default()); + rt.block_on(async { + let batches = collect(Arc::new(join), task_ctx).await.unwrap(); + batches.iter().map(|b| b.num_rows()).sum() + }) +} + +/// Build batches with sparse keys (key = row_index % key_mod * multiplier + key_offset). +/// The `multiplier` controls density: 1 = 100%, 2 = 50%, 10 = 10%. +fn build_batches_sparse( + num_rows: usize, + key_mod: usize, + key_offset: i32, + multiplier: i32, + schema: &SchemaRef, +) -> Vec { + let keys: Vec = (0..num_rows) + .map(|i| ((i % key_mod) as i32) * multiplier + key_offset) + .collect(); + let data: Vec = (0..num_rows).map(|i| i as i32).collect(); + let payload: Vec = data.iter().map(|d| format!("val_{d}")).collect(); + + let batch = RecordBatch::try_new( + Arc::clone(schema), + vec![ + Arc::new(Int32Array::from(keys)), + Arc::new(Int32Array::from(data)), + Arc::new(StringArray::from(payload)), + ], + ) + .unwrap(); + + let batch_size = 8192; + let mut batches = Vec::new(); + let mut offset = 0; + while offset < batch.num_rows() { + let len = (batch.num_rows() - offset).min(batch_size); + batches.push(batch.slice(offset, len)); + offset += len; + } + batches +} + +fn bench_hash_join_semi_anti(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + let s = schema(); + + let mut group = c.benchmark_group("hash_join_semi_anti"); + + // Build side: 100K rows, Probe side: 1M rows + // Matching ratio: 1:1 (build keys are unique, each probe matches at most 1 build row) + let build_rows = 100_000; + let probe_rows = 1_000_000; + + // ========================================================================= + // RightSemi Join benchmarks + // ========================================================================= + + // RightSemi - 100% Density, 100% hit rate + // Keys: 0..100K contiguous, all probe rows find a match + { + let left_batches = build_batches(build_rows, build_rows, 0, &s); + let right_batches = build_batches(probe_rows, build_rows, 0, &s); + group.bench_function(BenchmarkId::new("right_semi_d100_h100", probe_rows), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_hash_join(left, right, JoinType::RightSemi, &rt) + }) + }); + } + + // RightSemi - 100% Density, 10% hit rate + // Keys: 0..100K contiguous, only 10% of probe rows find a match + { + let left_batches = build_batches(build_rows, build_rows, 0, &s); + let right_batches = build_batches(probe_rows, build_rows * 10, 0, &s); + group.bench_function(BenchmarkId::new("right_semi_d100_h10", probe_rows), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_hash_join(left, right, JoinType::RightSemi, &rt) + }) + }); + } + + // RightSemi - 50% Density, 100% hit rate + // Keys: 0, 2, 4, ... (sparse, multiplier=2), all probe rows find a match + { + let left_batches = build_batches_sparse(build_rows, build_rows, 0, 2, &s); + let right_batches = build_batches_sparse(probe_rows, build_rows, 0, 2, &s); + group.bench_function(BenchmarkId::new("right_semi_d50_h100", probe_rows), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_hash_join(left, right, JoinType::RightSemi, &rt) + }) + }); + } + + // RightSemi - 50% Density, 10% hit rate + // Keys: 0, 2, 4, ... (sparse), only 10% of probe rows find a match + { + let left_batches = build_batches_sparse(build_rows, build_rows, 0, 2, &s); + let right_batches = build_batches_sparse(probe_rows, build_rows * 10, 0, 2, &s); + group.bench_function(BenchmarkId::new("right_semi_d50_h10", probe_rows), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_hash_join(left, right, JoinType::RightSemi, &rt) + }) + }); + } + + // RightSemi - 10% Density, 100% hit rate + // Keys: 0, 10, 20, ... (very sparse, multiplier=10), all probe rows find a match + { + let left_batches = build_batches_sparse(build_rows, build_rows, 0, 10, &s); + let right_batches = build_batches_sparse(probe_rows, build_rows, 0, 10, &s); + group.bench_function(BenchmarkId::new("right_semi_d10_h100", probe_rows), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_hash_join(left, right, JoinType::RightSemi, &rt) + }) + }); + } + + // RightSemi - 10% Density, 10% hit rate + // Keys: 0, 10, 20, ... (very sparse), only 10% of probe rows find a match + { + let left_batches = build_batches_sparse(build_rows, build_rows, 0, 10, &s); + let right_batches = build_batches_sparse(probe_rows, build_rows * 10, 0, 10, &s); + group.bench_function(BenchmarkId::new("right_semi_d10_h10", probe_rows), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_hash_join(left, right, JoinType::RightSemi, &rt) + }) + }); + } + + // RightSemi - 100% Density, ~1% hit rate, fanout ~100 + // Build keys are duplicated: 100K rows over 1K distinct keys. Matching + // probe rows produce many duplicate probe indices before RightSemi + // deduplication. + { + let fanout_keys = 1_000; + let left_batches = build_batches(build_rows, fanout_keys, 0, &s); + let right_batches = build_batches(probe_rows, build_rows, 0, &s); + group.bench_function( + BenchmarkId::new("right_semi_fanout100_h1", probe_rows), + |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_hash_join(left, right, JoinType::RightSemi, &rt) + }) + }, + ); + } + + // ========================================================================= + // RightAnti Join benchmarks + // ========================================================================= + + // RightAnti - 100% Density, 100% hit rate (no output) + // Keys: 0..100K contiguous, all probe rows find a match -> no output + { + let left_batches = build_batches(build_rows, build_rows, 0, &s); + let right_batches = build_batches(probe_rows, build_rows, 0, &s); + group.bench_function(BenchmarkId::new("right_anti_d100_h100", probe_rows), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_hash_join(left, right, JoinType::RightAnti, &rt) + }) + }); + } + + // RightAnti - 100% Density, 10% hit rate (90% output) + // Keys: 0..100K contiguous, only 10% of probe rows find a match -> 90% output + { + let left_batches = build_batches(build_rows, build_rows, 0, &s); + let right_batches = build_batches(probe_rows, build_rows * 10, 0, &s); + group.bench_function(BenchmarkId::new("right_anti_d100_h10", probe_rows), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_hash_join(left, right, JoinType::RightAnti, &rt) + }) + }); + } + + // RightAnti - 50% Density, 100% hit rate (no output) + // Keys: 0, 2, 4, ... (sparse), all probe rows find a match -> no output + { + let left_batches = build_batches_sparse(build_rows, build_rows, 0, 2, &s); + let right_batches = build_batches_sparse(probe_rows, build_rows, 0, 2, &s); + group.bench_function(BenchmarkId::new("right_anti_d50_h100", probe_rows), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_hash_join(left, right, JoinType::RightAnti, &rt) + }) + }); + } + + // RightAnti - 50% Density, 10% hit rate (90% output) + // Keys: 0, 2, 4, ... (sparse), only 10% of probe rows find a match -> 90% output + { + let left_batches = build_batches_sparse(build_rows, build_rows, 0, 2, &s); + let right_batches = build_batches_sparse(probe_rows, build_rows * 10, 0, 2, &s); + group.bench_function(BenchmarkId::new("right_anti_d50_h10", probe_rows), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_hash_join(left, right, JoinType::RightAnti, &rt) + }) + }); + } + + // RightAnti - 10% Density, 100% hit rate (no output) + // Keys: 0, 10, 20, ... (very sparse), all probe rows find a match -> no output + { + let left_batches = build_batches_sparse(build_rows, build_rows, 0, 10, &s); + let right_batches = build_batches_sparse(probe_rows, build_rows, 0, 10, &s); + group.bench_function(BenchmarkId::new("right_anti_d10_h100", probe_rows), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_hash_join(left, right, JoinType::RightAnti, &rt) + }) + }); + } + + // RightAnti - 10% Density, 10% hit rate (90% output) + // Keys: 0, 10, 20, ... (very sparse), only 10% of probe rows find a match -> 90% output + { + let left_batches = build_batches_sparse(build_rows, build_rows, 0, 10, &s); + let right_batches = build_batches_sparse(probe_rows, build_rows * 10, 0, 10, &s); + group.bench_function(BenchmarkId::new("right_anti_d10_h10", probe_rows), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_hash_join(left, right, JoinType::RightAnti, &rt) + }) + }); + } + + group.finish(); +} + +criterion_group!(benches, bench_hash_join_semi_anti); +criterion_main!(benches); diff --git a/datafusion/physical-plan/benches/multi_group_by.rs b/datafusion/physical-plan/benches/multi_group_by.rs new file mode 100644 index 0000000000000..92d0448775599 --- /dev/null +++ b/datafusion/physical-plan/benches/multi_group_by.rs @@ -0,0 +1,356 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for multi-column GROUP BY performance comparing vectorized +//! (`GroupValuesColumn`) vs row-based (`GroupValuesRows`) implementations. +//! +//! Motivated by which +//! showed vectorized can regress for low-cardinality, high-row-count scenarios. +//! +//! Uses the direct `GroupValues::intern()` API with identical Int32 data for +//! both implementations — a fair apples-to-apples comparison with the same +//! hashing and data layout. + +use arrow::array::{ArrayRef, Int32Array}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_physical_plan::aggregates::group_values::GroupValues; +use datafusion_physical_plan::aggregates::group_values::GroupValuesRows; +use datafusion_physical_plan::aggregates::group_values::multi_group_by::GroupValuesColumn; +use std::hint::black_box; +use std::sync::Arc; + +const DEFAULT_BATCH_SIZE: usize = 8192; + +fn make_schema(num_cols: usize) -> SchemaRef { + let fields: Vec = (0..num_cols) + .map(|i| Field::new(format!("col_{i}"), DataType::Int32, false)) + .collect(); + Arc::new(Schema::new(fields)) +} + +fn generate_batches( + num_cols: usize, + num_distinct_groups: usize, + num_rows: usize, + batch_size: usize, +) -> Vec> { + let per_col_card = (num_distinct_groups as f64) + .powf(1.0 / num_cols as f64) + .ceil() as usize; + + let num_full_batches = num_rows / batch_size; + let remainder = num_rows % batch_size; + let num_batches = num_full_batches + if remainder > 0 { 1 } else { 0 }; + + (0..num_batches) + .map(|batch_idx| { + let batch_start = batch_idx * batch_size; + let current_batch_size = if batch_idx == num_batches - 1 && remainder > 0 { + remainder + } else { + batch_size + }; + (0..num_cols) + .map(|col_idx| { + let values: Vec = (0..current_batch_size) + .map(|row| { + let global_row = batch_start + row; + let group_id = global_row % num_distinct_groups; + let divisor = per_col_card.pow(col_idx as u32); + ((group_id / divisor) % per_col_card) as i32 + }) + .collect(); + Arc::new(Int32Array::from(values)) as ArrayRef + }) + .collect() + }) + .collect() +} + +fn create_group_values(schema: &SchemaRef, vectorized: bool) -> Box { + if vectorized { + Box::new(GroupValuesColumn::::try_new(Arc::clone(schema)).unwrap()) + } else { + Box::new(GroupValuesRows::try_new(Arc::clone(schema)).unwrap()) + } +} + +fn bench_intern( + gv: &mut Box, + batches: &[Vec], + groups: &mut Vec, +) { + for batch in batches { + groups.clear(); + gv.intern(batch, groups).unwrap(); + } + black_box(&*groups); +} + +/// Experiment 1: Issue #17850 regression scenario. +/// 3 columns, 64 groups (4^3), scaling row count. +fn bench_issue_17850_regression(c: &mut Criterion) { + let mut group = c.benchmark_group("issue_17850_regression"); + group.sample_size(10); + + let num_cols = 3; + let num_groups = 64; + let schema = make_schema(num_cols); + + for num_rows in [1_000_000, 5_000_000, 10_000_000, 20_000_000, 50_000_000] { + let batches = + generate_batches(num_cols, num_groups, num_rows, DEFAULT_BATCH_SIZE); + + for vectorized in [true, false] { + let label = if vectorized { + "vectorized" + } else { + "row_based" + }; + group.bench_with_input( + BenchmarkId::new(label, format!("{num_rows}_rows")), + &batches, + |b, batches| { + b.iter_batched_ref( + || { + ( + create_group_values(&schema, vectorized), + Vec::::with_capacity(DEFAULT_BATCH_SIZE), + ) + }, + |(gv, groups)| bench_intern(gv, batches, groups), + criterion::BatchSize::LargeInput, + ); + }, + ); + } + } + group.finish(); +} + +/// Experiment 2: Low cardinality sweep. +fn bench_low_cardinality(c: &mut Criterion) { + let mut group = c.benchmark_group("low_cardinality"); + group.sample_size(15); + + for (num_cols, per_col_card) in + [(3usize, 2usize), (3, 4), (3, 8), (4, 2), (4, 4), (4, 8)] + { + let num_groups = per_col_card.pow(num_cols as u32); + let schema = make_schema(num_cols); + let batches = + generate_batches(num_cols, num_groups, 1_000_000, DEFAULT_BATCH_SIZE); + + for vectorized in [true, false] { + let label = if vectorized { + "vectorized" + } else { + "row_based" + }; + group.bench_with_input( + BenchmarkId::new( + label, + format!("cols_{num_cols}_card_{per_col_card}_grp_{num_groups}"), + ), + &batches, + |b, batches| { + b.iter_batched_ref( + || { + ( + create_group_values(&schema, vectorized), + Vec::::with_capacity(DEFAULT_BATCH_SIZE), + ) + }, + |(gv, groups)| bench_intern(gv, batches, groups), + criterion::BatchSize::LargeInput, + ); + }, + ); + } + } + group.finish(); +} + +/// Experiment 3: Batch size sensitivity. +fn bench_batch_size_sensitivity(c: &mut Criterion) { + let mut group = c.benchmark_group("batch_size_sensitivity"); + group.sample_size(10); + + let num_cols = 3; + let num_groups = 64; + let schema = make_schema(num_cols); + + for batch_size in [1024, 4096, 8192, 16384, 32768] { + let batches = generate_batches(num_cols, num_groups, 1_000_000, batch_size); + + for vectorized in [true, false] { + let label = if vectorized { + "vectorized" + } else { + "row_based" + }; + group.bench_with_input( + BenchmarkId::new(label, format!("batch_{batch_size}")), + &batches, + |b, batches| { + b.iter_batched_ref( + || { + ( + create_group_values(&schema, vectorized), + Vec::::with_capacity(batch_size), + ) + }, + |(gv, groups)| bench_intern(gv, batches, groups), + criterion::BatchSize::LargeInput, + ); + }, + ); + } + } + group.finish(); +} + +/// Experiment 4: Column count scaling with low groups. +fn bench_column_scaling(c: &mut Criterion) { + let mut group = c.benchmark_group("column_scaling"); + group.sample_size(15); + + let cases: &[(usize, usize)] = + &[(2, 100), (3, 125), (4, 81), (6, 729), (8, 256), (10, 1024)]; + + for &(num_cols, num_groups) in cases { + let schema = make_schema(num_cols); + let batches = + generate_batches(num_cols, num_groups, 1_000_000, DEFAULT_BATCH_SIZE); + + for vectorized in [true, false] { + let label = if vectorized { + "vectorized" + } else { + "row_based" + }; + group.bench_with_input( + BenchmarkId::new(label, format!("cols_{num_cols}_grp_{num_groups}")), + &batches, + |b, batches| { + b.iter_batched_ref( + || { + ( + create_group_values(&schema, vectorized), + Vec::::with_capacity(DEFAULT_BATCH_SIZE), + ) + }, + |(gv, groups)| bench_intern(gv, batches, groups), + criterion::BatchSize::LargeInput, + ); + }, + ); + } + } + group.finish(); +} + +/// Experiment 5: High cardinality column scaling (~1M groups). +fn bench_high_cardinality_scaling(c: &mut Criterion) { + let mut group = c.benchmark_group("high_cardinality_scaling"); + group.sample_size(10); + + for num_cols in [2, 3, 4, 6, 8, 10] { + let num_groups = 1_000_000; + let schema = make_schema(num_cols); + let batches = + generate_batches(num_cols, num_groups, 1_000_000, DEFAULT_BATCH_SIZE); + + for vectorized in [true, false] { + let label = if vectorized { + "vectorized" + } else { + "row_based" + }; + group.bench_with_input( + BenchmarkId::new(label, format!("cols_{num_cols}_grp_1M")), + &batches, + |b, batches| { + b.iter_batched_ref( + || { + ( + create_group_values(&schema, vectorized), + Vec::::with_capacity(DEFAULT_BATCH_SIZE), + ) + }, + |(gv, groups)| bench_intern(gv, batches, groups), + criterion::BatchSize::LargeInput, + ); + }, + ); + } + } + group.finish(); +} + +/// Experiment 6: Group count sweep with fixed 4 columns. +fn bench_group_count_sweep(c: &mut Criterion) { + let mut group = c.benchmark_group("group_count_sweep"); + group.sample_size(15); + + let num_cols = 4; + let schema = make_schema(num_cols); + + for num_groups in [ + 16, 64, 256, 1000, 5000, 10_000, 50_000, 100_000, 500_000, 1_000_000, + ] { + let batches = + generate_batches(num_cols, num_groups, 1_000_000, DEFAULT_BATCH_SIZE); + + for vectorized in [true, false] { + let label = if vectorized { + "vectorized" + } else { + "row_based" + }; + group.bench_with_input( + BenchmarkId::new(label, format!("grp_{num_groups}")), + &batches, + |b, batches| { + b.iter_batched_ref( + || { + ( + create_group_values(&schema, vectorized), + Vec::::with_capacity(DEFAULT_BATCH_SIZE), + ) + }, + |(gv, groups)| bench_intern(gv, batches, groups), + criterion::BatchSize::LargeInput, + ); + }, + ); + } + } + group.finish(); +} + +criterion_group!( + benches, + bench_issue_17850_regression, + bench_low_cardinality, + bench_batch_size_sensitivity, + bench_column_scaling, + bench_high_cardinality_scaling, + bench_group_count_sweep, +); +criterion_main!(benches); diff --git a/datafusion/physical-plan/benches/partial_ordering.rs b/datafusion/physical-plan/benches/partial_ordering.rs index e1a9d0b583e98..bdadd6274b75e 100644 --- a/datafusion/physical-plan/benches/partial_ordering.rs +++ b/datafusion/physical-plan/benches/partial_ordering.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array}; use datafusion_physical_plan::aggregates::order::GroupOrderingPartial; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; const BATCH_SIZE: usize = 8192; diff --git a/datafusion/physical-plan/benches/sort_merge_join.rs b/datafusion/physical-plan/benches/sort_merge_join.rs new file mode 100644 index 0000000000000..82610b2a54c2b --- /dev/null +++ b/datafusion/physical-plan/benches/sort_merge_join.rs @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Criterion benchmarks for Sort Merge Join +//! +//! These benchmarks measure the join kernel in isolation by feeding +//! pre-sorted RecordBatches directly into SortMergeJoinExec, avoiding +//! sort / scan overhead. + +use std::sync::Arc; + +use arrow::array::{Int64Array, RecordBatch, StringArray}; +use arrow::compute::SortOptions; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use datafusion_common::NullEquality; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_plan::collect; +use datafusion_physical_plan::joins::{SortMergeJoinExec, utils::JoinOn}; +use datafusion_physical_plan::test::TestMemoryExec; +use tokio::runtime::Runtime; + +/// Build pre-sorted RecordBatches (split into ~8192-row chunks). +/// +/// Schema: (key: Int64, data: Int64, payload: Utf8) +/// +/// `key_mod` controls distinct key count: key = row_index % key_mod. +fn build_sorted_batches( + num_rows: usize, + key_mod: usize, + schema: &SchemaRef, +) -> Vec { + let mut rows: Vec<(i64, i64)> = (0..num_rows) + .map(|i| ((i % key_mod) as i64, i as i64)) + .collect(); + rows.sort(); + + let keys: Vec = rows.iter().map(|(k, _)| *k).collect(); + let data: Vec = rows.iter().map(|(_, d)| *d).collect(); + let payload: Vec = data.iter().map(|d| format!("val_{d}")).collect(); + + let batch = RecordBatch::try_new( + Arc::clone(schema), + vec![ + Arc::new(Int64Array::from(keys)), + Arc::new(Int64Array::from(data)), + Arc::new(StringArray::from(payload)), + ], + ) + .unwrap(); + + let batch_size = 8192; + let mut batches = Vec::new(); + let mut offset = 0; + while offset < batch.num_rows() { + let len = (batch.num_rows() - offset).min(batch_size); + batches.push(batch.slice(offset, len)); + offset += len; + } + batches +} + +fn make_exec( + batches: &[RecordBatch], + schema: &SchemaRef, +) -> Arc { + TestMemoryExec::try_new_exec(&[batches.to_vec()], Arc::clone(schema), None).unwrap() +} + +fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int64, false), + Field::new("data", DataType::Int64, false), + Field::new("payload", DataType::Utf8, false), + ])) +} + +fn do_join( + left: Arc, + right: Arc, + join_type: datafusion_common::JoinType, + rt: &Runtime, +) -> usize { + let on: JoinOn = vec![( + col("key", &left.schema()).unwrap(), + col("key", &right.schema()).unwrap(), + )]; + let join = SortMergeJoinExec::try_new( + left, + right, + on, + None, + join_type, + vec![SortOptions::default()], + NullEquality::NullEqualsNothing, + ) + .unwrap(); + + let task_ctx = Arc::new(TaskContext::default()); + rt.block_on(async { + let batches = collect(Arc::new(join), task_ctx).await.unwrap(); + batches.iter().map(|b| b.num_rows()).sum() + }) +} + +fn bench_smj(c: &mut Criterion) { + let rt = Runtime::new().unwrap(); + let s = schema(); + + let mut group = c.benchmark_group("sort_merge_join"); + + // 1:1 Inner Join — 100K rows each, unique keys + // Best case for contiguous-range optimization: every index array is [0,1,2,...]. + { + let n = 100_000; + let left_batches = build_sorted_batches(n, n, &s); + let right_batches = build_sorted_batches(n, n, &s); + group.bench_function(BenchmarkId::new("inner_1to1", n), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_join(left, right, datafusion_common::JoinType::Inner, &rt) + }) + }); + } + + // 1:10 Inner Join — 100K left, 100K right, 10K distinct keys + { + let n = 100_000; + let key_mod = 10_000; + let left_batches = build_sorted_batches(n, key_mod, &s); + let right_batches = build_sorted_batches(n, key_mod, &s); + group.bench_function(BenchmarkId::new("inner_1to10", n), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_join(left, right, datafusion_common::JoinType::Inner, &rt) + }) + }); + } + + // Left Join — 100K each, ~5% unmatched on left + { + let n = 100_000; + let left_batches = build_sorted_batches(n, n + n / 20, &s); + let right_batches = build_sorted_batches(n, n, &s); + group.bench_function(BenchmarkId::new("left_1to1_unmatched", n), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_join(left, right, datafusion_common::JoinType::Left, &rt) + }) + }); + } + + // Left Semi Join — 100K left, 100K right, 10K keys + { + let n = 100_000; + let key_mod = 10_000; + let left_batches = build_sorted_batches(n, key_mod, &s); + let right_batches = build_sorted_batches(n, key_mod, &s); + group.bench_function(BenchmarkId::new("left_semi_1to10", n), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_join(left, right, datafusion_common::JoinType::LeftSemi, &rt) + }) + }); + } + + // Left Anti Join — 100K left, 100K right, partial match + { + let n = 100_000; + let left_batches = build_sorted_batches(n, n + n / 5, &s); + let right_batches = build_sorted_batches(n, n, &s); + group.bench_function(BenchmarkId::new("left_anti_partial", n), |b| { + b.iter(|| { + let left = make_exec(&left_batches, &s); + let right = make_exec(&right_batches, &s); + do_join(left, right, datafusion_common::JoinType::LeftAnti, &rt) + }) + }); + } + + group.finish(); +} + +criterion_group!(benches, bench_smj); +criterion_main!(benches); diff --git a/datafusion/physical-plan/benches/sort_preserving_merge.rs b/datafusion/physical-plan/benches/sort_preserving_merge.rs index f223fd806b694..76ebf230a30e0 100644 --- a/datafusion/physical-plan/benches/sort_preserving_merge.rs +++ b/datafusion/physical-plan/benches/sort_preserving_merge.rs @@ -20,9 +20,9 @@ use arrow::{ record_batch::RecordBatch, }; use arrow_schema::{SchemaRef, SortOptions}; -use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use criterion::{BatchSize, Criterion, criterion_group, criterion_main}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{expressions::col, LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr, expressions::col}; use datafusion_physical_plan::test::TestMemoryExec; use datafusion_physical_plan::{ collect, sorts::sort_preserving_merge::SortPreservingMergeExec, diff --git a/datafusion/physical-plan/benches/spill_io.rs b/datafusion/physical-plan/benches/spill_io.rs index 40c8f7634c8c4..fac2547a131b4 100644 --- a/datafusion/physical-plan/benches/spill_io.rs +++ b/datafusion/physical-plan/benches/spill_io.rs @@ -22,15 +22,15 @@ use arrow::array::{ use arrow::datatypes::{DataType, Field, Schema}; use criterion::measurement::WallTime; use criterion::{ - criterion_group, criterion_main, BatchSize, BenchmarkGroup, BenchmarkId, Criterion, + BatchSize, BenchmarkGroup, BenchmarkId, Criterion, criterion_group, criterion_main, }; use datafusion_common::config::SpillCompression; +use datafusion_common::human_readable_size; use datafusion_common::instant::Instant; -use datafusion_execution::memory_pool::human_readable_size; use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_physical_plan::SpillManager; use datafusion_physical_plan::common::collect; use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, SpillMetrics}; -use datafusion_physical_plan::SpillManager; use rand::{Rng, SeedableRng}; use std::sync::Arc; use tokio::runtime::Runtime; @@ -490,6 +490,7 @@ fn bench_spill_compression(c: &mut Criterion) { group.finish(); } +#[expect(clippy::needless_pass_by_value)] fn benchmark_spill_batches_for_all_codec( group: &mut BenchmarkGroup<'_, WallTime>, batch_label: &str, diff --git a/datafusion/physical-plan/src/aggregates/group_values/metrics.rs b/datafusion/physical-plan/src/aggregates/group_values/metrics.rs index c4e29ea71060b..a0934b976ea79 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/metrics.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/metrics.rs @@ -53,12 +53,13 @@ mod tests { use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::metrics::MetricsSet; use crate::test::TestMemoryExec; - use crate::{collect, ExecutionPlan}; + use crate::{ExecutionPlan, collect}; use arrow::array::{Float64Array, UInt32Array}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::TaskContext; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::aggregate::AggregateExprBuilder; @@ -135,7 +136,13 @@ mod tests { schema, )?); - let task_ctx = Arc::new(TaskContext::default()); + // This test is for `GroupByMetrics`, which are maintained by + // `GroupedHashAggregateStream`. Use a finite memory pool so the partial + // aggregate does not take the initial-partial stream path. + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(10 * 1024 * 1024, 1.0) + .build_arc()?; + let task_ctx = Arc::new(TaskContext::default().with_runtime(runtime)); let _result = collect(Arc::clone(&aggregate_exec) as _, Arc::clone(&task_ctx)).await?; diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 4bd7f03506a15..ee253e5d7afdd 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -22,7 +22,7 @@ use arrow::array::types::{ Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use arrow::array::{downcast_primitive, ArrayRef, RecordBatch}; +use arrow::array::{ArrayRef, downcast_primitive}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use datafusion_common::Result; @@ -31,10 +31,10 @@ use datafusion_expr::EmitTo; pub mod multi_group_by; mod row; +pub use row::GroupValuesRows; mod single_group_by; use datafusion_physical_expr::binary_map::OutputType; use multi_group_by::GroupValuesColumn; -use row::GroupValuesRows; pub(crate) use single_group_by::primitive::HashValue; @@ -112,7 +112,7 @@ pub trait GroupValues: Send { fn emit(&mut self, emit_to: EmitTo) -> Result>; /// Clear the contents and shrink the capacity to the size of the batch (free up memory usage) - fn clear_shrink(&mut self, batch: &RecordBatch); + fn clear_shrink(&mut self, num_rows: usize); } /// Return a specialized implementation of [`GroupValues`] for the given schema. @@ -130,7 +130,7 @@ pub trait GroupValues: Send { /// /// `GroupColumn`: crate::aggregates::group_values::multi_group_by::GroupColumn /// `GroupValuesColumn`: crate::aggregates::group_values::multi_group_by::GroupValuesColumn -/// `GroupValuesRows`: crate::aggregates::group_values::row::GroupValuesRows +/// `GroupValuesRows`: crate::aggregates::group_values::GroupValuesRows pub fn new_group_values( schema: SchemaRef, group_ordering: &GroupOrdering, diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/boolean.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/boolean.rs index 03e26446f5751..5fdbe434f9f30 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/boolean.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/boolean.rs @@ -18,11 +18,10 @@ use std::sync::Arc; use crate::aggregates::group_values::multi_group_by::Nulls; -use crate::aggregates::group_values::multi_group_by::{nulls_equal_to, GroupColumn}; +use crate::aggregates::group_values::multi_group_by::{GroupColumn, nulls_equal_to}; use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; use arrow::array::{Array as _, ArrayRef, AsArray, BooleanArray, BooleanBufferBuilder}; use datafusion_common::Result; -use itertools::izip; /// An implementation of [`GroupColumn`] for booleans /// @@ -81,19 +80,14 @@ impl GroupColumn for BooleanGroupValueBuilder { lhs_rows: &[usize], array: &ArrayRef, rhs_rows: &[usize], - equal_to_results: &mut [bool], + equal_to_results: &mut BooleanBufferBuilder, ) { let array = array.as_boolean(); - let iter = izip!( - lhs_rows.iter(), - rhs_rows.iter(), - equal_to_results.iter_mut(), - ); - - for (&lhs_row, &rhs_row, equal_to_result) in iter { - // Has found not equal to in previous column, don't need to check - if !*equal_to_result { + for (idx, (&lhs_row, &rhs_row)) in + lhs_rows.iter().zip(rhs_rows.iter()).enumerate() + { + if !equal_to_results.get_bit(idx) { continue; } @@ -101,12 +95,16 @@ impl GroupColumn for BooleanGroupValueBuilder { let exist_null = self.nulls.is_null(lhs_row); let input_null = array.is_null(rhs_row); if let Some(result) = nulls_equal_to(exist_null, input_null) { - *equal_to_result = result; + if !result { + equal_to_results.set_bit(idx, false); + } continue; } } - *equal_to_result = self.buffer.get_bit(lhs_row) == array.value(rhs_row); + if self.buffer.get_bit(lhs_row) != array.value(rhs_row) { + equal_to_results.set_bit(idx, false); + } } } @@ -195,10 +193,20 @@ impl GroupColumn for BooleanGroupValueBuilder { #[cfg(test)] mod tests { - use arrow::array::NullBufferBuilder; + use arrow::array::{BooleanBufferBuilder, NullBufferBuilder}; use super::*; + fn make_true_buffer(n: usize) -> BooleanBufferBuilder { + let mut buf = BooleanBufferBuilder::new(n); + buf.append_n(n, true); + buf + } + + fn to_vec(buf: &BooleanBufferBuilder) -> Vec { + (0..buf.len()).map(|i| buf.get_bit(i)).collect() + } + #[test] fn test_nullable_boolean_equal_to() { let append = |builder: &mut BooleanGroupValueBuilder, @@ -209,16 +217,18 @@ mod tests { } }; - let equal_to = |builder: &BooleanGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - let iter = lhs_rows.iter().zip(rhs_rows.iter()); - for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { - equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); - } - }; + let equal_to = + |builder: &BooleanGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut BooleanBufferBuilder| { + let iter = lhs_rows.iter().zip(rhs_rows.iter()); + for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { + equal_to_results + .set_bit(idx, builder.equal_to(lhs_row, input_array, rhs_row)); + } + }; test_nullable_boolean_equal_to_internal(append, equal_to); } @@ -233,18 +243,19 @@ mod tests { .unwrap(); }; - let equal_to = |builder: &BooleanGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - builder.vectorized_equal_to( - lhs_rows, - input_array, - rhs_rows, - equal_to_results, - ); - }; + let equal_to = + |builder: &BooleanGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut BooleanBufferBuilder| { + builder.vectorized_equal_to( + lhs_rows, + input_array, + rhs_rows, + equal_to_results, + ); + }; test_nullable_boolean_equal_to_internal(append, equal_to); } @@ -257,7 +268,7 @@ mod tests { &[usize], &ArrayRef, &[usize], - &mut Vec, + &mut BooleanBufferBuilder, ), { // Will cover such cases: @@ -268,7 +279,7 @@ mod tests { // - exist not null, input not null; values not equal // - exist not null, input not null; values equal - // Define PrimitiveGroupValueBuilder + // Define BooleanGroupValueBuilder let mut builder = BooleanGroupValueBuilder::::new(); let builder_array = Arc::new(BooleanArray::from(vec![ None, @@ -294,7 +305,7 @@ mod tests { // explicitly build a null buffer where one of the null values also happens to match let mut nulls = NullBufferBuilder::new(6); nulls.append_non_null(); - nulls.append_null(); // this sets Some(false) to null above + nulls.append_null(); nulls.append_null(); nulls.append_null(); nulls.append_non_null(); @@ -302,7 +313,7 @@ mod tests { let input_array = Arc::new(BooleanArray::new(values, nulls.finish())) as ArrayRef; // Check - let mut equal_to_results = vec![true; builder.len()]; + let mut equal_to_results = make_true_buffer(builder.len()); equal_to( &builder, &[0, 1, 2, 3, 4, 5], @@ -310,13 +321,14 @@ mod tests { &[0, 1, 2, 3, 4, 5], &mut equal_to_results, ); - - assert!(!equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(!equal_to_results[3]); - assert!(!equal_to_results[4]); - assert!(equal_to_results[5]); + let results = to_vec(&equal_to_results); + + assert!(!results[0]); + assert!(results[1]); + assert!(results[2]); + assert!(!results[3]); + assert!(!results[4]); + assert!(results[5]); } #[test] @@ -329,16 +341,18 @@ mod tests { } }; - let equal_to = |builder: &BooleanGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - let iter = lhs_rows.iter().zip(rhs_rows.iter()); - for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { - equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); - } - }; + let equal_to = + |builder: &BooleanGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut BooleanBufferBuilder| { + let iter = lhs_rows.iter().zip(rhs_rows.iter()); + for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { + equal_to_results + .set_bit(idx, builder.equal_to(lhs_row, input_array, rhs_row)); + } + }; test_not_nullable_boolean_equal_to_internal(append, equal_to); } @@ -353,18 +367,19 @@ mod tests { .unwrap(); }; - let equal_to = |builder: &BooleanGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - builder.vectorized_equal_to( - lhs_rows, - input_array, - rhs_rows, - equal_to_results, - ); - }; + let equal_to = + |builder: &BooleanGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut BooleanBufferBuilder| { + builder.vectorized_equal_to( + lhs_rows, + input_array, + rhs_rows, + equal_to_results, + ); + }; test_not_nullable_boolean_equal_to_internal(append, equal_to); } @@ -377,14 +392,14 @@ mod tests { &[usize], &ArrayRef, &[usize], - &mut Vec, + &mut BooleanBufferBuilder, ), { // Will cover such cases: // - values equal // - values not equal - // Define PrimitiveGroupValueBuilder + // Define BooleanGroupValueBuilder let mut builder = BooleanGroupValueBuilder::::new(); let builder_array = Arc::new(BooleanArray::from(vec![ Some(false), @@ -403,7 +418,7 @@ mod tests { ])) as ArrayRef; // Check - let mut equal_to_results = vec![true; builder.len()]; + let mut equal_to_results = make_true_buffer(builder.len()); equal_to( &builder, &[0, 1, 2, 3], @@ -411,11 +426,12 @@ mod tests { &[0, 1, 2, 3], &mut equal_to_results, ); + let results = to_vec(&equal_to_results); - assert!(equal_to_results[0]); - assert!(!equal_to_results[1]); - assert!(!equal_to_results[2]); - assert!(equal_to_results[3]); + assert!(results[0]); + assert!(!results[1]); + assert!(!results[2]); + assert!(results[3]); } #[test] @@ -432,19 +448,20 @@ mod tests { .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]) .unwrap(); - let mut equal_to_results = vec![true; all_nulls_input_array.len()]; + let mut equal_to_results = make_true_buffer(all_nulls_input_array.len()); builder.vectorized_equal_to( &[0, 1, 2, 3, 4], &all_nulls_input_array, &[0, 1, 2, 3, 4], &mut equal_to_results, ); + let results = to_vec(&equal_to_results); - assert!(equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(equal_to_results[3]); - assert!(equal_to_results[4]); + assert!(results[0]); + assert!(results[1]); + assert!(results[2]); + assert!(results[3]); + assert!(results[4]); // All not nulls input array let all_not_nulls_input_array = Arc::new(BooleanArray::from(vec![ @@ -458,18 +475,19 @@ mod tests { .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]) .unwrap(); - let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; + let mut equal_to_results = make_true_buffer(all_not_nulls_input_array.len()); builder.vectorized_equal_to( &[5, 6, 7, 8, 9], &all_not_nulls_input_array, &[0, 1, 2, 3, 4], &mut equal_to_results, ); + let results = to_vec(&equal_to_results); - assert!(equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(equal_to_results[3]); - assert!(equal_to_results[4]); + assert!(results[0]); + assert!(results[1]); + assert!(results[2]); + assert!(results[3]); + assert!(results[4]); } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs index d52721c2ee6c3..c83b1da4049bc 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs @@ -16,19 +16,19 @@ // under the License. use crate::aggregates::group_values::multi_group_by::{ - nulls_equal_to, GroupColumn, Nulls, + GroupColumn, Nulls, nulls_equal_to, }; use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; use arrow::array::{ - types::GenericStringType, Array, ArrayRef, AsArray, BufferBuilder, - GenericBinaryArray, GenericByteArray, GenericStringArray, OffsetSizeTrait, + Array, ArrayRef, AsArray, BooleanBufferBuilder, BufferBuilder, GenericBinaryArray, + GenericByteArray, GenericStringArray, OffsetSizeTrait, types::GenericStringType, }; use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::datatypes::{ByteArrayType, DataType, GenericBinaryType}; use datafusion_common::utils::proxy::VecAllocExt; -use datafusion_common::{exec_datafusion_err, Result}; -use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY}; -use itertools::izip; +use datafusion_common::utils::split_vec_min_alloc; +use datafusion_common::{Result, exec_datafusion_err}; +use datafusion_physical_expr_common::binary_map::{INITIAL_BUFFER_CAPACITY, OutputType}; use std::mem::size_of; use std::sync::Arc; use std::vec; @@ -106,25 +106,22 @@ where lhs_rows: &[usize], array: &ArrayRef, rhs_rows: &[usize], - equal_to_results: &mut [bool], + equal_to_results: &mut BooleanBufferBuilder, ) where B: ByteArrayType, { let array = array.as_bytes::(); - let iter = izip!( - lhs_rows.iter(), - rhs_rows.iter(), - equal_to_results.iter_mut(), - ); - - for (&lhs_row, &rhs_row, equal_to_result) in iter { - // Has found not equal to, don't need to check - if !*equal_to_result { + for (idx, (&lhs_row, &rhs_row)) in + lhs_rows.iter().zip(rhs_rows.iter()).enumerate() + { + if !equal_to_results.get_bit(idx) { continue; } - *equal_to_result = self.do_equal_to_inner(lhs_row, array, rhs_row); + if !self.do_equal_to_inner(lhs_row, array, rhs_row) { + equal_to_results.set_bit(idx, false); + } } } @@ -275,7 +272,7 @@ where lhs_rows: &[usize], array: &ArrayRef, rhs_rows: &[usize], - equal_to_results: &mut [bool], + equal_to_results: &mut BooleanBufferBuilder, ) { // Sanity array type match self.output_type { @@ -384,11 +381,10 @@ where // Given offsets like [0, 2, 4, 5] and n = 1, we expect to get // offsets [0, 2, 3]. We first create two offsets for first_n as [0, 2] and the remaining as [2, 4, 5]. // And we shift the offset starting from 0 for the remaining one, [2, 4, 5] -> [0, 2, 3]. - let mut first_n_offsets = self.offsets.drain(0..n).collect::>(); - let offset_n = *self.offsets.first().unwrap(); - self.offsets - .iter_mut() - .for_each(|offset| *offset = offset.sub(offset_n)); + let offset_n = self.offsets[n]; + let mut first_n_offsets = split_vec_min_alloc(&mut self.offsets, n); + // After the split, self.offsets[0] == offset_n in both branches; normalize in-place. + self.offsets.iter_mut().for_each(|o| *o = o.sub(offset_n)); first_n_offsets.push(offset_n); // SAFETY: the offsets were constructed correctly in `insert_if_new` -- @@ -433,12 +429,22 @@ mod tests { use std::sync::Arc; use crate::aggregates::group_values::multi_group_by::bytes::ByteGroupValueBuilder; - use arrow::array::{ArrayRef, NullBufferBuilder, StringArray}; + use arrow::array::{ArrayRef, BooleanBufferBuilder, NullBufferBuilder, StringArray}; use datafusion_common::DataFusionError; use datafusion_physical_expr::binary_map::OutputType; use super::GroupColumn; + fn make_true_buffer(n: usize) -> BooleanBufferBuilder { + let mut buf = BooleanBufferBuilder::new(n); + buf.append_n(n, true); + buf + } + + fn to_vec(buf: &BooleanBufferBuilder) -> Vec { + (0..buf.len()).map(|i| buf.get_bit(i)).collect() + } + #[test] fn test_byte_group_value_builder_overflow() { let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); @@ -516,16 +522,18 @@ mod tests { } }; - let equal_to = |builder: &ByteGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - let iter = lhs_rows.iter().zip(rhs_rows.iter()); - for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { - equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); - } - }; + let equal_to = + |builder: &ByteGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut BooleanBufferBuilder| { + let iter = lhs_rows.iter().zip(rhs_rows.iter()); + for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { + equal_to_results + .set_bit(idx, builder.equal_to(lhs_row, input_array, rhs_row)); + } + }; test_byte_equal_to_internal(append, equal_to); } @@ -540,18 +548,19 @@ mod tests { .unwrap(); }; - let equal_to = |builder: &ByteGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - builder.vectorized_equal_to( - lhs_rows, - input_array, - rhs_rows, - equal_to_results, - ); - }; + let equal_to = + |builder: &ByteGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut BooleanBufferBuilder| { + builder.vectorized_equal_to( + lhs_rows, + input_array, + rhs_rows, + equal_to_results, + ); + }; test_byte_equal_to_internal(append, equal_to); } @@ -575,19 +584,20 @@ mod tests { .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]) .unwrap(); - let mut equal_to_results = vec![true; all_nulls_input_array.len()]; + let mut equal_to_results = make_true_buffer(all_nulls_input_array.len()); builder.vectorized_equal_to( &[0, 1, 2, 3, 4], &all_nulls_input_array, &[0, 1, 2, 3, 4], &mut equal_to_results, ); + let results = to_vec(&equal_to_results); - assert!(equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(equal_to_results[3]); - assert!(equal_to_results[4]); + assert!(results[0]); + assert!(results[1]); + assert!(results[2]); + assert!(results[3]); + assert!(results[4]); // All not nulls input array let all_not_nulls_input_array = Arc::new(StringArray::from(vec![ @@ -601,19 +611,20 @@ mod tests { .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]) .unwrap(); - let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; + let mut equal_to_results = make_true_buffer(all_not_nulls_input_array.len()); builder.vectorized_equal_to( &[5, 6, 7, 8, 9], &all_not_nulls_input_array, &[0, 1, 2, 3, 4], &mut equal_to_results, ); + let results = to_vec(&equal_to_results); - assert!(equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(equal_to_results[3]); - assert!(equal_to_results[4]); + assert!(results[0]); + assert!(results[1]); + assert!(results[2]); + assert!(results[3]); + assert!(results[4]); } fn test_byte_equal_to_internal(mut append: A, mut equal_to: E) @@ -624,7 +635,7 @@ mod tests { &[usize], &ArrayRef, &[usize], - &mut Vec, + &mut BooleanBufferBuilder, ), { // Will cover such cases: @@ -658,10 +669,10 @@ mod tests { ]) .into_parts(); - // explicitly build a boolean buffer where one of the null values also happens to match + // explicitly build a null buffer where one of the null values also happens to match let mut nulls = NullBufferBuilder::new(6); nulls.append_non_null(); - nulls.append_null(); // this sets Some("bar") to null above + nulls.append_null(); nulls.append_null(); nulls.append_null(); nulls.append_non_null(); @@ -670,7 +681,7 @@ mod tests { Arc::new(StringArray::new(offsets, buffer, nulls.finish())) as ArrayRef; // Check - let mut equal_to_results = vec![true; builder.len()]; + let mut equal_to_results = make_true_buffer(builder.len()); equal_to( &builder, &[0, 1, 2, 3, 4, 5], @@ -678,12 +689,13 @@ mod tests { &[0, 1, 2, 3, 4, 5], &mut equal_to_results, ); - - assert!(!equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(!equal_to_results[3]); - assert!(!equal_to_results[4]); - assert!(equal_to_results[5]); + let results = to_vec(&equal_to_results); + + assert!(!results[0]); + assert!(results[1]); + assert!(results[2]); + assert!(!results[3]); + assert!(!results[4]); + assert!(results[5]); } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs index fde477c2cf7b5..e94e4547e1a75 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes_view.rs @@ -16,14 +16,17 @@ // under the License. use crate::aggregates::group_values::multi_group_by::{ - nulls_equal_to, GroupColumn, Nulls, + GroupColumn, Nulls, nulls_equal_to, }; use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; -use arrow::array::{make_view, Array, ArrayRef, AsArray, ByteView, GenericByteViewArray}; +use arrow::array::{ + Array, ArrayRef, AsArray, BooleanBufferBuilder, ByteView, GenericByteViewArray, + make_view, +}; use arrow::buffer::{Buffer, ScalarBuffer}; use arrow::datatypes::ByteViewType; use datafusion_common::Result; -use itertools::izip; +use datafusion_common::utils::split_vec_min_alloc; use std::marker::PhantomData; use std::mem::{replace, size_of}; use std::sync::Arc; @@ -99,7 +102,8 @@ impl ByteViewGroupValueBuilder { fn equal_to_inner(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { let array = array.as_byte_view::(); - self.do_equal_to_inner(lhs_row, array, rhs_row) + // since this is a single row comparison, don't bother specializing for nulls/buffers + self.do_equal_to_inner::(lhs_row, array, rhs_row) } fn append_val_inner(&mut self, array: &ArrayRef, row: usize) { @@ -117,28 +121,27 @@ impl ByteViewGroupValueBuilder { self.do_append_val_inner(arr, row); } - fn vectorized_equal_to_inner( + // Don't inline to keep the code small and give LLVM the best chance of + // vectorizing the inner loop + #[inline(never)] + fn vectorized_equal_to_inner( &self, lhs_rows: &[usize], - array: &ArrayRef, + array: &GenericByteViewArray, rhs_rows: &[usize], - equal_to_results: &mut [bool], + equal_to_results: &mut BooleanBufferBuilder, ) { - let array = array.as_byte_view::(); - - let iter = izip!( - lhs_rows.iter(), - rhs_rows.iter(), - equal_to_results.iter_mut(), - ); - - for (&lhs_row, &rhs_row, equal_to_result) in iter { - // Has found not equal to, don't need to check - if !*equal_to_result { + for (idx, (&lhs_row, &rhs_row)) in + lhs_rows.iter().zip(rhs_rows.iter()).enumerate() + { + if !equal_to_results.get_bit(idx) { continue; } - *equal_to_result = self.do_equal_to_inner(lhs_row, array, rhs_row); + if !self.do_equal_to_inner::(lhs_row, array, rhs_row) + { + equal_to_results.set_bit(idx, false); + } } } @@ -216,26 +219,42 @@ impl ByteViewGroupValueBuilder { } } - fn do_equal_to_inner( + /// Compare the value at `lhs_row` in this builder with + /// the value at `rhs_row` in input `array` + /// + /// Templated so that the inner compare loop can be + /// specialized based on the input array + #[inline(always)] + fn do_equal_to_inner( &self, lhs_row: usize, array: &GenericByteViewArray, rhs_row: usize, ) -> bool { // Check if nulls equal firstly - let exist_null = self.nulls.is_null(lhs_row); - let input_null = array.is_null(rhs_row); - if let Some(result) = nulls_equal_to(exist_null, input_null) { - return result; + if HAS_NULLS { + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; + } } // Otherwise, we need to check their values - let exist_view = self.views[lhs_row]; + + // SAFETY: the `lhs_row` and rhs_row` are valid + let exist_view = unsafe { *self.views.get_unchecked(lhs_row) }; let exist_view_len = exist_view as u32; - let input_view = array.views()[rhs_row]; + let input_view = unsafe { *array.views().get_unchecked(rhs_row) }; let input_view_len = input_view as u32; + // fast path, if we know there are no buffers, then the view must be inlined + // so we can simply compare the u128 views + if !HAS_BUFFERS { + return exist_view == input_view; + } + // The check logic // - Check len equality // - If inlined, check inlined value @@ -246,19 +265,8 @@ impl ByteViewGroupValueBuilder { } if exist_view_len <= 12 { - let exist_inline = unsafe { - GenericByteViewArray::::inline_value( - &exist_view, - exist_view_len as usize, - ) - }; - let input_inline = unsafe { - GenericByteViewArray::::inline_value( - &input_view, - input_view_len as usize, - ) - }; - exist_inline == input_inline + // both inlined, so compare inlined value + exist_view == input_view } else { let exist_prefix = unsafe { GenericByteViewArray::::inline_value(&exist_view, 4) }; @@ -269,30 +277,28 @@ impl ByteViewGroupValueBuilder { return false; } + // get the full values and compare let exist_full = { let byte_view = ByteView::from(exist_view); - self.value( - byte_view.buffer_index as usize, - byte_view.offset as usize, - byte_view.length as usize, - ) + let buffer_index = byte_view.buffer_index as usize; + let offset = byte_view.offset as usize; + let length = byte_view.length as usize; + debug_assert!(buffer_index <= self.completed.len()); + + unsafe { + if buffer_index < self.completed.len() { + let block = self.completed.get_unchecked(buffer_index); + block.as_slice().get_unchecked(offset..offset + length) + } else { + self.in_progress.get_unchecked(offset..offset + length) + } + } }; let input_full: &[u8] = unsafe { array.value_unchecked(rhs_row).as_ref() }; exist_full == input_full } } - fn value(&self, buffer_index: usize, offset: usize, length: usize) -> &[u8] { - debug_assert!(buffer_index <= self.completed.len()); - - if buffer_index < self.completed.len() { - let block = &self.completed[buffer_index]; - &block[offset..offset + length] - } else { - &self.in_progress[offset..offset + length] - } - } - fn build_inner(self) -> ArrayRef { let Self { views, @@ -358,7 +364,7 @@ impl ByteViewGroupValueBuilder { // // - Shift the `buffer index` of remaining non-inlined `views` // - let first_n_views = self.views.drain(0..n).collect::>(); + let first_n_views = split_vec_min_alloc(&mut self.views, n); let last_non_inlined_view = first_n_views .iter() @@ -451,21 +457,23 @@ impl ByteViewGroupValueBuilder { last_take_len: usize, ) -> Vec { let mut take_buffers = Vec::with_capacity(last_remaining_buffer_index + 1); + debug_assert!(last_remaining_buffer_index <= self.completed.len()); - // Take `0 ~ last_remaining_buffer_index - 1` buffers - if !self.completed.is_empty() || last_remaining_buffer_index == 0 { - take_buffers.extend(self.completed.drain(0..last_remaining_buffer_index)); - } - - // Process the `last_remaining_buffer_index` buffers + // Process the `last_remaining_buffer_index` buffer before draining so the index is valid. let last_buffer = if last_remaining_buffer_index < self.completed.len() { // If it is in `completed`, simply clone self.completed[last_remaining_buffer_index].clone() } else { // If it is `in_progress`, copied `0 ~ offset` part + debug_assert!(last_take_len <= self.in_progress.len()); let taken_last_buffer = self.in_progress[0..last_take_len].to_vec(); Buffer::from_vec(taken_last_buffer) }; + + // Take `0 ~ last_remaining_buffer_index - 1` buffers + if last_remaining_buffer_index > 0 { + take_buffers.extend(self.completed.drain(0..last_remaining_buffer_index)); + } take_buffers.push(last_buffer); take_buffers @@ -505,9 +513,38 @@ impl GroupColumn for ByteViewGroupValueBuilder { group_indices: &[usize], array: &ArrayRef, rows: &[usize], - equal_to_results: &mut [bool], + equal_to_results: &mut BooleanBufferBuilder, ) { - self.vectorized_equal_to_inner(group_indices, array, rows, equal_to_results); + let has_nulls = array.null_count() != 0; + let array = array.as_byte_view::(); + let has_buffers = !array.data_buffers().is_empty(); + // call specialized version based on nulls and buffers presence + match (has_nulls, has_buffers) { + (true, true) => self.vectorized_equal_to_inner::( + group_indices, + array, + rows, + equal_to_results, + ), + (true, false) => self.vectorized_equal_to_inner::( + group_indices, + array, + rows, + equal_to_results, + ), + (false, true) => self.vectorized_equal_to_inner::( + group_indices, + array, + rows, + equal_to_results, + ), + (false, false) => self.vectorized_equal_to_inner::( + group_indices, + array, + rows, + equal_to_results, + ), + } } fn vectorized_append(&mut self, array: &ArrayRef, rows: &[usize]) -> Result<()> { @@ -547,11 +584,23 @@ mod tests { use std::sync::Arc; use crate::aggregates::group_values::multi_group_by::bytes_view::ByteViewGroupValueBuilder; - use arrow::array::{ArrayRef, AsArray, NullBufferBuilder, StringViewArray}; + use arrow::array::{ + ArrayRef, AsArray, BooleanBufferBuilder, NullBufferBuilder, StringViewArray, + }; use arrow::datatypes::StringViewType; use super::GroupColumn; + fn make_true_buffer(n: usize) -> BooleanBufferBuilder { + let mut buf = BooleanBufferBuilder::new(n); + buf.append_n(n, true); + buf + } + + fn to_vec(buf: &BooleanBufferBuilder) -> Vec { + (0..buf.len()).map(|i| buf.get_bit(i)).collect() + } + #[test] fn test_byte_view_append_val() { let mut builder = @@ -586,16 +635,18 @@ mod tests { } }; - let equal_to = |builder: &ByteViewGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - let iter = lhs_rows.iter().zip(rhs_rows.iter()); - for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { - equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); - } - }; + let equal_to = + |builder: &ByteViewGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut BooleanBufferBuilder| { + let iter = lhs_rows.iter().zip(rhs_rows.iter()); + for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { + equal_to_results + .set_bit(idx, builder.equal_to(lhs_row, input_array, rhs_row)); + } + }; test_byte_view_equal_to_internal(append, equal_to); } @@ -610,18 +661,19 @@ mod tests { .unwrap(); }; - let equal_to = |builder: &ByteViewGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - builder.vectorized_equal_to( - lhs_rows, - input_array, - rhs_rows, - equal_to_results, - ); - }; + let equal_to = + |builder: &ByteViewGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut BooleanBufferBuilder| { + builder.vectorized_equal_to( + lhs_rows, + input_array, + rhs_rows, + equal_to_results, + ); + }; test_byte_view_equal_to_internal(append, equal_to); } @@ -646,19 +698,20 @@ mod tests { .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]) .unwrap(); - let mut equal_to_results = vec![true; all_nulls_input_array.len()]; + let mut equal_to_results = make_true_buffer(all_nulls_input_array.len()); builder.vectorized_equal_to( &[0, 1, 2, 3, 4], &all_nulls_input_array, &[0, 1, 2, 3, 4], &mut equal_to_results, ); + let results = to_vec(&equal_to_results); - assert!(equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(equal_to_results[3]); - assert!(equal_to_results[4]); + assert!(results[0]); + assert!(results[1]); + assert!(results[2]); + assert!(results[3]); + assert!(results[4]); // All not nulls input array let all_not_nulls_input_array = Arc::new(StringViewArray::from(vec![ @@ -672,19 +725,20 @@ mod tests { .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]) .unwrap(); - let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; + let mut equal_to_results = make_true_buffer(all_not_nulls_input_array.len()); builder.vectorized_equal_to( &[5, 6, 7, 8, 9], &all_not_nulls_input_array, &[0, 1, 2, 3, 4], &mut equal_to_results, ); + let results = to_vec(&equal_to_results); - assert!(equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(equal_to_results[3]); - assert!(equal_to_results[4]); + assert!(results[0]); + assert!(results[1]); + assert!(results[2]); + assert!(results[3]); + assert!(results[4]); } fn test_byte_view_equal_to_internal(mut append: A, mut equal_to: E) @@ -695,7 +749,7 @@ mod tests { &[usize], &ArrayRef, &[usize], - &mut Vec, + &mut BooleanBufferBuilder, ), { // Will cover such cases: @@ -742,7 +796,7 @@ mod tests { // Define input array let (views, buffer, _nulls) = StringViewArray::from(vec![ Some("foo"), - Some("bar"), // set to null + Some("bar"), None, None, Some("baz"), @@ -756,10 +810,10 @@ mod tests { ]) .into_parts(); - // explicitly build a boolean buffer where one of the null values also happens to match + // explicitly build a null buffer where one of the null values also happens to match let mut nulls = NullBufferBuilder::new(9); nulls.append_non_null(); - nulls.append_null(); // this sets Some("bar") to null above + nulls.append_null(); nulls.append_null(); nulls.append_null(); nulls.append_non_null(); @@ -774,7 +828,7 @@ mod tests { Arc::new(StringViewArray::new(views, buffer, nulls.finish())) as ArrayRef; // Check - let mut equal_to_results = vec![true; input_array.len()]; + let mut equal_to_results = make_true_buffer(input_array.len()); equal_to( &builder, &[0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 8, 8], @@ -782,19 +836,20 @@ mod tests { &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &mut equal_to_results, ); - - assert!(!equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(!equal_to_results[3]); - assert!(!equal_to_results[4]); - assert!(!equal_to_results[5]); - assert!(equal_to_results[6]); - assert!(!equal_to_results[7]); - assert!(!equal_to_results[8]); - assert!(equal_to_results[9]); - assert!(!equal_to_results[10]); - assert!(equal_to_results[11]); + let results = to_vec(&equal_to_results); + + assert!(!results[0]); + assert!(results[1]); + assert!(results[2]); + assert!(!results[3]); + assert!(!results[4]); + assert!(!results[5]); + assert!(results[6]); + assert!(!results[7]); + assert!(!results[8]); + assert!(results[9]); + assert!(!results[10]); + assert!(results[11]); } #[test] @@ -913,4 +968,28 @@ mod tests { let taken_array = builder.take_n(final_ones_to_append); assert_eq!(&taken_array, &input_array); } + + #[test] + fn test_byte_view_take_n_partial_completed_nonzero_index() { + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(30); + let input_array = StringViewArray::from(vec![ + Some("aaaaaaaaaaaaaa"), + Some("bbbbbbbbbbbbbb"), + Some("cccccccccccccc"), + Some("dddddddddddddd"), + Some("eeeeeeeeeeeeee"), + ]); + let input_array: ArrayRef = Arc::new(input_array); + + for row in 0..input_array.len() { + builder.append_val(&input_array, row).unwrap(); + } + + assert_eq!(builder.completed.len(), 2); + assert_eq!(builder.in_progress.len(), 14); + + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(0, 3)); + } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index 9adf028eca7f6..f275d777c3279 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -24,24 +24,24 @@ pub mod primitive; use std::mem::{self, size_of}; +use crate::aggregates::group_values::GroupValues; use crate::aggregates::group_values::multi_group_by::{ boolean::BooleanGroupValueBuilder, bytes::ByteGroupValueBuilder, bytes_view::ByteViewGroupValueBuilder, primitive::PrimitiveGroupValueBuilder, }; -use crate::aggregates::group_values::GroupValues; -use ahash::RandomState; -use arrow::array::{Array, ArrayRef, RecordBatch}; +use arrow::array::{Array, ArrayRef, BooleanBufferBuilder}; use arrow::compute::cast; use arrow::datatypes::{ - BinaryViewType, DataType, Date32Type, Date64Type, Decimal128Type, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Schema, SchemaRef, + BinaryViewType, DataType, Date32Type, Date64Type, Decimal128Type, Field, Float32Type, + Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, Schema, SchemaRef, StringViewType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, + TimestampNanosecondType, TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, + UInt64Type, }; +use datafusion_common::hash_utils::RandomState; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::{internal_datafusion_err, not_impl_err, Result}; +use datafusion_common::{Result, internal_datafusion_err, not_impl_err}; use datafusion_execution::memory_pool::proxy::{HashTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; @@ -82,7 +82,7 @@ pub trait GroupColumn: Send + Sync { lhs_rows: &[usize], array: &ArrayRef, rhs_rows: &[usize], - equal_to_results: &mut [bool], + equal_to_results: &mut BooleanBufferBuilder, ); /// The vectorized version `append_val` @@ -212,7 +212,7 @@ pub struct GroupValuesColumn { /// more general purpose [`GroupValuesRows`]. See the ticket for details: /// /// - /// [`GroupValuesRows`]: crate::aggregates::group_values::row::GroupValuesRows + /// [`GroupValuesRows`]: crate::aggregates::group_values::GroupValuesRows group_values: Vec>, /// reused buffer to store hashes @@ -224,7 +224,6 @@ pub struct GroupValuesColumn { /// Buffers to store intermediate results in `vectorized_append` /// and `vectorized_equal_to`, for reducing memory allocation -#[derive(Default)] struct VectorizedOperationBuffers { /// The `vectorized append` row indices buffer append_row_indices: Vec, @@ -235,8 +234,8 @@ struct VectorizedOperationBuffers { /// The `vectorized_equal_to` group indices buffer equal_to_group_indices: Vec, - /// The `vectorized_equal_to` result buffer - equal_to_results: Vec, + /// The `vectorized_equal_to` result buffer (bitmask) + equal_to_results: BooleanBufferBuilder, /// The buffer for storing row indices found not equal to /// exist groups in `group_values` in `vectorized_equal_to`. @@ -244,12 +243,23 @@ struct VectorizedOperationBuffers { remaining_row_indices: Vec, } +impl Default for VectorizedOperationBuffers { + fn default() -> Self { + Self { + append_row_indices: Vec::new(), + equal_to_row_indices: Vec::new(), + equal_to_group_indices: Vec::new(), + equal_to_results: BooleanBufferBuilder::new(0), + remaining_row_indices: Vec::new(), + } + } +} + impl VectorizedOperationBuffers { fn clear(&mut self) { self.append_row_indices.clear(); self.equal_to_row_indices.clear(); self.equal_to_group_indices.clear(); - self.equal_to_results.clear(); self.remaining_row_indices.clear(); } } @@ -262,6 +272,7 @@ impl GroupValuesColumn { /// Create a new instance of GroupValuesColumn if supported for the specified schema pub fn try_new(schema: SchemaRef) -> Result { let map = HashTable::with_capacity(0); + let group_values = Self::build_group_columns(&schema)?; Ok(Self { schema, map, @@ -269,12 +280,27 @@ impl GroupValuesColumn { emit_group_index_list_buffer: Vec::new(), vectorized_operation_buffers: VectorizedOperationBuffers::default(), map_size: 0, - group_values: vec![], + group_values, hashes_buffer: Default::default(), random_state: crate::aggregates::AGGREGATION_HASH_SEED, }) } + /// Build one fresh [`GroupColumn`] per field in the schema. + /// + /// Used at construction time (`try_new`) and to repopulate the column + /// vector after operations that drain it (`emit(EmitTo::All)`, + /// `clear_shrink`). Centralising it keeps the post-condition that + /// `self.group_values` always contains exactly one builder per schema + /// field outside of those transient drain points. + fn build_group_columns(schema: &Schema) -> Result>> { + let mut v: Vec> = Vec::with_capacity(schema.fields().len()); + for f in schema.fields().iter() { + v.push(make_group_column(f.as_ref())?); + } + Ok(v) + } + // ======================================================================== // Scalarized intern // ======================================================================== @@ -499,7 +525,6 @@ impl GroupValuesColumn { .equal_to_group_indices .clear(); - let mut group_values_len = self.group_values[0].len(); for (row, &target_hash) in batch_hashes.iter().enumerate() { let entry = self .map @@ -508,7 +533,8 @@ impl GroupValuesColumn { let Some((_, group_index_view)) = entry else { // 1. Bucket not found case // Build `new inlined group index view` - let current_group_idx = group_values_len; + let current_group_idx = self.group_values[0].len() + + self.vectorized_operation_buffers.append_row_indices.len(); let group_index_view = GroupIndexView::new_inlined(current_group_idx as u64); @@ -528,7 +554,6 @@ impl GroupValuesColumn { // Set group index to row in `groups` groups[row] = current_group_idx; - group_values_len += 1; continue; }; @@ -540,14 +565,13 @@ impl GroupValuesColumn { // into `vectorized_equal_to_row_indices` and `vectorized_equal_to_group_indices`. let list_offset = group_index_view.value() as usize; let group_index_list = &self.group_index_lists[list_offset]; - for &group_index in group_index_list { - self.vectorized_operation_buffers - .equal_to_row_indices - .push(row); - self.vectorized_operation_buffers - .equal_to_group_indices - .push(group_index); - } + + self.vectorized_operation_buffers + .equal_to_group_indices + .extend_from_slice(group_index_list); + self.vectorized_operation_buffers + .equal_to_row_indices + .extend(std::iter::repeat_n(row, group_index_list.len())); } else { let group_index = group_index_view.value() as usize; self.vectorized_operation_buffers @@ -616,15 +640,16 @@ impl GroupValuesColumn { // 1. Perform `vectorized_equal_to` for `rows` in `vectorized_equal_to_group_indices` // and `group_indices` in `vectorized_equal_to_group_indices` - let mut equal_to_results = - mem::take(&mut self.vectorized_operation_buffers.equal_to_results); - equal_to_results.clear(); - equal_to_results.resize( - self.vectorized_operation_buffers - .equal_to_group_indices - .len(), - true, + let n = self + .vectorized_operation_buffers + .equal_to_group_indices + .len(); + let mut equal_to_results = mem::replace( + &mut self.vectorized_operation_buffers.equal_to_results, + BooleanBufferBuilder::new(0), ); + equal_to_results.truncate(0); + equal_to_results.append_n(n, true); for (col_idx, group_col) in self.group_values.iter().enumerate() { group_col.vectorized_equal_to( @@ -644,7 +669,7 @@ impl GroupValuesColumn { .iter() .enumerate() { - let equal_to_result = equal_to_results[idx]; + let equal_to_result = equal_to_results.get_bit(idx); // Equal to case, set the `group_indices` to `rows` in `groups` if equal_to_result { @@ -712,7 +737,7 @@ impl GroupValuesColumn { /// /// The hash collision may be not frequent, so the fallback will indeed hardly happen. /// In most situations, `scalarized_indices` will found to be empty after finishing to - /// preform `vectorized_equal_to`. + /// perform `vectorized_equal_to`. fn scalarized_intern_remaining( &mut self, cols: &[ArrayRef], @@ -889,172 +914,174 @@ macro_rules! instantiate_primitive { }; } -impl GroupValues for GroupValuesColumn { - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { - if self.group_values.is_empty() { - let mut v = Vec::with_capacity(cols.len()); - - for f in self.schema.fields().iter() { - let nullable = f.is_nullable(); - let data_type = f.data_type(); - match data_type { - &DataType::Int8 => { - instantiate_primitive!(v, nullable, Int8Type, data_type) - } - &DataType::Int16 => { - instantiate_primitive!(v, nullable, Int16Type, data_type) - } - &DataType::Int32 => { - instantiate_primitive!(v, nullable, Int32Type, data_type) - } - &DataType::Int64 => { - instantiate_primitive!(v, nullable, Int64Type, data_type) - } - &DataType::UInt8 => { - instantiate_primitive!(v, nullable, UInt8Type, data_type) - } - &DataType::UInt16 => { - instantiate_primitive!(v, nullable, UInt16Type, data_type) - } - &DataType::UInt32 => { - instantiate_primitive!(v, nullable, UInt32Type, data_type) - } - &DataType::UInt64 => { - instantiate_primitive!(v, nullable, UInt64Type, data_type) - } - &DataType::Float32 => { - instantiate_primitive!(v, nullable, Float32Type, data_type) - } - &DataType::Float64 => { - instantiate_primitive!(v, nullable, Float64Type, data_type) - } - &DataType::Date32 => { - instantiate_primitive!(v, nullable, Date32Type, data_type) - } - &DataType::Date64 => { - instantiate_primitive!(v, nullable, Date64Type, data_type) - } - &DataType::Time32(t) => match t { - TimeUnit::Second => { - instantiate_primitive!( - v, - nullable, - Time32SecondType, - data_type - ) - } - TimeUnit::Millisecond => { - instantiate_primitive!( - v, - nullable, - Time32MillisecondType, - data_type - ) - } - _ => {} - }, - &DataType::Time64(t) => match t { - TimeUnit::Microsecond => { - instantiate_primitive!( - v, - nullable, - Time64MicrosecondType, - data_type - ) - } - TimeUnit::Nanosecond => { - instantiate_primitive!( - v, - nullable, - Time64NanosecondType, - data_type - ) - } - _ => {} - }, - &DataType::Timestamp(t, _) => match t { - TimeUnit::Second => { - instantiate_primitive!( - v, - nullable, - TimestampSecondType, - data_type - ) - } - TimeUnit::Millisecond => { - instantiate_primitive!( - v, - nullable, - TimestampMillisecondType, - data_type - ) - } - TimeUnit::Microsecond => { - instantiate_primitive!( - v, - nullable, - TimestampMicrosecondType, - data_type - ) - } - TimeUnit::Nanosecond => { - instantiate_primitive!( - v, - nullable, - TimestampNanosecondType, - data_type - ) - } - }, - &DataType::Decimal128(_, _) => { - instantiate_primitive! { - v, - nullable, - Decimal128Type, - data_type - } - } - &DataType::Utf8 => { - let b = ByteGroupValueBuilder::::new(OutputType::Utf8); - v.push(Box::new(b) as _) - } - &DataType::LargeUtf8 => { - let b = ByteGroupValueBuilder::::new(OutputType::Utf8); - v.push(Box::new(b) as _) - } - &DataType::Binary => { - let b = ByteGroupValueBuilder::::new(OutputType::Binary); - v.push(Box::new(b) as _) - } - &DataType::LargeBinary => { - let b = ByteGroupValueBuilder::::new(OutputType::Binary); - v.push(Box::new(b) as _) - } - &DataType::Utf8View => { - let b = ByteViewGroupValueBuilder::::new(); - v.push(Box::new(b) as _) - } - &DataType::BinaryView => { - let b = ByteViewGroupValueBuilder::::new(); - v.push(Box::new(b) as _) - } - &DataType::Boolean => { - if nullable { - let b = BooleanGroupValueBuilder::::new(); - v.push(Box::new(b) as _) - } else { - let b = BooleanGroupValueBuilder::::new(); - v.push(Box::new(b) as _) - } - } - dt => { - return not_impl_err!("{dt} not supported in GroupValuesColumn") - } - } +/// Returns true if the specified data type has a specialized +/// [`GroupColumn`] builder in [`make_group_column`]. +/// +/// This is the allow-list that gates the `GroupValuesRows` fallback in +/// [`crate::aggregates::group_values::new_group_values`]: it must accept +/// exactly the set of types that [`make_group_column`] constructs a +/// builder for. The `group_column_supported_type_matches_make_group_column` +/// test below pins this biconditional. +fn group_column_supported_type(data_type: &DataType) -> bool { + matches!( + *data_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + | DataType::Date32 + | DataType::Date64 + // Only the semantically valid Time variants per the Arrow spec. + // The dispatcher in `make_group_column` returns NotImpl for the + // other unit combinations, so accepting them here would cause a + // schema to be routed into GroupValuesColumn and then fail at + // intern. Keep these two arms in lockstep with the dispatcher. + | DataType::Time32(TimeUnit::Second) + | DataType::Time32(TimeUnit::Millisecond) + | DataType::Time64(TimeUnit::Microsecond) + | DataType::Time64(TimeUnit::Nanosecond) + | DataType::Timestamp(_, _) + | DataType::Utf8View + | DataType::BinaryView + | DataType::Boolean + ) +} + +/// Build a [`GroupColumn`] for a single schema field. +/// +/// Extracted from the inline match that used to live in +/// [`GroupValuesColumn::intern`] so the per-field dispatch lives in one +/// place. This factory is the single source of truth for which Arrow types +/// map to which builder, and it is the function that future nested-type +/// specializations (e.g. `Struct`, `List`, `LargeList`) plug into without +/// having to enumerate every combination inline. +/// +/// Returns `Err(not_impl_err!(...))` for any type not in the supported set; +/// callers (`GroupValues::intern`) propagate that error so the +/// `GroupValuesRows` fallback can take over upstream of this builder. +/// +/// The allow-list that gates this dispatcher lives in +/// [`group_column_supported_type`] directly above. +fn make_group_column(field: &Field) -> Result> { + let nullable = field.is_nullable(); + let data_type = field.data_type(); + let mut v: Vec> = Vec::with_capacity(1); + match *data_type { + DataType::Int8 => instantiate_primitive!(v, nullable, Int8Type, data_type), + DataType::Int16 => instantiate_primitive!(v, nullable, Int16Type, data_type), + DataType::Int32 => instantiate_primitive!(v, nullable, Int32Type, data_type), + DataType::Int64 => instantiate_primitive!(v, nullable, Int64Type, data_type), + DataType::UInt8 => instantiate_primitive!(v, nullable, UInt8Type, data_type), + DataType::UInt16 => instantiate_primitive!(v, nullable, UInt16Type, data_type), + DataType::UInt32 => instantiate_primitive!(v, nullable, UInt32Type, data_type), + DataType::UInt64 => instantiate_primitive!(v, nullable, UInt64Type, data_type), + DataType::Float32 => { + instantiate_primitive!(v, nullable, Float32Type, data_type) + } + DataType::Float64 => { + instantiate_primitive!(v, nullable, Float64Type, data_type) + } + DataType::Date32 => instantiate_primitive!(v, nullable, Date32Type, data_type), + DataType::Date64 => instantiate_primitive!(v, nullable, Date64Type, data_type), + DataType::Time32(t) => match t { + TimeUnit::Second => { + instantiate_primitive!(v, nullable, Time32SecondType, data_type) + } + TimeUnit::Millisecond => { + instantiate_primitive!(v, nullable, Time32MillisecondType, data_type) } - self.group_values = v; + // Time32 with Microsecond / Nanosecond is not a valid Arrow type + // combination; reject explicitly so group_column_supported_type + // and this dispatcher stay in lockstep (see consistency fuzz below). + _ => return not_impl_err!("{data_type} not supported in GroupValuesColumn"), + }, + DataType::Time64(t) => match t { + TimeUnit::Microsecond => { + instantiate_primitive!(v, nullable, Time64MicrosecondType, data_type) + } + TimeUnit::Nanosecond => { + instantiate_primitive!(v, nullable, Time64NanosecondType, data_type) + } + // Time64 with Second / Millisecond is not a valid Arrow type + // combination; reject explicitly. + _ => return not_impl_err!("{data_type} not supported in GroupValuesColumn"), + }, + DataType::Timestamp(t, _) => match t { + TimeUnit::Second => { + instantiate_primitive!(v, nullable, TimestampSecondType, data_type) + } + TimeUnit::Millisecond => { + instantiate_primitive!(v, nullable, TimestampMillisecondType, data_type) + } + TimeUnit::Microsecond => { + instantiate_primitive!(v, nullable, TimestampMicrosecondType, data_type) + } + TimeUnit::Nanosecond => { + instantiate_primitive!(v, nullable, TimestampNanosecondType, data_type) + } + }, + DataType::Decimal128(_, _) => { + instantiate_primitive!(v, nullable, Decimal128Type, data_type) + } + DataType::Utf8 => { + v.push(Box::new(ByteGroupValueBuilder::::new( + OutputType::Utf8, + ))); + } + DataType::LargeUtf8 => { + v.push(Box::new(ByteGroupValueBuilder::::new( + OutputType::Utf8, + ))); + } + DataType::Binary => { + v.push(Box::new(ByteGroupValueBuilder::::new( + OutputType::Binary, + ))); + } + DataType::LargeBinary => { + v.push(Box::new(ByteGroupValueBuilder::::new( + OutputType::Binary, + ))); } + DataType::Utf8View => { + v.push(Box::new(ByteViewGroupValueBuilder::::new())); + } + DataType::BinaryView => { + v.push(Box::new(ByteViewGroupValueBuilder::::new())); + } + DataType::Boolean => { + if nullable { + v.push(Box::new(BooleanGroupValueBuilder::::new())); + } else { + v.push(Box::new(BooleanGroupValueBuilder::::new())); + } + } + _ => return not_impl_err!("{data_type} not supported in GroupValuesColumn"), + } + debug_assert_eq!( + v.len(), + 1, + "make_group_column must push exactly one builder" + ); + Ok(v.into_iter().next().unwrap()) +} +impl GroupValues for GroupValuesColumn { + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + // `try_new` and the reset points in `emit` / `clear_shrink` keep + // `self.group_values` populated with one builder per schema field, + // so no lazy initialization is needed here. if !STREAMING { self.vectorized_intern(cols, groups) } else { @@ -1082,8 +1109,14 @@ impl GroupValues for GroupValuesColumn { fn emit(&mut self, emit_to: EmitTo) -> Result> { let mut output = match emit_to { EmitTo::All => { - let group_values = mem::take(&mut self.group_values); - debug_assert!(self.group_values.is_empty()); + // Replace the column builders with a fresh set so the + // aggregator is immediately reusable after the drain. + // Same `self.schema` was already validated by `try_new`, + // so `build_group_columns` would only error here if some + // out-of-band schema mutation occurred — propagate it as + // a real Result rather than panicking. + let fresh = Self::build_group_columns(&self.schema)?; + let group_values = mem::replace(&mut self.group_values, fresh); group_values .into_iter() @@ -1181,14 +1214,18 @@ impl GroupValues for GroupValuesColumn { Ok(output) } - fn clear_shrink(&mut self, batch: &RecordBatch) { - let count = batch.num_rows(); - self.group_values.clear(); + fn clear_shrink(&mut self, num_rows: usize) { + // Reset to a fresh column-builder vector. The schema was validated + // in `try_new`, so rebuilding cannot fail unless something else + // mutated the schema out-of-band — surface that as a panic since + // `clear_shrink` is infallible by trait signature. + self.group_values = Self::build_group_columns(&self.schema) + .expect("schema previously validated in try_new"); self.map.clear(); - self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared + self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since the map is cleared self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); self.hashes_buffer.clear(); - self.hashes_buffer.shrink_to(count); + self.hashes_buffer.shrink_to(num_rows); // Such structures are only used in `non-streaming` case if !STREAMING { @@ -1205,39 +1242,7 @@ pub fn supported_schema(schema: &Schema) -> bool { .fields() .iter() .map(|f| f.data_type()) - .all(supported_type) -} - -/// Returns true if the specified data type is supported by [`GroupValuesColumn`] -/// -/// In order to be supported, there must be a specialized implementation of -/// [`GroupColumn`] for the data type, instantiated in [`GroupValuesColumn::intern`] -fn supported_type(data_type: &DataType) -> bool { - matches!( - *data_type, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal128(_, _) - | DataType::Utf8 - | DataType::LargeUtf8 - | DataType::Binary - | DataType::LargeBinary - | DataType::Date32 - | DataType::Date64 - | DataType::Time32(_) - | DataType::Timestamp(_, _) - | DataType::Utf8View - | DataType::BinaryView - | DataType::Boolean - ) + .all(group_column_supported_type) } ///Shows how many `null`s there are in an array @@ -1261,10 +1266,131 @@ mod tests { use datafusion_expr::EmitTo; use crate::aggregates::group_values::{ - multi_group_by::GroupValuesColumn, GroupValues, + GroupValues, multi_group_by::GroupValuesColumn, }; - use super::GroupIndexView; + use super::{ + GroupIndexView, group_column_supported_type, make_group_column, supported_schema, + }; + + /// CRITICAL invariant: if `group_column_supported_type(t)` returns true + /// the dispatcher must accept that type at intern time, and conversely + /// if `group_column_supported_type(t)` returns false the planner must + /// NOT route it through `GroupValuesColumn`. A divergence here would + /// let the planner select `GroupValuesColumn` for a type whose + /// dispatcher arm is missing, producing a runtime `not_impl_err` after + /// the field reaches the builder factory. + /// + /// This test fuzzes a representative cross-section of types and asserts + /// both directions of the biconditional. When a new specialization is + /// added (`Float16`, `FixedSizeList`, `Struct`, ...) it should be added + /// to the supported_cases vector; when a type is intentionally rejected + /// it should be added to unsupported_cases. + #[test] + fn group_column_supported_type_matches_make_group_column() { + let supported_cases: Vec = vec![ + DataType::Int8, + DataType::Int64, + DataType::UInt64, + DataType::Float32, + DataType::Float64, + DataType::Decimal128(38, 10), + DataType::Utf8, + DataType::LargeUtf8, + DataType::Utf8View, + DataType::Binary, + DataType::LargeBinary, + DataType::BinaryView, + DataType::Boolean, + DataType::Date32, + DataType::Date64, + DataType::Time32(arrow::datatypes::TimeUnit::Second), + DataType::Time32(arrow::datatypes::TimeUnit::Millisecond), + DataType::Time64(arrow::datatypes::TimeUnit::Microsecond), + DataType::Time64(arrow::datatypes::TimeUnit::Nanosecond), + DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, None), + ]; + + for dt in &supported_cases { + assert!( + group_column_supported_type(dt), + "expected group_column_supported_type=true for {dt:?}" + ); + let field = Field::new("col", dt.clone(), true); + make_group_column(&field).unwrap_or_else(|e| { + panic!( + "group_column_supported_type accepted {dt:?} but make_group_column rejected: {e}" + ) + }); + } + + let unsupported_cases: Vec = vec![ + DataType::Float16, + DataType::Decimal256(76, 10), + // Invalid Time-unit combinations: Time32 is defined only for + // Second / Millisecond and Time64 only for Microsecond / + // Nanosecond. The TimeUnit enum allows constructing the other + // combinations programmatically, but they are not valid Arrow + // types and must be rejected by both group_column_supported_type + // and the dispatcher. + DataType::Time64(arrow::datatypes::TimeUnit::Second), + DataType::Time64(arrow::datatypes::TimeUnit::Millisecond), + DataType::Time32(arrow::datatypes::TimeUnit::Microsecond), + DataType::Time32(arrow::datatypes::TimeUnit::Nanosecond), + ]; + + for dt in &unsupported_cases { + assert!( + !group_column_supported_type(dt), + "expected group_column_supported_type=false for {dt:?}" + ); + let field = Field::new("col", dt.clone(), true); + assert!( + make_group_column(&field).is_err(), + "group_column_supported_type rejected {dt:?} but make_group_column accepted it" + ); + } + } + + #[test] + fn supported_schema_rejects_mix_of_supported_and_unsupported() { + // One Float16 column among supported columns flips the whole + // schema to GroupValuesRows fallback. + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float16, true), + ]); + assert!(!supported_schema(&schema)); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Boolean, true), + ]); + assert!(supported_schema(&schema)); + } + + #[test] + fn try_new_returns_not_impl_for_unsupported_top_level_type() { + // `try_new` now eagerly constructs the per-field GroupColumn + // builders via `make_group_column`, so an unsupported schema is + // rejected at construction time rather than at first `intern`. + // `GroupValuesColumn` doesn't implement `Debug`, so explicit match + // instead of `unwrap_err`. + let schema = + Arc::new(Schema::new(vec![Field::new("x", DataType::Float16, true)])); + match GroupValuesColumn::::try_new(schema) { + Ok(_) => panic!("expected NotImpl error, but try_new succeeded"), + Err(e) => { + let msg = e.to_string(); + assert!( + msg.contains("not supported in GroupValuesColumn"), + "expected NotImpl error from dispatcher, got: {msg}" + ); + } + } + } #[test] fn test_intern_for_vectorized_group_values() { @@ -1336,6 +1462,17 @@ mod tests { let schema = Arc::new(Schema::new_with_metadata(vec![field], HashMap::new())); let mut group_values = GroupValuesColumn::::try_new(schema).unwrap(); + // Seed the column with 12 placeholder rows so the upcoming + // `emit(EmitTo::First(4))` calls can `take_n` without panicking. + // The hashmap entries below reference group indices 0..=11, so the + // single column builder needs at least 12 rows to back them. + let seed: ArrayRef = Arc::new(arrow::array::Int32Array::from(vec![0_i32; 12])); + for row in 0..12 { + group_values.group_values[0] + .append_val(&seed, row) + .expect("seed append"); + } + // Insert group index views and check if success to insert insert_inline_group_index_view(&mut group_values, 0, 0); insert_non_inline_group_index_view(&mut group_values, 1, vec![1, 2]); diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs index df2cf4bdecce5..068b849cb240f 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs @@ -15,17 +15,22 @@ // specific language governing permissions and limitations // under the License. +use crate::aggregates::group_values::HashValue; use crate::aggregates::group_values::multi_group_by::{ - nulls_equal_to, GroupColumn, Nulls, + GroupColumn, Nulls, nulls_equal_to, }; use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; use arrow::array::ArrowNativeTypeOp; -use arrow::array::{cast::AsArray, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; +use arrow::array::{ + Array, ArrayRef, ArrowPrimitiveType, BooleanBufferBuilder, PrimitiveArray, + cast::AsArray, +}; use arrow::buffer::ScalarBuffer; use arrow::datatypes::DataType; +use arrow::util::bit_util::apply_bitwise_binary_op; use datafusion_common::Result; +use datafusion_common::utils::split_vec_min_alloc; use datafusion_execution::memory_pool::proxy::VecAllocExt; -use itertools::izip; use std::iter; use std::sync::Arc; @@ -47,6 +52,7 @@ pub struct PrimitiveGroupValueBuilder PrimitiveGroupValueBuilder where T: ArrowPrimitiveType, + T::Native: HashValue, { /// Create a new `PrimitiveGroupValueBuilder` pub fn new(data_type: DataType) -> Self { @@ -62,43 +68,47 @@ where lhs_rows: &[usize], array: &ArrayRef, rhs_rows: &[usize], - equal_to_results: &mut [bool], + equal_to_results: &mut BooleanBufferBuilder, ) { assert!( !NULLABLE || (array.null_count() == 0 && !self.nulls.might_have_nulls()), "called with nullable input" ); let array_values = array.as_primitive::().values(); + let n = lhs_rows.len(); - let iter = izip!( - lhs_rows.iter(), - rhs_rows.iter(), - equal_to_results.iter_mut(), - ); + // Build a packed comparison bitmask, then AND it into equal_to_results + let num_bytes = n.div_ceil(8); + let mut cmp_buf = vec![0u8; num_bytes]; - for (&lhs_row, &rhs_row, equal_to_result) in iter { - let result = { - // Getting unchecked not only for bound checks but because the bound checks are - // what prevents auto-vectorization - let left = if cfg!(debug_assertions) { - self.group_values[lhs_row] - } else { - // SAFETY: indices are guaranteed to be in bounds - unsafe { *self.group_values.get_unchecked(lhs_row) } - }; - let right = if cfg!(debug_assertions) { - array_values[rhs_row] - } else { - // SAFETY: indices are guaranteed to be in bounds - unsafe { *array_values.get_unchecked(rhs_row) } - }; - - // Always evaluate, to allow for auto-vectorization - left.is_eq(right) + for (i, (&lhs_row, &rhs_row)) in lhs_rows.iter().zip(rhs_rows.iter()).enumerate() + { + let left = if cfg!(debug_assertions) { + self.group_values[lhs_row] + } else { + unsafe { *self.group_values.get_unchecked(lhs_row) } }; - - *equal_to_result = result && *equal_to_result; + let right = if cfg!(debug_assertions) { + array_values[rhs_row] + } else { + unsafe { *array_values.get_unchecked(rhs_row) } + }; + // `left` was already canonicalized on append; canonicalize the + // input so ±0 (and any future equivalence class) compares equal. + if left.is_eq(right.canonicalize()) { + cmp_buf[i / 8] |= 1 << (i % 8); + } } + + // AND the comparison result into the existing equal_to_results bitmask + apply_bitwise_binary_op( + equal_to_results.as_slice_mut(), + 0, + &cmp_buf, + 0, + n, + |a, b| a & b, + ); } pub fn vectorized_equal_nullable( @@ -106,39 +116,38 @@ where lhs_rows: &[usize], array: &ArrayRef, rhs_rows: &[usize], - equal_to_results: &mut [bool], + equal_to_results: &mut BooleanBufferBuilder, ) { assert!(NULLABLE, "called with non-nullable input"); let array = array.as_primitive::(); - let iter = izip!( - lhs_rows.iter(), - rhs_rows.iter(), - equal_to_results.iter_mut(), - ); - - for (&lhs_row, &rhs_row, equal_to_result) in iter { - // Has found not equal to in previous column, don't need to check - if !*equal_to_result { + for (idx, (&lhs_row, &rhs_row)) in + lhs_rows.iter().zip(rhs_rows.iter()).enumerate() + { + if !equal_to_results.get_bit(idx) { continue; } - // Perf: skip null check (by short circuit) if input is not nullable let exist_null = self.nulls.is_null(lhs_row); let input_null = array.is_null(rhs_row); if let Some(result) = nulls_equal_to(exist_null, input_null) { - *equal_to_result = result; + if !result { + equal_to_results.set_bit(idx, false); + } continue; } - // Otherwise, we need to check their values - *equal_to_result = self.group_values[lhs_row].is_eq(array.value(rhs_row)); + if !self.group_values[lhs_row].is_eq(array.value(rhs_row).canonicalize()) { + equal_to_results.set_bit(idx, false); + } } } } impl GroupColumn for PrimitiveGroupValueBuilder +where + T::Native: HashValue, { fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { // Perf: skip null check (by short circuit) if input is not nullable @@ -151,7 +160,8 @@ impl GroupColumn // Otherwise, we need to check their values } - self.group_values[lhs_row].is_eq(array.as_primitive::().value(rhs_row)) + self.group_values[lhs_row] + .is_eq(array.as_primitive::().value(rhs_row).canonicalize()) } fn append_val(&mut self, array: &ArrayRef, row: usize) -> Result<()> { @@ -162,10 +172,12 @@ impl GroupColumn self.group_values.push(T::default_value()); } else { self.nulls.append(false); - self.group_values.push(array.as_primitive::().value(row)); + self.group_values + .push(array.as_primitive::().value(row).canonicalize()); } } else { - self.group_values.push(array.as_primitive::().value(row)); + self.group_values + .push(array.as_primitive::().value(row).canonicalize()); } Ok(()) @@ -176,7 +188,7 @@ impl GroupColumn lhs_rows: &[usize], array: &ArrayRef, rhs_rows: &[usize], - equal_to_results: &mut [bool], + equal_to_results: &mut BooleanBufferBuilder, ) { if !NULLABLE || (array.null_count() == 0 && !self.nulls.might_have_nulls()) { self.vectorized_equal_to_non_nullable( @@ -211,7 +223,7 @@ impl GroupColumn self.group_values.push(T::default_value()); } else { self.nulls.append(false); - self.group_values.push(arr.value(row)); + self.group_values.push(arr.value(row).canonicalize()); } } } @@ -219,7 +231,7 @@ impl GroupColumn (true, Nulls::None) => { self.nulls.append_n(rows.len(), false); for &row in rows { - self.group_values.push(arr.value(row)); + self.group_values.push(arr.value(row).canonicalize()); } } @@ -231,7 +243,7 @@ impl GroupColumn (false, _) => { for &row in rows { - self.group_values.push(arr.value(row)); + self.group_values.push(arr.value(row).canonicalize()); } } } @@ -265,8 +277,7 @@ impl GroupColumn } fn take_n(&mut self, n: usize) -> ArrayRef { - let first_n = self.group_values.drain(0..n).collect::>(); - + let first_n = split_vec_min_alloc(&mut self.group_values, n); let first_n_nulls = if NULLABLE { self.nulls.take_n(n) } else { None }; Arc::new( @@ -281,11 +292,23 @@ mod tests { use std::sync::Arc; use crate::aggregates::group_values::multi_group_by::primitive::PrimitiveGroupValueBuilder; - use arrow::array::{ArrayRef, Float32Array, Int64Array, NullBufferBuilder}; + use arrow::array::{ + ArrayRef, BooleanBufferBuilder, Float32Array, Int64Array, NullBufferBuilder, + }; use arrow::datatypes::{DataType, Float32Type, Int64Type}; use super::GroupColumn; + fn make_true_buffer(n: usize) -> BooleanBufferBuilder { + let mut buf = BooleanBufferBuilder::new(n); + buf.append_n(n, true); + buf + } + + fn to_vec(buf: &BooleanBufferBuilder) -> Vec { + (0..buf.len()).map(|i| buf.get_bit(i)).collect() + } + #[test] fn test_nullable_primitive_equal_to() { let append = |builder: &mut PrimitiveGroupValueBuilder, @@ -296,16 +319,18 @@ mod tests { } }; - let equal_to = |builder: &PrimitiveGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - let iter = lhs_rows.iter().zip(rhs_rows.iter()); - for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { - equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); - } - }; + let equal_to = + |builder: &PrimitiveGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut BooleanBufferBuilder| { + let iter = lhs_rows.iter().zip(rhs_rows.iter()); + for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { + equal_to_results + .set_bit(idx, builder.equal_to(lhs_row, input_array, rhs_row)); + } + }; test_nullable_primitive_equal_to_internal(append, equal_to); } @@ -320,18 +345,19 @@ mod tests { .unwrap(); }; - let equal_to = |builder: &PrimitiveGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - builder.vectorized_equal_to( - lhs_rows, - input_array, - rhs_rows, - equal_to_results, - ); - }; + let equal_to = + |builder: &PrimitiveGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut BooleanBufferBuilder| { + builder.vectorized_equal_to( + lhs_rows, + input_array, + rhs_rows, + equal_to_results, + ); + }; test_nullable_primitive_equal_to_internal(append, equal_to); } @@ -344,7 +370,7 @@ mod tests { &[usize], &ArrayRef, &[usize], - &mut Vec, + &mut BooleanBufferBuilder, ), { // Will cover such cases: @@ -384,7 +410,7 @@ mod tests { // explicitly build a null buffer where one of the null values also happens to match let mut nulls = NullBufferBuilder::new(6); nulls.append_non_null(); - nulls.append_null(); // this sets Some(2) to null above + nulls.append_null(); nulls.append_null(); nulls.append_non_null(); nulls.append_null(); @@ -393,7 +419,7 @@ mod tests { let input_array = Arc::new(Float32Array::new(values, nulls.finish())) as ArrayRef; // Check - let mut equal_to_results = vec![true; builder.len()]; + let mut equal_to_results = make_true_buffer(builder.len()); equal_to( &builder, &[0, 1, 2, 3, 4, 5, 6], @@ -401,14 +427,15 @@ mod tests { &[0, 1, 2, 3, 4, 5, 6], &mut equal_to_results, ); - - assert!(!equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(equal_to_results[3]); - assert!(!equal_to_results[4]); - assert!(equal_to_results[5]); - assert!(!equal_to_results[6]); + let results = to_vec(&equal_to_results); + + assert!(!results[0]); + assert!(results[1]); + assert!(results[2]); + assert!(results[3]); + assert!(!results[4]); + assert!(results[5]); + assert!(!results[6]); } #[test] @@ -421,16 +448,18 @@ mod tests { } }; - let equal_to = |builder: &PrimitiveGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - let iter = lhs_rows.iter().zip(rhs_rows.iter()); - for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { - equal_to_results[idx] = builder.equal_to(lhs_row, input_array, rhs_row); - } - }; + let equal_to = + |builder: &PrimitiveGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut BooleanBufferBuilder| { + let iter = lhs_rows.iter().zip(rhs_rows.iter()); + for (idx, (&lhs_row, &rhs_row)) in iter.enumerate() { + equal_to_results + .set_bit(idx, builder.equal_to(lhs_row, input_array, rhs_row)); + } + }; test_not_nullable_primitive_equal_to_internal(append, equal_to); } @@ -445,18 +474,19 @@ mod tests { .unwrap(); }; - let equal_to = |builder: &PrimitiveGroupValueBuilder, - lhs_rows: &[usize], - input_array: &ArrayRef, - rhs_rows: &[usize], - equal_to_results: &mut Vec| { - builder.vectorized_equal_to( - lhs_rows, - input_array, - rhs_rows, - equal_to_results, - ); - }; + let equal_to = + |builder: &PrimitiveGroupValueBuilder, + lhs_rows: &[usize], + input_array: &ArrayRef, + rhs_rows: &[usize], + equal_to_results: &mut BooleanBufferBuilder| { + builder.vectorized_equal_to( + lhs_rows, + input_array, + rhs_rows, + equal_to_results, + ); + }; test_not_nullable_primitive_equal_to_internal(append, equal_to); } @@ -469,7 +499,7 @@ mod tests { &[usize], &ArrayRef, &[usize], - &mut Vec, + &mut BooleanBufferBuilder, ), { // Will cover such cases: @@ -487,7 +517,7 @@ mod tests { let input_array = Arc::new(Int64Array::from(vec![Some(0), Some(2)])) as ArrayRef; // Check - let mut equal_to_results = vec![true; builder.len()]; + let mut equal_to_results = make_true_buffer(builder.len()); equal_to( &builder, &[0, 1], @@ -495,9 +525,10 @@ mod tests { &[0, 1], &mut equal_to_results, ); + let results = to_vec(&equal_to_results); - assert!(equal_to_results[0]); - assert!(!equal_to_results[1]); + assert!(results[0]); + assert!(!results[1]); } #[test] @@ -520,19 +551,20 @@ mod tests { .vectorized_append(&all_nulls_input_array, &[0, 1, 2, 3, 4]) .unwrap(); - let mut equal_to_results = vec![true; all_nulls_input_array.len()]; + let mut equal_to_results = make_true_buffer(all_nulls_input_array.len()); builder.vectorized_equal_to( &[0, 1, 2, 3, 4], &all_nulls_input_array, &[0, 1, 2, 3, 4], &mut equal_to_results, ); + let results = to_vec(&equal_to_results); - assert!(equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(equal_to_results[3]); - assert!(equal_to_results[4]); + assert!(results[0]); + assert!(results[1]); + assert!(results[2]); + assert!(results[3]); + assert!(results[4]); // All not nulls input array let all_not_nulls_input_array = Arc::new(Int64Array::from(vec![ @@ -546,18 +578,55 @@ mod tests { .vectorized_append(&all_not_nulls_input_array, &[0, 1, 2, 3, 4]) .unwrap(); - let mut equal_to_results = vec![true; all_not_nulls_input_array.len()]; + let mut equal_to_results = make_true_buffer(all_not_nulls_input_array.len()); builder.vectorized_equal_to( &[5, 6, 7, 8, 9], &all_not_nulls_input_array, &[0, 1, 2, 3, 4], &mut equal_to_results, ); + let results = to_vec(&equal_to_results); + + assert!(results[0]); + assert!(results[1]); + assert!(results[2]); + assert!(results[3]); + assert!(results[4]); + } - assert!(equal_to_results[0]); - assert!(equal_to_results[1]); - assert!(equal_to_results[2]); - assert!(equal_to_results[3]); - assert!(equal_to_results[4]); + #[test] + fn test_primitive_take_n() { + // drain branch: n * 2 <= len + let mut builder = + PrimitiveGroupValueBuilder::::new(DataType::Int64); + let array = Arc::new(Int64Array::from(vec![ + Some(10), + None, + Some(30), + Some(40), + Some(50), + ])) as ArrayRef; + for i in 0..5 { + builder.append_val(&array, i).unwrap(); + } + // len=5, n=2, n*2=4 <= 5 → drain branch + let out = builder.take_n(2); + let expected = Arc::new(Int64Array::from(vec![Some(10), None])) as ArrayRef; + assert_eq!(&out, &expected); + // remaining: [30, 40, 50] + assert_eq!(builder.len(), 3); + + // split_off branch: remaining < n (len=3, n=2, n*2=4 > 3) + let out2 = builder.take_n(2); + let expected2 = Arc::new(Int64Array::from(vec![Some(30), Some(40)])) as ArrayRef; + assert_eq!(&out2, &expected2); + // remaining: [50] + assert_eq!(builder.len(), 1); + + // take the last element + let out3 = builder.take_n(1); + let expected3 = Arc::new(Int64Array::from(vec![Some(50)])) as ArrayRef; + assert_eq!(&out3, &expected3); + assert_eq!(builder.len(), 0); } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index d632a7f0ad8ac..4976a098ecee5 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -16,13 +16,17 @@ // under the License. use crate::aggregates::group_values::GroupValues; -use ahash::RandomState; -use arrow::array::{Array, ArrayRef, ListArray, RecordBatch, StructArray}; +use arrow::array::{ + Array, ArrayRef, ListArray, PrimitiveArray, RunArray, StructArray, + downcast_run_end_index, +}; use arrow::compute::cast; use arrow::datatypes::{DataType, SchemaRef}; use arrow::row::{RowConverter, Rows, SortField}; -use datafusion_common::hash_utils::create_hashes; use datafusion_common::Result; +use datafusion_common::hash_utils::RandomState; +use datafusion_common::hash_utils::create_hashes; +use datafusion_common::utils::normalize_float_zero; use datafusion_execution::memory_pool::proxy::{HashTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; use hashbrown::hash_table::HashTable; @@ -113,6 +117,13 @@ impl GroupValuesRows { impl GroupValues for GroupValuesRows { fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + // Normalize -0.0 → +0.0 so RowConverter (IEEE 754 totalOrder) and + // primitive hashing both group ±0 together. No-op for non-float + // columns. + let normalized_cols: Vec = + cols.iter().map(normalize_float_zero).collect(); + let cols = normalized_cols.as_slice(); + // Convert the group keys into the row format let group_rows = &mut self.rows_buffer; group_rows.clear(); @@ -243,17 +254,16 @@ impl GroupValues for GroupValuesRows { Ok(output) } - fn clear_shrink(&mut self, batch: &RecordBatch) { - let count = batch.num_rows(); + fn clear_shrink(&mut self, num_rows: usize) { self.group_values = self.group_values.take().map(|mut rows| { rows.clear(); rows }); self.map.clear(); - self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared + self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since the map is cleared self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); self.hashes_buffer.clear(); - self.hashes_buffer.shrink_to(count); + self.hashes_buffer.shrink_to(num_rows); } } @@ -292,6 +302,33 @@ fn dictionary_encode_if_necessary( )?)) } (DataType::Dictionary(_, _), _) => Ok(cast(array.as_ref(), expected)?), + ( + DataType::RunEndEncoded(run_ends_field, expected_values_field), + &DataType::RunEndEncoded(_, _), + ) => { + macro_rules! reencode_ree { + ($run_end_type:ty) => {{ + let run_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let values = dictionary_encode_if_necessary( + &(Arc::clone(run_array.values()) as ArrayRef), + expected_values_field.data_type(), + )?; + let run_ends = PrimitiveArray::<$run_end_type>::new( + run_array.run_ends().inner().clone(), + None, + ); + Ok(Arc::new(RunArray::try_new(&run_ends, &values)?)) + }}; + } + downcast_run_end_index! { + run_ends_field.data_type() => (reencode_ree), + _ => unreachable!("unsupported run end type: {}", run_ends_field.data_type()), + } + } + (DataType::RunEndEncoded(_, _), _) => Ok(cast(array.as_ref(), expected)?), (_, _) => Ok(Arc::::clone(array)), } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs index 44b763a91f523..e993c0c53d199 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs @@ -19,7 +19,6 @@ use crate::aggregates::group_values::GroupValues; use arrow::array::{ ArrayRef, AsArray as _, BooleanArray, BooleanBufferBuilder, NullBufferBuilder, - RecordBatch, }; use datafusion_common::Result; use datafusion_expr::EmitTo; @@ -146,7 +145,7 @@ impl GroupValues for GroupValuesBoolean { Ok(vec![Arc::new(BooleanArray::new(values, nulls)) as _]) } - fn clear_shrink(&mut self, _batch: &RecordBatch) { + fn clear_shrink(&mut self, _num_rows: usize) { self.false_group = None; self.true_group = None; self.null_group = None; diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs index b901aee313fb7..b881a51b25474 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs @@ -19,7 +19,7 @@ use std::mem::size_of; use crate::aggregates::group_values::GroupValues; -use arrow::array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch}; +use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; use datafusion_common::Result; use datafusion_expr::EmitTo; use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; @@ -120,7 +120,7 @@ impl GroupValues for GroupValuesBytes { Ok(vec![group_values]) } - fn clear_shrink(&mut self, _batch: &RecordBatch) { + fn clear_shrink(&mut self, _num_rows: usize) { // in theory we could potentially avoid this reallocation and clear the // contents of the maps, but for now we just reset the map from the beginning self.map.take(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs index be9a0334e3ee6..7a56f7c52c11a 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs @@ -16,7 +16,7 @@ // under the License. use crate::aggregates::group_values::GroupValues; -use arrow::array::{Array, ArrayRef, RecordBatch}; +use arrow::array::{Array, ArrayRef}; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap; @@ -122,7 +122,7 @@ impl GroupValues for GroupValuesBytesView { Ok(vec![group_values]) } - fn clear_shrink(&mut self, _batch: &RecordBatch) { + fn clear_shrink(&mut self, _num_rows: usize) { // in theory we could potentially avoid this reallocation and clear the // contents of the maps, but for now we just reset the map from the beginning self.map.take(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index f35c580b0e632..e254aebcfd7ce 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -16,25 +16,40 @@ // under the License. use crate::aggregates::group_values::GroupValues; -use ahash::RandomState; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; use arrow::array::{ - cast::AsArray, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder, - PrimitiveArray, + ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder, PrimitiveArray, + cast::AsArray, }; -use arrow::datatypes::{i256, DataType}; -use arrow::record_batch::RecordBatch; +use arrow::datatypes::{DataType, i256}; use datafusion_common::Result; +use datafusion_common::hash_utils::RandomState; +use datafusion_common::utils::split_vec_min_alloc; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; use half::f16; use hashbrown::hash_table::HashTable; +#[cfg(not(feature = "force_hash_collisions"))] +use std::hash::BuildHasher; use std::mem::size_of; use std::sync::Arc; /// A trait to allow hashing of floating point numbers -pub(crate) trait HashValue { +pub trait HashValue { fn hash(&self, state: &RandomState) -> u64; + + /// Return a canonical representative whose bit pattern is identical for + /// all values that should be grouped together. Default is the identity; + /// floats override this to fold `-0.0` into `+0.0` so the bit-equal + /// `is_eq` check used during insertion treats them as the same group. + /// NaN payload bits are preserved. + #[inline] + fn canonicalize(self) -> Self + where + Self: Sized, + { + self + } } macro_rules! hash_integer { @@ -61,13 +76,20 @@ macro_rules! hash_float { $(impl HashValue for $t { #[cfg(not(feature = "force_hash_collisions"))] fn hash(&self, state: &RandomState) -> u64 { - state.hash_one(self.to_bits()) + state.hash_one(self.canonicalize().to_bits()) } #[cfg(feature = "force_hash_collisions")] fn hash(&self, _state: &RandomState) -> u64 { 0 } + + #[inline] + fn canonicalize(self) -> Self { + let bits = self.to_bits(); + let bits = if bits << 1 == 0 { 0 } else { bits }; + Self::from_bits(bits) + } })+ }; } @@ -125,11 +147,17 @@ where group_id }), Some(key) => { + // Fold equivalence-class duplicates (e.g. `-0.0` → `+0.0`) + // so the bit-equal `is_eq` matches and the stored value is + // the canonical representative. + let key = key.canonicalize(); let state = &self.random_state; let hash = key.hash(state); let insert = self.map.entry( hash, - |&(g, _)| unsafe { self.values.get_unchecked(g).is_eq(key) }, + |&(g, h)| unsafe { + hash == h && self.values.get_unchecked(g).is_eq(key) + }, |&(_, h)| h, ); @@ -204,20 +232,65 @@ where Some(_) => self.null_group.take(), None => None, }; - let mut split = self.values.split_off(n); - std::mem::swap(&mut self.values, &mut split); - build_primitive(split, null_group) + build_primitive(split_vec_min_alloc(&mut self.values, n), null_group) } }; Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))]) } - fn clear_shrink(&mut self, batch: &RecordBatch) { - let count = batch.num_rows(); + fn clear_shrink(&mut self, num_rows: usize) { self.values.clear(); - self.values.shrink_to(count); + self.values.shrink_to(num_rows); self.map.clear(); - self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared + self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since the map is cleared + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::types::Int32Type; + use arrow::array::{ArrayRef, Int32Array}; + use arrow::datatypes::DataType; + use datafusion_expr::EmitTo; + use std::sync::Arc; + + /// Mirror of the `EmitTo::take_needed` regression test, applied to the + /// concrete `GroupValuesPrimitive` accumulator. + /// + /// When `n` is small, the old `split_off(n) + swap` pattern used inside + /// `emit(EmitTo::First(n))` left `self.values` with a small fresh allocation + /// and returned the emitted prefix carrying the original large backing. + /// + /// With `split_vec_min_alloc` and `n * 2 <= len`, the drain branch is taken: + /// the emitted prefix gets a compact allocation and `self.values` retains the + /// original large one. + #[test] + fn emit_first_small_n_allocates_minimally() -> Result<()> { + let mut gv = GroupValuesPrimitive::::new(DataType::Int32); + + // Intern 20 distinct values; `new()` pre-allocates capacity 128 for `values`. + let arr: ArrayRef = Arc::new(Int32Array::from_iter_values(0..20i32)); + let mut groups = vec![]; + gv.intern(&[arr], &mut groups)?; + let capacity_before = gv.values.capacity(); // 128 + + // n=4, n*2=8 <= len=20 -> drain branch + let emitted = gv.emit(EmitTo::First(4))?; + + assert_eq!(emitted[0].len(), 4); + + // `self.values` must retain its original large allocation. + // Old split_off+swap left it with a fresh small allocation (~16). + assert_eq!( + gv.values.capacity(), + capacity_before, + "self.values capacity {} should equal original {} after small First(n) emit", + gv.values.capacity(), + capacity_before, + ); + + Ok(()) } } diff --git a/datafusion/physical-plan/src/aggregates/hash_aggregate.rs b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs new file mode 100644 index 0000000000000..0c8593efd05bb --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/hash_aggregate.rs @@ -0,0 +1,425 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! 2-stage hash aggregation stream implementation. +//! +//! See comments in [`PartialHashAggregateStream`] and [`FinalHashAggregateStream`] +//! for details. +//! +//! Note these streams are an incremental migration of the existing +//! [`crate::aggregates::row_hash::GroupedHashAggregateStream`]. +//! +//! See issue for details: + +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::Result; +use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use futures::stream::{Stream, StreamExt}; + +use super::AggregateExec; +use super::hash_table::{AggregateHashTable, Final, Partial}; +use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput, SpillMetrics}; +use crate::stream::EmptyRecordBatchStream; +use crate::{InputOrderMode, RecordBatchStream, SendableRecordBatchStream, metrics}; + +/// Hash aggregation uses a 2-stage (partial and final) hash aggregation, this stream +/// is for the partial stage. +/// +/// # Example +/// +/// select k, avg(v) from t group by k; +/// +/// ## Plan +/// AggregateExec(stage=final) +/// -- RepartitionExec(hash(k)) +/// ---- AggregateExec(stage=partial) +/// +/// ## Partial Stage Behavior +/// Input: raw rows +/// Output: partial states for all groups (e.g. for avg(x), it's sum(x), count(x)) +/// +/// ## Final Stage Behavior +/// Input: partial states +/// Output: results for all groups (e.g. for avg(x), it's avg(x) calculated from the state) +/// +/// # Optimization: DISTINCT LIMIT Soft Limit +/// +/// This optimization applies to both [`PartialHashAggregateStream`] and [`FinalHashAggregateStream`] +/// +/// Unordered distinct queries such as: +/// +/// ```sql +/// SELECT DISTINCT x FROM t LIMIT 10; +/// ``` +/// +/// are optimized into a two-stage aggregate like: +/// +/// ```txt +/// LimitExec, limit=10 +/// --AggregateExec(Final), group_by=[x], aggr=[], soft_limit=10 +/// ---- RepartitionExec, partitioning=hash(x) +/// ------ AggregateExec(Partial), group_by=[x], aggr=[], soft_limit=10 +/// -------- Scan(t) +/// ``` +/// +/// After each input batch, the stream checks whether the soft limit has been +/// reached. If so, it emits the accumulated groups and stops reading input. +/// +/// This operator does not guarantee an exact limit because a single batch can +/// cross the threshold. The downstream limit operator enforces the exact result +/// size. +pub(crate) struct PartialHashAggregateStream { + /// Output schema: group columns followed by partial aggregate state columns. + schema: SchemaRef, + + /// Input batches containing raw rows, not partial aggregate state. + input: SendableRecordBatchStream, + + /// Hash table state for this aggregate stream. + hash_table: AggregateHashTable, + + /// Memory reservation for group keys and accumulators. + reservation: MemoryReservation, + + /// Execution metrics shared with the aggregate plan node. + baseline_metrics: BaselineMetrics, + + /// Tracks partial aggregation row reduction, matching `GroupedHashAggregateStream`. + reduction_factor: metrics::RatioMetrics, + + /// Optional soft limit on the number of groups to accumulate before output. + /// + /// Invariant: when this is `Some(..)`, the accumulators inside `hash_table` must + /// be empty. See struct comments for details. + group_values_soft_limit: Option, +} + +/// Hash aggregation uses a 2-stage (partial and final) hash aggregation, this stream +/// is for the final stage. +/// +/// See [`PartialHashAggregateStream`] for details. +pub(crate) struct FinalHashAggregateStream { + /// Output schema: group columns followed by final aggregate value columns. + schema: SchemaRef, + + /// Input batches containing partial aggregate state rows. + input: SendableRecordBatchStream, + + /// Hash table state for this aggregate stream. + hash_table: AggregateHashTable, + + /// Execution metrics shared with the aggregate plan node. + baseline_metrics: BaselineMetrics, + + /// Memory reservation for group keys and accumulators. + reservation: MemoryReservation, + + /// See comments for the same variable in [`PartialHashAggregateStream`] + group_values_soft_limit: Option, +} + +impl PartialHashAggregateStream { + pub fn new( + agg: &AggregateExec, + context: &Arc, + partition: usize, + ) -> Result { + debug_assert_eq!(agg.mode, super::AggregateMode::Partial); + debug_assert_eq!(agg.input_order_mode, InputOrderMode::Linear); + + let schema = Arc::clone(&agg.schema); + let input = agg.input.execute(partition, Arc::clone(context))?; + let batch_size = context.session_config().batch_size(); + let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); + + // Preserve the existing aggregate metric surface for this plan node. + let _spill_metrics = SpillMetrics::new(&agg.metrics, partition); + let reduction_factor = MetricBuilder::new(&agg.metrics) + .with_type(metrics::MetricType::Summary) + .ratio_metrics("reduction_factor", partition); + + let hash_table = AggregateHashTable::::new( + agg, + partition, + Arc::clone(&schema), + batch_size, + )?; + + let reservation = + MemoryConsumer::new(format!("PartialHashAggregateStream[{partition}]")) + .register(context.memory_pool()); + + Ok(Self { + schema, + input, + hash_table, + baseline_metrics, + reservation, + reduction_factor, + group_values_soft_limit: agg.limit_options().map(|config| config.limit()), + }) + } + + /// See comments in [`Self::group_values_soft_limit`] for details. + fn hit_soft_group_limit(&self) -> bool { + self.group_values_soft_limit + .is_some_and(|limit| limit <= self.hash_table.building_group_count()) + } + + fn start_output(&mut self) -> Result<()> { + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); + self.hash_table.start_output() + } +} + +impl Stream for PartialHashAggregateStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + + loop { + if self.hash_table.is_done() { + let _ = self.reservation.try_resize(0); + return Poll::Ready(None); + } else if self.hash_table.is_building() { + match self.input.poll_next_unpin(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Ok(batch))) => { + let timer = elapsed_compute.timer(); + self.reduction_factor.add_total(batch.num_rows()); + let result = self.hash_table.aggregate_batch(&batch); + timer.done(); + + if let Err(e) = result { + return Poll::Ready(Some(Err(e))); + } + + if self.hit_soft_group_limit() { + let timer = elapsed_compute.timer(); + let result = self.start_output(); + timer.done(); + + if let Err(e) = result { + return Poll::Ready(Some(Err(e))); + } + + continue; + } + + // TODO: impl memory-limited aggr, when OOM directly send + // partial state to final aggregate stage + if let Err(e) = + self.reservation.try_resize(self.hash_table.memory_size()) + { + return Poll::Ready(Some(Err(e))); + } + } + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some(Err(e))); + } + Poll::Ready(None) => { + let timer = elapsed_compute.timer(); + let result = self.start_output(); + timer.done(); + + if let Err(e) = result { + return Poll::Ready(Some(Err(e))); + } + } + } + } else { + let timer = elapsed_compute.timer(); + let result = self.hash_table.next_output_batch(); + timer.done(); + + match result { + Ok(Some(batch)) => { + let _ = + self.reservation.try_resize(self.hash_table.memory_size()); + self.reduction_factor.add_part(batch.num_rows()); + debug_assert!(batch.num_rows() > 0); + return Poll::Ready(Some(Ok( + batch.record_output(&self.baseline_metrics) + ))); + } + Ok(None) => { + let _ = self.reservation.try_resize(0); + return Poll::Ready(None); + } + Err(e) => return Poll::Ready(Some(Err(e))), + } + } + } + } +} + +impl RecordBatchStream for PartialHashAggregateStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +impl FinalHashAggregateStream { + pub fn new( + agg: &AggregateExec, + context: &Arc, + partition: usize, + ) -> Result { + debug_assert!(matches!( + agg.mode, + super::AggregateMode::Final | super::AggregateMode::FinalPartitioned + )); + debug_assert_eq!(agg.input_order_mode, InputOrderMode::Linear); + + let schema = Arc::clone(&agg.schema); + let input = agg.input.execute(partition, Arc::clone(context))?; + let batch_size = context.session_config().batch_size(); + let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); + + // Preserve the existing aggregate metric surface for this plan node. + let _spill_metrics = SpillMetrics::new(&agg.metrics, partition); + + let hash_table = AggregateHashTable::::new( + agg, + partition, + Arc::clone(&schema), + batch_size, + )?; + + let reservation = + MemoryConsumer::new(format!("FinalHashAggregateStream[{partition}]")) + .register(context.memory_pool()); + + Ok(Self { + schema, + input, + hash_table, + baseline_metrics, + reservation, + group_values_soft_limit: agg.limit_options().map(|config| config.limit()), + }) + } + + /// See comments in [`Self::group_values_soft_limit`] for details. + fn hit_soft_group_limit(&self) -> bool { + self.group_values_soft_limit + .is_some_and(|limit| limit <= self.hash_table.building_group_count()) + } + + fn start_output(&mut self) -> Result<()> { + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); + self.hash_table.start_output() + } +} + +impl Stream for FinalHashAggregateStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + + loop { + if self.hash_table.is_done() { + let _ = self.reservation.try_resize(0); + return Poll::Ready(None); + } else if self.hash_table.is_building() { + match self.input.poll_next_unpin(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Some(Ok(batch))) => { + let timer = elapsed_compute.timer(); + let result = self.hash_table.aggregate_batch(&batch); + timer.done(); + + if let Err(e) = result { + return Poll::Ready(Some(Err(e))); + } + + if self.hit_soft_group_limit() { + let timer = elapsed_compute.timer(); + let result = self.start_output(); + timer.done(); + + if let Err(e) = result { + return Poll::Ready(Some(Err(e))); + } + + continue; + } + + if let Err(e) = + self.reservation.try_resize(self.hash_table.memory_size()) + { + return Poll::Ready(Some(Err(e))); + } + } + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some(Err(e))); + } + Poll::Ready(None) => { + let timer = elapsed_compute.timer(); + let result = self.start_output(); + timer.done(); + + if let Err(e) = result { + return Poll::Ready(Some(Err(e))); + } + } + } + } else { + let timer = elapsed_compute.timer(); + let result = self.hash_table.next_output_batch(); + timer.done(); + + match result { + Ok(Some(batch)) => { + let _ = + self.reservation.try_resize(self.hash_table.memory_size()); + debug_assert!(batch.num_rows() > 0); + return Poll::Ready(Some(Ok( + batch.record_output(&self.baseline_metrics) + ))); + } + Ok(None) => { + let _ = self.reservation.try_resize(0); + return Poll::Ready(None); + } + Err(e) => return Poll::Ready(Some(Err(e))), + } + } + } + } +} + +impl RecordBatchStream for FinalHashAggregateStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} diff --git a/datafusion/physical-plan/src/aggregates/hash_table.rs b/datafusion/physical-plan/src/aggregates/hash_table.rs new file mode 100644 index 0000000000000..87f16d0eebe6f --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/hash_table.rs @@ -0,0 +1,623 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::marker::PhantomData; +use std::sync::Arc; + +use arrow::array::{ArrayRef, AsArray, BooleanArray, new_null_array}; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::{Result, internal_err}; +use datafusion_execution::memory_pool::proxy::VecAllocExt; +use datafusion_expr::{EmitTo, GroupsAccumulator}; + +use super::group_values::{GroupByMetrics, GroupValues, new_group_values}; +use super::order::GroupOrdering; +use super::row_hash::create_group_accumulator; +use super::{ + AggregateExec, PhysicalGroupBy, aggregate_expressions, evaluate_group_by, + group_id_array, max_duplicate_ordinal, +}; +use crate::PhysicalExpr; +use crate::metrics::{MetricBuilder, MetricCategory}; + +/// Marker for raw rows -> partial state aggregation. +pub(super) struct Partial; +/// Marker for partial state -> final value aggregation. +pub(super) struct Final; + +/// Grouped hash table shared by the partial and final paths. +/// +/// While building, it consumes input batches and updates group / accumulator +/// state. While outputting, it incrementally output the materialized batches. +/// +/// # Marker Type +/// `AggrMode` selects the aggregate semantics. +/// +/// e.g. `AggregateHashTable::::new(...)` creates an aggregate hash table +/// for the partial hash aggregate stage, the input schema is raw rows and output +/// schema is intermediate states. +/// +/// It is a zero-sized compile-time marker, so each stage keeps its update logic +/// in a separate impl block, to make the behavior difference explicit. +pub(super) struct AggregateHashTable { + /// Grouping and accumulator-specific timing metrics. + group_by_metrics: GroupByMetrics, + + /// Raw input schema, used to evaluate expressions and synthesize empty + /// grouping-set rows. + input_schema: SchemaRef, + + /// Output schema: group columns followed by aggregate state or final values. + output_schema: SchemaRef, + + /// Maximum rows per emitted output batch. + batch_size: usize, + + /// Lifecycle-specific state: building stage / outputting stage + state: AggregateHashTableState, + + _mode: PhantomData, +} + +struct HashAggregateAccumulator { + /// Arguments to pass to this accumulator. + /// + /// Example: `CORR(x, y)` stores two expressions here, while `SUM(x)` stores one. + arguments: Vec>, + + /// Optional `FILTER` expression for this accumulator. + /// + /// Example: `SUM(x) FILTER (WHERE x > 10)` stores the `x > 10` predicate. + filter: Option>, + + /// Accumulator state for all groups for one aggregate expression. + accumulator: Box, +} + +struct EvaluatedHashAggregateAccumulator { + arguments: Vec, + filter: Option, +} + +/// Evaluated all group by keys and accumulator args. +/// +/// e.g., `select k+1, sum(v*v) from t group by (k+1)`, this function evaluates +/// `k+1`, `v*v` +struct EvaluatedAggregateBatch { + /// One entry per grouping set; each entry contains all evaluated group key + /// arrays for the current input batch. + grouping_set_args: Vec>, + + /// Evaluated arguments and filters, one entry per aggregate expression. + accumulator_args: Vec, +} + +/// Hash table state while grouped aggregation is consuming input. +/// +/// This owns the coupled state for: +/// - evaluating group keys, +/// - interning each distinct group, +/// - mapping each input row to its group index, +/// - evaluating aggregate inputs, +/// - updating per-group accumulator state. +struct BuildingHashTableState { + /// GROUP BY expressions evaluated for each input batch. + group_by: Arc, + + /// Interned group keys. Accumulator state is stored separately by group index. + group_values: Box, + + /// Group index for each row in the current input batch. + /// + /// Each value indexes into `group_values`, and the same index is used by every + /// accumulator to update that group's aggregate state. + batch_group_indices: Vec, + + /// One item per aggregate expression. + /// + /// Example: `COUNT(x), SUM(y)` creates two items. Each item owns the input + /// expressions, optional filter, and accumulator state for all groups. + accumulators: Vec, +} + +enum AggregateHashTableState { + Building(BuildingHashTableState), + Outputting { + output_batch: Option, + output_batch_offset: usize, + }, + Done, +} + +impl HashAggregateAccumulator { + fn new( + arguments: Vec>, + filter: Option>, + accumulator: Box, + ) -> Self { + Self { + arguments, + filter, + accumulator, + } + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let arguments = self + .arguments + .iter() + .map(|expr| { + expr.evaluate(batch) + .and_then(|value| value.into_array(batch.num_rows())) + }) + .collect::>()?; + + let filter = self + .filter + .as_ref() + .map(|filter| { + filter + .evaluate(batch) + .and_then(|value| value.into_array(batch.num_rows())) + }) + .transpose()?; + + Ok(EvaluatedHashAggregateAccumulator { arguments, filter }) + } + + fn update_batch( + &mut self, + values: &EvaluatedHashAggregateAccumulator, + group_indices: &[usize], + total_num_groups: usize, + ) -> Result<()> { + let filter = values.filter.as_ref().map(|filter| filter.as_boolean()); + self.accumulator.update_batch( + &values.arguments, + group_indices, + filter, + total_num_groups, + ) + } + + fn merge_batch( + &mut self, + values: &EvaluatedHashAggregateAccumulator, + group_indices: &[usize], + total_num_groups: usize, + ) -> Result<()> { + debug_assert!(values.filter.is_none()); + self.accumulator.merge_batch( + &values.arguments, + group_indices, + None, + total_num_groups, + ) + } + + fn evaluate_final(&mut self, emit_to: EmitTo) -> Result { + self.accumulator.evaluate(emit_to) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + self.accumulator.state(emit_to) + } + + fn supports_convert_to_state(&self) -> bool { + self.accumulator.supports_convert_to_state() + } + + fn null_arguments(&self, input_schema: &SchemaRef) -> Result> { + self.arguments + .iter() + .map(|expr| { + let data_type = expr.data_type(input_schema)?; + Ok(new_null_array(&data_type, 1)) + }) + .collect() + } +} + +impl AggregateHashTableState { + fn building(&self) -> &BuildingHashTableState { + let Self::Building(state) = self else { + unreachable!("hash aggregate table is not building") + }; + state + } + + fn building_mut(&mut self) -> &mut BuildingHashTableState { + let Self::Building(state) = self else { + unreachable!("hash aggregate table is not building") + }; + state + } +} + +impl AggregateHashTable { + fn new_with_filters( + agg: &AggregateExec, + partition: usize, + output_schema: SchemaRef, + batch_size: usize, + filters: Vec>>, + ) -> Result { + let input_schema = agg.input().schema(); + let aggregate_arguments = aggregate_expressions( + &agg.aggr_expr, + &agg.mode, + agg.group_by.num_group_exprs(), + )?; + let accumulators: Vec<_> = agg + .aggr_expr + .iter() + .zip(aggregate_arguments) + .zip(filters) + .map(|((agg_expr, arguments), filter)| { + let accumulator = create_group_accumulator(agg_expr)?; + Ok(HashAggregateAccumulator::new( + arguments, + filter, + accumulator, + )) + }) + .collect::>()?; + + let group_schema = agg.group_by.group_schema(&input_schema)?; + let group_values = new_group_values(group_schema, &GroupOrdering::None)?; + + Ok(Self { + group_by_metrics: GroupByMetrics::new(&agg.metrics, partition), + input_schema, + output_schema, + batch_size, + state: AggregateHashTableState::Building(BuildingHashTableState { + group_by: Arc::clone(&agg.group_by), + group_values, + batch_group_indices: Default::default(), + accumulators, + }), + _mode: PhantomData, + }) + } + + /// See comments in [`EvaluatedAggregateBatch`] + fn evaluate_batch(&self, batch: &RecordBatch) -> Result { + let state = self.state.building(); + let timer = self.group_by_metrics.time_calculating_group_ids.timer(); + // outer vec: one per each grouping set + // inner vec: all group by exprs for the current grouping set + let grouping_set_args = evaluate_group_by(&state.group_by, batch)?; + drop(timer); + + let timer = self.group_by_metrics.aggregate_arguments_time.timer(); + // The evaluated args for each accumulator + let accumulator_args = self + .state + .building() + .accumulators + .iter() + .map(|acc| acc.evaluate(batch)) + .collect::>>()?; + drop(timer); + + Ok(EvaluatedAggregateBatch { + grouping_set_args, + accumulator_args, + }) + } + + pub(super) fn memory_size(&self) -> usize { + match &self.state { + AggregateHashTableState::Building(state) => { + let acc = state + .accumulators + .iter() + .map(|acc| acc.accumulator.size()) + .sum::(); + + acc + state.group_values.size() + + state.batch_group_indices.allocated_size() + } + AggregateHashTableState::Outputting { output_batch, .. } => { + output_batch_memory_size(output_batch) + } + AggregateHashTableState::Done => 0, + } + } + + pub(super) fn building_group_count(&self) -> usize { + self.state.building().group_values.len() + } + + pub(super) fn is_building(&self) -> bool { + matches!(self.state, AggregateHashTableState::Building(_)) + } + + pub(super) fn is_done(&self) -> bool { + matches!(self.state, AggregateHashTableState::Done) + } + + fn set_output_batch(&mut self, output_batch: Option) { + self.state = AggregateHashTableState::Outputting { + output_batch, + output_batch_offset: 0, + }; + } + + pub(super) fn next_output_batch(&mut self) -> Result> { + match std::mem::replace(&mut self.state, AggregateHashTableState::Done) { + AggregateHashTableState::Outputting { + output_batch, + mut output_batch_offset, + } => { + let Some(batch) = output_batch.as_ref() else { + return Ok(None); + }; + + let num_rows = batch.num_rows(); + if output_batch_offset >= num_rows { + return Ok(None); + } + + debug_assert!(self.batch_size > 0); + let output_len = + self.batch_size.max(1).min(num_rows - output_batch_offset); + let output = batch.slice(output_batch_offset, output_len); + output_batch_offset += output_len; + + if output_batch_offset == num_rows { + self.state = AggregateHashTableState::Done; + } else { + self.state = AggregateHashTableState::Outputting { + output_batch, + output_batch_offset, + }; + } + + debug_assert!(output.num_rows() > 0); + debug_assert!(output.num_rows() <= self.batch_size.max(1)); + Ok(Some(output)) + } + _ => { + self.state = AggregateHashTableState::Done; + internal_err!("next_output_batch must be called in the outputting state") + } + } + } +} + +impl AggregateHashTable { + pub(super) fn new( + agg: &AggregateExec, + partition: usize, + output_schema: SchemaRef, + batch_size: usize, + ) -> Result { + let table = Self::new_with_filters( + agg, + partition, + output_schema, + batch_size, + agg.filter_expr.iter().cloned().collect(), + )?; + + if table + .state + .building() + .accumulators + .iter() + .all(|acc| acc.supports_convert_to_state()) + { + let _skipped_aggregation_rows = MetricBuilder::new(&agg.metrics) + .with_category(MetricCategory::Rows) + .counter("skipped_aggregation_rows", partition); + } + + Ok(table) + } + + pub(super) fn aggregate_batch(&mut self, batch: &RecordBatch) -> Result<()> { + let evaluated_batch = self.evaluate_batch(batch)?; + let state = self.state.building_mut(); + + let timer = self.group_by_metrics.aggregation_time.timer(); + for group_values in &evaluated_batch.grouping_set_args { + state + .group_values + .intern(group_values, &mut state.batch_group_indices)?; + let group_indices = &state.batch_group_indices; + let total_num_groups = state.group_values.len(); + + for (acc, values) in state + .accumulators + .iter_mut() + .zip(evaluated_batch.accumulator_args.iter()) + { + acc.update_batch(values, group_indices, total_num_groups)?; + } + } + drop(timer); + + Ok(()) + } + + pub(super) fn start_output(&mut self) -> Result<()> { + self.init_empty_grouping_sets()?; + let state = self.state.building_mut(); + + let output_batch = if state.group_values.is_empty() { + None + } else { + let timer = self.group_by_metrics.emitting_time.timer(); + let mut output = state.group_values.emit(EmitTo::All)?; + + for acc in state.accumulators.iter_mut() { + output.extend(acc.state(EmitTo::All)?); + } + + let batch = RecordBatch::try_new(Arc::clone(&self.output_schema), output)?; + debug_assert!(batch.num_rows() > 0); + drop(timer); + Some(batch) + }; + + self.set_output_batch(output_batch); + Ok(()) + } + + /// Creates the required empty grouping-set rows when the input is empty. + /// + /// For example, this query must still produce one grand-total group even if + /// `t` has no rows: + /// + /// ```sql + /// SELECT COUNT(v) + /// FROM t + /// GROUP BY GROUPING SETS (()); + /// ``` + /// + /// The synthetic row is filtered out before accumulator update so aggregates + /// see the same state they would see for an empty input, rather than a real + /// null-valued row. + fn init_empty_grouping_sets(&mut self) -> Result<()> { + let state = self.state.building_mut(); + if !state.group_by.has_grouping_set() || !state.group_values.is_empty() { + return Ok(()); + } + + let max_ordinal = max_duplicate_ordinal(state.group_by.groups()); + let mut ordinals: HashMap<&[bool], usize> = HashMap::new(); + let group_schema = state.group_by.group_schema(&self.input_schema)?; + let n_expr = state.group_by.expr().len(); + let mut any_interned = false; + + for group in state.group_by.groups() { + let ordinal = { + let entry = ordinals.entry(group.as_slice()).or_insert(0); + let ordinal = *entry; + *entry += 1; + ordinal + }; + + if !group.iter().all(|&is_null| is_null) { + continue; + } + + let mut cols: Vec = group_schema + .fields() + .iter() + .take(n_expr) + .map(|field| new_null_array(field.data_type(), 1)) + .collect(); + cols.push(group_id_array(group, ordinal, max_ordinal, 1)?); + + state + .group_values + .intern(&cols, &mut state.batch_group_indices)?; + any_interned = true; + } + + if any_interned { + let total_groups = state.group_values.len(); + let false_filter = BooleanArray::from(vec![false]); + for acc in state.accumulators.iter_mut() { + let null_args = acc.null_arguments(&self.input_schema)?; + let values = EvaluatedHashAggregateAccumulator { + arguments: null_args, + filter: Some(Arc::new(false_filter.clone())), + }; + acc.update_batch(&values, &[0], total_groups)?; + } + } + + Ok(()) + } +} + +impl AggregateHashTable { + pub(super) fn new( + agg: &AggregateExec, + partition: usize, + output_schema: SchemaRef, + batch_size: usize, + ) -> Result { + Self::new_with_filters( + agg, + partition, + output_schema, + batch_size, + vec![None; agg.aggr_expr.len()], + ) + } + + pub(super) fn aggregate_batch(&mut self, batch: &RecordBatch) -> Result<()> { + let evaluated_batch = self.evaluate_batch(batch)?; + let state = self.state.building_mut(); + + let timer = self.group_by_metrics.aggregation_time.timer(); + for group_values in &evaluated_batch.grouping_set_args { + state + .group_values + .intern(group_values, &mut state.batch_group_indices)?; + let group_indices = &state.batch_group_indices; + let total_num_groups = state.group_values.len(); + + for (acc, values) in state + .accumulators + .iter_mut() + .zip(evaluated_batch.accumulator_args.iter()) + { + acc.merge_batch(values, group_indices, total_num_groups)?; + } + } + drop(timer); + + Ok(()) + } + + pub(super) fn start_output(&mut self) -> Result<()> { + let state = self.state.building_mut(); + let output_batch = if state.group_values.is_empty() { + None + } else { + let timer = self.group_by_metrics.emitting_time.timer(); + let mut output = state.group_values.emit(EmitTo::All)?; + + for acc in state.accumulators.iter_mut() { + output.push(acc.evaluate_final(EmitTo::All)?); + } + + let batch = RecordBatch::try_new(Arc::clone(&self.output_schema), output)?; + debug_assert!(batch.num_rows() > 0); + drop(timer); + Some(batch) + }; + + self.set_output_batch(output_batch); + Ok(()) + } +} + +fn output_batch_memory_size(output_batch: &Option) -> usize { + output_batch + .as_ref() + .map(RecordBatch::get_array_memory_size) + .unwrap_or_default() +} diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 6c59195f76358..a5f1621812561 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -17,45 +17,50 @@ //! Aggregates functionalities -use std::any::Any; +use std::borrow::Cow; use std::sync::Arc; use super::{DisplayAs, ExecutionPlanProperties, PlanProperties}; use crate::aggregates::{ - no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream, + hash_aggregate::{FinalHashAggregateStream, PartialHashAggregateStream}, + no_grouping::AggregateStream, + row_hash::GroupedHashAggregateStream, topk_stream::GroupedTopKAggregateStream, }; use crate::execution_plan::{CardinalityEffect, EmissionType}; use crate::filter_pushdown::{ - ChildFilterDescription, FilterDescription, FilterPushdownPhase, PushedDownPredicate, + ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, PushedDownPredicate, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::windows::get_ordered_partition_by_indices; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, - SendableRecordBatchStream, Statistics, + SendableRecordBatchStream, Statistics, check_if_same_properties, }; use datafusion_common::config::ConfigOptions; use datafusion_physical_expr::utils::collect_columns; -use std::collections::HashSet; +use parking_lot::Mutex; +use std::collections::{HashMap, HashSet}; -use arrow::array::{ArrayRef, UInt16Array, UInt32Array, UInt64Array, UInt8Array}; -use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow_schema::FieldRef; use datafusion_common::stats::Precision; use datafusion_common::{ - assert_eq_or_internal_err, not_impl_err, Constraint, Constraints, Result, + Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err, + internal_err, not_impl_err, }; use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::MemoryLimit; use datafusion_expr::{Accumulator, Aggregate}; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr::equivalence::ProjectionMapping; -use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr, lit}; use datafusion_physical_expr::{ - physical_exprs_contains, ConstExpr, EquivalenceProperties, + ConstExpr, EquivalenceProperties, physical_exprs_contains, }; -use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExpr}; +use datafusion_physical_expr_common::physical_expr::{PhysicalExpr, fmt_sql}; use datafusion_physical_expr_common::sort_expr::{ LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement, }; @@ -63,23 +68,83 @@ use datafusion_physical_expr_common::sort_expr::{ use datafusion_expr::utils::AggregateOrderSensitivity; use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays; use itertools::Itertools; +use topk::hash_table::is_supported_hash_key_type; +use topk::heap::is_supported_heap_type; pub mod group_values; +mod hash_aggregate; +mod hash_table; mod no_grouping; pub mod order; mod row_hash; mod topk; mod topk_stream; +/// Returns true if TopK aggregation data structures support the provided key and value types. +/// +/// This function checks whether both the key type (used for grouping) and value type +/// (used in min/max aggregation) can be handled by the TopK aggregation heap and hash table. +/// Supported types include Arrow primitives (integers, floats, decimals, intervals) and +/// UTF-8 strings (`Utf8`, `LargeUtf8`, `Utf8View`). +/// ```text +pub fn topk_types_supported(key_type: &DataType, value_type: &DataType) -> bool { + is_supported_hash_key_type(key_type) && is_supported_heap_type(value_type) +} + /// Hard-coded seed for aggregations to ensure hash values differ from `RepartitionExec`, avoiding collisions. -const AGGREGATION_HASH_SEED: ahash::RandomState = - ahash::RandomState::with_seeds('A' as u64, 'G' as u64, 'G' as u64, 'R' as u64); +const AGGREGATION_HASH_SEED: datafusion_common::hash_utils::RandomState = + // This seed is chosen to be a large 64-bit number + datafusion_common::hash_utils::RandomState::with_seed(15395726432021054657); + +/// Whether an aggregate stage consumes raw input data or intermediate +/// accumulator state from a previous aggregation stage. +/// +/// See the [table on `AggregateMode`](AggregateMode#variants-and-their-inputoutput-modes) +/// for how this relates to aggregate modes. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum AggregateInputMode { + /// The stage consumes raw, unaggregated input data and calls + /// [`Accumulator::update_batch`]. + Raw, + /// The stage consumes intermediate accumulator state from a previous + /// aggregation stage and calls [`Accumulator::merge_batch`]. + Partial, +} + +/// Whether an aggregate stage produces intermediate accumulator state +/// or final output values. +/// +/// See the [table on `AggregateMode`](AggregateMode#variants-and-their-inputoutput-modes) +/// for how this relates to aggregate modes. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum AggregateOutputMode { + /// The stage produces intermediate accumulator state, serialized via + /// [`Accumulator::state`]. + Partial, + /// The stage produces final output values via + /// [`Accumulator::evaluate`]. + Final, +} /// Aggregation modes /// /// See [`Accumulator::state`] for background information on multi-phase /// aggregation and how these modes are used. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +/// +/// # Variants and their input/output modes +/// +/// Each variant can be characterized by its [`AggregateInputMode`] and +/// [`AggregateOutputMode`]: +/// +/// ```text +/// | Input: Raw data | Input: Partial state +/// Output: Final values | Single, SinglePartitioned | Final, FinalPartitioned +/// Output: Partial state | Partial | PartialReduce +/// ``` +/// +/// Use [`AggregateMode::input_mode`] and [`AggregateMode::output_mode`] +/// to query these properties. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum AggregateMode { /// One of multiple layers of aggregation, any input partitioning /// @@ -130,18 +195,62 @@ pub enum AggregateMode { /// This mode requires that the input has more than one partition, and is /// partitioned by group key (like FinalPartitioned). SinglePartitioned, + /// Combine multiple partial aggregations to produce a new partial + /// aggregation. + /// + /// Input is intermediate accumulator state (like Final), but output is + /// also intermediate accumulator state (like Partial). This enables + /// tree-reduce aggregation strategies where partial results from + /// multiple workers are combined in multiple stages before a final + /// evaluation. + /// + /// ```text + /// Final + /// / \ + /// PartialReduce PartialReduce + /// / \ / \ + /// Partial Partial Partial Partial + /// ``` + /// + /// # Motivation + /// + /// This reduces shuffling traffic in a distributed setting. See + /// + /// for details. + PartialReduce, } impl AggregateMode { - /// Checks whether this aggregation step describes a "first stage" calculation. - /// In other words, its input is not another aggregation result and the - /// `merge_batch` method will not be called for these modes. - pub fn is_first_stage(&self) -> bool { + /// Returns the [`AggregateInputMode`] for this mode: whether this + /// stage consumes raw input data or intermediate accumulator state. + /// + /// See the [table above](AggregateMode#variants-and-their-inputoutput-modes) + /// for details. + pub fn input_mode(&self) -> AggregateInputMode { match self { AggregateMode::Partial | AggregateMode::Single - | AggregateMode::SinglePartitioned => true, - AggregateMode::Final | AggregateMode::FinalPartitioned => false, + | AggregateMode::SinglePartitioned => AggregateInputMode::Raw, + AggregateMode::Final + | AggregateMode::FinalPartitioned + | AggregateMode::PartialReduce => AggregateInputMode::Partial, + } + } + + /// Returns the [`AggregateOutputMode`] for this mode: whether this + /// stage produces intermediate accumulator state or final output values. + /// + /// See the [table above](AggregateMode#variants-and-their-inputoutput-modes) + /// for details. + pub fn output_mode(&self) -> AggregateOutputMode { + match self { + AggregateMode::Final + | AggregateMode::FinalPartitioned + | AggregateMode::Single + | AggregateMode::SinglePartitioned => AggregateOutputMode::Final, + AggregateMode::Partial | AggregateMode::PartialReduce => { + AggregateOutputMode::Partial + } } } } @@ -175,6 +284,9 @@ pub struct PhysicalGroupBy { /// expression in null_expr. If `groups[i][j]` is true, then the /// j-th expression in the i-th group is NULL, otherwise it is `expr[j]`. groups: Vec>, + /// True when GROUPING SETS/CUBE/ROLLUP are used so `__grouping_id` should + /// be included in the output schema. + has_grouping_set: bool, } impl PhysicalGroupBy { @@ -183,11 +295,13 @@ impl PhysicalGroupBy { expr: Vec<(Arc, String)>, null_expr: Vec<(Arc, String)>, groups: Vec>, + has_grouping_set: bool, ) -> Self { Self { expr, null_expr, groups, + has_grouping_set, } } @@ -199,6 +313,7 @@ impl PhysicalGroupBy { expr, null_expr: vec![], groups: vec![vec![false; num_exprs]], + has_grouping_set: false, } } @@ -215,6 +330,11 @@ impl PhysicalGroupBy { exprs_nullable } + /// Returns true if this has no grouping at all (including no GROUPING SETS) + pub fn is_true_no_grouping(&self) -> bool { + self.is_empty() && !self.has_grouping_set + } + /// Returns the group expressions pub fn expr(&self) -> &[(Arc, String)] { &self.expr @@ -230,14 +350,20 @@ impl PhysicalGroupBy { &self.groups } + /// Returns true if this grouping uses GROUPING SETS, CUBE or ROLLUP. + pub fn has_grouping_set(&self) -> bool { + self.has_grouping_set + } + /// Returns true if this `PhysicalGroupBy` has no group expressions pub fn is_empty(&self) -> bool { self.expr.is_empty() } - /// Check whether grouping set is single group + /// Returns true if this is a "simple" GROUP BY (not using GROUPING SETS/CUBE/ROLLUP). + /// This determines whether the `__grouping_id` column is included in the output schema. pub fn is_single(&self) -> bool { - self.null_expr.is_empty() + !self.has_grouping_set } /// Calculate GROUP BY expressions according to input schema. @@ -251,7 +377,7 @@ impl PhysicalGroupBy { /// The number of expressions in the output schema. fn num_output_exprs(&self) -> usize { let mut num_exprs = self.expr.len(); - if !self.is_single() { + if self.has_grouping_set { num_exprs += 1 } num_exprs @@ -268,7 +394,7 @@ impl PhysicalGroupBy { .take(num_output_exprs) .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _), ); - if !self.is_single() { + if self.has_grouping_set { output_exprs.push(Arc::new(Column::new( Aggregate::INTERNAL_GROUPING_ID, self.expr.len(), @@ -279,11 +405,16 @@ impl PhysicalGroupBy { /// Returns the number expression as grouping keys. pub fn num_group_exprs(&self) -> usize { - if self.is_single() { - self.expr.len() - } else { - self.expr.len() + 1 - } + self.expr.len() + usize::from(self.has_grouping_set) + } + + /// Returns the Arrow data type of the `__grouping_id` column. + /// + /// The type is chosen to be wide enough to hold both the semantic bitmask + /// (in the low `n` bits, where `n` is the number of grouping expressions) + /// and the duplicate ordinal (in the high bits). + fn grouping_id_data_type(&self) -> DataType { + Aggregate::grouping_id_type(self.expr.len(), max_duplicate_ordinal(&self.groups)) } pub fn group_schema(&self, schema: &Schema) -> Result { @@ -294,7 +425,7 @@ impl PhysicalGroupBy { fn group_fields(&self, input_schema: &Schema) -> Result> { let mut fields = Vec::with_capacity(self.num_group_exprs()); for ((expr, name), group_expr_nullable) in - self.expr.iter().zip(self.exprs_nullable().into_iter()) + self.expr.iter().zip(self.exprs_nullable()) { fields.push( Field::new( @@ -306,11 +437,11 @@ impl PhysicalGroupBy { .into(), ); } - if !self.is_single() { + if self.has_grouping_set { fields.push( Field::new( Aggregate::INTERNAL_GROUPING_ID, - Aggregate::grouping_id_type(self.expr.len()), + self.grouping_id_data_type(), false, ) .into(), @@ -342,17 +473,17 @@ impl PhysicalGroupBy { ) .collect(); let num_exprs = expr.len(); - let groups = if self.expr.is_empty() { + let groups = if self.expr.is_empty() && !self.has_grouping_set { // No GROUP BY expressions - should have no groups vec![] } else { - // Has GROUP BY expressions - create a single group vec![vec![false; num_exprs]] }; Self { expr, null_expr: vec![], groups, + has_grouping_set: false, } } } @@ -372,13 +503,43 @@ impl PartialEq for PhysicalGroupBy { .zip(other.null_expr.iter()) .all(|((expr1, name1), (expr2, name2))| expr1.eq(expr2) && name1 == name2) && self.groups == other.groups + && self.has_grouping_set == other.has_grouping_set } } -#[allow(clippy::large_enum_variant)] +/// Streams used by [`AggregateExec`]. +/// +/// # Stream Variant Schema Notation +/// For example, `SELECT g, AVG(x) FROM t GROUP BY g` uses these schemas: +/// +/// ```text +/// initial input: [g, x] +/// partial state: [g, AVG(x) state columns, e.g. sum/count] +/// final result: [g, AVG(x)] +/// ``` +#[expect(clippy::large_enum_variant)] enum StreamType { + /// Single group (no group by) aggregate stream. + /// Input output scheme: initial input -> final result AggregateStream(AggregateStream), + /// Partial stage of the hash aggregation + /// Input output scheme: initial input -> partial state + PartialHash(PartialHashAggregateStream), + /// Final stage of the hash aggregation + /// Input output scheme: partial state -> final result + FinalHash(FinalHashAggregateStream), + /// Hash aggregation reused for multiple stages + /// + /// Note this is being incrementally migrated to dedicated streams like + /// [`StreamType::PartialHash`] and [`StreamType::FinalHash`] + /// + /// See issue for details: GroupedHash(GroupedHashAggregateStream), + /// Grouped TopK aggregate stream. + /// Input output scheme: initial input -> final result + /// + /// Used for grouped aggregation with LIMIT / ordering, where the stream keeps + /// only the top groups required by the query. GroupedPriorityQueue(GroupedTopKAggregateStream), } @@ -386,28 +547,152 @@ impl From for SendableRecordBatchStream { fn from(stream: StreamType) -> Self { match stream { StreamType::AggregateStream(stream) => Box::pin(stream), + StreamType::PartialHash(stream) => Box::pin(stream), + StreamType::FinalHash(stream) => Box::pin(stream), StreamType::GroupedHash(stream) => Box::pin(stream), StreamType::GroupedPriorityQueue(stream) => Box::pin(stream), } } } +/// # Aggregate Dynamic Filter Pushdown Overview +/// +/// For queries like +/// -- `example_table(type TEXT, val INT)` +/// SELECT min(val) +/// FROM example_table +/// WHERE type='A'; +/// +/// And `example_table`'s physical representation is a partitioned parquet file with +/// column statistics +/// - part-0.parquet: val {min=0, max=100} +/// - part-1.parquet: val {min=100, max=200} +/// - ... +/// - part-100.parquet: val {min=10000, max=10100} +/// +/// After scanning the 1st file, we know we only have to read files if their minimal +/// value on `val` column is less than 0, the minimal `val` value in the 1st file. +/// +/// We can skip scanning the remaining file by implementing dynamic filter, the +/// intuition is we keep a shared data structure for current min in both `AggregateExec` +/// and `DataSourceExec`, and let it update during execution, so the scanner can +/// know during execution if it's possible to skip scanning certain files. See +/// physical optimizer rule `FilterPushdown` for details. +/// +/// # Implementation +/// +/// ## Enable Condition +/// - No grouping (no `GROUP BY` clause in the sql, only a single global group to aggregate) +/// - The aggregate expression must be `min`/`max`, and evaluate directly on columns. +/// Note multiple aggregate expressions that satisfy this requirement are allowed, +/// and a dynamic filter will be constructed combining all applicable expr's +/// states. See more in the following example with dynamic filter on multiple columns. +/// +/// ## Filter Construction +/// The filter is kept in the `DataSourceExec`, and it will gets update during execution, +/// the reader will interpret it as "the upstream only needs rows that such filter +/// predicate is evaluated to true", and certain scanner implementation like `parquet` +/// can evaluate column statistics on those dynamic filters, to decide if they can +/// prune a whole range. +/// +/// ### Examples +/// - Expr: `min(a)`, Dynamic Filter: `a < a_cur_min` +/// - Expr: `min(a), max(a), min(b)`, Dynamic Filter: `(a < a_cur_min) OR (a > a_cur_max) OR (b < b_cur_min)` +#[derive(Debug, Clone)] +struct AggrDynFilter { + /// The physical expr for the dynamic filter shared between the `AggregateExec` + /// and the parquet scanner. + filter: Arc, + /// The current bounds for the dynamic filter, updates during the execution to + /// tighten the bound for more effective pruning. + /// + /// Each vector element is for the accumulators that support dynamic filter. + /// e.g. This `AggregateExec` has accumulator: + /// min(a), avg(a), max(b) + /// And this field stores [PerAccumulatorDynFilter(min(a)), PerAccumulatorDynFilter(min(b))] + supported_accumulators_info: Vec, +} + +// ---- Aggregate Dynamic Filter Utility Structs ---- + +/// Aggregate expressions that support the dynamic filter pushdown in aggregation. +/// See comments in [`AggrDynFilter`] for conditions. +#[derive(Debug, Clone)] +struct PerAccumulatorDynFilter { + aggr_type: DynamicFilterAggregateType, + /// During planning and optimization, the parent structure is kept in `AggregateExec`, + /// this index is into `aggr_expr` vec inside `AggregateExec`. + /// During execution, the parent struct is moved into `AggregateStream` (stream + /// for no grouping aggregate execution), and this index is into `aggregate_expressions` + /// vec inside `AggregateStreamInner` + aggr_index: usize, + // The current bound. Shared among all streams. + shared_bound: Arc>, +} + +/// Aggregate types that are supported for dynamic filter in `AggregateExec` +#[derive(Debug, Clone)] +enum DynamicFilterAggregateType { + Min, + Max, +} + +/// Configuration for limit-based optimizations in aggregation +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct LimitOptions { + /// The maximum number of rows to return + pub limit: usize, + /// Optional ordering direction (true = descending, false = ascending) + /// This is used for TopK aggregation to maintain a priority queue with the correct ordering + pub descending: Option, +} + +impl LimitOptions { + /// Create a new LimitOptions with a limit and no specific ordering + pub fn new(limit: usize) -> Self { + Self { + limit, + descending: None, + } + } + + /// Create a new LimitOptions with a limit and ordering direction + pub fn new_with_order(limit: usize, descending: bool) -> Self { + Self { + limit, + descending: Some(descending), + } + } + + pub fn limit(&self) -> usize { + self.limit + } + + pub fn descending(&self) -> Option { + self.descending + } +} + /// Hash aggregate execution plan #[derive(Debug, Clone)] pub struct AggregateExec { /// Aggregation mode (full, partial) mode: AggregateMode, /// Group by expressions - group_by: PhysicalGroupBy, + /// [`Arc`] used for a cheap clone, which improves physical plan optimization performance. + group_by: Arc, /// Aggregate expressions - aggr_expr: Vec>, + /// The same reason to [`Arc`] it as for [`Self::group_by`]. + aggr_expr: Arc<[Arc]>, /// FILTER (WHERE clause) expression for each aggregate expression - filter_expr: Vec>>, - /// Set if the output of this aggregation is truncated by a upstream sort/limit clause - limit: Option, + /// The same reason to [`Arc`] it as for [`Self::group_by`]. + filter_expr: Arc<[Option>]>, + /// Configuration for limit-based optimizations + limit_options: Option, /// Input plan, could be a partial aggregate or the input to the aggregate pub input: Arc, - /// Schema after the aggregate is applied + /// Schema after the aggregate is applied. Contains the group by columns followed by the + /// aggregate outputs. schema: SchemaRef, /// Input schema before any aggregation is applied. For partial aggregate this will be the /// same as input.schema() but for the final aggregate it will be the same as the input @@ -420,7 +705,14 @@ pub struct AggregateExec { required_input_ordering: Option, /// Describes how the input is ordered relative to the group by columns input_order_mode: InputOrderMode, - cache: PlanProperties, + cache: Arc, + /// During initialization, if the plan supports dynamic filtering (see [`AggrDynFilter`]), + /// it is set to `Some(..)` regardless of whether it can be pushed down to a child node. + /// + /// During filter pushdown optimization, if a child node can accept this filter, + /// it remains `Some(..)` to enable dynamic filtering during aggregate execution; + /// otherwise, it is cleared to `None`. + dynamic_filter: Option>, } impl AggregateExec { @@ -429,22 +721,43 @@ impl AggregateExec { /// Rewrites aggregate exec with new aggregate expressions. pub fn with_new_aggr_exprs( &self, - aggr_expr: Vec>, + aggr_expr: impl Into]>>, ) -> Self { Self { - aggr_expr, + aggr_expr: aggr_expr.into(), + // clone the rest of the fields + required_input_ordering: self.required_input_ordering.clone(), + metrics: ExecutionPlanMetricsSet::new(), + input_order_mode: self.input_order_mode.clone(), + cache: Arc::clone(&self.cache), + mode: self.mode, + group_by: Arc::clone(&self.group_by), + filter_expr: Arc::clone(&self.filter_expr), + limit_options: self.limit_options, + input: Arc::clone(&self.input), + schema: Arc::clone(&self.schema), + input_schema: Arc::clone(&self.input_schema), + dynamic_filter: self.dynamic_filter.clone(), + } + } + + /// Clone this exec, overriding only the limit hint. + pub fn with_new_limit_options(&self, limit_options: Option) -> Self { + Self { + limit_options, // clone the rest of the fields required_input_ordering: self.required_input_ordering.clone(), metrics: ExecutionPlanMetricsSet::new(), input_order_mode: self.input_order_mode.clone(), - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), mode: self.mode, - group_by: self.group_by.clone(), - filter_expr: self.filter_expr.clone(), - limit: self.limit, + group_by: Arc::clone(&self.group_by), + aggr_expr: Arc::clone(&self.aggr_expr), + filter_expr: Arc::clone(&self.filter_expr), input: Arc::clone(&self.input), schema: Arc::clone(&self.schema), input_schema: Arc::clone(&self.input_schema), + dynamic_filter: self.dynamic_filter.clone(), } } @@ -455,12 +768,13 @@ impl AggregateExec { /// Create a new hash aggregate execution plan pub fn try_new( mode: AggregateMode, - group_by: PhysicalGroupBy, + group_by: impl Into>, aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, ) -> Result { + let group_by = group_by.into(); let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?; let schema = Arc::new(schema); @@ -483,16 +797,18 @@ impl AggregateExec { /// a rule may re-write aggregate expressions (e.g. reverse them) during /// initialization, field names may change inadvertently if one re-creates /// the schema in such cases. - #[allow(clippy::too_many_arguments)] fn try_new_with_schema( mode: AggregateMode, - group_by: PhysicalGroupBy, + group_by: impl Into>, mut aggr_expr: Vec>, - filter_expr: Vec>>, + filter_expr: impl Into>]>>, input: Arc, input_schema: SchemaRef, schema: SchemaRef, ) -> Result { + let group_by = group_by.into(); + let filter_expr = filter_expr.into(); + // Make sure arguments are consistent in size assert_eq_or_internal_err!( aggr_expr.len(), @@ -508,12 +824,13 @@ impl AggregateExec { // If existing ordering satisfies a prefix of the GROUP BY expressions, // prefix requirements with this section. In this case, aggregation will // work more efficiently. - let indices = get_ordered_partition_by_indices(&groupby_exprs, &input)?; - let mut new_requirements = indices - .iter() - .map(|&idx| { - PhysicalSortRequirement::new(Arc::clone(&groupby_exprs[idx]), None) - }) + // Copy the `PhysicalSortExpr`s to retain the sort options. + let (new_sort_exprs, indices) = + input_eq_properties.find_longest_permutation(&groupby_exprs)?; + + let mut new_requirements = new_sort_exprs + .into_iter() + .map(PhysicalSortRequirement::from) .collect::>(); let req = get_finer_aggregate_exprs_requirement( @@ -558,23 +875,28 @@ impl AggregateExec { &group_expr_mapping, &mode, &input_order_mode, - aggr_expr.as_slice(), + aggr_expr.as_ref(), )?; - Ok(AggregateExec { + let mut exec = AggregateExec { mode, group_by, - aggr_expr, + aggr_expr: aggr_expr.into(), filter_expr, input, schema, input_schema, metrics: ExecutionPlanMetricsSet::new(), required_input_ordering, - limit: None, + limit_options: None, input_order_mode, - cache, - }) + cache: Arc::new(cache), + dynamic_filter: None, + }; + + exec.init_dynamic_filter(); + + Ok(exec) } /// Aggregation mode (full, partial) @@ -582,11 +904,17 @@ impl AggregateExec { &self.mode } - /// Set the `limit` of this AggExec - pub fn with_limit(mut self, limit: Option) -> Self { - self.limit = limit; + /// Set the limit options for this AggExec + pub fn with_limit_options(mut self, limit_options: Option) -> Self { + self.limit_options = limit_options; self } + + /// Get the limit options (if set) + pub fn limit_options(&self) -> Option { + self.limit_options + } + /// Grouping expressions pub fn group_expr(&self) -> &PhysicalGroupBy { &self.group_by @@ -607,6 +935,47 @@ impl AggregateExec { &self.filter_expr } + /// Returns the dynamic filter expression for this aggregate, if set. + pub fn dynamic_filter_expr(&self) -> Option<&Arc> { + self.dynamic_filter.as_ref().map(|df| &df.filter) + } + + /// Replace the dynamic filter expression. This method errors if the aggregate does not + /// support dynamic filtering or if the filter expression is incompatible with this + /// [`AggregateExec`]. + pub fn with_dynamic_filter_expr( + mut self, + filter: Arc, + ) -> Result { + // If there is no dynamic filter state initialized via `try_new`, then + // we can safely assume that the aggregate does not support dynamic filtering. + let Some(dyn_filter) = self.dynamic_filter.as_ref() else { + return internal_err!("Aggregate does not support dynamic filtering"); + }; + + // Validate that the filter is compatible with the aggregation columns. + let cols = self.cols_for_dynamic_filter(&dyn_filter.supported_accumulators_info); + if cols.len() != filter.children().len() { + return internal_err!( + "Dynamic filter expression is incompatible with aggregate due to mismatched number of columns" + ); + } + for (col, child) in cols.iter().zip(filter.children()) { + if !col.eq(child) { + return internal_err!( + "Dynamic filter expression is incompatible with aggregate due to mismatched column references {col} != {child}" + ); + } + } + + // Overwrite our filter + self.dynamic_filter = Some(Arc::new(AggrDynFilter { + filter, + supported_accumulators_info: dyn_filter.supported_accumulators_info.clone(), + })); + Ok(self) + } + /// Input plan pub fn input(&self) -> &Arc { &self.input @@ -617,29 +986,42 @@ impl AggregateExec { Arc::clone(&self.input_schema) } - /// number of rows soft limit of the AggregateExec - pub fn limit(&self) -> Option { - self.limit - } - fn execute_typed( &self, partition: usize, context: &Arc, ) -> Result { - // no group by at all - if self.group_by.expr.is_empty() { + if self.group_by.is_true_no_grouping() { return Ok(StreamType::AggregateStream(AggregateStream::new( self, context, partition, )?)); } // grouping by an expression that has a sort/limit upstream - if let Some(limit) = self.limit { - if !self.is_unordered_unfiltered_group_by_distinct() { - return Ok(StreamType::GroupedPriorityQueue( - GroupedTopKAggregateStream::new(self, context, partition, limit)?, - )); + if let Some(config) = self.limit_options + && !self.is_unordered_unfiltered_group_by_distinct() + { + return Ok(StreamType::GroupedPriorityQueue( + GroupedTopKAggregateStream::new(self, context, partition, config.limit)?, + )); + } + + if context + .session_config() + .options() + .execution + .enable_migration_aggregate + { + if self.should_use_partial_hash_stream(context) { + return Ok(StreamType::PartialHash(PartialHashAggregateStream::new( + self, context, partition, + )?)); + } + + if self.should_use_final_hash_stream(context) { + return Ok(StreamType::FinalHash(FinalHashAggregateStream::new( + self, context, partition, + )?)); } } @@ -649,6 +1031,39 @@ impl AggregateExec { )?)) } + fn should_use_partial_hash_stream(&self, context: &TaskContext) -> bool { + // TODO: implement memory-limited path and remove this limitation + if matches!(context.memory_pool().memory_limit(), MemoryLimit::Finite(_)) { + return false; + } + + self.mode == AggregateMode::Partial + && self.input_order_mode == InputOrderMode::Linear + && !self.group_by.is_true_no_grouping() + && self.group_by.is_single() + && self.limit_options_supported_by_hash_stream() + } + + fn should_use_final_hash_stream(&self, context: &TaskContext) -> bool { + // TODO: implement memory-limited path and remove this limitation + if matches!(context.memory_pool().memory_limit(), MemoryLimit::Finite(_)) { + return false; + } + + matches!( + self.mode, + AggregateMode::Final | AggregateMode::FinalPartitioned + ) && self.limit_options_supported_by_hash_stream() + && self.input_order_mode == InputOrderMode::Linear + && !self.group_by.is_true_no_grouping() + && self.group_by.is_single() + } + + /// See comments in `PartialHashAggregateStream` limit optimization section + fn limit_options_supported_by_hash_stream(&self) -> bool { + self.limit_options.is_none() || self.is_unordered_unfiltered_group_by_distinct() + } + /// Finds the DataType and SortDirection for this Aggregate, if there is one pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> { let agg_expr = self.aggr_expr.iter().exactly_one().ok()?; @@ -660,8 +1075,15 @@ impl AggregateExec { /// This method qualifies the use of the LimitedDistinctAggregation rewrite rule /// on an AggregateExec. pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool { + if self + .limit_options() + .and_then(|config| config.descending) + .is_some() + { + return false; + } // ensure there is a group by - if self.group_expr().is_empty() { + if self.group_expr().is_empty() && !self.group_expr().has_grouping_set() { return false; } // ensure there are no aggregate expressions @@ -720,7 +1142,7 @@ impl AggregateExec { .iter() .flat_map(|(_, target_cols)| { target_cols.iter().flat_map(|(expr, _)| { - expr.as_any().downcast_ref::().map(|c| c.index()) + expr.downcast_ref::().map(|c| c.index()) }) }) .collect(), @@ -731,14 +1153,15 @@ impl AggregateExec { // Get output partitioning: let input_partitioning = input.output_partitioning().clone(); - let output_partitioning = if mode.is_first_stage() { - // First stage aggregation will not change the output partitioning, - // but needs to respect aliases (e.g. mapping in the GROUP BY - // expression). - let input_eq_properties = input.equivalence_properties(); - input_partitioning.project(group_expr_mapping, input_eq_properties) - } else { - input_partitioning.clone() + let output_partitioning = match mode.input_mode() { + AggregateInputMode::Raw => { + // First stage aggregation will not change the output partitioning, + // but needs to respect aliases (e.g. mapping in the GROUP BY + // expression). + let input_eq_properties = input.equivalence_properties(); + input_partitioning.project(group_expr_mapping, input_eq_properties) + } + AggregateInputMode::Partial => input_partitioning.clone(), }; // TODO: Emission type and boundedness information can be enhanced here @@ -760,6 +1183,45 @@ impl AggregateExec { &self.input_order_mode } + /// Estimates output statistics for this aggregate node. + /// + /// For grouped aggregations with known input row count > 1, the output row + /// count is estimated as: + /// + /// ```text + /// ndv = sum over each grouping set of product(max(NDV_i + nulls_i, 1)) + /// output_rows = input_rows // baseline + /// output_rows = min(output_rows, ndv) // if NDV available + /// output_rows = min(output_rows, limit) // if TopK active + /// ``` + /// + /// **Example 1 — single group key:** + /// `GROUP BY city` where input_rows = 10,000, NDV(city) = 200 + /// → output_rows = min(10_000, 200) = 200 + /// + /// **Example 2 — two group keys with TopK:** + /// `GROUP BY city, category` where input_rows = 10,000, NDV(city) = 200, + /// NDV(category) = 5, limit = 100 + /// → ndv = 200 × 5 = 1,000 + /// → output_rows = min(10_000, 1_000) = 1,000 + /// → output_rows = min(1_000, 100) = 100 + /// + /// When `input_rows` is absent but NDV is available, falls back to: + /// + /// ```text + /// output_rows = min(ndv, limit) // if both available + /// output_rows = ndv // if only NDV available + /// output_rows = limit // if only limit available + /// ``` + /// + /// NDV estimation details (see [`Self::compute_group_ndv`]): + /// - For each grouping set, only active (non-NULL) columns contribute + /// - Per-column contribution is `max(NDV + null_adj, 1)` where `null_adj` + /// is 1 when nulls are present, 0 otherwise (a null group is a distinct + /// output row; `.max(1)` prevents a zero NDV from zeroing the product) + /// - Per-set products are summed across all grouping sets + /// - Requires NDV stats for ALL active group-by columns; if any lacks stats, + /// falls back to `input_rows` (or `Absent` if that is also unknown) fn statistics_inner(&self, child_statistics: &Statistics) -> Result { // TODO stats: group expressions: // - once expressions will be able to compute their own stats, use it here @@ -772,16 +1234,13 @@ impl AggregateExec { let mut column_statistics = Statistics::unknown_column(&self.schema()); for (idx, (expr, _)) in self.group_by.expr.iter().enumerate() { - if let Some(col) = expr.as_any().downcast_ref::() { - column_statistics[idx].max_value = child_statistics.column_statistics - [col.index()] - .max_value - .clone(); - - column_statistics[idx].min_value = child_statistics.column_statistics - [col.index()] - .min_value - .clone(); + if let Some(col) = expr.downcast_ref::() { + let child_col_stats = + &child_statistics.column_statistics[col.index()]; + column_statistics[idx].max_value = child_col_stats.max_value.clone(); + column_statistics[idx].min_value = child_col_stats.min_value.clone(); + column_statistics[idx].distinct_count = + child_col_stats.distinct_count; } } @@ -791,37 +1250,232 @@ impl AggregateExec { AggregateMode::Final | AggregateMode::FinalPartitioned if self.group_by.expr.is_empty() => { + let total_byte_size = + Self::calculate_scaled_byte_size(child_statistics, 1); + Ok(Statistics { num_rows: Precision::Exact(1), column_statistics, - total_byte_size: Precision::Absent, + total_byte_size, }) } _ => { - // When the input row count is 1, we can adopt that statistic keeping its reliability. - // When it is larger than 1, we degrade the precision since it may decrease after aggregation. - let num_rows = if let Some(value) = child_statistics.num_rows.get_value() - { - if *value > 1 { - child_statistics.num_rows.to_inexact() - } else if *value == 0 { - child_statistics.num_rows - } else { - // num_rows = 1 case - let grouping_set_num = self.group_by.groups.len(); - child_statistics.num_rows.map(|x| x * grouping_set_num) - } - } else { - Precision::Absent - }; + let num_rows = self.estimate_num_rows(child_statistics); + + let total_byte_size = num_rows + .get_value() + .and_then(|&output_rows| { + Self::calculate_scaled_byte_size(child_statistics, output_rows) + .get_value() + .map(|&bytes| Precision::Inexact(bytes)) + }) + .unwrap_or(Precision::Absent); + Ok(Statistics { num_rows, column_statistics, - total_byte_size: Precision::Absent, + total_byte_size, }) } } } + + /// Estimates the output row count for grouped aggregations, combining NDV, + /// input row count, and TopK limit into a single [`Precision`]. + fn estimate_num_rows(&self, child_statistics: &Statistics) -> Precision { + let ndv = if !self.group_by.expr.is_empty() { + self.compute_group_ndv(child_statistics) + } else { + None + }; + let limit = self.limit_options.as_ref().map(|lo| lo.limit); + + if let Some(&value) = child_statistics.num_rows.get_value() { + if value > 1 { + let mut num_rows = child_statistics.num_rows.to_inexact(); + if let Some(ndv) = ndv { + num_rows = num_rows.map(|n| n.min(ndv)); + } + if let Some(limit) = limit { + num_rows = num_rows.map(|n| n.min(limit)); + } + num_rows + } else if value == 0 { + child_statistics.num_rows + } else { + let grouping_set_num = self.group_by.groups.len(); + let mut num_rows = + child_statistics.num_rows.map(|x| x * grouping_set_num); + if let Some(limit) = limit { + num_rows = num_rows.map(|n| n.min(limit)); + } + num_rows + } + } else { + match (ndv, limit) { + (Some(n), Some(l)) => Precision::Inexact(n.min(l)), + (Some(n), None) => Precision::Inexact(n), + (None, Some(l)) => Precision::Inexact(l), + (None, None) => Precision::Absent, + } + } + } + + /// Computes the estimated number of distinct groups across all grouping sets. + /// For each grouping set, computes `product(NDV_i + null_adj_i)` for active columns, + /// then sums across all sets. Returns `None` if any active column is not a direct + /// column reference or lacks `distinct_count` stats. Non-column expressions + /// (e.g. `abs(a)`) are not yet supported because expression-level statistics + /// propagation is still in progress (see ). + /// When `null_count` is absent or unknown, null_adjustment defaults to 0. + /// + /// **Single key:** `GROUP BY a` where NDV(a) = 100, null_count(a) = 5 + /// → product = max(100 + 1, 1) = 101, total = 101 + /// + /// **Two keys:** `GROUP BY a, b` where NDV(a) = 100, NDV(b) = 50, no nulls + /// → product = 100 × 50 = 5,000, total = 5,000 + /// + /// **Grouping sets:** `GROUPING SETS ((a), (b), (a, b))` with NDV(a) = 100, NDV(b) = 50 + /// → set(a) = 100, set(b) = 50, set(a, b) = 100 × 50 = 5,000 + /// → total = 100 + 50 + 5,000 = 5,150 + fn compute_group_ndv(&self, child_statistics: &Statistics) -> Option { + let mut total: usize = 0; + for group_mask in &self.group_by.groups { + let mut set_product: usize = 1; + for (j, (expr, _)) in self.group_by.expr.iter().enumerate() { + if group_mask[j] { + continue; + } + let col = expr.downcast_ref::()?; + let col_stats = &child_statistics.column_statistics[col.index()]; + let ndv = *col_stats.distinct_count.get_value()?; + let null_adjustment = match col_stats.null_count.get_value() { + Some(&n) if n > 0 => 1usize, + _ => 0, + }; + set_product = set_product + .saturating_mul(ndv.saturating_add(null_adjustment).max(1)); + } + total = total.saturating_add(set_product); + } + Some(total) + } + + /// Check if dynamic filter is possible for the current plan node. + /// - If yes, init one inside `AggregateExec`'s `dynamic_filter` field. + /// - If not supported, `self.dynamic_filter` should be kept `None` + fn init_dynamic_filter(&mut self) { + if (!self.group_by.is_empty()) || (self.mode != AggregateMode::Partial) { + debug_assert!( + self.dynamic_filter.is_none(), + "The current operator node does not support dynamic filter" + ); + return; + } + + // Already initialized. + if self.dynamic_filter.is_some() { + return; + } + + // Collect supported accumulators + // It is assumed the order of aggregate expressions are not changed from `AggregateExec` + // to `AggregateStream` + let mut aggr_dyn_filters = Vec::new(); + // All column references in the dynamic filter, used when initializing the dynamic + // filter, and it's used to decide if this dynamic filter is able to get push + // through certain node during optimization. + let mut all_cols: Vec> = Vec::new(); + for (i, aggr_expr) in self.aggr_expr.iter().enumerate() { + // 1. Only `min` or `max` aggregate function + let fun_name = aggr_expr.fun().name(); + // HACK: Should check the function type more precisely + // Issue: + let aggr_type = if fun_name.eq_ignore_ascii_case("min") { + DynamicFilterAggregateType::Min + } else if fun_name.eq_ignore_ascii_case("max") { + DynamicFilterAggregateType::Max + } else { + return; + }; + + // 2. arg should be only 1 column reference + if let [arg] = aggr_expr.expressions().as_slice() + && arg.is::() + { + all_cols.push(Arc::clone(arg)); + aggr_dyn_filters.push(PerAccumulatorDynFilter { + aggr_type, + aggr_index: i, + shared_bound: Arc::new(Mutex::new(ScalarValue::Null)), + }); + } + } + + if !aggr_dyn_filters.is_empty() { + self.dynamic_filter = Some(Arc::new(AggrDynFilter { + filter: Arc::new(DynamicFilterPhysicalExpr::new(all_cols, lit(true))), + supported_accumulators_info: aggr_dyn_filters, + })) + } + } + + // Collect column references for the dynamic filter expression from the supported accumulators. + fn cols_for_dynamic_filter( + &self, + supported_accumulators_info: &[PerAccumulatorDynFilter], + ) -> Vec> { + let all_cols: Vec> = supported_accumulators_info + .iter() + .filter_map(|info| { + // This should always be true due to how the supported accumulators + // are constructed. See `init_dynamic_filter` for more details. + if let [arg] = &self.aggr_expr[info.aggr_index].expressions().as_slice() + && arg.is::() + { + return Some(Arc::clone(arg)); + } + None + }) + .collect(); + debug_assert!(all_cols.len() == supported_accumulators_info.len()); + all_cols + } + + /// Calculate scaled byte size based on row count ratio. + /// Returns `Precision::Absent` if input statistics are insufficient. + /// Returns `Precision::Inexact` with the scaled value otherwise. + /// + /// This is a simple heuristic that assumes uniform row sizes. + #[inline] + fn calculate_scaled_byte_size( + input_stats: &Statistics, + target_row_count: usize, + ) -> Precision { + match ( + input_stats.num_rows.get_value(), + input_stats.total_byte_size.get_value(), + ) { + (Some(&input_rows), Some(&input_bytes)) if input_rows > 0 => { + let bytes_per_row = input_bytes as f64 / input_rows as f64; + let scaled_bytes = + (bytes_per_row * target_row_count as f64).ceil() as usize; + Precision::Inexact(scaled_bytes) + } + _ => Precision::Absent, + } + } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for AggregateExec { @@ -878,11 +1532,11 @@ impl DisplayAs for AggregateExec { let a: Vec = self .aggr_expr .iter() - .map(|agg| agg.name().to_string()) + .map(|agg| format_aggregate_exec_expr(agg).to_string()) .collect(); write!(f, ", aggr=[{}]", a.join(", "))?; - if let Some(limit) = self.limit { - write!(f, ", lim=[{limit}]")?; + if let Some(config) = self.limit_options { + write!(f, ", lim=[{}]", config.limit)?; } if self.input_order_mode != InputOrderMode::Linear { @@ -932,7 +1586,7 @@ impl DisplayAs for AggregateExec { let a: Vec = self .aggr_expr .iter() - .map(|agg| agg.human_display().to_string()) + .map(|agg| format_tree_aggregate_expr(agg).to_string()) .collect(); writeln!(f, "mode={:?}", self.mode)?; if !g.is_empty() { @@ -941,29 +1595,51 @@ impl DisplayAs for AggregateExec { if !a.is_empty() { writeln!(f, "aggr={}", a.join(", "))?; } + if let Some(config) = self.limit_options { + writeln!(f, "limit={}", config.limit)?; + } } } Ok(()) } } +fn format_aggregate_exec_expr(agg: &AggregateFunctionExpr) -> Cow<'_, str> { + match agg.human_display_alias() { + Some(_) => format_human_display(agg.human_display(), agg.human_display_alias()) + .unwrap_or_else(|| Cow::Borrowed(agg.name())), + None => Cow::Borrowed(agg.name()), + } +} + +fn format_tree_aggregate_expr(agg: &AggregateFunctionExpr) -> Cow<'_, str> { + format_human_display(agg.human_display(), agg.human_display_alias()) + .unwrap_or_else(|| Cow::Borrowed(agg.name())) +} + +fn format_human_display<'a>( + human_display: Option<&'a str>, + alias: Option<&'a str>, +) -> Option> { + human_display.map(|human_display| match alias { + Some(alias) => Cow::Owned(format!("{human_display} as {alias}")), + None => Cow::Borrowed(human_display), + }) +} + impl ExecutionPlan for AggregateExec { fn name(&self) -> &'static str { "AggregateExec" } /// Return a reference to Any that can be used for down-casting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } fn required_input_distribution(&self) -> Vec { match &self.mode { - AggregateMode::Partial => { + AggregateMode::Partial | AggregateMode::PartialReduce => { vec![Distribution::UnspecifiedDistribution] } AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => { @@ -1000,16 +1676,19 @@ impl ExecutionPlan for AggregateExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); + let mut me = AggregateExec::try_new_with_schema( self.mode, - self.group_by.clone(), - self.aggr_expr.clone(), - self.filter_expr.clone(), + Arc::clone(&self.group_by), + self.aggr_expr.to_vec(), + Arc::clone(&self.filter_expr), Arc::clone(&children[0]), Arc::clone(&self.input_schema), Arc::clone(&self.schema), )?; - me.limit = self.limit; + me.limit_options = self.limit_options; + me.dynamic_filter.clone_from(&self.dynamic_filter); Ok(Arc::new(me)) } @@ -1027,13 +1706,9 @@ impl ExecutionPlan for AggregateExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { let child_statistics = self.input().partition_statistics(partition)?; - self.statistics_inner(&child_statistics) + Ok(Arc::new(self.statistics_inner(&child_statistics)?)) } fn cardinality_effect(&self) -> CardinalityEffect { @@ -1041,12 +1716,12 @@ impl ExecutionPlan for AggregateExec { } /// Push down parent filters when possible (see implementation comment for details), - /// but do not introduce any new self filters. + /// and also pushdown self dynamic filters (see `AggrDynFilter` for details) fn gather_filters_for_pushdown( &self, - _phase: FilterPushdownPhase, + phase: FilterPushdownPhase, parent_filters: Vec>, - _config: &ConfigOptions, + config: &ConfigOptions, ) -> Result { // It's safe to push down filters through aggregates when filters only reference // grouping columns, because such filters determine which groups to compute, not @@ -1056,11 +1731,15 @@ impl ExecutionPlan for AggregateExec { // This optimization is NOT safe for filters on aggregated columns (like filtering on // the result of SUM or COUNT), as those require computing all groups first. - let grouping_columns: HashSet<_> = self - .group_by - .expr() - .iter() - .flat_map(|(expr, _)| collect_columns(expr)) + // Build grouping columns using output indices because parent filters reference the + // AggregateExec's output schema where grouping columns in the output schema. The + // grouping expressions reference input columns which may not match the output schema. + // + // It is safe to assume that the output_schema contains group by columns in the same order + // as the group by expression. See [`create_schema`] and [`AggregateExec`]. + let output_schema = self.schema(); + let grouping_columns: HashSet<_> = (0..self.group_by.expr().len()) + .map(|i| Column::new(output_schema.field(i).name(), i)) .collect(); // Analyze each filter separately to determine if it can be pushed down @@ -1085,9 +1764,7 @@ impl ExecutionPlan for AggregateExec { let filter_column_indices: Vec = filter_columns .iter() .filter_map(|filter_col| { - self.group_by.expr().iter().position(|(expr, _)| { - collect_columns(expr).contains(filter_col) - }) + grouping_columns.get(filter_col).map(|col| col.index()) }) .collect(); @@ -1119,33 +1796,92 @@ impl ExecutionPlan for AggregateExec { .map(PushedDownPredicate::unsupported), ); + // Include self dynamic filter when it's possible + if phase == FilterPushdownPhase::Post + && config.optimizer.enable_aggregate_dynamic_filter_pushdown + && let Some(self_dyn_filter) = &self.dynamic_filter + { + let dyn_filter = Arc::clone(&self_dyn_filter.filter); + child_desc = child_desc.with_self_filter(dyn_filter); + } + Ok(FilterDescription::new().with_child(child_desc)) } -} -fn create_schema( - input_schema: &Schema, - group_by: &PhysicalGroupBy, - aggr_expr: &[Arc], - mode: AggregateMode, + /// If child accepts self's dynamic filter, keep `self.dynamic_filter` with Some, + /// otherwise clear it to None. + fn handle_child_pushdown_result( + &self, + phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone()); + + // If this node tried to pushdown some dynamic filter before, now we check + // if the child accept the filter + if phase == FilterPushdownPhase::Post + && let Some(dyn_filter) = &self.dynamic_filter + { + // let child_accepts_dyn_filter = child_pushdown_result + // .self_filters + // .first() + // .map(|filters| { + // assert_eq_or_internal_err!( + // filters.len(), + // 1, + // "Aggregate only pushdown one self dynamic filter" + // ); + // let filter = filters.get(0).unwrap(); // Asserted above + // Ok(matches!(filter.discriminant, PushedDown::Yes)) + // }) + // .unwrap_or_else(|| internal_err!("The length of self filters equals to the number of child of this ExecutionPlan, so it must be 1"))?; + + // HACK: The above snippet should be used, however, now the child reply + // `PushDown::No` can indicate they're not able to push down row-level + // filter, but still keep the filter for statistics pruning. + // So here, we try to use ref count to determine if the dynamic filter + // has actually be pushed down. + // Issue: + let child_accepts_dyn_filter = Arc::strong_count(dyn_filter) > 1; + + if !child_accepts_dyn_filter { + // Child can't consume the self dynamic filter, so disable it by setting + // to `None` + let mut new_node = self.clone(); + new_node.dynamic_filter = None; + + result = result + .with_updated_node(Arc::new(new_node) as Arc); + } + } + + Ok(result) + } +} + +/// Creates the output schema for an [`AggregateExec`] containing the group by columns followed +/// by the aggregate columns. +fn create_schema( + input_schema: &Schema, + group_by: &PhysicalGroupBy, + aggr_expr: &[Arc], + mode: AggregateMode, ) -> Result { let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len()); fields.extend(group_by.output_fields(input_schema)?); - match mode { - AggregateMode::Partial => { - // in partial mode, the fields of the accumulator's state + match mode.output_mode() { + AggregateOutputMode::Final => { + // in final mode, the field with the final result of the accumulator for expr in aggr_expr { - fields.extend(expr.state_fields()?.iter().cloned()); + fields.push(expr.field()) } } - AggregateMode::Final - | AggregateMode::FinalPartitioned - | AggregateMode::Single - | AggregateMode::SinglePartitioned => { - // in final mode, the field with the final result of the accumulator + AggregateOutputMode::Partial => { + // in partial mode, the fields of the accumulator's state for expr in aggr_expr { - fields.push(expr.field()) + fields.extend(expr.state_fields()?.iter().cloned()); } } } @@ -1185,7 +1921,7 @@ fn get_aggregate_expr_req( // If the aggregation is performing a "second stage" calculation, // then ignore the ordering requirement. Ordering requirement applies // only to the aggregation input data. - if !agg_mode.is_first_stage() { + if agg_mode.input_mode() == AggregateInputMode::Partial { return None; } @@ -1351,10 +2087,8 @@ pub fn aggregate_expressions( mode: &AggregateMode, col_idx_base: usize, ) -> Result>>> { - match mode { - AggregateMode::Partial - | AggregateMode::Single - | AggregateMode::SinglePartitioned => Ok(aggr_expr + match mode.input_mode() { + AggregateInputMode::Raw => Ok(aggr_expr .iter() .map(|agg| { let mut result = agg.expressions(); @@ -1365,8 +2099,8 @@ pub fn aggregate_expressions( result }) .collect()), - // In this mode, we build the merge expressions of the aggregation. - AggregateMode::Final | AggregateMode::FinalPartitioned => { + AggregateInputMode::Partial => { + // In merge mode, we build the merge expressions of the aggregation. let mut col_idx_base = col_idx_base; aggr_expr .iter() @@ -1414,8 +2148,15 @@ pub fn finalize_aggregation( accumulators: &mut [AccumulatorItem], mode: &AggregateMode, ) -> Result> { - match mode { - AggregateMode::Partial => { + match mode.output_mode() { + AggregateOutputMode::Final => { + // Merge the state to the final value + accumulators + .iter_mut() + .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array())) + .collect() + } + AggregateOutputMode::Partial => { // Build the vector of states accumulators .iter_mut() @@ -1429,16 +2170,6 @@ pub fn finalize_aggregation( .flatten_ok() .collect() } - AggregateMode::Final - | AggregateMode::FinalPartitioned - | AggregateMode::Single - | AggregateMode::SinglePartitioned => { - // Merge the state to the final value - accumulators - .iter_mut() - .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array())) - .collect() - } } } @@ -1468,25 +2199,69 @@ fn evaluate_optional( .collect() } -fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result { - if group.len() > 64 { +/// Builds the internal `__grouping_id` array for a single grouping set. +/// +/// The returned array packs two values into a single integer: +/// +/// - Low `n` bits (positions 0 .. n-1): the semantic bitmask. A `1` bit +/// at position `i` means that the `i`-th grouping column (counting from the +/// least significant bit, i.e. the *last* column in the `group` slice) is +/// `NULL` for this grouping set. +/// - High bits (positions n and above): the duplicate `ordinal`, which +/// distinguishes multiple occurrences of the same grouping-set pattern. The +/// ordinal is `0` for the first occurrence, `1` for the second, and so on. +/// +/// The integer type is chosen to be the smallest `UInt8 / UInt16 / UInt32 / +/// UInt64` that can represent both parts. It matches the type returned by +/// [`Aggregate::grouping_id_type`]. +pub(crate) fn group_id_array( + group: &[bool], + ordinal: usize, + max_ordinal: usize, + num_rows: usize, +) -> Result { + let n = group.len(); + if n > 64 { return not_impl_err!( "Grouping sets with more than 64 columns are not supported" ); } - let group_id = group.iter().fold(0u64, |acc, &is_null| { + let ordinal_bits = usize::BITS as usize - max_ordinal.leading_zeros() as usize; + let total_bits = n + ordinal_bits; + if total_bits > 64 { + return not_impl_err!( + "Grouping sets with {n} columns and a maximum duplicate ordinal of \ + {max_ordinal} require {total_bits} bits, which exceeds 64" + ); + } + let semantic_id = group.iter().fold(0u64, |acc, &is_null| { (acc << 1) | if is_null { 1 } else { 0 } }); - let num_rows = batch.num_rows(); - if group.len() <= 8 { - Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows]))) - } else if group.len() <= 16 { - Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows]))) - } else if group.len() <= 32 { - Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows]))) + let full_id = semantic_id | ((ordinal as u64) << n); + if total_bits <= 8 { + Ok(Arc::new(UInt8Array::from(vec![full_id as u8; num_rows]))) + } else if total_bits <= 16 { + Ok(Arc::new(UInt16Array::from(vec![full_id as u16; num_rows]))) + } else if total_bits <= 32 { + Ok(Arc::new(UInt32Array::from(vec![full_id as u32; num_rows]))) } else { - Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows]))) + Ok(Arc::new(UInt64Array::from(vec![full_id; num_rows]))) + } +} + +/// Returns the highest duplicate ordinal across all grouping sets. +/// +/// At the call-site, the ordinal is the 0-based index assigned to each +/// occurrence of a repeated grouping-set pattern: the first occurrence gets +/// ordinal 0, the second gets 1, and so on. If the same `Vec` appears +/// three times the ordinals are 0, 1, 2 and this function returns 2. +/// Returns 0 when no grouping set is duplicated. +pub(crate) fn max_duplicate_ordinal(groups: &[Vec]) -> usize { + let mut counts: HashMap<&[bool], usize> = HashMap::new(); + for group in groups { + *counts.entry(group).or_insert(0) += 1; } + counts.into_values().max().unwrap_or(0).saturating_sub(1) } /// Evaluate a group by expression against a `RecordBatch` @@ -1499,10 +2274,38 @@ fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result { /// The outer Vec appears to be for grouping sets /// The inner Vec contains the results per expression /// The inner-inner Array contains the results per row +/// +/// For example, for `GROUP BY GROUPING SETS ((a, b), (a))` with input: +/// +/// ```text +/// a b +/// 1 1 +/// 1 2 +/// 2 1 +/// ``` +/// +/// The output is: +/// +/// ```text +/// [ +/// [ +/// a: [1, 1, 2] +/// b: [1, 2, 1] +/// grouping_id: [0, 0, 0] +/// ], +/// [ +/// a: [1, 1, 2] +/// b: [NULL, NULL, NULL] +/// grouping_id: [1, 1, 1] +/// ] +/// ] +/// ``` pub fn evaluate_group_by( group_by: &PhysicalGroupBy, batch: &RecordBatch, ) -> Result>> { + let max_ordinal = max_duplicate_ordinal(&group_by.groups); + let mut ordinal_per_pattern: HashMap<&[bool], usize> = HashMap::new(); let exprs = evaluate_expressions_to_arrays( group_by.expr.iter().map(|(expr, _)| expr), batch, @@ -1516,6 +2319,10 @@ pub fn evaluate_group_by( .groups .iter() .map(|group| { + let ordinal = ordinal_per_pattern.entry(group).or_insert(0); + let current_ordinal = *ordinal; + *ordinal += 1; + let mut group_values = Vec::with_capacity(group_by.num_group_exprs()); group_values.extend(group.iter().enumerate().map(|(idx, is_null)| { if *is_null { @@ -1525,7 +2332,12 @@ pub fn evaluate_group_by( } })); if !group_by.is_single() { - group_values.push(group_id_array(group, batch)?); + group_values.push(group_id_array( + group, + current_ordinal, + max_ordinal, + batch.num_rows(), + )?); } Ok(group_values) }) @@ -1537,41 +2349,48 @@ mod tests { use std::task::{Context, Poll}; use super::*; - use crate::coalesce_batches::CoalesceBatchesExec; + use crate::RecordBatchStream; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::common; use crate::common::collect; + use crate::empty::EmptyExec; use crate::execution_plan::Boundedness; use crate::expressions::col; use crate::metrics::MetricValue; - use crate::test::assert_is_pending; - use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::test::TestMemoryExec; - use crate::RecordBatchStream; + use crate::test::assert_is_pending; + use crate::test::exec::{ + BlockingExec, StatisticsExec, assert_strong_count_converges_to_zero, + }; use arrow::array::{ - DictionaryArray, Float32Array, Float64Array, Int32Array, StructArray, + DictionaryArray, Float32Array, Float64Array, Int32Array, Int64Array, StructArray, UInt32Array, UInt64Array, }; - use arrow::compute::{concat_batches, SortOptions}; - use arrow::datatypes::{DataType, Int32Type}; + use arrow::compute::{SortOptions, concat_batches}; + use arrow::datatypes::Int32Type; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; - use datafusion_common::{internal_err, DataFusionError, ScalarValue}; + use datafusion_common::{DataFusionError, internal_err}; use datafusion_execution::config::SessionConfig; use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::RuntimeEnvBuilder; + use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; + use datafusion_expr::{AggregateUDF, AggregateUDFImpl, Signature, Volatility}; + use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf; use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf}; use datafusion_functions_aggregate::median::median_udaf; + use datafusion_functions_aggregate::min_max::min_udaf; use datafusion_functions_aggregate::sum::sum_udaf; - use datafusion_physical_expr::aggregate::AggregateExprBuilder; - use datafusion_physical_expr::expressions::lit; - use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::Partitioning; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::expressions::Literal; + use crate::projection::ProjectionExec; + use datafusion_physical_expr::projection::ProjectionExpr; use futures::{FutureExt, Stream}; use insta::{allow_duplicates, assert_snapshot}; @@ -1702,6 +2521,7 @@ mod tests { vec![true, false], // (NULL, b) vec![false, false], // (a,b) ], + true, ); let aggregates = vec![Arc::new( @@ -1736,30 +2556,30 @@ mod tests { allow_duplicates! { assert_snapshot!(batches_to_sort_string(&result), @r" -+---+-----+---------------+-----------------+ -| a | b | __grouping_id | COUNT(1)[count] | -+---+-----+---------------+-----------------+ -| | 1.0 | 2 | 1 | -| | 1.0 | 2 | 1 | -| | 2.0 | 2 | 1 | -| | 2.0 | 2 | 1 | -| | 3.0 | 2 | 1 | -| | 3.0 | 2 | 1 | -| | 4.0 | 2 | 1 | -| | 4.0 | 2 | 1 | -| 2 | | 1 | 1 | -| 2 | | 1 | 1 | -| 2 | 1.0 | 0 | 1 | -| 2 | 1.0 | 0 | 1 | -| 3 | | 1 | 1 | -| 3 | | 1 | 2 | -| 3 | 2.0 | 0 | 2 | -| 3 | 3.0 | 0 | 1 | -| 4 | | 1 | 1 | -| 4 | | 1 | 2 | -| 4 | 3.0 | 0 | 1 | -| 4 | 4.0 | 0 | 2 | -+---+-----+---------------+-----------------+ + +---+-----+---------------+-----------------+ + | a | b | __grouping_id | COUNT(1)[count] | + +---+-----+---------------+-----------------+ + | | 1.0 | 2 | 1 | + | | 1.0 | 2 | 1 | + | | 2.0 | 2 | 1 | + | | 2.0 | 2 | 1 | + | | 3.0 | 2 | 1 | + | | 3.0 | 2 | 1 | + | | 4.0 | 2 | 1 | + | | 4.0 | 2 | 1 | + | 2 | | 1 | 1 | + | 2 | | 1 | 1 | + | 2 | 1.0 | 0 | 1 | + | 2 | 1.0 | 0 | 1 | + | 3 | | 1 | 1 | + | 3 | | 1 | 2 | + | 3 | 2.0 | 0 | 2 | + | 3 | 3.0 | 0 | 1 | + | 4 | | 1 | 1 | + | 4 | | 1 | 2 | + | 4 | 3.0 | 0 | 1 | + | 4 | 4.0 | 0 | 2 | + +---+-----+---------------+-----------------+ " ); } @@ -1767,22 +2587,22 @@ mod tests { allow_duplicates! { assert_snapshot!(batches_to_sort_string(&result), @r" -+---+-----+---------------+-----------------+ -| a | b | __grouping_id | COUNT(1)[count] | -+---+-----+---------------+-----------------+ -| | 1.0 | 2 | 2 | -| | 2.0 | 2 | 2 | -| | 3.0 | 2 | 2 | -| | 4.0 | 2 | 2 | -| 2 | | 1 | 2 | -| 2 | 1.0 | 0 | 2 | -| 3 | | 1 | 3 | -| 3 | 2.0 | 0 | 2 | -| 3 | 3.0 | 0 | 1 | -| 4 | | 1 | 3 | -| 4 | 3.0 | 0 | 1 | -| 4 | 4.0 | 0 | 2 | -+---+-----+---------------+-----------------+ + +---+-----+---------------+-----------------+ + | a | b | __grouping_id | COUNT(1)[count] | + +---+-----+---------------+-----------------+ + | | 1.0 | 2 | 2 | + | | 2.0 | 2 | 2 | + | | 3.0 | 2 | 2 | + | | 4.0 | 2 | 2 | + | 2 | | 1 | 2 | + | 2 | 1.0 | 0 | 2 | + | 3 | | 1 | 3 | + | 3 | 2.0 | 0 | 2 | + | 3 | 3.0 | 0 | 1 | + | 4 | | 1 | 3 | + | 4 | 3.0 | 0 | 1 | + | 4 | 4.0 | 0 | 2 | + +---+-----+---------------+-----------------+ " ); } @@ -1816,23 +2636,23 @@ mod tests { assert_snapshot!( batches_to_sort_string(&result), @r" - +---+-----+---------------+----------+ - | a | b | __grouping_id | COUNT(1) | - +---+-----+---------------+----------+ - | | 1.0 | 2 | 2 | - | | 2.0 | 2 | 2 | - | | 3.0 | 2 | 2 | - | | 4.0 | 2 | 2 | - | 2 | | 1 | 2 | - | 2 | 1.0 | 0 | 2 | - | 3 | | 1 | 3 | - | 3 | 2.0 | 0 | 2 | - | 3 | 3.0 | 0 | 1 | - | 4 | | 1 | 3 | - | 4 | 3.0 | 0 | 1 | - | 4 | 4.0 | 0 | 2 | - +---+-----+---------------+----------+ - " + +---+-----+---------------+----------+ + | a | b | __grouping_id | COUNT(1) | + +---+-----+---------------+----------+ + | | 1.0 | 2 | 2 | + | | 2.0 | 2 | 2 | + | | 3.0 | 2 | 2 | + | | 4.0 | 2 | 2 | + | 2 | | 1 | 2 | + | 2 | 1.0 | 0 | 2 | + | 3 | | 1 | 3 | + | 3 | 2.0 | 0 | 2 | + | 3 | 3.0 | 0 | 1 | + | 4 | | 1 | 3 | + | 4 | 3.0 | 0 | 1 | + | 4 | 4.0 | 0 | 2 | + +---+-----+---------------+----------+ + " ); } @@ -1851,6 +2671,7 @@ mod tests { vec![(col("a", &input_schema)?, "a".to_string())], vec![], vec![vec![false]], + false, ); let aggregates: Vec> = vec![Arc::new( @@ -1882,27 +2703,27 @@ mod tests { if spill { allow_duplicates! { assert_snapshot!(batches_to_sort_string(&result), @r" - +---+---------------+-------------+ - | a | AVG(b)[count] | AVG(b)[sum] | - +---+---------------+-------------+ - | 2 | 1 | 1.0 | - | 2 | 1 | 1.0 | - | 3 | 1 | 2.0 | - | 3 | 2 | 5.0 | - | 4 | 3 | 11.0 | - +---+---------------+-------------+ + +---+---------------+-------------+ + | a | AVG(b)[count] | AVG(b)[sum] | + +---+---------------+-------------+ + | 2 | 1 | 1.0 | + | 2 | 1 | 1.0 | + | 3 | 1 | 2.0 | + | 3 | 2 | 5.0 | + | 4 | 3 | 11.0 | + +---+---------------+-------------+ "); } } else { allow_duplicates! { assert_snapshot!(batches_to_sort_string(&result), @r" - +---+---------------+-------------+ - | a | AVG(b)[count] | AVG(b)[sum] | - +---+---------------+-------------+ - | 2 | 2 | 2.0 | - | 3 | 3 | 7.0 | - | 4 | 3 | 11.0 | - +---+---------------+-------------+ + +---+---------------+-------------+ + | a | AVG(b)[count] | AVG(b)[sum] | + +---+---------------+-------------+ + | 2 | 2 | 2.0 | + | 3 | 3 | 7.0 | + | 4 | 3 | 11.0 | + +---+---------------+-------------+ "); } }; @@ -1920,6 +2741,10 @@ mod tests { input_schema, )?); + // Verify statistics are preserved proportionally through aggregation + let final_stats = merged_aggregate.partition_statistics(None)?; + assert!(final_stats.total_byte_size.get_value().is_some()); + let task_ctx = if spill { // enlarge memory limit to let the final aggregation finish new_spill_ctx(2, 2600) @@ -1933,14 +2758,14 @@ mod tests { allow_duplicates! { assert_snapshot!(batches_to_sort_string(&result), @r" - +---+--------------------+ - | a | AVG(b) | - +---+--------------------+ - | 2 | 1.0 | - | 3 | 2.3333333333333335 | - | 4 | 3.6666666666666665 | - +---+--------------------+ - "); + +---+--------------------+ + | a | AVG(b) | + +---+--------------------+ + | 2 | 1.0 | + | 3 | 2.3333333333333335 | + | 4 | 3.6666666666666665 | + +---+--------------------+ + "); // For row 2: 3, (2 + 3 + 2) / 3 // For row 3: 4, (3 + 4 + 4) / 3 } @@ -1976,14 +2801,17 @@ mod tests { struct TestYieldingExec { /// True if this exec should yield back to runtime the first time it is polled pub yield_first: bool, - cache: PlanProperties, + cache: Arc, } impl TestYieldingExec { fn new(yield_first: bool) -> Self { let schema = some_data().0; let cache = Self::compute_properties(schema); - Self { yield_first, cache } + Self { + yield_first, + cache: Arc::new(cache), + } } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -2020,11 +2848,7 @@ mod tests { "TestYieldingExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -2053,20 +2877,19 @@ mod tests { Ok(Box::pin(stream)) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics( + &self, + partition: Option, + ) -> Result> { if partition.is_some() { - return Ok(Statistics::new_unknown(self.schema().as_ref())); + return Ok(Arc::new(Statistics::new_unknown(self.schema().as_ref()))); } let (_, batches) = some_data(); - Ok(common::compute_record_batch_statistics( + Ok(Arc::new(common::compute_record_batch_statistics( &[batches], &self.schema(), None, - )) + ))) } } @@ -2192,6 +3015,7 @@ mod tests { vec![(col("a", &input_schema)?, "a".to_string())], vec![], vec![vec![false]], + false, ); // something that allocates within the aggregator @@ -2212,7 +3036,7 @@ mod tests { ] { let n_aggr = aggregates.len(); let partial_aggregate = Arc::new(AggregateExec::try_new( - AggregateMode::Partial, + AggregateMode::Single, groups, aggregates, vec![None; n_aggr], @@ -2250,6 +3074,191 @@ mod tests { Ok(()) } + #[tokio::test] + async fn partial_grouped_aggregate_uses_raw_partial_stream() -> Result<()> { + let (schema, batches) = some_data(); + let input = TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?; + let group_by = + PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); + let udaf = Arc::new(AggregateUDF::from(InputTypeAssertingUdaf::new( + vec![DataType::Float64], + vec![DataType::Int32], + DataType::Int64, + ))); + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(udaf, vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("input_type_asserting(b)") + .build()?, + )]; + + let partial_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggregates.clone(), + vec![None], + input, + Arc::clone(&schema), + )?); + let task_ctx = Arc::new( + TaskContext::default().with_session_config( + SessionConfig::new() + .with_batch_size(2) + .set_bool("datafusion.execution.enable_migration_aggregate", true), + ), + ); + + let partial_stream = partial_aggregate.execute_typed(0, &task_ctx)?; + assert!(matches!(partial_stream, StreamType::PartialHash(_))); + + let fallback_task_ctx = Arc::new( + TaskContext::default().with_session_config( + SessionConfig::new() + .with_batch_size(2) + .set_bool("datafusion.execution.enable_migration_aggregate", false), + ), + ); + let stream = partial_aggregate.execute_typed(0, &fallback_task_ctx)?; + assert!(matches!(stream, StreamType::GroupedHash(_))); + + let stream: SendableRecordBatchStream = partial_stream.into(); + let batches = collect(stream).await?; + assert_eq!( + batches + .iter() + .map(RecordBatch::num_rows) + .collect::>(), + vec![2, 1] + ); + assert_eq!(batches.iter().map(RecordBatch::num_rows).sum::(), 3); + + let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); + let final_aggregate = AggregateExec::try_new( + AggregateMode::Final, + group_by.as_final(), + aggregates, + vec![None], + merge, + Arc::clone(&schema), + )?; + + let final_stream = final_aggregate.execute_typed(0, &task_ctx)?; + assert!(matches!(final_stream, StreamType::FinalHash(_))); + + let stream = final_aggregate.execute_typed(0, &fallback_task_ctx)?; + assert!(matches!(stream, StreamType::GroupedHash(_))); + + let stream: SendableRecordBatchStream = final_stream.into(); + let batches = collect(stream).await?; + assert_eq!( + batches + .iter() + .map(RecordBatch::num_rows) + .collect::>(), + vec![2, 1] + ); + assert_eq!(batches.iter().map(RecordBatch::num_rows).sum::(), 3); + + Ok(()) + } + + #[tokio::test] + async fn limited_distinct_aggregate_uses_migrated_hash_streams() -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32, false)])); + let input_batches = vec![ + RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(UInt32Array::from(vec![1, 2, 1]))], + )?, + RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(UInt32Array::from(vec![3, 4]))], + )?, + ]; + let group_by = + PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); + let task_ctx = Arc::new( + TaskContext::default().with_session_config( + SessionConfig::new() + .set_bool("datafusion.execution.enable_migration_aggregate", true), + ), + ); + + let partial_input = TestMemoryExec::try_new_exec( + std::slice::from_ref(&input_batches), + Arc::clone(&schema), + None, + )?; + let partial_aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + vec![], + vec![], + partial_input, + Arc::clone(&schema), + )? + .with_limit_options(Some(LimitOptions::new(2))), + ); + + let partial_stream = partial_aggregate.execute_typed(0, &task_ctx)?; + assert!(matches!(partial_stream, StreamType::PartialHash(_))); + let stream: SendableRecordBatchStream = partial_stream.into(); + let partial_output = collect(stream).await?; + assert_eq!( + partial_output + .iter() + .map(RecordBatch::num_rows) + .sum::(), + 2 + ); + assert_snapshot!(batches_to_sort_string(&partial_output), @r" ++---+ +| a | ++---+ +| 1 | +| 2 | ++---+ +"); + + let final_input = + TestMemoryExec::try_new_exec(&[input_batches], Arc::clone(&schema), None)?; + let final_aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by.as_final(), + vec![], + vec![], + final_input, + Arc::clone(&schema), + )? + .with_limit_options(Some(LimitOptions::new(2))), + ); + + let final_stream = final_aggregate.execute_typed(0, &task_ctx)?; + assert!(matches!(final_stream, StreamType::FinalHash(_))); + let stream: SendableRecordBatchStream = final_stream.into(); + let final_output = collect(stream).await?; + assert_eq!( + final_output + .iter() + .map(RecordBatch::num_rows) + .sum::(), + 2 + ); + assert_snapshot!(batches_to_sort_string(&final_output), @r" ++---+ +| a | ++---+ +| 1 | +| 2 | ++---+ +"); + + Ok(()) + } + #[tokio::test] async fn test_drop_cancel_without_groups() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); @@ -2327,17 +3336,9 @@ mod tests { #[tokio::test] async fn run_first_last_multi_partitions() -> Result<()> { - for use_coalesce_batches in [false, true] { - for is_first_acc in [false, true] { - for spill in [false, true] { - first_last_multi_partitions( - use_coalesce_batches, - is_first_acc, - spill, - 4200, - ) - .await? - } + for is_first_acc in [false, true] { + for spill in [false, true] { + first_last_multi_partitions(is_first_acc, spill, 4200).await? } } Ok(()) @@ -2380,15 +3381,148 @@ mod tests { .map(Arc::new) } - // This function either constructs the physical plan below, - // - // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]", - // " CoalesceBatchesExec: target_batch_size=1024", - // " CoalescePartitionsExec", - // " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None", - // " DataSourceExec: partitions=4, partition_sizes=[1, 1, 1, 1]", - // - // or + fn first_value_agg_expr( + schema: &SchemaRef, + column: &str, + alias: &str, + human_display: Option<&str>, + human_display_alias: Option<&str>, + ) -> Result { + let mut builder = + AggregateExprBuilder::new(first_value_udaf(), vec![col(column, schema)?]) + .order_by(vec![PhysicalSortExpr { + expr: col(column, schema)?, + options: SortOptions::new(false, false), + }]) + .schema(Arc::clone(schema)) + .alias(alias); + + if let Some(human_display) = human_display { + builder = builder.human_display(human_display); + } + if let Some(human_display_alias) = human_display_alias { + builder = builder.human_display_alias(human_display_alias); + } + + builder.build() + } + + #[test] + fn test_reverse_expr_preserves_aliased_human_display() -> Result<()> { + let schema = create_test_schema()?; + let agg = first_value_agg_expr( + &schema, + "b", + "agg", + Some("first_value(b) ORDER BY [b ASC NULLS LAST]"), + Some("agg"), + )?; + + let reversed = agg.reverse_expr().expect("expected reverse expr"); + + assert_eq!(reversed.name(), "agg"); + assert_eq!(reversed.human_display_alias(), Some("agg")); + assert_eq!( + format_tree_aggregate_expr(&reversed), + "last_value(b) ORDER BY [b DESC NULLS FIRST] as agg" + ); + assert_eq!( + reversed.human_display(), + Some("last_value(b) ORDER BY [b DESC NULLS FIRST]") + ); + + Ok(()) + } + + #[test] + fn test_reverse_expr_does_not_rewrite_column_names_in_human_display() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "first_value_col", + DataType::Int32, + true, + )])); + let agg = first_value_agg_expr( + &schema, + "first_value_col", + "agg", + Some( + "first_value(first_value_col) ORDER BY [first_value_col ASC NULLS LAST]", + ), + Some("agg"), + )?; + + let reversed = agg.reverse_expr().expect("expected reverse expr"); + + assert_eq!(reversed.name(), "agg"); + assert_eq!( + reversed.human_display(), + Some( + "last_value(first_value_col) ORDER BY [first_value_col DESC NULLS FIRST]" + ) + ); + assert_eq!( + format_tree_aggregate_expr(&reversed), + "last_value(first_value_col) ORDER BY [first_value_col DESC NULLS FIRST] as agg" + ); + + Ok(()) + } + + #[test] + fn test_empty_human_display_is_treated_as_absent() -> Result<()> { + let schema = create_test_schema()?; + let agg = first_value_agg_expr(&schema, "b", "agg", Some(""), None)?; + + assert_eq!(agg.human_display(), None); + assert_eq!(format_tree_aggregate_expr(&agg), "agg"); + + Ok(()) + } + + #[test] + fn test_human_display_alias_must_match_name() -> Result<()> { + let schema = create_test_schema()?; + let error = first_value_agg_expr( + &schema, + "b", + "agg", + Some("first_value(b) ORDER BY [b ASC NULLS LAST]"), + Some("other_alias"), + ) + .unwrap_err(); + + assert!( + error + .to_string() + .contains("aggregate human_display_alias must match") + ); + + Ok(()) + } + + #[test] + fn test_reverse_expr_preserves_non_aliased_display_path() -> Result<()> { + let schema = create_test_schema()?; + let agg = first_value_agg_expr( + &schema, + "b", + "first_value(b) ORDER BY [b ASC NULLS LAST]", + None, + None, + )?; + + let reversed = agg.reverse_expr().expect("expected reverse expr"); + + assert_eq!( + reversed.name(), + "last_value(b) ORDER BY [b DESC NULLS FIRST]" + ); + assert_eq!(reversed.human_display(), None); + + Ok(()) + } + + // This function constructs the physical plan below, // // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]", // " CoalescePartitionsExec", @@ -2398,7 +3532,6 @@ mod tests { // and checks whether the function `merge_batch` works correctly for // FIRST_VALUE and LAST_VALUE functions. async fn first_last_multi_partitions( - use_coalesce_batches: bool, is_first_acc: bool, spill: bool, max_memory: usize, @@ -2446,13 +3579,8 @@ mod tests { memory_exec, Arc::clone(&schema), )?); - let coalesce = if use_coalesce_batches { - let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec)); - Arc::new(CoalesceBatchesExec::new(coalesce, 1024)) as Arc - } else { - Arc::new(CoalescePartitionsExec::new(aggregate_exec)) - as Arc - }; + let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec)) + as Arc; let aggregate_final = Arc::new(AggregateExec::try_new( AggregateMode::Final, groups, @@ -2466,26 +3594,26 @@ mod tests { if is_first_acc { allow_duplicates! { assert_snapshot!(batches_to_string(&result), @r" - +---+--------------------------------------------+ - | a | first_value(b) ORDER BY [b ASC NULLS LAST] | - +---+--------------------------------------------+ - | 2 | 0.0 | - | 3 | 1.0 | - | 4 | 3.0 | - +---+--------------------------------------------+ - "); + +---+--------------------------------------------+ + | a | first_value(b) ORDER BY [b ASC NULLS LAST] | + +---+--------------------------------------------+ + | 2 | 0.0 | + | 3 | 1.0 | + | 4 | 3.0 | + +---+--------------------------------------------+ + "); } } else { allow_duplicates! { assert_snapshot!(batches_to_string(&result), @r" - +---+-------------------------------------------+ - | a | last_value(b) ORDER BY [b ASC NULLS LAST] | - +---+-------------------------------------------+ - | 2 | 3.0 | - | 3 | 5.0 | - | 4 | 6.0 | - +---+-------------------------------------------+ - "); + +---+-------------------------------------------+ + | a | last_value(b) ORDER BY [b ASC NULLS LAST] | + +---+-------------------------------------------+ + | 2 | 3.0 | + | 3 | 5.0 | + | 4 | 6.0 | + +---+-------------------------------------------+ + "); } }; Ok(()) @@ -2636,14 +3764,16 @@ mod tests { vec![true, false, true], vec![true, true, false], ], + true, ); - let aggregates: Vec> = - vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) + let aggregates: Vec> = vec![ + AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) .schema(Arc::clone(&schema)) .alias("1") .build() - .map(Arc::new)?]; + .map(Arc::new)?, + ]; let input_batches = (0..4) .map(|_| { @@ -2672,13 +3802,13 @@ mod tests { allow_duplicates! { assert_snapshot!(batches_to_sort_string(&output), @r" - +-----+-----+-------+---------------+-------+ - | a | b | const | __grouping_id | 1 | - +-----+-----+-------+---------------+-------+ - | | | 1 | 6 | 32768 | - | | 0.0 | | 5 | 32768 | - | 0.0 | | | 3 | 32768 | - +-----+-----+-------+---------------+-------+ + +-----+-----+-------+---------------+-------+ + | a | b | const | __grouping_id | 1 | + +-----+-----+-------+---------------+-------+ + | | | 1 | 6 | 32768 | + | | 0.0 | | 5 | 32768 | + | 0.0 | | | 3 | 32768 | + +-----+-----+-------+---------------+-------+ "); } @@ -2759,14 +3889,13 @@ mod tests { "labels".to_string(), )]); - let aggr_expr = vec![AggregateExprBuilder::new( - sum_udaf(), - vec![col("value", &batch.schema())?], - ) - .schema(Arc::clone(&batch.schema())) - .alias(String::from("SUM(value)")) - .build() - .map(Arc::new)?]; + let aggr_expr = vec![ + AggregateExprBuilder::new(sum_udaf(), vec![col("value", &batch.schema())?]) + .schema(Arc::clone(&batch.schema())) + .alias(String::from("SUM(value)")) + .build() + .map(Arc::new)?, + ]; let input = TestMemoryExec::try_new_exec( &[vec![batch.clone()]], @@ -2788,13 +3917,13 @@ mod tests { allow_duplicates! { assert_snapshot!(batches_to_string(&output), @r" - +--------------+------------+ - | labels | SUM(value) | - +--------------+------------+ - | {a: a, b: b} | 2 | - | {a: , b: c} | 1 | - +--------------+------------+ - "); + +--------------+------------+ + | labels | SUM(value) | + +--------------+------------+ + | {a: a, b: b} | 2 | + | {a: , b: c} | 1 | + +--------------+------------+ + "); } Ok(()) @@ -2810,14 +3939,13 @@ mod tests { let group_by = PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]); - let aggr_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) - .schema(Arc::clone(&schema)) - .alias(String::from("COUNT(val)")) - .build() - .map(Arc::new)?, - ]; + let aggr_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) + .schema(Arc::clone(&schema)) + .alias(String::from("COUNT(val)")) + .build() + .map(Arc::new)?, + ]; let input_data = vec![ RecordBatch::try_new( @@ -2859,8 +3987,11 @@ mod tests { &ScalarValue::Float64(Some(0.1)), ); - let ctx = TaskContext::default().with_session_config(session_config); - let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?; + let ctx = Arc::new(TaskContext::default().with_session_config(session_config)); + let stream: SendableRecordBatchStream = Box::pin( + GroupedHashAggregateStream::new(aggregate_exec.as_ref(), &ctx, 0)?, + ); + let output = collect(stream).await?; allow_duplicates! { assert_snapshot!(batches_to_string(&output), @r" @@ -2890,14 +4021,13 @@ mod tests { let group_by = PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]); - let aggr_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) - .schema(Arc::clone(&schema)) - .alias(String::from("COUNT(val)")) - .build() - .map(Arc::new)?, - ]; + let aggr_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) + .schema(Arc::clone(&schema)) + .alias(String::from("COUNT(val)")) + .build() + .map(Arc::new)?, + ]; let input_data = vec![ RecordBatch::try_new( @@ -2947,8 +4077,11 @@ mod tests { &ScalarValue::Float64(Some(0.1)), ); - let ctx = TaskContext::default().with_session_config(session_config); - let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?; + let ctx = Arc::new(TaskContext::default().with_session_config(session_config)); + let stream: SendableRecordBatchStream = Box::pin( + GroupedHashAggregateStream::new(aggregate_exec.as_ref(), &ctx, 0)?, + ); + let output = collect(stream).await?; allow_duplicates! { assert_snapshot!(batches_to_string(&output), @r" @@ -2969,35 +4102,119 @@ mod tests { Ok(()) } - #[test] - fn group_exprs_nullable() -> Result<()> { - let input_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Float32, false), - Field::new("b", DataType::Float32, false), + /// When `skip_partial_aggregation_probe_ratio_threshold` is set to 1.0, + /// the feature must be effectively disabled: even with 100% cardinality + /// (every row is a unique group), no rows should be skipped. + #[tokio::test] + async fn test_skip_aggregation_disabled_at_threshold_one() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int32, true), + Field::new("val", DataType::Int32, true), ])); - let aggr_expr = - vec![ - AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?]) - .schema(Arc::clone(&input_schema)) - .alias("COUNT(a)") - .build() - .map(Arc::new)?, - ]; + let group_by = + PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]); - let grouping_set = PhysicalGroupBy::new( - vec![ - (col("a", &input_schema)?, "a".to_string()), - (col("b", &input_schema)?, "b".to_string()), - ], - vec![ - (lit(ScalarValue::Float32(None)), "a".to_string()), - (lit(ScalarValue::Float32(None)), "b".to_string()), + let aggr_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) + .schema(Arc::clone(&schema)) + .alias(String::from("COUNT(val)")) + .build() + .map(Arc::new)?, + ]; + + // Two batches are required: batch 1 triggers the probe threshold so the + // skip decision is evaluated; batch 2 is what would be skipped on main + // (where >= caused threshold=1.0 to still skip at 100% cardinality). + // All rows have unique keys => ratio = 1.0 (100% cardinality). + let input_data = vec![ + // Batch 1: fires the probe check (ratio = 5/5 = 1.0) + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])), + Arc::new(Int32Array::from(vec![0, 0, 0, 0, 0])), + ], + ) + .unwrap(), + // Batch 2: would be skipped if threshold=1.0 did not disable the feature + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![6, 7, 8, 9, 10])), + Arc::new(Int32Array::from(vec![0, 0, 0, 0, 0])), + ], + ) + .unwrap(), + ]; + + let input = + TestMemoryExec::try_new_exec(&[input_data], Arc::clone(&schema), None)?; + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by, + aggr_expr, + vec![None], + Arc::clone(&input) as Arc, + schema, + )?); + + let session_config = SessionConfig::default() + .set( + "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", + &ScalarValue::Int64(Some(1)), + ) + .set( + "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold", + &ScalarValue::Float64(Some(1.0)), + ); + + let ctx = TaskContext::default().with_session_config(session_config); + collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?; + + let metrics = aggregate_exec.metrics().unwrap(); + let skipped_rows = metrics + .sum_by_name("skipped_aggregation_rows") + .map(|m| m.as_usize()) + .unwrap_or(0); + + assert_eq!( + skipped_rows, 0, + "threshold=1.0 should disable skip aggregation, but {skipped_rows} rows were skipped" + ); + + Ok(()) + } + + #[test] + fn group_exprs_nullable() -> Result<()> { + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float32, false), + ])); + + let aggr_expr = vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .alias("COUNT(a)") + .build() + .map(Arc::new)?, + ]; + + let grouping_set = PhysicalGroupBy::new( + vec![ + (col("a", &input_schema)?, "a".to_string()), + (col("b", &input_schema)?, "b".to_string()), + ], + vec![ + (lit(ScalarValue::Float32(None)), "a".to_string()), + (lit(ScalarValue::Float32(None)), "b".to_string()), ], vec![ vec![false, true], // (a, NULL) vec![false, false], // (a,b) ], + true, ); let aggr_schema = create_schema( &input_schema, @@ -3049,18 +4266,16 @@ mod tests { vec![(col("a", &schema)?, "a".to_string())], vec![], vec![vec![false]], + false, ); // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count). let aggregates: Vec> = vec![ Arc::new( - AggregateExprBuilder::new( - datafusion_functions_aggregate::min_max::min_udaf(), - vec![col("b", &schema)?], - ) - .schema(Arc::clone(&schema)) - .alias("MIN(b)") - .build()?, + AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("MIN(b)") + .build()?, ), Arc::new( AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) @@ -3097,13 +4312,13 @@ mod tests { allow_duplicates! { assert_snapshot!(batches_to_string(&result), @r" - +---+--------+--------+ - | a | MIN(b) | AVG(b) | - +---+--------+--------+ - | 2 | 1.0 | 1.0 | - | 3 | 2.0 | 2.0 | - | 4 | 3.0 | 3.5 | - +---+--------+--------+ + +---+--------+--------+ + | a | MIN(b) | AVG(b) | + +---+--------+--------+ + | 2 | 1.0 | 1.0 | + | 3 | 2.0 | 2.0 | + | 4 | 3.0 | 3.5 | + +---+--------+--------+ "); } @@ -3130,7 +4345,9 @@ mod tests { "Expected spill but SpillCount metric not found or SpillCount was 0." ); } else if !expect_spill && spill_count > 0 { - panic!("Expected no spill but found SpillCount metric with value greater than 0."); + panic!( + "Expected no spill but found SpillCount metric with value greater than 0." + ); } } else { panic!("No metrics returned from the operator; cannot verify spilling."); @@ -3145,4 +4362,1648 @@ mod tests { run_test_with_spill_pool_if_necessary(20_000, false).await?; Ok(()) } + + #[tokio::test] + async fn test_grouped_aggregation_respects_memory_limit() -> Result<()> { + // test with spill + fn create_record_batch( + schema: &Arc, + data: (Vec, Vec), + ) -> Result { + Ok(RecordBatch::try_new( + Arc::clone(schema), + vec![ + Arc::new(UInt32Array::from(data.0)), + Arc::new(Float64Array::from(data.1)), + ], + )?) + } + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::Float64, false), + ])); + + let batches = vec![ + create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?, + create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?, + ]; + let plan: Arc = + TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?; + let proj = ProjectionExec::try_new( + vec![ + ProjectionExpr::new(lit("0"), "l".to_string()), + ProjectionExpr::new_from_expression(col("a", &schema)?, &schema)?, + ProjectionExpr::new_from_expression(col("b", &schema)?, &schema)?, + ], + plan, + )?; + let plan: Arc = Arc::new(proj); + let schema = plan.schema(); + + let grouping_set = PhysicalGroupBy::new( + vec![ + (col("l", &schema)?, "l".to_string()), + (col("a", &schema)?, "a".to_string()), + ], + vec![], + vec![vec![false, false]], + false, + ); + + // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count). + let aggregates: Vec> = vec![ + Arc::new( + AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("MIN(b)") + .build()?, + ), + Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(b)") + .build()?, + ), + ]; + + let single_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + grouping_set, + aggregates, + vec![None, None], + plan, + Arc::clone(&schema), + )?); + + let batch_size = 2; + let memory_pool = Arc::new(FairSpillPool::new(2000)); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )), + ); + + let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await; + match result { + Ok(result) => { + assert_spill_count_metric(true, single_aggregate); + + allow_duplicates! { + assert_snapshot!(batches_to_string(&result), @r" + +---+---+--------+--------+ + | l | a | MIN(b) | AVG(b) | + +---+---+--------+--------+ + | 0 | 2 | 1.0 | 1.0 | + | 0 | 3 | 2.0 | 2.0 | + | 0 | 4 | 3.0 | 3.5 | + +---+---+--------+--------+ + "); + } + } + Err(e) => assert!(matches!(e, DataFusionError::ResourcesExhausted(_))), + } + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_statistics_edge_cases() -> Result<()> { + use datafusion_common::ColumnStatistics; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Float64, false), + ])); + + let absent_byte_stats = Statistics { + num_rows: Precision::Exact(100), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + let agg = build_test_aggregate( + &schema, + absent_byte_stats, + PhysicalGroupBy::default(), + None, + )?; + let stats = agg.partition_statistics(None)?; + assert_eq!(stats.total_byte_size, Precision::Absent); + + let zero_row_stats = Statistics { + num_rows: Precision::Exact(0), + total_byte_size: Precision::Exact(0), + column_statistics: vec![ + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }; + let agg_zero = build_test_aggregate( + &schema, + zero_row_stats, + PhysicalGroupBy::default(), + None, + )?; + let stats_zero = agg_zero.partition_statistics(None)?; + assert_eq!(stats_zero.total_byte_size, Precision::Absent); + + Ok(()) + } + + fn build_test_aggregate( + schema: &SchemaRef, + stats: Statistics, + group_by: PhysicalGroupBy, + limit: Option, + ) -> Result { + let input = Arc::new(StatisticsExec::new(stats, (**schema).clone())) + as Arc; + + let mut agg = AggregateExec::try_new( + AggregateMode::Final, + group_by, + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("a", schema)?]) + .schema(Arc::clone(schema)) + .alias("COUNT(a)") + .build()?, + )], + vec![None], + input, + Arc::clone(schema), + )?; + + if let Some(limit) = limit { + agg = agg.with_limit_options(Some(limit)); + } + + Ok(agg) + } + + fn simple_group_by(schema: &SchemaRef, cols: &[&str]) -> PhysicalGroupBy { + if cols.is_empty() { + PhysicalGroupBy::default() + } else { + PhysicalGroupBy::new_single( + cols.iter() + .map(|name| { + ( + col(name, schema).unwrap() as Arc, + name.to_string(), + ) + }) + .collect(), + ) + } + } + + #[test] + fn test_aggregate_cardinality_estimation() -> Result<()> { + use datafusion_common::ColumnStatistics; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + struct TestCase { + name: &'static str, + input_rows: Precision, + col_a_stats: ColumnStatistics, + col_b_stats: ColumnStatistics, + group_by_cols: Vec<&'static str>, + limit_options: Option, + expected_num_rows: Precision, + } + + let cases = vec![ + // --- NDV-based estimation --- + TestCase { + name: "single group-by col with NDV tightens estimate", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(500), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Inexact(500), + }, + TestCase { + name: "multi-col group-by multiplies NDVs", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics { + distinct_count: Precision::Exact(50), + ..ColumnStatistics::new_unknown() + }, + group_by_cols: vec!["a", "b"], + limit_options: None, + expected_num_rows: Precision::Inexact(5_000), + }, + TestCase { + name: "NDV product capped by input rows", + input_rows: Precision::Exact(200), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics { + distinct_count: Precision::Exact(50), + ..ColumnStatistics::new_unknown() + }, + group_by_cols: vec!["a", "b"], + limit_options: None, + expected_num_rows: Precision::Inexact(200), + }, + TestCase { + name: "null adjustment adds +1 per column", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(99), + null_count: Precision::Exact(10), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + // 99 + 1 (null adjustment) = 100 + expected_num_rows: Precision::Inexact(100), + }, + TestCase { + name: "null adjustment on multiple columns", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(99), + null_count: Precision::Exact(5), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics { + distinct_count: Precision::Exact(49), + null_count: Precision::Exact(3), + ..ColumnStatistics::new_unknown() + }, + group_by_cols: vec!["a", "b"], + limit_options: None, + // (99+1) * (49+1) = 100 * 50 = 5000 + expected_num_rows: Precision::Inexact(5_000), + }, + TestCase { + name: "zero null_count means no adjustment", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(100), + null_count: Precision::Exact(0), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Inexact(100), + }, + // --- Bail-out: partial NDV stats (Spark-style) --- + TestCase { + name: "bail out when one group-by col lacks NDV", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a", "b"], + limit_options: None, + expected_num_rows: Precision::Inexact(1_000_000), + }, + TestCase { + name: "bail out when all group-by cols lack NDV", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Inexact(1_000_000), + }, + // --- TopK limit capping --- + TestCase { + name: "TopK limit caps output rows", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: Some(LimitOptions::new(10)), + expected_num_rows: Precision::Inexact(10), + }, + TestCase { + name: "NDV + TopK limit: min(NDV, limit) when NDV < limit", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(5), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: Some(LimitOptions::new(10)), + expected_num_rows: Precision::Inexact(5), + }, + TestCase { + name: "NDV + TopK limit: min(NDV, limit) when limit < NDV", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(500), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: Some(LimitOptions::new(10)), + expected_num_rows: Precision::Inexact(10), + }, + // --- Absent input rows --- + TestCase { + name: "absent input rows without limit stays absent", + input_rows: Precision::Absent, + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Absent, + }, + TestCase { + name: "absent input rows with TopK limit gives inexact(limit)", + input_rows: Precision::Absent, + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: Some(LimitOptions::new(10)), + expected_num_rows: Precision::Inexact(10), + }, + // --- No group-by (global aggregation) --- + TestCase { + name: "no group-by cols (Final mode) returns Exact(1)", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec![], + limit_options: None, + expected_num_rows: Precision::Exact(1), + }, + // --- One input row --- + TestCase { + name: "one input row returns Exact(1)", + input_rows: Precision::Exact(1), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(1), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Exact(1), + }, + // --- Zero input rows --- + TestCase { + name: "zero input rows returns Exact(0)", + input_rows: Precision::Exact(0), + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Exact(0), + }, + // --- Inexact NDV stats --- + TestCase { + name: "inexact NDV still used for estimation", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Inexact(200), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Inexact(200), + }, + TestCase { + name: "inexact NDV combined with limit", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Inexact(200), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: Some(LimitOptions::new(10)), + expected_num_rows: Precision::Inexact(10), + }, + // --- NDV zero column (all-null) --- + TestCase { + name: "all-null column contributes 1 to the product, not 0", + input_rows: Precision::Exact(1_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(0), + null_count: Precision::Exact(1_000), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics { + distinct_count: Precision::Exact(50), + ..ColumnStatistics::new_unknown() + }, + group_by_cols: vec!["a", "b"], + limit_options: None, + // NDV(a)=0 with nulls => max(0+1, 1)=1, NDV(b)=50 => 1*50=50 + expected_num_rows: Precision::Inexact(50), + }, + // --- Absent num_rows with NDV --- + TestCase { + name: "absent num_rows falls back to NDV estimate", + input_rows: Precision::Absent, + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Inexact(100), + }, + TestCase { + name: "absent num_rows with NDV and limit returns min(ndv, limit)", + input_rows: Precision::Absent, + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: Some(LimitOptions::new(10)), + expected_num_rows: Precision::Inexact(10), + }, + ]; + + for case in cases { + let input_stats = Statistics { + num_rows: case.input_rows, + total_byte_size: Precision::Inexact(1_000_000), + column_statistics: vec![ + case.col_a_stats.clone(), + case.col_b_stats.clone(), + ], + }; + + let group_by = simple_group_by(&schema, &case.group_by_cols); + let agg = + build_test_aggregate(&schema, input_stats, group_by, case.limit_options)?; + + let stats = agg.partition_statistics(None)?; + assert_eq!( + stats.num_rows, case.expected_num_rows, + "FAILED: '{}' — expected {:?}, got {:?}", + case.name, case.expected_num_rows, stats.num_rows + ); + } + + Ok(()) + } + + #[test] + fn test_aggregate_stats_distinct_count_propagation() -> Result<()> { + use datafusion_common::ColumnStatistics; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let input_stats = Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Inexact(10000), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Exact(100), + null_count: Precision::Exact(5), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics::new_unknown(), + ], + }; + let agg = build_test_aggregate( + &schema, + input_stats, + simple_group_by(&schema, &["a"]), + None, + )?; + + let stats = agg.partition_statistics(None)?; + assert_eq!( + stats.column_statistics[0].distinct_count, + Precision::Exact(100), + "distinct_count should be propagated from child for group-by columns" + ); + + Ok(()) + } + + #[test] + fn test_aggregate_stats_grouping_sets() -> Result<()> { + use datafusion_common::ColumnStatistics; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let input_stats = Statistics { + num_rows: Precision::Exact(1_000_000), + total_byte_size: Precision::Inexact(1_000_000), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics { + distinct_count: Precision::Exact(50), + ..ColumnStatistics::new_unknown() + }, + ], + }; + + // CUBE-like grouping set: (a, NULL), (NULL, b), (a, b) — 3 groups + let grouping_set = PhysicalGroupBy::new( + vec![ + (col("a", &schema)? as Arc, "a".to_string()), + (col("b", &schema)? as Arc, "b".to_string()), + ], + vec![ + (lit(ScalarValue::Int32(None)), "a".to_string()), + (lit(ScalarValue::Int32(None)), "b".to_string()), + ], + vec![ + vec![false, true], // (a, NULL) + vec![true, false], // (NULL, b) + vec![false, false], // (a, b) + ], + true, + ); + + let agg = build_test_aggregate(&schema, input_stats, grouping_set, None)?; + + let stats = agg.partition_statistics(None)?; + // Per-set NDV: (a,NULL)=100, (NULL,b)=50, (a,b)=100*50=5000 + // Total = 100 + 50 + 5000 = 5150 + assert_eq!( + stats.num_rows, + Precision::Inexact(5_150), + "grouping sets should sum per-set NDV products" + ); + + Ok(()) + } + + #[test] + fn test_aggregate_stats_non_column_expr_bails_out() -> Result<()> { + use datafusion_common::ColumnStatistics; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::BinaryExpr; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let input_stats = Statistics { + num_rows: Precision::Exact(1_000_000), + total_byte_size: Precision::Inexact(1_000_000), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics { + distinct_count: Precision::Exact(50), + ..ColumnStatistics::new_unknown() + }, + ], + }; + + // GROUP BY (a + b) — not a direct column reference + let expr_a_plus_b: Arc = Arc::new(BinaryExpr::new( + col("a", &schema)?, + Operator::Plus, + col("b", &schema)?, + )); + + let group_by = + PhysicalGroupBy::new_single(vec![(expr_a_plus_b, "a+b".to_string())]); + let agg = build_test_aggregate(&schema, input_stats, group_by, None)?; + + let stats = agg.partition_statistics(None)?; + assert_eq!( + stats.num_rows, + Precision::Inexact(1_000_000), + "non-column group-by expression should bail out to input_rows" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_order_is_retained_when_spilling() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Int64, false), + Field::new("c", DataType::Int64, false), + ])); + + let batches = vec![vec![ + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![2])), + Arc::new(Int64Array::from(vec![2])), + Arc::new(Int64Array::from(vec![1])), + ], + )?, + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![1])), + Arc::new(Int64Array::from(vec![1])), + Arc::new(Int64Array::from(vec![1])), + ], + )?, + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![0])), + Arc::new(Int64Array::from(vec![0])), + Arc::new(Int64Array::from(vec![1])), + ], + )?, + ]]; + let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?; + let scan = scan.try_with_sort_information(vec![ + LexOrdering::new([PhysicalSortExpr::new( + col("b", schema.as_ref())?, + SortOptions::default().desc(), + )]) + .unwrap(), + ])?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new( + vec![ + (col("b", schema.as_ref())?, "b".to_string()), + (col("c", schema.as_ref())?, "c".to_string()), + ], + vec![], + vec![vec![false, false]], + false, + ), + vec![Arc::new( + AggregateExprBuilder::new(sum_udaf(), vec![col("c", schema.as_ref())?]) + .schema(Arc::clone(&schema)) + .alias("SUM(c)") + .build()?, + )], + vec![None], + Arc::new(scan) as Arc, + Arc::clone(&schema), + )?); + + let task_ctx = new_spill_ctx(1, 600); + let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await?; + assert_spill_count_metric(true, aggr); + + allow_duplicates! { + assert_snapshot!(batches_to_string(&result), @r" + +---+---+--------+ + | b | c | SUM(c) | + +---+---+--------+ + | 2 | 1 | 1 | + | 1 | 1 | 1 | + | 0 | 1 | 1 | + +---+---+--------+ + "); + } + Ok(()) + } + + /// Tests that when the memory pool is too small to accommodate the sort + /// reservation during spill, the error is properly propagated as + /// ResourcesExhausted rather than silently exceeding memory limits. + #[tokio::test] + async fn test_sort_reservation_fails_during_spill() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("g", DataType::Int64, false), + Field::new("a", DataType::Float64, false), + Field::new("b", DataType::Float64, false), + Field::new("c", DataType::Float64, false), + Field::new("d", DataType::Float64, false), + Field::new("e", DataType::Float64, false), + ])); + + let batches = vec![vec![ + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![1])), + Arc::new(Float64Array::from(vec![10.0])), + Arc::new(Float64Array::from(vec![20.0])), + Arc::new(Float64Array::from(vec![30.0])), + Arc::new(Float64Array::from(vec![40.0])), + Arc::new(Float64Array::from(vec![50.0])), + ], + )?, + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![2])), + Arc::new(Float64Array::from(vec![11.0])), + Arc::new(Float64Array::from(vec![21.0])), + Arc::new(Float64Array::from(vec![31.0])), + Arc::new(Float64Array::from(vec![41.0])), + Arc::new(Float64Array::from(vec![51.0])), + ], + )?, + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![3])), + Arc::new(Float64Array::from(vec![12.0])), + Arc::new(Float64Array::from(vec![22.0])), + Arc::new(Float64Array::from(vec![32.0])), + Arc::new(Float64Array::from(vec![42.0])), + Arc::new(Float64Array::from(vec![52.0])), + ], + )?, + ]]; + + let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new( + vec![(col("g", schema.as_ref())?, "g".to_string())], + vec![], + vec![vec![false]], + false, + ), + vec![ + Arc::new( + AggregateExprBuilder::new( + avg_udaf(), + vec![col("a", schema.as_ref())?], + ) + .schema(Arc::clone(&schema)) + .alias("AVG(a)") + .build()?, + ), + Arc::new( + AggregateExprBuilder::new( + avg_udaf(), + vec![col("b", schema.as_ref())?], + ) + .schema(Arc::clone(&schema)) + .alias("AVG(b)") + .build()?, + ), + Arc::new( + AggregateExprBuilder::new( + avg_udaf(), + vec![col("c", schema.as_ref())?], + ) + .schema(Arc::clone(&schema)) + .alias("AVG(c)") + .build()?, + ), + Arc::new( + AggregateExprBuilder::new( + avg_udaf(), + vec![col("d", schema.as_ref())?], + ) + .schema(Arc::clone(&schema)) + .alias("AVG(d)") + .build()?, + ), + Arc::new( + AggregateExprBuilder::new( + avg_udaf(), + vec![col("e", schema.as_ref())?], + ) + .schema(Arc::clone(&schema)) + .alias("AVG(e)") + .build()?, + ), + ], + vec![None, None, None, None, None], + Arc::new(scan) as Arc, + Arc::clone(&schema), + )?); + + // Pool must be large enough for accumulation to start but too small for + // sort_memory after clearing. + let task_ctx = new_spill_ctx(1, 500); + let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await; + + match &result { + Ok(_) => panic!("Expected ResourcesExhausted error but query succeeded"), + Err(e) => { + let root = e.find_root(); + assert!( + matches!(root, DataFusionError::ResourcesExhausted(_)), + "Expected ResourcesExhausted, got: {root}", + ); + let msg = root.to_string(); + assert!( + msg.contains("Failed to reserve memory for sort during spill"), + "Expected sort reservation error, got: {msg}", + ); + } + } + + Ok(()) + } + + /// Tests that PartialReduce mode: + /// 1. Accepts state as input (like Final) + /// 2. Produces state as output (like Partial) + /// 3. Can be followed by a Final stage to get the correct result + /// + /// This simulates a tree-reduce pattern: + /// Partial -> PartialReduce -> Final + async fn evaluate_partial_reduce( + groups: PhysicalGroupBy, + aggregates: Vec>, + partition_1_and_2_batches: [Vec; 2], + ) -> Result> { + let schema = partition_1_and_2_batches + .iter() + .flatten() + .next() + .expect("Must have at least 1 batch") + .schema(); + + let [partition_1, partition_2] = partition_1_and_2_batches; + + // Step 1: Partial aggregation on partition 1 + let input1 = + TestMemoryExec::try_new_exec(&[partition_1], Arc::clone(&schema), None)?; + let partial1 = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups.clone(), + aggregates.clone(), + vec![None; aggregates.len()], + input1, + Arc::clone(&schema), + )?); + + // Step 2: Partial aggregation on partition 2 + let input2 = + TestMemoryExec::try_new_exec(&[partition_2], Arc::clone(&schema), None)?; + let partial2 = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups.clone(), + aggregates.clone(), + vec![None; aggregates.len()], + input2, + Arc::clone(&schema), + )?); + + // Collect partial results + let task_ctx = Arc::new(TaskContext::default()); + let partial_result1 = + crate::collect(Arc::clone(&partial1) as _, Arc::clone(&task_ctx)).await?; + let partial_result2 = + crate::collect(Arc::clone(&partial2) as _, Arc::clone(&task_ctx)).await?; + + // The partial results have state schema (group cols + accumulator state) + let partial_schema = partial1.schema(); + + // Step 3: PartialReduce — combine partial results, still producing state + let combined_input = TestMemoryExec::try_new_exec( + &[partial_result1, partial_result2], + Arc::clone(&partial_schema), + None, + )?; + // Coalesce into a single partition for the PartialReduce + let coalesced = Arc::new(CoalescePartitionsExec::new(combined_input)); + + let partial_reduce = Arc::new(AggregateExec::try_new( + AggregateMode::PartialReduce, + groups.clone(), + aggregates.clone(), + vec![None; aggregates.len()], + coalesced, + Arc::clone(&partial_schema), + )?); + + // Verify PartialReduce output schema matches Partial output schema + // (both produce state, not final values) + assert_eq!(partial_reduce.schema(), partial_schema); + + // Collect PartialReduce results + let reduce_result = + crate::collect(Arc::clone(&partial_reduce) as _, Arc::clone(&task_ctx)) + .await?; + + // Step 4: Final aggregation on the PartialReduce output + let final_input = TestMemoryExec::try_new_exec( + &[reduce_result], + Arc::clone(&partial_schema), + None, + )?; + let final_agg = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + groups.clone(), + aggregates.clone(), + vec![None; aggregates.len()], + final_input, + Arc::clone(&partial_schema), + )?); + + let result = crate::collect(final_agg, Arc::clone(&task_ctx)).await?; + + Ok(result) + } + + /// Builds the shared `Partial -> PartialReduce -> Final` fixture used by + /// the `test_partial_reduce_*` tests below and runs the pipeline against + /// the aggregate produced by `build_aggregates`. + /// + /// Each test only needs to supply the UDAF/alias under test, so the test + /// body stays focused on which aggregate shape is being exercised. + async fn run_partial_reduce_pipeline( + build_aggregates: F, + ) -> Result> + where + F: FnOnce(&Arc) -> Result>>, + { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::Float64, false), + ])); + + // Two partitions of input data so the Partial stage produces multiple + // partial states that PartialReduce must combine. + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt32Array::from(vec![1, 2, 3])), + Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])), + ], + )?; + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt32Array::from(vec![1, 2, 3])), + Arc::new(Float64Array::from(vec![40.0, 50.0, 60.0])), + ], + )?; + + let groups = + PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); + let aggregates = build_aggregates(&schema)?; + + evaluate_partial_reduce(groups, aggregates, [vec![batch1], vec![batch2]]).await + } + + // ------------------------------------------------------------------- + // PartialReduce regression coverage. + // + // Each shape (single state field / single input arg, multi-state / + // single-input, more-state-than-input) is covered twice: + // * once against a real UDAF, to round-trip an actual aggregate end + // to end through `Partial -> PartialReduce -> Final`; and + // * once against [`InputTypeAssertingUdaf`], whose input / state / + // output types are deliberately pairwise-disjoint within each test + // so a regression that swapped state-field types for input-field + // types (or vice versa) fails the assertion instead of slipping + // through on a coincidental type match. + // + // The stub variants do the heavy lifting on the contract; the real + // ones make sure no real aggregate is broken by it. + // ------------------------------------------------------------------- + + /// Real-UDAF round-trip: aggregate with a single state field and a + /// single input argument (`SUM(b)` — state and input are both `Float64`). + #[tokio::test] + async fn test_partial_reduce_with_single_state_field_and_single_input_arg() + -> Result<()> { + let result = run_partial_reduce_pipeline(|schema| { + Ok(vec![Arc::new( + AggregateExprBuilder::new(sum_udaf(), vec![col("b", schema)?]) + .schema(Arc::clone(schema)) + .alias("SUM(b)") + .build()?, + )]) + }) + .await?; + + // Expected: group 1 -> 10+40=50, group 2 -> 20+50=70, group 3 -> 30+60=90 + assert_snapshot!(batches_to_sort_string(&result), @r" + +---+--------+ + | a | SUM(b) | + +---+--------+ + | 1 | 50.0 | + | 2 | 70.0 | + | 3 | 90.0 | + +---+--------+ + "); + + Ok(()) + } + + /// Real-UDAF round-trip: aggregate with multiple state fields and a + /// single input argument (`AVG(b)` — state is `[sum: Float64, count: + /// UInt64]`). + #[tokio::test] + async fn test_partial_reduce_with_multiple_state_fields_and_single_input_arg() + -> Result<()> { + let result = run_partial_reduce_pipeline(|schema| { + Ok(vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", schema)?]) + .schema(Arc::clone(schema)) + .alias("AVG(b)") + .build()?, + )]) + }) + .await?; + + assert_snapshot!(batches_to_sort_string(&result), @r" + +---+--------+ + | a | AVG(b) | + +---+--------+ + | 1 | 25.0 | + | 2 | 35.0 | + | 3 | 45.0 | + +---+--------+ + "); + + Ok(()) + } + + /// Real-UDAF round-trip: aggregate whose state has more fields than the + /// input has arguments (`approx_percentile_cont` carries a t-digest). + #[tokio::test] + async fn test_partial_reduce_with_more_state_fields_than_input_args() -> Result<()> { + let result = run_partial_reduce_pipeline(|schema| { + Ok(vec![Arc::new( + AggregateExprBuilder::new( + approx_percentile_cont_udaf(), + vec![col("b", schema)?, lit(0.75f32)], + ) + .schema(Arc::clone(schema)) + .alias("approx_percentile_cont(b, 0.75)") + .build()?, + )]) + }) + .await?; + + assert_snapshot!(batches_to_sort_string(&result), @r" + +---+---------------------------------+ + | a | approx_percentile_cont(b, 0.75) | + +---+---------------------------------+ + | 1 | 40.0 | + | 2 | 50.0 | + | 3 | 60.0 | + +---+---------------------------------+ + "); + + Ok(()) + } + + /// Stub variant of + /// [`test_partial_reduce_with_single_state_field_and_single_input_arg`] + /// with disjoint input / state / output types. + /// + /// - input: `Float64` + /// - state: `Int32` + /// - output: `Int64` + /// + /// Any mode that accidentally forwarded state-field types in place of + /// input-field types would fail the assertion in + /// [`InputTypeAssertingUdaf`] instead of being masked by a coincidental + /// type match. + #[tokio::test] + async fn test_partial_reduce_with_single_state_field_and_single_input_arg_using_unique_types() + -> Result<()> { + let result = run_partial_reduce_pipeline(|schema| { + let udaf = Arc::new(AggregateUDF::from(InputTypeAssertingUdaf::new( + vec![DataType::Float64], + vec![DataType::Int32], + DataType::Int64, + ))); + Ok(vec![Arc::new( + AggregateExprBuilder::new(udaf, vec![col("b", schema)?]) + .schema(Arc::clone(schema)) + .alias("input_type_asserting(b)") + .build()?, + )]) + }) + .await?; + + // Pipeline completing without error is the real assertion. The + // snapshot guards against silent regressions in the row shape. + assert_snapshot!(batches_to_sort_string(&result), @r" + +---+-------------------------+ + | a | input_type_asserting(b) | + +---+-------------------------+ + | 1 | 0 | + | 2 | 0 | + | 3 | 0 | + +---+-------------------------+ + "); + + Ok(()) + } + + /// Stub variant of + /// [`test_partial_reduce_with_multiple_state_fields_and_single_input_arg`] + /// with disjoint input / state / output types. + /// + /// - input: `Float64` + /// - state: `[Int32, Utf8]` + /// - output: `Int64` + #[tokio::test] + async fn test_partial_reduce_with_multiple_state_fields_and_single_input_arg_using_unique_types() + -> Result<()> { + let result = run_partial_reduce_pipeline(|schema| { + let udaf = Arc::new(AggregateUDF::from(InputTypeAssertingUdaf::new( + vec![DataType::Float64], + vec![DataType::Int32, DataType::Utf8], + DataType::Int64, + ))); + Ok(vec![Arc::new( + AggregateExprBuilder::new(udaf, vec![col("b", schema)?]) + .schema(Arc::clone(schema)) + .alias("input_type_asserting(b)") + .build()?, + )]) + }) + .await?; + + assert_snapshot!(batches_to_sort_string(&result), @r" + +---+-------------------------+ + | a | input_type_asserting(b) | + +---+-------------------------+ + | 1 | 0 | + | 2 | 0 | + | 3 | 0 | + +---+-------------------------+ + "); + + Ok(()) + } + + /// Stub variant of + /// [`test_partial_reduce_with_more_state_fields_than_input_args`] with + /// disjoint input / state / output types — and with multiple input + /// arguments to exercise the multi-arg path explicitly. + /// + /// - input: `[Float64, Date32]` + /// - state: `[Int32, Utf8, Boolean]` + /// - output: `Int64` + #[tokio::test] + async fn test_partial_reduce_with_more_state_fields_than_input_args_using_unique_types() + -> Result<()> { + let result = run_partial_reduce_pipeline(|schema| { + let udaf = Arc::new(AggregateUDF::from(InputTypeAssertingUdaf::new( + vec![DataType::Float64, DataType::Date32], + vec![DataType::Int32, DataType::Utf8, DataType::Boolean], + DataType::Int64, + ))); + Ok(vec![Arc::new( + AggregateExprBuilder::new( + udaf, + vec![col("b", schema)?, lit(ScalarValue::Date32(Some(1)))], + ) + .schema(Arc::clone(schema)) + .alias("input_type_asserting(b, lit)") + .build()?, + )]) + }) + .await?; + + assert_snapshot!(batches_to_sort_string(&result), @r" + +---+------------------------------+ + | a | input_type_asserting(b, lit) | + +---+------------------------------+ + | 1 | 0 | + | 2 | 0 | + | 3 | 0 | + +---+------------------------------+ + "); + + Ok(()) + } + + /// Stub test: many input args, few state fields (5 inputs / 2 state). + /// + /// All eight types involved are pairwise-disjoint: + /// - input: `[Float64, Date32, UInt16, Boolean, Int32]` + /// - state: `[Utf8, Int64]` + /// - output: `Float32` + #[tokio::test] + async fn test_partial_reduce_with_5_input_args_and_2_state_fields_using_unique_types() + -> Result<()> { + let result = run_partial_reduce_pipeline(|schema| { + let udaf = Arc::new(AggregateUDF::from(InputTypeAssertingUdaf::new( + vec![ + DataType::Float64, + DataType::Date32, + DataType::UInt16, + DataType::Boolean, + DataType::Int32, + ], + vec![DataType::Utf8, DataType::Int64], + DataType::Float32, + ))); + Ok(vec![Arc::new( + AggregateExprBuilder::new( + udaf, + vec![ + col("b", schema)?, + lit(ScalarValue::Date32(Some(1))), + lit(ScalarValue::UInt16(Some(1))), + lit(ScalarValue::Boolean(Some(false))), + lit(ScalarValue::Int32(Some(1))), + ], + ) + .schema(Arc::clone(schema)) + .alias("input_type_asserting(b, l1, l2, l3, l4)") + .build()?, + )]) + }) + .await?; + + assert_snapshot!(batches_to_sort_string(&result), @r" + +---+-----------------------------------------+ + | a | input_type_asserting(b, l1, l2, l3, l4) | + +---+-----------------------------------------+ + | 1 | 0.0 | + | 2 | 0.0 | + | 3 | 0.0 | + +---+-----------------------------------------+ + "); + + Ok(()) + } + + /// Stub test: few input args, many state fields (2 inputs / 5 state). + /// + /// All eight types involved are pairwise-disjoint: + /// - input: `[Float64, Date32]` + /// - state: `[Boolean, Int32, Utf8, Int64, UInt16]` + /// - output: `Float32` + #[tokio::test] + async fn test_partial_reduce_with_2_input_args_and_5_state_fields_using_unique_types() + -> Result<()> { + let result = run_partial_reduce_pipeline(|schema| { + let udaf = Arc::new(AggregateUDF::from(InputTypeAssertingUdaf::new( + vec![DataType::Float64, DataType::Date32], + vec![ + DataType::Boolean, + DataType::Int32, + DataType::Utf8, + DataType::Int64, + DataType::UInt16, + ], + DataType::Float32, + ))); + Ok(vec![Arc::new( + AggregateExprBuilder::new( + udaf, + vec![col("b", schema)?, lit(ScalarValue::Date32(Some(1)))], + ) + .schema(Arc::clone(schema)) + .alias("input_type_asserting(b, lit)") + .build()?, + )]) + }) + .await?; + + assert_snapshot!(batches_to_sort_string(&result), @r" + +---+------------------------------+ + | a | input_type_asserting(b, lit) | + +---+------------------------------+ + | 1 | 0.0 | + | 2 | 0.0 | + | 3 | 0.0 | + +---+------------------------------+ + "); + + Ok(()) + } + + /// Test-only aggregate whose `return_type`, `state_fields`, and + /// `accumulator` hooks all assert that they receive the originally- + /// declared input types; the companion accumulator further asserts + /// `update_batch` sees inputs and `merge_batch` sees state. + /// + /// Each test instantiates it with input / state / output types that + /// are pairwise-disjoint, so a regression that forwarded the wrong + /// types fails on type mismatch rather than passing by accident. + #[derive(Debug, PartialEq, Eq, Hash)] + struct InputTypeAssertingUdaf { + signature: Signature, + input_types: Vec, + state_types: Vec, + output_type: DataType, + } + + fn assert_data_types( + what: &str, + expected: &[DataType], + actual: &[DataType], + ) -> Result<()> { + if actual != expected { + return internal_err!( + "InputTypeAssertingUdaf: {} expected types {:?} but got {:?} — a regression is leaking the wrong types into the accumulator contract", + what, + expected, + actual + ); + } + Ok(()) + } + + /// Produce a zeroed [`ScalarValue`] for `dt`. Only the data types the + /// tests above plug into [`InputTypeAssertingUdaf`] are listed; adding + /// a new type to a test requires extending this match. + fn zero_scalar_for(dt: &DataType) -> Result { + match dt { + DataType::Boolean => Ok(ScalarValue::Boolean(Some(false))), + DataType::Int32 => Ok(ScalarValue::Int32(Some(0))), + DataType::Int64 => Ok(ScalarValue::Int64(Some(0))), + DataType::UInt16 => Ok(ScalarValue::UInt16(Some(0))), + DataType::Float32 => Ok(ScalarValue::Float32(Some(0.0))), + DataType::Utf8 => Ok(ScalarValue::Utf8(Some(String::new()))), + other => internal_err!( + "InputTypeAssertingUdaf: no zero ScalarValue registered for {other:?} \ + — extend `zero_scalar_for` when adding a new state/output type" + ), + } + } + + impl InputTypeAssertingUdaf { + fn new( + input_types: Vec, + state_types: Vec, + output_type: DataType, + ) -> Self { + // Within-test type-disjointness is enforced by construction so + // a future test author can't quietly reintroduce overlap. + assert!( + all_pairwise_distinct(&input_types, &state_types, &output_type), + "InputTypeAssertingUdaf::new: input ({input_types:?}), state \ + ({state_types:?}), and output ({output_type:?}) types must be \ + pairwise-disjoint to avoid accidental passes", + ); + Self { + signature: Signature::exact(input_types.clone(), Volatility::Immutable), + input_types, + state_types, + output_type, + } + } + } + + /// True iff every type in `inputs ∪ states ∪ {output}` is unique. + fn all_pairwise_distinct( + inputs: &[DataType], + states: &[DataType], + output: &DataType, + ) -> bool { + let mut seen = HashSet::new(); + for dt in inputs + .iter() + .chain(states.iter()) + .chain(std::iter::once(output)) + { + if !seen.insert(dt) { + return false; + } + } + true + } + + impl AggregateUDFImpl for InputTypeAssertingUdaf { + fn name(&self) -> &str { + "input_type_asserting" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + assert_data_types("return_type(arg_types)", &self.input_types, arg_types)?; + Ok(self.output_type.clone()) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let actual: Vec = args + .input_fields + .iter() + .map(|f| f.data_type().clone()) + .collect(); + assert_data_types( + "state_fields(args.input_fields)", + &self.input_types, + &actual, + )?; + Ok(self + .state_types + .iter() + .enumerate() + .map(|(i, dt)| { + Field::new(format!("{}[s{i}]", args.name), dt.clone(), true).into() + }) + .collect()) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let actual: Vec = acc_args + .expr_fields + .iter() + .map(|f| f.data_type().clone()) + .collect(); + assert_data_types( + "accumulator(acc_args.expr_fields)", + &self.input_types, + &actual, + )?; + Ok(Box::new(InputTypeAssertingAccumulator { + input_types: self.input_types.clone(), + state_types: self.state_types.clone(), + output_type: self.output_type.clone(), + })) + } + } + + /// Companion accumulator for [`InputTypeAssertingUdaf`]. + /// + /// - `update_batch` must always receive arrays of the original input + /// types. + /// - `merge_batch` must always receive arrays of the declared state + /// types. + /// + /// Anything else means a non-input mode is calling the wrong path. + #[derive(Debug)] + struct InputTypeAssertingAccumulator { + input_types: Vec, + state_types: Vec, + output_type: DataType, + } + + impl Accumulator for InputTypeAssertingAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let actual: Vec = + values.iter().map(|a| a.data_type().clone()).collect(); + assert_data_types("update_batch(values)", &self.input_types, &actual) + } + + fn evaluate(&mut self) -> Result { + zero_scalar_for(&self.output_type) + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn state(&mut self) -> Result> { + self.state_types.iter().map(zero_scalar_for).collect() + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let actual: Vec = + states.iter().map(|a| a.data_type().clone()).collect(); + assert_data_types("merge_batch(states)", &self.state_types, &actual) + } + } + + /// Test that [`AggregateExec::with_dynamic_filter_expr`] overrides the existing dynamic filter + #[test] + fn test_with_dynamic_filter() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let child = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + // Partial min aggregate supports dynamic filtering + let agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![]), + vec![Arc::new( + AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build()?, + )], + vec![None], + child, + Arc::clone(&schema), + )?; + + // Assertion 1: A filter with the same children can override the existing + // dynamic filter. + let new_df = Arc::new(DynamicFilterPhysicalExpr::new( + vec![col("a", &schema)?], + lit(false), + )); + let agg = agg.with_dynamic_filter_expr(Arc::clone(&new_df))?; + + // The aggregate's filter should now resolve to the new inner expression. + let swapped = agg + .dynamic_filter_expr() + .expect("should still have dynamic filter") + .current()?; + assert_eq!(format!("{swapped}"), format!("{}", lit(false))); + + // Assertion 2: A filter that has been through `PhysicalExpr::with_new_children` + // should still be accepted when the new children are equivalent to the originals. + let new_df_as_pexpr: Arc = + Arc::::clone(&new_df); + let remapped_pexpr = + new_df_as_pexpr.with_new_children(vec![col("a", &schema)?])?; + let Ok(remapped_df) = (remapped_pexpr as Arc) + .downcast::() + else { + panic!("should be DynamicFilterPhysicalExpr after with_new_children"); + }; + // Hard to assert this because the filter is identical. No error means + // the filter was accepted. That's a good enough assertion for now. + let _agg = agg.with_dynamic_filter_expr(remapped_df)?; + Ok(()) + } + + /// Test that [`AggregateExec::with_dynamic_filter_expr`] errors when the aggregate does not support dynamic filtering + #[test] + fn test_with_dynamic_filter_error_unsupported() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Int64, false), + ])); + let child = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + // Final mode with a group-by does not support dynamic filters. + let agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]), + vec![Arc::new( + AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("sum_b") + .build()?, + )], + vec![None], + child, + Arc::clone(&schema), + )?; + assert!(agg.dynamic_filter_expr().is_none()); + + let df = Arc::new(DynamicFilterPhysicalExpr::new( + vec![col("a", &schema)?], + lit(true), + )); + assert!(agg.with_dynamic_filter_expr(df).is_err()); + Ok(()) + } + + /// Test that [`AggregateExec::with_dynamic_filter_expr`] errors when the column is not in the schema + #[test] + fn test_with_dynamic_filter_error_column_mismatch() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let child = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let agg = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![]), + vec![Arc::new( + AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build()?, + )], + vec![None], + child, + Arc::clone(&schema), + )?; + + let df = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("bad", 99)) as _], + lit(true), + )); + assert!(agg.with_dynamic_filter_expr(df).is_err()); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index a4b202f1ae2a7..ac7727b459300 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -18,18 +18,23 @@ //! Aggregate without grouping columns use crate::aggregates::{ - aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem, - AggregateMode, + AccumulatorItem, AggrDynFilter, AggregateInputMode, AggregateMode, + DynamicFilterAggregateType, aggregate_expressions, create_accumulators, + finalize_aggregation, }; use crate::metrics::{BaselineMetrics, RecordOutput}; +use crate::stream::EmptyRecordBatchStream; use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::Result; +use datafusion_common::{Result, ScalarValue, internal_datafusion_err, internal_err}; use datafusion_execution::TaskContext; +use datafusion_expr::Operator; use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::expressions::{BinaryExpr, lit}; use futures::stream::BoxStream; use std::borrow::Cow; +use std::cmp::Ordering; use std::sync::Arc; use std::task::{Context, Poll}; @@ -53,15 +58,225 @@ pub(crate) struct AggregateStream { /// /// The latter requires a state object, which is [`AggregateStreamInner`]. struct AggregateStreamInner { + // ==== Properties ==== schema: SchemaRef, mode: AggregateMode, input: SendableRecordBatchStream, - baseline_metrics: BaselineMetrics, aggregate_expressions: Vec>>, - filter_expressions: Vec>>, + filter_expressions: Arc<[Option>]>, + + // ==== Runtime States/Buffers ==== accumulators: Vec, - reservation: MemoryReservation, + // None if the dynamic filter is not applicable. See details in `AggrDynFilter`. + agg_dyn_filter_state: Option>, finished: bool, + + // ==== Execution Resources ==== + baseline_metrics: BaselineMetrics, + reservation: MemoryReservation, +} + +impl AggregateStreamInner { + // TODO: check if we get Null handling correct + /// # Examples + /// - Example 1 + /// Accumulators: min(c1) + /// Current Bounds: min(c1)=10 + /// --> dynamic filter PhysicalExpr: c1 < 10 + /// + /// - Example 2 + /// Accumulators: min(c1), max(c1), min(c2) + /// Current Bounds: min(c1)=10, max(c1)=100, min(c2)=20 + /// --> dynamic filter PhysicalExpr: (c1 < 10) OR (c1>100) OR (c2 < 20) + /// + /// # Errors + /// Returns internal errors if the dynamic filter is not enabled, or other + /// invariant check fails. + fn build_dynamic_filter_from_accumulator_bounds( + &self, + ) -> Result> { + let Some(filter_state) = self.agg_dyn_filter_state.as_ref() else { + return internal_err!( + "`build_dynamic_filter_from_accumulator_bounds()` is only called when dynamic filter is enabled" + ); + }; + + let mut predicates: Vec> = + Vec::with_capacity(filter_state.supported_accumulators_info.len()); + + for acc_info in &filter_state.supported_accumulators_info { + // Skip if we don't yet have a meaningful bound + let bound = { + let guard = acc_info.shared_bound.lock(); + if (*guard).is_null() { + continue; + } + guard.clone() + }; + + let agg_exprs = self + .aggregate_expressions + .get(acc_info.aggr_index) + .ok_or_else(|| { + internal_datafusion_err!( + "Invalid aggregate expression index {} for dynamic filter", + acc_info.aggr_index + ) + })?; + // Only aggregates with a single argument are supported. + let column_expr = agg_exprs.first().ok_or_else(|| { + internal_datafusion_err!( + "Aggregate expression at index {} expected a single argument", + acc_info.aggr_index + ) + })?; + + let literal = lit(bound); + let predicate: Arc = match acc_info.aggr_type { + DynamicFilterAggregateType::Min => Arc::new(BinaryExpr::new( + Arc::clone(column_expr), + Operator::Lt, + literal, + )), + DynamicFilterAggregateType::Max => Arc::new(BinaryExpr::new( + Arc::clone(column_expr), + Operator::Gt, + literal, + )), + }; + predicates.push(predicate); + } + + let combined = predicates.into_iter().reduce(|acc, pred| { + Arc::new(BinaryExpr::new(acc, Operator::Or, pred)) as Arc + }); + + Ok(combined.unwrap_or_else(|| lit(true))) + } + + // If the dynamic filter is enabled, update it using the current accumulator's + // values + fn maybe_update_dyn_filter(&mut self) -> Result<()> { + // Step 1: Update each partition's current bound + let Some(filter_state) = self.agg_dyn_filter_state.as_ref() else { + return Ok(()); + }; + + let mut bounds_changed = false; + + for acc_info in &filter_state.supported_accumulators_info { + let acc = + self.accumulators + .get_mut(acc_info.aggr_index) + .ok_or_else(|| { + internal_datafusion_err!( + "Invalid accumulator index {} for dynamic filter", + acc_info.aggr_index + ) + })?; + // First get current partition's bound, then update the shared bound among + // all partitions. + let current_bound = acc.evaluate()?; + { + let mut bound = acc_info.shared_bound.lock(); + let new_bound = match acc_info.aggr_type { + DynamicFilterAggregateType::Max => { + scalar_max(&bound, ¤t_bound)? + } + DynamicFilterAggregateType::Min => { + scalar_min(&bound, ¤t_bound)? + } + }; + if new_bound != *bound { + *bound = new_bound; + bounds_changed = true; + } + } + } + + // Step 2: Sync the dynamic filter physical expression with reader, + // but only if any bound actually changed. + if bounds_changed { + let predicate = self.build_dynamic_filter_from_accumulator_bounds()?; + filter_state.filter.update(predicate)?; + } + + Ok(()) + } +} + +/// Returns the element-wise minimum of two `ScalarValue`s. +/// +/// # Null semantics +/// - `min(NULL, NULL) = NULL` +/// - `min(NULL, x) = x` +/// - `min(x, NULL) = x` +/// +/// # Errors +/// Returns internal error if v1 and v2 has incompatible types. +fn scalar_min(v1: &ScalarValue, v2: &ScalarValue) -> Result { + if let Some(result) = scalar_cmp_null_short_circuit(v1, v2) { + return Ok(result); + } + + match v1.partial_cmp(v2) { + Some(Ordering::Less | Ordering::Equal) => Ok(v1.clone()), + Some(Ordering::Greater) => Ok(v2.clone()), + None => datafusion_common::internal_err!( + "cannot compare values of different or incompatible types: {v1:?} vs {v2:?}" + ), + } +} + +/// Returns the element-wise maximum of two `ScalarValue`s. +/// +/// # Null semantics +/// - `max(NULL, NULL) = NULL` +/// - `max(NULL, x) = x` +/// - `max(x, NULL) = x` +/// +/// # Errors +/// Returns internal error if v1 and v2 has incompatible types. +fn scalar_max(v1: &ScalarValue, v2: &ScalarValue) -> Result { + if let Some(result) = scalar_cmp_null_short_circuit(v1, v2) { + return Ok(result); + } + + match v1.partial_cmp(v2) { + Some(Ordering::Greater | Ordering::Equal) => Ok(v1.clone()), + Some(Ordering::Less) => Ok(v2.clone()), + None => datafusion_common::internal_err!( + "cannot compare values of different or incompatible types: {v1:?} vs {v2:?}" + ), + } +} + +fn scalar_cmp_null_short_circuit( + v1: &ScalarValue, + v2: &ScalarValue, +) -> Option { + match (v1, v2) { + (ScalarValue::Null, ScalarValue::Null) => Some(ScalarValue::Null), + (ScalarValue::Null, other) | (other, ScalarValue::Null) => Some(other.clone()), + _ => None, + } +} + +/// Prepend the grouping ID column to the output columns if present. +/// +/// For GROUPING SETS with no GROUP BY expressions, the schema includes a `__grouping_id` +/// column that must be present in the output. This function inserts it at the beginning +/// of the columns array to maintain schema alignment. +fn prepend_grouping_id_column( + mut columns: Vec>, + grouping_id: Option<&ScalarValue>, +) -> Result>> { + if let Some(id) = grouping_id { + let num_rows = columns.first().map(|array| array.len()).unwrap_or(1); + let grouping_ids = id.to_array_of_size(num_rows)?; + columns.insert(0, grouping_ids); + } + Ok(columns) } impl AggregateStream { @@ -72,25 +287,39 @@ impl AggregateStream { partition: usize, ) -> Result { let agg_schema = Arc::clone(&agg.schema); - let agg_filter_expr = agg.filter_expr.clone(); + let agg_filter_expr = Arc::clone(&agg.filter_expr); let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition); let input = agg.input.execute(partition, Arc::clone(context))?; let aggregate_expressions = aggregate_expressions(&agg.aggr_expr, &agg.mode, 0)?; - let filter_expressions = match agg.mode { - AggregateMode::Partial - | AggregateMode::Single - | AggregateMode::SinglePartitioned => agg_filter_expr, - AggregateMode::Final | AggregateMode::FinalPartitioned => { - vec![None; agg.aggr_expr.len()] - } + let filter_expressions = match agg.mode.input_mode() { + AggregateInputMode::Raw => agg_filter_expr, + AggregateInputMode::Partial => vec![None; agg.aggr_expr.len()].into(), }; let accumulators = create_accumulators(&agg.aggr_expr)?; let reservation = MemoryConsumer::new(format!("AggregateStream[{partition}]")) .register(context.memory_pool()); + // Enable dynamic filter if: + // 1. AggregateExec did the check and ensure it supports the dynamic filter + // (its dynamic_filter field will be Some(..)) + // 2. Aggregate dynamic filter is enabled from the config + let mut maybe_dynamic_filter = match agg.dynamic_filter.as_ref() { + Some(filter) => Some(Arc::clone(filter)), + _ => None, + }; + + if !context + .session_config() + .options() + .optimizer + .enable_aggregate_dynamic_filter_pushdown + { + maybe_dynamic_filter = None; + } + let inner = AggregateStreamInner { schema: Arc::clone(&agg.schema), mode: agg.mode, @@ -101,27 +330,33 @@ impl AggregateStream { accumulators, reservation, finished: false, + agg_dyn_filter_state: maybe_dynamic_filter, }; + let stream = futures::stream::unfold(inner, |mut this| async move { if this.finished { return None; } - let elapsed_compute = this.baseline_metrics.elapsed_compute(); - loop { let result = match this.input.next().await { Some(Ok(batch)) => { - let timer = elapsed_compute.timer(); - let result = aggregate_batch( - &this.mode, - &batch, - &mut this.accumulators, - &this.aggregate_expressions, - &this.filter_expressions, - ); + let result = { + let elapsed_compute = this.baseline_metrics.elapsed_compute(); + let _timer = elapsed_compute.timer(); // Stops on drop + aggregate_batch( + &this.mode, + &batch, + &mut this.accumulators, + &this.aggregate_expressions, + &this.filter_expressions, + ) + }; - timer.done(); + let result = result.and_then(|allocated| { + this.maybe_update_dyn_filter()?; + Ok(allocated) + }); // allocate memory // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with @@ -136,9 +371,15 @@ impl AggregateStream { Some(Err(e)) => Err(e), None => { this.finished = true; + // Release the input pipeline's resources before finalization. + let input_schema = this.input.schema(); + this.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); let timer = this.baseline_metrics.elapsed_compute().timer(); let result = finalize_aggregation(&mut this.accumulators, &this.mode) + .and_then(|columns| { + prepend_grouping_id_column(columns, None) + }) .and_then(|columns| { RecordBatch::try_new( Arc::clone(&this.schema), @@ -224,13 +465,9 @@ fn aggregate_batch( // 1.4 let size_pre = accum.size(); - let res = match mode { - AggregateMode::Partial - | AggregateMode::Single - | AggregateMode::SinglePartitioned => accum.update_batch(&values), - AggregateMode::Final | AggregateMode::FinalPartitioned => { - accum.merge_batch(&values) - } + let res = match mode.input_mode() { + AggregateInputMode::Raw => accum.update_batch(&values), + AggregateInputMode::Partial => accum.merge_batch(&values), }; let size_post = accum.size(); allocated += size_post.saturating_sub(size_pre); diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index bbcb30d877cf0..97fbd519c825c 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -52,7 +52,8 @@ impl GroupOrdering { } } - // How many groups be emitted, or None if no data can be emitted + /// Returns how many groups can be emitted while respecting the current + /// ordering guarantees, or `None` if no data can be emitted. pub fn emit_to(&self) -> Option { match self { GroupOrdering::None => None, @@ -61,7 +62,29 @@ impl GroupOrdering { } } - /// Updates the state the input is done + /// Returns the emit strategy to use under memory pressure (OOM). + /// + /// Returns the strategy that must be used when emitting up to `n` groups + /// while respecting the current ordering guarantees. + /// + /// Returns `None` if no data can be emitted. + pub fn oom_emit_to(&self, n: usize) -> Option { + if n == 0 { + return None; + } + + match self { + GroupOrdering::None => Some(EmitTo::First(n)), + GroupOrdering::Partial(_) | GroupOrdering::Full(_) => { + self.emit_to().map(|emit_to| match emit_to { + EmitTo::First(max) => EmitTo::First(n.min(max)), + EmitTo::All => EmitTo::First(n), + }) + } + } + } + + /// Updates the state to indicate that the input is complete. pub fn input_done(&mut self) { match self { GroupOrdering::None => {} @@ -70,8 +93,8 @@ impl GroupOrdering { } } - /// remove the first n groups from the internal state, shifting - /// all existing indexes down by `n` + /// Removes the first `n` groups from the internal state, shifting all + /// existing indexes down by `n`. pub fn remove_groups(&mut self, n: usize) { match self { GroupOrdering::None => {} @@ -80,16 +103,14 @@ impl GroupOrdering { } } - /// Called when new groups are added in a batch + /// Called when new groups are added in a batch. /// - /// * `total_num_groups`: total number of groups (so max - /// group_index is total_num_groups - 1). - /// - /// * `group_values`: group key values for *each row* in the batch + /// * `batch_group_values`: group key values for each row in the batch /// /// * `group_indices`: indices for each row in the batch /// - /// * `hashes`: hash values for each row in the batch + /// * `total_num_groups`: total number of groups (so max + /// group_index is total_num_groups - 1). pub fn new_groups( &mut self, batch_group_values: &[ArrayRef], @@ -112,7 +133,7 @@ impl GroupOrdering { Ok(()) } - /// Return the size of memory used by the ordering state, in bytes + /// Returns the size of memory used by the ordering state, in bytes. pub fn size(&self) -> usize { size_of::() + match self { @@ -122,3 +143,63 @@ impl GroupOrdering { } } } + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::Arc; + + use arrow::array::Int32Array; + + #[test] + fn test_oom_emit_to_none_ordering() { + let group_ordering = GroupOrdering::None; + + assert_eq!(group_ordering.oom_emit_to(0), None); + assert_eq!(group_ordering.oom_emit_to(5), Some(EmitTo::First(5))); + } + + /// Creates a partially ordered grouping state with three groups. + /// + /// `sort_key_values` controls whether a sort boundary exists in the batch: + /// distinct values such as `[1, 2, 3]` create boundaries, while repeated + /// values such as `[1, 1, 1]` do not. + fn partial_ordering(sort_key_values: Vec) -> Result { + let mut group_ordering = + GroupOrdering::Partial(GroupOrderingPartial::try_new(vec![0])?); + + let batch_group_values: Vec = vec![ + Arc::new(Int32Array::from(sort_key_values)), + Arc::new(Int32Array::from(vec![10, 20, 30])), + ]; + let group_indices = vec![0, 1, 2]; + + group_ordering.new_groups(&batch_group_values, &group_indices, 3)?; + + Ok(group_ordering) + } + + #[test] + fn test_oom_emit_to_partial_clamps_to_boundary() -> Result<()> { + let group_ordering = partial_ordering(vec![1, 2, 3])?; + + // Can emit both `1` and `2` groups because we have seen `3` + assert_eq!(group_ordering.emit_to(), Some(EmitTo::First(2))); + assert_eq!(group_ordering.oom_emit_to(1), Some(EmitTo::First(1))); + assert_eq!(group_ordering.oom_emit_to(3), Some(EmitTo::First(2))); + + Ok(()) + } + + #[test] + fn test_oom_emit_to_partial_without_boundary() -> Result<()> { + let group_ordering = partial_ordering(vec![1, 1, 1])?; + + // Can't emit the last `1` group as it may have more values + assert_eq!(group_ordering.emit_to(), None); + assert_eq!(group_ordering.oom_emit_to(3), None); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 3a1560d9ef16e..c3f73976c721a 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -22,37 +22,39 @@ use std::task::{Context, Poll}; use std::vec; use super::order::GroupOrdering; -use super::AggregateExec; -use crate::aggregates::group_values::{new_group_values, GroupByMetrics, GroupValues}; +use super::{AggregateExec, format_human_display}; +use crate::aggregates::group_values::{GroupByMetrics, GroupValues, new_group_values}; use crate::aggregates::order::GroupOrderingFull; use crate::aggregates::{ - create_schema, evaluate_group_by, evaluate_many, evaluate_optional, AggregateMode, - PhysicalGroupBy, + AggregateInputMode, AggregateMode, AggregateOutputMode, PhysicalGroupBy, + create_schema, evaluate_group_by, evaluate_many, evaluate_optional, group_id_array, + max_duplicate_ordinal, }; -use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput}; -use crate::sorts::sort::sort_batch; +use crate::metrics::{BaselineMetrics, MetricBuilder, MetricCategory, RecordOutput}; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; -use crate::spill::spill_manager::SpillManager; -use crate::stream::RecordBatchStreamAdapter; -use crate::{aggregates, metrics, PhysicalExpr}; +use crate::spill::spill_manager::{GetSlicedSize, SpillManager}; +use crate::stream::EmptyRecordBatchStream; +use crate::{PhysicalExpr, aggregates, metrics}; use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; use arrow::datatypes::SchemaRef; use datafusion_common::{ - assert_eq_or_internal_err, assert_or_internal_err, internal_err, DataFusionError, - Result, + DataFusionError, Result, assert_eq_or_internal_err, assert_or_internal_err, + internal_err, resources_datafusion_err, }; +use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; -use datafusion_execution::TaskContext; use datafusion_expr::{EmitTo, GroupsAccumulator}; use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use crate::sorts::IncrementalSortIterator; use datafusion_common::instant::Instant; +use datafusion_common::utils::memory::get_record_batch_memory_size; use futures::ready; use futures::stream::{Stream, StreamExt}; use log::debug; @@ -193,8 +195,10 @@ impl SkipAggregationProbe { self.num_groups = num_groups; if self.input_rows >= self.probe_rows_threshold { self.should_skip = self.num_groups as f64 / self.input_rows as f64 - >= self.probe_ratio_threshold; - self.is_locked = true; + > self.probe_ratio_threshold; + // Set is_locked to true only if we have decided to skip, otherwise we can try to skip + // during processing the next record_batch. + self.is_locked = self.should_skip; } } @@ -208,6 +212,17 @@ impl SkipAggregationProbe { } } +/// Controls the behavior when an out-of-memory condition occurs. +#[derive(PartialEq, Debug)] +enum OutOfMemoryMode { + /// When out of memory occurs, spill state to disk + Spill, + /// When out of memory occurs, attempt to emit group values early + EmitEarly, + /// When out of memory occurs, immediately report the error + ReportError, +} + /// HashTable based Grouping Aggregator /// /// # Design Goals @@ -247,7 +262,7 @@ impl SkipAggregationProbe { /// /// group_values accumulators /// -/// ``` +/// ``` /// /// For example, given a query like `COUNT(x), SUM(y) ... GROUP BY z`, /// [`group_values`] will store the distinct values of `z`. There will @@ -347,6 +362,7 @@ pub(crate) struct GroupedHashAggregateStream { // the execution. // ======================================================================== schema: SchemaRef, + input_schema: SchemaRef, input: SendableRecordBatchStream, mode: AggregateMode, @@ -365,10 +381,10 @@ pub(crate) struct GroupedHashAggregateStream { /// /// For example, for an aggregate like `SUM(x) FILTER (WHERE x >= 100)`, /// the filter expression is `x > 100`. - filter_expressions: Vec>>, + filter_expressions: Arc<[Option>]>, /// GROUP BY expressions - group_by: PhysicalGroupBy, + group_by: Arc, /// max rows in output RecordBatches batch_size: usize, @@ -431,6 +447,9 @@ pub(crate) struct GroupedHashAggregateStream { /// The memory reservation for this grouping reservation: MemoryReservation, + /// The behavior to trigger when out of memory occurs + oom_mode: OutOfMemoryMode, + /// Execution metrics baseline_metrics: BaselineMetrics, @@ -450,8 +469,8 @@ impl GroupedHashAggregateStream { ) -> Result { debug!("Creating GroupedHashAggregateStream"); let agg_schema = Arc::clone(&agg.schema); - let agg_group_by = agg.group_by.clone(); - let agg_filter_expr = agg.filter_expr.clone(); + let agg_group_by = Arc::clone(&agg.group_by); + let agg_filter_expr = Arc::clone(&agg.filter_expr); let batch_size = context.session_config().batch_size(); let input = agg.input.execute(partition, Arc::clone(context))?; @@ -460,7 +479,7 @@ impl GroupedHashAggregateStream { let timer = baseline_metrics.elapsed_compute().timer(); - let aggregate_exprs = agg.aggr_expr.clone(); + let aggregate_exprs = Arc::clone(&agg.aggr_expr); // arguments for each aggregate, one vec of expressions per // aggregate @@ -476,13 +495,9 @@ impl GroupedHashAggregateStream { agg_group_by.num_group_exprs(), )?; - let filter_expressions = match agg.mode { - AggregateMode::Partial - | AggregateMode::Single - | AggregateMode::SinglePartitioned => agg_filter_expr, - AggregateMode::Final | AggregateMode::FinalPartitioned => { - vec![None; agg.aggr_expr.len()] - } + let filter_expressions = match agg.mode.input_mode() { + AggregateInputMode::Raw => agg_filter_expr, + AggregateInputMode::Partial => vec![None; agg.aggr_expr.len()].into(), }; // Instantiate the accumulators @@ -508,12 +523,12 @@ impl GroupedHashAggregateStream { // Therefore, when we spill these intermediate states or pass them to another // aggregation operator, we must use a schema that includes both the group // columns **and** the partial-state columns. - let partial_agg_schema = create_schema( + let spill_schema = Arc::new(create_schema( &agg.input().schema(), &agg_group_by, &aggregate_exprs, AggregateMode::Partial, - )?; + )?); // Need to update the GROUP BY expressions to point to the correct column after schema change let merging_group_by_expr = agg_group_by @@ -525,34 +540,67 @@ impl GroupedHashAggregateStream { }) .collect(); - let partial_agg_schema = Arc::new(partial_agg_schema); + let output_ordering = agg.cache.output_ordering(); - let spill_expr = + let spill_sort_exprs = group_schema .fields .into_iter() .enumerate() .map(|(idx, field)| { - PhysicalSortExpr::new_default(Arc::new(Column::new( - field.name().as_str(), - idx, - )) as _) + let output_expr = Column::new(field.name().as_str(), idx); + + // Try to use the sort options from the output ordering, if available. + // This ensures that spilled state is sorted in the required order as well. + let sort_options = output_ordering + .and_then(|o| o.get_sort_options(&output_expr)) + .unwrap_or_default(); + + PhysicalSortExpr::new(Arc::new(output_expr), sort_options) }); - let Some(spill_expr) = LexOrdering::new(spill_expr) else { + let Some(spill_ordering) = LexOrdering::new(spill_sort_exprs) else { return internal_err!("Spill expression is empty"); }; let agg_fn_names = aggregate_exprs .iter() - .map(|expr| expr.human_display()) + .map(|expr| { + format_human_display(expr.human_display(), expr.human_display_alias()) + .map(|display| display.into_owned()) + .unwrap_or_else(|| expr.name().to_string()) + }) .collect::>() .join(", "); let name = format!("GroupedHashAggregateStream[{partition}] ({agg_fn_names})"); - let reservation = MemoryConsumer::new(name) - .with_can_spill(true) - .register(context.memory_pool()); let group_ordering = GroupOrdering::try_new(&agg.input_order_mode)?; + let oom_mode = match (agg.mode, &group_ordering) { + // In partial aggregation mode, always prefer to emit incomplete results early. + (AggregateMode::Partial, _) => OutOfMemoryMode::EmitEarly, + // For non-partial aggregation modes, emitting incomplete results is not an option. + // Instead, use disk spilling to store sorted, incomplete results, and merge them + // afterwards. + (_, GroupOrdering::None | GroupOrdering::Partial(_)) + if context.runtime_env().disk_manager.tmp_files_enabled() => + { + OutOfMemoryMode::Spill + } + // For `GroupOrdering::Full`, the incoming stream is already sorted. This ensures the + // number of incomplete groups can be kept small at all times. If we still hit + // an out-of-memory condition, spilling to disk would not be beneficial since the same + // situation is likely to reoccur when reading back the spilled data. + // Therefore, we fall back to simply reporting the error immediately. + // This mode will also be used if the `DiskManager` is not configured to allow spilling + // to disk. + _ => OutOfMemoryMode::ReportError, + }; + let group_values = new_group_values(group_schema, &group_ordering)?; + let reservation = MemoryConsumer::new(name) + // We interpret 'can spill' as 'can handle memory back pressure'. + // This value needs to be set to true for the default memory pool implementations + // to ensure fair application of back pressure amongst the memory consumers. + .with_can_spill(oom_mode != OutOfMemoryMode::ReportError) + .register(context.memory_pool()); timer.done(); let exec_state = ExecutionState::ReadingInput; @@ -560,18 +608,19 @@ impl GroupedHashAggregateStream { let spill_manager = SpillManager::new( context.runtime_env(), metrics::SpillMetrics::new(&agg.metrics, partition), - Arc::clone(&partial_agg_schema), + Arc::clone(&spill_schema), ) .with_compression_type(context.session_config().spill_compression()); let spill_state = SpillState { spills: vec![], - spill_expr, - spill_schema: partial_agg_schema, + spill_expr: spill_ordering, + spill_schema, is_stream_merging: false, merging_aggregate_arguments, merging_group_by: PhysicalGroupBy::new_single(merging_group_by_expr), peak_mem_used: MetricBuilder::new(&agg.metrics) + .with_category(MetricCategory::Bytes) .gauge("peak_mem_used", partition), spill_manager, }; @@ -595,13 +644,20 @@ impl GroupedHashAggregateStream { options.skip_partial_aggregation_probe_rows_threshold; let probe_ratio_threshold = options.skip_partial_aggregation_probe_ratio_threshold; - let skipped_aggregation_rows = MetricBuilder::new(&agg.metrics) - .counter("skipped_aggregation_rows", partition); - Some(SkipAggregationProbe::new( - probe_rows_threshold, - probe_ratio_threshold, - skipped_aggregation_rows, - )) + // A threshold >= 1.0 means the ratio (num_groups / input_rows) can + // never exceed it, so the feature is effectively disabled. + if probe_ratio_threshold >= 1.0 { + None + } else { + let skipped_aggregation_rows = MetricBuilder::new(&agg.metrics) + .with_category(MetricCategory::Rows) + .counter("skipped_aggregation_rows", partition); + Some(SkipAggregationProbe::new( + probe_rows_threshold, + probe_ratio_threshold, + skipped_aggregation_rows, + )) + } } else { None }; @@ -609,7 +665,7 @@ impl GroupedHashAggregateStream { let reduction_factor = if agg.mode == AggregateMode::Partial { Some( MetricBuilder::new(&agg.metrics) - .with_type(metrics::MetricType::SUMMARY) + .with_type(metrics::MetricType::Summary) .ratio_metrics("reduction_factor", partition), ) } else { @@ -618,6 +674,7 @@ impl GroupedHashAggregateStream { Ok(GroupedHashAggregateStream { schema: agg_schema, + input_schema: agg.input().schema(), input, mode: agg.mode, accumulators, @@ -625,6 +682,7 @@ impl GroupedHashAggregateStream { filter_expressions, group_by: agg_group_by, reservation, + oom_mode, group_values, current_group_indices: Default::default(), exec_state, @@ -634,7 +692,7 @@ impl GroupedHashAggregateStream { group_ordering, input_done: false, spill_state, - group_values_soft_limit: agg.limit, + group_values_soft_limit: agg.limit_options().map(|config| config.limit()), skip_aggregation_probe, reduction_factor, }) @@ -674,21 +732,24 @@ impl Stream for GroupedHashAggregateStream { match &self.exec_state { ExecutionState::ReadingInput => 'reading_input: { match ready!(self.input.poll_next_unpin(cx)) { - // New batch to aggregate in partial aggregation operator - Some(Ok(batch)) if self.mode == AggregateMode::Partial => { + // New batch to aggregate + Some(Ok(batch)) => { let timer = elapsed_compute.timer(); let input_rows = batch.num_rows(); - if let Some(reduction_factor) = self.reduction_factor.as_ref() + if self.mode == AggregateMode::Partial + && let Some(reduction_factor) = + self.reduction_factor.as_ref() { reduction_factor.add_total(input_rows); } - // Do the grouping + // Do the grouping. + // `group_aggregate_batch` will _not_ have updated the memory reservation yet. + // The rest of the code will first try to reduce memory usage by + // already emitting results. self.group_aggregate_batch(&batch)?; - // If we can begin emitting rows, do so, - // otherwise keep consuming input assert!(!self.input_done); // If the number of group values equals or exceeds the soft limit, @@ -700,7 +761,13 @@ impl Stream for GroupedHashAggregateStream { break 'reading_input; } - if let Some(to_emit) = self.group_ordering.emit_to() { + // Try to emit completed groups if possible. + // If we already started spilling, we can no longer emit since + // this might lead to incorrect output ordering + if (self.spill_state.spills.is_empty() + || self.spill_state.is_stream_merging) + && let Some(to_emit) = self.group_ordering.emit_to() + { timer.done(); if let Some(batch) = self.emit(to_emit, false)? { self.exec_state = @@ -710,18 +777,28 @@ impl Stream for GroupedHashAggregateStream { break 'reading_input; } - // Check if we should switch to skip aggregation mode - // It's important that we do this before we early emit since we've - // already updated the probe. - self.update_skip_aggregation_probe(input_rows); - if let Some(new_state) = self.switch_to_skip_aggregation()? { - timer.done(); - self.exec_state = new_state; - break 'reading_input; + if self.mode == AggregateMode::Partial { + // Spilling should never be activated in partial aggregation mode. + assert!(!self.spill_state.is_stream_merging); + + // Check if we should switch to skip aggregation mode + // It's important that we do this before we early emit since we've + // already updated the probe. + self.update_skip_aggregation_probe(input_rows); + if let Some(new_state) = + self.switch_to_skip_aggregation()? + { + timer.done(); + self.exec_state = new_state; + break 'reading_input; + } } - // Check if we need to emit early due to memory pressure - if let Some(new_state) = self.emit_early_if_necessary()? { + // If we reach this point, try to update the memory reservation + // handling out-of-memory conditions as determined by the OOM mode. + if let Some(new_state) = + self.try_update_memory_reservation()? + { timer.done(); self.exec_state = new_state; break 'reading_input; @@ -730,43 +807,6 @@ impl Stream for GroupedHashAggregateStream { timer.done(); } - // New batch to aggregate in terminal aggregation operator - // (Final/FinalPartitioned/Single/SinglePartitioned) - Some(Ok(batch)) => { - let timer = elapsed_compute.timer(); - - // Make sure we have enough capacity for `batch`, otherwise spill - self.spill_previous_if_necessary(&batch)?; - - // Do the grouping - self.group_aggregate_batch(&batch)?; - - // If we can begin emitting rows, do so, - // otherwise keep consuming input - assert!(!self.input_done); - - // If the number of group values equals or exceeds the soft limit, - // emit all groups and switch to producing output - if self.hit_soft_group_limit() { - timer.done(); - self.set_input_done_and_produce_output()?; - // make sure the exec_state just set is not overwritten below - break 'reading_input; - } - - if let Some(to_emit) = self.group_ordering.emit_to() { - timer.done(); - if let Some(batch) = self.emit(to_emit, false)? { - self.exec_state = - ExecutionState::ProducingOutput(batch); - }; - // make sure the exec_state just set is not overwritten below - break 'reading_input; - } - - timer.done(); - } - // Found error from input stream Some(Err(e)) => { // inner had error, return to caller @@ -808,6 +848,10 @@ impl Stream for GroupedHashAggregateStream { self.group_values.len() ))); } + // Release the input pipeline's resources. + let input_schema = self.input.schema(); + self.input = + Box::pin(EmptyRecordBatchStream::new(input_schema)); self.exec_state = ExecutionState::Done; } } @@ -860,7 +904,8 @@ impl Stream for GroupedHashAggregateStream { return Poll::Ready(Some(internal_err!( "AggregateStream was in Done state with {} groups left in hash table. \ This is a bug - all groups should have been emitted before entering Done state.", - self.group_values.len()))); + self.group_values.len() + ))); } // release the memory reservation since sending back output batch itself needs // some memory reservation, so make some room for it. @@ -954,29 +999,24 @@ impl GroupedHashAggregateStream { // Call the appropriate method on each aggregator with // the entire input row and the relevant group indexes - match self.mode { - AggregateMode::Partial - | AggregateMode::Single - | AggregateMode::SinglePartitioned - if !self.spill_state.is_stream_merging => - { - acc.update_batch( - values, - group_indices, - opt_filter, - total_num_groups, - )?; - } - _ => { - assert_or_internal_err!( - opt_filter.is_none(), - "aggregate filter should be applied in partial stage, there should be no filter in final stage" - ); - - // if aggregation is over intermediate states, - // use merge - acc.merge_batch(values, group_indices, None, total_num_groups)?; - } + if self.mode.input_mode() == AggregateInputMode::Raw + && !self.spill_state.is_stream_merging + { + acc.update_batch( + values, + group_indices, + opt_filter, + total_num_groups, + )?; + } else { + assert_or_internal_err!( + opt_filter.is_none(), + "aggregate filter should be applied in partial stage, there should be no filter in final stage" + ); + + // if aggregation is over intermediate states, + // use merge + acc.merge_batch(values, group_indices, None, total_num_groups)?; } self.group_by_metrics .aggregation_time @@ -984,25 +1024,76 @@ impl GroupedHashAggregateStream { } } - match self.update_memory_reservation() { - // Here we can ignore `insufficient_capacity_err` because we will spill later, - // but at least one batch should fit in the memory - Err(DataFusionError::ResourcesExhausted(_)) - if self.group_values.len() >= self.batch_size => - { - Ok(()) + Ok(()) + } + + /// Attempts to update the memory reservation. If that fails due to a + /// [DataFusionError::ResourcesExhausted] error, an attempt will be made to resolve + /// the out-of-memory condition based on the [out-of-memory handling mode](OutOfMemoryMode). + /// + /// If the out-of-memory condition can not be resolved, an `Err` value will be returned + /// + /// Returns `Ok(Some(ExecutionState))` if the state should be changed, `Ok(None)` otherwise. + fn try_update_memory_reservation(&mut self) -> Result> { + let oom = match self.update_memory_reservation() { + Err(e @ DataFusionError::ResourcesExhausted(_)) => e, + Err(e) => return Err(e), + Ok(_) => return Ok(None), + }; + + match self.oom_mode { + OutOfMemoryMode::Spill if !self.group_values.is_empty() => { + self.spill()?; + self.clear_shrink(self.batch_size); + self.update_memory_reservation()?; + Ok(None) } - other => other, + OutOfMemoryMode::EmitEarly if self.group_values.len() > 1 => { + let n = if self.group_values.len() >= self.batch_size { + // Try to emit an integer multiple of batch size if possible + self.group_values.len() / self.batch_size * self.batch_size + } else { + // Otherwise emit whatever we can + self.group_values.len() + }; + + if let Some(emit_to) = self.group_ordering.oom_emit_to(n) + && let Some(batch) = self.emit(emit_to, false)? + { + return Ok(Some(ExecutionState::ProducingOutput(batch))); + } + Err(oom) + } + OutOfMemoryMode::EmitEarly + | OutOfMemoryMode::Spill + | OutOfMemoryMode::ReportError => Err(oom), } } fn update_memory_reservation(&mut self) -> Result<()> { let acc = self.accumulators.iter().map(|x| x.size()).sum::(); - let reservation_result = self.reservation.try_resize( - acc + self.group_values.size() - + self.group_ordering.size() - + self.current_group_indices.allocated_size(), - ); + let groups_and_acc_size = acc + + self.group_values.size() + + self.group_ordering.size() + + self.current_group_indices.allocated_size(); + + // Reserve extra headroom for sorting during potential spill. + // When OOM triggers, group_aggregate_batch has already processed the + // latest input batch, so the internal state may have grown well beyond + // the last successful reservation. The emit batch reflects this larger + // actual state, and the sort needs memory proportional to it. + // By reserving headroom equal to the data size, we trigger OOM earlier + // (before too much data accumulates), ensuring the freed reservation + // after clear_shrink is sufficient to cover the sort memory. + let sort_headroom = + if self.oom_mode == OutOfMemoryMode::Spill && !self.group_values.is_empty() { + acc + self.group_values.size() + } else { + 0 + }; + + let new_size = groups_and_acc_size + sort_headroom; + let reservation_result = self.reservation.try_resize(new_size); if reservation_result.is_ok() { self.spill_state @@ -1033,17 +1124,12 @@ impl GroupedHashAggregateStream { // Next output each aggregate value for acc in self.accumulators.iter_mut() { - match self.mode { - AggregateMode::Partial => output.extend(acc.state(emit_to)?), - _ if spilling => { - // If spilling, output partial state because the spilled data will be - // merged and re-evaluated later. - output.extend(acc.state(emit_to)?) - } - AggregateMode::Final - | AggregateMode::FinalPartitioned - | AggregateMode::Single - | AggregateMode::SinglePartitioned => output.push(acc.evaluate(emit_to)?), + if self.mode.output_mode() == AggregateOutputMode::Final && !spilling { + output.push(acc.evaluate(emit_to)?) + } else { + // Output partial state: either because we're in a non-final mode, + // or because we're spilling and will merge/re-evaluate later. + output.extend(acc.state(emit_to)?) } } drop(timer); @@ -1057,21 +1143,101 @@ impl GroupedHashAggregateStream { Ok(Some(batch)) } - /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly - /// (~ 1 [`RecordBatch`]) for simplicity. In such cases, spill the data to disk and clear the - /// memory. Currently only [`GroupOrdering::None`] is supported for spilling. - fn spill_previous_if_necessary(&mut self, batch: &RecordBatch) -> Result<()> { - // TODO: support group_ordering for spilling - if !self.group_values.is_empty() - && batch.num_rows() > 0 - && matches!(self.group_ordering, GroupOrdering::None) - && !self.spill_state.is_stream_merging - && self.update_memory_reservation().is_err() - { - assert_ne!(self.mode, AggregateMode::Partial); - self.spill()?; - self.clear_shrink(batch); + /// Registers groups for empty grouping sets when no input rows were seen. + /// + /// `GROUP BY GROUPING SETS (())` must always produce one row even when there + /// are no input rows (standard SQL semantics for a "grand total" group). + /// Mixed grouping sets like `GROUPING SETS (a, ())` also produce one row for + /// the empty set `()` on empty input. + /// + /// This method interns the group keys and primes the accumulators so they + /// produce their zero-row aggregate values (e.g. `NULL` for `SUM`, + /// `0` for `COUNT`). + fn init_empty_grouping_sets(&mut self) -> Result<()> { + if !self.group_by.has_grouping_set() || !self.group_values.is_empty() { + return Ok(()); } + + let max_ordinal = max_duplicate_ordinal(self.group_by.groups()); + let mut ordinals: std::collections::HashMap<&[bool], usize> = + std::collections::HashMap::new(); + let group_schema = self.group_by.group_schema(&self.input_schema)?; + let n_expr = self.group_by.expr().len(); + let mut any_interned = false; + + for group in self.group_by.groups() { + let ordinal = { + let entry = ordinals.entry(group.as_slice()).or_insert(0); + let o = *entry; + *entry += 1; + o + }; + + if !group.iter().all(|&is_null| is_null) { + continue; + } + + // Build the group key: one NULL per group-by expression, then the grouping_id. + let mut cols: Vec = group_schema + .fields() + .iter() + .take(n_expr) + .map(|f| new_null_array(f.data_type(), 1)) + .collect(); + cols.push(group_id_array(group, ordinal, max_ordinal, 1)?); + + let starting_groups = self.group_values.len(); + self.group_values + .intern(&cols, &mut self.current_group_indices)?; + let total_groups = self.group_values.len(); + if total_groups > starting_groups { + self.group_ordering.new_groups( + &cols, + &self.current_group_indices, + total_groups, + )?; + } + any_interned = true; + } + + if any_interned { + // Prime each accumulator for the registered group count with no data. + // + // We build 1-row null arrays for each aggregate argument and pass them + // with an all-false filter. The filter ensures no row is accumulated + // into any group, which keeps every group in its "zero" initial state + // (NULL for SUM/AVG/MIN/MAX, 0 for COUNT). + // + // Using a 1-row batch rather than 0 rows is required to avoid a fast + // path in `NullState::accumulate` that treats "0 nulls in a 0-row + // array" as "all groups have been seen", which would cause SUM to + // return 0 instead of NULL. + // + // Argument types are inferred directly from the expression metadata so + // we never need to construct a full `RecordBatch`. + let total_groups = self.group_values.len(); + let null_args: Vec> = self + .aggregate_arguments + .iter() + .map(|args| { + args.iter() + .map(|expr| { + let dt = expr.data_type(&self.input_schema)?; + Ok(new_null_array(&dt, 1)) + }) + .collect::>>() + }) + .collect::>>()?; + let false_filter = BooleanArray::from(vec![false]); + for (acc, args) in self.accumulators.iter_mut().zip(null_args.iter()) { + if self.mode.input_mode() == AggregateInputMode::Raw { + acc.update_batch(args, &[0], Some(&false_filter), total_groups)?; + } else { + acc.merge_batch(args, &[0], Some(&false_filter), total_groups)?; + } + } + } + Ok(()) } @@ -1083,17 +1249,47 @@ impl GroupedHashAggregateStream { let Some(emit) = self.emit(EmitTo::All, true)? else { return Ok(()); }; - let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?; - // Spill sorted state to disk + // Free accumulated state now that data has been emitted into `emit`. + // This must happen before reserving sort memory so the pool has room. + // Use 0 to minimize allocated capacity and maximize memory available for sorting. + self.clear_shrink(0); + self.update_memory_reservation()?; + + let batch_size_ratio = self.batch_size as f32 / emit.num_rows() as f32; + let batch_memory = get_record_batch_memory_size(&emit); + // The maximum worst case for a sort is 2X the original underlying buffers(regardless of slicing) + // First we get the underlying buffers' size, then we get the sliced("actual") size of the batch, + // and multiply it by the ratio of batch_size to actual size to get the estimated memory needed for sorting the batch. + // If something goes wrong in get_sliced_size()(double counting or something), + // we fall back to the worst case. + let sort_memory = (batch_memory + + (emit.get_sliced_size()? as f32 * batch_size_ratio) as usize) + .min(batch_memory * 2); + + // If we can't grow even that, we have no choice but to return an error since we can't spill to disk without sorting the data first. + self.reservation.try_grow(sort_memory).map_err(|err| { + resources_datafusion_err!( + "Failed to reserve memory for sort during spill: {err}" + ) + })?; + + let sorted_iter = IncrementalSortIterator::new( + emit, + self.spill_state.spill_expr.clone(), + self.batch_size, + ); let spillfile = self .spill_state .spill_manager - .spill_record_batch_by_size_and_return_max_batch_memory( - &sorted, + .spill_record_batch_iter_and_return_max_batch_memory( + sorted_iter, "HashAggSpill", - self.batch_size, )?; + + // Shrink the memory we allocated for sorting as the sorting is fully done at this point. + self.reservation.shrink(sort_memory); + match spillfile { Some((spillfile, max_record_batch_memory)) => { self.spill_state.spills.push(SortedSpillFile { @@ -1111,73 +1307,16 @@ impl GroupedHashAggregateStream { Ok(()) } - /// Clear memory and shirk capacities to the size of the batch. - fn clear_shrink(&mut self, batch: &RecordBatch) { - self.group_values.clear_shrink(batch); + /// Clear memory and shrink capacities to the given number of rows. + fn clear_shrink(&mut self, num_rows: usize) { + self.group_values.clear_shrink(num_rows); self.current_group_indices.clear(); - self.current_group_indices.shrink_to(batch.num_rows()); + self.current_group_indices.shrink_to(num_rows); } - /// Clear memory and shirk capacities to zero. + /// Clear memory and shrink capacities to zero. fn clear_all(&mut self) { - let s = self.schema(); - self.clear_shrink(&RecordBatch::new_empty(s)); - } - - /// Emit if the used memory exceeds the target for partial aggregation. - /// Currently only [`GroupOrdering::None`] is supported for early emitting. - /// TODO: support group_ordering for early emitting - /// - /// Returns `Some(ExecutionState)` if the state should be changed, None otherwise. - fn emit_early_if_necessary(&mut self) -> Result> { - if self.group_values.len() >= self.batch_size - && matches!(self.group_ordering, GroupOrdering::None) - && self.update_memory_reservation().is_err() - { - assert_eq!(self.mode, AggregateMode::Partial); - let n = self.group_values.len() / self.batch_size * self.batch_size; - if let Some(batch) = self.emit(EmitTo::First(n), false)? { - return Ok(Some(ExecutionState::ProducingOutput(batch))); - }; - } - Ok(None) - } - - /// At this point, all the inputs are read and there are some spills. - /// Emit the remaining rows and create a batch. - /// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully - /// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`]. - fn update_merged_stream(&mut self) -> Result<()> { - let Some(batch) = self.emit(EmitTo::All, true)? else { - return Ok(()); - }; - // clear up memory for streaming_merge - self.clear_all(); - self.update_memory_reservation()?; - let mut streams: Vec = vec![]; - let expr = self.spill_state.spill_expr.clone(); - let schema = batch.schema(); - streams.push(Box::pin(RecordBatchStreamAdapter::new( - Arc::clone(&schema), - futures::stream::once(futures::future::lazy(move |_| { - sort_batch(&batch, &expr, None) - })), - ))); - - self.spill_state.is_stream_merging = true; - self.input = StreamingMergeBuilder::new() - .with_streams(streams) - .with_schema(schema) - .with_spill_manager(self.spill_state.spill_manager.clone()) - .with_sorted_spill_files(std::mem::take(&mut self.spill_state.spills)) - .with_expressions(&self.spill_state.spill_expr) - .with_metrics(self.baseline_metrics.clone()) - .with_batch_size(self.batch_size) - .with_reservation(self.reservation.new_empty()) - .build()?; - self.input_done = false; - self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new()); - Ok(()) + self.clear_shrink(0); } /// returns true if there is a soft groups limit and the number of distinct @@ -1189,18 +1328,78 @@ impl GroupedHashAggregateStream { group_values_soft_limit <= self.group_values.len() } - /// common function for signalling end of processing of the input stream + /// Finalizes reading of the input stream and prepares for producing output values. + /// + /// This method is called both when the original input stream and, + /// in case of disk spilling, the SPM stream have been drained. fn set_input_done_and_produce_output(&mut self) -> Result<()> { self.input_done = true; self.group_ordering.input_done(); + // Release the original input pipeline's resources now that we're done + // reading from it. In the spill branch below, `self.input` is replaced + // again with a stream that merges spill files. + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); let timer = elapsed_compute.timer(); self.exec_state = if self.spill_state.spills.is_empty() { + // Input has been entirely processed without spilling to disk. + self.init_empty_grouping_sets()?; + + // Flush any remaining group values. let batch = self.emit(EmitTo::All, false)?; + + // If there are none, we're done; otherwise switch to emitting them batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput) } else { - // If spill files exist, stream-merge them. - self.update_merged_stream()?; + // Spill any remaining data to disk. There is some performance overhead in + // writing out this last chunk of data and reading it back. The benefit of + // doing this is that memory usage for this stream is reduced, and the more + // sophisticated memory handling in `MultiLevelMergeBuilder` can take over + // instead. + // Spilling to disk and reading back also ensures batch size is consistent + // rather than potentially having one significantly larger last batch. + self.spill()?; + + // Mark that we're switching to stream merging mode. + self.spill_state.is_stream_merging = true; + + self.input = StreamingMergeBuilder::new() + .with_schema(Arc::clone(&self.spill_state.spill_schema)) + .with_spill_manager(self.spill_state.spill_manager.clone()) + .with_sorted_spill_files(std::mem::take(&mut self.spill_state.spills)) + .with_expressions(&self.spill_state.spill_expr) + .with_metrics(self.baseline_metrics.clone()) + .with_batch_size(self.batch_size) + .with_reservation(self.reservation.new_empty()) + .build()?; + self.input_done = false; + + // Reset the group values collectors. + self.clear_all(); + + // We can now use `GroupOrdering::Full` since the spill files are sorted + // on the grouping columns. + self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new()); + + // Recreate `group_values` for streaming merge so group ids are assigned + // in first-seen order, as required by `GroupOrderingFull`. + // The pre-spill multi-column collector may use `vectorized_intern`, which + // can assign new group ids out of input order under hash collisions. + let group_schema = self + .spill_state + .merging_group_by + .group_schema(&self.spill_state.spill_schema)?; + if group_schema.fields().len() > 1 { + self.group_values = new_group_values(group_schema, &self.group_ordering)?; + } + + // Use `OutOfMemoryMode::ReportError` from this point on + // to ensure we don't spill the spilled data to disk again. + self.oom_mode = OutOfMemoryMode::ReportError; + + self.update_memory_reservation()?; + ExecutionState::ReadingInput }; timer.done(); @@ -1226,13 +1425,12 @@ impl GroupedHashAggregateStream { /// /// Returns `Some(ExecutionState)` if the state should be changed, None otherwise. fn switch_to_skip_aggregation(&mut self) -> Result> { - if let Some(probe) = self.skip_aggregation_probe.as_mut() { - if probe.should_skip() { - if let Some(batch) = self.emit(EmitTo::All, false)? { - return Ok(Some(ExecutionState::ProducingOutput(batch))); - }; - } - } + if let Some(probe) = self.skip_aggregation_probe.as_mut() + && probe.should_skip() + && let Some(batch) = self.emit(EmitTo::All, false)? + { + return Ok(Some(ExecutionState::ProducingOutput(batch))); + }; Ok(None) } @@ -1280,15 +1478,15 @@ impl GroupedHashAggregateStream { #[cfg(test)] mod tests { use super::*; + use crate::InputOrderMode; + use crate::execution_plan::ExecutionPlan; use crate::test::TestMemoryExec; use arrow::array::{Int32Array, Int64Array}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; - use datafusion_execution::TaskContext; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::col; - use std::sync::Arc; #[tokio::test] async fn test_double_emission_race_condition_bug() -> Result<()> { @@ -1395,4 +1593,255 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_skip_aggregation_probe_not_locked_until_skip() -> Result<()> { + // Test that the probe is not locked until we actually decide to skip. + // This allows us to continue evaluating the skip condition across multiple batches. + // + // Scenario: + // - Batch 1: Hits rows threshold but NOT ratio threshold (low cardinality) -> don't skip + // - Batch 2: Now hits ratio threshold (high cardinality) -> skip + // + // Without the fix, the probe would be locked after batch 1, preventing the skip + // decision from being made on batch 2. + + let schema = Arc::new(Schema::new(vec![ + Field::new("group_col", DataType::Int32, false), + Field::new("value_col", DataType::Int32, false), + ])); + + // Configure thresholds: + // - probe_rows_threshold: 100 rows + // - probe_ratio_threshold: 0.8 (80%) + let probe_rows_threshold = 100; + let probe_ratio_threshold = 0.8; + + // Batch 1: 100 rows with only 10 unique groups + // Ratio: 10/100 = 0.1 (10%) < 0.8 -> should NOT skip + // This will hit the rows threshold but not the ratio threshold + let batch1_rows = 100; + let batch1_groups = 10; + let mut group_ids_batch1 = Vec::new(); + for i in 0..batch1_rows { + group_ids_batch1.push((i % batch1_groups) as i32); + } + let values_batch1: Vec = vec![1; batch1_rows]; + + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(group_ids_batch1)), + Arc::new(Int32Array::from(values_batch1)), + ], + )?; + + // Batch 2: 360 rows with 360 unique NEW groups (starting from group 10) + // After batch 2, total: 460 rows, 370 groups + // Ratio: 370/460 ≈ 0.804 (80.4%) > 0.8 -> SHOULD decide to skip + let batch2_rows = 360; + let batch2_groups = 360; + let group_ids_batch2: Vec = (batch1_groups..(batch1_groups + batch2_groups)) + .map(|x| x as i32) + .collect(); + let values_batch2: Vec = vec![1; batch2_rows]; + + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(group_ids_batch2)), + Arc::new(Int32Array::from(values_batch2)), + ], + )?; + + // Batch 3: This batch should be skipped since we decided to skip after batch 2 + // 100 rows with 100 unique groups (continuing from where batch 2 left off) + let batch3_rows = 100; + let batch3_groups = 100; + let batch3_start_group = batch1_groups + batch2_groups; + let group_ids_batch3: Vec = (batch3_start_group + ..(batch3_start_group + batch3_groups)) + .map(|x| x as i32) + .collect(); + let values_batch3: Vec = vec![1; batch3_rows]; + + let batch3 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(group_ids_batch3)), + Arc::new(Int32Array::from(values_batch3)), + ], + )?; + + let input_partitions = vec![vec![batch1, batch2, batch3]]; + + let runtime = RuntimeEnvBuilder::default().build_arc()?; + let mut task_ctx = TaskContext::default().with_runtime(runtime); + + // Configure skip aggregation settings + let mut session_config = task_ctx.session_config().clone(); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", + &datafusion_common::ScalarValue::UInt64(Some(probe_rows_threshold)), + ); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold", + &datafusion_common::ScalarValue::Float64(Some(probe_ratio_threshold)), + ); + task_ctx = task_ctx.with_session_config(session_config); + let task_ctx = Arc::new(task_ctx); + + // Create aggregate: COUNT(*) GROUP BY group_col + let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())]; + let aggr_expr = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("count_value") + .build()?, + )]; + + let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?; + let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec))); + + // Use Partial mode + let aggregate_exec = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(group_expr), + aggr_expr, + vec![None], + exec, + Arc::clone(&schema), + )?; + + // Execute and collect results + let mut stream = + GroupedHashAggregateStream::new(&aggregate_exec, &Arc::clone(&task_ctx), 0)?; + let mut results = Vec::new(); + + while let Some(result) = stream.next().await { + let batch = result?; + results.push(batch); + } + + // Check that skip aggregation actually happened + // The key metric is skipped_aggregation_rows + let metrics = aggregate_exec.metrics().unwrap(); + let skipped_rows = metrics + .sum_by_name("skipped_aggregation_rows") + .map(|m| m.as_usize()) + .unwrap_or(0); + + // We expect batch 3's rows to be skipped (100 rows) + assert_eq!( + skipped_rows, batch3_rows, + "Expected batch 3's rows ({batch3_rows}) to be skipped", + ); + + Ok(()) + } + + #[tokio::test] + async fn test_emit_early_with_partially_sorted() -> Result<()> { + // Reproducer for #20445: EmitEarly with PartiallySorted panics in + // remove_groups because it emits more groups than the sort boundary. + let schema = Arc::new(Schema::new(vec![ + Field::new("sort_col", DataType::Int32, false), + Field::new("group_col", DataType::Int32, false), + Field::new("value_col", DataType::Int64, false), + ])); + + // All rows share sort_col=1 (no sort boundary), with unique group_col + // values to create many groups and trigger memory pressure. + let n = 256; + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1; n])), + Arc::new(Int32Array::from((0..n as i32).collect::>())), + Arc::new(Int64Array::from(vec![1; n])), + ], + )?; + + let runtime = RuntimeEnvBuilder::default() + .with_memory_limit(4096, 1.0) + .build_arc()?; + let mut task_ctx = TaskContext::default().with_runtime(runtime); + let mut cfg = task_ctx.session_config().clone(); + cfg = cfg.set( + "datafusion.execution.batch_size", + &datafusion_common::ScalarValue::UInt64(Some(128)), + ); + cfg = cfg.set( + "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", + &datafusion_common::ScalarValue::UInt64(Some(u64::MAX)), + ); + task_ctx = task_ctx.with_session_config(cfg); + let task_ctx = Arc::new(task_ctx); + + let ordering = LexOrdering::new(vec![PhysicalSortExpr::new_default(Arc::new( + Column::new("sort_col", 0), + ) + as _)]) + .unwrap(); + let exec = TestMemoryExec::try_new(&[vec![batch]], Arc::clone(&schema), None)? + .try_with_sort_information(vec![ordering])?; + let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec))); + + // GROUP BY sort_col, group_col with input sorted on sort_col + // gives PartiallySorted([0]) + let aggregate_exec = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![ + (col("sort_col", &schema)?, "sort_col".to_string()), + (col("group_col", &schema)?, "group_col".to_string()), + ]), + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("count_value") + .build()?, + )], + vec![None], + exec, + Arc::clone(&schema), + )?; + assert!(matches!( + aggregate_exec.input_order_mode(), + InputOrderMode::PartiallySorted(_) + )); + + // Must not panic with "assertion failed: *current_sort >= n" + let mut stream = GroupedHashAggregateStream::new(&aggregate_exec, &task_ctx, 0)?; + while let Some(result) = stream.next().await { + if let Err(e) = result { + if e.to_string().contains("Resources exhausted") { + break; + } + return Err(e); + } + } + + Ok(()) + } + + #[test] + fn test_skip_aggregation_probe_equality_does_not_skip() { + // When num_groups / input_rows == probe_ratio_threshold, the `>` boundary + // means we must NOT skip — equality is not sufficient to trigger skip. + let threshold_ratio = 0.5_f64; + let threshold_rows = 10_usize; + let mut probe = SkipAggregationProbe::new( + threshold_rows, + threshold_ratio, + metrics::Count::new(), + ); + + // 10 rows, 5 groups → ratio = 5/10 = 0.5 exactly equals threshold + probe.update_state(10, 5); + + assert!( + !probe.should_skip(), + "ratio == threshold should not trigger skip (boundary is exclusive)" + ); + } } diff --git a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs index 974aea3b6292c..694780f08547f 100644 --- a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs @@ -15,22 +15,23 @@ // specific language governing permissions and limitations // under the License. -//! A wrapper around `hashbrown::RawTable` that allows entries to be tracked by index +//! A wrapper around `hashbrown::HashTable` that allows entries to be tracked by index use crate::aggregates::group_values::HashValue; use crate::aggregates::topk::heap::Comparable; -use ahash::RandomState; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; use arrow::array::{ - builder::PrimitiveBuilder, cast::AsArray, downcast_primitive, Array, ArrayRef, - ArrowPrimitiveType, LargeStringArray, PrimitiveArray, StringArray, StringViewArray, + Array, ArrayRef, ArrowPrimitiveType, LargeStringArray, PrimitiveArray, StringArray, + StringViewArray, builder::PrimitiveBuilder, cast::AsArray, downcast_primitive, }; -use arrow::datatypes::{i256, DataType}; -use datafusion_common::exec_datafusion_err; +use arrow::datatypes::{DataType, i256}; use datafusion_common::Result; +use datafusion_common::exec_datafusion_err; +use datafusion_common::hash_utils::RandomState; use half::f16; -use hashbrown::raw::RawTable; +use hashbrown::hash_table::HashTable; use std::fmt::Debug; +use std::hash::BuildHasher; use std::sync::Arc; /// A "type alias" for Keys which are stored in our map @@ -48,13 +49,17 @@ pub struct HashTableItem { pub heap_idx: usize, } -/// A custom wrapper around `hashbrown::RawTable` that: +/// A custom wrapper around `hashbrown::HashTable` that: /// 1. limits the number of entries to the top K /// 2. Allocates a capacity greater than top K to maintain a low-fill factor and prevent resizing /// 3. Tracks indexes to allow corresponding heap to refer to entries by index vs hash -/// 4. Catches resize events to allow the corresponding heap to update it's indexes struct TopKHashTable { - map: RawTable>, + map: HashTable, + // Store the actual items separately to allow for index-based access + store: Vec>>, + // Free index in the store for reuse + free_index: Option, + // The maximum number of entries allowed limit: usize, } @@ -62,25 +67,23 @@ struct TopKHashTable { pub trait ArrowHashTable { fn set_batch(&mut self, ids: ArrayRef); fn len(&self) -> usize; - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: the caller must provide valid indexes - unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]); - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: the caller must provide a valid index - unsafe fn heap_idx_at(&self, map_idx: usize) -> usize; - unsafe fn take_all(&mut self, indexes: Vec) -> ArrayRef; - - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: the caller must provide valid indexes - unsafe fn find_or_insert( - &mut self, - row_idx: usize, - replace_idx: usize, - map: &mut Vec<(usize, usize)>, - ) -> (usize, bool); + fn update_heap_idx(&mut self, mapper: &[(usize, usize)]); + fn heap_idx_at(&self, map_idx: usize) -> usize; + fn take_all(&mut self, indexes: Vec) -> ArrayRef; + fn find_or_insert(&mut self, row_idx: usize, replace_idx: usize) -> (usize, bool); +} + +/// Returns true if the given data type can be used as a top-K aggregation hash key. +/// +/// Supported types include Arrow primitives (integers, floats, decimals, intervals) +/// and UTF-8 strings (`Utf8`, `LargeUtf8`, `Utf8View`). This is used internally by +/// `PriorityMap::supports()` to validate grouping key type compatibility. +pub fn is_supported_hash_key_type(kt: &DataType) -> bool { + kt.is_primitive() + || matches!( + kt, + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 + ) } // An implementation of ArrowHashTable for String keys @@ -119,6 +122,34 @@ impl StringHashTable { data_type, } } + + /// Extracts the string value at the given row index, handling nulls and different string types. + /// + /// Returns `None` if the value is null, otherwise `Some(value.to_string())`. + fn extract_string_value(&self, row_idx: usize) -> Option { + let is_null_and_value = match self.data_type { + DataType::Utf8 => { + let arr = self.owned.as_string::(); + (arr.is_null(row_idx), arr.value(row_idx)) + } + DataType::LargeUtf8 => { + let arr = self.owned.as_string::(); + (arr.is_null(row_idx), arr.value(row_idx)) + } + DataType::Utf8View => { + let arr = self.owned.as_string_view(); + (arr.is_null(row_idx), arr.value(row_idx)) + } + _ => panic!("Unsupported data type"), + }; + + let (is_null, value) = is_null_and_value; + if is_null { + None + } else { + Some(value.to_string()) + } + } } impl ArrowHashTable for StringHashTable { @@ -130,15 +161,15 @@ impl ArrowHashTable for StringHashTable { self.map.len() } - unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { self.map.update_heap_idx(mapper); } - unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { + fn heap_idx_at(&self, map_idx: usize) -> usize { self.map.heap_idx_at(map_idx) } - unsafe fn take_all(&mut self, indexes: Vec) -> ArrayRef { + fn take_all(&mut self, indexes: Vec) -> ArrayRef { let ids = self.map.take_all(indexes); match self.data_type { DataType::Utf8 => Arc::new(StringArray::from(ids)), @@ -148,67 +179,16 @@ impl ArrowHashTable for StringHashTable { } } - unsafe fn find_or_insert( - &mut self, - row_idx: usize, - replace_idx: usize, - mapper: &mut Vec<(usize, usize)>, - ) -> (usize, bool) { - let id = match self.data_type { - DataType::Utf8 => { - let ids = self - .owned - .as_any() - .downcast_ref::() - .expect("Expected StringArray for DataType::Utf8"); - if ids.is_null(row_idx) { - None - } else { - Some(ids.value(row_idx)) - } - } - DataType::LargeUtf8 => { - let ids = self - .owned - .as_any() - .downcast_ref::() - .expect("Expected LargeStringArray for DataType::LargeUtf8"); - if ids.is_null(row_idx) { - None - } else { - Some(ids.value(row_idx)) - } - } - DataType::Utf8View => { - let ids = self - .owned - .as_any() - .downcast_ref::() - .expect("Expected StringViewArray for DataType::Utf8View"); - if ids.is_null(row_idx) { - None - } else { - Some(ids.value(row_idx)) - } - } - _ => panic!("Unsupported data type"), - }; + fn find_or_insert(&mut self, row_idx: usize, replace_idx: usize) -> (usize, bool) { + let id = self.extract_string_value(row_idx); - let hash = self.rnd.hash_one(id); - if let Some(map_idx) = self - .map - .find(hash, |mi| id == mi.as_ref().map(|id| id.as_str())) - { - return (map_idx, false); - } + // Compute hash and create equality closure for hash table lookup. + let hash = self.rnd.hash_one(id.as_deref()); + let id_for_eq = id.clone(); + let eq = move |mi: &Option| id_for_eq.as_deref() == mi.as_deref(); - // we're full and this is a better value, so remove the worst - let heap_idx = self.map.remove_if_full(replace_idx); - - // add the new group - let id = id.map(|id| id.to_string()); - let map_idx = self.map.insert(hash, id, heap_idx, mapper); - (map_idx, true) + // Use entry API to avoid double lookup + self.map.find_or_insert(hash, id, replace_idx, eq) } } @@ -245,15 +225,15 @@ where self.map.len() } - unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { self.map.update_heap_idx(mapper); } - unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { + fn heap_idx_at(&self, map_idx: usize) -> usize { self.map.heap_idx_at(map_idx) } - unsafe fn take_all(&mut self, indexes: Vec) -> ArrayRef { + fn take_all(&mut self, indexes: Vec) -> ArrayRef { let ids = self.map.take_all(indexes); let mut builder: PrimitiveBuilder = PrimitiveArray::builder(ids.len()).with_data_type(self.kt.clone()); @@ -267,112 +247,117 @@ where Arc::new(ids) } - unsafe fn find_or_insert( - &mut self, - row_idx: usize, - replace_idx: usize, - mapper: &mut Vec<(usize, usize)>, - ) -> (usize, bool) { + fn find_or_insert(&mut self, row_idx: usize, replace_idx: usize) -> (usize, bool) { let ids = self.owned.as_primitive::(); let id: Option = if ids.is_null(row_idx) { None } else { Some(ids.value(row_idx)) }; - + // Compute hash and create equality closure for hash table lookup. let hash: u64 = id.hash(&self.rnd); - if let Some(map_idx) = self.map.find(hash, |mi| id == *mi) { - return (map_idx, false); - } + let eq = |mi: &Option| id == *mi; - // we're full and this is a better value, so remove the worst - let heap_idx = self.map.remove_if_full(replace_idx); - - // add the new group - let map_idx = self.map.insert(hash, id, heap_idx, mapper); - (map_idx, true) + // Use entry API to avoid double lookup + self.map.find_or_insert(hash, id, replace_idx, eq) } } -impl TopKHashTable { +use hashbrown::hash_table::Entry; +impl TopKHashTable { pub fn new(limit: usize, capacity: usize) -> Self { Self { - map: RawTable::with_capacity(capacity), + map: HashTable::with_capacity(capacity), + store: Vec::with_capacity(capacity), + free_index: None, limit, } } - pub fn find(&self, hash: u64, mut eq: impl FnMut(&ID) -> bool) -> Option { - let bucket = self.map.find(hash, |mi| eq(&mi.id))?; - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: getting the index of a bucket we just found - let idx = unsafe { self.map.bucket_index(&bucket) }; - Some(idx) - } - - pub unsafe fn heap_idx_at(&self, map_idx: usize) -> usize { - let bucket = unsafe { self.map.bucket(map_idx) }; - bucket.as_ref().heap_idx + pub fn heap_idx_at(&self, map_idx: usize) -> usize { + self.store[map_idx].as_ref().unwrap().heap_idx } - pub unsafe fn remove_if_full(&mut self, replace_idx: usize) -> usize { + pub fn remove_if_full(&mut self, replace_idx: usize) -> usize { if self.map.len() >= self.limit { - self.map.erase(self.map.bucket(replace_idx)); + let item_to_remove = self.store[replace_idx].as_ref().unwrap(); + let hash = item_to_remove.hash; + let id_to_remove = &item_to_remove.id; + + let eq = |&idx: &usize| self.store[idx].as_ref().unwrap().id == *id_to_remove; + let hasher = |idx: &usize| self.store[*idx].as_ref().unwrap().hash; + match self.map.entry(hash, eq, hasher) { + Entry::Occupied(entry) => { + let (removed_idx, _) = entry.remove(); + self.store[removed_idx] = None; + self.free_index = Some(removed_idx); + } + Entry::Vacant(_) => unreachable!(), + } 0 // if full, always replace top node } else { self.map.len() // if we're not full, always append to end } } - unsafe fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { + fn update_heap_idx(&mut self, mapper: &[(usize, usize)]) { for (m, h) in mapper { - self.map.bucket(*m).as_mut().heap_idx = *h + self.store[*m].as_mut().unwrap().heap_idx = *h; } } - pub fn insert( + /// Find an existing entry or insert a new one, avoiding double hash table lookup. + /// Returns (map_idx, is_new) where is_new indicates if this was a new insertion. + /// If inserting a new entry and the table is full, replaces the entry at replace_idx. + pub fn find_or_insert( &mut self, hash: u64, id: ID, - heap_idx: usize, - mapper: &mut Vec<(usize, usize)>, - ) -> usize { - let mi = HashTableItem::new(hash, id, heap_idx); - let bucket = self.map.try_insert_no_grow(hash, mi); - let bucket = match bucket { - Ok(bucket) => bucket, - Err(new_item) => { - let bucket = self.map.insert(hash, new_item, |mi| mi.hash); - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: we're getting indexes of buckets, not dereferencing them - unsafe { - for bucket in self.map.iter() { - let heap_idx = bucket.as_ref().heap_idx; - let map_idx = self.map.bucket_index(&bucket); - mapper.push((heap_idx, map_idx)); - } - } - bucket + replace_idx: usize, + mut eq: impl FnMut(&ID) -> bool, + ) -> (usize, bool) { + // Check if entry exists - this is the only hash table lookup + { + let eq_fn = |idx: &usize| eq(&self.store[*idx].as_ref().unwrap().id); + if let Some(&map_idx) = self.map.find(hash, eq_fn) { + return (map_idx, false); } + } + + // Entry doesn't exist - compute heap_idx and prepare item + let heap_idx = self.remove_if_full(replace_idx); + let mi = HashTableItem::new(hash, id, heap_idx); + let store_idx = if let Some(idx) = self.free_index.take() { + self.store[idx] = Some(mi); + idx + } else { + self.store.push(Some(mi)); + self.store.len() - 1 }; - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: we're getting indexes of buckets, not dereferencing them - unsafe { self.map.bucket_index(&bucket) } + + // Reserve space if needed + let hasher = |idx: &usize| self.store[*idx].as_ref().unwrap().hash; + if self.map.len() == self.map.capacity() { + self.map.reserve(self.limit, hasher); + } + + // Insert without checking again since we already confirmed it doesn't exist + self.map.insert_unique(hash, store_idx, hasher); + (store_idx, true) } pub fn len(&self) -> usize { self.map.len() } - pub unsafe fn take_all(&mut self, idxs: Vec) -> Vec { + pub fn take_all(&mut self, idxs: Vec) -> Vec { let ids = idxs .into_iter() - .map(|idx| self.map.bucket(idx).as_ref().id.clone()) + .map(|idx| self.store[idx].take().unwrap().id) .collect(); self.map.clear(); + self.store.clear(); + self.free_index = None; ids } } @@ -451,11 +436,8 @@ mod tests { let dt = DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())); let mut ht = new_hash_table(1, dt.clone())?; ht.set_batch(Arc::new(ids)); - let mut mapper = vec![]; - let ids = unsafe { - ht.find_or_insert(0, 0, &mut mapper); - ht.take_all(vec![0]) - }; + ht.find_or_insert(0, 0); + let ids = ht.take_all(vec![0]); assert_eq!(ids.data_type(), &dt); Ok(()) @@ -464,28 +446,29 @@ mod tests { #[test] fn should_resize_properly() -> Result<()> { let mut heap_to_map = BTreeMap::::new(); + // Create TopKHashTable with limit=5 and capacity=3 to force resizing let mut map = TopKHashTable::>::new(5, 3); - for (heap_idx, id) in vec!["1", "2", "3", "4", "5"].into_iter().enumerate() { - let mut mapper = vec![]; + + // Insert 5 entries, tracking the heap-to-map index mapping + for (heap_idx, id) in ["1", "2", "3", "4", "5"].iter().enumerate() { + let value = Some(id.to_string()); let hash = heap_idx as u64; - let map_idx = map.insert(hash, Some(id.to_string()), heap_idx, &mut mapper); - let _ = heap_to_map.insert(heap_idx, map_idx); - if heap_idx == 3 { - assert_eq!( - mapper, - vec![(0, 0), (1, 1), (2, 2), (3, 3)], - "Pass {heap_idx} resized incorrectly!" - ); - for (heap_idx, map_idx) in mapper { - let _ = heap_to_map.insert(heap_idx, map_idx); - } - } else { - assert_eq!(mapper, vec![], "Pass {heap_idx} should not have resized!"); - } + let (map_idx, is_new) = + map.find_or_insert(hash, value.clone(), heap_idx, |v| *v == value); + assert!(is_new, "Entry should be new"); + heap_to_map.insert(heap_idx, map_idx); } + // Verify all 5 entries are present + assert_eq!(map.len(), 5); + + // Verify that the hash table resized properly (capacity should have grown beyond 3) + // This is implicit - if it didn't resize, insertions would have failed or been slow + + // Drain all values in heap order let (_heap_idxs, map_idxs): (Vec<_>, Vec<_>) = heap_to_map.into_iter().unzip(); - let ids = unsafe { map.take_all(map_idxs) }; + let ids = map.take_all(map_idxs); + assert_eq!( format!("{ids:?}"), r#"[Some("1"), Some("2"), Some("3"), Some("4"), Some("5")]"# diff --git a/datafusion/physical-plan/src/aggregates/topk/heap.rs b/datafusion/physical-plan/src/aggregates/topk/heap.rs index 83d76a919e4fa..889fe04bf830a 100644 --- a/datafusion/physical-plan/src/aggregates/topk/heap.rs +++ b/datafusion/physical-plan/src/aggregates/topk/heap.rs @@ -15,17 +15,25 @@ // specific language governing permissions and limitations // under the License. -//! A custom binary heap implementation for performant top K aggregation - +//! A custom binary heap implementation for performant top K aggregation. +//! +//! the `new_heap` //! factory function selects an appropriate heap implementation +//! based on the Arrow data type. +//! +//! Supported value types include Arrow primitives (integers, floats, decimals, intervals) +//! and UTF-8 strings (`Utf8`, `LargeUtf8`, `Utf8View`) using lexicographic ordering. + +use arrow::array::{ArrayRef, ArrowPrimitiveType, PrimitiveArray, downcast_primitive}; +use arrow::array::{LargeStringBuilder, StringBuilder, StringViewBuilder}; use arrow::array::{ + StringArray, cast::AsArray, types::{IntervalDayTime, IntervalMonthDayNano}, }; -use arrow::array::{downcast_primitive, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; use arrow::buffer::ScalarBuffer; -use arrow::datatypes::{i256, DataType}; -use datafusion_common::exec_datafusion_err; +use arrow::datatypes::{DataType, i256}; use datafusion_common::Result; +use datafusion_common::exec_datafusion_err; use half::f16; use std::cmp::Ordering; @@ -72,7 +80,6 @@ pub trait ArrowHeap { fn set_batch(&mut self, vals: ArrayRef); fn is_worse(&self, idx: usize) -> bool; fn worst_map_idx(&self) -> usize; - fn renumber(&mut self, heap_to_map: &[(usize, usize)]); fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>); fn replace_if_better( &mut self, @@ -131,10 +138,6 @@ where self.heap.worst_map_idx() } - fn renumber(&mut self, heap_to_map: &[(usize, usize)]) { - self.heap.renumber(heap_to_map); - } - fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>) { let vals = self.batch.as_primitive::(); let new_val = vals.value(row_idx); @@ -161,6 +164,164 @@ where } } +/// An implementation of `ArrowHeap` that deals with string values. +/// +/// Supports all three UTF-8 string types: `Utf8`, `LargeUtf8`, and `Utf8View`. +/// String values are compared lexicographically using the compare-first pattern: +/// borrowed strings are compared before allocation, and only allocated when the +/// heap confirms they improve the top-K set. +/// +pub struct StringHeap { + batch: ArrayRef, + heap: TopKHeap>, + desc: bool, + data_type: DataType, +} + +impl StringHeap { + pub fn new(limit: usize, desc: bool, data_type: DataType) -> Self { + let batch: ArrayRef = Arc::new(StringArray::from(Vec::<&str>::new())); + Self { + batch, + heap: TopKHeap::new(limit, desc), + desc, + data_type, + } + } + + /// Extracts a string value from the current batch at the given row index. + /// + /// Panics if the row index is out of bounds or if the data type is not one of + /// the supported UTF-8 string types. + /// + /// Note: Null values should not appear in the input; the aggregation layer + /// ensures nulls are filtered before reaching this code. + fn value(&self, row_idx: usize) -> &str { + extract_string_value(&self.batch, &self.data_type, row_idx) + } +} + +/// Helper to extract a string value from an ArrayRef at a given index. +/// +/// Supports `Utf8`, `LargeUtf8`, and `Utf8View` data types. +/// +/// # Panics +/// Panics if the index is out of bounds or if the data type is unsupported. +fn extract_string_value<'a>( + batch: &'a ArrayRef, + data_type: &DataType, + idx: usize, +) -> &'a str { + match data_type { + DataType::Utf8 => batch.as_string::().value(idx), + DataType::LargeUtf8 => batch.as_string::().value(idx), + DataType::Utf8View => batch.as_string_view().value(idx), + _ => unreachable!("Unsupported string type: {data_type}"), + } +} + +impl ArrowHeap for StringHeap { + fn set_batch(&mut self, vals: ArrayRef) { + self.batch = vals; + } + + fn is_worse(&self, row_idx: usize) -> bool { + if !self.heap.is_full() { + return false; + } + // Compare borrowed `&str` against the worst heap value first to avoid + // allocating a `String` unless this row would actually replace an + // existing heap entry. + let new_val = self.value(row_idx); + let worst_val = self.heap.worst_val().expect("Missing root"); + match worst_val { + None => false, + Some(worst_str) => { + (!self.desc && new_val > worst_str.as_str()) + || (self.desc && new_val < worst_str.as_str()) + } + } + } + + fn worst_map_idx(&self) -> usize { + self.heap.worst_map_idx() + } + + fn insert(&mut self, row_idx: usize, map_idx: usize, map: &mut Vec<(usize, usize)>) { + // When appending (heap not full) we must allocate to own the string + // because it will be stored in the heap. For replacements we avoid + // allocation until `replace_if_better` confirms a replacement is + // necessary. + let new_str = self.value(row_idx).to_string(); + let new_val = Some(new_str); + self.heap.append_or_replace(new_val, map_idx, map); + } + + fn replace_if_better( + &mut self, + heap_idx: usize, + row_idx: usize, + map: &mut Vec<(usize, usize)>, + ) { + let new_str = self.value(row_idx); + let existing = self.heap.heap[heap_idx] + .as_ref() + .expect("Missing heap item"); + + // Compare borrowed reference first—no allocation yet. + // We compare the borrowed `&str` with the stored `Option` and + // only allocate (`to_string()`) when a replacement is required. + match &existing.val { + None => { + // Existing is null; new value always wins + let new_val = Some(new_str.to_string()); + self.heap.replace_if_better(heap_idx, new_val, map); + } + Some(existing_str) => { + // Compare borrowed strings first + if (!self.desc && new_str < existing_str.as_str()) + || (self.desc && new_str > existing_str.as_str()) + { + let new_val = Some(new_str.to_string()); + self.heap.replace_if_better(heap_idx, new_val, map); + } + // Else: no improvement, no allocation + } + } + } + + fn drain(&mut self) -> (ArrayRef, Vec) { + let (vals, map_idxs) = self.heap.drain(); + // Use Arrow builders to safely construct arrays from the owned + // `Option` values. Builders avoid needing to maintain + // references to temporary storage. + + // Macro to eliminate duplication across string builder types. + // All three builders share the same interface for append_value, + // append_null, and finish, differing only in their concrete types. + macro_rules! build_string_array { + ($builder_type:ty) => {{ + let mut builder = <$builder_type>::new(); + for val in vals { + match val { + Some(s) => builder.append_value(&s), + None => builder.append_null(), + } + } + Arc::new(builder.finish()) + }}; + } + + let arr: ArrayRef = match self.data_type { + DataType::Utf8 => build_string_array!(StringBuilder), + DataType::LargeUtf8 => build_string_array!(LargeStringBuilder), + DataType::Utf8View => build_string_array!(StringViewBuilder), + _ => unreachable!("Unsupported string type: {}", self.data_type), + }; + (arr, map_idxs) + } +} + impl TopKHeap { pub fn new(limit: usize, desc: bool) -> Self { Self { @@ -268,14 +429,6 @@ impl TopKHeap { self.heapify_down(heap_idx, mapper); } - pub fn renumber(&mut self, heap_to_map: &[(usize, usize)]) { - for (heap_idx, map_idx) in heap_to_map.iter() { - if let Some(Some(hi)) = self.heap.get_mut(*heap_idx) { - hi.map_idx = *map_idx; - } - } - } - fn heapify_up(&mut self, mut idx: usize, mapper: &mut Vec<(usize, usize)>) { let desc = self.desc; while idx != 0 { @@ -311,13 +464,12 @@ impl TopKHeap { let mut best_idx = node_idx; let mut best_val = &entry.val; for child_idx in left_child..=left_child + 1 { - if let Some(Some(child)) = self.heap.get(child_idx) { - if (!desc && child.val.comp(best_val) == Ordering::Greater) - || (desc && child.val.comp(best_val) == Ordering::Less) - { - best_val = &child.val; - best_idx = child_idx; - } + if let Some(Some(child)) = self.heap.get(child_idx) + && ((!desc && child.val.comp(best_val) == Ordering::Greater) + || (desc && child.val.comp(best_val) == Ordering::Less)) + { + best_val = &child.val; + best_idx = child_idx; } } if best_val.comp(&entry.val) != Ordering::Equal { @@ -329,11 +481,7 @@ impl TopKHeap { fn _tree_print(&self, idx: usize, prefix: &str, is_tail: bool, output: &mut String) { if let Some(Some(hi)) = self.heap.get(idx) { let connector = if idx != 0 { - if is_tail { - "└── " - } else { - "├── " - } + if is_tail { "└── " } else { "├── " } } else { "" }; @@ -456,11 +604,31 @@ compare_integer!(u8, u16, u32, u64); compare_integer!(IntervalDayTime, IntervalMonthDayNano); compare_float!(f16, f32, f64); +/// Returns true if the given data type can be stored in a top-K aggregation heap. +/// +/// Supported types include Arrow primitives (integers, floats, decimals, intervals) +/// and UTF-8 strings (`Utf8`, `LargeUtf8`, `Utf8View`). This is used internally by +/// `PriorityMap::supports()` to validate aggregate value type compatibility. +pub fn is_supported_heap_type(vt: &DataType) -> bool { + vt.is_primitive() + || matches!( + vt, + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 + ) +} + pub fn new_heap( limit: usize, desc: bool, vt: DataType, ) -> Result> { + if matches!( + vt, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) { + return Ok(Box::new(StringHeap::new(limit, desc, vt))); + } + macro_rules! downcast_helper { ($vt:ty, $d:ident) => { return Ok(Box::new(PrimitiveHeap::<$vt>::new(limit, desc, vt))) @@ -472,7 +640,9 @@ pub fn new_heap( _ => {} } - Err(exec_datafusion_err!("Can't group type: {vt:?}")) + Err(exec_datafusion_err!( + "Unsupported TopK aggregate value type: {vt:?}" + )) } #[cfg(test)] @@ -488,9 +658,7 @@ mod tests { heap.append_or_replace(1, 1, &mut map); let actual = heap.to_string(); - assert_snapshot!(actual, @r#" -val=1 idx=0, bucket=1 - "#); + assert_snapshot!(actual, @"val=1 idx=0, bucket=1"); Ok(()) } @@ -507,10 +675,10 @@ val=1 idx=0, bucket=1 assert_eq!(map, vec![(2, 0), (1, 1)]); let actual = heap.to_string(); - assert_snapshot!(actual, @r#" -val=2 idx=0, bucket=2 -└── val=1 idx=1, bucket=1 - "#); + assert_snapshot!(actual, @r" + val=2 idx=0, bucket=2 + └── val=1 idx=1, bucket=1 + "); Ok(()) } @@ -524,20 +692,20 @@ val=2 idx=0, bucket=2 heap.append_or_replace(2, 2, &mut map); heap.append_or_replace(3, 3, &mut map); let actual = heap.to_string(); - assert_snapshot!(actual, @r#" -val=3 idx=0, bucket=3 -├── val=1 idx=1, bucket=1 -└── val=2 idx=2, bucket=2 - "#); + assert_snapshot!(actual, @r" + val=3 idx=0, bucket=3 + ├── val=1 idx=1, bucket=1 + └── val=2 idx=2, bucket=2 + "); let mut map = vec![]; heap.append_or_replace(0, 0, &mut map); let actual = heap.to_string(); - assert_snapshot!(actual, @r#" -val=2 idx=0, bucket=2 -├── val=1 idx=1, bucket=1 -└── val=0 idx=2, bucket=0 - "#); + assert_snapshot!(actual, @r" + val=2 idx=0, bucket=2 + ├── val=1 idx=1, bucket=1 + └── val=0 idx=2, bucket=0 + "); assert_eq!(map, vec![(2, 0), (0, 2)]); Ok(()) @@ -553,22 +721,22 @@ val=2 idx=0, bucket=2 heap.append_or_replace(3, 3, &mut map); heap.append_or_replace(4, 4, &mut map); let actual = heap.to_string(); - assert_snapshot!(actual, @r#" -val=4 idx=0, bucket=4 -├── val=3 idx=1, bucket=3 -│ └── val=1 idx=3, bucket=1 -└── val=2 idx=2, bucket=2 - "#); + assert_snapshot!(actual, @r" + val=4 idx=0, bucket=4 + ├── val=3 idx=1, bucket=3 + │ └── val=1 idx=3, bucket=1 + └── val=2 idx=2, bucket=2 + "); let mut map = vec![]; heap.replace_if_better(1, 0, &mut map); let actual = heap.to_string(); - assert_snapshot!(actual, @r#" -val=4 idx=0, bucket=4 -├── val=1 idx=1, bucket=1 -│ └── val=0 idx=3, bucket=3 -└── val=2 idx=2, bucket=2 - "#); + assert_snapshot!(actual, @r" + val=4 idx=0, bucket=4 + ├── val=1 idx=1, bucket=1 + │ └── val=0 idx=3, bucket=3 + └── val=2 idx=2, bucket=2 + "); assert_eq!(map, vec![(1, 1), (3, 3)]); Ok(()) @@ -583,10 +751,10 @@ val=4 idx=0, bucket=4 heap.append_or_replace(2, 2, &mut map); let actual = heap.to_string(); - assert_snapshot!(actual, @r#" -val=2 idx=0, bucket=2 -└── val=1 idx=1, bucket=1 - "#); + assert_snapshot!(actual, @r" + val=2 idx=0, bucket=2 + └── val=1 idx=1, bucket=1 + "); assert_eq!(heap.worst_val(), Some(&2)); assert_eq!(heap.worst_map_idx(), 2); @@ -603,10 +771,10 @@ val=2 idx=0, bucket=2 heap.append_or_replace(2, 2, &mut map); let actual = heap.to_string(); - assert_snapshot!(actual, @r#" -val=2 idx=0, bucket=2 -└── val=1 idx=1, bucket=1 - "#); + assert_snapshot!(actual, @r" + val=2 idx=0, bucket=2 + └── val=1 idx=1, bucket=1 + "); let (vals, map_idxs) = heap.drain(); assert_eq!(vals, vec![1, 2]); @@ -615,29 +783,4 @@ val=2 idx=0, bucket=2 Ok(()) } - - #[test] - fn should_renumber() -> Result<()> { - let mut map = vec![]; - let mut heap = TopKHeap::new(10, false); - - heap.append_or_replace(1, 1, &mut map); - heap.append_or_replace(2, 2, &mut map); - - let actual = heap.to_string(); - assert_snapshot!(actual, @r#" -val=2 idx=0, bucket=2 -└── val=1 idx=1, bucket=1 - "#); - - let numbers = vec![(0, 1), (1, 2)]; - heap.renumber(numbers.as_slice()); - let actual = heap.to_string(); - assert_snapshot!(actual, @r#" -val=2 idx=0, bucket=1 -└── val=1 idx=1, bucket=2 - "#); - - Ok(()) - } } diff --git a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs index a09d70f7471f3..c74b648d373ce 100644 --- a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs +++ b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs @@ -17,8 +17,8 @@ //! A `Map` / `PriorityQueue` combo that evicts the worst values after reaching `capacity` -use crate::aggregates::topk::hash_table::{new_hash_table, ArrowHashTable}; -use crate::aggregates::topk::heap::{new_heap, ArrowHeap}; +use crate::aggregates::topk::hash_table::{ArrowHashTable, new_hash_table}; +use crate::aggregates::topk::heap::{ArrowHeap, new_heap}; use arrow::array::ArrayRef; use arrow::datatypes::DataType; use datafusion_common::Result; @@ -63,40 +63,26 @@ impl PriorityMap { // handle new groups we haven't seen yet map.clear(); let replace_idx = self.heap.worst_map_idx(); - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: replace_idx kept valid during resizes - let (map_idx, did_insert) = - unsafe { self.map.find_or_insert(row_idx, replace_idx, map) }; + + let (map_idx, did_insert) = self.map.find_or_insert(row_idx, replace_idx); if did_insert { - self.heap.renumber(map); - map.clear(); self.heap.insert(row_idx, map_idx, map); - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: the map was created on the line above, so all the indexes should be valid - unsafe { self.map.update_heap_idx(map) }; + self.map.update_heap_idx(map); return Ok(()); }; // this is a value for an existing group map.clear(); - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: map_idx was just found, so it is valid - let heap_idx = unsafe { self.map.heap_idx_at(map_idx) }; + let heap_idx = self.map.heap_idx_at(map_idx); self.heap.replace_if_better(heap_idx, row_idx, map); - // JUSTIFICATION - // Benefit: ~15% speedup + required to index into RawTable from binary heap - // Soundness: the index map was just built, so it will be valid - unsafe { self.map.update_heap_idx(map) }; + self.map.update_heap_idx(map); Ok(()) } pub fn emit(&mut self) -> Result> { let (vals, map_idxs) = self.heap.drain(); - let ids = unsafe { self.map.take_all(map_idxs) }; + let ids = self.map.take_all(map_idxs); Ok(vec![ids, vals]) } @@ -182,13 +168,13 @@ mod tests { let batch = RecordBatch::try_new(test_schema(), cols)?; let actual = format!("{}", pretty_format_batches(&[batch])?); - assert_snapshot!(actual, @r#" -+----------+--------------+ -| trace_id | timestamp_ms | -+----------+--------------+ -| 1 | 1 | -+----------+--------------+ - "# + assert_snapshot!(actual, @r" + +----------+--------------+ + | trace_id | timestamp_ms | + +----------+--------------+ + | 1 | 1 | + +----------+--------------+ + " ); Ok(()) @@ -207,13 +193,13 @@ mod tests { let batch = RecordBatch::try_new(test_schema(), cols)?; let actual = format!("{}", pretty_format_batches(&[batch])?); - assert_snapshot!(actual, @r#" -+----------+--------------+ -| trace_id | timestamp_ms | -+----------+--------------+ -| 1 | 1 | -+----------+--------------+ - "# + assert_snapshot!(actual, @r" + +----------+--------------+ + | trace_id | timestamp_ms | + +----------+--------------+ + | 1 | 1 | + +----------+--------------+ + " ); Ok(()) @@ -231,13 +217,13 @@ mod tests { let cols = agg.emit()?; let batch = RecordBatch::try_new(test_schema(), cols)?; let actual = format!("{}", pretty_format_batches(&[batch])?); - assert_snapshot!(actual, @r#" -+----------+--------------+ -| trace_id | timestamp_ms | -+----------+--------------+ -| 2 | 2 | -+----------+--------------+ - "# + assert_snapshot!(actual, @r" + +----------+--------------+ + | trace_id | timestamp_ms | + +----------+--------------+ + | 2 | 2 | + +----------+--------------+ + " ); Ok(()) @@ -255,13 +241,13 @@ mod tests { let cols = agg.emit()?; let batch = RecordBatch::try_new(test_schema(), cols)?; let actual = format!("{}", pretty_format_batches(&[batch])?); - assert_snapshot!(actual, @r#" -+----------+--------------+ -| trace_id | timestamp_ms | -+----------+--------------+ -| 1 | 1 | -+----------+--------------+ - "# + assert_snapshot!(actual, @r" + +----------+--------------+ + | trace_id | timestamp_ms | + +----------+--------------+ + | 1 | 1 | + +----------+--------------+ + " ); Ok(()) @@ -279,13 +265,13 @@ mod tests { let cols = agg.emit()?; let batch = RecordBatch::try_new(test_schema(), cols)?; let actual = format!("{}", pretty_format_batches(&[batch])?); - assert_snapshot!(actual, @r#" -+----------+--------------+ -| trace_id | timestamp_ms | -+----------+--------------+ -| 1 | 2 | -+----------+--------------+ - "# + assert_snapshot!(actual, @r" + +----------+--------------+ + | trace_id | timestamp_ms | + +----------+--------------+ + | 1 | 2 | + +----------+--------------+ + " ); Ok(()) @@ -303,13 +289,13 @@ mod tests { let cols = agg.emit()?; let batch = RecordBatch::try_new(test_schema(), cols)?; let actual = format!("{}", pretty_format_batches(&[batch])?); - assert_snapshot!(actual, @r#" -+----------+--------------+ -| trace_id | timestamp_ms | -+----------+--------------+ -| 1 | 1 | -+----------+--------------+ - "# + assert_snapshot!(actual, @r" + +----------+--------------+ + | trace_id | timestamp_ms | + +----------+--------------+ + | 1 | 1 | + +----------+--------------+ + " ); Ok(()) @@ -327,13 +313,13 @@ mod tests { let cols = agg.emit()?; let batch = RecordBatch::try_new(test_schema(), cols)?; let actual = format!("{}", pretty_format_batches(&[batch])?); - assert_snapshot!(actual, @r#" -+----------+--------------+ -| trace_id | timestamp_ms | -+----------+--------------+ -| 2 | 2 | -+----------+--------------+ - "# + assert_snapshot!(actual, @r" + +----------+--------------+ + | trace_id | timestamp_ms | + +----------+--------------+ + | 2 | 2 | + +----------+--------------+ + " ); Ok(()) @@ -351,13 +337,13 @@ mod tests { let cols = agg.emit()?; let batch = RecordBatch::try_new(test_schema(), cols)?; let actual = format!("{}", pretty_format_batches(&[batch])?); - assert_snapshot!(actual, @r#" -+----------+--------------+ -| trace_id | timestamp_ms | -+----------+--------------+ -| 1 | 1 | -+----------+--------------+ - "# + assert_snapshot!(actual, @r" + +----------+--------------+ + | trace_id | timestamp_ms | + +----------+--------------+ + | 1 | 1 | + +----------+--------------+ + " ); Ok(()) @@ -375,14 +361,110 @@ mod tests { let cols = agg.emit()?; let batch = RecordBatch::try_new(test_schema(), cols)?; let actual = format!("{}", pretty_format_batches(&[batch])?); + assert_snapshot!(actual, @r" + +----------+--------------+ + | trace_id | timestamp_ms | + +----------+--------------+ + | 1 | 2 | + +----------+--------------+ + " + ); + + Ok(()) + } + + #[test] + fn should_track_lexicographic_min_utf8_value() -> Result<()> { + let ids: ArrayRef = Arc::new(Int64Array::from(vec![1, 1])); + let vals: ArrayRef = Arc::new(StringArray::from(vec!["zulu", "alpha"])); + let mut agg = PriorityMap::new(DataType::Int64, DataType::Utf8, 1, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema_value(DataType::Utf8), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + assert_snapshot!(actual, @r#" +----------+--------------+ | trace_id | timestamp_ms | +----------+--------------+ -| 1 | 2 | +| 1 | alpha | +----------+--------------+ - "# - ); + "#); + + Ok(()) + } + + #[test] + fn should_track_lexicographic_max_utf8_value_desc() -> Result<()> { + let ids: ArrayRef = Arc::new(Int64Array::from(vec![1, 1])); + let vals: ArrayRef = Arc::new(StringArray::from(vec!["alpha", "zulu"])); + let mut agg = PriorityMap::new(DataType::Int64, DataType::Utf8, 1, true)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema_value(DataType::Utf8), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + + assert_snapshot!(actual, @r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | zulu | ++----------+--------------+ + "#); + + Ok(()) + } + + #[test] + fn should_track_large_utf8_values() -> Result<()> { + let ids: ArrayRef = Arc::new(Int64Array::from(vec![1, 1])); + let vals: ArrayRef = Arc::new(LargeStringArray::from(vec!["zulu", "alpha"])); + let mut agg = PriorityMap::new(DataType::Int64, DataType::LargeUtf8, 1, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema_value(DataType::LargeUtf8), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + + assert_snapshot!(actual, @r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | alpha | ++----------+--------------+ + "#); + + Ok(()) + } + + #[test] + fn should_track_utf8_view_values() -> Result<()> { + let ids: ArrayRef = Arc::new(Int64Array::from(vec![1, 1])); + let vals: ArrayRef = Arc::new(StringViewArray::from(vec!["alpha", "zulu"])); + let mut agg = PriorityMap::new(DataType::Int64, DataType::Utf8View, 1, true)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + agg.insert(1)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema_value(DataType::Utf8View), cols)?; + let actual = format!("{}", pretty_format_batches(&[batch])?); + + assert_snapshot!(actual, @r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | zulu | ++----------+--------------+ + "#); Ok(()) } @@ -400,14 +482,14 @@ mod tests { let cols = agg.emit()?; let batch = RecordBatch::try_new(test_schema(), cols)?; let actual = format!("{}", pretty_format_batches(&[batch])?); - assert_snapshot!(actual, @r#" -+----------+--------------+ -| trace_id | timestamp_ms | -+----------+--------------+ -| | 3 | -| 1 | 1 | -+----------+--------------+ - "# + assert_snapshot!(actual, @r" + +----------+--------------+ + | trace_id | timestamp_ms | + +----------+--------------+ + | | 3 | + | 1 | 1 | + +----------+--------------+ + " ); Ok(()) @@ -433,4 +515,11 @@ mod tests { Field::new("timestamp_ms", DataType::Int64, true), ])) } + + fn test_schema_value(value_type: DataType) -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Int64, true), + Field::new("timestamp_ms", value_type, true), + ])) + } } diff --git a/datafusion/physical-plan/src/aggregates/topk_stream.rs b/datafusion/physical-plan/src/aggregates/topk_stream.rs index c706b48e348eb..97f4662c11342 100644 --- a/datafusion/physical-plan/src/aggregates/topk_stream.rs +++ b/datafusion/physical-plan/src/aggregates/topk_stream.rs @@ -19,21 +19,26 @@ use crate::aggregates::group_values::GroupByMetrics; use crate::aggregates::topk::priority_map::PriorityMap; +#[cfg(debug_assertions)] +use crate::aggregates::topk_types_supported; use crate::aggregates::{ - aggregate_expressions, evaluate_group_by, evaluate_many, AggregateExec, - PhysicalGroupBy, + AggregateExec, PhysicalGroupBy, aggregate_expressions, evaluate_group_by, + evaluate_many, }; use crate::metrics::BaselineMetrics; +use crate::stream::EmptyRecordBatchStream; use crate::{RecordBatchStream, SendableRecordBatchStream}; -use arrow::array::{Array, ArrayRef, RecordBatch}; +use arrow::array::{Array, ArrayRef, RecordBatch, new_null_array}; +use arrow::compute::concat; use arrow::datatypes::SchemaRef; use arrow::util::pretty::print_batches; -use datafusion_common::internal_datafusion_err; use datafusion_common::Result; +use datafusion_common::internal_datafusion_err; use datafusion_execution::TaskContext; use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::metrics::RecordOutput; use futures::stream::{Stream, StreamExt}; -use log::{trace, Level}; +use log::{Level, trace}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -42,13 +47,16 @@ pub struct GroupedTopKAggregateStream { partition: usize, row_count: usize, started: bool, + done: bool, schema: SchemaRef, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, group_by_metrics: GroupByMetrics, aggregate_arguments: Vec>>, - group_by: PhysicalGroupBy, + group_by: Arc, priority_map: PriorityMap, + /// Whether a NULL group key has been seen for a group-by-only aggregation. + null_group_seen: bool, } impl GroupedTopKAggregateStream { @@ -59,25 +67,53 @@ impl GroupedTopKAggregateStream { limit: usize, ) -> Result { let agg_schema = Arc::clone(&aggr.schema); - let group_by = aggr.group_by.clone(); + let group_by = Arc::clone(&aggr.group_by); let input = aggr.input.execute(partition, Arc::clone(context))?; let baseline_metrics = BaselineMetrics::new(&aggr.metrics, partition); let group_by_metrics = GroupByMetrics::new(&aggr.metrics, partition); let aggregate_arguments = aggregate_expressions(&aggr.aggr_expr, &aggr.mode, group_by.expr.len())?; - let (val_field, desc) = aggr - .get_minmax_desc() - .ok_or_else(|| internal_datafusion_err!("Min/max required"))?; let (expr, _) = &aggr.group_expr().expr()[0]; let kt = expr.data_type(&aggr.input().schema())?; - let vt = val_field.data_type().clone(); + // Check if this is a MIN/MAX aggregate or a DISTINCT-like operation + let (vt, desc) = if let Some((val_field, desc)) = aggr.get_minmax_desc() { + // MIN/MAX case: use the aggregate output type + (val_field.data_type().clone(), desc) + } else { + // DISTINCT case: use the group key type and get ordering from limit_order_descending + // The ordering direction is set by the optimizer when it pushes down the limit + let desc = aggr + .limit_options() + .and_then(|config| config.descending) + .ok_or_else(|| { + internal_datafusion_err!( + "Ordering direction required for DISTINCT with limit" + ) + })?; + (kt.clone(), desc) + }; + + // Type validation is performed by the optimizer and can_use_topk() check. + // This debug assertion documents the contract without runtime overhead in release builds. + #[cfg(debug_assertions)] + { + debug_assert!( + topk_types_supported(&kt, &vt), + "TopK type validation should have been performed by optimizer and can_use_topk(). \ + Found unsupported types: key={kt:?}, value={vt:?}" + ); + } + + // Note: Null values in aggregate columns are filtered by the aggregation layer + // before reaching the heap, so the heap implementations don't need explicit null handling. let priority_map = PriorityMap::new(kt, vt, limit, desc)?; Ok(GroupedTopKAggregateStream { partition, started: false, + done: false, row_count: 0, schema: agg_schema, input, @@ -86,6 +122,7 @@ impl GroupedTopKAggregateStream { aggregate_arguments, group_by, priority_map, + null_group_seen: false, }) } } @@ -97,6 +134,10 @@ impl RecordBatchStream for GroupedTopKAggregateStream { } impl GroupedTopKAggregateStream { + fn is_group_by_only(&self) -> bool { + self.aggregate_arguments.is_empty() + } + fn intern(&mut self, ids: &ArrayRef, vals: &ArrayRef) -> Result<()> { let _timer = self.group_by_metrics.time_calculating_group_ids.timer(); @@ -105,6 +146,9 @@ impl GroupedTopKAggregateStream { .set_batch(Arc::clone(ids), Arc::clone(vals)); let has_nulls = vals.null_count() > 0; + if has_nulls && self.is_group_by_only() { + self.null_group_seen = true; + } for row_idx in 0..len { if has_nulls && vals.is_null(row_idx) { continue; @@ -113,6 +157,39 @@ impl GroupedTopKAggregateStream { } Ok(()) } + + fn emit_columns(&mut self) -> Result> { + let mut cols = if self.priority_map.is_empty() { + vec![] + } else { + self.priority_map.emit()? + }; + + // GROUP BY-only aggregation covers DISTINCT-like queries. The group + // key and heap value are the same column, but the output schema has + // only the group key. + if self.is_group_by_only() { + cols.truncate(1); + if self.null_group_seen { + self.append_null_group(&mut cols)?; + } + } + + Ok(cols) + } + + fn append_null_group(&self, cols: &mut Vec) -> Result<()> { + let dt = self.schema.field(0).data_type(); + let null_arr = new_null_array(dt, 1); + if cols.is_empty() { + cols.push(null_arr); + } else { + // NULL group keys are tracked outside the heap, so append a + // one-row NULL array to the emitted non-NULL group key column. + cols[0] = concat(&[cols[0].as_ref(), null_arr.as_ref()])?; + } + Ok(()) + } } impl Stream for GroupedTopKAggregateStream { @@ -122,6 +199,9 @@ impl Stream for GroupedTopKAggregateStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { + if self.done { + return Poll::Ready(None); + } let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); let emitting_time = self.group_by_metrics.emitting_time.clone(); while let Poll::Ready(res) = self.input.poll_next_unpin(cx) { @@ -154,33 +234,41 @@ impl Stream for GroupedTopKAggregateStream { "Exactly 1 group value required" ); let group_by_values = Arc::clone(&group_by_values[0][0]); - let input_values = { - let _timer = (!self.aggregate_arguments.is_empty()).then(|| { - self.group_by_metrics.aggregate_arguments_time.timer() - }); - evaluate_many( + let input_values = if self.is_group_by_only() { + // GROUP BY-only case: use group key as both key and value + Arc::clone(&group_by_values) + } else { + // MIN/MAX case: evaluate aggregate expressions + let _timer = + self.group_by_metrics.aggregate_arguments_time.timer(); + let input_values = evaluate_many( &self.aggregate_arguments, batches.first().unwrap(), - )? + )?; + assert_eq!(input_values.len(), 1, "Exactly 1 input required"); + assert_eq!(input_values[0].len(), 1, "Exactly 1 input required"); + Arc::clone(&input_values[0][0]) }; - assert_eq!(input_values.len(), 1, "Exactly 1 input required"); - assert_eq!(input_values[0].len(), 1, "Exactly 1 input required"); - let input_values = Arc::clone(&input_values[0][0]); // iterate over each column of group_by values (*self).intern(&group_by_values, &input_values)?; } // inner is done, emit all rows and switch to producing output None => { - if self.priority_map.is_empty() { + // Release the input pipeline's resources before emitting. + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); + if self.priority_map.is_empty() && !self.null_group_seen { trace!("partition {} emit None", self.partition); + self.done = true; return Poll::Ready(None); } let batch = { let _timer = emitting_time.timer(); - let cols = self.priority_map.emit()?; + let cols = self.emit_columns()?; RecordBatch::try_new(Arc::clone(&self.schema), cols)? }; + let batch = batch.record_output(&self.baseline_metrics); trace!( "partition {} emit batch with {} rows", self.partition, @@ -189,6 +277,7 @@ impl Stream for GroupedTopKAggregateStream { if log::log_enabled!(Level::Trace) { print_batches(std::slice::from_ref(&batch))?; } + self.done = true; return Poll::Ready(Some(Ok(batch))); } // inner had error, return to caller diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index 01f997f23d6a9..27e0f5e923d85 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -17,7 +17,6 @@ //! Defines the ANALYZE operator -use std::any::Any; use std::sync::Arc; use super::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; @@ -26,12 +25,16 @@ use super::{ SendableRecordBatchStream, }; use crate::display::DisplayableExecutionPlan; -use crate::metrics::MetricType; +use crate::execution_plan::EvaluationType; +use crate::metrics::{MetricCategory, MetricType}; use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; +use datafusion_common::format::ExplainFormat; use datafusion_common::instant::Instant; -use datafusion_common::{assert_eq_or_internal_err, DataFusionError, Result}; +use datafusion_common::{ + DataFusionError, Result, assert_eq_or_internal_err, internal_err, +}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; @@ -47,32 +50,92 @@ pub struct AnalyzeExec { show_statistics: bool, /// Which metric categories should be displayed metric_types: Vec, + /// Optional filter by semantic category (rows / bytes / timing). + metric_categories: Option>, + /// Output format for the rendered plan + metrics. + format: ExplainFormat, /// The input plan (the plan being analyzed) pub(crate) input: Arc, /// The output schema for RecordBatches of this exec node schema: SchemaRef, - cache: PlanProperties, + cache: Arc, } -impl AnalyzeExec { - /// Create a new AnalyzeExec +/// Builder for [`AnalyzeExec`]. +/// +/// Builder for [AnalyzeExec]. +pub struct AnalyzeExecBuilder { + verbose: bool, + show_statistics: bool, + input: Arc, + schema: SchemaRef, + metric_types: Vec, + metric_categories: Option>, + format: ExplainFormat, +} + +impl AnalyzeExecBuilder { pub fn new( verbose: bool, show_statistics: bool, - metric_types: Vec, input: Arc, schema: SchemaRef, ) -> Self { - let cache = Self::compute_properties(&input, Arc::clone(&schema)); - AnalyzeExec { + Self { verbose, show_statistics, - metric_types, input, schema, - cache, + metric_types: vec![MetricType::Summary, MetricType::Dev], + metric_categories: None, + format: ExplainFormat::Indent, + } + } + + pub fn with_metric_types(mut self, metric_types: Vec) -> Self { + self.metric_types = metric_types; + self + } + + pub fn with_metric_categories( + mut self, + metric_categories: Option>, + ) -> Self { + self.metric_categories = metric_categories; + self + } + + pub fn with_format(mut self, format: ExplainFormat) -> Self { + self.format = format; + self + } + + pub fn build(self) -> AnalyzeExec { + let cache = + AnalyzeExec::compute_properties(&self.input, Arc::clone(&self.schema)); + AnalyzeExec { + verbose: self.verbose, + show_statistics: self.show_statistics, + metric_types: self.metric_types, + metric_categories: self.metric_categories, + format: self.format, + input: self.input, + schema: self.schema, + cache: Arc::new(cache), } } +} + +impl AnalyzeExec { + /// Returns a builder for constructing an [`AnalyzeExec`]. + pub fn builder( + verbose: bool, + show_statistics: bool, + input: Arc, + schema: SchemaRef, + ) -> AnalyzeExecBuilder { + AnalyzeExecBuilder::new(verbose, show_statistics, input, schema) + } /// Access to verbose pub fn verbose(&self) -> bool { @@ -84,6 +147,16 @@ impl AnalyzeExec { self.show_statistics } + /// Access to metric_categories + pub fn metric_categories(&self) -> Option<&[MetricCategory]> { + self.metric_categories.as_deref() + } + + /// Access to format + pub fn format(&self) -> &ExplainFormat { + &self.format + } + /// The input plan pub fn input(&self) -> &Arc { &self.input @@ -100,6 +173,7 @@ impl AnalyzeExec { input.pipeline_behavior(), input.boundedness(), ) + .with_evaluation_type(EvaluationType::Eager) } } @@ -127,11 +201,7 @@ impl ExecutionPlan for AnalyzeExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -147,13 +217,18 @@ impl ExecutionPlan for AnalyzeExec { self: Arc, mut children: Vec>, ) -> Result> { - Ok(Arc::new(Self::new( - self.verbose, - self.show_statistics, - self.metric_types.clone(), - children.pop().unwrap(), - Arc::clone(&self.schema), - ))) + Ok(Arc::new( + AnalyzeExec::builder( + self.verbose, + self.show_statistics, + children.pop().unwrap(), + Arc::clone(&self.schema), + ) + .with_metric_types(self.metric_types.clone()) + .with_metric_categories(self.metric_categories.clone()) + .with_format(self.format.clone()) + .build(), + )) } fn execute( @@ -189,6 +264,8 @@ impl ExecutionPlan for AnalyzeExec { let verbose = self.verbose; let show_statistics = self.show_statistics; let metric_types = self.metric_types.clone(); + let metric_categories = self.metric_categories.clone(); + let format = self.format.clone(); // future that gathers the results from all the tasks in the // JoinSet that computes the overall row count and final @@ -199,6 +276,7 @@ impl ExecutionPlan for AnalyzeExec { while let Some(batch) = input_stream.next().await.transpose()? { total_rows += batch.num_rows(); } + drop(input_stream); let duration = Instant::now() - start; create_output_batch( @@ -209,6 +287,8 @@ impl ExecutionPlan for AnalyzeExec { &captured_input, &captured_schema, &metric_types, + metric_categories.as_deref(), + &format, ) }; @@ -220,6 +300,7 @@ impl ExecutionPlan for AnalyzeExec { } /// Creates the output of AnalyzeExec as a RecordBatch +#[expect(clippy::too_many_arguments)] fn create_output_batch( verbose: bool, show_statistics: bool, @@ -228,37 +309,62 @@ fn create_output_batch( input: &Arc, schema: &SchemaRef, metric_types: &[MetricType], + metric_categories: Option<&[MetricCategory]>, + format: &ExplainFormat, ) -> Result { let mut type_builder = StringBuilder::with_capacity(1, 1024); let mut plan_builder = StringBuilder::with_capacity(1, 1024); - // TODO use some sort of enum rather than strings? - type_builder.append_value("Plan with Metrics"); - - let annotated_plan = DisplayableExecutionPlan::with_metrics(input.as_ref()) - .set_metric_types(metric_types.to_vec()) - .set_show_statistics(show_statistics) - .indent(verbose) - .to_string(); - plan_builder.append_value(annotated_plan); - - // Verbose output - // TODO make this more sophisticated - if verbose { - type_builder.append_value("Plan with Full Metrics"); - - let annotated_plan = DisplayableExecutionPlan::with_full_metrics(input.as_ref()) - .set_metric_types(metric_types.to_vec()) - .set_show_statistics(show_statistics) - .indent(verbose) - .to_string(); - plan_builder.append_value(annotated_plan); - - type_builder.append_value("Output Rows"); - plan_builder.append_value(total_rows.to_string()); - - type_builder.append_value("Duration"); - plan_builder.append_value(format!("{duration:?}")); + match format { + ExplainFormat::Indent => { + // TODO use some sort of enum rather than strings? + type_builder.append_value("Plan with Metrics"); + let annotated_plan = DisplayableExecutionPlan::with_metrics(input.as_ref()) + .set_metric_types(metric_types.to_vec()) + .set_metric_categories(metric_categories.map(|c| c.to_vec())) + .set_show_statistics(show_statistics) + .indent(verbose) + .to_string(); + plan_builder.append_value(annotated_plan); + // Verbose output + // TODO make this more sophisticated + if verbose { + type_builder.append_value("Plan with Full Metrics"); + let annotated_plan = + DisplayableExecutionPlan::with_full_metrics(input.as_ref()) + .set_metric_types(metric_types.to_vec()) + .set_metric_categories(metric_categories.map(|c| c.to_vec())) + .set_show_statistics(show_statistics) + .indent(verbose) + .to_string(); + plan_builder.append_value(annotated_plan); + type_builder.append_value("Output Rows"); + plan_builder.append_value(total_rows.to_string()); + type_builder.append_value("Duration"); + plan_builder.append_value(format!("{duration:?}")); + } + } + ExplainFormat::PostgresJSON => { + // `show_statistics` is intentionally not forwarded here: the pgjson + // renderer does not emit statistics, and the planner rejects the + // `show_statistics` + pgjson combination up front. + type_builder.append_value("Plan with Metrics"); + let mut displayable = if verbose { + DisplayableExecutionPlan::with_full_metrics(input.as_ref()) + } else { + DisplayableExecutionPlan::with_metrics(input.as_ref()) + }; + displayable = displayable + .set_metric_types(metric_types.to_vec()) + .set_metric_categories(metric_categories.map(|c| c.to_vec())); + if verbose { + displayable = displayable.set_summary(Some(total_rows), Some(duration)); + } + plan_builder.append_value(displayable.pgjson(verbose).to_string()); + } + ExplainFormat::Tree | ExplainFormat::Graphviz => { + return internal_err!("AnalyzeExec does not support {format} output format"); + } } RecordBatch::try_new( @@ -278,7 +384,7 @@ mod tests { collect, test::{ assert_is_pending, - exec::{assert_strong_count_converges_to_zero, BlockingExec}, + exec::{BlockingExec, assert_strong_count_converges_to_zero}, }, }; @@ -293,13 +399,8 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); - let analyze_exec = Arc::new(AnalyzeExec::new( - true, - false, - vec![MetricType::SUMMARY, MetricType::DEV], - blocking_exec, - schema, - )); + let analyze_exec = + Arc::new(AnalyzeExec::builder(true, false, blocking_exec, schema).build()); let fut = collect(analyze_exec, task_ctx); let mut fut = fut.boxed(); diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index d442307e9488e..1b15bf27e78cc 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -15,37 +15,42 @@ // specific language governing permissions and limitations // under the License. +use crate::coalesce::LimitedBatchCoalescer; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::stream::RecordBatchStreamAdapter; +use crate::stream::{EmptyRecordBatchStream, RecordBatchStreamAdapter}; use crate::{ DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + check_if_same_properties, }; use arrow::array::RecordBatch; use arrow_schema::{Fields, Schema, SchemaRef}; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; -use datafusion_common::{assert_eq_or_internal_err, Result}; -use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_common::{Result, assert_eq_or_internal_err}; +use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_expr_common::metrics::{BaselineMetrics, RecordOutput}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use futures::Stream; use futures::stream::StreamExt; use log::trace; -use std::any::Any; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll, ready}; /// This structure evaluates a set of async expressions on a record /// batch producing a new record batch /// /// The schema of the output of the AsyncFuncExec is: /// Input columns followed by one column for each async expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct AsyncFuncExec { /// The async expressions to evaluate async_exprs: Vec>, input: Arc, - cache: PlanProperties, + cache: Arc, metrics: ExecutionPlanMetricsSet, } @@ -79,7 +84,7 @@ impl AsyncFuncExec { Ok(Self { input, async_exprs, - cache, + cache: Arc::new(cache), metrics: ExecutionPlanMetricsSet::new(), }) } @@ -100,6 +105,25 @@ impl AsyncFuncExec { input.boundedness(), )) } + + pub fn async_exprs(&self) -> &[Arc] { + &self.async_exprs + } + + pub fn input(&self) -> &Arc { + &self.input + } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for AsyncFuncExec { @@ -132,11 +156,7 @@ impl ExecutionPlan for AsyncFuncExec { "async_func" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -146,16 +166,17 @@ impl ExecutionPlan for AsyncFuncExec { fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { assert_eq_or_internal_err!( children.len(), 1, "AsyncFuncExec wrong number of children" ); + check_if_same_properties!(self, children); Ok(Arc::new(AsyncFuncExec::try_new( self.async_exprs.clone(), - Arc::clone(&children[0]), + children.swap_remove(0), )?)) } @@ -170,22 +191,35 @@ impl ExecutionPlan for AsyncFuncExec { context.session_id(), context.task_id() ); - // TODO figure out how to record metrics // first execute the input stream let input_stream = self.input.execute(partition, Arc::clone(&context))?; + // TODO: Track `elapsed_compute` in `BaselineMetrics` + // Issue: + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + // now, for each record batch, evaluate the async expressions and add the columns to the result let async_exprs_captured = Arc::new(self.async_exprs.clone()); let schema_captured = self.schema(); let config_options_ref = Arc::clone(context.session_config().options()); - let stream_with_async_functions = input_stream.then(move |batch| { + let coalesced_input_stream = CoalesceInputStream { + input_stream, + batch_coalescer: LimitedBatchCoalescer::new( + Arc::clone(&self.input.schema()), + config_options_ref.execution.batch_size, + None, + ), + }; + + let stream_with_async_functions = coalesced_input_stream.then(move |batch| { // need to clone *again* to capture the async_exprs and schema in the // stream and satisfy lifetime requirements. let async_exprs_captured = Arc::clone(&async_exprs_captured); let schema_captured = Arc::clone(&schema_captured); let config_options = Arc::clone(&config_options_ref); + let baseline_metrics_captured = baseline_metrics.clone(); async move { let batch = batch?; @@ -198,7 +232,8 @@ impl ExecutionPlan for AsyncFuncExec { output_arrays.push(output.to_array(batch.num_rows())?); } let batch = RecordBatch::try_new(schema_captured, output_arrays)?; - Ok(batch) + + Ok(batch.record_output(&baseline_metrics_captured)) } }); @@ -213,6 +248,53 @@ impl ExecutionPlan for AsyncFuncExec { } } +struct CoalesceInputStream { + input_stream: Pin>, + batch_coalescer: LimitedBatchCoalescer, +} + +impl Stream for CoalesceInputStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let mut completed = false; + + loop { + if let Some(batch) = self.batch_coalescer.next_completed_batch() { + return Poll::Ready(Some(Ok(batch))); + } + + if completed { + return Poll::Ready(None); + } + + match ready!(self.input_stream.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if let Err(err) = self.batch_coalescer.push_batch(batch) { + return Poll::Ready(Some(Err(err))); + } + } + Some(err) => { + return Poll::Ready(Some(err)); + } + None => { + completed = true; + // Release the input pipeline's resources. + let input_schema = self.input_stream.schema(); + self.input_stream = + Box::pin(EmptyRecordBatchStream::new(input_schema)); + if let Err(err) = self.batch_coalescer.finish() { + return Poll::Ready(Some(Err(err))); + } + } + } + } + } +} + const ASYNC_FN_PREFIX: &str = "__async_fn_"; /// Maps async_expressions to new columns @@ -252,17 +334,15 @@ impl AsyncMapper { ) -> Result<()> { // recursively look for references to async functions physical_expr.apply(|expr| { - if let Some(scalar_func_expr) = - expr.as_any().downcast_ref::() + if let Some(scalar_func_expr) = expr.downcast_ref::() + && scalar_func_expr.fun().as_async().is_some() { - if scalar_func_expr.fun().as_async().is_some() { - let next_name = self.next_column_name(); - self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new( - next_name, - Arc::clone(expr), - schema, - )?)); - } + let next_name = self.next_column_name(); + self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new( + next_name, + Arc::clone(expr), + schema, + )?)); } Ok(TreeNodeRecursion::Continue) })?; @@ -300,3 +380,51 @@ impl AsyncMapper { Arc::new(Column::new(async_expr.name(), output_idx)) } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::{RecordBatch, UInt32Array}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_execution::{TaskContext, config::SessionConfig}; + use futures::StreamExt; + + use crate::{ExecutionPlan, async_func::AsyncFuncExec, test::TestMemoryExec}; + + #[tokio::test] + async fn test_async_fn_with_coalescing() -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6]))], + )?; + + let batches: Vec = std::iter::repeat_n(batch, 50).collect(); + + let session_config = SessionConfig::new().with_batch_size(200); + let task_ctx = TaskContext::default().with_session_config(session_config); + let task_ctx = Arc::new(task_ctx); + + let test_exec = + TestMemoryExec::try_new_exec(&[batches], Arc::clone(&schema), None)?; + let exec = AsyncFuncExec::try_new(vec![], test_exec)?; + + let mut stream = exec.execute(0, Arc::clone(&task_ctx))?; + let batch = stream + .next() + .await + .expect("expected to get a record batch")?; + assert_eq!(200, batch.num_rows()); + let batch = stream + .next() + .await + .expect("expected to get a record batch")?; + assert_eq!(100, batch.num_rows()); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/buffer.rs b/datafusion/physical-plan/src/buffer.rs new file mode 100644 index 0000000000000..2985dc57661b0 --- /dev/null +++ b/datafusion/physical-plan/src/buffer.rs @@ -0,0 +1,639 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`BufferExec`] decouples production and consumption on messages by buffering the input in the +//! background up to a certain capacity. + +use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, +}; +use crate::projection::ProjectionExec; +use crate::stream::RecordBatchStreamAdapter; +use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SortOrderPushdownResult, + check_if_same_properties, +}; +use arrow::array::RecordBatch; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{Result, Statistics, internal_err, plan_err}; +use datafusion_common_runtime::SpawnedTask; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr_common::metrics::{ + ExecutionPlanMetricsSet, MetricBuilder, MetricCategory, MetricsSet, +}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use futures::{Stream, StreamExt, TryStreamExt}; +use pin_project_lite::pin_project; +use std::fmt; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::sync::mpsc::UnboundedReceiver; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; + +/// WARNING: EXPERIMENTAL +/// +/// Decouples production and consumption of record batches with an internal queue per partition, +/// eagerly filling up the capacity of the queues even before any message is requested. +/// +/// ```text +/// ┌───────────────────────────┐ +/// │ BufferExec │ +/// │ │ +/// │┌────── Partition 0 ──────┐│ +/// ││ ┌────┐ ┌────┐││ ┌────┐ +/// ──background poll────────▶│ │ │ ├┼┼───────▶ │ +/// ││ └────┘ └────┘││ └────┘ +/// │└─────────────────────────┘│ +/// │┌────── Partition 1 ──────┐│ +/// ││ ┌────┐ ┌────┐ ┌────┐││ ┌────┐ +/// ──background poll─▶│ │ │ │ │ ├┼┼───────▶ │ +/// ││ └────┘ └────┘ └────┘││ └────┘ +/// │└─────────────────────────┘│ +/// │ │ +/// │ ... │ +/// │ │ +/// │┌────── Partition N ──────┐│ +/// ││ ┌────┐││ ┌────┐ +/// ──background poll───────────────▶│ ├┼┼───────▶ │ +/// ││ └────┘││ └────┘ +/// │└─────────────────────────┘│ +/// └───────────────────────────┘ +/// ``` +/// +/// The capacity is provided in bytes, and for each buffered record batch it will take into account +/// the size reported by [RecordBatch::get_array_memory_size]. +/// +/// If a single record batch exceeds the maximum capacity set in the `capacity` argument, it's still +/// allowed to pass in order to not deadlock the buffer. +/// +/// This is useful for operators that conditionally start polling one of their children only after +/// other child has finished, allowing to perform some early work and accumulating batches in +/// memory so that they can be served immediately when requested. +#[derive(Debug, Clone)] +pub struct BufferExec { + input: Arc, + properties: Arc, + capacity: usize, + metrics: ExecutionPlanMetricsSet, +} + +impl BufferExec { + /// Builds a new [BufferExec] with the provided capacity in bytes. + pub fn new(input: Arc, capacity: usize) -> Self { + let properties = PlanProperties::clone(input.properties()) + .with_scheduling_type(SchedulingType::Cooperative) + .with_evaluation_type(EvaluationType::Eager); + + Self { + input, + properties: Arc::new(properties), + capacity, + metrics: ExecutionPlanMetricsSet::new(), + } + } + + /// Returns the input [ExecutionPlan] of this [BufferExec]. + pub fn input(&self) -> &Arc { + &self.input + } + + /// Returns the per-partition capacity in bytes for this [BufferExec]. + pub fn capacity(&self) -> usize { + self.capacity + } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } +} + +impl DisplayAs for BufferExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "BufferExec: capacity={}", self.capacity) + } + DisplayFormatType::TreeRender => { + writeln!(f, "target_batch_size={}", self.capacity) + } + } + } +} + +impl ExecutionPlan for BufferExec { + fn name(&self) -> &str { + "BufferExec" + } + + fn properties(&self) -> &Arc { + &self.properties + } + + fn maintains_input_order(&self) -> Vec { + vec![true] + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + check_if_same_properties!(self, children); + if children.len() != 1 { + return plan_err!("BufferExec can only have one child"); + } + Ok(Arc::new(Self::new(children.swap_remove(0), self.capacity))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let mem_reservation = MemoryConsumer::new(format!("BufferExec[{partition}]")) + .register(context.memory_pool()); + let in_stream = self.input.execute(partition, context)?; + + // Set up the metrics for the stream. + let curr_mem_in = Arc::new(AtomicUsize::new(0)); + let curr_mem_out = Arc::clone(&curr_mem_in); + let mut max_mem_in = 0; + let max_mem = MetricBuilder::new(&self.metrics) + .with_category(MetricCategory::Bytes) + .gauge("max_mem_used", partition); + + let curr_queued_in = Arc::new(AtomicUsize::new(0)); + let curr_queued_out = Arc::clone(&curr_queued_in); + let mut max_queued_in = 0; + let max_queued = MetricBuilder::new(&self.metrics) + .with_category(MetricCategory::Rows) + .gauge("max_queued", partition); + + // Capture metrics when an element is queued on the stream. + let in_stream = in_stream.inspect_ok(move |v| { + let size = v.get_array_memory_size(); + let curr_size = curr_mem_in.fetch_add(size, Ordering::Relaxed) + size; + if curr_size > max_mem_in { + max_mem_in = curr_size; + max_mem.set(max_mem_in); + } + + let curr_queued = curr_queued_in.fetch_add(1, Ordering::Relaxed) + 1; + if curr_queued > max_queued_in { + max_queued_in = curr_queued; + max_queued.set(max_queued_in); + } + }); + // Buffer the input. + let out_stream = + MemoryBufferedStream::new(in_stream, self.capacity, mem_reservation); + // Update in the metrics that when an element gets out, some memory gets freed. + let out_stream = out_stream.inspect_ok(move |v| { + curr_mem_out.fetch_sub(v.get_array_memory_size(), Ordering::Relaxed); + curr_queued_out.fetch_sub(1, Ordering::Relaxed); + }); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + out_stream, + ))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn partition_statistics(&self, partition: Option) -> Result> { + self.input.partition_statistics(partition) + } + + fn supports_limit_pushdown(&self) -> bool { + self.input.supports_limit_pushdown() + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } + + fn try_swapping_with_projection( + &self, + projection: &ProjectionExec, + ) -> Result>> { + match self.input.try_swapping_with_projection(projection)? { + Some(new_input) => Ok(Some( + Arc::new(self.clone()).with_new_children(vec![new_input])?, + )), + None => Ok(None), + } + } + + fn gather_filters_for_pushdown( + &self, + _phase: FilterPushdownPhase, + parent_filters: Vec>, + _config: &ConfigOptions, + ) -> Result { + FilterDescription::from_children(parent_filters, &self.children()) + } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + Ok(FilterPushdownPropagation::if_all(child_pushdown_result)) + } + + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + ) -> Result>> { + // CoalesceBatchesExec is transparent for sort ordering - it preserves order + // Delegate to the child and wrap with a new CoalesceBatchesExec + self.input.try_pushdown_sort(order)?.try_map(|new_input| { + Ok(Arc::new(Self::new(new_input, self.capacity)) as Arc) + }) + } +} + +/// Represents anything that occupies a capacity in a [MemoryBufferedStream]. +pub trait SizedMessage { + fn size(&self) -> usize; +} + +impl SizedMessage for RecordBatch { + fn size(&self) -> usize { + self.get_array_memory_size() + } +} + +pin_project! { +/// Decouples production and consumption of messages in a stream with an internal queue, eagerly +/// filling it up to the specified maximum capacity even before any message is requested. +/// +/// Allows each message to have a different size, which is taken into account for determining if +/// the queue is full or not. +pub struct MemoryBufferedStream { + task: SpawnedTask<()>, + batch_rx: UnboundedReceiver>, + memory_reservation: Arc, +}} + +impl MemoryBufferedStream { + /// Builds a new [MemoryBufferedStream] with the provided capacity and event handler. + /// + /// This immediately spawns a Tokio task that will start consumption of the input stream. + pub fn new( + mut input: impl Stream> + Unpin + Send + 'static, + capacity: usize, + memory_reservation: MemoryReservation, + ) -> Self { + let semaphore = Arc::new(Semaphore::new(capacity)); + let (batch_tx, batch_rx) = tokio::sync::mpsc::unbounded_channel(); + + let memory_reservation = Arc::new(memory_reservation); + let memory_reservation_clone = Arc::clone(&memory_reservation); + let task = SpawnedTask::spawn(async move { + loop { + // Select on both the input stream and the channel being closed. + // By down this, we abort polling the input as soon as the consumer channel is + // closed. Otherwise, we would need to wait for a full new message to be available + // in order to consider aborting the stream + let item_or_err = tokio::select! { + biased; + _ = batch_tx.closed() => break, + item_or_err = input.next() => { + let Some(item_or_err) = item_or_err else { + break; // stream finished + }; + item_or_err + } + }; + + let item = match item_or_err { + Ok(batch) => batch, + Err(err) => { + let _ = batch_tx.send(Err(err)); // If there's an error it means the channel was closed, which is fine. + break; + } + }; + + let size = item.size(); + if let Err(err) = memory_reservation.try_grow(size) { + let _ = batch_tx.send(Err(err)); // If there's an error it means the channel was closed, which is fine. + break; + } + + // We need to cap the minimum between amount of permits and the actual size of the + // message. If at any point we try to acquire more permits than the capacity of the + // semaphore, the stream will deadlock. + let capped_size = size.min(capacity) as u32; + + let semaphore = Arc::clone(&semaphore); + let Ok(permit) = semaphore.acquire_many_owned(capped_size).await else { + let _ = batch_tx.send(internal_err!("Closed semaphore in MemoryBufferedStream. This is a bug in DataFusion, please report it!")); + break; + }; + + if batch_tx.send(Ok((item, permit))).is_err() { + break; // stream was closed + }; + } + }); + + Self { + task, + batch_rx, + memory_reservation: memory_reservation_clone, + } + } + + /// Returns the number of queued messages. + pub fn messages_queued(&self) -> usize { + self.batch_rx.len() + } +} + +impl Stream for MemoryBufferedStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let self_project = self.project(); + match self_project.batch_rx.poll_recv(cx) { + Poll::Ready(Some(Ok((item, _semaphore_permit)))) => { + self_project.memory_reservation.shrink(item.size()); + Poll::Ready(Some(Ok(item))) + } + Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } + + fn size_hint(&self) -> (usize, Option) { + if self.batch_rx.is_closed() { + let len = self.batch_rx.len(); + (len, Some(len)) + } else { + (self.batch_rx.len(), None) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::{DataFusionError, assert_contains}; + use datafusion_execution::memory_pool::{ + GreedyMemoryPool, MemoryPool, UnboundedMemoryPool, + }; + use std::error::Error; + use std::fmt::Debug; + use std::time::Duration; + use tokio::time::timeout; + + #[tokio::test] + async fn buffers_only_some_messages() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let buffered = MemoryBufferedStream::new(input, 4, res); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 2); + Ok(()) + } + + #[tokio::test] + async fn yields_all_messages() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 4); + + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + finished(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn yields_first_msg_even_if_big() -> Result<(), Box> { + let input = futures::stream::iter([25, 1, 2, 3]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn memory_pool_kills_stream() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = bounded_memory_pool_and_reservation(7); + + let mut buffered = MemoryBufferedStream::new(input, 10, res); + wait_for_buffering().await; + + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + let msg = pull_err_msg(&mut buffered).await?; + + assert_contains!(msg.to_string(), "Failed to allocate additional 4.0 B"); + Ok(()) + } + + #[tokio::test] + async fn memory_pool_does_not_kill_stream() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (_, res) = bounded_memory_pool_and_reservation(7); + + let mut buffered = MemoryBufferedStream::new(input, 3, res); + wait_for_buffering().await; + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + finished(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn messages_pass_even_if_all_exceed_limit() -> Result<(), Box> { + let input = futures::stream::iter([3, 3, 3, 3]).map(Ok); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 2, res); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 1); + pull_ok_msg(&mut buffered).await?; + + wait_for_buffering().await; + finished(&mut buffered).await?; + Ok(()) + } + + #[tokio::test] + async fn errors_get_propagated() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(|v| { + if v == 3 { + return internal_err!("Error on 3"); + } + Ok(v) + }); + let (_, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res); + wait_for_buffering().await; + + pull_ok_msg(&mut buffered).await?; + pull_ok_msg(&mut buffered).await?; + pull_err_msg(&mut buffered).await?; + + Ok(()) + } + + #[tokio::test] + async fn memory_gets_released_if_stream_drops() -> Result<(), Box> { + let input = futures::stream::iter([1, 2, 3, 4]).map(Ok); + let (pool, res) = memory_pool_and_reservation(); + + let mut buffered = MemoryBufferedStream::new(input, 10, res); + wait_for_buffering().await; + assert_eq!(buffered.messages_queued(), 4); + assert_eq!(pool.reserved(), 10); + + pull_ok_msg(&mut buffered).await?; + assert_eq!(buffered.messages_queued(), 3); + assert_eq!(pool.reserved(), 9); + + pull_ok_msg(&mut buffered).await?; + assert_eq!(buffered.messages_queued(), 2); + assert_eq!(pool.reserved(), 7); + + drop(buffered); + assert_eq!(pool.reserved(), 0); + Ok(()) + } + + fn memory_pool_and_reservation() -> (Arc, MemoryReservation) { + let pool = Arc::new(UnboundedMemoryPool::default()) as _; + let reservation = MemoryConsumer::new("test").register(&pool); + (pool, reservation) + } + + fn bounded_memory_pool_and_reservation( + size: usize, + ) -> (Arc, MemoryReservation) { + let pool = Arc::new(GreedyMemoryPool::new(size)) as _; + let reservation = MemoryConsumer::new("test").register(&pool); + (pool, reservation) + } + + async fn wait_for_buffering() { + // We do not have control over the spawned task, so the best we can do is to yield some + // cycles to the tokio runtime and let the task make progress on its own. + tokio::time::sleep(Duration::from_millis(1)).await; + } + + async fn pull_ok_msg( + buffered: &mut MemoryBufferedStream, + ) -> Result> { + Ok(timeout(Duration::from_millis(1), buffered.next()) + .await? + .unwrap_or_else(|| internal_err!("Stream should not have finished"))?) + } + + async fn pull_err_msg( + buffered: &mut MemoryBufferedStream, + ) -> Result> { + Ok(timeout(Duration::from_millis(1), buffered.next()) + .await? + .map(|v| match v { + Ok(v) => internal_err!( + "Stream should not have failed, but succeeded with {v:?}" + ), + Err(err) => Ok(err), + }) + .unwrap_or_else(|| internal_err!("Stream should not have finished"))?) + } + + async fn finished( + buffered: &mut MemoryBufferedStream, + ) -> Result<(), Box> { + match timeout(Duration::from_millis(1), buffered.next()) + .await? + .is_none() + { + true => Ok(()), + false => internal_err!("Stream should have finished")?, + } + } + + impl SizedMessage for usize { + fn size(&self) -> usize { + *self + } + } +} diff --git a/datafusion/physical-plan/src/coalesce/mod.rs b/datafusion/physical-plan/src/coalesce/mod.rs index d0930b2c0e58a..ea1a87d091481 100644 --- a/datafusion/physical-plan/src/coalesce/mod.rs +++ b/datafusion/physical-plan/src/coalesce/mod.rs @@ -18,7 +18,7 @@ use arrow::array::RecordBatch; use arrow::compute::BatchCoalescer; use arrow::datatypes::SchemaRef; -use datafusion_common::{assert_or_internal_err, Result}; +use datafusion_common::{Result, assert_or_internal_err}; /// Concatenate multiple [`RecordBatch`]es and apply a limit /// @@ -134,6 +134,10 @@ impl LimitedBatchCoalescer { Ok(()) } + pub(crate) fn is_finished(&self) -> bool { + self.finished + } + /// Return the next completed batch, if any pub fn next_completed_batch(&mut self) -> Option { self.inner.next_completed_batch() diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index eb3c3b5befbdd..76b2f63798f88 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -17,15 +17,17 @@ //! [`CoalesceBatchesExec`] combines small batches into larger batches. -use std::any::Any; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{DisplayAs, ExecutionPlanProperties, PlanProperties, Statistics}; +use crate::projection::ProjectionExec; +use crate::stream::EmptyRecordBatchStream; use crate::{ DisplayFormatType, ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, + check_if_same_properties, }; use arrow::datatypes::SchemaRef; @@ -40,7 +42,9 @@ use crate::filter_pushdown::{ ChildPushdownResult, FilterDescription, FilterPushdownPhase, FilterPushdownPropagation, }; +use crate::sort_pushdown::SortOrderPushdownResult; use datafusion_common::config::ConfigOptions; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use futures::ready; use futures::stream::{Stream, StreamExt}; @@ -54,6 +58,10 @@ use futures::stream::{Stream, StreamExt}; /// reaches the `fetch` value. /// /// See [`LimitedBatchCoalescer`] for more information +#[deprecated( + since = "52.0.0", + note = "We now use BatchCoalescer from arrow-rs instead of a dedicated operator" +)] #[derive(Debug, Clone)] pub struct CoalesceBatchesExec { /// The input plan @@ -64,9 +72,10 @@ pub struct CoalesceBatchesExec { fetch: Option, /// Execution metrics metrics: ExecutionPlanMetricsSet, - cache: PlanProperties, + cache: Arc, } +#[expect(deprecated)] impl CoalesceBatchesExec { /// Create a new CoalesceBatchesExec pub fn new(input: Arc, target_batch_size: usize) -> Self { @@ -76,7 +85,7 @@ impl CoalesceBatchesExec { target_batch_size, fetch: None, metrics: ExecutionPlanMetricsSet::new(), - cache, + cache: Arc::new(cache), } } @@ -107,8 +116,20 @@ impl CoalesceBatchesExec { input.boundedness(), ) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } +#[expect(deprecated)] impl DisplayAs for CoalesceBatchesExec { fn fmt_as( &self, @@ -139,17 +160,14 @@ impl DisplayAs for CoalesceBatchesExec { } } +#[expect(deprecated)] impl ExecutionPlan for CoalesceBatchesExec { fn name(&self) -> &'static str { "CoalesceBatchesExec" } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -167,10 +185,11 @@ impl ExecutionPlan for CoalesceBatchesExec { fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); Ok(Arc::new( - CoalesceBatchesExec::new(Arc::clone(&children[0]), self.target_batch_size) + CoalesceBatchesExec::new(children.swap_remove(0), self.target_batch_size) .with_fetch(self.fetch), )) } @@ -196,14 +215,9 @@ impl ExecutionPlan for CoalesceBatchesExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { - self.input - .partition_statistics(partition)? - .with_fetch(self.fetch, 0, 1) + fn partition_statistics(&self, partition: Option) -> Result> { + let stats = Arc::unwrap_or_clone(self.input.partition_statistics(partition)?); + Ok(Arc::new(stats.with_fetch(self.fetch, 0, 1)?)) } fn with_fetch(&self, limit: Option) -> Option> { @@ -212,7 +226,7 @@ impl ExecutionPlan for CoalesceBatchesExec { target_batch_size: self.target_batch_size, fetch: limit, metrics: self.metrics.clone(), - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), })) } @@ -224,6 +238,18 @@ impl ExecutionPlan for CoalesceBatchesExec { CardinalityEffect::Equal } + fn try_swapping_with_projection( + &self, + projection: &ProjectionExec, + ) -> Result>> { + match self.input.try_swapping_with_projection(projection)? { + Some(new_input) => Ok(Some( + Arc::new(self.clone()).with_new_children(vec![new_input])?, + )), + None => Ok(None), + } + } + fn gather_filters_for_pushdown( &self, _phase: FilterPushdownPhase, @@ -241,6 +267,20 @@ impl ExecutionPlan for CoalesceBatchesExec { ) -> Result>> { Ok(FilterPushdownPropagation::if_all(child_pushdown_result)) } + + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + ) -> Result>> { + // CoalesceBatchesExec is transparent for sort ordering - it preserves order + // Delegate to the child and wrap with a new CoalesceBatchesExec + self.input.try_pushdown_sort(order)?.try_map(|new_input| { + Ok(Arc::new( + CoalesceBatchesExec::new(new_input, self.target_batch_size) + .with_fetch(self.fetch), + ) as Arc) + }) + } } /// Stream for [`CoalesceBatchesExec`]. See [`CoalesceBatchesExec`] for more details. @@ -296,6 +336,8 @@ impl CoalesceBatchesStream { None => { // Input stream is exhausted, finalize any remaining batches self.completed = true; + self.input = + Box::pin(EmptyRecordBatchStream::new(self.coalescer.schema())); self.coalescer.finish()?; } Some(Ok(batch)) => { @@ -306,6 +348,9 @@ impl CoalesceBatchesStream { PushBatchStatus::LimitReached => { // limit was reached, so stop early self.completed = true; + self.input = Box::pin(EmptyRecordBatchStream::new( + self.coalescer.schema(), + )); self.coalescer.finish()?; } } diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 64e0315a523d1..fa200ef845f3a 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -18,7 +18,6 @@ //! Defines the merge plan for executing partitions in parallel and then merging the results //! into a single partition -use std::any::Any; use std::sync::Arc; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; @@ -29,11 +28,13 @@ use super::{ }; use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; use crate::filter_pushdown::{FilterDescription, FilterPushdownPhase}; -use crate::projection::{make_with_child, ProjectionExec}; -use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; +use crate::projection::{ProjectionExec, make_with_child}; +use crate::sort_pushdown::SortOrderPushdownResult; +use crate::{DisplayFormatType, ExecutionPlan, Partitioning, check_if_same_properties}; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_common::config::ConfigOptions; -use datafusion_common::{assert_eq_or_internal_err, internal_err, Result}; +use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; use datafusion_execution::TaskContext; use datafusion_physical_expr::PhysicalExpr; @@ -45,7 +46,7 @@ pub struct CoalescePartitionsExec { input: Arc, /// Execution metrics metrics: ExecutionPlanMetricsSet, - cache: PlanProperties, + cache: Arc, /// Optional number of rows to fetch. Stops producing rows after this fetch pub(crate) fetch: Option, } @@ -57,7 +58,7 @@ impl CoalescePartitionsExec { CoalescePartitionsExec { input, metrics: ExecutionPlanMetricsSet::new(), - cache, + cache: Arc::new(cache), fetch: None, } } @@ -98,6 +99,17 @@ impl CoalescePartitionsExec { .with_evaluation_type(drive) .with_scheduling_type(scheduling) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for CoalescePartitionsExec { @@ -129,11 +141,7 @@ impl ExecutionPlan for CoalescePartitionsExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -147,9 +155,10 @@ impl ExecutionPlan for CoalescePartitionsExec { fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { - let mut plan = CoalescePartitionsExec::new(Arc::clone(&children[0])); + check_if_same_properties!(self, children); + let mut plan = CoalescePartitionsExec::new(children.swap_remove(0)); plan.fetch = self.fetch; Ok(Arc::new(plan)) } @@ -222,14 +231,9 @@ impl ExecutionPlan for CoalescePartitionsExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, _partition: Option) -> Result { - self.input - .partition_statistics(None)? - .with_fetch(self.fetch, 0, 1) + fn partition_statistics(&self, _partition: Option) -> Result> { + let stats = Arc::unwrap_or_clone(self.input.partition_statistics(None)?); + Ok(Arc::new(stats.with_fetch(self.fetch, 0, 1)?)) } fn supports_limit_pushdown(&self) -> bool { @@ -272,10 +276,23 @@ impl ExecutionPlan for CoalescePartitionsExec { input: Arc::clone(&self.input), fetch: limit, metrics: self.metrics.clone(), - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), })) } + fn with_preserve_order( + &self, + preserve_order: bool, + ) -> Option> { + self.input + .with_preserve_order(preserve_order) + .and_then(|new_input| { + Arc::new(self.clone()) + .with_new_children(vec![new_input]) + .ok() + }) + } + fn gather_filters_for_pushdown( &self, _phase: FilterPushdownPhase, @@ -284,17 +301,56 @@ impl ExecutionPlan for CoalescePartitionsExec { ) -> Result { FilterDescription::from_children(parent_filters, &self.children()) } + + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + ) -> Result>> { + // CoalescePartitionsExec merges multiple partitions into one, which loses + // global ordering. However, we can still push the sort requirement down + // to optimize individual partitions - the Sort operator above will handle + // the global ordering. + // + // Note: The result will always be at most Inexact (never Exact) when there + // are multiple partitions, because merging destroys global ordering. + let result = self.input.try_pushdown_sort(order)?; + + // If we have multiple partitions, we can't return Exact even if the + // underlying source claims Exact - merging destroys global ordering + let has_multiple_partitions = + self.input.output_partitioning().partition_count() > 1; + + result + .try_map(|new_input| { + Ok( + Arc::new( + CoalescePartitionsExec::new(new_input).with_fetch(self.fetch), + ) as Arc, + ) + }) + .map(|r| { + if has_multiple_partitions { + // Downgrade Exact to Inexact when merging multiple partitions + r.into_inexact() + } else { + r + } + }) + } } #[cfg(test)] mod tests { use super::*; use crate::test::exec::{ - assert_strong_count_converges_to_zero, BlockingExec, PanicExec, + BarrierExec, BlockingExec, PanicExec, assert_strong_count_converges_to_zero, }; use crate::test::{self, assert_is_pending}; use crate::{collect, common}; + use std::time::Duration; + + use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, Schema}; use futures::FutureExt; @@ -329,6 +385,45 @@ mod tests { Ok(()) } + #[tokio::test] + async fn drops_input_plan_after_input_streams_start() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + let input_partitions = 2; + let batch = RecordBatch::new_empty(Arc::clone(&schema)); + let input = Arc::new( + BarrierExec::new(vec![vec![batch]; input_partitions], schema) + .without_start_barrier() + .with_finish_barrier() + .with_log(false), + ); + let refs = Arc::downgrade(&input); + + let input_plan: Arc = Arc::clone(&input); + let coalesce = CoalescePartitionsExec::new(input_plan); + let stream = coalesce.execute(0, task_ctx)?; + drop(coalesce); + + tokio::time::timeout(Duration::from_secs(5), async { + // Why not `wait_finish` here: that releases the barrier which lets the input tasks + // finish, which drops the input Arcs and hides the bug. + while !input.is_finish_barrier_reached() { + tokio::task::yield_now().await; + } + }) + .await + .expect("input streams should reach pending"); + + drop(input); + + assert_strong_count_converges_to_zero(refs).await; + + drop(stream); + + Ok(()) + } + #[tokio::test] async fn test_drop_cancel() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); diff --git a/datafusion/physical-plan/src/column_rewriter.rs b/datafusion/physical-plan/src/column_rewriter.rs new file mode 100644 index 0000000000000..2df95cd61474e --- /dev/null +++ b/datafusion/physical-plan/src/column_rewriter.rs @@ -0,0 +1,382 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion_common::{ + DataFusionError, HashMap, + tree_node::{Transformed, TreeNodeRecursion, TreeNodeRewriter}, +}; +use datafusion_physical_expr::{PhysicalExpr, expressions::Column}; + +/// Rewrite column references in a physical expr according to a mapping. +/// +/// This rewriter traverses the expression tree and replaces [`Column`] nodes +/// with the corresponding expression found in the `column_map`. +/// +/// If a column is found in the map, it is replaced by the mapped expression. +/// If a column is NOT found in the map, a `DataFusionError::Internal` is +/// returned. +pub struct PhysicalColumnRewriter<'a> { + /// Mapping from original column to new column. + pub column_map: &'a HashMap>, +} + +impl<'a> PhysicalColumnRewriter<'a> { + /// Create a new PhysicalColumnRewriter with the given column mapping. + pub fn new(column_map: &'a HashMap>) -> Self { + Self { column_map } + } +} + +impl<'a> TreeNodeRewriter for PhysicalColumnRewriter<'a> { + type Node = Arc; + + fn f_down( + &mut self, + node: Self::Node, + ) -> datafusion_common::Result> { + if let Some(column) = node.downcast_ref::() { + if let Some(new_column) = self.column_map.get(column) { + // jump to prevent rewriting the new sub-expression again + return Ok(Transformed::new( + Arc::clone(new_column), + true, + TreeNodeRecursion::Jump, + )); + } else { + // Column not found in mapping + return Err(DataFusionError::Internal(format!( + "Column {column:?} not found in column mapping {:?}", + self.column_map + ))); + } + } + Ok(Transformed::no(node)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{Result, tree_node::TreeNode}; + use datafusion_physical_expr::{ + PhysicalExpr, + expressions::{Column, binary, col, lit}, + }; + + /// Helper function to create a test schema + fn create_test_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("new_col", DataType::Int32, true), + Field::new("inner_col", DataType::Int32, true), + Field::new("another_col", DataType::Int32, true), + ])) + } + + /// Helper function to create a complex nested expression with multiple columns + /// Create: (col_a + col_b) * (col_c - col_d) + col_e + fn create_complex_expression(schema: &Schema) -> Arc { + let col_a = col("a", schema).unwrap(); + let col_b = col("b", schema).unwrap(); + let col_c = col("c", schema).unwrap(); + let col_d = col("d", schema).unwrap(); + let col_e = col("e", schema).unwrap(); + + let add_expr = + binary(col_a, datafusion_expr::Operator::Plus, col_b, schema).unwrap(); + let sub_expr = + binary(col_c, datafusion_expr::Operator::Minus, col_d, schema).unwrap(); + let mul_expr = binary( + add_expr, + datafusion_expr::Operator::Multiply, + sub_expr, + schema, + ) + .unwrap(); + binary(mul_expr, datafusion_expr::Operator::Plus, col_e, schema).unwrap() + } + + /// Helper function to create a deeply nested expression + /// Create: col_a + (col_b + (col_c + (col_d + col_e))) + fn create_deeply_nested_expression(schema: &Schema) -> Arc { + let col_a = col("a", schema).unwrap(); + let col_b = col("b", schema).unwrap(); + let col_c = col("c", schema).unwrap(); + let col_d = col("d", schema).unwrap(); + let col_e = col("e", schema).unwrap(); + + let inner1 = + binary(col_d, datafusion_expr::Operator::Plus, col_e, schema).unwrap(); + let inner2 = + binary(col_c, datafusion_expr::Operator::Plus, inner1, schema).unwrap(); + let inner3 = + binary(col_b, datafusion_expr::Operator::Plus, inner2, schema).unwrap(); + binary(col_a, datafusion_expr::Operator::Plus, inner3, schema).unwrap() + } + + #[test] + fn test_simple_column_replacement_with_jump() -> Result<()> { + let schema = create_test_schema(); + + // Test that Jump prevents re-processing of replaced columns + let mut column_map = HashMap::new(); + column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(42i32)); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + lit("replaced_b"), + ); + column_map.insert( + Column::new_with_schema("c", &schema).unwrap(), + col("c", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("d", &schema).unwrap(), + col("d", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("e", &schema).unwrap(), + col("e", &schema).unwrap(), + ); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + let expr = create_complex_expression(&schema); + + let result = expr.rewrite(&mut rewriter)?; + + // Verify the transformation occurred + assert!(result.transformed); + + assert_eq!( + format!("{}", result.data), + "(42 + replaced_b) * (c@2 - d@3) + e@4" + ); + + Ok(()) + } + + #[test] + fn test_nested_column_replacement_with_jump() -> Result<()> { + let schema = create_test_schema(); + // Test Jump behavior with deeply nested expressions + let mut column_map = HashMap::new(); + // Replace col_c with a complex expression containing new columns + let replacement_expr = binary( + lit(100i32), + datafusion_expr::Operator::Plus, + col("new_col", &schema).unwrap(), + &schema, + ) + .unwrap(); + column_map.insert( + Column::new_with_schema("c", &schema).unwrap(), + replacement_expr, + ); + column_map.insert( + Column::new_with_schema("a", &schema).unwrap(), + col("a", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + col("b", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("d", &schema).unwrap(), + col("d", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("e", &schema).unwrap(), + col("e", &schema).unwrap(), + ); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + let expr = create_deeply_nested_expression(&schema); + + let result = expr.rewrite(&mut rewriter)?; + + // Verify transformation occurred + assert!(result.transformed); + + assert_eq!( + format!("{}", result.data), + "a@0 + b@1 + 100 + new_col@5 + d@3 + e@4" + ); + + Ok(()) + } + + #[test] + fn test_circular_reference_prevention() -> Result<()> { + let schema = create_test_schema(); + // Test that Jump prevents infinite recursion with circular references + let mut column_map = HashMap::new(); + + // Create a circular reference: col_a -> col_b -> col_a (but Jump should prevent the second visit) + column_map.insert( + Column::new_with_schema("a", &schema).unwrap(), + col("b", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + col("a", &schema).unwrap(), + ); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + + // Start with an expression containing col_a + let expr = binary( + col("a", &schema).unwrap(), + datafusion_expr::Operator::Plus, + col("b", &schema).unwrap(), + &schema, + ) + .unwrap(); + + let result = expr.rewrite(&mut rewriter)?; + + // Verify transformation occurred + assert!(result.transformed); + + assert_eq!(format!("{}", result.data), "b@1 + a@0"); + + Ok(()) + } + + #[test] + fn test_multiple_replacements_in_same_expression() -> Result<()> { + let schema = create_test_schema(); + // Test multiple column replacements in the same complex expression + let mut column_map = HashMap::new(); + + // Replace multiple columns with literals + column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(10i32)); + column_map.insert(Column::new_with_schema("c", &schema).unwrap(), lit(20i32)); + column_map.insert(Column::new_with_schema("e", &schema).unwrap(), lit(30i32)); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + col("b", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("d", &schema).unwrap(), + col("d", &schema).unwrap(), + ); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + let expr = create_complex_expression(&schema); // (col_a + col_b) * (col_c - col_d) + col_e + + let result = expr.rewrite(&mut rewriter)?; + + // Verify transformation occurred + assert!(result.transformed); + assert_eq!(format!("{}", result.data), "(10 + b@1) * (20 - d@3) + 30"); + + Ok(()) + } + + #[test] + fn test_jump_with_complex_replacement_expression() -> Result<()> { + let schema = create_test_schema(); + // Test Jump behavior when replacing with very complex expressions + let mut column_map = HashMap::new(); + + // Replace col_a with a complex nested expression + let inner_expr = binary( + lit(5i32), + datafusion_expr::Operator::Multiply, + col("a", &schema).unwrap(), + &schema, + ) + .unwrap(); + let middle_expr = binary( + inner_expr, + datafusion_expr::Operator::Plus, + lit(3i32), + &schema, + ) + .unwrap(); + let complex_replacement = binary( + middle_expr, + datafusion_expr::Operator::Minus, + col("another_col", &schema).unwrap(), + &schema, + ) + .unwrap(); + + column_map.insert( + Column::new_with_schema("a", &schema).unwrap(), + complex_replacement, + ); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + col("b", &schema).unwrap(), + ); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + + // Create expression: col_a + col_b + let expr = binary( + col("a", &schema).unwrap(), + datafusion_expr::Operator::Plus, + col("b", &schema).unwrap(), + &schema, + ) + .unwrap(); + + let result = expr.rewrite(&mut rewriter)?; + + assert_eq!( + format!("{}", result.data), + "5 * a@0 + 3 - another_col@7 + b@1" + ); + + // Verify transformation occurred + assert!(result.transformed); + + Ok(()) + } + + #[test] + fn test_unmapped_columns_detection() -> Result<()> { + let schema = create_test_schema(); + let mut column_map = HashMap::new(); + + // Only map col_a, leave col_b unmapped + column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(42i32)); + + let mut rewriter = PhysicalColumnRewriter::new(&column_map); + + // Create expression: col_a + col_b + let expr = binary( + col("a", &schema).unwrap(), + datafusion_expr::Operator::Plus, + col("b", &schema).unwrap(), + &schema, + ) + .unwrap(); + + let err = expr.rewrite(&mut rewriter).unwrap_err(); + assert!(matches!(err, DataFusionError::Internal(_))); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index e9a8499a7c9ac..0dafcf6bd3390 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -22,21 +22,22 @@ use std::fs::metadata; use std::sync::Arc; use super::SendableRecordBatchStream; +use crate::expressions::{CastExpr, Column}; +use crate::projection::{ProjectionExec, ProjectionExpr}; use crate::stream::RecordBatchReceiverStream; -use crate::{ColumnStatistics, Statistics}; +use crate::{ColumnStatistics, ExecutionPlan, Statistics}; use arrow::array::Array; -use arrow::datatypes::Schema; +use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; -use datafusion_common::{plan_err, Result}; +use datafusion_common::{Result, plan_err}; use datafusion_execution::memory_pool::MemoryReservation; use futures::{StreamExt, TryStreamExt}; -use parking_lot::Mutex; /// [`MemoryReservation`] used across query execution streams -pub(crate) type SharedMemoryReservation = Arc>; +pub(crate) type SharedMemoryReservation = Arc; /// Create a vector of record batches from a stream pub async fn collect(stream: SendableRecordBatchStream) -> Result> { @@ -89,6 +90,96 @@ fn build_file_list_recurse( Ok(()) } +/// Align `input`'s physical plan schema with `expected_schema`. +/// +/// This helper is intended for operators that combine independently planned children but +/// expose a single declared output schema. It returns `input` unchanged when schemas already +/// match exactly. Otherwise, it validates that projection can safely produce the expected +/// schema, then wraps `input` in a [`ProjectionExec`] that keeps columns in their existing +/// positional order and aliases them to `expected_schema`'s field names. +/// +/// [`ProjectionExec`] can rename fields. When the expected field is nullable and the input +/// field is not, this helper also widens nullability with a same-type [`CastExpr`]. It rejects +/// differences that projection cannot safely normalize exactly, such as data type, metadata, +/// schema metadata, and nullability narrowing. +pub fn project_plan_to_schema( + input: Arc, + expected_schema: &SchemaRef, +) -> Result> { + let input_schema = input.schema(); + if input_schema.as_ref() == expected_schema.as_ref() { + return Ok(input); + } + + if input_schema.fields().len() != expected_schema.fields().len() { + return plan_err!( + "Cannot project plan to expected schema: expected {} column(s), got {}", + expected_schema.fields().len(), + input_schema.fields().len() + ); + } + + if input_schema.metadata() != expected_schema.metadata() { + return plan_err!( + "Cannot project plan to expected schema: schema metadata differ" + ); + } + + if let Some((i, input_field, expected_field, mismatch)) = input_schema + .fields() + .iter() + .zip(expected_schema.fields().iter()) + .enumerate() + .find_map(|(i, (input_field, expected_field))| { + if input_field.data_type() != expected_field.data_type() { + Some((i, input_field, expected_field, "data type")) + } else if input_field.is_nullable() && !expected_field.is_nullable() { + Some((i, input_field, expected_field, "nullability")) + } else if input_field.metadata() != expected_field.metadata() { + Some((i, input_field, expected_field, "metadata")) + } else { + None + } + }) + { + return plan_err!( + "Cannot project plan column {i} ('{}') to expected output field '{}': \ + field {mismatch} differs (input field: {:?}, expected field: {:?})", + input_field.name(), + expected_field.name(), + input_field, + expected_field + ); + } + + let projection_exprs = expected_schema + .fields() + .iter() + .enumerate() + .map(|(i, expected_field)| { + let input_field = input_schema.field(i); + let column = Arc::new(Column::new(input_field.name(), i)); + let expr = if !input_field.is_nullable() && expected_field.is_nullable() { + Arc::new(CastExpr::new_with_target_field( + column, + Arc::clone(expected_field), + None, + )) as _ + } else { + column as _ + }; + ProjectionExpr { + expr, + alias: expected_field.name().clone(), + } + }) + .collect::>(); + + let projection = ProjectionExec::try_new(projection_exprs, input)?; + debug_assert_eq!(projection.schema().as_ref(), expected_schema.as_ref()); + Ok(Arc::new(projection)) +} + /// If running in a tokio context spawns the execution of `stream` to a separate task /// allowing it to execute in parallel with an intermediate buffer of size `buffer` pub fn spawn_buffered( @@ -179,10 +270,7 @@ pub fn compute_record_batch_statistics( } /// Checks if the given projection is valid for the given schema. -pub fn can_project( - schema: &arrow::datatypes::SchemaRef, - projection: Option<&Vec>, -) -> Result<()> { +pub fn can_project(schema: &SchemaRef, projection: Option<&[usize]>) -> Result<()> { match projection { Some(columns) => { if columns @@ -207,12 +295,20 @@ pub fn can_project( #[cfg(test)] mod tests { use super::*; + use crate::empty::EmptyExec; + use crate::projection::ProjectionExec; + + use std::collections::HashMap; use arrow::{ array::{Float32Array, Float64Array, UInt64Array}, - datatypes::{DataType, Field}, + datatypes::{DataType, Field, Schema}, }; + fn empty_exec(fields: Vec) -> Arc { + Arc::new(EmptyExec::new(Arc::new(Schema::new(fields)))) + } + #[test] fn test_compute_record_batch_statistics_empty() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -262,6 +358,7 @@ mod tests { min_value: Precision::Absent, sum_value: Precision::Absent, null_count: Precision::Exact(0), + byte_size: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Absent, @@ -269,6 +366,7 @@ mod tests { min_value: Precision::Absent, sum_value: Precision::Absent, null_count: Precision::Exact(0), + byte_size: Precision::Absent, }, ], }; @@ -302,10 +400,158 @@ mod tests { min_value: Precision::Absent, sum_value: Precision::Absent, null_count: Precision::Exact(3), + byte_size: Precision::Absent, }], }; assert_eq!(actual, expected); Ok(()) } + + #[test] + fn project_plan_to_schema_returns_input_when_schema_matches() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int32, + false, + )])); + let input: Arc = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let result = project_plan_to_schema(Arc::clone(&input), &schema)?; + + assert!(Arc::ptr_eq(&input, &result)); + Ok(()) + } + + #[test] + fn project_plan_to_schema_aliases_field_names_with_projection_exec() -> Result<()> { + let input = empty_exec(vec![ + Field::new("recursive_a", DataType::Int32, false), + Field::new("recursive_b", DataType::Utf8, true), + ]); + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ])); + + let result = project_plan_to_schema(Arc::clone(&input), &expected_schema)?; + + let projection = result + .downcast_ref::() + .expect("schema rename should use ProjectionExec"); + assert!(Arc::ptr_eq(projection.input(), &input)); + assert_eq!(projection.schema(), expected_schema); + assert_eq!(projection.expr()[0].alias, "a"); + assert_eq!(projection.expr()[1].alias, "b"); + Ok(()) + } + + #[test] + fn project_plan_to_schema_preserves_matching_metadata_while_renaming() -> Result<()> { + let field_metadata = HashMap::from([("key".to_string(), "value".to_string())]); + let schema_metadata = + HashMap::from([("schema-key".to_string(), "schema-value".to_string())]); + let input_schema = Arc::new(Schema::new_with_metadata( + vec![ + Field::new("input", DataType::Int32, false) + .with_metadata(field_metadata.clone()), + ], + schema_metadata.clone(), + )); + let input: Arc = Arc::new(EmptyExec::new(input_schema)); + let expected_schema = Arc::new(Schema::new_with_metadata( + vec![ + Field::new("expected", DataType::Int32, false) + .with_metadata(field_metadata), + ], + schema_metadata, + )); + + let result = project_plan_to_schema(input, &expected_schema)?; + + assert_eq!(result.schema(), expected_schema); + Ok(()) + } + + #[test] + fn project_plan_to_schema_errors_on_column_count_mismatch() { + let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]); + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); + assert!(err.to_string().contains("expected 2 column")); + } + + #[test] + fn project_plan_to_schema_errors_on_type_mismatch() { + let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]); + let expected_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + + let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); + assert!(err.to_string().contains("field data type differs")); + } + + #[test] + fn project_plan_to_schema_widens_nullability() -> Result<()> { + let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]); + let expected_schema = Arc::new(Schema::new(vec![Field::new( + "renamed", + DataType::Int32, + true, + )])); + + let result = project_plan_to_schema(input, &expected_schema)?; + + assert_eq!(result.schema(), expected_schema); + Ok(()) + } + + #[test] + fn project_plan_to_schema_errors_on_nullability_narrowing() { + let input = empty_exec(vec![Field::new("a", DataType::Int32, true)]); + let expected_schema = Arc::new(Schema::new(vec![Field::new( + "renamed", + DataType::Int32, + false, + )])); + + let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); + assert!(err.to_string().contains("field nullability differs")); + } + + #[test] + fn project_plan_to_schema_errors_on_field_metadata_mismatch() { + let input = + empty_exec(vec![Field::new("a", DataType::Int32, false).with_metadata( + HashMap::from([("source".to_string(), "input".to_string())]), + )]); + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("renamed", DataType::Int32, false).with_metadata(HashMap::from([ + ("source".to_string(), "expected".to_string()), + ])), + ])); + + let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); + assert!(err.to_string().contains("field metadata differs")); + } + + #[test] + fn project_plan_to_schema_errors_on_schema_metadata_mismatch() { + let input_schema = Arc::new(Schema::new_with_metadata( + vec![Field::new("a", DataType::Int32, false)], + HashMap::from([("source".to_string(), "input".to_string())]), + )); + let input: Arc = Arc::new(EmptyExec::new(input_schema)); + let expected_schema = Arc::new(Schema::new_with_metadata( + vec![Field::new("renamed", DataType::Int32, false)], + HashMap::from([("source".to_string(), "expected".to_string())]), + )); + + let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); + assert!(err.to_string().contains("schema metadata differ")); + } } diff --git a/datafusion/physical-plan/src/coop.rs b/datafusion/physical-plan/src/coop.rs index aa5e7b4a8cec1..111999b71c91d 100644 --- a/datafusion/physical-plan/src/coop.rs +++ b/datafusion/physical-plan/src/coop.rs @@ -22,10 +22,15 @@ //! A single call to `poll_next` on a top-level [`Stream`] may potentially perform a lot of work //! before it returns a `Poll::Pending`. Think for instance of calculating an aggregation over a //! large dataset. +//! //! If a `Stream` runs for a long period of time without yielding back to the Tokio executor, //! it can starve other tasks waiting on that executor to execute them. //! Additionally, this prevents the query execution from being cancelled. //! +//! For more background, please also see the [Using Rust async for Query Execution and Cancelling Long-Running Queries blog] +//! +//! [Using Rust async for Query Execution and Cancelling Long-Running Queries blog]: https://datafusion.apache.org/blog/2025/06/30/cancellation +//! //! To ensure that `Stream` implementations yield regularly, operators can insert explicit yield //! points using the utilities in this module. For most operators this is **not** necessary. The //! `Stream`s of the built-in DataFusion operators that generate (rather than manipulate) @@ -69,7 +74,6 @@ use datafusion_common::config::ConfigOptions; use datafusion_physical_expr::PhysicalExpr; #[cfg(datafusion_coop = "tokio_fallback")] use futures::Future; -use std::any::Any; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -79,17 +83,19 @@ use crate::filter_pushdown::{ ChildPushdownResult, FilterDescription, FilterPushdownPhase, FilterPushdownPropagation, }; +use crate::projection::ProjectionExec; use crate::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, - SendableRecordBatchStream, + SendableRecordBatchStream, SortOrderPushdownResult, check_if_same_properties, }; use arrow::record_batch::RecordBatch; use arrow_schema::Schema; -use datafusion_common::{assert_eq_or_internal_err, Result, Statistics}; +use datafusion_common::{Result, Statistics, assert_eq_or_internal_err}; use datafusion_execution::TaskContext; use crate::execution_plan::SchedulingType; use crate::stream::RecordBatchStreamAdapter; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use futures::{Stream, StreamExt}; /// A stream that passes record batches through unchanged while cooperating with the Tokio runtime. @@ -207,19 +213,18 @@ where /// An execution plan decorator that enables cooperative multitasking. /// It wraps the streams produced by its input execution plan using the [`make_cooperative`] function, /// which makes the stream participate in Tokio cooperative scheduling. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct CooperativeExec { input: Arc, - properties: PlanProperties, + properties: Arc, } impl CooperativeExec { /// Creates a new `CooperativeExec` operator that wraps the given input execution plan. pub fn new(input: Arc) -> Self { - let properties = input - .properties() - .clone() - .with_scheduling_type(SchedulingType::Cooperative); + let properties = PlanProperties::clone(input.properties()) + .with_scheduling_type(SchedulingType::Cooperative) + .into(); Self { input, properties } } @@ -228,6 +233,16 @@ impl CooperativeExec { pub fn input(&self) -> &Arc { &self.input } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + ..Self::clone(self) + } + } } impl DisplayAs for CooperativeExec { @@ -245,15 +260,11 @@ impl ExecutionPlan for CooperativeExec { "CooperativeExec" } - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> Arc { self.input.schema() } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.properties } @@ -274,6 +285,7 @@ impl ExecutionPlan for CooperativeExec { 1, "CooperativeExec requires exactly one child" ); + check_if_same_properties!(self, children); Ok(Arc::new(CooperativeExec::new(children.swap_remove(0)))) } @@ -286,7 +298,7 @@ impl ExecutionPlan for CooperativeExec { Ok(make_cooperative(child_stream)) } - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { self.input.partition_statistics(partition) } @@ -298,6 +310,18 @@ impl ExecutionPlan for CooperativeExec { Equal } + fn try_swapping_with_projection( + &self, + projection: &ProjectionExec, + ) -> Result>> { + match self.input.try_swapping_with_projection(projection)? { + Some(new_input) => Ok(Some( + Arc::new(self.clone()).with_new_children(vec![new_input])?, + )), + None => Ok(None), + } + } + fn gather_filters_for_pushdown( &self, _phase: FilterPushdownPhase, @@ -315,6 +339,27 @@ impl ExecutionPlan for CooperativeExec { ) -> Result>> { Ok(FilterPushdownPropagation::if_all(child_pushdown_result)) } + + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + ) -> Result>> { + let child = self.input(); + + match child.try_pushdown_sort(order)? { + SortOrderPushdownResult::Exact { inner } => { + let new_exec = Arc::new(self.clone()).with_new_children(vec![inner])?; + Ok(SortOrderPushdownResult::Exact { inner: new_exec }) + } + SortOrderPushdownResult::Inexact { inner } => { + let new_exec = Arc::new(self.clone()).with_new_children(vec![inner])?; + Ok(SortOrderPushdownResult::Inexact { inner: new_exec }) + } + SortOrderPushdownResult::Unsupported => { + Ok(SortOrderPushdownResult::Unsupported) + } + } + } } /// Creates a [`CooperativeStream`] wrapper around the given [`RecordBatchStream`]. @@ -343,11 +388,10 @@ pub fn make_cooperative(stream: SendableRecordBatchStream) -> SendableRecordBatc #[cfg(test)] mod tests { use super::*; - use crate::stream::RecordBatchStreamAdapter; use arrow_schema::SchemaRef; - use futures::{stream, StreamExt}; + use futures::stream; // This is the hardcoded value Tokio uses const TASK_BUDGET: usize = 128; diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index 35ca0b65ae294..4642a9a4b1222 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -21,6 +21,7 @@ use std::collections::{BTreeMap, HashMap}; use std::fmt; use std::fmt::Formatter; +use std::time::Duration; use arrow::datatypes::SchemaRef; @@ -28,10 +29,10 @@ use datafusion_common::display::{GraphvizBuilder, PlanType, StringifiedPlan}; use datafusion_expr::display_schema; use datafusion_physical_expr::LexOrdering; -use crate::metrics::MetricType; +use crate::metrics::{MetricCategory, MetricType, MetricValue}; use crate::render_tree::RenderTree; -use super::{accept, ExecutionPlan, ExecutionPlanVisitor}; +use super::{ExecutionPlan, ExecutionPlanVisitor, accept}; /// Options for controlling how each [`ExecutionPlan`] should format itself #[derive(Debug, Clone, Copy, PartialEq)] @@ -75,7 +76,7 @@ pub enum DisplayFormatType { /// │ partition_sizes: [1] │ /// │ Parquet │ /// └───────────────────────────┘ - /// ``` + /// ``` TreeRender, } @@ -123,13 +124,27 @@ pub struct DisplayableExecutionPlan<'a> { show_schema: bool, /// Which metric categories should be included when rendering metric_types: Vec, + /// Optional filter by semantic category (rows / bytes / timing). + /// `None` means show all categories; `Some(vec![])` means plan-only. + metric_categories: Option>, // (TreeRender) Maximum total width of the rendered tree tree_maximum_render_width: usize, + /// Optional summary totals (currently only used by `pgjson`) — the total + /// row count and wall-clock duration of the `AnalyzeExec` execution. + summary: Option, +} + +/// Summary information attached to the root of an `EXPLAIN ANALYZE` +/// pgjson render. +#[derive(Debug, Clone, Copy)] +struct AnalyzeSummary { + total_rows: Option, + duration: Option, } impl<'a> DisplayableExecutionPlan<'a> { fn default_metric_types() -> Vec { - vec![MetricType::SUMMARY, MetricType::DEV] + vec![MetricType::Summary, MetricType::Dev] } /// Create a wrapper around an [`ExecutionPlan`] which can be @@ -141,7 +156,9 @@ impl<'a> DisplayableExecutionPlan<'a> { show_statistics: false, show_schema: false, metric_types: Self::default_metric_types(), + metric_categories: None, tree_maximum_render_width: 240, + summary: None, } } @@ -155,7 +172,9 @@ impl<'a> DisplayableExecutionPlan<'a> { show_statistics: false, show_schema: false, metric_types: Self::default_metric_types(), + metric_categories: None, tree_maximum_render_width: 240, + summary: None, } } @@ -169,7 +188,9 @@ impl<'a> DisplayableExecutionPlan<'a> { show_statistics: false, show_schema: false, metric_types: Self::default_metric_types(), + metric_categories: None, tree_maximum_render_width: 240, + summary: None, } } @@ -194,12 +215,44 @@ impl<'a> DisplayableExecutionPlan<'a> { self } + /// Specify which metric categories to include. + /// + /// - `None` means show all categories (default). + /// - `Some(vec![])` means plan-only — suppress all metrics. + /// - `Some(vec![Rows])` means show only row-count metrics (plus + /// uncategorized metrics). + /// + /// See [`MetricCategory`] for the determinism properties of each + /// category. + pub fn set_metric_categories( + mut self, + metric_categories: Option>, + ) -> Self { + self.metric_categories = metric_categories; + self + } + /// Set the maximum render width for the tree format pub fn set_tree_maximum_render_width(mut self, width: usize) -> Self { self.tree_maximum_render_width = width; self } + /// Attach an `EXPLAIN ANALYZE` summary (total output rows and duration) + /// to the rendered output. Currently only used by [`Self::pgjson`], which + /// serializes the summary alongside the root plan object. + pub fn set_summary( + mut self, + total_rows: Option, + duration: Option, + ) -> Self { + self.summary = Some(AnalyzeSummary { + total_rows, + duration, + }); + self + } + /// Return a `format`able structure that produces a single line /// per node. /// @@ -223,6 +276,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_statistics: bool, show_schema: bool, metric_types: Vec, + metric_categories: Option>, } impl fmt::Display for Wrapper<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { @@ -234,6 +288,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_statistics: self.show_statistics, show_schema: self.show_schema, metric_types: &self.metric_types, + metric_categories: self.metric_categories.as_deref(), }; accept(self.plan, &mut visitor) } @@ -245,6 +300,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_statistics: self.show_statistics, show_schema: self.show_schema, metric_types: self.metric_types.clone(), + metric_categories: self.metric_categories.clone(), } } @@ -265,6 +321,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_metrics: ShowMetrics, show_statistics: bool, metric_types: Vec, + metric_categories: Option>, } impl fmt::Display for Wrapper<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { @@ -276,6 +333,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_metrics: self.show_metrics, show_statistics: self.show_statistics, metric_types: &self.metric_types, + metric_categories: self.metric_categories.as_deref(), graphviz_builder: GraphvizBuilder::default(), parents: Vec::new(), }; @@ -294,6 +352,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_metrics: self.show_metrics, show_statistics: self.show_statistics, metric_types: self.metric_types.clone(), + metric_categories: self.metric_categories.clone(), } } @@ -320,6 +379,75 @@ impl<'a> DisplayableExecutionPlan<'a> { } } + /// Returns a `format`able structure that produces PostgreSQL-style JSON + /// output, mirroring the logical-plan pgjson format. + /// + /// Each node is rendered as a JSON object with: + /// - `"Node Type"` — `ExecutionPlan::name()` + /// - `"Details"` — the one-line `DisplayAs::Default` rendering + /// - `"Output"` — schema column names (when `set_show_schema(true)`) + /// - `"Actual Rows"` / `"Actual Total Time"` — PG-canonical metric keys + /// populated from `output_rows` / `elapsed_compute` when available + /// - `"Extras"` — remaining metrics keyed by DataFusion metric name + /// - `"Plans"` — array of child nodes + /// + /// When a summary has been set via [`Self::set_summary`], `"Total Rows"` + /// and `"Duration"` fields are attached at the root. + pub fn pgjson(&self, verbose: bool) -> impl fmt::Display + 'a { + struct Wrapper<'a> { + plan: &'a dyn ExecutionPlan, + verbose: bool, + show_metrics: ShowMetrics, + show_schema: bool, + metric_types: Vec, + metric_categories: Option>, + summary: Option, + } + impl fmt::Display for Wrapper<'_> { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let mut visitor = PgJsonExecutionPlanVisitor { + verbose: self.verbose, + show_metrics: self.show_metrics, + show_schema: self.show_schema, + metric_types: &self.metric_types, + metric_categories: self.metric_categories.as_deref(), + objects: HashMap::new(), + parent_ids: Vec::new(), + next_id: 0, + root: None, + }; + accept(self.plan, &mut visitor).map_err(|_| fmt::Error)?; + let root = visitor.root.ok_or(fmt::Error)?; + let mut root_entry = serde_json::json!({ "Plan": root }); + if let Some(summary) = self.summary { + if let Some(total_rows) = summary.total_rows { + root_entry["Total Rows"] = serde_json::Value::from(total_rows); + } + if let Some(duration) = summary.duration { + root_entry["Duration"] = + serde_json::Value::from(format!("{duration:?}")); + } + } + let doc = serde_json::Value::Array(vec![root_entry]); + write!( + f, + "{}", + serde_json::to_string_pretty(&doc).map_err(|_| fmt::Error)? + ) + } + } + + Wrapper { + plan: self.inner, + verbose, + show_metrics: self.show_metrics, + show_schema: self.show_schema, + metric_types: self.metric_types.clone(), + metric_categories: self.metric_categories.clone(), + summary: self.summary, + } + } + /// Return a single-line summary of the root of the plan /// Example: `ProjectionExec: expr=[a@0 as a]`. pub fn one_line(&self) -> impl fmt::Display + 'a { @@ -329,6 +457,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_statistics: bool, show_schema: bool, metric_types: Vec, + metric_categories: Option>, } impl fmt::Display for Wrapper<'_> { @@ -341,6 +470,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_statistics: self.show_statistics, show_schema: self.show_schema, metric_types: &self.metric_types, + metric_categories: self.metric_categories.as_deref(), }; visitor.pre_visit(self.plan)?; Ok(()) @@ -353,6 +483,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_statistics: self.show_statistics, show_schema: self.show_schema, metric_types: self.metric_types.clone(), + metric_categories: self.metric_categories.clone(), } } @@ -409,6 +540,8 @@ struct IndentVisitor<'a, 'b> { show_schema: bool, /// Which metric types should be rendered metric_types: &'a [MetricType], + /// Optional filter by semantic category (rows / bytes / timing). + metric_categories: Option<&'a [MetricCategory]>, } impl ExecutionPlanVisitor for IndentVisitor<'_, '_> { @@ -420,12 +553,14 @@ impl ExecutionPlanVisitor for IndentVisitor<'_, '_> { ShowMetrics::None => {} ShowMetrics::Aggregated => { if let Some(metrics) = plan.metrics() { - let metrics = metrics + let mut metrics = metrics .filter_by_metric_types(self.metric_types) .aggregate_by_name() .sorted_for_display() .timestamps_removed(); - + if let Some(cats) = self.metric_categories { + metrics = metrics.filter_by_categories(cats); + } write!(self.f, ", metrics=[{metrics}]")?; } else { write!(self.f, ", metrics=[]")?; @@ -433,7 +568,10 @@ impl ExecutionPlanVisitor for IndentVisitor<'_, '_> { } ShowMetrics::Full => { if let Some(metrics) = plan.metrics() { - let metrics = metrics.filter_by_metric_types(self.metric_types); + let mut metrics = metrics.filter_by_metric_types(self.metric_types); + if let Some(cats) = self.metric_categories { + metrics = metrics.filter_by_categories(cats); + } write!(self.f, ", metrics=[{metrics}]")?; } else { write!(self.f, ", metrics=[]")?; @@ -472,6 +610,8 @@ struct GraphvizVisitor<'a, 'b> { show_statistics: bool, /// Which metric types should be rendered metric_types: &'a [MetricType], + /// Optional filter by semantic category + metric_categories: Option<&'a [MetricCategory]>, graphviz_builder: GraphvizBuilder, /// Used to record parent node ids when visiting a plan. @@ -508,12 +648,14 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { ShowMetrics::None => "".to_string(), ShowMetrics::Aggregated => { if let Some(metrics) = plan.metrics() { - let metrics = metrics + let mut metrics = metrics .filter_by_metric_types(self.metric_types) .aggregate_by_name() .sorted_for_display() .timestamps_removed(); - + if let Some(cats) = self.metric_categories { + metrics = metrics.filter_by_categories(cats); + } format!("metrics=[{metrics}]") } else { "metrics=[]".to_string() @@ -521,7 +663,10 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { } ShowMetrics::Full => { if let Some(metrics) = plan.metrics() { - let metrics = metrics.filter_by_metric_types(self.metric_types); + let mut metrics = metrics.filter_by_metric_types(self.metric_types); + if let Some(cats) = self.metric_categories { + metrics = metrics.filter_by_categories(cats); + } format!("metrics=[{metrics}]") } else { "metrics=[]".to_string() @@ -565,6 +710,182 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { } } +/// Formats physical plans into PostgreSQL-style JSON output with live +/// per-operator metrics. +/// +/// This visitor mirrors the logical-plan `PgJsonVisitor` in +/// `datafusion-expr`: during `pre_visit` it assembles a JSON object for the +/// current node; during `post_visit` it attaches that object into its +/// parent's `"Plans"` array (or stores it as the root). +struct PgJsonExecutionPlanVisitor<'a> { + verbose: bool, + show_metrics: ShowMetrics, + show_schema: bool, + metric_types: &'a [MetricType], + metric_categories: Option<&'a [MetricCategory]>, + objects: HashMap, + parent_ids: Vec, + next_id: u32, + root: Option, +} + +impl PgJsonExecutionPlanVisitor<'_> { + /// Produce the one-line `DisplayAs::Default` rendering of a node. + fn one_line_details(plan: &dyn ExecutionPlan) -> String { + struct One<'b>(&'b dyn ExecutionPlan); + impl fmt::Display for One<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + self.0.fmt_as(DisplayFormatType::Default, f) + } + } + // Some operators include internal newlines; collapse them so the + // rendered JSON value stays on a single line. + format!("{}", One(plan)) + .replace('\n', " ") + .trim() + .to_string() + } + + /// Render the given `MetricValue` into the most natural `serde_json::Value` + /// we can produce: a number for simple counts/gauges/times, a float-ms for + /// `ElapsedCompute`, and a string fallback for anything else. + fn metric_value_to_json(value: &MetricValue) -> serde_json::Value { + match value { + MetricValue::OutputRows(c) => serde_json::Value::from(c.value()), + MetricValue::SpillCount(c) + | MetricValue::OutputBatches(c) + | MetricValue::SpilledRows(c) => serde_json::Value::from(c.value()), + MetricValue::SpilledBytes(c) | MetricValue::OutputBytes(c) => { + serde_json::Value::from(c.value()) + } + MetricValue::CurrentMemoryUsage(g) => serde_json::Value::from(g.value()), + MetricValue::ElapsedCompute(t) => { + // Emit as float milliseconds to align with PG's + // `"Actual Total Time"` convention. DataFusion tracks compute + // time (summed across partitions), not wall time — visualizers + // should be read with that caveat in mind. + let ms = (t.value() as f64) / 1_000_000.0; + serde_json::Value::from(ms) + } + MetricValue::Count { count, .. } => serde_json::Value::from(count.value()), + MetricValue::Gauge { gauge, .. } => serde_json::Value::from(gauge.value()), + MetricValue::Time { time, .. } => { + let ms = (time.value() as f64) / 1_000_000.0; + serde_json::Value::from(ms) + } + // Timestamps, PruningMetrics, Ratio, Custom: fall back to Display. + other => serde_json::Value::String(format!("{other}")), + } + } + + /// Populate `"Actual Rows"`, `"Actual Total Time"`, and `"Extras"` for + /// the given node from its aggregated `MetricsSet`, honoring the same + /// filtering pipeline used by `IndentVisitor`. + fn attach_metrics(&self, plan: &dyn ExecutionPlan, object: &mut serde_json::Value) { + if matches!(self.show_metrics, ShowMetrics::None) { + return; + } + let Some(metrics) = plan.metrics() else { + return; + }; + + let metrics = match self.show_metrics { + ShowMetrics::None => return, + ShowMetrics::Aggregated => metrics + .filter_by_metric_types(self.metric_types) + .aggregate_by_name() + .sorted_for_display() + .timestamps_removed(), + ShowMetrics::Full => metrics.filter_by_metric_types(self.metric_types), + }; + let metrics = if let Some(cats) = self.metric_categories { + metrics.filter_by_categories(cats) + } else { + metrics + }; + + // Build the Extras bucket, while extracting PG-canonical keys to the + // top level. + let mut extras = serde_json::Map::new(); + for metric in metrics.iter() { + let value = metric.value(); + match value { + MetricValue::OutputRows(c) => { + object["Actual Rows"] = serde_json::Value::from(c.value()); + } + MetricValue::ElapsedCompute(t) => { + let ms = (t.value() as f64) / 1_000_000.0; + object["Actual Total Time"] = serde_json::Value::from(ms); + } + _ => { + extras.insert( + value.name().to_string(), + Self::metric_value_to_json(value), + ); + } + } + } + if !extras.is_empty() { + object["Extras"] = serde_json::Value::Object(extras); + } + } +} + +impl ExecutionPlanVisitor for PgJsonExecutionPlanVisitor<'_> { + type Error = fmt::Error; + + fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { + let id = self.next_id; + self.next_id += 1; + + // Build fields in reading order: Node Type, Details, (schema), + // (metrics), Plans last — so the JSON output reads top-down like a + // PostgreSQL plan. + let mut object = serde_json::json!({ + "Node Type": plan.name(), + "Details": Self::one_line_details(plan), + }); + + if self.show_schema || self.verbose { + // Always include output columns when a caller asked for schema; + // also include them in verbose mode so the pgjson output mirrors + // the extra context shown by indent's verbose flag. + let columns: Vec = plan + .schema() + .fields() + .iter() + .map(|f| serde_json::Value::String(f.name().to_string())) + .collect(); + object["Output"] = serde_json::Value::Array(columns); + } + + self.attach_metrics(plan, &mut object); + + object["Plans"] = serde_json::Value::Array(vec![]); + + self.objects.insert(id, object); + self.parent_ids.push(id); + Ok(true) + } + + fn post_visit(&mut self, _plan: &dyn ExecutionPlan) -> Result { + let id = self.parent_ids.pop().ok_or(fmt::Error)?; + let current = self.objects.remove(&id).ok_or(fmt::Error)?; + + if let Some(parent_id) = self.parent_ids.last() { + let parent = self.objects.get_mut(parent_id).ok_or(fmt::Error)?; + let plans = parent + .get_mut("Plans") + .and_then(|p| p.as_array_mut()) + .ok_or(fmt::Error)?; + plans.push(current); + } else { + self.root = Some(current); + } + Ok(true) + } +} + /// This module implements a tree-like art renderer for execution plans, /// based on DuckDB's implementation: /// @@ -734,13 +1055,14 @@ impl TreeRenderVisitor<'_, '_> { if let Some(node) = root.get_node(x, y) { write!(self.f, "{}", Self::VERTICAL)?; - // Rigure out what to render. - let mut render_text = String::new(); - if render_y == 0 { - render_text = node.name.clone(); + // Figure out what to render. + let mut render_text = if render_y == 0 { + node.name.clone() } else if render_y <= extra_info[x].len() { - render_text = extra_info[x][render_y - 1].clone(); - } + extra_info[x][render_y - 1].clone() + } else { + String::new() + }; render_text = Self::adjust_text_for_rendering( &render_text, @@ -1120,7 +1442,7 @@ mod tests { use std::fmt::Write; use std::sync::Arc; - use datafusion_common::{internal_datafusion_err, Result, Statistics}; + use datafusion_common::{Result, Statistics, internal_datafusion_err}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use crate::{DisplayAs, ExecutionPlan, PlanProperties}; @@ -1149,11 +1471,7 @@ mod tests { "TestStatsExecPlan" } - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { unimplemented!() } @@ -1176,18 +1494,17 @@ mod tests { todo!() } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics( + &self, + partition: Option, + ) -> Result> { if partition.is_some() { - return Ok(Statistics::new_unknown(self.schema().as_ref())); + return Ok(Arc::new(Statistics::new_unknown(self.schema().as_ref()))); } match self { Self::Panic => panic!("expected panic"), Self::Error => Err(internal_datafusion_err!("expected error")), - Self::Ok => Ok(Statistics::new_unknown(self.schema().as_ref())), + Self::Ok => Ok(Arc::new(Statistics::new_unknown(self.schema().as_ref()))), } } } @@ -1233,4 +1550,192 @@ mod tests { fn test_display_when_stats_ok_with_show_stats() { test_stats_display(TestStatsExecPlan::Ok, false); } + + mod pgjson { + use std::sync::Arc; + use std::time::Duration; + + use arrow::datatypes::{DataType, Field, Schema}; + use insta::assert_snapshot; + + use super::super::DisplayableExecutionPlan; + use crate::empty::EmptyExec; + use crate::filter::FilterExec; + use crate::projection::ProjectionExec; + use datafusion_physical_expr::expressions::{binary, col, lit}; + use datafusion_physical_expr::{Partitioning, PhysicalExpr}; + + fn sample_plan() -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + let empty = Arc::new(EmptyExec::new(Arc::clone(&schema))); + let predicate = binary( + col("a", &schema).unwrap(), + datafusion_expr::Operator::Gt, + lit(5i32), + &schema, + ) + .unwrap(); + let filter = Arc::new(FilterExec::try_new(predicate, empty).unwrap()); + let proj_expr: Vec<(Arc, String)> = + vec![(col("a", &schema).unwrap(), "a".to_string())]; + let _ = Partitioning::UnknownPartitioning(1); + Arc::new(ProjectionExec::try_new(proj_expr, filter).unwrap()) + } + + #[test] + fn pgjson_renders_plan_without_metrics() { + let plan = sample_plan(); + let out = DisplayableExecutionPlan::new(plan.as_ref()) + .pgjson(false) + .to_string(); + let value: serde_json::Value = serde_json::from_str(&out).unwrap(); + // Root is an array with one {"Plan": ...} entry. + let root = value + .as_array() + .expect("root array") + .first() + .expect("root entry") + .get("Plan") + .expect("plan object"); + assert_eq!(root["Node Type"].as_str(), Some("ProjectionExec")); + assert!(root.get("Actual Rows").is_none()); + assert!(root.get("Extras").is_none()); + let plans = root["Plans"].as_array().expect("Plans array"); + assert_eq!(plans.len(), 1); + assert_eq!(plans[0]["Node Type"].as_str(), Some("FilterExec")); + } + + #[test] + fn pgjson_emits_pg_canonical_metric_keys() { + use crate::metrics::{Count, Metric, MetricValue, MetricsSet, Time}; + use crate::{DisplayFormatType, ExecutionPlan, PlanProperties}; + use datafusion_common::Result; + use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + + // Wrap `sample_plan()` with an adapter node that exposes a + // hand-crafted `MetricsSet` so we can assert the PG key mapping + // without running anything. + #[derive(Debug)] + struct WithMetrics { + inner: Arc, + metrics: MetricsSet, + } + impl crate::DisplayAs for WithMetrics { + fn fmt_as( + &self, + _t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "WithMetrics") + } + } + impl ExecutionPlan for WithMetrics { + fn name(&self) -> &'static str { + "WithMetrics" + } + fn properties(&self) -> &Arc { + self.inner.properties() + } + fn children(&self) -> Vec<&Arc> { + vec![&self.inner] + } + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + unimplemented!() + } + fn execute( + &self, + _: usize, + _: Arc, + ) -> Result { + unimplemented!() + } + fn metrics(&self) -> Option { + Some(self.metrics.clone()) + } + } + + let mut metrics = MetricsSet::new(); + let rows = Count::new(); + rows.add(42); + metrics.push(Arc::new(Metric::new(MetricValue::OutputRows(rows), None))); + let elapsed = Time::new(); + elapsed.add_duration(Duration::from_millis(5)); + metrics.push(Arc::new(Metric::new( + MetricValue::ElapsedCompute(elapsed), + None, + ))); + let batches = Count::new(); + batches.add(7); + metrics.push(Arc::new(Metric::new( + MetricValue::OutputBatches(batches), + None, + ))); + + let plan: Arc = Arc::new(WithMetrics { + inner: sample_plan(), + metrics, + }); + + let out = DisplayableExecutionPlan::with_metrics(plan.as_ref()) + .pgjson(false) + .to_string(); + let value: serde_json::Value = serde_json::from_str(&out).unwrap(); + let root = value[0].get("Plan").expect("plan"); + assert_eq!(root["Actual Rows"].as_u64(), Some(42)); + assert_eq!(root["Actual Total Time"].as_f64(), Some(5.0)); + assert_eq!(root["Extras"]["output_batches"].as_u64(), Some(7)); + } + + #[test] + fn pgjson_includes_summary_when_set() { + let plan = sample_plan(); + let out = DisplayableExecutionPlan::with_metrics(plan.as_ref()) + .set_summary(Some(42), Some(Duration::from_millis(7))) + .pgjson(false) + .to_string(); + let value: serde_json::Value = serde_json::from_str(&out).unwrap(); + let entry = &value.as_array().unwrap()[0]; + assert_eq!(entry["Total Rows"].as_u64(), Some(42)); + assert!(entry["Duration"].is_string()); + } + + #[test] + fn pgjson_snapshot_of_sample_plan() { + let plan = sample_plan(); + let out = DisplayableExecutionPlan::new(plan.as_ref()) + .pgjson(false) + .to_string(); + // This snapshot assumes `serde_json` is built with the + // `preserve_order` feature (enabled via this crate's dev-deps). + assert_snapshot!(out, @r#" + [ + { + "Plan": { + "Node Type": "ProjectionExec", + "Details": "ProjectionExec: expr=[a@0 as a]", + "Plans": [ + { + "Node Type": "FilterExec", + "Details": "FilterExec: a@0 > 5", + "Plans": [ + { + "Node Type": "EmptyExec", + "Details": "EmptyExec", + "Plans": [] + } + ] + } + ] + } + } + ] + "#); + } + } } diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index e072b55ecff44..2e7f982a51a31 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -17,19 +17,19 @@ //! EmptyRelation with produce_one_row=false execution plan -use std::any::Any; use std::sync::Arc; use crate::memory::MemoryStream; -use crate::{common, DisplayAs, PlanProperties, SendableRecordBatchStream, Statistics}; +use crate::{DisplayAs, PlanProperties, SendableRecordBatchStream, Statistics}; use crate::{ - execution_plan::{Boundedness, EmissionType}, DisplayFormatType, ExecutionPlan, Partitioning, + execution_plan::{Boundedness, EmissionType}, }; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::{assert_or_internal_err, Result}; +use datafusion_common::stats::Precision; +use datafusion_common::{ColumnStatistics, Result, ScalarValue, assert_or_internal_err}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; @@ -43,7 +43,7 @@ pub struct EmptyExec { schema: SchemaRef, /// Number of partitions partitions: usize, - cache: PlanProperties, + cache: Arc, } impl EmptyExec { @@ -53,7 +53,7 @@ impl EmptyExec { EmptyExec { schema, partitions: 1, - cache, + cache: Arc::new(cache), } } @@ -62,7 +62,7 @@ impl EmptyExec { self.partitions = partitions; // Changing partitions may invalidate output partitioning, so update it: let output_partitioning = Self::output_partitioning_helper(self.partitions); - self.cache = self.cache.with_partitioning(output_partitioning); + Arc::make_mut(&mut self.cache).partitioning = output_partitioning; self } @@ -110,11 +110,7 @@ impl ExecutionPlan for EmptyExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -134,7 +130,12 @@ impl ExecutionPlan for EmptyExec { partition: usize, context: Arc, ) -> Result { - trace!("Start EmptyExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + trace!( + "Start EmptyExec::execute for partition {} of context session_id {} and task_id {:?}", + partition, + context.session_id(), + context.task_id() + ); assert_or_internal_err!( partition < self.partitions, @@ -150,11 +151,7 @@ impl ExecutionPlan for EmptyExec { )?)) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { if let Some(partition) = partition { assert_or_internal_err!( partition < self.partitions, @@ -164,20 +161,31 @@ impl ExecutionPlan for EmptyExec { ); } - let batch = self - .data() - .expect("Create empty RecordBatch should not fail"); - Ok(common::compute_record_batch_statistics( - &[batch], - &self.schema, - None, - )) + // Build explicit stats: exact zero rows and bytes, with explicit known column stats + let mut stats = Statistics::default() + .with_num_rows(Precision::Exact(0)) + .with_total_byte_size(Precision::Exact(0)); + + // Add explicit column stats for each field in schema + for _ in self.schema.fields() { + stats = stats.add_column_statistics(ColumnStatistics { + null_count: Precision::Exact(0), + distinct_count: Precision::Exact(0), + min_value: Precision::::Absent, + max_value: Precision::::Absent, + sum_value: Precision::::Absent, + byte_size: Precision::Exact(0), + }); + } + + Ok(Arc::new(stats)) } } #[cfg(test)] mod tests { use super::*; + use crate::common; use crate::test; use crate::with_new_children_if_necessary; diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs index 553e3e26cec04..8577e86f00514 100644 --- a/datafusion/physical-plan/src/execution_plan.rs +++ b/datafusion/physical-plan/src/execution_plan.rs @@ -22,39 +22,46 @@ use crate::filter_pushdown::{ }; pub use crate::metrics::Metric; pub use crate::ordering::InputOrderMode; +use crate::sort_pushdown::SortOrderPushdownResult; pub use crate::stream::EmptyRecordBatchStream; +use arrow_schema::Schema; pub use datafusion_common::hash_utils; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; pub use datafusion_common::utils::project_schema; -pub use datafusion_common::{internal_err, ColumnStatistics, Statistics}; +pub use datafusion_common::{ColumnStatistics, Statistics, internal_err}; pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; pub use datafusion_expr::{Accumulator, ColumnarValue}; pub use datafusion_physical_expr::window::WindowExpr; pub use datafusion_physical_expr::{ - expressions, Distribution, Partitioning, PhysicalExpr, + Distribution, Partitioning, PhysicalExpr, expressions, }; use std::any::Any; use std::fmt::Debug; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::display::DisplayableExecutionPlan; use crate::metrics::MetricsSet; use crate::projection::ProjectionExec; +use crate::repartition::RepartitionExec; +use crate::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::stream::RecordBatchStreamAdapter; use arrow::array::{Array, RecordBatch}; use arrow::datatypes::SchemaRef; use datafusion_common::config::ConfigOptions; use datafusion_common::{ - assert_eq_or_internal_err, assert_or_internal_err, exec_err, Constraints, - DataFusionError, Result, + Constraints, DataFusionError, Result, assert_eq_or_internal_err, + assert_or_internal_err, exec_err, }; use datafusion_common_runtime::JoinSet; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, OrderingRequirements, PhysicalSortExpr, +}; use futures::stream::{StreamExt, TryStreamExt}; @@ -85,8 +92,8 @@ use futures::stream::{StreamExt, TryStreamExt}; /// `ExecutionPlan` with memory tracking and spilling support. /// /// [`datafusion-examples`]: https://github.com/apache/datafusion/tree/main/datafusion-examples -/// [`memory_pool_execution_plan.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/memory_pool_execution_plan.rs -pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { +/// [`memory_pool_execution_plan.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/execution_monitoring/memory_pool_execution_plan.rs +pub trait ExecutionPlan: Any + Debug + DisplayAs + Send + Sync { /// Short name for the ExecutionPlan, such as 'DataSourceExec'. /// /// Implementation note: this method can just proxy to @@ -110,9 +117,29 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { } } - /// Returns the execution plan as [`Any`] so that it can be - /// downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; + /// Returns the plan that provides this plan's public + /// [`ExecutionPlan`] downcast identity. + /// + /// This hook is for wrapper nodes that delegate their public downcast + /// identity to another plan while adding cross-cutting behavior such as + /// instrumentation. The default implementation returns `None`, meaning this + /// plan's concrete type is used for type introspection. + /// + /// Most `ExecutionPlan` implementations should use the default `None`; + /// override this only for wrapper plans that intentionally delegate their + /// public downcast identity to another plan. + /// + /// The `is` and `downcast_ref` helpers follow the returned delegate instead + /// of checking the current concrete type, making intermediate delegating + /// wrappers invisible to normal downcast-based inspection. + /// + /// Implementations that opt in should return the delegate plan, not `self`. + /// + /// This is independent from [`Self::children`] and should not be used for + /// plan traversal or optimizer rewrites. + fn downcast_delegate(&self) -> Option<&dyn ExecutionPlan> { + None + } /// Get the schema for this execution plan fn schema(&self) -> SchemaRef { @@ -124,7 +151,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// /// This information is available via methods on [`ExecutionPlanProperties`] /// trait, which is implemented for all `ExecutionPlan`s. - fn properties(&self) -> &PlanProperties; + fn properties(&self) -> &Arc; /// Returns an error if this individual node does not conform to its invariants. /// These invariants are typically only checked in debug mode. @@ -468,22 +495,11 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { None } - /// Returns statistics for this `ExecutionPlan` node. If statistics are not - /// available, should return [`Statistics::new_unknown`] (the default), not - /// an error. - /// - /// For TableScan executors, which supports filter pushdown, special attention - /// needs to be paid to whether the stats returned by this method are exact or not - #[deprecated(since = "48.0.0", note = "Use `partition_statistics` method instead")] - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema())) - } - /// Returns statistics for a specific partition of this `ExecutionPlan` node. /// If statistics are not available, should return [`Statistics::new_unknown`] /// (the default), not an error. /// If `partition` is `None`, it returns statistics for the entire plan. - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { if let Some(idx) = partition { // Validate partition index let partition_count = self.properties().partitioning.partition_count(); @@ -494,7 +510,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { partition_count ); } - Ok(Statistics::new_unknown(&self.schema())) + Ok(Arc::new(Statistics::new_unknown(&self.schema()))) } /// Returns `true` if a limit can be safely pushed down through this @@ -509,6 +525,10 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// Returns a fetching variant of this `ExecutionPlan` node, if it supports /// fetch limits. Returns `None` otherwise. + /// + /// See physical optimizer rule [`limit_pushdown`] for details. + /// + /// [`limit_pushdown`]: https://docs.rs/datafusion/latest/datafusion/physical_optimizer/limit_pushdown/index.html fn with_fetch(&self, _limit: Option) -> Option> { None } @@ -573,6 +593,7 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { } /// Handle the result of a child pushdown. + /// /// This method is called as we recurse back up the plan tree after pushing /// filters down to child nodes via [`ExecutionPlan::gather_filters_for_pushdown`]. /// It allows the current node to process the results of filter pushdown from @@ -682,6 +703,74 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { ) -> Option> { None } + + /// Try to push down sort ordering requirements to this node. + /// + /// This method is called during sort pushdown optimization to determine if this + /// node can optimize for a requested sort ordering. Implementations should: + /// + /// - Return [`SortOrderPushdownResult::Exact`] if the node can guarantee the exact + /// ordering (allowing the Sort operator to be removed) + /// - Return [`SortOrderPushdownResult::Inexact`] if the node can optimize for the + /// ordering but cannot guarantee perfect sorting (Sort operator is kept) + /// - Return [`SortOrderPushdownResult::Unsupported`] if the node cannot optimize + /// for the ordering + /// + /// For transparent nodes (that preserve ordering), implement this to delegate to + /// children and wrap the result with a new instance of this node. + /// + /// Default implementation returns `Unsupported`. + fn try_pushdown_sort( + &self, + _order: &[PhysicalSortExpr], + ) -> Result>> { + Ok(SortOrderPushdownResult::Unsupported) + } + + /// Returns a variant of this `ExecutionPlan` that is aware of order-sensitivity. + /// + /// This is used to signal to data sources that the output ordering must be + /// preserved, even if it might be more efficient to ignore it (e.g. by + /// skipping some row groups in Parquet). + /// + fn with_preserve_order( + &self, + _preserve_order: bool, + ) -> Option> { + None + } +} + +impl dyn ExecutionPlan { + /// Returns `true` if the plan is of type `T`. + /// + /// If this plan provides a [`ExecutionPlan::downcast_delegate`], delegates + /// to it. + /// + /// Prefer this over `downcast_ref::().is_some()`. Works correctly when + /// called on `Arc` via auto-deref. + pub fn is(&self) -> bool { + match self.downcast_delegate() { + Some(delegate) => delegate.is::(), + None => (self as &dyn Any).is::(), + } + } + + /// Attempts to downcast this plan to a concrete type `T`, returning `None` + /// if the plan is not of that type. + /// + /// If this plan provides a [`ExecutionPlan::downcast_delegate`], delegates + /// to it. + /// + /// Works correctly when called on `Arc` via auto-deref, + /// unlike `(&arc as &dyn Any).downcast_ref::()` which would attempt to + /// downcast the `Arc` itself. + pub fn downcast_ref(&self) -> Option<&T> { + match self.downcast_delegate() { + Some(delegate) => delegate.downcast_ref::(), + None => (self as &dyn Any).downcast_ref(), + } + } } /// [`ExecutionPlan`] Invariant Level @@ -875,25 +964,36 @@ pub enum SchedulingType { Cooperative, } -/// Represents how an operator's `Stream` implementation generates `RecordBatch`es. +/// Represents how an operator's stream drives [`RecordBatch`] production +/// relative to downstream demand. /// -/// Most operators in DataFusion generate `RecordBatch`es when asked to do so by a call to -/// `Stream::poll_next`. This is known as demand-driven or lazy evaluation. -/// -/// Some operators like `Repartition` need to drive `RecordBatch` generation themselves though. This -/// is known as data-driven or eager evaluation. +/// This is execution-topology metadata for optimizers. It distinguishes streams +/// whose batch production is driven directly by downstream calls to +/// `Stream::poll_next` from streams that may also drive input or output +/// production independently, such as by spawning tasks or buffering batches +/// ahead of demand. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum EvaluationType { - /// The stream generated by [`execute`](ExecutionPlan::execute) only generates `RecordBatch` - /// instances when it is demanded by invoking `Stream::poll_next`. - /// Filter, projection, and join are examples of such lazy operators. + /// The stream generated by [`execute`](ExecutionPlan::execute) is + /// demand-driven: it produces [`RecordBatch`]es in response to downstream + /// calls to `Stream::poll_next`. + /// + /// Filter, projection, and join operators are examples of lazy operators. /// /// Lazy operators are also known as demand-driven operators. Lazy, - /// The stream generated by [`execute`](ExecutionPlan::execute) eagerly generates `RecordBatch` - /// in one or more spawned Tokio tasks. Eager evaluation is only started the first time - /// `Stream::poll_next` is called. - /// Examples of eager operators are repartition, coalesce partitions, and sort preserving merge. + /// The stream generated by [`execute`](ExecutionPlan::execute) may drive + /// input or output [`RecordBatch`] production ahead of, or independently + /// from, downstream calls to `Stream::poll_next`. + /// + /// Eager operators commonly poll input streams from spawned Tokio tasks, + /// buffer batches ahead of demand, or otherwise create an independent + /// child-polling pipeline. Eager work may start when `execute` creates the + /// stream or when the returned stream is first polled; that timing is an + /// implementation detail. + /// + /// Repartition, coalesce partitions, sort-preserving merge, buffer, and + /// analyze operators are examples of eager operators. /// /// Eager operators are also known as a data-driven operators. Eager, @@ -921,7 +1021,7 @@ pub(crate) fn boundedness_from_children<'a>( } => { return Boundedness::Unbounded { requires_infinite_memory: true, - } + }; } Boundedness::Unbounded { requires_infinite_memory: false, @@ -1020,12 +1120,17 @@ impl PlanProperties { self } - /// Overwrite equivalence properties with its new value. - pub fn with_eq_properties(mut self, eq_properties: EquivalenceProperties) -> Self { + /// Set equivalence properties having mut reference. + pub fn set_eq_properties(&mut self, eq_properties: EquivalenceProperties) { // Changing equivalence properties also changes output ordering, so // make sure to overwrite it: self.output_ordering = eq_properties.output_ordering(); self.eq_properties = eq_properties; + } + + /// Overwrite equivalence properties with its new value. + pub fn with_eq_properties(mut self, eq_properties: EquivalenceProperties) -> Self { + self.set_eq_properties(eq_properties); self } @@ -1057,9 +1162,14 @@ impl PlanProperties { self } + /// Set constraints having mut reference. + pub fn set_constraints(&mut self, constraints: Constraints) { + self.eq_properties.set_constraints(constraints); + } + /// Overwrite constraints with its new value. pub fn with_constraints(mut self, constraints: Constraints) -> Self { - self.eq_properties = self.eq_properties.with_constraints(constraints); + self.set_constraints(constraints); self } @@ -1112,15 +1222,31 @@ pub fn check_default_invariants( Ok(()) } -/// Indicate whether a data exchange is needed for the input of `plan`, which will be very helpful -/// especially for the distributed engine to judge whether need to deal with shuffling. -/// Currently, there are 3 kinds of execution plan which needs data exchange -/// 1. RepartitionExec for changing the partition number between two `ExecutionPlan`s -/// 2. CoalescePartitionsExec for collapsing all of the partitions into one without ordering guarantee -/// 3. SortPreservingMergeExec for collapsing all of the sorted partitions into one with ordering guarantee +/// Indicate whether a data exchange is needed for the input of `plan`. +/// +/// This identifies physical operators that redistribute child partitions or +/// gather multiple child partitions into one output partition: +/// +/// 1. RepartitionExec for non-round-robin repartitioning +/// 2. CoalescePartitionsExec for collapsing multiple partitions into one without ordering guarantee +/// 3. SortPreservingMergeExec for collapsing multiple sorted partitions into one with ordering guarantee #[expect(clippy::needless_pass_by_value)] pub fn need_data_exchange(plan: Arc) -> bool { - plan.properties().evaluation_type == EvaluationType::Eager + if let Some(repartition) = plan.downcast_ref::() { + !matches!(repartition.partitioning(), Partitioning::RoundRobinBatch(_)) + } else if let Some(coalesce) = plan.downcast_ref::() { + coalesce.input().output_partitioning().partition_count() > 1 + } else if let Some(sort_preserving_merge) = + plan.downcast_ref::() + { + sort_preserving_merge + .input() + .output_partitioning() + .partition_count() + > 1 + } else { + false + } } /// Returns a copy of this plan if we change any child according to the pointer comparison. @@ -1358,6 +1484,68 @@ pub fn check_not_null_constraints( Ok(batch) } +/// Make plan ready to be re-executed returning its clone with state reset for all nodes. +/// +/// Some plans will change their internal states after execution, making them unable to be executed again. +/// This function uses [`ExecutionPlan::reset_state`] to reset any internal state within the plan. +/// +/// An example is `CrossJoinExec`, which loads the left table into memory and stores it in the plan. +/// However, if the data of the left table is derived from the work table, it will become outdated +/// as the work table changes. When the next iteration executes this plan again, we must clear the left table. +/// +/// # Limitations +/// +/// While this function enables plan reuse, it does not allow the same plan to be executed if it (OR): +/// +/// * uses dynamic filters, +/// * represents a recursive query. +/// +pub fn reset_plan_states(plan: Arc) -> Result> { + plan.transform_up(|plan| { + let new_plan = Arc::clone(&plan).reset_state()?; + Ok(Transformed::yes(new_plan)) + }) + .data() +} + +/// Check if the `plan` children has the same properties as passed `children`. +/// In this case plan can avoid self properties re-computation when its children +/// replace is requested. +/// The size of `children` must be equal to the size of `ExecutionPlan::children()`. +pub fn has_same_children_properties( + plan: &impl ExecutionPlan, + children: &[Arc], +) -> Result { + let old_children = plan.children(); + assert_eq_or_internal_err!( + children.len(), + old_children.len(), + "Wrong number of children" + ); + for (lhs, rhs) in old_children.iter().zip(children.iter()) { + if !Arc::ptr_eq(lhs.properties(), rhs.properties()) { + return Ok(false); + } + } + Ok(true) +} + +/// Helper macro to avoid properties re-computation if passed children properties +/// the same as plan already has. Could be used to implement fast-path for method +/// [`ExecutionPlan::with_new_children`]. +#[macro_export] +macro_rules! check_if_same_properties { + ($plan: expr, $children: expr) => { + if $crate::execution_plan::has_same_children_properties( + $plan.as_ref(), + &$children, + )? { + let plan = $plan.with_new_children_and_same_properties($children); + return Ok(::std::sync::Arc::new(plan)); + } + }; +} + /// Utility function yielding a string representation of the given [`ExecutionPlan`]. pub fn get_plan_string(plan: &Arc) -> Vec { let formatted = displayable(plan.as_ref()).indent(true).to_string(); @@ -1379,18 +1567,30 @@ pub enum CardinalityEffect { GreaterEqual, } +/// Can be used in contexts where properties have not yet been initialized properly. +pub(crate) fn stub_properties() -> Arc { + static STUB_PROPERTIES: LazyLock> = LazyLock::new(|| { + Arc::new(PlanProperties::new( + EquivalenceProperties::new(Arc::new(Schema::empty())), + Partitioning::UnknownPartitioning(1), + EmissionType::Final, + Boundedness::Bounded, + )) + }); + + Arc::clone(&STUB_PROPERTIES) +} + #[cfg(test)] mod tests { - use std::any::Any; - use std::sync::Arc; use super::*; + use crate::buffer::BufferExec; + use crate::test::exec::MockExec; use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; use arrow::array::{DictionaryArray, Int32Array, NullArray, RunArray}; - use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::{Result, Statistics}; - use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + use arrow::datatypes::{DataType, Field, Schema}; #[derive(Debug)] pub struct EmptyExec; @@ -1416,11 +1616,7 @@ mod tests { Self::static_name() } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { unimplemented!() } @@ -1443,11 +1639,10 @@ mod tests { unimplemented!() } - fn statistics(&self) -> Result { - unimplemented!() - } - - fn partition_statistics(&self, _partition: Option) -> Result { + fn partition_statistics( + &self, + _partition: Option, + ) -> Result> { unimplemented!() } } @@ -1483,11 +1678,7 @@ mod tests { "MyRenamedEmptyExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { unimplemented!() } @@ -1510,15 +1701,66 @@ mod tests { unimplemented!() } - fn statistics(&self) -> Result { + fn partition_statistics( + &self, + _partition: Option, + ) -> Result> { unimplemented!() } + } + + #[derive(Debug)] + struct DowncastDelegatingExec(Arc); - fn partition_statistics(&self, _partition: Option) -> Result { + impl DisplayAs for DowncastDelegatingExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + _f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { unimplemented!() } } + impl ExecutionPlan for DowncastDelegatingExec { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn properties(&self) -> &Arc { + unimplemented!() + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + unimplemented!() + } + + fn downcast_delegate(&self) -> Option<&dyn ExecutionPlan> { + Some(self.0.as_ref()) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!() + } + + fn partition_statistics( + &self, + _partition: Option, + ) -> Result> { + unimplemented!() + } + } #[test] fn test_execution_plan_name() { let schema1 = Arc::new(Schema::empty()); @@ -1531,14 +1773,41 @@ mod tests { assert_eq!(RenamedEmptyExec::static_name(), "MyRenamedEmptyExec"); } + #[test] + fn test_execution_plan_downcast_delegates_to_downcast_delegate() { + let schema = Arc::new(Schema::empty()); + let inner: Arc = Arc::new(EmptyExec::new(schema)); + let wrapped: Arc = Arc::new(DowncastDelegatingExec(inner)); + let nested: Arc = + Arc::new(DowncastDelegatingExec(Arc::clone(&wrapped))); + + for plan in [wrapped.as_ref(), nested.as_ref()] { + assert!(!plan.is::()); + assert!(plan.downcast_ref::().is_none()); + assert!(plan.is::()); + assert!(plan.downcast_ref::().is_some()); + assert!(!plan.is::()); + assert!(plan.downcast_ref::().is_none()); + } + } + /// A compilation test to ensure that the `ExecutionPlan::name()` method can /// be called from a trait object. /// Related ticket: https://github.com/apache/datafusion/pull/11047 - #[allow(dead_code)] + #[expect(unused)] fn use_execution_plan_as_trait_object(plan: &dyn ExecutionPlan) { let _ = plan.name(); } + #[test] + fn buffer_exec_does_not_need_data_exchange() { + let schema = Arc::new(Schema::empty()); + let input: Arc = Arc::new(MockExec::new(vec![], schema)); + let buffer: Arc = Arc::new(BufferExec::new(input, 1024)); + + assert!(!need_data_exchange(buffer)); + } + #[test] fn test_check_not_null_constraints_accept_non_null() -> Result<()> { check_not_null_constraints( diff --git a/datafusion/physical-plan/src/explain.rs b/datafusion/physical-plan/src/explain.rs index 4b8491cf14dd8..98eac3d28b5df 100644 --- a/datafusion/physical-plan/src/explain.rs +++ b/datafusion/physical-plan/src/explain.rs @@ -17,7 +17,6 @@ //! Defines the EXPLAIN operator -use std::any::Any; use std::sync::Arc; use super::{DisplayAs, PlanProperties, SendableRecordBatchStream}; @@ -27,7 +26,7 @@ use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; use datafusion_common::display::StringifiedPlan; -use datafusion_common::{assert_eq_or_internal_err, Result}; +use datafusion_common::{Result, assert_eq_or_internal_err}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; @@ -44,7 +43,7 @@ pub struct ExplainExec { stringified_plans: Vec, /// control which plans to print verbose: bool, - cache: PlanProperties, + cache: Arc, } impl ExplainExec { @@ -59,7 +58,7 @@ impl ExplainExec { schema, stringified_plans, verbose, - cache, + cache: Arc::new(cache), } } @@ -108,11 +107,7 @@ impl ExecutionPlan for ExplainExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -133,7 +128,12 @@ impl ExecutionPlan for ExplainExec { partition: usize, context: Arc, ) -> Result { - trace!("Start ExplainExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + trace!( + "Start ExplainExec::execute for partition {} of context session_id {} and task_id {:?}", + partition, + context.session_id(), + context.task_id() + ); assert_eq_or_internal_err!( partition, 0, @@ -174,7 +174,11 @@ impl ExecutionPlan for ExplainExec { )?; trace!( - "Before returning RecordBatchStream in ExplainExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + "Before returning RecordBatchStream in ExplainExec::execute for partition {} of context session_id {} and task_id {:?}", + partition, + context.session_id(), + context.task_id() + ); Ok(Box::pin(RecordBatchStreamAdapter::new( Arc::clone(&self.schema), diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 58185c8cdf5b2..11d36192f3aae 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -15,33 +15,37 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; +use std::collections::hash_map::Entry; +use std::collections::{HashMap, HashSet}; use std::pin::Pin; use std::sync::Arc; -use std::task::{ready, Context, Poll}; +use std::task::{Context, Poll, ready}; +use datafusion_physical_expr::projection::{ProjectionRef, combine_projections}; use itertools::Itertools; use super::{ ColumnStatistics, DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use crate::coalesce::LimitedBatchCoalescer; -use crate::coalesce::PushBatchStatus::LimitReached; +use crate::check_if_same_properties; +use crate::coalesce::{LimitedBatchCoalescer, PushBatchStatus}; use crate::common::can_project; use crate::execution_plan::CardinalityEffect; use crate::filter_pushdown::{ ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase, - FilterPushdownPropagation, PushedDown, PushedDownPredicate, + FilterPushdownPropagation, PushedDown, }; +use crate::limit::LocalLimitExec; use crate::metrics::{MetricBuilder, MetricType}; use crate::projection::{ - make_with_child, try_embed_projection, update_expr, EmbeddedProjection, - ProjectionExec, ProjectionExpr, + EmbeddedProjection, ProjectionExec, ProjectionExpr, make_with_child, + try_embed_projection, update_expr, }; +use crate::stream::EmptyRecordBatchStream; use crate::{ - metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RatioMetrics}, DisplayFormatType, ExecutionPlan, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RatioMetrics}, }; use arrow::compute::filter_record_batch; @@ -51,17 +55,19 @@ use datafusion_common::cast::as_boolean_array; use datafusion_common::config::ConfigOptions; use datafusion_common::stats::Precision; use datafusion_common::{ - internal_err, plan_err, project_schema, DataFusionError, Result, ScalarValue, + DataFusionError, Result, ScalarValue, internal_err, plan_err, project_schema, }; use datafusion_execution::TaskContext; use datafusion_expr::Operator; use datafusion_physical_expr::equivalence::ProjectionMapping; -use datafusion_physical_expr::expressions::{lit, BinaryExpr, Column}; +use datafusion_physical_expr::expressions::{ + BinaryExpr, Column, IsNotNullExpr, Literal, lit, +}; use datafusion_physical_expr::intervals::utils::check_support; -use datafusion_physical_expr::utils::collect_columns; +use datafusion_physical_expr::utils::{collect_columns, reassign_expr_columns}; use datafusion_physical_expr::{ - analyze, conjunction, split_conjunction, AcrossPartitions, AnalysisContext, - ConstExpr, ExprBoundaries, PhysicalExpr, + AcrossPartitions, AnalysisContext, ConstExpr, ExprBoundaries, PhysicalExpr, analyze, + conjunction, split_conjunction, }; use datafusion_physical_expr_common::physical_expr::fmt_sql; @@ -84,48 +90,168 @@ pub struct FilterExec { /// Selectivity for statistics. 0 = no rows, 100 = all rows default_selectivity: u8, /// Properties equivalence properties, partitioning, etc. - cache: PlanProperties, + cache: Arc, /// The projection indices of the columns in the output schema of join - projection: Option>, + projection: Option, /// Target batch size for output batches batch_size: usize, /// Number of rows to fetch fetch: Option, } +/// Builder for [`FilterExec`] to set optional parameters +pub struct FilterExecBuilder { + predicate: Arc, + input: Arc, + projection: Option, + default_selectivity: u8, + batch_size: usize, + fetch: Option, +} + +impl FilterExecBuilder { + /// Create a new builder with required parameters (predicate and input) + pub fn new(predicate: Arc, input: Arc) -> Self { + Self { + predicate, + input, + projection: None, + default_selectivity: FILTER_EXEC_DEFAULT_SELECTIVITY, + batch_size: FILTER_EXEC_DEFAULT_BATCH_SIZE, + fetch: None, + } + } + + /// Set the input execution plan + pub fn with_input(mut self, input: Arc) -> Self { + self.input = input; + self + } + + /// Set the predicate expression + pub fn with_predicate(mut self, predicate: Arc) -> Self { + self.predicate = predicate; + self + } + + /// Set the projection, composing with any existing projection. + /// + /// If a projection is already set, the new projection indices are mapped + /// through the existing projection. For example, if the current projection + /// is `[0, 2, 3]` and `apply_projection(Some(vec![0, 2]))` is called, the + /// resulting projection will be `[0, 3]` (indices 0 and 2 of `[0, 2, 3]`). + /// + /// If no projection is currently set, the new projection is used directly. + /// If `None` is passed, the projection is cleared. + pub fn apply_projection(self, projection: Option>) -> Result { + let projection = projection.map(Into::into); + self.apply_projection_by_ref(projection.as_ref()) + } + + /// The same as [`Self::apply_projection`] but takes projection shared reference. + pub fn apply_projection_by_ref( + mut self, + projection: Option<&ProjectionRef>, + ) -> Result { + // Check if the projection is valid against current output schema + can_project(&self.input.schema(), projection.map(AsRef::as_ref))?; + self.projection = combine_projections(projection, self.projection.as_ref())?; + Ok(self) + } + + /// Set the default selectivity + pub fn with_default_selectivity(mut self, default_selectivity: u8) -> Self { + self.default_selectivity = default_selectivity; + self + } + + /// Set the batch size + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// Set the fetch limit + pub fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } + + /// Build the FilterExec, computing properties once with all configured parameters + pub fn build(self) -> Result { + // Validate predicate type + match self.predicate.data_type(self.input.schema().as_ref())? { + DataType::Boolean => {} + other => { + return plan_err!( + "Filter predicate must return BOOLEAN values, got {other:?}" + ); + } + } + + // Validate selectivity + if self.default_selectivity > 100 { + return plan_err!( + "Default filter selectivity value needs to be less than or equal to 100" + ); + } + + // Validate projection if provided + can_project(&self.input.schema(), self.projection.as_deref())?; + + // Compute properties once with all parameters + let cache = FilterExec::compute_properties( + &self.input, + &self.predicate, + self.default_selectivity, + self.projection.as_deref(), + )?; + + Ok(FilterExec { + predicate: self.predicate, + input: self.input, + metrics: ExecutionPlanMetricsSet::new(), + default_selectivity: self.default_selectivity, + cache: Arc::new(cache), + projection: self.projection, + batch_size: self.batch_size, + fetch: self.fetch, + }) + } +} + +impl From<&FilterExec> for FilterExecBuilder { + fn from(exec: &FilterExec) -> Self { + Self { + predicate: Arc::clone(&exec.predicate), + input: Arc::clone(&exec.input), + projection: exec.projection.clone(), + default_selectivity: exec.default_selectivity, + batch_size: exec.batch_size, + fetch: exec.fetch, + // We could cache / copy over PlanProperties + // here but that would require invalidating them in FilterExecBuilder::apply_projection, etc. + // and currently every call to this method ends up invalidating them anyway. + // If useful this can be added in the future as a non-breaking change. + } + } +} + impl FilterExec { - /// Create a FilterExec on an input - #[expect(clippy::needless_pass_by_value)] + /// Create a FilterExec on an input using the builder pattern pub fn try_new( predicate: Arc, input: Arc, ) -> Result { - match predicate.data_type(input.schema().as_ref())? { - DataType::Boolean => { - let default_selectivity = FILTER_EXEC_DEFAULT_SELECTIVITY; - let cache = Self::compute_properties( - &input, - &predicate, - default_selectivity, - None, - )?; - Ok(Self { - predicate, - input: Arc::clone(&input), - metrics: ExecutionPlanMetricsSet::new(), - default_selectivity, - cache, - projection: None, - batch_size: FILTER_EXEC_DEFAULT_BATCH_SIZE, - fetch: None, - }) - } - other => { - plan_err!("Filter predicate must return BOOLEAN values, got {other:?}") - } - } + FilterExecBuilder::new(predicate, input).build() } + /// Get a batch size + pub fn batch_size(&self) -> usize { + self.batch_size + } + + /// Set the default selectivity pub fn with_default_selectivity( mut self, default_selectivity: u8, @@ -140,43 +266,26 @@ impl FilterExec { } /// Return new instance of [FilterExec] with the given projection. + /// + /// # Deprecated + /// Use [`FilterExecBuilder::apply_projection`] instead + #[deprecated( + since = "52.0.0", + note = "Use FilterExecBuilder::apply_projection instead" + )] pub fn with_projection(&self, projection: Option>) -> Result { - // Check if the projection is valid - can_project(&self.schema(), projection.as_ref())?; - - let projection = match projection { - Some(projection) => match &self.projection { - Some(p) => Some(projection.iter().map(|i| p[*i]).collect()), - None => Some(projection), - }, - None => None, - }; - - let cache = Self::compute_properties( - &self.input, - &self.predicate, - self.default_selectivity, - projection.as_ref(), - )?; - Ok(Self { - predicate: Arc::clone(&self.predicate), - input: Arc::clone(&self.input), - metrics: self.metrics.clone(), - default_selectivity: self.default_selectivity, - cache, - projection, - batch_size: self.batch_size, - fetch: self.fetch, - }) + let builder = FilterExecBuilder::from(self); + builder.apply_projection(projection)?.build() } + /// Set the batch size pub fn with_batch_size(&self, batch_size: usize) -> Result { Ok(Self { predicate: Arc::clone(&self.predicate), input: Arc::clone(&self.input), metrics: self.metrics.clone(), default_selectivity: self.default_selectivity, - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), projection: self.projection.clone(), batch_size, fetch: self.fetch, @@ -199,43 +308,100 @@ impl FilterExec { } /// Projection - pub fn projection(&self) -> Option<&Vec> { - self.projection.as_ref() + pub fn projection(&self) -> &Option { + &self.projection } - /// Calculates `Statistics` for `FilterExec`, by applying selectivity (either default, or estimated) to input statistics. - fn statistics_helper( + /// Calculates `Statistics` for `FilterExec` by applying the filter's + /// selectivity (default, or estimated from interval analysis) to the input + /// statistics. + /// + /// The estimated output row count is used to keep the per-column statistics + /// consistent with it: + /// - null and distinct counts are capped at the estimated row count; + /// - byte sizes (per column and total) are scaled by the selectivity; + /// - a column constrained to a single value (`col = literal`, or an + /// interval that collapses to one point) gets a distinct count of 1; + /// - a column in a null-rejecting conjunct gets a null count of 0. + /// + /// When interval analysis applies, min/max are also tightened to the + /// surviving value range. + /// + /// A contradictory predicate (e.g. `a = 1 AND a = 2`) yields zero rows and + /// empty-column statistics. + pub(crate) fn statistics_helper( schema: &SchemaRef, input_stats: Statistics, predicate: &Arc, default_selectivity: u8, ) -> Result { - if !check_support(predicate, schema) { - let selectivity = default_selectivity as f64 / 100.0; - let mut stats = input_stats.to_inexact(); - stats.num_rows = stats.num_rows.with_estimated_selectivity(selectivity); - stats.total_byte_size = stats - .total_byte_size - .with_estimated_selectivity(selectivity); - return Ok(stats); - } - - let num_rows = input_stats.num_rows; - let total_byte_size = input_stats.total_byte_size; - let input_analysis_ctx = - AnalysisContext::try_from_statistics(schema, &input_stats.column_statistics)?; + let (eq_columns, is_infeasible) = collect_equality_columns(predicate); + + let input_num_rows = input_stats.num_rows; + let input_total_byte_size = input_stats.total_byte_size; + + let (selectivity, num_rows, column_statistics) = if is_infeasible { + // Contradictory predicate: no rows survive. Row-bounded counts are + // zero; value statistics are undefined on an empty column. + let mut cs = input_stats.to_inexact().column_statistics; + for col_stat in &mut cs { + col_stat.distinct_count = Precision::Exact(0); + col_stat.null_count = Precision::Exact(0); + col_stat.min_value = Precision::Absent; + col_stat.max_value = Precision::Absent; + col_stat.sum_value = Precision::Absent; + col_stat.byte_size = Precision::Exact(0); + } + (0.0, Precision::Exact(0), cs) + } else { + let null_rejecting_columns = collect_null_rejecting_columns(predicate); - let analysis_ctx = analyze(predicate, input_analysis_ctx, schema)?; + if check_support(predicate, schema) { + let input_analysis_ctx = AnalysisContext::try_from_statistics( + schema, + &input_stats.column_statistics, + )?; + let analysis_ctx = analyze(predicate, input_analysis_ctx, schema)?; + let selectivity = analysis_ctx.selectivity.unwrap_or(1.0); + let filtered_num_rows = + input_num_rows.with_estimated_selectivity(selectivity); + let cs = collect_new_statistics( + schema, + &input_stats.column_statistics, + analysis_ctx.boundaries, + selectivity, + &null_rejecting_columns, + filtered_num_rows, + ); + (selectivity, filtered_num_rows, cs) + } else { + // Without interval boundaries, use the default selectivity and + // apply the row-count constraints that still follow from the + // filter predicate. + let selectivity = default_selectivity as f64 / 100.0; + let filtered_num_rows = + input_num_rows.with_estimated_selectivity(selectivity); + let mut cs = input_stats.to_inexact().column_statistics; + for (idx, col_stat) in cs.iter_mut().enumerate() { + col_stat.byte_size = scale_byte_size(col_stat.byte_size, selectivity); + col_stat.null_count = if null_rejecting_columns.contains(&idx) { + Precision::Exact(0) + } else { + cap_at_rows(col_stat.null_count, filtered_num_rows) + }; + col_stat.distinct_count = if eq_columns.contains(&idx) { + distinct_count_for_singleton_domain(filtered_num_rows) + } else { + cap_at_rows(col_stat.distinct_count, filtered_num_rows) + }; + } + (selectivity, filtered_num_rows, cs) + } + }; - // Estimate (inexact) selectivity of predicate - let selectivity = analysis_ctx.selectivity.unwrap_or(1.0); - let num_rows = num_rows.with_estimated_selectivity(selectivity); - let total_byte_size = total_byte_size.with_estimated_selectivity(selectivity); + let total_byte_size = + input_total_byte_size.with_estimated_selectivity(selectivity); - let column_statistics = collect_new_statistics( - &input_stats.column_statistics, - analysis_ctx.boundaries, - ); Ok(Statistics { num_rows, total_byte_size, @@ -243,49 +409,19 @@ impl FilterExec { }) } - fn extend_constants( - input: &Arc, - predicate: &Arc, - ) -> Vec { - let mut res_constants = Vec::new(); - let input_eqs = input.equivalence_properties(); - - let conjunctions = split_conjunction(predicate); - for conjunction in conjunctions { - if let Some(binary) = conjunction.as_any().downcast_ref::() { - if binary.op() == &Operator::Eq { - // Filter evaluates to single value for all partitions - if input_eqs.is_expr_constant(binary.left()).is_some() { - let across = input_eqs - .is_expr_constant(binary.right()) - .unwrap_or_default(); - res_constants - .push(ConstExpr::new(Arc::clone(binary.right()), across)); - } else if input_eqs.is_expr_constant(binary.right()).is_some() { - let across = input_eqs - .is_expr_constant(binary.left()) - .unwrap_or_default(); - res_constants - .push(ConstExpr::new(Arc::clone(binary.left()), across)); - } - } - } - } - res_constants - } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties( input: &Arc, predicate: &Arc, default_selectivity: u8, - projection: Option<&Vec>, + projection: Option<&[usize]>, ) -> Result { // Combine the equal predicates with the input equivalence properties // to construct the equivalence properties: let schema = input.schema(); let stats = Self::statistics_helper( &schema, - input.partition_statistics(None)?, + Arc::unwrap_or_clone(input.partition_statistics(None)?), predicate, default_selectivity, )?; @@ -310,14 +446,17 @@ impl FilterExec { eq_properties.add_constants(constants)?; // This is for logical constant (for example: a = '1', then a could be marked as a constant) // to do: how to deal with multiple situation to represent = (for example c1 between 0 and 0) - eq_properties.add_constants(Self::extend_constants(input, predicate))?; + eq_properties.add_constants(ConstExpr::collect_predicate_constants( + input.equivalence_properties(), + predicate, + ))?; let mut output_partitioning = input.output_partitioning().clone(); // If contains projection, update the PlanProperties. if let Some(projection) = projection { let schema = eq_properties.schema(); let projection_mapping = ProjectionMapping::from_indices(projection, schema)?; - let out_schema = project_schema(schema, Some(projection))?; + let out_schema = project_schema(schema, Some(&projection))?; output_partitioning = output_partitioning.project(&projection_mapping, &eq_properties); eq_properties = eq_properties.project(&projection_mapping, out_schema); @@ -330,6 +469,17 @@ impl FilterExec { input.boundedness(), )) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for FilterExec { @@ -368,6 +518,9 @@ impl DisplayAs for FilterExec { ) } DisplayFormatType::TreeRender => { + if let Some(fetch) = self.fetch { + writeln!(f, "fetch={fetch}")?; + } write!(f, "predicate={}", fmt_sql(self.predicate.as_ref())) } } @@ -380,11 +533,7 @@ impl ExecutionPlan for FilterExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -401,13 +550,12 @@ impl ExecutionPlan for FilterExec { self: Arc, mut children: Vec>, ) -> Result> { - FilterExec::try_new(Arc::clone(&self.predicate), children.swap_remove(0)) - .and_then(|e| { - let selectivity = e.default_selectivity(); - e.with_default_selectivity(selectivity) - }) - .and_then(|e| e.with_projection(self.projection().cloned())) - .map(|e| e.with_fetch(self.fetch).unwrap()) + check_if_same_properties!(self, children); + let new_input = children.swap_remove(0); + FilterExecBuilder::from(&*self) + .with_input(new_input) + .build() + .map(|e| Arc::new(e) as _) } fn execute( @@ -415,7 +563,12 @@ impl ExecutionPlan for FilterExec { partition: usize, context: Arc, ) -> Result { - trace!("Start FilterExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + trace!( + "Start FilterExec::execute for partition {} of context session_id {} and task_id {:?}", + partition, + context.session_id(), + context.task_id() + ); let metrics = FilterExecMetrics::new(&self.metrics, partition); Ok(Box::pin(FilterExecStream { schema: self.schema(), @@ -437,20 +590,16 @@ impl ExecutionPlan for FilterExec { /// The output statistics of a filtering operation can be estimated if the /// predicate's selectivity value can be determined for the incoming data. - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { - let input_stats = self.input.partition_statistics(partition)?; - let schema = self.schema(); + fn partition_statistics(&self, partition: Option) -> Result> { + let input_stats = + Arc::unwrap_or_clone(self.input.partition_statistics(partition)?); let stats = Self::statistics_helper( - &schema, + &self.input.schema(), input_stats, self.predicate(), self.default_selectivity, )?; - Ok(stats.project(self.projection.as_ref())) + Ok(Arc::new(stats.project(self.projection.as_ref()))) } fn cardinality_effect(&self) -> CardinalityEffect { @@ -469,15 +618,15 @@ impl ExecutionPlan for FilterExec { if let Some(new_predicate) = update_expr(self.predicate(), projection.expr(), false)? { - return FilterExec::try_new( - new_predicate, - make_with_child(projection, self.input())?, - ) - .and_then(|e| { - let selectivity = self.default_selectivity(); - e.with_default_selectivity(selectivity) - }) - .map(|e| Some(Arc::new(e) as _)); + return FilterExecBuilder::from(self) + .with_input(make_with_child(projection, self.input())?) + .with_predicate(new_predicate) + // The original FilterExec projection referenced columns from its old + // input. After the swap the new input is the ProjectionExec which + // already handles column selection, so clear the projection here. + .apply_projection(None)? + .build() + .map(|e| Some(Arc::new(e) as _)); } } try_embed_projection(projection, self) @@ -489,16 +638,10 @@ impl ExecutionPlan for FilterExec { parent_filters: Vec>, _config: &ConfigOptions, ) -> Result { - if !matches!(phase, FilterPushdownPhase::Pre) { - // For non-pre phase, filters pass through unchanged - let filter_supports = parent_filters - .into_iter() - .map(PushedDownPredicate::supported) - .collect(); - return Ok(FilterDescription::new().with_child(ChildFilterDescription { - parent_filters: filter_supports, - self_filters: vec![], - })); + if phase != FilterPushdownPhase::Pre { + let child = + ChildFilterDescription::from_child(&parent_filters, self.input())?; + return Ok(FilterDescription::new().with_child(child)); } let child = ChildFilterDescription::from_child(&parent_filters, self.input())? @@ -518,14 +661,30 @@ impl ExecutionPlan for FilterExec { child_pushdown_result: ChildPushdownResult, _config: &ConfigOptions, ) -> Result>> { - if !matches!(phase, FilterPushdownPhase::Pre) { + if phase != FilterPushdownPhase::Pre { return Ok(FilterPushdownPropagation::if_all(child_pushdown_result)); } // We absorb any parent filters that were not handled by our children - let unsupported_parent_filters = - child_pushdown_result.parent_filters.iter().filter_map(|f| { - matches!(f.all(), PushedDown::No).then_some(Arc::clone(&f.filter)) - }); + let mut unsupported_parent_filters: Vec> = + child_pushdown_result + .parent_filters + .iter() + .filter_map(|f| { + matches!(f.all(), PushedDown::No).then_some(Arc::clone(&f.filter)) + }) + .collect(); + + // If this FilterExec has a projection, the unsupported parent filters + // are in the output schema (after projection) coordinates. We need to + // remap them to the input schema coordinates before combining with self filters. + if self.projection.is_some() { + let input_schema = self.input().schema(); + unsupported_parent_filters = unsupported_parent_filters + .into_iter() + .map(|expr| reassign_expr_columns(expr, &input_schema)) + .collect::>>()?; + } + let unsupported_self_filters = child_pushdown_result .self_filters .first() @@ -546,8 +705,23 @@ impl ExecutionPlan for FilterExec { let filter_input = Arc::clone(self.input()); let new_predicate = conjunction(unhandled_filters); let updated_node = if new_predicate.eq(&lit(true)) { - // FilterExec is no longer needed, but we may need to leave a projection in place - match self.projection() { + // FilterExec is no longer needed, but we may need to leave a projection in place. + // If this FilterExec had a fetch limit, propagate it to the child. + // When the child also has a fetch, use the minimum of both to preserve + // the tighter constraint. + let filter_input = if let Some(outer_fetch) = self.fetch { + let effective_fetch = match filter_input.fetch() { + Some(inner_fetch) => outer_fetch.min(inner_fetch), + None => outer_fetch, + }; + match filter_input.with_fetch(Some(effective_fetch)) { + Some(node) => node, + None => Arc::new(LocalLimitExec::new(filter_input, effective_fetch)), + } + } else { + filter_input + }; + match self.projection().as_ref() { Some(projection_indices) => { let filter_child_schema = filter_input.schema(); let proj_exprs = projection_indices @@ -573,19 +747,19 @@ impl ExecutionPlan for FilterExec { // The new predicate is the same as our current predicate None } else { - // Create a new FilterExec with the new predicate + // Create a new FilterExec with the new predicate, preserving the projection let new = FilterExec { predicate: Arc::clone(&new_predicate), input: Arc::clone(&filter_input), metrics: self.metrics.clone(), default_selectivity: self.default_selectivity, - cache: Self::compute_properties( + cache: Arc::new(Self::compute_properties( &filter_input, &new_predicate, self.default_selectivity, - self.projection.as_ref(), - )?, - projection: None, + self.projection.as_deref(), + )?), + projection: self.projection.clone(), batch_size: self.batch_size, fetch: self.fetch, }; @@ -598,33 +772,213 @@ impl ExecutionPlan for FilterExec { }) } + fn fetch(&self) -> Option { + self.fetch + } + fn with_fetch(&self, fetch: Option) -> Option> { Some(Arc::new(Self { predicate: Arc::clone(&self.predicate), input: Arc::clone(&self.input), metrics: self.metrics.clone(), default_selectivity: self.default_selectivity, - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), projection: self.projection.clone(), batch_size: self.batch_size, fetch, })) } + + fn with_preserve_order( + &self, + preserve_order: bool, + ) -> Option> { + self.input + .with_preserve_order(preserve_order) + .and_then(|new_input| { + Arc::new(self.clone()) + .with_new_children(vec![new_input]) + .ok() + }) + } } impl EmbeddedProjection for FilterExec { fn with_projection(&self, projection: Option>) -> Result { - self.with_projection(projection) + FilterExecBuilder::from(self) + .apply_projection(projection)? + .build() + } +} + +/// Collects column equality information from `col = literal` predicates in a +/// conjunction. +/// +/// Returns `(eq_columns, is_infeasible)`: +/// - `eq_columns`: set of column indices constrained to a single literal value. +/// - `is_infeasible`: `true` when the same column is equated to two different +/// non-null literals (e.g. `name = 'alice' AND name = 'bob'`), which is +/// always unsatisfiable. +/// +/// Only AND conjunctions are traversed; OR is intentionally skipped +/// since `a = 1 OR a = 2` does not pin NDV to 1. +fn collect_equality_columns(predicate: &Arc) -> (HashSet, bool) { + let mut eq_values: HashMap = HashMap::new(); + let mut infeasible = false; + + for expr in split_conjunction(predicate) { + let Some(binary) = expr.downcast_ref::() else { + continue; + }; + if *binary.op() != Operator::Eq { + continue; + } + let left = binary.left(); + let right = binary.right(); + let pair = if let Some(col) = left.downcast_ref::() + && let Some(lit) = right.downcast_ref::() + && !lit.value().is_null() + { + Some((col.index(), lit.value().clone())) + } else if let Some(col) = right.downcast_ref::() + && let Some(lit) = left.downcast_ref::() + && !lit.value().is_null() + { + Some((col.index(), lit.value().clone())) + } else { + None + }; + + if let Some((idx, value)) = pair { + match eq_values.entry(idx) { + Entry::Occupied(prev) => { + if *prev.get() != value { + infeasible = true; + break; + } + } + Entry::Vacant(slot) => { + slot.insert(value); + } + } + } + } + + (eq_values.into_keys().collect(), infeasible) +} + +/// Collects columns that cannot be NULL in any surviving row. +/// +/// A filter keeps only rows where the predicate is TRUE, so a column is +/// null-rejecting if some top-level AND conjunct evaluates to NULL or FALSE +/// whenever that column is NULL. Two such conjuncts are recognized: +/// +/// - a binary operator that returns NULL on NULL input, applied directly to the +/// column (e.g. `a = 10`, `a < b`); +/// - an `IS NOT NULL` check on the column (e.g. `a IS NOT NULL`). +/// +/// This analysis is conservative; for example, OR clauses are not considered +/// null-rejecting, and neither are indirect operands like `a + 1 < 10`. +fn collect_null_rejecting_columns(predicate: &Arc) -> HashSet { + let mut columns = HashSet::new(); + + for expr in split_conjunction(predicate) { + // `col IS NOT NULL` keeps only rows where `col` is non-null. + if let Some(is_not_null) = expr.downcast_ref::() { + if let Some(col) = is_not_null.arg().downcast_ref::() { + columns.insert(col.index()); + } + continue; + } + + // A binary operator that returns NULL on NULL input rejects rows where + // a direct column operand is NULL. + if let Some(binary) = expr.downcast_ref::() { + if !binary.op().returns_null_on_null() { + continue; + } + if let Some(col) = binary.left().downcast_ref::() { + columns.insert(col.index()); + } + if let Some(col) = binary.right().downcast_ref::() { + columns.insert(col.index()); + } + } + } + + columns +} + +/// Converts an interval bound to a [`Precision`] value. NULL bounds (which +/// represent "unbounded" in the interval type) map to [`Precision::Absent`]. +fn interval_bound_to_precision( + bound: ScalarValue, + is_exact: bool, +) -> Precision { + if bound.is_null() { + Precision::Absent + } else if is_exact { + Precision::Exact(bound) + } else { + Precision::Inexact(bound) + } +} + +/// Scales a column's `byte_size` by the estimated filter `selectivity`. An +/// exact zero is preserved: an empty column stays exactly empty after +/// filtering. +fn scale_byte_size(byte_size: Precision, selectivity: f64) -> Precision { + match byte_size { + Precision::Exact(0) => Precision::Exact(0), + byte_size => byte_size.with_estimated_selectivity(selectivity), + } +} + +/// Caps a row-bounded column statistic (a null count or distinct count) at the +/// filtered row estimate, since a column cannot have more nulls or distinct +/// values than it has rows. Known counts are demoted to inexact because the +/// filtered row count is itself an estimate. +fn cap_at_rows( + value: Precision, + filtered_num_rows: Precision, +) -> Precision { + match filtered_num_rows { + Precision::Absent => value.to_inexact(), + rows => value.to_inexact().min(&rows), + } +} + +/// Returns the NDV for a column constrained to one non-null value (e.g. +/// `column = literal` or a singleton interval), derived from the filtered row +/// estimate: zero rows means zero distinct values, a known positive row count +/// means exactly one, and an unknown row count means an inexact one (the column +/// could still be empty). +/// +/// The caller is responsible for proving the singleton domain. +fn distinct_count_for_singleton_domain( + filtered_num_rows: Precision, +) -> Precision { + match filtered_num_rows { + Precision::Exact(0) | Precision::Inexact(0) => filtered_num_rows, + // The row count is unknown, so the column could still be empty (zero + // distinct values); report an inexact one rather than overstating it. + Precision::Absent => Precision::Inexact(1), + _ => Precision::Exact(1), } } -/// This function ensures that all bounds in the `ExprBoundaries` vector are -/// converted to closed bounds. If a lower/upper bound is initially open, it -/// is adjusted by using the next/previous value for its data type to convert -/// it into a closed bound. +/// Builds output column statistics from interval-analysis boundaries. +/// +/// The interval bounds become min/max values, singleton intervals become +/// singleton NDV, and row-bounded counts are kept consistent with the filtered +/// row estimate. fn collect_new_statistics( + schema: &SchemaRef, input_column_stats: &[ColumnStatistics], analysis_boundaries: Vec, + selectivity: f64, + null_rejecting_columns: &HashSet, + filtered_num_rows: Precision, ) -> Vec { analysis_boundaries .into_iter() @@ -639,27 +993,49 @@ fn collect_new_statistics( }, )| { let Some(interval) = interval else { - // If the interval is `None`, we can say that there are no rows: + // If the interval is `None`, we can say that there are no rows. + // Use a typed null to preserve the column's data type, so that + // downstream interval analysis can still intersect intervals + // of the same type. + let typed_null = ScalarValue::try_from(schema.field(idx).data_type()) + .unwrap_or(ScalarValue::Null); return ColumnStatistics { null_count: Precision::Exact(0), - max_value: Precision::Exact(ScalarValue::Null), - min_value: Precision::Exact(ScalarValue::Null), - sum_value: Precision::Exact(ScalarValue::Null), + max_value: Precision::Exact(typed_null.clone()), + min_value: Precision::Exact(typed_null.clone()), + sum_value: Precision::Exact(typed_null), distinct_count: Precision::Exact(0), + byte_size: Precision::Exact(0), }; }; let (lower, upper) = interval.into_bounds(); - let (min_value, max_value) = if lower.eq(&upper) { - (Precision::Exact(lower), Precision::Exact(upper)) + let is_single_value = + !lower.is_null() && !upper.is_null() && lower == upper; + let min_value = interval_bound_to_precision(lower, is_single_value); + let max_value = interval_bound_to_precision(upper, is_single_value); + + // Distinct and null counts cannot exceed the number of rows + // that survive the filter. Singleton intervals and + // null-rejecting predicates provide tighter bounds. + let capped_distinct_count = if is_single_value { + distinct_count_for_singleton_domain(filtered_num_rows) + } else { + cap_at_rows(distinct_count, filtered_num_rows) + }; + let capped_null_count = if null_rejecting_columns.contains(&idx) { + Precision::Exact(0) } else { - (Precision::Inexact(lower), Precision::Inexact(upper)) + cap_at_rows(input_column_stats[idx].null_count, filtered_num_rows) }; + let byte_size = + scale_byte_size(input_column_stats[idx].byte_size, selectivity); ColumnStatistics { - null_count: input_column_stats[idx].null_count.to_inexact(), + null_count: capped_null_count, max_value, min_value, sum_value: Precision::Absent, - distinct_count: distinct_count.to_inexact(), + distinct_count: capped_distinct_count, + byte_size, } }, ) @@ -678,17 +1054,19 @@ struct FilterExecStream { /// Runtime metrics recording metrics: FilterExecMetrics, /// The projection indices of the columns in the input schema - projection: Option>, + projection: Option, /// Batch coalescer to combine small batches batch_coalescer: LimitedBatchCoalescer, } /// The metrics for `FilterExec` struct FilterExecMetrics { - // Common metrics for most operators + /// Common metrics for most operators baseline_metrics: BaselineMetrics, - // Selectivity of the filter, calculated as output_rows / input_rows + /// Selectivity of the filter, calculated as output_rows / input_rows selectivity: RatioMetrics, + // Remember to update `docs/source/user-guide/metrics.md` when adding new metrics, + // or modifying metrics comments } impl FilterExecMetrics { @@ -696,29 +1074,12 @@ impl FilterExecMetrics { Self { baseline_metrics: BaselineMetrics::new(metrics, partition), selectivity: MetricBuilder::new(metrics) - .with_type(MetricType::SUMMARY) + .with_type(MetricType::Summary) .ratio_metrics("selectivity", partition), } } } -impl FilterExecStream { - fn flush_remaining_batches( - &mut self, - ) -> Poll>> { - // Flush any remaining buffered batch - match self.batch_coalescer.finish() { - Ok(()) => { - Poll::Ready(self.batch_coalescer.next_completed_batch().map(|batch| { - self.metrics.selectivity.add_part(batch.num_rows()); - Ok(batch) - })) - } - Err(e) => Poll::Ready(Some(Err(e))), - } - } -} - pub fn batch_filter( batch: &RecordBatch, predicate: &Arc, @@ -758,18 +1119,37 @@ impl Stream for FilterExecStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let poll; let elapsed_compute = self.metrics.baseline_metrics.elapsed_compute().clone(); loop { + // If there is a completed batch ready, return it + if let Some(batch) = self.batch_coalescer.next_completed_batch() { + self.metrics.selectivity.add_part(batch.num_rows()); + let poll = Poll::Ready(Some(Ok(batch))); + return self.metrics.baseline_metrics.record_poll(poll); + } + + if self.batch_coalescer.is_finished() { + // If input is done and no batches are ready, return None to signal end of stream. + return Poll::Ready(None); + } + + // Attempt to pull the next batch from the input stream. match ready!(self.input.poll_next_unpin(cx)) { + None => { + self.batch_coalescer.finish()?; + // Release the input pipeline's resources. + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); + // continue draining the coalescer + } Some(Ok(batch)) => { let timer = elapsed_compute.timer(); let status = self.predicate.as_ref() .evaluate(&batch) .and_then(|v| v.into_array(batch.num_rows())) .and_then(|array| { - Ok(match self.projection { - Some(ref projection) => { + Ok(match self.projection.as_ref() { + Some(projection) => { let projected_batch = batch.project(projection)?; (array, projected_batch) }, @@ -793,37 +1173,26 @@ impl Stream for FilterExecStream { })?; timer.done(); - if let LimitReached = status { - poll = self.flush_remaining_batches(); - break; - } - - if let Some(batch) = self.batch_coalescer.next_completed_batch() { - self.metrics.selectivity.add_part(batch.num_rows()); - poll = Poll::Ready(Some(Ok(batch))); - break; - } - continue; - } - None => { - // Flush any remaining buffered batch - match self.batch_coalescer.finish() { - Ok(()) => { - poll = self.flush_remaining_batches(); + match status { + PushBatchStatus::Continue => { + // Keep pushing more batches } - Err(e) => { - poll = Poll::Ready(Some(Err(e))); + PushBatchStatus::LimitReached => { + // limit was reached, so stop early + self.batch_coalescer.finish()?; + // Release the input pipeline's resources. + let input_schema = self.input.schema(); + self.input = + Box::pin(EmptyRecordBatchStream::new(input_schema)); + // continue draining the coalescer } } - break; - } - value => { - poll = Poll::Ready(value); - break; } + + // Error case + other => return Poll::Ready(other), } } - self.metrics.baseline_metrics.record_poll(poll) } fn size_hint(&self) -> (usize, Option) { @@ -856,7 +1225,20 @@ fn collect_columns_from_predicate_inner( let predicates = split_conjunction(predicate); predicates.into_iter().for_each(|p| { - if let Some(binary) = p.as_any().downcast_ref::() { + if let Some(binary) = p.downcast_ref::() { + // Only extract pairs where at least one side is a Column reference. + // Pairs like `complex_expr = literal` should not create equivalence + // classes — the literal could appear in many unrelated expressions + // (e.g. sort keys), and normalize_expr's deep traversal would + // replace those occurrences with the complex expression, corrupting + // sort orderings. Constant propagation for such pairs is handled + // separately by `extend_constants`. + let has_direct_column_operand = + binary.left().downcast_ref::().is_some() + || binary.right().downcast_ref::().is_some(); + if !has_direct_column_operand { + return; + } match binary.op() { Operator::Eq => { eq_predicate_columns.push((binary.left(), binary.right())) @@ -887,7 +1269,6 @@ mod tests { use crate::test; use crate::test::exec::StatisticsExec; use arrow::datatypes::{Field, Schema, UnionFields, UnionMode}; - use datafusion_common::ScalarValue; #[tokio::test] async fn collect_columns_predicates() -> Result<()> { @@ -970,6 +1351,8 @@ mod tests { assert_eq!( statistics.column_statistics, vec![ColumnStatistics { + // `a <= 25` rejects nulls, so the column has no surviving nulls. + null_count: Precision::Exact(0), min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), max_value: Precision::Inexact(ScalarValue::Int32(Some(25))), ..Default::default() @@ -1016,6 +1399,8 @@ mod tests { assert_eq!( statistics.column_statistics, vec![ColumnStatistics { + // `a <= 25 AND a >= 10` rejects nulls in `a`. + null_count: Precision::Exact(0), min_value: Precision::Inexact(ScalarValue::Int32(Some(10))), max_value: Precision::Inexact(ScalarValue::Int32(Some(25))), ..Default::default() @@ -1083,11 +1468,16 @@ mod tests { statistics.column_statistics, vec![ ColumnStatistics { + // `a <= 25 AND a >= 10` rejects nulls in `a`. + null_count: Precision::Exact(0), min_value: Precision::Inexact(ScalarValue::Int32(Some(10))), max_value: Precision::Inexact(ScalarValue::Int32(Some(25))), ..Default::default() }, ColumnStatistics { + // `b > 45` in the upstream filter zeroes b's nulls; the outer + // filter then caps the (already zero) count, demoting to inexact. + null_count: Precision::Inexact(0), min_value: Precision::Inexact(ScalarValue::Int32(Some(46))), max_value: Precision::Inexact(ScalarValue::Int32(Some(50))), ..Default::default() @@ -1214,7 +1604,7 @@ mod tests { ]; let _ = exp_col_stats .into_iter() - .zip(statistics.column_statistics) + .zip(statistics.column_statistics.clone()) .map(|(expected, actual)| { if let Some(val) = actual.min_value.get_value() { if val.data_type().is_floating() { @@ -1284,8 +1674,13 @@ mod tests { Arc::new(Column::new("b", 1)), )), )); - // Since filter predicate passes all entries, statistics after filter shouldn't change. - let expected = input.partition_statistics(None)?.column_statistics; + // The filter predicate passes all (non-null) entries, so min/max/NDV + // are unchanged. `a < 200` and `1 <= b` are null-rejecting, though, so + // both columns lose any nulls regardless of selectivity. + let mut expected = input.partition_statistics(None)?.column_statistics.clone(); + for col in &mut expected { + col.null_count = Precision::Exact(0); + } let filter: Arc = Arc::new(FilterExec::try_new(predicate, input)?); let statistics = filter.partition_statistics(None)?; @@ -1349,18 +1744,20 @@ mod tests { statistics.column_statistics, vec![ ColumnStatistics { - min_value: Precision::Exact(ScalarValue::Null), - max_value: Precision::Exact(ScalarValue::Null), - sum_value: Precision::Exact(ScalarValue::Null), + min_value: Precision::Exact(ScalarValue::Int32(None)), + max_value: Precision::Exact(ScalarValue::Int32(None)), + sum_value: Precision::Exact(ScalarValue::Int32(None)), distinct_count: Precision::Exact(0), null_count: Precision::Exact(0), + byte_size: Precision::Exact(0), }, ColumnStatistics { - min_value: Precision::Exact(ScalarValue::Null), - max_value: Precision::Exact(ScalarValue::Null), - sum_value: Precision::Exact(ScalarValue::Null), + min_value: Precision::Exact(ScalarValue::Int32(None)), + max_value: Precision::Exact(ScalarValue::Int32(None)), + sum_value: Precision::Exact(ScalarValue::Int32(None)), distinct_count: Precision::Exact(0), null_count: Precision::Exact(0), + byte_size: Precision::Exact(0), }, ] ); @@ -1368,6 +1765,70 @@ mod tests { Ok(()) } + /// Regression test: stacking two FilterExecs where the inner filter + /// proves zero selectivity should not panic with a type mismatch + /// during interval intersection. + /// + /// Previously, when a filter proved no rows could match, the column + /// statistics used untyped `ScalarValue::Null` (data type `Null`). + /// If an outer FilterExec then tried to analyze its own predicate + /// against those statistics, `Interval::intersect` would fail with: + /// "Only intervals with the same data type are intersectable, lhs:Null, rhs:Int32" + #[tokio::test] + async fn test_nested_filter_with_zero_selectivity_inner() -> Result<()> { + // Inner table: a: [1, 100], b: [1, 3] + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(3))), + ..Default::default() + }, + ], + }, + schema, + )); + + // Inner filter: a > 200 (impossible given a max=100 → zero selectivity) + let inner_predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(200)))), + )); + let inner_filter: Arc = + Arc::new(FilterExec::try_new(inner_predicate, input)?); + + // Outer filter: a = 50 + // Before the fix, this would panic because the inner filter's + // zero-selectivity statistics produced Null-typed intervals for + // column `a`, which couldn't intersect with the Int32 literal. + let outer_predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(50)))), + )); + let outer_filter: Arc = + Arc::new(FilterExec::try_new(outer_predicate, inner_filter)?); + + // Should succeed without error + let statistics = outer_filter.partition_statistics(None)?; + assert_eq!(statistics.num_rows, Precision::Inexact(0)); + + Ok(()) + } + #[tokio::test] async fn test_filter_statistics_more_inputs() -> Result<()> { let schema = Schema::new(vec![ @@ -1409,10 +1870,14 @@ mod tests { statistics.column_statistics, vec![ ColumnStatistics { + // `a < 50` rejects nulls in `a`. + null_count: Precision::Exact(0), min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), max_value: Precision::Inexact(ScalarValue::Int32(Some(49))), ..Default::default() }, + // `b` is not referenced by the predicate, so its stats are + // unchanged (null count stays unknown). ColumnStatistics { min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), @@ -1457,15 +1922,18 @@ mod tests { num_rows: Precision::Absent, total_byte_size: Precision::Absent, column_statistics: vec![ColumnStatistics { - null_count: Precision::Absent, + // `a <= 10` rejects nulls, so `a` has no surviving nulls even + // though the input statistics are entirely unknown. + null_count: Precision::Exact(0), min_value: Precision::Inexact(ScalarValue::Int32(Some(5))), max_value: Precision::Inexact(ScalarValue::Int32(Some(10))), sum_value: Precision::Absent, distinct_count: Precision::Absent, + byte_size: Precision::Absent, }], }; - assert_eq!(filter_statistics, expected_filter_statistics); + assert_eq!(*filter_statistics, expected_filter_statistics); Ok(()) } @@ -1545,13 +2013,14 @@ mod tests { #[test] fn test_equivalence_properties_union_type() -> Result<()> { let union_type = DataType::Union( - UnionFields::new( + UnionFields::try_new( vec![0, 1], vec![ Field::new("f1", DataType::Int32, true), Field::new("f2", DataType::Utf8, true), ], - ), + ) + .unwrap(), UnionMode::Sparse, ); @@ -1574,4 +2043,1545 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_builder_with_projection() -> Result<()> { + // Create a schema with multiple columns + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + // Create a filter predicate: a > 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + + // Create filter with projection [0, 2] (columns a and c) using builder + let projection = Some(vec![0, 2]); + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(projection.clone()) + .unwrap() + .build()?; + + // Verify projection is set correctly + assert_eq!(filter.projection(), &Some([0, 2].into())); + + // Verify schema contains only projected columns + let output_schema = filter.schema(); + assert_eq!(output_schema.fields().len(), 2); + assert_eq!(output_schema.field(0).name(), "a"); + assert_eq!(output_schema.field(1).name(), "c"); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_without_projection() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )); + + // Create filter without projection using builder + let filter = FilterExecBuilder::new(predicate, input).build()?; + + // Verify no projection is set + assert!(filter.projection().is_none()); + + // Verify schema contains all columns + let output_schema = filter.schema(); + assert_eq!(output_schema.fields().len(), 2); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_invalid_projection() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )); + + // Try to create filter with invalid projection (index out of bounds) using builder + let result = + FilterExecBuilder::new(predicate, input).apply_projection(Some(vec![0, 5])); // 5 is out of bounds + + // Should return an error + assert!(result.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_vs_with_projection() -> Result<()> { + // This test verifies that the builder with projection produces the same result + // as try_new().with_projection(), but more efficiently (one compute_properties call) + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + Field::new("d", DataType::Int32, false), + ]); + + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(4000), + column_statistics: vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ColumnStatistics { + ..Default::default() + }, + ColumnStatistics { + ..Default::default() + }, + ColumnStatistics { + ..Default::default() + }, + ], + }, + schema, + )); + let input: Arc = input; + + let predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(50)))), + )); + + let projection = Some(vec![0, 2]); + + // Method 1: Builder with projection (one call to compute_properties) + let filter1 = FilterExecBuilder::new(Arc::clone(&predicate), Arc::clone(&input)) + .apply_projection(projection.clone()) + .unwrap() + .build()?; + + // Method 2: Also using builder for comparison (deprecated try_new().with_projection() removed) + let filter2 = FilterExecBuilder::new(predicate, input) + .apply_projection(projection) + .unwrap() + .build()?; + + // Both methods should produce equivalent results + assert_eq!(filter1.schema(), filter2.schema()); + assert_eq!(filter1.projection(), filter2.projection()); + + // Verify statistics are the same + let stats1 = filter1.partition_statistics(None)?; + let stats2 = filter2.partition_statistics(None)?; + assert_eq!(stats1.num_rows, stats2.num_rows); + assert_eq!(stats1.total_byte_size, stats2.total_byte_size); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_statistics_with_projection() -> Result<()> { + // Test that statistics are correctly computed when using builder with projection + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(12000), + column_statistics: vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(10))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(200))), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(5))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(50))), + ..Default::default() + }, + ], + }, + schema, + )); + + // Filter: a < 50, Project: [0, 2] + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(50)))), + )); + + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(Some(vec![0, 2])) + .unwrap() + .build()?; + + let statistics = filter.partition_statistics(None)?; + + // Verify statistics reflect both filtering and projection + assert!(matches!(statistics.num_rows, Precision::Inexact(_))); + + // Schema should only have 2 columns after projection + assert_eq!(filter.schema().fields().len(), 2); + + Ok(()) + } + + #[test] + fn test_builder_predicate_validation() -> Result<()> { + // Test that builder validates predicate type correctly + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + // Create a predicate that doesn't return boolean (returns Int32) + let invalid_predicate = Arc::new(Column::new("a", 0)); + + // Should fail because predicate doesn't return boolean + let result = FilterExecBuilder::new(invalid_predicate, input) + .apply_projection(Some(vec![0])) + .unwrap() + .build(); + + assert!(result.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_projection_composition() -> Result<()> { + // Test that calling apply_projection multiple times composes projections + // If initial projection is [0, 2, 3] and we call apply_projection([0, 2]), + // the result should be [0, 3] (indices 0 and 2 of [0, 2, 3]) + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + Field::new("d", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + // Create a filter predicate: a > 10 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + + // First projection: [0, 2, 3] -> select columns a, c, d + // Second projection: [0, 2] -> select indices 0 and 2 of [0, 2, 3] -> [0, 3] + // Final result: columns a and d + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(Some(vec![0, 2, 3]))? + .apply_projection(Some(vec![0, 2]))? + .build()?; + + // Verify composed projection is [0, 3] + assert_eq!(filter.projection(), &Some([0, 3].into())); + + // Verify schema contains only columns a and d + let output_schema = filter.schema(); + assert_eq!(output_schema.fields().len(), 2); + assert_eq!(output_schema.field(0).name(), "a"); + assert_eq!(output_schema.field(1).name(), "d"); + + Ok(()) + } + + #[tokio::test] + async fn test_builder_projection_composition_none_clears() -> Result<()> { + // Test that passing None clears the projection + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )); + + // Set a projection then clear it with None + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(Some(vec![0]))? + .apply_projection(None)? + .build()?; + + // Projection should be cleared + assert_eq!(filter.projection(), &None); + + // Schema should have all columns + let output_schema = filter.schema(); + assert_eq!(output_schema.fields().len(), 2); + + Ok(()) + } + + #[test] + fn test_filter_with_projection_remaps_post_phase_parent_filters() -> Result<()> { + // Test that FilterExec with a projection must remap parent dynamic + // filter column indices from its output schema to the input schema + // before passing them to the child. + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Float64, false), + ])); + let input = Arc::new(EmptyExec::new(Arc::clone(&input_schema))); + + // FilterExec: a > 0, projection=[c@2] + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), + )); + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(Some(vec![2]))? + .build()?; + + // Output schema should be [c:Float64] + let output_schema = filter.schema(); + assert_eq!(output_schema.fields().len(), 1); + assert_eq!(output_schema.field(0).name(), "c"); + + // Simulate a parent dynamic filter referencing output column c@0 + let parent_filter: Arc = Arc::new(Column::new("c", 0)); + + let config = ConfigOptions::new(); + let desc = filter.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![parent_filter], + &config, + )?; + + // The filter pushed to the child must reference c@2 (input schema), + // not c@0 (output schema). + let parent_filters = desc.parent_filters(); + assert_eq!(parent_filters.len(), 1); // one child + assert_eq!(parent_filters[0].len(), 1); // one filter + let remapped = &parent_filters[0][0].predicate; + let display = format!("{remapped}"); + assert_eq!( + display, "c@2", + "Post-phase parent filter column index must be remapped \ + from output schema (c@0) to input schema (c@2)" + ); + + Ok(()) + } + + /// Regression test for https://github.com/apache/datafusion/issues/20194 + /// + /// `collect_columns_from_predicate_inner` should only extract equality + /// pairs where at least one side is a Column. Pairs like + /// `complex_expr = literal` must not create equivalence classes because + /// `normalize_expr`'s deep traversal would replace the literal inside + /// unrelated expressions (e.g. sort keys) with the complex expression. + #[test] + fn test_collect_columns_skips_non_column_pairs() -> Result<()> { + let schema = test::aggr_test_schema(); + + // Simulate: nvl(c2, 0) = 0 → (c2 IS DISTINCT FROM 0) = 0 + // Neither side is a Column, so this should NOT be extracted. + let complex_expr: Arc = binary( + col("c2", &schema)?, + Operator::IsDistinctFrom, + lit(0u32), + &schema, + )?; + let predicate: Arc = + binary(complex_expr, Operator::Eq, lit(0u32), &schema)?; + + let (equal_pairs, _) = collect_columns_from_predicate_inner(&predicate); + assert_eq!( + 0, + equal_pairs.len(), + "Should not extract equality pairs where neither side is a Column" + ); + + // But col = literal should still be extracted + let predicate: Arc = + binary(col("c2", &schema)?, Operator::Eq, lit(0u32), &schema)?; + let (equal_pairs, _) = collect_columns_from_predicate_inner(&predicate); + assert_eq!( + 1, + equal_pairs.len(), + "Should extract equality pairs where one side is a Column" + ); + + Ok(()) + } + + /// Columns with Absent min/max statistics should remain Absent after + /// FilterExec. + #[tokio::test] + async fn test_filter_statistics_absent_columns_stay_absent() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Absent, + column_statistics: vec![ + ColumnStatistics::default(), + ColumnStatistics::default(), + ], + }, + schema.clone(), + )); + + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + + let statistics = filter.partition_statistics(None)?; + let col_b_stats = &statistics.column_statistics[1]; + assert_eq!(col_b_stats.min_value, Precision::Absent); + assert_eq!(col_b_stats.max_value, Precision::Absent); + + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_equality_ndv() -> Result<()> { + #[expect(clippy::type_complexity)] + let cases: Vec<( + &str, + Vec, + Vec, + Arc, + Vec>, + )> = vec![ + ( + "utf8 equality", + vec![Field::new("name", DataType::Utf8, false)], + vec![ColumnStatistics { + distinct_count: Precision::Inexact(50), + ..Default::default() + }], + Arc::new(BinaryExpr::new( + Arc::new(Column::new("name", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Utf8(Some("hello".to_string())))), + )), + vec![Precision::Exact(1)], + ), + ( + "utf8view equality", + vec![Field::new("name", DataType::Utf8View, false)], + vec![ColumnStatistics { + distinct_count: Precision::Inexact(50), + ..Default::default() + }], + Arc::new(BinaryExpr::new( + Arc::new(Column::new("name", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Utf8View(Some( + "hello".to_string(), + )))), + )), + vec![Precision::Exact(1)], + ), + ( + "largeutf8 equality", + vec![Field::new("name", DataType::LargeUtf8, false)], + vec![ColumnStatistics { + distinct_count: Precision::Inexact(50), + ..Default::default() + }], + Arc::new(BinaryExpr::new( + Arc::new(Column::new("name", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::LargeUtf8(Some( + "hello".to_string(), + )))), + )), + vec![Precision::Exact(1)], + ), + ( + "utf8 reversed (literal = column)", + vec![Field::new("name", DataType::Utf8, false)], + vec![ColumnStatistics { + distinct_count: Precision::Inexact(50), + ..Default::default() + }], + Arc::new(BinaryExpr::new( + Arc::new(Literal::new(ScalarValue::Utf8(Some("hello".to_string())))), + Operator::Eq, + Arc::new(Column::new("name", 0)), + )), + vec![Precision::Exact(1)], + ), + ( + "OR is not collapsed to NDV=1, but NDV is capped at filtered rows", + vec![Field::new("name", DataType::Utf8, false)], + vec![ColumnStatistics { + distinct_count: Precision::Inexact(50), + ..Default::default() + }], + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("name", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Utf8(Some("a".to_string())))), + )), + Operator::Or, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("name", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Utf8(Some("b".to_string())))), + )), + )), + // Input NDV is 50, but the 20% default selectivity on 100 rows + // estimates 20 output rows, so NDV is capped at 20. + vec![Precision::Inexact(20)], + ), + ( + "AND with mixed types (Utf8 + Int32)", + vec![ + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::Int32, false), + ], + vec![ + ColumnStatistics { + distinct_count: Precision::Inexact(50), + ..Default::default() + }, + ColumnStatistics { + distinct_count: Precision::Inexact(80), + ..Default::default() + }, + ], + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("name", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Utf8(Some( + "hello".to_string(), + )))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("age", 1)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))), + )), + )), + vec![Precision::Exact(1), Precision::Exact(1)], + ), + ( + "numeric equality with min/max bounds (interval analysis path)", + vec![Field::new("a", DataType::Int32, false)], + vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + distinct_count: Precision::Inexact(80), + ..Default::default() + }], + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))), + )), + vec![Precision::Exact(1)], + ), + ( + "timestamp equality", + vec![Field::new( + "ts", + DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, None), + false, + )], + vec![ColumnStatistics { + distinct_count: Precision::Inexact(500), + ..Default::default() + }], + Arc::new(BinaryExpr::new( + Arc::new(Column::new("ts", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::TimestampNanosecond( + Some(1_609_459_200_000_000_000), + None, + ))), + )), + vec![Precision::Exact(1)], + ), + ( + "contradictory numeric equality (infeasible)", + vec![Field::new("a", DataType::Int32, false)], + vec![ColumnStatistics { + distinct_count: Precision::Inexact(50), + ..Default::default() + }], + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(99)))), + )), + )), + vec![Precision::Exact(0)], + ), + ( + "utf8 equality with absent input NDV", + vec![Field::new("name", DataType::Utf8, false)], + vec![ColumnStatistics { + distinct_count: Precision::Absent, + ..Default::default() + }], + Arc::new(BinaryExpr::new( + Arc::new(Column::new("name", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Utf8(Some("hello".to_string())))), + )), + vec![Precision::Exact(1)], + ), + ( + "contradictory utf8 equality (infeasible)", + vec![Field::new("name", DataType::Utf8, false)], + vec![ColumnStatistics { + distinct_count: Precision::Inexact(100), + ..Default::default() + }], + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("name", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Utf8(Some( + "alice".to_string(), + )))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("name", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Utf8(Some( + "bob".to_string(), + )))), + )), + )), + vec![Precision::Exact(0)], + ), + ( + "redundant same-value equality combined with another column", + vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ], + vec![ + ColumnStatistics { + distinct_count: Precision::Inexact(80), + ..Default::default() + }, + ColumnStatistics { + distinct_count: Precision::Inexact(40), + ..Default::default() + }, + ], + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))), + )), + )), + vec![Precision::Exact(1), Precision::Exact(1)], + ), + ]; + + for (desc, fields, col_stats, predicate, expected_ndvs) in cases { + let schema = Schema::new(fields); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(1000), + column_statistics: col_stats, + }, + schema.clone(), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let statistics = filter.partition_statistics(None)?; + + for (i, expected) in expected_ndvs.iter().enumerate() { + assert_eq!( + statistics.column_statistics[i].distinct_count, *expected, + "case '{desc}': column {i} NDV mismatch" + ); + } + } + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_empty_input_equality_ndv_zero() -> Result<()> { + let cases: Vec<(&str, Schema, Statistics, Arc)> = vec![ + ( + "fallback string equality", + Schema::new(vec![Field::new("name", DataType::Utf8, true)]), + Statistics { + num_rows: Precision::Exact(0), + total_byte_size: Precision::Exact(0), + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Exact(0), + null_count: Precision::Exact(0), + byte_size: Precision::Exact(0), + ..Default::default() + }], + }, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("name", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Utf8(Some("x".to_string())))), + )), + ), + ( + "interval numeric equality", + Schema::new(vec![Field::new("a", DataType::Int32, true)]), + Statistics { + num_rows: Precision::Exact(0), + total_byte_size: Precision::Exact(0), + column_statistics: vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(10))), + distinct_count: Precision::Exact(0), + null_count: Precision::Exact(0), + byte_size: Precision::Exact(0), + ..Default::default() + }], + }, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )), + ), + ]; + + for (desc, schema, input_stats, predicate) in cases { + let input = Arc::new(StatisticsExec::new(input_stats, schema)); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let statistics = filter.partition_statistics(None)?; + + assert_eq!( + statistics.num_rows, + Precision::Inexact(0), + "case '{desc}': row count mismatch" + ); + assert_eq!( + statistics.column_statistics[0].distinct_count, + Precision::Inexact(0), + "case '{desc}': NDV should be capped at zero rows" + ); + } + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_and_equality_ndv() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(1200), + column_statistics: vec![ + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + null_count: Precision::Inexact(80), + distinct_count: Precision::Inexact(80), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(50))), + distinct_count: Precision::Inexact(40), + ..Default::default() + }, + ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(200))), + null_count: Precision::Inexact(90), + distinct_count: Precision::Inexact(150), + ..Default::default() + }, + ], + }, + schema.clone(), + )); + + // a = 42 AND b > 10 AND c = 7 + let predicate = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(7)))), + )), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let statistics = filter.partition_statistics(None)?; + // Equality predicates collapse NDV and reject nulls for their columns. + assert_eq!( + statistics.column_statistics[0].distinct_count, + Precision::Exact(1) + ); + assert_eq!( + statistics.column_statistics[0].null_count, + Precision::Exact(0) + ); + // b > 10 narrows to [11, 50] but doesn't collapse to a single value. + // The combined selectivity of a=42 (1/80) and c=7 (1/150) on 100 rows + // computes num_rows = 1, so NDV is capped at the row count: min(40, 1) = 1. + assert_eq!( + statistics.column_statistics[1].distinct_count, + Precision::Inexact(1) + ); + assert_eq!( + statistics.column_statistics[2].distinct_count, + Precision::Exact(1) + ); + assert_eq!( + statistics.column_statistics[2].null_count, + Precision::Exact(0) + ); + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_equality_absent_bounds_ndv() -> Result<()> { + // a: ndv=80, no min/max + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(400), + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Inexact(80), + ..Default::default() + }], + }, + schema.clone(), + )); + + // Even without input bounds, interval analysis can derive singleton + // bounds from the equality itself. + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let statistics = filter.partition_statistics(None)?; + assert_eq!( + statistics.column_statistics[0].distinct_count, + Precision::Exact(1) + ); + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_equality_int8_ndv() -> Result<()> { + // a: min=-100, max=100, ndv=50 + let schema = Schema::new(vec![Field::new("a", DataType::Int8, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(100), + column_statistics: vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int8(Some(-100))), + max_value: Precision::Inexact(ScalarValue::Int8(Some(100))), + distinct_count: Precision::Inexact(50), + ..Default::default() + }], + }, + schema.clone(), + )); + + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int8(Some(42)))), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let statistics = filter.partition_statistics(None)?; + assert_eq!( + statistics.column_statistics[0].distinct_count, + Precision::Exact(1) + ); + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_equality_int64_ndv() -> Result<()> { + // a: min=0, max=1_000_000, ndv=100_000 + let schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(100_000), + total_byte_size: Precision::Inexact(800_000), + column_statistics: vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int64(Some(0))), + max_value: Precision::Inexact(ScalarValue::Int64(Some(1_000_000))), + distinct_count: Precision::Inexact(100_000), + ..Default::default() + }], + }, + schema.clone(), + )); + + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int64(Some(42)))), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let statistics = filter.partition_statistics(None)?; + assert_eq!( + statistics.column_statistics[0].distinct_count, + Precision::Exact(1) + ); + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_equality_float32_ndv() -> Result<()> { + // a: min=0.0, max=100.0, ndv=50 + let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(400), + column_statistics: vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Float32(Some(0.0))), + max_value: Precision::Inexact(ScalarValue::Float32(Some(100.0))), + distinct_count: Precision::Inexact(50), + ..Default::default() + }], + }, + schema.clone(), + )); + + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Float32(Some(42.5)))), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let statistics = filter.partition_statistics(None)?; + assert_eq!( + statistics.column_statistics[0].distinct_count, + Precision::Exact(1) + ); + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_equality_reversed_ndv() -> Result<()> { + // a: min=1, max=100, ndv=80 + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(400), + column_statistics: vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + distinct_count: Precision::Inexact(80), + ..Default::default() + }], + }, + schema.clone(), + )); + + // 42 = a (literal on the left) + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))), + Operator::Eq, + Arc::new(Column::new("a", 0)), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let statistics = filter.partition_statistics(None)?; + assert_eq!( + statistics.column_statistics[0].distinct_count, + Precision::Exact(1) + ); + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_equality_timestamp_ndv() -> Result<()> { + // ts: min=1_000_000_000, max=2_000_000_000, ndv=500 + let schema = Schema::new(vec![Field::new( + "ts", + DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, None), + false, + )]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(1000), + total_byte_size: Precision::Inexact(8000), + column_statistics: vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::TimestampNanosecond( + Some(1_000_000_000), + None, + )), + max_value: Precision::Inexact(ScalarValue::TimestampNanosecond( + Some(2_000_000_000), + None, + )), + distinct_count: Precision::Inexact(500), + ..Default::default() + }], + }, + schema.clone(), + )); + + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("ts", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::TimestampNanosecond( + Some(1_500_000_000), + None, + ))), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + let statistics = filter.partition_statistics(None)?; + assert_eq!( + statistics.column_statistics[0].distinct_count, + Precision::Exact(1) + ); + Ok(()) + } + + #[test] + fn test_collect_equality_columns() { + use std::collections::HashSet; + // (description, predicate, expected_column_indices, expected_infeasible) + #[expect(clippy::type_complexity)] + let cases: Vec<(&str, Arc, Vec, bool)> = vec![ + ( + "simple col = literal", + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))), + )), + vec![0], + false, + ), + ( + "reversed literal = col", + Arc::new(BinaryExpr::new( + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))), + Operator::Eq, + Arc::new(Column::new("a", 0)), + )), + vec![0], + false, + ), + ( + "AND with two equalities", + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Utf8(Some( + "hello".to_string(), + )))), + )), + )), + vec![0, 1], + false, + ), + ( + "OR produces empty set", + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))), + )), + Operator::Or, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(99)))), + )), + )), + vec![], + false, + ), + ( + "greater-than produces empty set", + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))), + )), + vec![], + false, + ), + ( + "col = col produces empty set", + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Column::new("b", 1)), + )), + vec![], + false, + ), + ( + "nested AND with three equalities", + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))), + )), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c", 2)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(3)))), + )), + )), + vec![0, 1, 2], + false, + ), + ( + "AND with mixed equality and non-equality", + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )), + )), + vec![0], + false, + ), + ( + "col = NULL is excluded", + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(None))), + )), + vec![], + false, + ), + ( + "NULL = col is excluded", + Arc::new(BinaryExpr::new( + Arc::new(Literal::new(ScalarValue::Utf8(None))), + Operator::Eq, + Arc::new(Column::new("a", 0)), + )), + vec![], + false, + ), + ( + "contradictory: same col, different literals", + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Utf8(Some( + "alice".to_string(), + )))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Utf8(Some( + "bob".to_string(), + )))), + )), + )), + vec![0], + true, + ), + ( + "same col, same literal is not contradictory", + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))), + )), + Operator::And, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Eq, + Arc::new(Literal::new(ScalarValue::Int32(Some(42)))), + )), + )), + vec![0], + false, + ), + ]; + + for (desc, expr, expected_cols, expected_infeasible) in cases { + let (result, infeasible) = collect_equality_columns(&expr); + let expected: HashSet = expected_cols.into_iter().collect(); + if expected_infeasible { + // When infeasible, the scan is short-circuited, so we only + // assert the infeasibility flag — the partial column set + // contents are an implementation detail. + assert!(infeasible, "case '{desc}': expected infeasible"); + } else { + assert_eq!(result, expected, "case '{desc}': columns mismatch"); + assert!(!infeasible, "case '{desc}': expected feasible"); + } + } + } + + /// Regression test: ProjectionExec on top of a FilterExec that already has + /// an explicit projection must not panic when `try_swapping_with_projection` + /// attempts to swap the two nodes. + /// + /// Before the fix, `FilterExecBuilder::from(self)` copied the old projection + /// (e.g. `[0, 1, 2]`) from the FilterExec. After `.with_input` replaced the + /// input with the narrower ProjectionExec (2 columns), `.build()` tried to + /// validate the stale `[0, 1, 2]` projection against the 2-column schema and + /// panicked with "project index 2 out of bounds, max field 2". + #[test] + fn test_filter_with_projection_swap_does_not_panic() -> Result<()> { + use crate::projection::ProjectionExpr; + use datafusion_physical_expr::expressions::col; + + // Schema: [ts: Int64, tokens: Int64, svc: Utf8] + let schema = Arc::new(Schema::new(vec![ + Field::new("ts", DataType::Int64, false), + Field::new("tokens", DataType::Int64, false), + Field::new("svc", DataType::Utf8, false), + ])); + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + // FilterExec: ts > 0, projection=[ts@0, tokens@1, svc@2] (all 3 cols) + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("ts", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), + )); + let filter = Arc::new( + FilterExecBuilder::new(predicate, input) + .apply_projection(Some(vec![0, 1, 2]))? + .build()?, + ); + + // ProjectionExec: narrows to [ts, tokens] (drops svc) + let proj_exprs = vec![ + ProjectionExpr { + expr: col("ts", &filter.schema())?, + alias: "ts".to_string(), + }, + ProjectionExpr { + expr: col("tokens", &filter.schema())?, + alias: "tokens".to_string(), + }, + ]; + let projection = Arc::new(ProjectionExec::try_new( + proj_exprs, + Arc::clone(&filter) as _, + )?); + + // This must not panic + let result = filter.try_swapping_with_projection(&projection)?; + assert!(result.is_some(), "swap should succeed"); + + let new_plan = result.unwrap(); + // Output schema must still be [ts, tokens] + let out_schema = new_plan.schema(); + assert_eq!(out_schema.fields().len(), 2); + assert_eq!(out_schema.field(0).name(), "ts"); + assert_eq!(out_schema.field(1).name(), "tokens"); + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_ndv_capped_at_row_count() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(1000), + column_statistics: vec![ColumnStatistics { + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + null_count: Precision::Inexact(80), + distinct_count: Precision::Inexact(80), + byte_size: Precision::Exact(1000), + ..Default::default() + }], + }, + schema.clone(), + )); + + // a <= 10 => ~10 rows out of 100 + let predicate: Arc = + binary(col("a", &schema)?, Operator::LtEq, lit(10i32), &schema)?; + + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + + let statistics = filter.partition_statistics(None)?; + assert_eq!(statistics.num_rows, Precision::Inexact(10)); + let ndv = &statistics.column_statistics[0].distinct_count; + assert!( + ndv.get_value().copied() <= Some(10), + "Expected NDV <= 10 (filtered row count), got {ndv:?}" + ); + // `a <= 10` rejects nulls, so the 80 input nulls drop to exactly zero. + assert_eq!( + statistics.column_statistics[0].null_count, + Precision::Exact(0) + ); + // byte_size follows the same 10% selectivity estimate. + assert_eq!( + statistics.column_statistics[0].byte_size, + Precision::Inexact(100) + ); + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_default_selectivity_column_stats() -> Result<()> { + let schema = Schema::new(vec![Field::new("name", DataType::Utf8, true)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(1000), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Inexact(80), + distinct_count: Precision::Inexact(60), + byte_size: Precision::Exact(1000), + ..Default::default() + }], + }, + schema.clone(), + )); + + // Utf8 interval analysis is unsupported, so this exercises the default + // selectivity path. The predicate rejects nulls but does not constrain + // the column to one value. + let predicate: Arc = + binary(col("name", &schema)?, Operator::Gt, lit("m"), &schema)?; + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + + let statistics = filter.partition_statistics(None)?; + assert_eq!(statistics.num_rows, Precision::Inexact(20)); + assert_eq!( + statistics.column_statistics[0].null_count, + Precision::Exact(0) + ); + assert_eq!( + statistics.column_statistics[0].byte_size, + Precision::Inexact(200) + ); + assert_eq!( + statistics.column_statistics[0].distinct_count, + Precision::Inexact(20) + ); + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_or_does_not_reject_nulls() -> Result<()> { + let schema = Schema::new(vec![Field::new("name", DataType::Utf8, true)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(1000), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Inexact(80), + distinct_count: Precision::Inexact(60), + byte_size: Precision::Exact(1000), + ..Default::default() + }], + }, + schema.clone(), + )); + + let predicate: Arc = binary( + binary(col("name", &schema)?, Operator::Gt, lit("m"), &schema)?, + Operator::Or, + is_null(col("name", &schema)?)?, + &schema, + )?; + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + + let statistics = filter.partition_statistics(None)?; + assert_eq!(statistics.num_rows, Precision::Inexact(20)); + assert_eq!( + statistics.column_statistics[0].null_count, + Precision::Inexact(20) + ); + assert_eq!( + statistics.column_statistics[0].byte_size, + Precision::Inexact(200) + ); + assert_eq!( + statistics.column_statistics[0].distinct_count, + Precision::Inexact(20) + ); + Ok(()) + } + + #[tokio::test] + async fn test_filter_statistics_is_not_null_rejects_nulls() -> Result<()> { + let schema = Schema::new(vec![Field::new("name", DataType::Utf8, true)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(1000), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Inexact(80), + distinct_count: Precision::Inexact(60), + byte_size: Precision::Exact(1000), + ..Default::default() + }], + }, + schema.clone(), + )); + + // `name IS NOT NULL` keeps only non-null rows, so the surviving null + // count is exactly zero. Utf8 interval analysis is unsupported, so this + // also exercises the default-selectivity path. + let predicate: Arc = is_not_null(col("name", &schema)?)?; + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, input)?); + + let statistics = filter.partition_statistics(None)?; + assert_eq!(statistics.num_rows, Precision::Inexact(20)); + assert_eq!( + statistics.column_statistics[0].null_count, + Precision::Exact(0) + ); + assert_eq!( + statistics.column_statistics[0].byte_size, + Precision::Inexact(200) + ); + assert_eq!( + statistics.column_statistics[0].distinct_count, + Precision::Inexact(20) + ); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/filter_pushdown.rs b/datafusion/physical-plan/src/filter_pushdown.rs index 1274e954eaeb3..810f9ffcbcdb1 100644 --- a/datafusion/physical-plan/src/filter_pushdown.rs +++ b/datafusion/physical-plan/src/filter_pushdown.rs @@ -37,10 +37,13 @@ use std::collections::HashSet; use std::sync::Arc; -use datafusion_common::Result; -use datafusion_physical_expr::utils::{collect_columns, reassign_expr_columns}; +use arrow_schema::SchemaRef; +use datafusion_common::{ + Result, + tree_node::{Transformed, TreeNode}, +}; +use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use itertools::Itertools; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum FilterPushdownPhase { @@ -217,13 +220,13 @@ pub struct ChildPushdownResult { /// Returned from [`ExecutionPlan::handle_child_pushdown_result`] to communicate /// to the optimizer: /// -/// 1. What to do with any parent filters that were could not be pushed down into the children. +/// 1. What to do with any parent filters that could not be pushed down into the children. /// 2. If the node needs to be replaced in the execution plan with a new node or not. /// /// [`ExecutionPlan::handle_child_pushdown_result`]: crate::ExecutionPlan::handle_child_pushdown_result #[derive(Debug, Clone)] pub struct FilterPushdownPropagation { - /// What filters were pushed into the parent node. + /// Which parent filters were pushed down into this node's children. pub filters: Vec, /// The updated node, if it was updated during pushdown pub updated_node: Option, @@ -306,6 +309,83 @@ pub struct ChildFilterDescription { pub(crate) self_filters: Vec>, } +/// Validates and remaps filter column references to a target schema in one step. +/// +/// When pushing filters from a parent to a child node, we need to: +/// 1. Verify that all columns referenced by the filter exist in the target +/// 2. Remap column indices to match the target schema +/// +/// `allowed_indices` controls which column indices (in the parent schema) are +/// considered valid. For single-input nodes this defaults to +/// `0..child_schema.len()` (all columns are reachable). For join nodes it is +/// restricted to the subset of output columns that map to the target child, +/// which is critical when different sides have same-named columns. +pub(crate) struct FilterRemapper { + /// The target schema to remap column indices into. + child_schema: SchemaRef, + /// Only columns at these indices (in the *parent* schema) are considered + /// valid. For non-join nodes this defaults to `0..child_schema.len()`. + allowed_indices: HashSet, +} + +impl FilterRemapper { + /// Create a remapper that accepts any column whose index falls within + /// `0..child_schema.len()` and whose name exists in the target schema. + pub(crate) fn new(child_schema: SchemaRef) -> Self { + let allowed_indices = (0..child_schema.fields().len()).collect(); + Self { + child_schema, + allowed_indices, + } + } + + /// Create a remapper that only accepts columns at the given indices. + /// This is used by join nodes to restrict pushdown to one side of the + /// join when both sides have same-named columns. + fn with_allowed_indices( + child_schema: SchemaRef, + allowed_indices: HashSet, + ) -> Self { + Self { + child_schema, + allowed_indices, + } + } + + /// Try to remap a filter's column references to the target schema. + /// + /// Validates and remaps in a single tree traversal: for each column, + /// checks that its index is in the allowed set and that + /// its name exists in the target schema, then remaps the index. + /// Returns `Some(remapped)` if all columns are valid, or `None` if any + /// column fails validation. + pub(crate) fn try_remap( + &self, + filter: &Arc, + ) -> Result>> { + let mut all_valid = true; + let transformed = Arc::clone(filter).transform_down(|expr| { + if let Some(col) = expr.downcast_ref::() { + if self.allowed_indices.contains(&col.index()) + && let Ok(new_index) = self.child_schema.index_of(col.name()) + { + Ok(Transformed::yes(Arc::new(Column::new( + col.name(), + new_index, + )))) + } else { + all_valid = false; + Ok(Transformed::complete(expr)) + } + } else { + Ok(Transformed::no(expr)) + } + })?; + + Ok(all_valid.then_some(transformed.data)) + } +} + impl ChildFilterDescription { /// Build a child filter description by analyzing which parent filters can be pushed to a specific child. /// @@ -318,36 +398,41 @@ impl ChildFilterDescription { parent_filters: &[Arc], child: &Arc, ) -> Result { - let child_schema = child.schema(); + let remapper = FilterRemapper::new(child.schema()); + Self::remap_filters(parent_filters, &remapper) + } - // Get column names from child schema for quick lookup - let child_column_names: HashSet<&str> = child_schema - .fields() - .iter() - .map(|f| f.name().as_str()) - .collect(); + /// Like [`Self::from_child`], but restricts which parent-level columns are + /// considered reachable through this child. + /// + /// `allowed_indices` is the set of column indices (in the *parent* + /// schema) that map to this child's side of a join. A filter is only + /// eligible for pushdown when **every** column index it references + /// appears in `allowed_indices`. + /// + /// This prevents incorrect pushdown when different join sides have + /// columns with the same name: matching on index ensures a filter + /// referencing the right side's `k@2` is not pushed to the left side + /// which also has a column named `k` but at a different index. + pub fn from_child_with_allowed_indices( + parent_filters: &[Arc], + allowed_indices: HashSet, + child: &Arc, + ) -> Result { + let remapper = + FilterRemapper::with_allowed_indices(child.schema(), allowed_indices); + Self::remap_filters(parent_filters, &remapper) + } - // Analyze each parent filter + fn remap_filters( + parent_filters: &[Arc], + remapper: &FilterRemapper, + ) -> Result { let mut child_parent_filters = Vec::with_capacity(parent_filters.len()); - for filter in parent_filters { - // Check which columns the filter references - let referenced_columns = collect_columns(filter); - - // Check if all referenced columns exist in the child schema - let all_columns_exist = referenced_columns - .iter() - .all(|col| child_column_names.contains(col.name())); - - if all_columns_exist { - // All columns exist in child - we can push down - // Need to reassign column indices to match child schema - let reassigned_filter = - reassign_expr_columns(Arc::clone(filter), &child_schema)?; - child_parent_filters - .push(PushedDownPredicate::supported(reassigned_filter)); + if let Some(remapped) = remapper.try_remap(filter)? { + child_parent_filters.push(PushedDownPredicate::supported(remapped)); } else { - // Some columns don't exist in child - cannot push down child_parent_filters .push(PushedDownPredicate::unsupported(Arc::clone(filter))); } @@ -359,6 +444,17 @@ impl ChildFilterDescription { }) } + /// Mark all parent filters as unsupported for this child. + pub fn all_unsupported(parent_filters: &[Arc]) -> Self { + Self { + parent_filters: parent_filters + .iter() + .map(|f| PushedDownPredicate::unsupported(Arc::clone(f))) + .collect(), + self_filters: vec![], + } + } + /// Add a self filter (from the current node) to be pushed down to this child. pub fn with_self_filter(mut self, filter: Arc) -> Self { self.self_filters.push(filter); @@ -434,15 +530,9 @@ impl FilterDescription { children: &[&Arc], ) -> Self { let mut desc = Self::new(); - let child_filters = parent_filters - .iter() - .map(|f| PushedDownPredicate::unsupported(Arc::clone(f))) - .collect_vec(); for _ in 0..children.len() { - desc = desc.with_child(ChildFilterDescription { - parent_filters: child_filters.clone(), - self_filters: vec![], - }); + desc = + desc.with_child(ChildFilterDescription::all_unsupported(parent_filters)); } desc } diff --git a/datafusion/physical-plan/src/joins/array_map.rs b/datafusion/physical-plan/src/joins/array_map.rs new file mode 100644 index 0000000000000..ad40d6776df4f --- /dev/null +++ b/datafusion/physical-plan/src/joins/array_map.rs @@ -0,0 +1,547 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::DataType; +use num_traits::AsPrimitive; +use std::mem::size_of; + +use crate::joins::MapOffset; +use crate::joins::chain::traverse_chain; +use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; +use arrow::buffer::BooleanBuffer; +use arrow::datatypes::ArrowNumericType; +use datafusion_common::{Result, ScalarValue, internal_err}; + +/// A macro to downcast only supported integer types (up to 64-bit) and invoke a generic function. +/// +/// Usage: `downcast_supported_integer!(data_type => (Method, arg1, arg2, ...))` +/// +/// The `Method` must be an associated method of [`ArrayMap`] that is generic over +/// `` and allow `T::Native: AsPrimitive`. +macro_rules! downcast_supported_integer { + ($DATA_TYPE:expr => ($METHOD:ident $(, $ARGS:expr)*)) => { + match $DATA_TYPE { + arrow::datatypes::DataType::Int8 => ArrayMap::$METHOD::($($ARGS),*), + arrow::datatypes::DataType::Int16 => ArrayMap::$METHOD::($($ARGS),*), + arrow::datatypes::DataType::Int32 => ArrayMap::$METHOD::($($ARGS),*), + arrow::datatypes::DataType::Int64 => ArrayMap::$METHOD::($($ARGS),*), + arrow::datatypes::DataType::UInt8 => ArrayMap::$METHOD::($($ARGS),*), + arrow::datatypes::DataType::UInt16 => ArrayMap::$METHOD::($($ARGS),*), + arrow::datatypes::DataType::UInt32 => ArrayMap::$METHOD::($($ARGS),*), + arrow::datatypes::DataType::UInt64 => ArrayMap::$METHOD::($($ARGS),*), + _ => { + return internal_err!( + "Unsupported type for ArrayMap: {:?}", + $DATA_TYPE + ); + } + } + }; +} + +/// A dense map for single-column integer join keys within a limited range. +/// +/// Maps join keys to build-side indices using direct array indexing: +/// `data[val - min_val_in_build_side] -> val_idx_in_build_side + 1`. +/// +/// NULL values are ignored on both the build side and the probe side. +/// +/// # Handling Negative Numbers with `wrapping_sub` +/// +/// This implementation supports signed integer ranges (e.g., `[-5, 5]`) efficiently by +/// treating them as `u64` (Two's Complement) and relying on the bitwise properties of +/// wrapping arithmetic (`wrapping_sub`). +/// +/// In Two's Complement representation, `a_signed - b_signed` produces the same bit pattern +/// as `a_unsigned.wrapping_sub(b_unsigned)` (modulo 2^N). This allows us to perform +/// range calculations and zero-based index mapping uniformly for both signed and unsigned +/// types without branching. +/// +/// ## Examples +/// +/// Consider an `Int64` range `[-5, 5]`. +/// * `min_val (-5)` casts to `u64`: `...11111011` (`u64::MAX - 4`) +/// * `max_val (5)` casts to `u64`: `...00000101` (`5`) +/// +/// **1. Range Calculation** +/// +/// ```text +/// In modular arithmetic, this is equivalent to: +/// (5 - (2^64 - 5)) mod 2^64 +/// = (5 - 2^64 + 5) mod 2^64 +/// = (10 - 2^64) mod 2^64 +/// = 10 +/// +/// ``` +/// The resulting `range` (10) correctly represents the size of the interval `[-5, 5]`. +/// +/// **2. Index Lookup (in `get_matched_indices`)** +/// +/// For a probe value of `0` (which is stored as `0u64`): +/// ```text +/// In modular arithmetic, this is equivalent to: +/// (0 - (2^64 - 5)) mod 2^64 +/// = (-2^64 + 5) mod 2^64 +/// = 5 +/// ``` +/// This correctly maps `-5` to index `0`, `0` to index `5`, etc. +#[derive(Debug)] +pub struct ArrayMap { + // data[probSideVal-offset] -> valIdxInBuildSide + 1; 0 for absent + data: Vec, + // min val in buildSide + offset: u64, + // next[buildSideIdx] -> next matching valIdxInBuildSide + 1; 0 for end of chain. + // If next is empty, it means there are no duplicate keys (no conflicts). + // It uses the same chain-based conflict resolution as [`JoinHashMapType`]. + next: Vec, + num_of_distinct_key: usize, +} + +impl ArrayMap { + pub fn is_supported_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + ) + } + + pub(crate) fn key_to_u64(v: &ScalarValue) -> Option { + match v { + ScalarValue::Int8(Some(v)) => Some(*v as u64), + ScalarValue::Int16(Some(v)) => Some(*v as u64), + ScalarValue::Int32(Some(v)) => Some(*v as u64), + ScalarValue::Int64(Some(v)) => Some(*v as u64), + ScalarValue::UInt8(Some(v)) => Some(*v as u64), + ScalarValue::UInt16(Some(v)) => Some(*v as u64), + ScalarValue::UInt32(Some(v)) => Some(*v as u64), + ScalarValue::UInt64(Some(v)) => Some(*v), + _ => None, + } + } + + /// Estimates the maximum memory usage for an `ArrayMap` with the given parameters. + /// + pub fn estimate_memory_size(min_val: u64, max_val: u64, num_rows: usize) -> usize { + let range = Self::calculate_range(min_val, max_val); + if range >= usize::MAX as u64 { + return usize::MAX; + } + let size = (range + 1) as usize; + size.saturating_mul(size_of::()) + .saturating_add(num_rows.saturating_mul(size_of::())) + } + + pub fn calculate_range(min_val: u64, max_val: u64) -> u64 { + max_val.wrapping_sub(min_val) + } + + /// Creates a new [`ArrayMap`] from the given array of join keys. + /// + /// Note: This function processes only the non-null values in the input `array`, + /// ignoring any rows where the key is `NULL`. + /// + pub(crate) fn try_new(array: &ArrayRef, min_val: u64, max_val: u64) -> Result { + let range = max_val.wrapping_sub(min_val); + if range >= usize::MAX as u64 { + return internal_err!("ArrayMap key range is too large to be allocated."); + } + let size = (range + 1) as usize; + + let mut data: Vec = vec![0; size]; + let mut next: Vec = vec![]; + let mut num_of_distinct_key = 0; + + downcast_supported_integer!( + array.data_type() => ( + fill_data, + array, + min_val, + &mut data, + &mut next, + &mut num_of_distinct_key + ) + )?; + + Ok(Self { + data, + offset: min_val, + next, + num_of_distinct_key, + }) + } + + fn fill_data( + array: &ArrayRef, + offset_val: u64, + data: &mut [u32], + next: &mut Vec, + num_of_distinct_key: &mut usize, + ) -> Result<()> + where + T::Native: AsPrimitive, + { + let arr = array.as_primitive::(); + // Iterate in reverse to maintain FIFO order when there are duplicate keys. + for (i, val) in arr.iter().enumerate().rev() { + if let Some(val) = val { + let key: u64 = val.as_(); + let idx = key.wrapping_sub(offset_val) as usize; + if idx >= data.len() { + return internal_err!("failed build Array idx >= data.len()"); + } + + if data[idx] != 0 { + if next.is_empty() { + *next = vec![0; array.len()] + } + next[i] = data[idx] + } else { + *num_of_distinct_key += 1; + } + data[idx] = (i) as u32 + 1; + } + } + Ok(()) + } + + pub fn num_of_distinct_key(&self) -> usize { + self.num_of_distinct_key + } + + /// Returns the memory usage of this [`ArrayMap`] in bytes. + pub fn size(&self) -> usize { + self.data.capacity() * size_of::() + self.next.capacity() * size_of::() + } + + pub fn get_matched_indices_with_limit_offset( + &self, + prob_side_keys: &[ArrayRef], + limit: usize, + current_offset: MapOffset, + probe_indices: &mut Vec, + build_indices: &mut Vec, + ) -> Result> { + if prob_side_keys.len() != 1 { + return internal_err!( + "ArrayMap expects 1 join key, but got {}", + prob_side_keys.len() + ); + } + let array = &prob_side_keys[0]; + + downcast_supported_integer!( + array.data_type() => ( + lookup_and_get_indices, + self, + array, + limit, + current_offset, + probe_indices, + build_indices + ) + ) + } + + fn lookup_and_get_indices( + &self, + array: &ArrayRef, + limit: usize, + current_offset: MapOffset, + probe_indices: &mut Vec, + build_indices: &mut Vec, + ) -> Result> + where + T::Native: Copy + AsPrimitive, + { + probe_indices.clear(); + build_indices.clear(); + + let arr = array.as_primitive::(); + + let have_null = arr.null_count() > 0; + + if self.next.is_empty() { + for prob_idx in current_offset.0..arr.len() { + if build_indices.len() == limit { + return Ok(Some((prob_idx, None))); + } + + // short circuit + if have_null && arr.is_null(prob_idx) { + continue; + } + // SAFETY: prob_idx is guaranteed to be within bounds by the loop range. + let prob_val: u64 = unsafe { arr.value_unchecked(prob_idx) }.as_(); + let idx_in_build_side = prob_val.wrapping_sub(self.offset) as usize; + + if idx_in_build_side >= self.data.len() + || self.data[idx_in_build_side] == 0 + { + continue; + } + build_indices.push((self.data[idx_in_build_side] - 1) as u64); + probe_indices.push(prob_idx as u32); + } + Ok(None) + } else { + let mut remaining_output = limit; + let to_skip = match current_offset { + // None `initial_next_idx` indicates that `initial_idx` processing hasn't been started + (idx, None) => idx, + // Zero `initial_next_idx` indicates that `initial_idx` has been processed during + // previous iteration, and it should be skipped + (idx, Some(0)) => idx + 1, + // Otherwise, process remaining `initial_idx` matches by traversing `next_chain`, + // to start with the next index + (idx, Some(next_idx)) => { + let is_last = idx == arr.len() - 1; + if let Some(next_offset) = traverse_chain( + &self.next, + idx, + next_idx as u32, + &mut remaining_output, + probe_indices, + build_indices, + is_last, + ) { + return Ok(Some(next_offset)); + } + idx + 1 + } + }; + + for prob_side_idx in to_skip..arr.len() { + if remaining_output == 0 { + return Ok(Some((prob_side_idx, None))); + } + + if arr.is_null(prob_side_idx) { + continue; + } + + let is_last = prob_side_idx == arr.len() - 1; + + // SAFETY: prob_idx is guaranteed to be within bounds by the loop range. + let prob_val: u64 = unsafe { arr.value_unchecked(prob_side_idx) }.as_(); + let idx_in_build_side = prob_val.wrapping_sub(self.offset) as usize; + if idx_in_build_side >= self.data.len() + || self.data[idx_in_build_side] == 0 + { + continue; + } + + let build_idx = self.data[idx_in_build_side]; + + if let Some(offset) = traverse_chain( + &self.next, + prob_side_idx, + build_idx, + &mut remaining_output, + probe_indices, + build_indices, + is_last, + ) { + return Ok(Some(offset)); + } + } + Ok(None) + } + } + + pub fn contain_keys(&self, probe_side_keys: &[ArrayRef]) -> Result { + if probe_side_keys.len() != 1 { + return internal_err!( + "ArrayMap join expects 1 join key, but got {}", + probe_side_keys.len() + ); + } + let array = &probe_side_keys[0]; + + downcast_supported_integer!( + array.data_type() => ( + contain_hashes_helper, + self, + array + ) + ) + } + + fn contain_hashes_helper( + &self, + array: &ArrayRef, + ) -> Result + where + T::Native: AsPrimitive, + { + let arr = array.as_primitive::(); + let buffer = BooleanBuffer::collect_bool(arr.len(), |i| { + if arr.is_null(i) { + return false; + } + // SAFETY: i is within bounds [0, arr.len()) + let key: u64 = unsafe { arr.value_unchecked(i) }.as_(); + let idx = key.wrapping_sub(self.offset) as usize; + idx < self.data.len() && self.data[idx] != 0 + }); + Ok(BooleanArray::new(buffer, None)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int32Array; + use arrow::array::Int64Array; + use std::sync::Arc; + + #[test] + fn test_array_map_limit_offset_duplicate_elements() -> Result<()> { + let build: ArrayRef = Arc::new(Int32Array::from(vec![1, 1, 2])); + let map = ArrayMap::try_new(&build, 1, 2)?; + let probe = [Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef]; + + let mut prob_idx = Vec::new(); + let mut build_idx = Vec::new(); + let mut next = Some((0, None)); + let mut results = vec![]; + + while let Some(o) = next { + next = map.get_matched_indices_with_limit_offset( + &probe, + 1, + o, + &mut prob_idx, + &mut build_idx, + )?; + results.push((prob_idx.clone(), build_idx.clone(), next)); + } + + let expected = vec![ + (vec![0], vec![0], Some((0, Some(2)))), + (vec![0], vec![1], Some((0, Some(0)))), + (vec![1], vec![2], None), + ]; + assert_eq!(results, expected); + Ok(()) + } + + #[test] + fn test_array_map_with_limit_and_misses() -> Result<()> { + let build: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); + let map = ArrayMap::try_new(&build, 1, 2)?; + let probe = [Arc::new(Int32Array::from(vec![10, 1, 2])) as ArrayRef]; + + let (mut p_idx, mut b_idx) = (vec![], vec![]); + // Skip 10, find 1, next is 2 + let next = map.get_matched_indices_with_limit_offset( + &probe, + 1, + (0, None), + &mut p_idx, + &mut b_idx, + )?; + assert_eq!(p_idx, vec![1]); + assert_eq!(b_idx, vec![0]); + assert_eq!(next, Some((2, None))); + + // Find 2, end + let next = map.get_matched_indices_with_limit_offset( + &probe, + 1, + next.unwrap(), + &mut p_idx, + &mut b_idx, + )?; + assert_eq!(p_idx, vec![2]); + assert_eq!(b_idx, vec![1]); + assert!(next.is_none()); + Ok(()) + } + + #[test] + fn test_array_map_with_build_duplicates_and_misses() -> Result<()> { + let build_array: ArrayRef = Arc::new(Int32Array::from(vec![1, 1])); + let array_map = ArrayMap::try_new(&build_array, 1, 1)?; + // prob: 10(m), 1(h1, h2), 20(m), 1(h1, h2) + let probe_array: ArrayRef = Arc::new(Int32Array::from(vec![10, 1, 20, 1])); + let prob_side_keys = [probe_array]; + + let mut prob_indices = Vec::new(); + let mut build_indices = Vec::new(); + + // batch_size=3, should get 2 matches from first '1' and 1 match from second '1' + let result_offset = array_map.get_matched_indices_with_limit_offset( + &prob_side_keys, + 3, + (0, None), + &mut prob_indices, + &mut build_indices, + )?; + + assert_eq!(prob_indices, vec![1, 1, 3]); + assert_eq!(build_indices, vec![0, 1, 0]); + assert_eq!(result_offset, Some((3, Some(2)))); + Ok(()) + } + + #[test] + fn test_array_map_i64_with_negative_and_positive_numbers() -> Result<()> { + // Build array with a mix of negative and positive i64 values, no duplicates + let build_array: ArrayRef = Arc::new(Int64Array::from(vec![-5, 0, 5, -2, 3, 10])); + let min_val = -5_i128; + let max_val = 10_i128; + + let array_map = ArrayMap::try_new(&build_array, min_val as u64, max_val as u64)?; + + // Probe array + let probe_array: ArrayRef = Arc::new(Int64Array::from(vec![0, -5, 10, -1])); + let prob_side_keys = [Arc::clone(&probe_array)]; + + let mut prob_indices = Vec::new(); + let mut build_indices = Vec::new(); + + // Call once to get all matches + let result_offset = array_map.get_matched_indices_with_limit_offset( + &prob_side_keys, + 10, // A batch size larger than number of probes + (0, None), + &mut prob_indices, + &mut build_indices, + )?; + + // Expected matches, in probe-side order: + // Probe 0 (value 0) -> Build 1 (value 0) + // Probe 1 (value -5) -> Build 0 (value -5) + // Probe 2 (value 10) -> Build 5 (value 10) + let expected_prob_indices = vec![0, 1, 2]; + let expected_build_indices = vec![1, 0, 5]; + + assert_eq!(prob_indices, expected_prob_indices); + assert_eq!(build_indices, expected_build_indices); + assert!(result_offset.is_none()); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/chain.rs b/datafusion/physical-plan/src/joins/chain.rs new file mode 100644 index 0000000000000..846b7505d6478 --- /dev/null +++ b/datafusion/physical-plan/src/joins/chain.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::Debug; +use std::ops::Sub; + +use arrow::datatypes::ArrowNativeType; + +use crate::joins::MapOffset; + +/// Traverses the chain of matching indices, collecting results up to the remaining limit. +/// Returns `Some(offset)` if the limit was reached and there are more results to process, +/// or `None` if the chain was fully traversed. +#[inline(always)] +pub(crate) fn traverse_chain( + next_chain: &[T], + prob_idx: usize, + start_chain_idx: T, + remaining: &mut usize, + input_indices: &mut Vec, + match_indices: &mut Vec, + is_last_input: bool, +) -> Option +where + T: Copy + TryFrom + PartialOrd + Into + Sub, + >::Error: Debug, + T: ArrowNativeType, +{ + let zero = T::usize_as(0); + let one = T::usize_as(1); + let mut match_row_idx = start_chain_idx - one; + + loop { + match_indices.push(match_row_idx.into()); + input_indices.push(prob_idx as u32); + *remaining -= 1; + + let next = next_chain[match_row_idx.into() as usize]; + + if *remaining == 0 { + // Limit reached - return offset for next call + return if is_last_input && next == zero { + // Finished processing the last input row + None + } else { + Some((prob_idx, Some(next.into()))) + }; + } + if next == zero { + // End of chain + return None; + } + match_row_idx = next - one; + } +} diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 0488cd35a8e36..45b34692abed4 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -18,23 +18,24 @@ //! Defines the cross join plan for loading the left side of the cross join //! and producing batches in parallel for the right partitions -use std::{any::Any, sync::Arc, task::Poll}; +use std::{sync::Arc, task::Poll}; use super::utils::{ - adjust_right_output_partitioning, reorder_output_after_swap, BatchSplitter, - BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut, - StatefulStreamResult, + BatchSplitter, BatchTransformer, BuildProbeJoinMetrics, NoopBatchTransformer, + OnceAsync, OnceFut, StatefulStreamResult, adjust_right_output_partitioning, + reorder_output_after_swap, }; -use crate::execution_plan::{boundedness_from_children, EmissionType}; +use crate::execution_plan::{EmissionType, boundedness_from_children}; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::projection::{ - join_allows_pushdown, join_table_borders, new_join_children, - physical_to_column_exprs, ProjectionExec, + ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children, + physical_to_column_exprs, }; +use crate::stream::EmptyRecordBatchStream; use crate::{ - handle_state, ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, - ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream, - SendableRecordBatchStream, Statistics, + ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, + ExecutionPlanProperties, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, Statistics, check_if_same_properties, handle_state, }; use arrow::array::{RecordBatch, RecordBatchOptions}; @@ -42,14 +43,14 @@ use arrow::compute::concat_batches; use arrow::datatypes::{Fields, Schema, SchemaRef}; use datafusion_common::stats::Precision; use datafusion_common::{ - assert_eq_or_internal_err, internal_err, JoinType, Result, ScalarValue, + JoinType, Result, ScalarValue, assert_eq_or_internal_err, internal_err, }; -use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_physical_expr::equivalence::join_equivalence_properties; use async_trait::async_trait; -use futures::{ready, Stream, StreamExt, TryStreamExt}; +use futures::{Stream, StreamExt, TryStreamExt, ready}; /// Data of the left side that is buffered into memory #[derive(Debug)] @@ -61,7 +62,7 @@ struct JoinLeftData { _reservation: MemoryReservation, } -#[allow(rustdoc::private_intra_doc_links)] +#[expect(rustdoc::private_intra_doc_links)] /// Cross Join Execution Plan /// /// This operator is used when there are no predicates between two tables and @@ -94,7 +95,7 @@ pub struct CrossJoinExec { /// Execution plan metrics metrics: ExecutionPlanMetricsSet, /// Properties such as schema, equivalence properties, ordering, partitioning, etc. - cache: PlanProperties, + cache: Arc, } impl CrossJoinExec { @@ -125,7 +126,7 @@ impl CrossJoinExec { schema, left_fut: Default::default(), metrics: ExecutionPlanMetricsSet::default(), - cache, + cache: Arc::new(cache), } } @@ -192,6 +193,23 @@ impl CrossJoinExec { &self.right.schema(), ) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + let left = children.swap_remove(0); + let right = children.swap_remove(0); + + Self { + left, + right, + metrics: ExecutionPlanMetricsSet::new(), + left_fut: Default::default(), + cache: Arc::clone(&self.cache), + schema: Arc::clone(&self.schema), + } + } } /// Asynchronously collect the result of the left child @@ -206,7 +224,7 @@ async fn load_left_input( let (batches, _metrics, reservation) = stream .try_fold( (Vec::new(), metrics, reservation), - |(mut batches, metrics, mut reservation), batch| async { + |(mut batches, metrics, reservation), batch| async { let batch_size = batch.get_array_memory_size(); // Reserve memory for incoming batch reservation.try_grow(batch_size)?; @@ -252,11 +270,7 @@ impl ExecutionPlan for CrossJoinExec { "CrossJoinExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -272,6 +286,7 @@ impl ExecutionPlan for CrossJoinExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); Ok(Arc::new(CrossJoinExec::new( Arc::clone(&children[0]), Arc::clone(&children[1]), @@ -285,7 +300,7 @@ impl ExecutionPlan for CrossJoinExec { schema: Arc::clone(&self.schema), left_fut: Default::default(), // reset the build side! metrics: ExecutionPlanMetricsSet::default(), - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), }; Ok(Arc::new(new_exec)) } @@ -356,16 +371,13 @@ impl ExecutionPlan for CrossJoinExec { } } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { // Get the all partitions statistics of the left - let left_stats = self.left.partition_statistics(None)?; - let right_stats = self.right.partition_statistics(partition)?; + let left_stats = Arc::unwrap_or_clone(self.left.partition_statistics(None)?); + let right_stats = + Arc::unwrap_or_clone(self.right.partition_statistics(partition)?); - Ok(stats_cartesian_product(left_stats, right_stats)) + Ok(Arc::new(stats_cartesian_product(left_stats, right_stats))) } /// Tries to swap the projection with its input [`CrossJoinExec`]. If it can be done, @@ -418,13 +430,14 @@ fn stats_cartesian_product( let left_row_count = left_stats.num_rows; let right_row_count = right_stats.num_rows; - // calculate global stats + // Calculate global stats let num_rows = left_row_count.multiply(&right_row_count); - // the result size is two times a*b because you have the columns of both left and right - let total_byte_size = left_stats - .total_byte_size - .multiply(&right_stats.total_byte_size) - .multiply(&Precision::Exact(2)); + + // Each output row includes every left and right column, so the left side is + // repeated once per right row and the right side once per left row. + let left_byte_size = left_stats.total_byte_size.multiply(&right_row_count); + let right_byte_size = right_stats.total_byte_size.multiply(&left_row_count); + let total_byte_size = left_byte_size.add(&right_byte_size); let left_col_stats = left_stats.column_statistics; let right_col_stats = right_stats.column_statistics; @@ -433,31 +446,34 @@ fn stats_cartesian_product( // Min, max and distinct_count on the other hand are invariants. let cross_join_stats = left_col_stats .into_iter() - .map(|s| ColumnStatistics { - null_count: s.null_count.multiply(&right_row_count), - distinct_count: s.distinct_count, - min_value: s.min_value, - max_value: s.max_value, - sum_value: s - .sum_value - .get_value() - // Cast the row count into the same type as any existing sum value - .and_then(|v| { - Precision::::from(right_row_count) - .cast_to(&v.data_type()) - .ok() - }) - .map(|row_count| s.sum_value.multiply(&row_count)) - .unwrap_or(Precision::Absent), + .map(|s| { + let widened_sum = s.sum_value.cast_to_sum_type(); + ColumnStatistics { + null_count: s.null_count.multiply(&right_row_count), + distinct_count: s.distinct_count, + min_value: s.min_value, + max_value: s.max_value, + sum_value: widened_sum + .get_value() + // Cast the row count into the same type as any existing sum value + .and_then(|v| { + Precision::::from(right_row_count) + .cast_to(&v.data_type()) + .ok() + }) + .map(|row_count| widened_sum.multiply(&row_count)) + .unwrap_or(Precision::Absent), + byte_size: Precision::Absent, + } }) .chain(right_col_stats.into_iter().map(|s| { + let widened_sum = s.sum_value.cast_to_sum_type(); ColumnStatistics { null_count: s.null_count.multiply(&left_row_count), distinct_count: s.distinct_count, min_value: s.min_value, max_value: s.max_value, - sum_value: s - .sum_value + sum_value: widened_sum .get_value() // Cast the row count into the same type as any existing sum value .and_then(|v| { @@ -465,8 +481,9 @@ fn stats_cartesian_product( .cast_to(&v.data_type()) .ok() }) - .map(|row_count| s.sum_value.multiply(&row_count)) + .map(|row_count| widened_sum.multiply(&row_count)) .unwrap_or(Precision::Absent), + byte_size: Precision::Absent, } })) .collect(); @@ -478,7 +495,7 @@ fn stats_cartesian_product( } } -/// A stream that issues [RecordBatch]es as they arrive from the right of the join. +/// A stream that issues [RecordBatch]es as they arrive from the right of the join. struct CrossJoinStream { /// Input schema schema: Arc, @@ -620,7 +637,12 @@ impl CrossJoinStream { let right_data = match ready!(self.right.poll_next_unpin(cx)) { Some(Ok(right_data)) => right_data, Some(Err(e)) => return Poll::Ready(Err(e)), - None => return Poll::Ready(Ok(StatefulStreamResult::Ready(None))), + None => { + // Release the right (probe) input pipeline's resources. + let right_schema = self.right.schema(); + self.right = Box::pin(EmptyRecordBatchStream::new(right_schema)); + return Poll::Ready(Ok(StatefulStreamResult::Ready(None))); + } }; self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(right_data.num_rows()); @@ -704,6 +726,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Exact(0), + byte_size: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Exact(1), @@ -711,6 +734,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::from("a")), sum_value: Precision::Absent, null_count: Precision::Exact(3), + byte_size: Precision::Absent, }, ], }; @@ -724,6 +748,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int64(Some(0))), sum_value: Precision::Exact(ScalarValue::Int64(Some(20))), null_count: Precision::Exact(2), + byte_size: Precision::Absent, }], }; @@ -731,7 +756,9 @@ mod tests { let expected = Statistics { num_rows: Precision::Exact(left_row_count * right_row_count), - total_byte_size: Precision::Exact(2 * left_bytes * right_bytes), + total_byte_size: Precision::Exact( + left_bytes * right_row_count + right_bytes * left_row_count, + ), column_statistics: vec![ ColumnStatistics { distinct_count: Precision::Exact(5), @@ -741,6 +768,7 @@ mod tests { 42 * right_row_count as i64, ))), null_count: Precision::Exact(0), + byte_size: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Exact(1), @@ -748,6 +776,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::from("a")), sum_value: Precision::Absent, null_count: Precision::Exact(3 * right_row_count), + byte_size: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Exact(3), @@ -757,6 +786,7 @@ mod tests { 20 * left_row_count as i64, ))), null_count: Precision::Exact(2 * left_row_count), + byte_size: Precision::Absent, }, ], }; @@ -778,6 +808,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), null_count: Precision::Exact(0), + byte_size: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Exact(1), @@ -785,6 +816,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::from("a")), sum_value: Precision::Absent, null_count: Precision::Exact(3), + byte_size: Precision::Absent, }, ], }; @@ -798,6 +830,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int64(Some(0))), sum_value: Precision::Exact(ScalarValue::Int64(Some(20))), null_count: Precision::Exact(2), + byte_size: Precision::Absent, }], }; @@ -813,6 +846,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), sum_value: Precision::Absent, // we don't know the row count on the right null_count: Precision::Absent, // we don't know the row count on the right + byte_size: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Exact(1), @@ -820,6 +854,7 @@ mod tests { min_value: Precision::Exact(ScalarValue::from("a")), sum_value: Precision::Absent, null_count: Precision::Absent, // we don't know the row count on the right + byte_size: Precision::Absent, }, ColumnStatistics { distinct_count: Precision::Exact(3), @@ -829,6 +864,7 @@ mod tests { 20 * left_row_count as i64, ))), null_count: Precision::Exact(2 * left_row_count), + byte_size: Precision::Absent, }, ], }; @@ -836,6 +872,49 @@ mod tests { assert_eq!(result, expected); } + #[tokio::test] + async fn test_stats_cartesian_product_unsigned_sum_widens_to_u64() { + let left_row_count = 2; + let right_row_count = 3; + + let left = Statistics { + num_rows: Precision::Exact(left_row_count), + total_byte_size: Precision::Exact(10), + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Exact(2), + max_value: Precision::Exact(ScalarValue::UInt32(Some(10))), + min_value: Precision::Exact(ScalarValue::UInt32(Some(1))), + sum_value: Precision::Exact(ScalarValue::UInt32(Some(7))), + null_count: Precision::Exact(0), + byte_size: Precision::Absent, + }], + }; + + let right = Statistics { + num_rows: Precision::Exact(right_row_count), + total_byte_size: Precision::Exact(10), + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::UInt32(Some(12))), + min_value: Precision::Exact(ScalarValue::UInt32(Some(0))), + sum_value: Precision::Exact(ScalarValue::UInt32(Some(11))), + null_count: Precision::Exact(0), + byte_size: Precision::Absent, + }], + }; + + let result = stats_cartesian_product(left, right); + + assert_eq!( + result.column_statistics[0].sum_value, + Precision::Exact(ScalarValue::UInt64(Some(21))) + ); + assert_eq!( + result.column_statistics[1].sum_value, + Precision::Exact(ScalarValue::UInt64(Some(22))) + ); + } + #[tokio::test] async fn test_join() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); @@ -855,18 +934,18 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - assert_snapshot!(batches_to_sort_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+----+----+----+----+ - | 1 | 4 | 7 | 10 | 12 | 14 | - | 1 | 4 | 7 | 11 | 13 | 15 | - | 2 | 5 | 8 | 10 | 12 | 14 | - | 2 | 5 | 8 | 11 | 13 | 15 | - | 3 | 6 | 9 | 10 | 12 | 14 | - | 3 | 6 | 9 | 11 | 13 | 15 | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 10 | 12 | 14 | + | 1 | 4 | 7 | 11 | 13 | 15 | + | 2 | 5 | 8 | 10 | 12 | 14 | + | 2 | 5 | 8 | 11 | 13 | 15 | + | 3 | 6 | 9 | 10 | 12 | 14 | + | 3 | 6 | 9 | 11 | 13 | 15 | + +----+----+----+----+----+----+ + "); assert_join_metrics!(metrics, 6); diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 97ee8ecbdba8a..3774a300209d0 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -15,46 +15,54 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashSet; use std::fmt; use std::mem::size_of; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Arc, OnceLock}; -use std::{any::Any, vec}; +use std::vec; -use crate::execution_plan::{boundedness_from_children, EmissionType}; +use crate::ExecutionPlanProperties; +use crate::execution_plan::{ + EmissionType, boundedness_from_children, has_same_children_properties, + stub_properties, +}; use crate::filter_pushdown::{ - ChildPushdownResult, FilterDescription, FilterPushdownPhase, + ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase, FilterPushdownPropagation, }; +use crate::joins::Map; +use crate::joins::array_map::ArrayMap; +use crate::joins::hash_join::inlist_builder::build_struct_inlist_values; use crate::joins::hash_join::shared_bounds::{ - ColumnBounds, PartitionBounds, SharedBuildAccumulator, + ColumnBounds, PartitionBounds, PushdownStrategy, SharedBuildAccumulator, }; use crate::joins::hash_join::stream::{ BuildSide, BuildSideInitialState, HashJoinStream, HashJoinStreamState, }; use crate::joins::join_hash_map::{JoinHashMapU32, JoinHashMapU64}; use crate::joins::utils::{ - asymmetric_join_output_partitioning, reorder_output_after_swap, swap_join_projection, - update_hash, OnceAsync, OnceFut, + OnceAsync, OnceFut, asymmetric_join_output_partitioning, reorder_output_after_swap, + swap_join_projection, update_hash, }; use crate::joins::{JoinOn, JoinOnRef, PartitionMode, SharedBitmapBuilder}; +use crate::metrics::{Count, MetricBuilder, MetricCategory}; use crate::projection::{ - try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, - ProjectionExec, + EmbeddedProjection, JoinData, ProjectionExec, try_embed_projection, + try_pushdown_through_join, }; use crate::repartition::REPARTITION_RANDOM_STATE; use crate::spill::get_record_batch_memory_size; -use crate::ExecutionPlanProperties; use crate::{ + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + PlanProperties, SendableRecordBatchStream, Statistics, common::can_project, joins::utils::{ + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinHashMapType, build_join_schema, check_join_is_valid, estimate_join_statistics, need_produce_result_in_final, symmetric_join_output_partitioning, - BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinHashMapType, }, metrics::{ExecutionPlanMetricsSet, MetricsSet}, - DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, - PlanProperties, SendableRecordBatchStream, Statistics, }; use arrow::array::{ArrayRef, BooleanBufferBuilder}; @@ -62,38 +70,126 @@ use arrow::compute::concat_batches; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; -use arrow_schema::DataType; +use arrow_schema::{DataType, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ - assert_or_internal_err, plan_err, project_schema, JoinSide, JoinType, NullEquality, - Result, + JoinSide, JoinType, NullEquality, Result, assert_or_internal_err, internal_err, + plan_err, project_schema, }; -use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_expr::Accumulator; use datafusion_functions_aggregate_common::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::equivalence::{ - join_equivalence_properties, ProjectionMapping, + ProjectionMapping, join_equivalence_properties, }; -use datafusion_physical_expr::expressions::{lit, DynamicFilterPhysicalExpr}; +use datafusion_physical_expr::expressions::{Column, DynamicFilterPhysicalExpr, lit}; +use datafusion_physical_expr::projection::{ProjectionRef, combine_projections}; use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; -use ahash::RandomState; +use datafusion_common::hash_utils::RandomState; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays; use futures::TryStreamExt; use parking_lot::Mutex; +use super::partitioned_hash_eval::SeededRandomState; + /// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions. -const HASH_JOIN_SEED: RandomState = - RandomState::with_seeds('J' as u64, 'O' as u64, 'I' as u64, 'N' as u64); +pub(crate) const HASH_JOIN_SEED: SeededRandomState = + SeededRandomState::with_seed(12210250226015887276); + +const ARRAY_MAP_CREATED_COUNT_METRIC_NAME: &str = "array_map_created_count"; + +#[expect(clippy::too_many_arguments)] +fn try_create_array_map( + bounds: &Option, + schema: &SchemaRef, + batches: &[RecordBatch], + on_left: &[PhysicalExprRef], + reservation: &mut MemoryReservation, + perfect_hash_join_small_build_threshold: usize, + perfect_hash_join_min_key_density: f64, + null_equality: NullEquality, +) -> Result)>> { + if on_left.len() != 1 { + return Ok(None); + } + + if null_equality == NullEquality::NullEqualsNull { + for batch in batches.iter() { + let arrays = evaluate_expressions_to_arrays(on_left, batch)?; + if arrays[0].null_count() > 0 { + return Ok(None); + } + } + } + + let (min_val, max_val) = if let Some(bounds) = bounds { + let (min_val, max_val) = if let Some(cb) = bounds.get_column_bounds(0) { + (cb.min.clone(), cb.max.clone()) + } else { + return Ok(None); + }; + + if min_val.is_null() || max_val.is_null() { + return Ok(None); + } + + if min_val > max_val { + return internal_err!("min_val>max_val"); + } + + if let Some((mi, ma)) = + ArrayMap::key_to_u64(&min_val).zip(ArrayMap::key_to_u64(&max_val)) + { + (mi, ma) + } else { + return Ok(None); + } + } else { + return Ok(None); + }; + + let range = ArrayMap::calculate_range(min_val, max_val); + let num_row: usize = batches.iter().map(|x| x.num_rows()).sum(); + + // TODO: support create ArrayMap + if num_row >= u32::MAX as usize { + return Ok(None); + } + + // When the key range spans the full integer domain (e.g. i64::MIN to i64::MAX), + // range is u64::MAX and `range + 1` below would overflow. + if range == usize::MAX as u64 { + return Ok(None); + } + + let dense_ratio = (num_row as f64) / ((range + 1) as f64); + + if range >= perfect_hash_join_small_build_threshold as u64 + && dense_ratio <= perfect_hash_join_min_key_density + { + return Ok(None); + } + + let mem_size = ArrayMap::estimate_memory_size(min_val, max_val, num_row); + reservation.try_grow(mem_size)?; + + let batch = concat_batches(schema, batches)?; + let left_values = evaluate_expressions_to_arrays(on_left, &batch)?; + + let array_map = ArrayMap::try_new(&left_values[0], min_val, max_val)?; + + Ok(Some((array_map, batch, left_values))) +} /// HashTable and input data for the left (build side) of a join pub(super) struct JoinLeftData { /// The hash table with indices into `batch` /// Arc is used to allow sharing with SharedBuildAccumulator for hash map pushdown - pub(super) hash_map: Arc, + pub(super) map: Arc, /// The input rows for the build side batch: RecordBatch, /// The build side on expressions values @@ -112,12 +208,20 @@ pub(super) struct JoinLeftData { /// If the partition is empty (no rows) this will be None. /// If the partition has some rows this will be Some with the bounds for each join key column. pub(super) bounds: Option, + /// Membership testing strategy for filter pushdown + /// Contains either InList values for small build sides or hash table reference for large build sides + pub(super) membership: PushdownStrategy, + /// Shared atomic flag indicating if any probe partition saw data (for null-aware anti joins) + /// This is shared across all probe partitions to provide global knowledge + pub(super) probe_side_non_empty: AtomicBool, + /// Shared atomic flag indicating if any probe partition saw NULL in join keys (for null-aware anti joins) + pub(super) probe_side_has_null: AtomicBool, } impl JoinLeftData { - /// return a reference to the hash map - pub(super) fn hash_map(&self) -> &dyn JoinHashMapType { - &*self.hash_map + /// return a reference to the map + pub(super) fn map(&self) -> &Map { + &self.map } /// returns a reference to the build side batch @@ -135,6 +239,11 @@ impl JoinLeftData { &self.visited_indices_bitmap } + /// returns a reference to the InList values for filter pushdown + pub(super) fn membership(&self) -> &PushdownStrategy { + &self.membership + } + /// Decrements the counter of running threads, and returns `true` /// if caller is the last running thread pub(super) fn report_probe_completed(&self) -> bool { @@ -142,7 +251,278 @@ impl JoinLeftData { } } -#[allow(rustdoc::private_intra_doc_links)] +/// Helps to build [`HashJoinExec`]. +/// +/// Builder can be created from an existing [`HashJoinExec`] using [`From::from`]. +/// In this case, all its fields are inherited. If a field that affects the node's +/// properties is modified, they will be automatically recomputed during the build. +/// +/// # Adding setters +/// +/// When adding a new setter, it is necessary to ensure that the `preserve_properties` +/// flag is set to false if modifying the field requires a recomputation of the plan's +/// properties. +/// +pub struct HashJoinExecBuilder { + exec: HashJoinExec, + preserve_properties: bool, +} + +impl HashJoinExecBuilder { + /// Make a new [`HashJoinExecBuilder`]. + pub fn new( + left: Arc, + right: Arc, + on: Vec<(PhysicalExprRef, PhysicalExprRef)>, + join_type: JoinType, + ) -> Self { + Self { + exec: HashJoinExec { + left, + right, + on, + filter: None, + join_type, + left_fut: Default::default(), + random_state: HASH_JOIN_SEED, + mode: PartitionMode::Auto, + fetch: None, + metrics: ExecutionPlanMetricsSet::new(), + projection: None, + column_indices: vec![], + null_equality: NullEquality::NullEqualsNothing, + null_aware: false, + dynamic_filter: None, + // Will be computed at when plan will be built. + cache: stub_properties(), + join_schema: Arc::new(Schema::empty()), + }, + // As `exec` is initialized with stub properties, + // they will be properly computed when plan will be built. + preserve_properties: false, + } + } + + /// Set join type. + pub fn with_type(mut self, join_type: JoinType) -> Self { + self.exec.join_type = join_type; + self.preserve_properties = false; + self + } + + /// Set projection from the vector. + pub fn with_projection(self, projection: Option>) -> Self { + self.with_projection_ref(projection.map(Into::into)) + } + + /// Set projection from the shared reference. + pub fn with_projection_ref(mut self, projection: Option) -> Self { + self.exec.projection = projection; + self.preserve_properties = false; + self + } + + /// Set optional filter. + pub fn with_filter(mut self, filter: Option) -> Self { + self.exec.filter = filter; + self + } + + /// Set expressions to join on. + pub fn with_on(mut self, on: Vec<(PhysicalExprRef, PhysicalExprRef)>) -> Self { + self.exec.on = on; + self.preserve_properties = false; + self + } + + /// Set partition mode. + pub fn with_partition_mode(mut self, mode: PartitionMode) -> Self { + self.exec.mode = mode; + self.preserve_properties = false; + self + } + + /// Set null equality property. + pub fn with_null_equality(mut self, null_equality: NullEquality) -> Self { + self.exec.null_equality = null_equality; + self + } + + /// Set null aware property. + pub fn with_null_aware(mut self, null_aware: bool) -> Self { + self.exec.null_aware = null_aware; + self + } + + /// Set fetch property. + pub fn with_fetch(mut self, fetch: Option) -> Self { + self.exec.fetch = fetch; + self + } + + /// Require to recompute plan properties. + pub fn recompute_properties(mut self) -> Self { + self.preserve_properties = false; + self + } + + /// Replace children. + pub fn with_new_children( + mut self, + mut children: Vec>, + ) -> Result { + assert_or_internal_err!( + children.len() == 2, + "wrong number of children passed into `HashJoinExecBuilder`" + ); + self.preserve_properties &= has_same_children_properties(&self.exec, &children)?; + self.exec.right = children.swap_remove(1); + self.exec.left = children.swap_remove(0); + Ok(self) + } + + /// Reset runtime state. + pub fn reset_state(mut self) -> Self { + self.exec.left_fut = Default::default(); + self.exec.dynamic_filter = None; + self.exec.metrics = ExecutionPlanMetricsSet::new(); + self + } + + /// Build result as a dyn execution plan. + pub fn build_exec(self) -> Result> { + self.build().map(|p| Arc::new(p) as _) + } + + /// Build resulting execution plan. + pub fn build(self) -> Result { + let Self { + exec, + preserve_properties, + } = self; + + // Validate null_aware flag + if exec.null_aware { + let join_type = exec.join_type(); + if !matches!(join_type, JoinType::LeftAnti) { + return plan_err!( + "null_aware can only be true for LeftAnti joins, got {join_type}" + ); + } + let on = exec.on(); + if on.len() != 1 { + return plan_err!( + "null_aware anti join only supports single column join key, got {} columns", + on.len() + ); + } + } + + if preserve_properties { + return Ok(exec); + } + + let HashJoinExec { + left, + right, + on, + filter, + join_type, + left_fut, + random_state, + mode, + metrics, + projection, + null_equality, + null_aware, + dynamic_filter, + fetch, + // Recomputed. + join_schema: _, + column_indices: _, + cache: _, + } = exec; + + let left_schema = left.schema(); + let right_schema = right.schema(); + if on.is_empty() { + return plan_err!("On constraints in HashJoinExec should be non-empty"); + } + + check_join_is_valid(&left_schema, &right_schema, &on)?; + let (join_schema, column_indices) = + build_join_schema(&left_schema, &right_schema, &join_type); + + let join_schema = Arc::new(join_schema); + + // Check if the projection is valid. + can_project(&join_schema, projection.as_deref())?; + + let cache = HashJoinExec::compute_properties( + &left, + &right, + &join_schema, + join_type, + &on, + mode, + projection.as_deref(), + )?; + + Ok(HashJoinExec { + left, + right, + on, + filter, + join_type, + join_schema, + left_fut, + random_state, + mode, + metrics, + projection, + column_indices, + null_equality, + null_aware, + cache: Arc::new(cache), + dynamic_filter, + fetch, + }) + } + + fn with_dynamic_filter(mut self, filter: Option) -> Self { + self.exec.dynamic_filter = filter; + self + } +} + +impl From<&HashJoinExec> for HashJoinExecBuilder { + fn from(exec: &HashJoinExec) -> Self { + Self { + exec: HashJoinExec { + left: Arc::clone(exec.left()), + right: Arc::clone(exec.right()), + on: exec.on.clone(), + filter: exec.filter.clone(), + join_type: exec.join_type, + join_schema: Arc::clone(&exec.join_schema), + left_fut: Arc::clone(&exec.left_fut), + random_state: exec.random_state.clone(), + mode: exec.mode, + metrics: exec.metrics.clone(), + projection: exec.projection.clone(), + column_indices: exec.column_indices.clone(), + null_equality: exec.null_equality, + null_aware: exec.null_aware, + cache: Arc::clone(&exec.cache), + dynamic_filter: exec.dynamic_filter.clone(), + fetch: exec.fetch, + }, + preserve_properties: true, + } + } +} + +#[expect(rustdoc::private_intra_doc_links)] /// Join execution plan: Evaluates equijoin predicates in parallel on multiple /// partitions using a hash table and an optional filter list to apply post /// join. @@ -157,6 +537,36 @@ impl JoinLeftData { /// ` != `) are known as "filter expressions" and are evaluated /// after the equijoin predicates. /// +/// # ArrayMap Optimization +/// +/// For joins with a single integer-based join key, `HashJoinExec` may use an [`ArrayMap`] +/// (also known as a "perfect hash join") instead of a general-purpose hash map. +/// This optimization is used when: +/// 1. There is exactly one join key. +/// 2. The join key is an integer type up to 64 bits wide that can be losslessly converted +/// to `u64` (128-bit integer types such as `i128` and `u128` are not supported). +/// 3. The range of keys is small enough (controlled by `perfect_hash_join_small_build_threshold`) +/// OR the keys are sufficiently dense (controlled by `perfect_hash_join_min_key_density`). +/// 4. build_side.num_rows() < u32::MAX +/// 5. NullEqualsNothing || (NullEqualsNull && build side doesn't contain null) +/// +/// See [`try_create_array_map`] for more details. +/// +/// Note that when using [`PartitionMode::Partitioned`], the build side is split into multiple +/// partitions. This can cause a dense build side to become sparse within each partition, +/// potentially disabling this optimization. +/// +/// For example, consider: +/// ```sql +/// SELECT t1.value, t2.value +/// FROM range(10000) AS t1 +/// JOIN range(10000) AS t2 +/// ON t1.value = t2.value; +/// ``` +/// With 24 partitions, each partition will only receive a subset of the 10,000 rows. +/// The first partition might contain values like `3, 10, 18, 39, 43`, which are sparse +/// relative to the original range, even though the overall data set is dense. +/// /// # "Build Side" vs "Probe Side" /// /// HashJoin takes two inputs, which are referred to as the "build" and the @@ -190,9 +600,9 @@ impl JoinLeftData { /// Resulting hash table stores hashed join-key fields for each row as a key, and /// indices of corresponding rows in concatenated batch. /// -/// Hash join uses LIFO data structure as a hash table, and in order to retain -/// original build-side input order while obtaining data during probe phase, hash -/// table is updated by iterating batch sequence in reverse order -- it allows to +/// When using the standard `JoinHashMap`, hash join uses LIFO data structure as a hash table, +/// and in order to retain original build-side input order while obtaining data during probe phase, +/// hash table is updated by iterating batch sequence in reverse order -- it allows to /// keep rows with smaller indices "on the top" of hash table, and still maintain /// correct indexing for concatenated build-side data batch. /// @@ -325,24 +735,28 @@ pub struct HashJoinExec { /// Each output stream waits on the `OnceAsync` to signal the completion of /// the hash table creation. left_fut: Arc>, - /// Shared the `RandomState` for the hashing algorithm - random_state: RandomState, + /// Shared the `SeededRandomState` for the hashing algorithm (seeds preserved for serialization) + random_state: SeededRandomState, /// Partitioning mode to use pub mode: PartitionMode, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// The projection indices of the columns in the output schema of join - pub projection: Option>, + pub projection: Option, /// Information of index and left / right placement of columns column_indices: Vec, /// The equality null-handling behavior of the join algorithm. pub null_equality: NullEquality, + /// Flag to indicate if this is a null-aware anti join + pub null_aware: bool, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, /// Dynamic filter for pushing down to the probe side /// Set when dynamic filter pushdown is detected in handle_child_pushdown_result. /// HashJoinExec also needs to keep a shared bounds accumulator for coordinating updates. dynamic_filter: Option, + /// Maximum number of rows to return + fetch: Option, } #[derive(Clone)] @@ -383,11 +797,11 @@ impl EmbeddedProjection for HashJoinExec { } impl HashJoinExec { - /// Tries to create a new [HashJoinExec]. + /// Tries to create a new [`HashJoinExec`]. /// /// # Error /// This function errors when it is not possible to join the left and right sides on keys `on`. - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] pub fn try_new( left: Arc, right: Arc, @@ -397,55 +811,24 @@ impl HashJoinExec { projection: Option>, partition_mode: PartitionMode, null_equality: NullEquality, + null_aware: bool, ) -> Result { - let left_schema = left.schema(); - let right_schema = right.schema(); - if on.is_empty() { - return plan_err!("On constraints in HashJoinExec should be non-empty"); - } - - check_join_is_valid(&left_schema, &right_schema, &on)?; - - let (join_schema, column_indices) = - build_join_schema(&left_schema, &right_schema, join_type); - - let random_state = HASH_JOIN_SEED; - - let join_schema = Arc::new(join_schema); - - // check if the projection is valid - can_project(&join_schema, projection.as_ref())?; - - let cache = Self::compute_properties( - &left, - &right, - &join_schema, - *join_type, - &on, - partition_mode, - projection.as_ref(), - )?; - - // Initialize both dynamic filter and bounds accumulator to None - // They will be set later if dynamic filtering is enabled + HashJoinExecBuilder::new(left, right, on, *join_type) + .with_filter(filter) + .with_projection(projection) + .with_partition_mode(partition_mode) + .with_null_equality(null_equality) + .with_null_aware(null_aware) + .build() + } - Ok(HashJoinExec { - left, - right, - on, - filter, - join_type: *join_type, - join_schema, - left_fut: Default::default(), - random_state, - mode: partition_mode, - metrics: ExecutionPlanMetricsSet::new(), - projection, - column_indices, - null_equality, - cache, - dynamic_filter: None, - }) + /// Create a builder based on the existing [`HashJoinExec`]. + /// + /// Returned builder preserves all existing fields. If a field requiring properties + /// recomputation is modified, this will be done automatically during the node build. + /// + pub fn builder(&self) -> HashJoinExecBuilder { + self.into() } fn create_dynamic_filter(on: &JoinOn) -> Arc { @@ -456,6 +839,27 @@ impl HashJoinExec { Arc::new(DynamicFilterPhysicalExpr::new(right_keys, lit(true))) } + fn allow_join_dynamic_filter_pushdown(&self, config: &ConfigOptions) -> bool { + let (_, probe_preserved) = self.join_type.on_lr_is_preserved(); + if !probe_preserved || !config.optimizer.enable_join_dynamic_filter_pushdown { + return false; + } + + // `preserve_file_partitions` can report Hash partitioning for Hive-style + // file groups, but those partitions are not actually hash-distributed. + // Partitioned dynamic filters rely on hash routing, so disable them in + // this mode to avoid incorrect results. Follow-up work: enable dynamic + // filtering for preserve_file_partitioned scans (issue #20195). + // https://github.com/apache/datafusion/issues/20195 + if config.optimizer.preserve_file_partitions > 0 + && self.mode == PartitionMode::Partitioned + { + return false; + } + + true + } + /// left (build) side which gets hashed pub fn left(&self) -> &Arc { &self.left @@ -497,6 +901,35 @@ impl HashJoinExec { self.null_equality } + /// Get the dynamic filter expression for testing purposes. + /// Returns the dynamic filter expression for this hash join, if set. + pub fn dynamic_filter_expr(&self) -> Option<&Arc> { + self.dynamic_filter.as_ref().map(|df| &df.filter) + } + + /// Set the dynamic filter on this hash join. + /// + /// Resets any internal state that depends on any existing dynamic filter. + /// + /// Validates that the filter's children reference valid columns in + /// the probe (right) side's schema. + pub fn with_dynamic_filter_expr( + mut self, + filter: Arc, + ) -> Result { + let probe_schema = self.right.schema(); + for child in filter.children() { + child.data_type(&probe_schema)?; + } + self.dynamic_filter = Some(HashJoinExecDynamicFilter { + filter, + // Initialize with an empty accumulator which will be lazily populated + // during execution. + build_accumulator: OnceLock::new(), + }); + Ok(self) + } + /// Calculate order preservation flags for this hash join. fn maintains_input_order(join_type: JoinType) -> Vec { vec![ @@ -525,25 +958,12 @@ impl HashJoinExec { /// Return new instance of [HashJoinExec] with the given projection. pub fn with_projection(&self, projection: Option>) -> Result { + let projection = projection.map(Into::into); // check if the projection is valid - can_project(&self.schema(), projection.as_ref())?; - let projection = match projection { - Some(projection) => match &self.projection { - Some(p) => Some(projection.iter().map(|i| p[*i]).collect()), - None => Some(projection), - }, - None => None, - }; - Self::try_new( - Arc::clone(&self.left), - Arc::clone(&self.right), - self.on.clone(), - self.filter.clone(), - &self.join_type, - projection, - self.mode, - self.null_equality, - ) + can_project(&self.schema(), projection.as_deref())?; + let projection = + combine_projections(projection.as_ref(), self.projection.as_ref())?; + self.builder().with_projection_ref(projection).build() } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -554,7 +974,7 @@ impl HashJoinExec { join_type: JoinType, on: JoinOnRef, mode: PartitionMode, - projection: Option<&Vec>, + projection: Option<&[usize]>, ) -> Result { // Calculate equivalence properties: let mut eq_properties = join_equivalence_properties( @@ -606,7 +1026,7 @@ impl HashJoinExec { if let Some(projection) = projection { // construct a map from the input expressions to the output expression of the Projection let projection_mapping = ProjectionMapping::from_indices(projection, schema)?; - let out_schema = project_schema(schema, Some(projection))?; + let out_schema = project_schema(schema, Some(&projection))?; output_partitioning = output_partitioning.project(&projection_mapping, &eq_properties); eq_properties = eq_properties.project(&projection_mapping, out_schema); @@ -649,24 +1069,25 @@ impl HashJoinExec { ) -> Result> { let left = self.left(); let right = self.right(); - let new_join = HashJoinExec::try_new( - Arc::clone(right), - Arc::clone(left), - self.on() - .iter() - .map(|(l, r)| (Arc::clone(r), Arc::clone(l))) - .collect(), - self.filter().map(JoinFilter::swap), - &self.join_type().swap(), - swap_join_projection( + let new_join = self + .builder() + .with_type(self.join_type.swap()) + .with_new_children(vec![Arc::clone(right), Arc::clone(left)])? + .with_on( + self.on() + .iter() + .map(|(l, r)| (Arc::clone(r), Arc::clone(l))) + .collect(), + ) + .with_filter(self.filter().map(JoinFilter::swap)) + .with_projection(swap_join_projection( left.schema().fields().len(), right.schema().fields().len(), - self.projection.as_ref(), + self.projection.as_deref(), self.join_type(), - ), - partition_mode, - self.null_equality(), - )?; + )) + .with_partition_mode(partition_mode) + .build()?; // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again if matches!( self.join_type(), @@ -712,11 +1133,14 @@ impl DisplayAs for HashJoinExec { "".to_string() }; let display_null_equality = - if matches!(self.null_equality(), NullEquality::NullEqualsNull) { + if self.null_equality() == NullEquality::NullEqualsNull { ", NullsEqual: true" } else { "" }; + let display_fetch = self + .fetch + .map_or_else(String::new, |f| format!(", fetch={f}")); let on = self .on .iter() @@ -725,13 +1149,14 @@ impl DisplayAs for HashJoinExec { .join(", "); write!( f, - "HashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}{}{}", + "HashJoinExec: mode={:?}, join_type={:?}, on=[{}]{}{}{}{}", self.mode, self.join_type, on, display_filter, display_projections, display_null_equality, + display_fetch, ) } DisplayFormatType::TreeRender => { @@ -750,7 +1175,7 @@ impl DisplayAs for HashJoinExec { writeln!(f, "on={on}")?; - if matches!(self.null_equality(), NullEquality::NullEqualsNull) { + if self.null_equality() == NullEquality::NullEqualsNull { writeln!(f, "NullsEqual: true")?; } @@ -758,6 +1183,10 @@ impl DisplayAs for HashJoinExec { writeln!(f, "filter={filter}")?; } + if let Some(fetch) = self.fetch { + writeln!(f, "fetch={fetch}")?; + } + Ok(()) } } @@ -769,11 +1198,7 @@ impl ExecutionPlan for HashJoinExec { "HashJoinExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -834,54 +1259,11 @@ impl ExecutionPlan for HashJoinExec { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(HashJoinExec { - left: Arc::clone(&children[0]), - right: Arc::clone(&children[1]), - on: self.on.clone(), - filter: self.filter.clone(), - join_type: self.join_type, - join_schema: Arc::clone(&self.join_schema), - left_fut: Arc::clone(&self.left_fut), - random_state: self.random_state.clone(), - mode: self.mode, - metrics: ExecutionPlanMetricsSet::new(), - projection: self.projection.clone(), - column_indices: self.column_indices.clone(), - null_equality: self.null_equality, - cache: Self::compute_properties( - &children[0], - &children[1], - &self.join_schema, - self.join_type, - &self.on, - self.mode, - self.projection.as_ref(), - )?, - // Keep the dynamic filter, bounds accumulator will be reset - dynamic_filter: self.dynamic_filter.clone(), - })) + self.builder().with_new_children(children)?.build_exec() } fn reset_state(self: Arc) -> Result> { - Ok(Arc::new(HashJoinExec { - left: Arc::clone(&self.left), - right: Arc::clone(&self.right), - on: self.on.clone(), - filter: self.filter.clone(), - join_type: self.join_type, - join_schema: Arc::clone(&self.join_schema), - // Reset the left_fut to allow re-execution - left_fut: Arc::new(OnceAsync::default()), - random_state: self.random_state.clone(), - mode: self.mode, - metrics: ExecutionPlanMetricsSet::new(), - projection: self.projection.clone(), - column_indices: self.column_indices.clone(), - null_equality: self.null_equality, - cache: self.cache.clone(), - // Reset dynamic filter and bounds accumulator to initial state - dynamic_filter: None, - })) + self.builder().reset_state().build_exec() } fn execute( @@ -910,25 +1292,71 @@ impl ExecutionPlan for HashJoinExec { consider using CoalescePartitionsExec or the EnforceDistribution rule" ); - let enable_dynamic_filter_pushdown = self.dynamic_filter.is_some(); + // Only enable dynamic filter pushdown if: + // - The session config enables dynamic filter pushdown + // - A dynamic filter exists + // - At least one consumer is holding a reference to it, this avoids expensive filter + // computation when disabled or when no consumer will use it. + let enable_dynamic_filter_pushdown = self + .allow_join_dynamic_filter_pushdown(context.session_config().options()) + && self + .dynamic_filter + .as_ref() + .map(|df| df.filter.is_used()) + .unwrap_or(false); let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); - let left_fut = match self.mode { - PartitionMode::CollectLeft => self.left_fut.try_once(|| { - let left_stream = self.left.execute(0, Arc::clone(&context))?; - let reservation = - MemoryConsumer::new("HashJoinInput").register(context.memory_pool()); + let array_map_created_count = MetricBuilder::new(&self.metrics) + .with_category(MetricCategory::Rows) + .counter(ARRAY_MAP_CREATED_COUNT_METRIC_NAME, partition); - Ok(collect_left_input( - self.random_state.clone(), - left_stream, - on_left.clone(), - join_metrics.clone(), - reservation, + // Initialize build_accumulator lazily with runtime partition counts (only if enabled) + // Use RepartitionExec's random state (seeds: 0,0,0,0) for partition routing + let repartition_random_state = REPARTITION_RANDOM_STATE; + let build_accumulator = enable_dynamic_filter_pushdown + .then(|| { + self.dynamic_filter.as_ref().map(|df| { + let filter = Arc::clone(&df.filter); + let on_right = self + .on + .iter() + .map(|(_, right_expr)| Arc::clone(right_expr)) + .collect::>(); + Some(Arc::clone(df.build_accumulator.get_or_init(|| { + Arc::new(SharedBuildAccumulator::new_from_partition_mode( + self.mode, + self.left.as_ref(), + self.right.as_ref(), + filter, + on_right, + repartition_random_state, + )) + }))) + }) + }) + .flatten() + .flatten(); + + let left_fut = match self.mode { + PartitionMode::CollectLeft => self.left_fut.try_once(|| { + let left_stream = self.left.execute(0, Arc::clone(&context))?; + + let reservation = + MemoryConsumer::new("HashJoinInput").register(context.memory_pool()); + + Ok(collect_left_input( + self.random_state.random_state().clone(), + left_stream, + on_left.clone(), + join_metrics.clone(), + reservation, need_produce_result_in_final(self.join_type), self.right().output_partitioning().partition_count(), enable_dynamic_filter_pushdown, + Arc::clone(context.session_config().options()), + self.null_equality, + array_map_created_count, )) })?, PartitionMode::Partitioned => { @@ -937,9 +1365,8 @@ impl ExecutionPlan for HashJoinExec { let reservation = MemoryConsumer::new(format!("HashJoinInput[{partition}]")) .register(context.memory_pool()); - OnceFut::new(collect_left_input( - self.random_state.clone(), + self.random_state.random_state().clone(), left_stream, on_left.clone(), join_metrics.clone(), @@ -947,6 +1374,9 @@ impl ExecutionPlan for HashJoinExec { need_produce_result_in_final(self.join_type), 1, enable_dynamic_filter_pushdown, + Arc::clone(context.session_config().options()), + self.null_equality, + array_map_created_count, )) } PartitionMode::Auto => { @@ -959,39 +1389,12 @@ impl ExecutionPlan for HashJoinExec { let batch_size = context.session_config().batch_size(); - // Initialize build_accumulator lazily with runtime partition counts (only if enabled) - // Use RepartitionExec's random state (seeds: 0,0,0,0) for partition routing - let repartition_random_state = REPARTITION_RANDOM_STATE; - let build_accumulator = enable_dynamic_filter_pushdown - .then(|| { - self.dynamic_filter.as_ref().map(|df| { - let filter = Arc::clone(&df.filter); - let on_right = self - .on - .iter() - .map(|(_, right_expr)| Arc::clone(right_expr)) - .collect::>(); - Some(Arc::clone(df.build_accumulator.get_or_init(|| { - Arc::new(SharedBuildAccumulator::new_from_partition_mode( - self.mode, - self.left.as_ref(), - self.right.as_ref(), - filter, - on_right, - repartition_random_state, - )) - }))) - }) - }) - .flatten() - .flatten(); - // we have the batches and the hash map with their keys. We can how create a stream // over the right that uses this information to issue new batches. let right_stream = self.right.execute(partition, context)?; // update column indices to reflect the projection - let column_indices_after_projection = match &self.projection { + let column_indices_after_projection = match self.projection.as_ref() { Some(projection) => projection .iter() .map(|i| self.column_indices[*i].clone()) @@ -1012,7 +1415,7 @@ impl ExecutionPlan for HashJoinExec { self.filter.clone(), self.join_type, right_stream, - self.random_state.clone(), + self.random_state.random_state().clone(), join_metrics, column_indices_after_projection, self.null_equality, @@ -1023,6 +1426,8 @@ impl ExecutionPlan for HashJoinExec { self.right.output_ordering().is_some(), build_accumulator, self.mode, + self.null_aware, + self.fetch, ))) } @@ -1030,26 +1435,63 @@ impl ExecutionPlan for HashJoinExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } + fn partition_statistics(&self, partition: Option) -> Result> { + let stats = match (partition, self.mode) { + // For CollectLeft mode, the left side is collected into a single partition, + // so all left partitions are available to each output partition. + // For the right side, we need the specific partition statistics. + (Some(partition), PartitionMode::CollectLeft) => { + let left_stats = self.left.partition_statistics(None)?; + let right_stats = self.right.partition_statistics(Some(partition))?; + + estimate_join_statistics( + Arc::unwrap_or_clone(left_stats), + Arc::unwrap_or_clone(right_stats), + &self.on, + self.null_equality, + &self.join_type, + &self.join_schema, + )? + } - fn partition_statistics(&self, partition: Option) -> Result { - if partition.is_some() { - return Ok(Statistics::new_unknown(&self.schema())); - } - // TODO stats: it is not possible in general to know the output size of joins - // There are some special cases though, for example: - // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` - let stats = estimate_join_statistics( - self.left.partition_statistics(None)?, - self.right.partition_statistics(None)?, - &self.on, - &self.join_type, - &self.join_schema, - )?; + // For Partitioned mode, both sides are partitioned, so each output partition + // only has access to the corresponding partition from both sides. + (Some(partition), PartitionMode::Partitioned) => { + let left_stats = self.left.partition_statistics(Some(partition))?; + let right_stats = self.right.partition_statistics(Some(partition))?; + + estimate_join_statistics( + Arc::unwrap_or_clone(left_stats), + Arc::unwrap_or_clone(right_stats), + &self.on, + self.null_equality, + &self.join_type, + &self.join_schema, + )? + } + + // For Auto mode or when no specific partition is requested, fall back to + // the current behavior of getting all partition statistics. + (None, _) | (Some(_), PartitionMode::Auto) => { + // TODO stats: it is not possible in general to know the output size of joins + // There are some special cases though, for example: + // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` + let left_stats = self.left.partition_statistics(None)?; + let right_stats = self.right.partition_statistics(None)?; + estimate_join_statistics( + Arc::unwrap_or_clone(left_stats), + Arc::unwrap_or_clone(right_stats), + &self.on, + self.null_equality, + &self.join_type, + &self.join_schema, + )? + } + }; // Project statistics if there is a projection - Ok(stats.project(self.projection.as_ref())) + let stats = stats.project(self.projection.as_ref()); + // Apply fetch limit to statistics + Ok(Arc::new(stats.with_fetch(self.fetch, 0, 1)?)) } /// Tries to push `projection` down through `hash_join`. If possible, performs the @@ -1078,17 +1520,17 @@ impl ExecutionPlan for HashJoinExec { &schema, self.filter(), )? { - Ok(Some(Arc::new(HashJoinExec::try_new( - Arc::new(projected_left_child), - Arc::new(projected_right_child), - join_on, - join_filter, - self.join_type(), + self.builder() + .with_new_children(vec![ + Arc::new(projected_left_child), + Arc::new(projected_right_child), + ])? + .with_on(join_on) + .with_filter(join_filter) // Returned early if projection is not None - None, - *self.partition_mode(), - self.null_equality, - )?))) + .with_projection(None) + .build_exec() + .map(Some) } else { try_embed_projection(projection, self) } @@ -1100,30 +1542,111 @@ impl ExecutionPlan for HashJoinExec { parent_filters: Vec>, config: &ConfigOptions, ) -> Result { - // Other types of joins can support *some* filters, but restrictions are complex and error prone. - // For now we don't support them. - // See the logical optimizer rules for more details: datafusion/optimizer/src/push_down_filter.rs - // See https://github.com/apache/datafusion/issues/16973 for tracking. - if self.join_type != JoinType::Inner { - return Ok(FilterDescription::all_unsupported( - &parent_filters, - &self.children(), - )); + // This is the physical-plan equivalent of `push_down_all_join` in + // `datafusion/optimizer/src/push_down_filter.rs`. That function uses `lr_is_preserved` + // to decide which parent predicates can be pushed past a logical join to its children, + // then checks column references to route each predicate to the correct side. + // + // We apply the same two-level logic here: + // 1. `lr_is_preserved` gates whether a side is eligible at all. + // 2. For each filter, we check that all column references belong to the + // target child (using `column_indices` to map output column positions + // to join sides). This is critical for correctness: name-based matching + // alone (as done by `ChildFilterDescription::from_child`) can incorrectly + // push filters when different join sides have columns with the same name + // (e.g. nested mark joins both producing "mark" columns). + let (left_preserved, right_preserved) = lr_is_preserved(self.join_type); + + // Build the set of allowed column indices for each side + let column_indices: Vec = match self.projection.as_ref() { + Some(projection) => projection + .iter() + .map(|i| self.column_indices[*i].clone()) + .collect(), + None => self.column_indices.clone(), + }; + + let (mut left_allowed, mut right_allowed) = (HashSet::new(), HashSet::new()); + column_indices + .iter() + .enumerate() + .for_each(|(output_idx, ci)| { + match ci.side { + JoinSide::Left => left_allowed.insert(output_idx), + JoinSide::Right => right_allowed.insert(output_idx), + // Mark columns - don't allow pushdown to either side + JoinSide::None => false, + }; + }); + + // For semi/anti joins, the non-preserved side's columns are not in the + // output, but filters on join key columns can still be pushed there. + // We find output columns that are join keys on the preserved side and + // add their output indices to the non-preserved side's allowed set. + // The name-based remap in FilterRemapper will then match them to the + // corresponding column in the non-preserved child's schema. + match self.join_type { + JoinType::LeftSemi | JoinType::LeftAnti => { + let left_key_indices: HashSet = self + .on + .iter() + .filter_map(|(left_key, _)| { + left_key.downcast_ref::().map(|c| c.index()) + }) + .collect(); + for (output_idx, ci) in column_indices.iter().enumerate() { + if ci.side == JoinSide::Left && left_key_indices.contains(&ci.index) { + right_allowed.insert(output_idx); + } + } + } + JoinType::RightSemi | JoinType::RightAnti => { + let right_key_indices: HashSet = self + .on + .iter() + .filter_map(|(_, right_key)| { + right_key.downcast_ref::().map(|c| c.index()) + }) + .collect(); + for (output_idx, ci) in column_indices.iter().enumerate() { + if ci.side == JoinSide::Right && right_key_indices.contains(&ci.index) + { + left_allowed.insert(output_idx); + } + } + } + _ => {} } - // Get basic filter descriptions for both children - let left_child = crate::filter_pushdown::ChildFilterDescription::from_child( - &parent_filters, - self.left(), - )?; - let mut right_child = crate::filter_pushdown::ChildFilterDescription::from_child( - &parent_filters, - self.right(), - )?; + let left_child = if left_preserved { + ChildFilterDescription::from_child_with_allowed_indices( + &parent_filters, + left_allowed, + self.left(), + )? + } else { + ChildFilterDescription::all_unsupported(&parent_filters) + }; + + let mut right_child = if right_preserved { + ChildFilterDescription::from_child_with_allowed_indices( + &parent_filters, + right_allowed, + self.right(), + )? + } else { + ChildFilterDescription::all_unsupported(&parent_filters) + }; - // Add dynamic filters in Post phase if enabled - if matches!(phase, FilterPushdownPhase::Post) - && config.optimizer.enable_join_dynamic_filter_pushdown + // Add dynamic filters in Post phase if enabled. Skip when this join + // already carries a dynamic filter from a previous pass — the shared + // `Arc` is still wired into the probe-side + // scan's predicate, and re-creating it would AND a fresh duplicate + // onto every Post-phase invocation (apache/datafusion-ballista#1359 + // surfaces this in AQE replan loops). + if phase == FilterPushdownPhase::Post + && self.dynamic_filter.is_none() + && self.allow_join_dynamic_filter_pushdown(config) { // Add actual dynamic filter to right side (probe side) let dynamic_filter = Self::create_dynamic_filter(&self.on); @@ -1141,23 +1664,10 @@ impl ExecutionPlan for HashJoinExec { child_pushdown_result: ChildPushdownResult, _config: &ConfigOptions, ) -> Result>> { - // Note: this check shouldn't be necessary because we already marked all parent filters as unsupported for - // non-inner joins in `gather_filters_for_pushdown`. - // However it's a cheap check and serves to inform future devs touching this function that they need to be really - // careful pushing down filters through non-inner joins. - if self.join_type != JoinType::Inner { - // Other types of joins can support *some* filters, but restrictions are complex and error prone. - // For now we don't support them. - // See the logical optimizer rules for more details: datafusion/optimizer/src/push_down_filter.rs - return Ok(FilterPushdownPropagation::all_unsupported( - child_pushdown_result, - )); - } - let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone()); assert_eq!(child_pushdown_result.self_filters.len(), 2); // Should always be 2, we have 2 children let right_child_self_filters = &child_pushdown_result.self_filters[1]; // We only push down filters to the right child - // We expect 0 or 1 self filters + // We expect 0 or 1 self filters if let Some(filter) = right_child_self_filters.first() { // Note that we don't check PushdDownPredicate::discrimnant because even if nothing said // "yes, I can fully evaluate this filter" things might still use it for statistics -> it's worth updating @@ -1166,31 +1676,58 @@ impl ExecutionPlan for HashJoinExec { Arc::downcast::(predicate) { // We successfully pushed down our self filter - we need to make a new node with the dynamic filter - let new_node = Arc::new(HashJoinExec { - left: Arc::clone(&self.left), - right: Arc::clone(&self.right), - on: self.on.clone(), - filter: self.filter.clone(), - join_type: self.join_type, - join_schema: Arc::clone(&self.join_schema), - left_fut: Arc::clone(&self.left_fut), - random_state: self.random_state.clone(), - mode: self.mode, - metrics: ExecutionPlanMetricsSet::new(), - projection: self.projection.clone(), - column_indices: self.column_indices.clone(), - null_equality: self.null_equality, - cache: self.cache.clone(), - dynamic_filter: Some(HashJoinExecDynamicFilter { + let new_node = self + .builder() + .with_dynamic_filter(Some(HashJoinExecDynamicFilter { filter: dynamic_filter, build_accumulator: OnceLock::new(), - }), - }); - result = result.with_updated_node(new_node as Arc); + })) + .build_exec()?; + result = result.with_updated_node(new_node); } } Ok(result) } + + fn supports_limit_pushdown(&self) -> bool { + // Hash join execution plan does not support pushing limit down through to children + // because the children don't know about the join condition and can't + // determine how many rows to produce + false + } + + fn fetch(&self) -> Option { + self.fetch + } + + fn with_fetch(&self, limit: Option) -> Option> { + self.builder() + .with_fetch(limit) + .build() + .ok() + .map(|exec| Arc::new(exec) as _) + } +} + +/// Determines which sides of a join are "preserved" for filter pushdown. +/// +/// A preserved side means filters on that side's columns can be safely pushed +/// below the join. This mirrors the logic in the logical optimizer's +/// `lr_is_preserved` in `datafusion/optimizer/src/push_down_filter.rs`. +fn lr_is_preserved(join_type: JoinType) -> (bool, bool) { + match join_type { + JoinType::Inner => (true, true), + JoinType::Left => (true, false), + JoinType::Right => (false, true), + JoinType::Full => (false, false), + // Filters in semi/anti joins are either on the preserved side, or on join keys, + // as all output columns come from the preserved side. Join key filters can be + // safely pushed down into the other side. + JoinType::LeftSemi | JoinType::LeftAnti => (true, true), + JoinType::RightSemi | JoinType::RightAnti => (true, true), + JoinType::LeftMark => (true, false), + JoinType::RightMark => (false, true), + } } /// Accumulator for collecting min/max bounds from build-side data during hash join. @@ -1308,6 +1845,19 @@ impl BuildSideState { } } +fn should_collect_min_max_for_perfect_hash( + on_left: &[PhysicalExprRef], + schema: &SchemaRef, +) -> Result { + if on_left.len() != 1 { + return Ok(false); + } + + let expr = &on_left[0]; + let data_type = expr.data_type(schema)?; + Ok(ArrayMap::is_supported_type(&data_type)) +} + /// Collects all batches from the left (build) side stream and creates a hash map for joining. /// /// This function is responsible for: @@ -1336,7 +1886,7 @@ impl BuildSideState { /// # Returns /// `JoinLeftData` containing the hash map, consolidated batch, join key values, /// visited indices bitmap, and computed bounds (if requested). -#[allow(clippy::too_many_arguments)] +#[expect(clippy::too_many_arguments)] async fn collect_left_input( random_state: RandomState, left_stream: SendableRecordBatchStream, @@ -1346,18 +1896,21 @@ async fn collect_left_input( with_visited_indices_bitmap: bool, probe_threads_count: usize, should_compute_dynamic_filters: bool, + config: Arc, + null_equality: NullEquality, + array_map_created_count: Count, ) -> Result { let schema = left_stream.schema(); - // This operation performs 2 steps at once: - // 1. creates a [JoinHashMap] of all batches from the stream - // 2. stores the batches in a vector. + let should_collect_min_max_for_phj = + should_collect_min_max_for_perfect_hash(&on_left, &schema)?; + let initial = BuildSideState::try_new( metrics, reservation, on_left.clone(), &schema, - should_compute_dynamic_filters, + should_compute_dynamic_filters || should_collect_min_max_for_phj, )?; let state = left_stream @@ -1394,50 +1947,85 @@ async fn collect_left_input( bounds_accumulators, } = state; - // Estimation of memory size, required for hashtable, prior to allocation. - // Final result can be verified using `RawTable.allocation_info()` - let fixed_size_u32 = size_of::(); - let fixed_size_u64 = size_of::(); - - // Use `u32` indices for the JoinHashMap when num_rows ≤ u32::MAX, otherwise use the - // `u64` indice variant - // Arc is used instead of Box to allow sharing with SharedBuildAccumulator for hash map pushdown - let mut hashmap: Box = if num_rows > u32::MAX as usize { - let estimated_hashtable_size = - estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; - reservation.try_grow(estimated_hashtable_size)?; - metrics.build_mem_used.add(estimated_hashtable_size); - Box::new(JoinHashMapU64::with_capacity(num_rows)) - } else { - let estimated_hashtable_size = - estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; - reservation.try_grow(estimated_hashtable_size)?; - metrics.build_mem_used.add(estimated_hashtable_size); - Box::new(JoinHashMapU32::with_capacity(num_rows)) + // Compute bounds + let mut bounds = match bounds_accumulators { + Some(accumulators) if num_rows > 0 => { + let bounds = accumulators + .into_iter() + .map(CollectLeftAccumulator::evaluate) + .collect::>>()?; + Some(PartitionBounds::new(bounds)) + } + _ => None, }; - let mut hashes_buffer = Vec::new(); - let mut offset = 0; - - // Updating hashmap starting from the last batch - let batches_iter = batches.iter().rev(); - for batch in batches_iter.clone() { - hashes_buffer.clear(); - hashes_buffer.resize(batch.num_rows(), 0); - update_hash( + let (join_hash_map, batch, left_values) = + if let Some((array_map, batch, left_value)) = try_create_array_map( + &bounds, + &schema, + &batches, &on_left, - batch, - &mut *hashmap, - offset, - &random_state, - &mut hashes_buffer, - 0, - true, - )?; - offset += batch.num_rows(); - } - // Merge all batches into a single batch, so we can directly index into the arrays - let batch = concat_batches(&schema, batches_iter)?; + &mut reservation, + config.execution.perfect_hash_join_small_build_threshold, + config.execution.perfect_hash_join_min_key_density, + null_equality, + )? { + array_map_created_count.add(1); + metrics.build_mem_used.add(array_map.size()); + + (Map::ArrayMap(array_map), batch, left_value) + } else { + // Estimation of memory size, required for hashtable, prior to allocation. + // Final result can be verified using `RawTable.allocation_info()` + let fixed_size_u32 = size_of::(); + let fixed_size_u64 = size_of::(); + + // Use `u32` indices for the JoinHashMap when num_rows ≤ u32::MAX, otherwise use the + // `u64` indice variant + // Arc is used instead of Box to allow sharing with SharedBuildAccumulator for hash map pushdown + let mut hashmap: Box = if num_rows > u32::MAX as usize { + let estimated_hashtable_size = + estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; + reservation.try_grow(estimated_hashtable_size)?; + metrics.build_mem_used.add(estimated_hashtable_size); + Box::new(JoinHashMapU64::with_capacity(num_rows)) + } else { + let estimated_hashtable_size = + estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; + reservation.try_grow(estimated_hashtable_size)?; + metrics.build_mem_used.add(estimated_hashtable_size); + Box::new(JoinHashMapU32::with_capacity(num_rows)) + }; + + let mut hashes_buffer = Vec::new(); + let mut offset = 0; + + let batches_iter = batches.iter().rev(); + + // Updating hashmap starting from the last batch + for batch in batches_iter.clone() { + hashes_buffer.clear(); + hashes_buffer.resize(batch.num_rows(), 0); + update_hash( + &on_left, + batch, + &mut *hashmap, + offset, + &random_state, + &mut hashes_buffer, + 0, + true, + )?; + offset += batch.num_rows(); + } + + // Merge all batches into a single batch, so we can directly index into the arrays + let batch = concat_batches(&schema, batches_iter.clone())?; + + let left_values = evaluate_expressions_to_arrays(&on_left, &batch)?; + + (Map::HashMap(hashmap), batch, left_values) + }; // Reserve additional memory for visited indices bitmap and create shared builder let visited_indices_bitmap = if with_visited_indices_bitmap { @@ -1452,31 +2040,49 @@ async fn collect_left_input( BooleanBufferBuilder::new(0) }; - let left_values = evaluate_expressions_to_arrays(&on_left, &batch)?; + let map = Arc::new(join_hash_map); - // Compute bounds for dynamic filter if enabled - let bounds = match bounds_accumulators { - Some(accumulators) if num_rows > 0 => { - let bounds = accumulators - .into_iter() - .map(CollectLeftAccumulator::evaluate) - .collect::>>()?; - Some(PartitionBounds::new(bounds)) + let membership = if num_rows == 0 { + PushdownStrategy::Empty + } else { + // If the build side is small enough we can use IN list pushdown. + // If it's too big we fall back to pushing down a reference to the hash table. + // See `PushdownStrategy` for more details. + let estimated_size = left_values + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum::(); + if left_values.is_empty() + || left_values[0].is_empty() + || estimated_size > config.optimizer.hash_join_inlist_pushdown_max_size + || map.num_of_distinct_key() + > config + .optimizer + .hash_join_inlist_pushdown_max_distinct_values + { + PushdownStrategy::Map(Arc::clone(&map)) + } else if let Some(in_list_values) = build_struct_inlist_values(&left_values)? { + PushdownStrategy::InList(in_list_values) + } else { + PushdownStrategy::Map(Arc::clone(&map)) } - _ => None, }; - // Convert Box to Arc for sharing with SharedBuildAccumulator - let hash_map: Arc = hashmap.into(); + if should_collect_min_max_for_phj && !should_compute_dynamic_filters { + bounds = None; + } let data = JoinLeftData { - hash_map, + map, batch, values: left_values, visited_indices_bitmap: Mutex::new(visited_indices_bitmap), probe_threads_counter: AtomicUsize::new(probe_threads_count), _reservation: reservation, bounds, + membership, + probe_side_non_empty: AtomicBool::new(false), + probe_side_has_null: AtomicBool::new(false), }; Ok(data) @@ -1485,29 +2091,66 @@ async fn collect_left_input( #[cfg(test)] mod tests { use super::*; + + fn assert_phj_used(metrics: &MetricsSet, use_phj: bool) { + if use_phj { + assert!( + metrics + .sum_by_name(ARRAY_MAP_CREATED_COUNT_METRIC_NAME) + .expect("should have array_map_created_count metrics") + .as_usize() + >= 1 + ); + } else { + assert_eq!( + metrics + .sum_by_name(ARRAY_MAP_CREATED_COUNT_METRIC_NAME) + .map(|v| v.as_usize()) + .unwrap_or(0), + 0 + ) + } + } + + fn build_schema_and_on() -> Result<(SchemaRef, SchemaRef, JoinOn)> { + let left_schema = Arc::new(Schema::new(vec![ + Field::new("a1", DataType::Int32, true), + Field::new("b1", DataType::Int32, true), + ])); + let right_schema = Arc::new(Schema::new(vec![ + Field::new("a2", DataType::Int32, true), + Field::new("b1", DataType::Int32, true), + ])); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left_schema)?) as _, + Arc::new(Column::new_with_schema("b1", &right_schema)?) as _, + )]; + Ok((left_schema, right_schema, on)) + } + use crate::coalesce_partitions::CoalescePartitionsExec; use crate::joins::hash_join::stream::lookup_join_hashmap; - use crate::test::{assert_join_metrics, TestMemoryExec}; + use crate::test::{TestMemoryExec, assert_join_metrics}; use crate::{ common, expressions::Column, repartition::RepartitionExec, test::build_table_i32, test::exec::MockExec, }; - use arrow::array::{Date32Array, Int32Array, StructArray, UInt32Array, UInt64Array}; + use arrow::array::{ + Date32Array, Int32Array, Int64Array, StructArray, UInt32Array, UInt64Array, + }; use arrow::buffer::NullBuffer; use arrow::datatypes::{DataType, Field}; - use arrow_schema::Schema; use datafusion_common::hash_utils::create_hashes; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, - internal_err, ScalarValue, + ScalarValue, assert_batches_eq, assert_batches_sorted_eq, assert_contains, + exec_err, internal_err, }; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; - use datafusion_physical_expr::PhysicalExpr; use hashbrown::HashTable; use insta::{allow_duplicates, assert_snapshot}; use rstest::*; @@ -1519,10 +2162,37 @@ mod tests { #[template] #[rstest] - fn batch_sizes(#[values(8192, 10, 5, 2, 1)] batch_size: usize) {} + fn hash_join_exec_configs( + #[values(8192, 10, 5, 2, 1)] batch_size: usize, + #[values(true, false)] use_perfect_hash_join_as_possible: bool, + ) { + } - fn prepare_task_ctx(batch_size: usize) -> Arc { - let session_config = SessionConfig::default().with_batch_size(batch_size); + fn prepare_task_ctx( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Arc { + let mut session_config = SessionConfig::default().with_batch_size(batch_size); + + if use_perfect_hash_join_as_possible { + session_config + .options_mut() + .execution + .perfect_hash_join_small_build_threshold = 819200; + session_config + .options_mut() + .execution + .perfect_hash_join_min_key_density = 0.0; + } else { + session_config + .options_mut() + .execution + .perfect_hash_join_small_build_threshold = 0; + session_config + .options_mut() + .execution + .perfect_hash_join_min_key_density = f64::INFINITY; + } Arc::new(TaskContext::default().with_session_config(session_config)) } @@ -1536,6 +2206,26 @@ mod tests { TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() } + /// Build a table with two columns supporting nullable values + fn build_table_two_cols( + a: (&str, &Vec>), + b: (&str, &Vec>), + ) -> Arc { + let schema = Arc::new(Schema::new(vec![ + Field::new(a.0, DataType::Int32, true), + Field::new(b.0, DataType::Int32, true), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + ], + ) + .unwrap(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + fn join( left: Arc, right: Arc, @@ -1552,6 +2242,7 @@ mod tests { None, PartitionMode::CollectLeft, null_equality, + false, ) } @@ -1572,59 +2263,180 @@ mod tests { None, PartitionMode::CollectLeft, null_equality, + false, ) } - async fn join_collect( - left: Arc, - right: Arc, - on: JoinOn, - join_type: &JoinType, - null_equality: NullEquality, - context: Arc, - ) -> Result<(Vec, Vec, MetricsSet)> { - let join = join(left, right, on, join_type, null_equality)?; - let columns_header = columns(&join.schema()); + fn empty_build_with_probe_error_inputs() + -> (Arc, Arc, JoinOn) { + let left_batch = + build_table_i32(("a1", &vec![]), ("b1", &vec![]), ("c1", &vec![])); + let left_schema = left_batch.schema(); + let left: Arc = TestMemoryExec::try_new_exec( + &[vec![left_batch]], + Arc::clone(&left_schema), + None, + ) + .unwrap(); - let stream = join.execute(0, context)?; - let batches = common::collect(stream).await?; - let metrics = join.metrics().unwrap(); + let err = exec_err!("bad data error"); + let right_batch = + build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![])); + let right_schema = right_batch.schema(); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left_schema).unwrap()) as _, + Arc::new(Column::new_with_schema("b1", &right_schema).unwrap()) as _, + )]; + let right: Arc = Arc::new( + MockExec::new(vec![Ok(right_batch), err], right_schema).with_use_task(false), + ); - Ok((columns_header, batches, metrics)) + (left, right, on) } - async fn partitioned_join_collect( + async fn assert_empty_build_probe_behavior( + join_types: &[JoinType], + expect_probe_error: bool, + with_filter: bool, + ) { + let (left, right, on) = empty_build_with_probe_error_inputs(); + let filter = prepare_join_filter(); + + for join_type in join_types { + let join = if with_filter { + join_with_filter( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + filter.clone(), + join_type, + NullEquality::NullEqualsNothing, + ) + .unwrap() + } else { + join( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + join_type, + NullEquality::NullEqualsNothing, + ) + .unwrap() + }; + + let result = common::collect( + join.execute(0, Arc::new(TaskContext::default())).unwrap(), + ) + .await; + + if expect_probe_error { + let result_string = result.unwrap_err().to_string(); + assert!( + result_string.contains("bad data error"), + "actual: {result_string}" + ); + } else { + let batches = result.unwrap(); + assert!( + batches.is_empty(), + "expected no output batches for {join_type}, got {batches:?}" + ); + } + } + } + + fn hash_join_with_dynamic_filter( left: Arc, right: Arc, on: JoinOn, - join_type: &JoinType, - null_equality: NullEquality, - context: Arc, - ) -> Result<(Vec, Vec, MetricsSet)> { - join_collect_with_partition_mode( + join_type: JoinType, + ) -> Result<(HashJoinExec, Arc)> { + hash_join_with_dynamic_filter_and_mode( left, right, on, join_type, - PartitionMode::Partitioned, - null_equality, - context, + PartitionMode::CollectLeft, ) - .await } - async fn join_collect_with_partition_mode( + fn hash_join_with_dynamic_filter_and_mode( left: Arc, right: Arc, on: JoinOn, - join_type: &JoinType, - partition_mode: PartitionMode, - null_equality: NullEquality, - context: Arc, - ) -> Result<(Vec, Vec, MetricsSet)> { - let partition_count = 4; - - let (left_expr, right_expr) = on + join_type: JoinType, + mode: PartitionMode, + ) -> Result<(HashJoinExec, Arc)> { + let dynamic_filter = HashJoinExec::create_dynamic_filter(&on); + let mut join = HashJoinExec::try_new( + left, + right, + on, + None, + &join_type, + None, + mode, + NullEquality::NullEqualsNothing, + false, + )?; + join.dynamic_filter = Some(HashJoinExecDynamicFilter { + filter: Arc::clone(&dynamic_filter), + build_accumulator: OnceLock::new(), + }); + + Ok((join, dynamic_filter)) + } + + async fn join_collect( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + null_equality: NullEquality, + context: Arc, + ) -> Result<(Vec, Vec, MetricsSet)> { + let join = join(left, right, on, join_type, null_equality)?; + let columns_header = columns(&join.schema()); + + let stream = join.execute(0, context)?; + let batches = common::collect(stream).await?; + let metrics = join.metrics().unwrap(); + + Ok((columns_header, batches, metrics)) + } + + async fn partitioned_join_collect( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + null_equality: NullEquality, + context: Arc, + ) -> Result<(Vec, Vec, MetricsSet)> { + join_collect_with_partition_mode( + left, + right, + on, + join_type, + PartitionMode::Partitioned, + null_equality, + context, + ) + .await + } + + async fn join_collect_with_partition_mode( + left: Arc, + right: Arc, + on: JoinOn, + join_type: &JoinType, + partition_mode: PartitionMode, + null_equality: NullEquality, + context: Arc, + ) -> Result<(Vec, Vec, MetricsSet)> { + let partition_count = 4; + + let (left_expr, right_expr) = on .iter() .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) .unzip(); @@ -1636,7 +2448,7 @@ mod tests { Partitioning::Hash(left_expr, partition_count), )?), PartitionMode::Auto => { - return internal_err!("Unexpected PartitionMode::Auto in join tests") + return internal_err!("Unexpected PartitionMode::Auto in join tests"); } }; @@ -1657,7 +2469,7 @@ mod tests { Partitioning::Hash(right_expr, partition_count), )?), PartitionMode::Auto => { - return internal_err!("Unexpected PartitionMode::Auto in join tests") + return internal_err!("Unexpected PartitionMode::Auto in join tests"); } }; @@ -1670,6 +2482,7 @@ mod tests { None, partition_mode, null_equality, + false, )?; let columns = columns(&join.schema()); @@ -1690,10 +2503,13 @@ mod tests { Ok((columns, batches, metrics)) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_inner_one(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_inner_one( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1724,26 +2540,30 @@ mod tests { allow_duplicates! { // Inner join output is expected to preserve both inputs order - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+----+----+----+----+ - | 1 | 4 | 7 | 10 | 4 | 70 | - | 2 | 5 | 8 | 20 | 5 | 80 | - | 3 | 5 | 9 | 20 | 5 | 80 | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 10 | 4 | 70 | + | 2 | 5 | 8 | 20 | 5 | 80 | + | 3 | 5 | 9 | 20 | 5 | 80 | + +----+----+----+----+----+----+ + "); } assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn partitioned_join_inner_one(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn partitioned_join_inner_one( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1772,18 +2592,19 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+----+----+----+----+ - | 1 | 4 | 7 | 10 | 4 | 70 | - | 2 | 5 | 8 | 20 | 5 | 80 | - | 3 | 5 | 9 | 20 | 5 | 80 | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 10 | 4 | 70 | + | 2 | 5 | 8 | 20 | 5 | 80 | + | 3 | 5 | 9 | 20 | 5 | 80 | + +----+----+----+----+----+----+ + "); } assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } @@ -1820,7 +2641,7 @@ mod tests { // Inner join output is expected to preserve both inputs order allow_duplicates! { - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b2 | c2 | +----+----+----+----+----+----+ @@ -1828,7 +2649,7 @@ mod tests { | 2 | 5 | 8 | 20 | 5 | 80 | | 3 | 5 | 9 | 20 | 5 | 80 | +----+----+----+----+----+----+ - "#); + "); } assert_join_metrics!(metrics, 3); @@ -1868,7 +2689,7 @@ mod tests { // Inner join output is expected to preserve both inputs order allow_duplicates! { - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b2 | c2 | +----+----+----+----+----+----+ @@ -1877,7 +2698,7 @@ mod tests { | 0 | 4 | 6 | 10 | 4 | 70 | | 1 | 4 | 7 | 10 | 4 | 70 | +----+----+----+----+----+----+ - "#); + "); } assert_join_metrics!(metrics, 4); @@ -1885,10 +2706,13 @@ mod tests { Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_inner_two(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_inner_two( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 2]), ("b2", &vec![1, 2, 2]), @@ -1936,11 +2760,16 @@ mod tests { div_ceil(9, batch_size) }; - assert_eq!(batches.len(), expected_batch_count); + // With batch coalescing, we may have fewer batches than expected + assert!( + batches.len() <= expected_batch_count, + "expected at most {expected_batch_count} batches, got {}", + batches.len() + ); // Inner join output is expected to preserve both inputs order allow_duplicates! { - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b2 | c1 | a1 | b2 | c2 | +----+----+----+----+----+----+ @@ -1948,7 +2777,7 @@ mod tests { | 2 | 2 | 8 | 2 | 2 | 80 | | 2 | 2 | 9 | 2 | 2 | 80 | +----+----+----+----+----+----+ - "#); + "); } assert_join_metrics!(metrics, 3); @@ -1957,10 +2786,13 @@ mod tests { } /// Test where the left has 2 parts, the right with 1 part => 1 part - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_inner_one_two_parts_left(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_inner_one_two_parts_left( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let batch1 = build_table_i32( ("a1", &vec![1, 2]), ("b2", &vec![1, 2]), @@ -2016,11 +2848,16 @@ mod tests { div_ceil(9, batch_size) }; - assert_eq!(batches.len(), expected_batch_count); + // With batch coalescing, we may have fewer batches than expected + assert!( + batches.len() <= expected_batch_count, + "expected at most {expected_batch_count} batches, got {}", + batches.len() + ); // Inner join output is expected to preserve both inputs order allow_duplicates! { - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b2 | c1 | a1 | b2 | c2 | +----+----+----+----+----+----+ @@ -2028,7 +2865,7 @@ mod tests { | 2 | 2 | 8 | 2 | 2 | 80 | | 2 | 2 | 9 | 2 | 2 | 80 | +----+----+----+----+----+----+ - "#); + "); } assert_join_metrics!(metrics, 3); @@ -2079,7 +2916,7 @@ mod tests { // Inner join output is expected to preserve both inputs order allow_duplicates! { - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b2 | c2 | +----+----+----+----+----+----+ @@ -2088,7 +2925,7 @@ mod tests { | 0 | 4 | 6 | 10 | 4 | 70 | | 1 | 4 | 7 | 10 | 4 | 70 | +----+----+----+----+----+----+ - "#); + "); } assert_join_metrics!(metrics, 4); @@ -2097,10 +2934,13 @@ mod tests { } /// Test where the left has 1 part, the right has 2 parts => 2 parts - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_inner_one_two_parts_right(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_inner_one_two_parts_right( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -2152,17 +2992,22 @@ mod tests { // and filtered later. div_ceil(6, batch_size) }; - assert_eq!(batches.len(), expected_batch_count); + // With batch coalescing, we may have fewer batches than expected + assert!( + batches.len() <= expected_batch_count, + "expected at most {expected_batch_count} batches, got {}", + batches.len() + ); // Inner join output is expected to preserve both inputs order allow_duplicates! { - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ | 1 | 4 | 7 | 10 | 4 | 70 | +----+----+----+----+----+----+ - "#); + "); } // second part @@ -2177,20 +3022,28 @@ mod tests { // and filtered later. div_ceil(3, batch_size) }; - assert_eq!(batches.len(), expected_batch_count); + // With batch coalescing, we may have fewer batches than expected + assert!( + batches.len() <= expected_batch_count, + "expected at most {expected_batch_count} batches, got {}", + batches.len() + ); // Inner join output is expected to preserve both inputs order allow_duplicates! { - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ | 2 | 5 | 8 | 30 | 5 | 90 | | 3 | 5 | 9 | 30 | 5 | 90 | +----+----+----+----+----+----+ - "#); + "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } @@ -2204,10 +3057,13 @@ mod tests { TestMemoryExec::try_new_exec(&[vec![batch.clone(), batch]], schema, None).unwrap() } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_multi_batch(batch_size: usize) { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_multi_batch( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -2224,9 +3080,9 @@ mod tests { )]; let join = join( - left, - right, - on, + Arc::clone(&left), + Arc::clone(&right), + on.clone(), &JoinType::Left, NullEquality::NullEqualsNothing, ) @@ -2235,11 +3091,18 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - let stream = join.execute(0, task_ctx).unwrap(); - let batches = common::collect(stream).await.unwrap(); + let (_, batches, metrics) = join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::Left, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ @@ -2249,14 +3112,20 @@ mod tests { | 2 | 5 | 8 | 20 | 5 | 80 | | 3 | 7 | 9 | | | | +----+----+----+----+----+----+ - "#); + "); } + + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + return Ok(()); } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_full_multi_batch(batch_size: usize) { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_full_multi_batch( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -2287,9 +3156,10 @@ mod tests { let stream = join.execute(0, task_ctx).unwrap(); let batches = common::collect(stream).await.unwrap(); + let metrics = join.metrics().unwrap(); allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b2 | c2 | +----+----+----+----+----+----+ @@ -2301,14 +3171,19 @@ mod tests { | 2 | 5 | 8 | 20 | 5 | 80 | | 3 | 7 | 9 | | | | +----+----+----+----+----+----+ - "#); + "); } + + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_empty_right(batch_size: usize) { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_empty_right( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -2335,9 +3210,10 @@ mod tests { let stream = join.execute(0, task_ctx).unwrap(); let batches = common::collect(stream).await.unwrap(); + let metrics = join.metrics().unwrap(); allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ @@ -2345,14 +3221,19 @@ mod tests { | 2 | 5 | 8 | | | | | 3 | 7 | 9 | | | | +----+----+----+----+----+----+ - "#); + "); } + + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_full_empty_right(batch_size: usize) { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_full_empty_right( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -2379,9 +3260,10 @@ mod tests { let stream = join.execute(0, task_ctx).unwrap(); let batches = common::collect(stream).await.unwrap(); + let metrics = join.metrics().unwrap(); allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b2 | c2 | +----+----+----+----+----+----+ @@ -2389,14 +3271,19 @@ mod tests { | 2 | 5 | 8 | | | | | 3 | 7 | 9 | | | | +----+----+----+----+----+----+ - "#); + "); } + + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_one(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_one( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -2425,7 +3312,7 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ @@ -2433,18 +3320,22 @@ mod tests { | 2 | 5 | 8 | 20 | 5 | 80 | | 3 | 7 | 9 | | | | +----+----+----+----+----+----+ - "#); + "); } assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn partitioned_join_left_one(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn partitioned_join_left_one( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -2473,7 +3364,7 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ @@ -2481,10 +3372,11 @@ mod tests { | 2 | 5 | 8 | 20 | 5 | 80 | | 3 | 7 | 9 | | | | +----+----+----+----+----+----+ - "#); + "); } assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } @@ -2509,10 +3401,13 @@ mod tests { ) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_semi(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_semi( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table left semi join right_table on left_table.b1 = right_table.b2 @@ -2537,7 +3432,7 @@ mod tests { // ignore the order allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +----+----+-----+ | a1 | b1 | c1 | +----+----+-----+ @@ -2545,16 +3440,22 @@ mod tests { | 13 | 10 | 130 | | 9 | 8 | 90 | +----+----+-----+ - "#); + "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_semi_with_filter(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_semi_with_filter( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); @@ -2610,6 +3511,9 @@ mod tests { "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + // left_table left semi join right_table on left_table.b1 = right_table.b2 and right_table.a2 > 10 let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), @@ -2638,22 +3542,28 @@ mod tests { let batches = common::collect(stream).await?; allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +----+----+-----+ | a1 | b1 | c1 | +----+----+-----+ | 13 | 10 | 130 | +----+----+-----+ - "#); + "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_right_semi(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_right_semi( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); @@ -2679,7 +3589,7 @@ mod tests { // RightSemi join output is expected to preserve right input order allow_duplicates! { - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+-----+ | a2 | b2 | c2 | +----+----+-----+ @@ -2687,16 +3597,22 @@ mod tests { | 12 | 10 | 40 | | 10 | 10 | 100 | +----+----+-----+ - "#); + "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_right_semi_with_filter(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_right_semi_with_filter( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); @@ -2742,7 +3658,7 @@ mod tests { // RightSemi join output is expected to preserve right input order allow_duplicates! { - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+-----+ | a2 | b2 | c2 | +----+----+-----+ @@ -2750,9 +3666,12 @@ mod tests { | 12 | 10 | 40 | | 10 | 10 | 100 | +----+----+-----+ - "#); + "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + // left_table right semi join right_table on left_table.b1 = right_table.b2 on left_table.a1!=9 let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), @@ -2779,23 +3698,29 @@ mod tests { // RightSemi join output is expected to preserve right input order allow_duplicates! { - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+-----+ | a2 | b2 | c2 | +----+----+-----+ | 12 | 10 | 40 | | 10 | 10 | 100 | +----+----+-----+ - "#); + "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_anti(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_anti( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table left anti join right_table on left_table.b1 = right_table.b2 @@ -2819,7 +3744,7 @@ mod tests { let batches = common::collect(stream).await?; allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +----+----+----+ | a1 | b1 | c1 | +----+----+----+ @@ -2828,15 +3753,22 @@ mod tests { | 5 | 5 | 50 | | 7 | 7 | 70 | +----+----+----+ - "#); + "); } + + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_anti_with_filter(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_anti_with_filter( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table left anti join right_table on left_table.b1 = right_table.b2 and right_table.a2!=8 @@ -2879,7 +3811,7 @@ mod tests { let batches = common::collect(stream).await?; allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +----+----+-----+ | a1 | b1 | c1 | +----+----+-----+ @@ -2890,9 +3822,12 @@ mod tests { | 7 | 7 | 70 | | 9 | 8 | 90 | +----+----+-----+ - "#); + "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + // left_table left anti join right_table on left_table.b1 = right_table.b2 and right_table.a2 != 13 let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), @@ -2922,7 +3857,7 @@ mod tests { let batches = common::collect(stream).await?; allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +----+----+-----+ | a1 | b1 | c1 | +----+----+-----+ @@ -2933,16 +3868,22 @@ mod tests { | 7 | 7 | 70 | | 9 | 8 | 90 | +----+----+-----+ - "#); + "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_right_anti(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_right_anti( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); let on = vec![( @@ -2966,7 +3907,7 @@ mod tests { // RightAnti join output is expected to preserve right input order allow_duplicates! { - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+-----+ | a2 | b2 | c2 | +----+----+-----+ @@ -2974,15 +3915,22 @@ mod tests { | 2 | 2 | 80 | | 4 | 4 | 120 | +----+----+-----+ - "#); + "); } + + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_right_anti_with_filter(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_right_anti_with_filter( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_semi_anti_left_table(); let right = build_semi_anti_right_table(); // left_table right anti join right_table on left_table.b1 = right_table.b2 and left_table.a1!=13 @@ -3027,7 +3975,7 @@ mod tests { // RightAnti join output is expected to preserve right input order allow_duplicates! { - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+-----+ | a2 | b2 | c2 | +----+----+-----+ @@ -3037,9 +3985,12 @@ mod tests { | 10 | 10 | 100 | | 4 | 4 | 120 | +----+----+-----+ - "#); + "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + // left_table right anti join right_table on left_table.b1 = right_table.b2 and right_table.b2!=8 let column_indices = vec![ColumnIndex { index: 1, @@ -3074,7 +4025,7 @@ mod tests { // RightAnti join output is expected to preserve right input order allow_duplicates! { - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+-----+ | a2 | b2 | c2 | +----+----+-----+ @@ -3083,16 +4034,22 @@ mod tests { | 2 | 2 | 80 | | 4 | 4 | 120 | +----+----+-----+ - "#); + "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_right_one(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_right_one( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -3121,7 +4078,7 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ @@ -3129,18 +4086,22 @@ mod tests { | 1 | 4 | 7 | 10 | 4 | 70 | | 2 | 5 | 8 | 20 | 5 | 80 | +----+----+----+----+----+----+ - "#); + "); } assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn partitioned_join_right_one(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn partitioned_join_right_one( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -3169,7 +4130,7 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ @@ -3177,18 +4138,22 @@ mod tests { | 1 | 4 | 7 | 10 | 4 | 70 | | 2 | 5 | 8 | 20 | 5 | 80 | +----+----+----+----+----+----+ - "#); + "); } assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_full_one(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_full_one( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -3219,7 +4184,7 @@ mod tests { let batches = common::collect(stream).await?; allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b2 | c2 | +----+----+----+----+----+----+ @@ -3228,16 +4193,22 @@ mod tests { | 2 | 5 | 8 | 20 | 5 | 80 | | 3 | 7 | 9 | | | | +----+----+----+----+----+----+ - "#); + "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_mark(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_mark( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -3266,7 +4237,7 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]); allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +----+----+----+-------+ | a1 | b1 | c1 | mark | +----+----+----+-------+ @@ -3274,18 +4245,22 @@ mod tests { | 2 | 5 | 8 | true | | 3 | 7 | 9 | false | +----+----+----+-------+ - "#); + "); } assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn partitioned_join_left_mark(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn partitioned_join_left_mark( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -3314,7 +4289,7 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]); allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +----+----+----+-------+ | a1 | b1 | c1 | mark | +----+----+----+-------+ @@ -3322,18 +4297,22 @@ mod tests { | 2 | 5 | 8 | true | | 3 | 7 | 9 | false | +----+----+----+-------+ - "#); + "); } assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_right_mark(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_right_mark( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -3373,14 +4352,18 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); assert_join_metrics!(metrics, 3); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn partitioned_join_right_mark(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn partitioned_join_right_mark( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -3421,6 +4404,7 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); assert_join_metrics!(metrics, 4); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); Ok(()) } @@ -3434,7 +4418,7 @@ mod tests { ("y", &vec![200, 300]), ); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let hashes_buff = &mut vec![0; left.num_rows()]; let hashes = create_hashes([&left.columns()[0]], &random_state, hashes_buff)?; @@ -3467,6 +4451,8 @@ mod tests { let mut hashes_buffer = vec![0; right.num_rows()]; create_hashes([&right_keys_values], &random_state, &mut hashes_buffer)?; + let mut probe_indices_buffer = Vec::new(); + let mut build_indices_buffer = Vec::new(); let (l, r, _) = lookup_join_hashmap( &join_hash_map, &[left_keys_values], @@ -3475,6 +4461,8 @@ mod tests { &hashes_buffer, 8192, (0, None), + &mut probe_indices_buffer, + &mut build_indices_buffer, )?; let left_ids: UInt64Array = vec![0, 1].into(); @@ -3497,7 +4485,7 @@ mod tests { ("y", &vec![200, 300]), ); - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let hashes_buff = &mut vec![0; left.num_rows()]; let hashes = create_hashes([&left.columns()[0]], &random_state, hashes_buff)?; @@ -3524,6 +4512,8 @@ mod tests { let mut hashes_buffer = vec![0; right.num_rows()]; create_hashes([&right_keys_values], &random_state, &mut hashes_buffer)?; + let mut probe_indices_buffer = Vec::new(); + let mut build_indices_buffer = Vec::new(); let (l, r, _) = lookup_join_hashmap( &join_hash_map, &[left_keys_values], @@ -3532,6 +4522,8 @@ mod tests { &hashes_buffer, 8192, (0, None), + &mut probe_indices_buffer, + &mut build_indices_buffer, )?; // We still expect to match rows 0 and 1 on both sides @@ -3578,14 +4570,14 @@ mod tests { let batches = common::collect(stream).await?; allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +---+---+---+----+---+----+ | a | b | c | a | b | c | +---+---+---+----+---+----+ | 1 | 4 | 7 | 10 | 1 | 70 | | 2 | 5 | 8 | 20 | 2 | 80 | +---+---+---+----+---+----+ - "#); + "); } Ok(()) @@ -3619,10 +4611,13 @@ mod tests { ) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_inner_with_filter(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_inner_with_filter( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -3655,23 +4650,29 @@ mod tests { let batches = common::collect(stream).await?; allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +---+---+---+----+---+---+ | a | b | c | a | b | c | +---+---+---+----+---+---+ | 2 | 7 | 9 | 10 | 2 | 7 | | 2 | 7 | 9 | 20 | 2 | 5 | +---+---+---+----+---+---+ - "#); + "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_left_with_filter(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_left_with_filter( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -3704,7 +4705,7 @@ mod tests { let batches = common::collect(stream).await?; allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +---+---+---+----+---+---+ | a | b | c | a | b | c | +---+---+---+----+---+---+ @@ -3714,16 +4715,22 @@ mod tests { | 2 | 7 | 9 | 20 | 2 | 5 | | 2 | 8 | 1 | | | | +---+---+---+----+---+---+ - "#); + "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_right_with_filter(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_right_with_filter( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -3756,7 +4763,7 @@ mod tests { let batches = common::collect(stream).await?; allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +---+---+---+----+---+---+ | a | b | c | a | b | c | +---+---+---+----+---+---+ @@ -3765,16 +4772,22 @@ mod tests { | 2 | 7 | 9 | 10 | 2 | 7 | | 2 | 7 | 9 | 20 | 2 | 5 | +---+---+---+----+---+---+ - "#); + "); } + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + Ok(()) } - #[apply(batch_sizes)] + #[apply(hash_join_exec_configs)] #[tokio::test] - async fn join_full_with_filter(batch_size: usize) -> Result<()> { - let task_ctx = prepare_task_ctx(batch_size); + async fn join_full_with_filter( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); let left = build_table( ("a", &vec![0, 1, 2, 2]), ("b", &vec![4, 5, 7, 8]), @@ -3821,6 +4834,9 @@ mod tests { ]; assert_batches_sorted_eq!(expected, &batches); + let metrics = join.metrics().unwrap(); + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + // THIS MIGRATION HALTED DUE TO ISSUE #15312 //allow_duplicates! { // assert_snapshot!(batches_to_sort_string(&batches), @r#" @@ -4011,7 +5027,7 @@ mod tests { let batches = common::collect(stream).await?; allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches), @r#" + assert_snapshot!(batches_to_sort_string(&batches), @r" +------------+---+------------+---+ | date | n | date | n | +------------+---+------------+---+ @@ -4019,7 +5035,7 @@ mod tests { | 2022-04-26 | 2 | 2022-04-26 | 5 | | 2022-04-27 | 3 | 2022-04-27 | 6 | +------------+---+------------+---+ - "#); + "); } Ok(()) @@ -4079,6 +5095,70 @@ mod tests { } } + #[tokio::test] + async fn join_does_not_consume_probe_when_empty_build_fixes_output() { + assert_empty_build_probe_behavior( + &[ + JoinType::Inner, + JoinType::Left, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::LeftMark, + JoinType::RightSemi, + ], + false, + false, + ) + .await; + } + + #[tokio::test] + async fn join_does_not_consume_probe_when_empty_build_fixes_output_with_filter() { + assert_empty_build_probe_behavior( + &[ + JoinType::Inner, + JoinType::Left, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::LeftMark, + JoinType::RightSemi, + ], + false, + true, + ) + .await; + } + + #[tokio::test] + async fn join_still_consumes_probe_when_empty_build_needs_probe_rows() { + assert_empty_build_probe_behavior( + &[ + JoinType::Right, + JoinType::Full, + JoinType::RightAnti, + JoinType::RightMark, + ], + true, + false, + ) + .await; + } + + #[tokio::test] + async fn join_still_consumes_probe_when_empty_build_needs_probe_rows_with_filter() { + assert_empty_build_probe_behavior( + &[ + JoinType::Right, + JoinType::Full, + JoinType::RightAnti, + JoinType::RightMark, + ], + true, + true, + ) + .await; + } + #[tokio::test] async fn join_split_batch() { let left = build_table( @@ -4170,7 +5250,7 @@ mod tests { // validation of partial join results output for different batch_size setting for join_type in join_types { for batch_size in (1..21).rev() { - let task_ctx = prepare_task_ctx(batch_size); + let task_ctx = prepare_task_ctx(batch_size, true); let join = join( Arc::clone(&left), @@ -4197,10 +5277,11 @@ mod tests { } _ => div_ceil(expected_resultset_records, batch_size) + 1, }; - assert_eq!( - batches.len(), - expected_batch_count, - "expected {expected_batch_count} output batches for {join_type} join with batch_size = {batch_size}" + // With batch coalescing, we may have fewer batches than expected + assert!( + batches.len() <= expected_batch_count, + "expected at most {expected_batch_count} output batches for {join_type} join with batch_size = {batch_size}, got {}", + batches.len() ); let expected = match join_type { @@ -4210,7 +5291,17 @@ mod tests { JoinType::LeftAnti => left_empty.to_vec(), _ => common_result.to_vec(), }; - assert_batches_eq!(expected, &batches); + // For anti joins with empty results, we may get zero batches + // (with coalescing) instead of one empty batch with schema + if batches.is_empty() { + // Verify this is an expected empty result case + assert!( + matches!(join_type, JoinType::RightAnti | JoinType::LeftAnti), + "Unexpected empty result for {join_type} join" + ); + } else { + assert_batches_eq!(expected, &batches); + } } } } @@ -4339,6 +5430,7 @@ mod tests { None, PartitionMode::Partitioned, NullEquality::NullEqualsNothing, + false, )?; let stream = join.execute(1, task_ctx)?; @@ -4348,7 +5440,6 @@ mod tests { assert_contains!( err.to_string(), "Resources exhausted: Additional allocation failed for HashJoinInput[1] with top memory consumers (across reservations) as:\n HashJoinInput[1]" - ); assert_contains!( @@ -4411,7 +5502,7 @@ mod tests { assert_eq!(columns, vec!["n1", "n2"]); allow_duplicates! { - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +--------+--------+ | n1 | n2 | +--------+--------+ @@ -4419,7 +5510,7 @@ mod tests { | {a: 1} | {a: 1} | | {a: 2} | {a: 2} | +--------+--------+ - "#); + "); } assert_join_metrics!(metrics, 3); @@ -4450,13 +5541,13 @@ mod tests { .await?; allow_duplicates! { - assert_snapshot!(batches_to_sort_string(&batches_null_eq), @r#" + assert_snapshot!(batches_to_sort_string(&batches_null_eq), @r" +----+----+ | n1 | n2 | +----+----+ | | | +----+----+ - "#); + "); } assert_join_metrics!(metrics, 1); @@ -4473,9 +5564,15 @@ mod tests { assert_join_metrics!(metrics, 0); - let expected_null_neq = - ["+----+----+", "| n1 | n2 |", "+----+----+", "+----+----+"]; - assert_batches_eq!(expected_null_neq, &batches_null_neq); + // With batch coalescing, empty results may not emit any batches + // Check that either we have no batches, or an empty batch with proper schema + if batches_null_neq.is_empty() { + // This is fine - no output rows + } else { + let expected_null_neq = + ["+----+----+", "| n1 | n2 |", "+----+----+", "+----+----+"]; + assert_batches_eq!(expected_null_neq, &batches_null_neq); + } Ok(()) } @@ -4505,25 +5602,8 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - // Create a dynamic filter manually - let dynamic_filter = HashJoinExec::create_dynamic_filter(&on); - let dynamic_filter_clone = Arc::clone(&dynamic_filter); - - // Create HashJoinExec with the dynamic filter - let mut join = HashJoinExec::try_new( - left, - right, - on, - None, - &JoinType::Inner, - None, - PartitionMode::CollectLeft, - NullEquality::NullEqualsNothing, - )?; - join.dynamic_filter = Some(HashJoinExecDynamicFilter { - filter: dynamic_filter, - build_accumulator: OnceLock::new(), - }); + let (join, dynamic_filter) = + hash_join_with_dynamic_filter(left, right, on, JoinType::Inner)?; // Execute the join let stream = join.execute(0, task_ctx)?; @@ -4531,7 +5611,7 @@ mod tests { // After the join completes, the dynamic filter should be marked as complete // wait_complete() should return immediately - dynamic_filter_clone.wait_complete().await; + dynamic_filter.wait_complete().await; Ok(()) } @@ -4553,34 +5633,747 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - // Create a dynamic filter manually - let dynamic_filter = HashJoinExec::create_dynamic_filter(&on); - let dynamic_filter_clone = Arc::clone(&dynamic_filter); + let (join, dynamic_filter) = + hash_join_with_dynamic_filter(left, right, on, JoinType::Inner)?; - // Create HashJoinExec with the dynamic filter - let mut join = HashJoinExec::try_new( + // Execute the join + let stream = join.execute(0, task_ctx)?; + let _batches = common::collect(stream).await?; + + // Even with empty build side, the dynamic filter should be marked as complete + // wait_complete() should return immediately + dynamic_filter.wait_complete().await; + + Ok(()) + } + + #[tokio::test] + async fn test_partitioned_dynamic_filter_reports_empty_canceled_partitions() + -> Result<()> { + let mut session_config = SessionConfig::default(); + session_config + .options_mut() + .optimizer + .enable_dynamic_filter_pushdown = true; + let task_ctx = + Arc::new(TaskContext::default().with_session_config(session_config)); + + let child_left_schema = Arc::new(Schema::new(vec![ + Field::new("child_left_payload", DataType::Int32, false), + Field::new("child_key", DataType::Int32, false), + Field::new("child_left_extra", DataType::Int32, false), + ])); + let child_right_schema = Arc::new(Schema::new(vec![ + Field::new("child_right_payload", DataType::Int32, false), + Field::new("child_right_key", DataType::Int32, false), + Field::new("child_right_extra", DataType::Int32, false), + ])); + let parent_left_schema = Arc::new(Schema::new(vec![ + Field::new("parent_payload", DataType::Int32, false), + Field::new("parent_key", DataType::Int32, false), + Field::new("parent_extra", DataType::Int32, false), + ])); + + let child_left: Arc = TestMemoryExec::try_new_exec( + &[ + vec![build_table_i32( + ("child_left_payload", &vec![10]), + ("child_key", &vec![0]), + ("child_left_extra", &vec![100]), + )], + vec![build_table_i32( + ("child_left_payload", &vec![11]), + ("child_key", &vec![1]), + ("child_left_extra", &vec![101]), + )], + vec![build_table_i32( + ("child_left_payload", &vec![12]), + ("child_key", &vec![2]), + ("child_left_extra", &vec![102]), + )], + vec![build_table_i32( + ("child_left_payload", &vec![13]), + ("child_key", &vec![3]), + ("child_left_extra", &vec![103]), + )], + ], + Arc::clone(&child_left_schema), + None, + )?; + let child_right: Arc = TestMemoryExec::try_new_exec( + &[ + vec![build_table_i32( + ("child_right_payload", &vec![20]), + ("child_right_key", &vec![0]), + ("child_right_extra", &vec![200]), + )], + vec![build_table_i32( + ("child_right_payload", &vec![21]), + ("child_right_key", &vec![1]), + ("child_right_extra", &vec![201]), + )], + vec![build_table_i32( + ("child_right_payload", &vec![22]), + ("child_right_key", &vec![2]), + ("child_right_extra", &vec![202]), + )], + vec![build_table_i32( + ("child_right_payload", &vec![23]), + ("child_right_key", &vec![3]), + ("child_right_extra", &vec![203]), + )], + ], + Arc::clone(&child_right_schema), + None, + )?; + let parent_left: Arc = TestMemoryExec::try_new_exec( + &[ + vec![build_table_i32( + ("parent_payload", &vec![30]), + ("parent_key", &vec![0]), + ("parent_extra", &vec![300]), + )], + vec![RecordBatch::new_empty(Arc::clone(&parent_left_schema))], + vec![build_table_i32( + ("parent_payload", &vec![32]), + ("parent_key", &vec![2]), + ("parent_extra", &vec![302]), + )], + vec![RecordBatch::new_empty(Arc::clone(&parent_left_schema))], + ], + Arc::clone(&parent_left_schema), + None, + )?; + + let child_on = vec![( + Arc::new(Column::new_with_schema("child_key", &child_left_schema)?) as _, + Arc::new(Column::new_with_schema( + "child_right_key", + &child_right_schema, + )?) as _, + )]; + let (child_join, _child_dynamic_filter) = hash_join_with_dynamic_filter_and_mode( + child_left, + child_right, + child_on, + JoinType::Inner, + PartitionMode::Partitioned, + )?; + let child_join: Arc = Arc::new(child_join); + + let parent_on = vec![( + Arc::new(Column::new_with_schema("parent_key", &parent_left_schema)?) as _, + Arc::new(Column::new_with_schema("child_key", &child_join.schema())?) as _, + )]; + let parent_join = HashJoinExec::try_new( + parent_left, + child_join, + parent_on, + None, + &JoinType::RightSemi, + None, + PartitionMode::Partitioned, + NullEquality::NullEqualsNothing, + false, + )?; + + let batches = tokio::time::timeout( + std::time::Duration::from_secs(5), + crate::execution_plan::collect(Arc::new(parent_join), task_ctx), + ) + .await + .expect("partitioned right-semi join should not hang")?; + + assert_batches_sorted_eq!( + [ + "+--------------------+-----------+------------------+---------------------+-----------------+-------------------+", + "| child_left_payload | child_key | child_left_extra | child_right_payload | child_right_key | child_right_extra |", + "+--------------------+-----------+------------------+---------------------+-----------------+-------------------+", + "| 10 | 0 | 100 | 20 | 0 | 200 |", + "| 12 | 2 | 102 | 22 | 2 | 202 |", + "+--------------------+-----------+------------------+---------------------+-----------------+-------------------+", + ], + &batches + ); + + Ok(()) + } + + #[tokio::test] + async fn test_hash_join_skips_probe_on_empty_build_after_partition_bounds_report() + -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let (left, right, on) = empty_build_with_probe_error_inputs(); + + // Keep an extra consumer reference so execute() enables dynamic filter pushdown + // and enters the WaitPartitionBoundsReport path before deciding whether to poll + // the probe side. + let (join, dynamic_filter) = + hash_join_with_dynamic_filter(left, right, on, JoinType::Inner)?; + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + assert!(batches.is_empty()); + + dynamic_filter.wait_complete().await; + + Ok(()) + } + + #[tokio::test] + async fn test_perfect_hash_join_with_negative_numbers() -> Result<()> { + let task_ctx = prepare_task_ctx(8192, true); + let (left_schema, right_schema, on) = build_schema_and_on()?; + + let left_batch = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + Arc::new(Int32Array::from(vec![-1, 0, 1])) as ArrayRef, + ], + )?; + let left = TestMemoryExec::try_new_exec(&[vec![left_batch]], left_schema, None)?; + + let right_batch = RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(vec![10, 20, 30, 40])) as ArrayRef, + Arc::new(Int32Array::from(vec![1, -1, 0, 2])) as ArrayRef, + ], + )?; + let right = + TestMemoryExec::try_new_exec(&[vec![right_batch]], right_schema, None)?; + + let (columns, batches, metrics) = join_collect( left, right, on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "a2", "b1"]); + + assert_batches_sorted_eq!( + [ + "+----+----+----+----+", + "| a1 | b1 | a2 | b1 |", + "+----+----+----+----+", + "| 1 | -1 | 20 | -1 |", + "| 2 | 0 | 30 | 0 |", + "| 3 | 1 | 10 | 1 |", + "+----+----+----+----+", + ], + &batches + ); + + assert_phj_used(&metrics, true); + + Ok(()) + } + + #[tokio::test] + async fn test_perfect_hash_join_overflow_full_int64_range() -> Result<()> { + let task_ctx = prepare_task_ctx(8192, true); + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int64Array::from(vec![i64::MIN, i64::MAX]))], + )?; + let left = TestMemoryExec::try_new_exec( + &[vec![batch.clone()]], + Arc::clone(&schema), None, + )?; + let right = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None)?; + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("a", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a", &right.schema())?) as _, + )]; + let (_columns, batches, _metrics) = join_collect( + left, + right, + on, &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 2); + Ok(()) + } + + #[apply(hash_join_exec_configs)] + #[tokio::test] + async fn test_phj_null_equals_null_build_no_nulls_probe_has_nulls( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); + let (left_schema, right_schema, on) = build_schema_and_on()?; + + let left_batch = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef, + Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef, + ], + )?; + let left = TestMemoryExec::try_new_exec(&[vec![left_batch]], left_schema, None)?; + + let right_batch = RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(vec![3, 4])) as ArrayRef, + Arc::new(Int32Array::from(vec![Some(10), None])) as ArrayRef, + ], + )?; + let right = + TestMemoryExec::try_new_exec(&[vec![right_batch]], right_schema, None)?; + + let (columns, batches, metrics) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNull, + task_ctx, + ) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "a2", "b1"]); + assert_batches_sorted_eq!( + [ + "+----+----+----+----+", + "| a1 | b1 | a2 | b1 |", + "+----+----+----+----+", + "| 1 | 10 | 3 | 10 |", + "+----+----+----+----+", + ], + &batches + ); + + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + + Ok(()) + } + + #[apply(hash_join_exec_configs)] + #[tokio::test] + async fn test_phj_null_equals_nothing_build_probe_all_have_nulls( + batch_size: usize, + use_perfect_hash_join_as_possible: bool, + ) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, use_perfect_hash_join_as_possible); + let (left_schema, right_schema, on) = build_schema_and_on()?; + + let left_batch = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef, + Arc::new(Int32Array::from(vec![Some(10), None])) as ArrayRef, + ], + )?; + let left = TestMemoryExec::try_new_exec(&[vec![left_batch]], left_schema, None)?; + + let right_batch = RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(vec![Some(3), Some(4)])) as ArrayRef, + Arc::new(Int32Array::from(vec![Some(10), None])) as ArrayRef, + ], + )?; + let right = + TestMemoryExec::try_new_exec(&[vec![right_batch]], right_schema, None)?; + + let (columns, batches, metrics) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "a2", "b1"]); + assert_batches_sorted_eq!( + [ + "+----+----+----+----+", + "| a1 | b1 | a2 | b1 |", + "+----+----+----+----+", + "| 1 | 10 | 3 | 10 |", + "+----+----+----+----+", + ], + &batches + ); + + assert_phj_used(&metrics, use_perfect_hash_join_as_possible); + + Ok(()) + } + + #[tokio::test] + async fn test_phj_null_equals_null_build_have_nulls() -> Result<()> { + let task_ctx = prepare_task_ctx(8192, true); + let (left_schema, right_schema, on) = build_schema_and_on()?; + + let left_batch = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])) as ArrayRef, + Arc::new(Int32Array::from(vec![Some(10), Some(20), None])) as ArrayRef, + ], + )?; + let left = TestMemoryExec::try_new_exec(&[vec![left_batch]], left_schema, None)?; + + let right_batch = RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(vec![Some(3), Some(4)])) as ArrayRef, + Arc::new(Int32Array::from(vec![Some(10), Some(30)])) as ArrayRef, + ], + )?; + let right = + TestMemoryExec::try_new_exec(&[vec![right_batch]], right_schema, None)?; + + let (columns, batches, metrics) = join_collect( + left, + right, + on, + &JoinType::Inner, + NullEquality::NullEqualsNull, + task_ctx, + ) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "a2", "b1"]); + assert_batches_sorted_eq!( + [ + "+----+----+----+----+", + "| a1 | b1 | a2 | b1 |", + "+----+----+----+----+", + "| 1 | 10 | 3 | 10 |", + "+----+----+----+----+", + ], + &batches + ); + + assert_phj_used(&metrics, false); + + Ok(()) + } + + /// Test null-aware anti join when probe side (right) contains NULL + /// Expected: no rows should be output (NULL in subquery means all results are unknown) + #[apply(hash_join_exec_configs)] + #[tokio::test] + async fn test_null_aware_anti_join_probe_null(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, false); + + // Build left table (rows to potentially output) + let left = build_table_two_cols( + ("c1", &vec![Some(1), Some(2), Some(3), Some(4)]), + ("dummy", &vec![Some(10), Some(20), Some(30), Some(40)]), + ); + + // Build right table (subquery with NULL) + let right = build_table_two_cols( + ("c2", &vec![Some(1), Some(2), Some(3), None]), + ("dummy", &vec![Some(100), Some(200), Some(300), Some(400)]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("c1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("c2", &right.schema())?) as _, + )]; + + // Create null-aware anti join + let join = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::LeftAnti, None, PartitionMode::CollectLeft, NullEquality::NullEqualsNothing, + true, // null_aware = true )?; - join.dynamic_filter = Some(HashJoinExecDynamicFilter { - filter: dynamic_filter, - build_accumulator: OnceLock::new(), - }); - // Execute the join let stream = join.execute(0, task_ctx)?; - let _batches = common::collect(stream).await?; + let batches = common::collect(stream).await?; - // Even with empty build side, the dynamic filter should be marked as complete - // wait_complete() should return immediately - dynamic_filter_clone.wait_complete().await; + // Expected: empty result (probe side has NULL, so no rows should be output) + allow_duplicates! { + assert_snapshot!(batches_to_sort_string(&batches), @r" + ++ + ++ + "); + } + Ok(()) + } + + /// Test null-aware anti join when build side (left) contains NULL keys + /// Expected: rows with NULL keys should not be output + #[apply(hash_join_exec_configs)] + #[tokio::test] + async fn test_null_aware_anti_join_build_null(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, false); + + // Build left table with NULL key (this row should not be output) + let left = build_table_two_cols( + ("c1", &vec![Some(1), Some(4), None]), + ("dummy", &vec![Some(10), Some(40), Some(0)]), + ); + + // Build right table (no NULL, so probe-side check passes) + let right = build_table_two_cols( + ("c2", &vec![Some(1), Some(2), Some(3)]), + ("dummy", &vec![Some(100), Some(200), Some(300)]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("c1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("c2", &right.schema())?) as _, + )]; + + // Create null-aware anti join + let join = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::LeftAnti, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + true, // null_aware = true + )?; + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + + // Expected: only c1=4 (not c1=1 which matches, not c1=NULL) + allow_duplicates! { + assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+-------+ + | c1 | dummy | + +----+-------+ + | 4 | 40 | + +----+-------+ + "); + } + Ok(()) + } + + /// Test null-aware anti join with no NULLs (should work like regular anti join) + #[apply(hash_join_exec_configs)] + #[tokio::test] + async fn test_null_aware_anti_join_no_nulls(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size, false); + + // Build left table (no NULLs) + let left = build_table_two_cols( + ("c1", &vec![Some(1), Some(2), Some(4), Some(5)]), + ("dummy", &vec![Some(10), Some(20), Some(40), Some(50)]), + ); + + // Build right table (no NULLs) + let right = build_table_two_cols( + ("c2", &vec![Some(1), Some(2), Some(3)]), + ("dummy", &vec![Some(100), Some(200), Some(300)]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("c1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("c2", &right.schema())?) as _, + )]; + + // Create null-aware anti join + let join = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::LeftAnti, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + true, // null_aware = true + )?; + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + + // Expected: c1=4 and c1=5 (they don't match anything in right) + allow_duplicates! { + assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+-------+ + | c1 | dummy | + +----+-------+ + | 4 | 40 | + | 5 | 50 | + +----+-------+ + "); + } + Ok(()) + } + + /// Test that null_aware validation rejects non-LeftAnti join types + #[tokio::test] + async fn test_null_aware_validation_wrong_join_type() { + let left = + build_table_two_cols(("c1", &vec![Some(1)]), ("dummy", &vec![Some(10)])); + let right = + build_table_two_cols(("c2", &vec![Some(1)]), ("dummy", &vec![Some(100)])); + + let on = vec![( + Arc::new(Column::new_with_schema("c1", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("c2", &right.schema()).unwrap()) as _, + )]; + + // Try to create null-aware Inner join (should fail) + let result = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + true, // null_aware = true (invalid for Inner join) + ); + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("null_aware can only be true for LeftAnti joins") + ); + } + + /// Test that null_aware validation rejects multi-column joins + #[tokio::test] + async fn test_null_aware_validation_multi_column() { + let left = build_table(("a", &vec![1]), ("b", &vec![2]), ("c", &vec![3])); + let right = build_table(("x", &vec![1]), ("y", &vec![2]), ("z", &vec![3])); + + // Try multi-column join + let on = vec![ + ( + Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("x", &right.schema()).unwrap()) as _, + ), + ( + Arc::new(Column::new_with_schema("b", &left.schema()).unwrap()) as _, + Arc::new(Column::new_with_schema("y", &right.schema()).unwrap()) as _, + ), + ]; + + // Try to create null-aware anti join with 2 columns (should fail) + let result = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::LeftAnti, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + true, // null_aware = true (invalid for multi-column) + ); + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("null_aware anti join only supports single column join key") + ); + } + + #[test] + fn test_lr_is_preserved() { + assert_eq!(lr_is_preserved(JoinType::Inner), (true, true)); + assert_eq!(lr_is_preserved(JoinType::Left), (true, false)); + assert_eq!(lr_is_preserved(JoinType::Right), (false, true)); + assert_eq!(lr_is_preserved(JoinType::Full), (false, false)); + assert_eq!(lr_is_preserved(JoinType::LeftSemi), (true, true)); + assert_eq!(lr_is_preserved(JoinType::LeftAnti), (true, true)); + assert_eq!(lr_is_preserved(JoinType::LeftMark), (true, false)); + assert_eq!(lr_is_preserved(JoinType::RightSemi), (true, true)); + assert_eq!(lr_is_preserved(JoinType::RightAnti), (true, true)); + assert_eq!(lr_is_preserved(JoinType::RightMark), (false, true)); + } + + #[test] + fn test_with_dynamic_filter() -> Result<()> { + let (_, _, on) = build_schema_and_on()?; + let left = build_table(("a1", &vec![1]), ("b1", &vec![1]), ("c1", &vec![1])); + let right = build_table(("a2", &vec![1]), ("b1", &vec![1]), ("c2", &vec![1])); + + let join = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + )?; + assert!(join.dynamic_filter_expr().is_none()); + + let df = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("b1", 1)) as _], + lit(true), + )); + let join = join.with_dynamic_filter_expr(Arc::clone(&df))?; + + let restored = join + .dynamic_filter_expr() + .expect("should have dynamic filter"); + assert_eq!( + restored + .expression_id() + .expect("DynamicFilterPhysicalExpr always has an expression_id"), + df.expression_id() + .expect("DynamicFilterPhysicalExpr always has an expression_id"), + ); + Ok(()) + } + + #[test] + fn test_with_dynamic_filter_rejects_invalid_columns() -> Result<()> { + let (_, _, on) = build_schema_and_on()?; + let left = build_table(("a1", &vec![1]), ("b1", &vec![1]), ("c1", &vec![1])); + let right = build_table(("a2", &vec![1]), ("b1", &vec![1]), ("c2", &vec![1])); + + let join = HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + )?; + // Column index 99 is out of bounds for the right (probe) side schema. + let df = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("bad", 99)) as _], + lit(true), + )); + assert!(join.with_dynamic_filter_expr(df).is_err()); Ok(()) } } diff --git a/datafusion/physical-plan/src/joins/hash_join/inlist_builder.rs b/datafusion/physical-plan/src/joins/hash_join/inlist_builder.rs new file mode 100644 index 0000000000000..2fc3201c6363f --- /dev/null +++ b/datafusion/physical-plan/src/joins/hash_join/inlist_builder.rs @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for building InList expressions from hash join build side data + +use std::sync::Arc; + +use arrow::array::{ArrayRef, StructArray}; +use arrow::datatypes::{Field, FieldRef, Fields}; +use arrow_schema::DataType; +use datafusion_common::Result; + +pub(super) fn build_struct_fields(data_types: &[DataType]) -> Result { + data_types + .iter() + .enumerate() + .map(|(i, dt)| Ok(Field::new(format!("c{i}"), dt.clone(), true))) + .collect() +} + +/// Builds InList values from join key column arrays. +/// +/// If `join_key_arrays` is: +/// 1. A single array, let's say Int32, this will produce a flat +/// InList expression where the lookup is expected to be scalar Int32 values, +/// that is: this will produce `IN LIST (1, 2, 3)` expected to be used as `2 IN LIST (1, 2, 3)`. +/// 2. An Int32 array and a Utf8 array, this will produce a Struct InList expression +/// where the lookup is expected to be Struct values with two fields (Int32, Utf8), +/// that is: this will produce `IN LIST ((1, "a"), (2, "b"))` expected to be used as `(2, "b") IN LIST ((1, "a"), (2, "b"))`. +/// The field names of the struct are auto-generated as "c0", "c1", ... and should match the struct expression used in the join keys. +/// +/// Note that this function does not deduplicate values - deduplication will happen later +/// when building an InList expression from this array via `InListExpr::try_new_from_array`. +/// +/// Returns `None` if the estimated size exceeds `max_size_bytes` or if the number of rows +/// exceeds `max_distinct_values`. +pub(super) fn build_struct_inlist_values( + join_key_arrays: &[ArrayRef], +) -> Result> { + // Build the source array/struct + let source_array: ArrayRef = if join_key_arrays.len() == 1 { + // Single column: use directly + Arc::clone(&join_key_arrays[0]) + } else { + // Multi-column: build StructArray once from all columns + let fields = build_struct_fields( + &join_key_arrays + .iter() + .map(|arr| arr.data_type().clone()) + .collect::>(), + )?; + + // Build field references with proper Arc wrapping + let arrays_with_fields: Vec<(FieldRef, ArrayRef)> = fields + .iter() + .cloned() + .zip(join_key_arrays.iter().cloned()) + .collect(); + + Arc::new(StructArray::from(arrays_with_fields)) + }; + + Ok(Some(source_array)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ + DictionaryArray, Int8Array, Int32Array, StringArray, StringDictionaryBuilder, + }; + + #[test] + fn test_build_single_column_inlist_array() { + let array = Arc::new(Int32Array::from(vec![1, 2, 3, 2, 1])) as ArrayRef; + let result = build_struct_inlist_values(std::slice::from_ref(&array)) + .unwrap() + .unwrap(); + + assert!(array.eq(&result)); + } + + #[test] + fn test_build_multi_column_inlist() { + let array1 = Arc::new(Int32Array::from(vec![1, 2, 3, 2, 1])) as ArrayRef; + let array2 = + Arc::new(StringArray::from(vec!["a", "b", "c", "b", "a"])) as ArrayRef; + + let result = build_struct_inlist_values(&[array1, array2]) + .unwrap() + .unwrap(); + + assert_eq!( + *result.data_type(), + DataType::Struct( + build_struct_fields(&[DataType::Int32, DataType::Utf8]).unwrap() + ) + ); + } + + #[test] + fn test_build_multi_column_inlist_with_dictionary() { + let mut builder = StringDictionaryBuilder::::new(); + builder.append_value("foo"); + builder.append_value("foo"); + builder.append_value("foo"); + let dict_array = Arc::new(builder.finish()) as ArrayRef; + + let int_array = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; + + let result = build_struct_inlist_values(&[dict_array, int_array]) + .unwrap() + .unwrap(); + + assert_eq!(result.len(), 3); + assert_eq!( + *result.data_type(), + DataType::Struct( + build_struct_fields(&[ + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Utf8) + ), + DataType::Int32 + ]) + .unwrap() + ) + ); + } + + #[test] + fn test_build_single_column_dictionary_inlist() { + let keys = Int8Array::from(vec![0i8, 0, 0]); + let values = Arc::new(StringArray::from(vec!["foo"])); + let dict_array = Arc::new(DictionaryArray::new(keys, values)) as ArrayRef; + + let result = build_struct_inlist_values(std::slice::from_ref(&dict_array)) + .unwrap() + .unwrap(); + + assert_eq!(result.len(), 3); + assert_eq!(result.data_type(), dict_array.data_type()); + } +} diff --git a/datafusion/physical-plan/src/joins/hash_join/mod.rs b/datafusion/physical-plan/src/joins/hash_join/mod.rs index 6c073e7a9cff5..b915802ea4015 100644 --- a/datafusion/physical-plan/src/joins/hash_join/mod.rs +++ b/datafusion/physical-plan/src/joins/hash_join/mod.rs @@ -17,9 +17,11 @@ //! [`HashJoinExec`] Partitioned Hash Join Operator -pub use exec::HashJoinExec; +pub use exec::{HashJoinExec, HashJoinExecBuilder}; +pub use partitioned_hash_eval::{HashExpr, HashTableLookupExpr, SeededRandomState}; mod exec; +mod inlist_builder; mod partitioned_hash_eval; mod shared_bounds; mod stream; diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs index 527642ade07e1..60a25fc2efcff 100644 --- a/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned_hash_eval.rs @@ -17,20 +17,54 @@ //! Hash computation and hash table lookup expressions for dynamic filtering -use std::{any::Any, fmt::Display, hash::Hash, sync::Arc}; +use std::{fmt::Display, hash::Hash, sync::Arc}; -use ahash::RandomState; use arrow::{ - array::UInt64Array, + array::{ArrayRef, UInt64Array}, datatypes::{DataType, Schema}, + record_batch::RecordBatch, }; use datafusion_common::Result; +use datafusion_common::hash_utils::RandomState; +use datafusion_common::hash_utils::{create_hashes, with_hashes}; +#[cfg(feature = "proto")] +use datafusion_common::internal_err; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::physical_expr::{ DynHash, PhysicalExpr, PhysicalExprRef, }; -use crate::hash_utils::create_hashes; +use crate::joins::Map; + +/// RandomState wrapper that preserves the seed used to create it. +/// +/// This is needed because `RandomState` doesn't expose its seed after creation, +/// but we need them for serialization (e.g., protobuf serde). +#[derive(Clone, Debug)] +pub struct SeededRandomState { + random_state: RandomState, + seed: u64, +} + +impl SeededRandomState { + /// Create a new SeededRandomState with the given seed. + pub const fn with_seed(k: u64) -> Self { + Self { + random_state: RandomState::with_seed(k), + seed: k, + } + } + + /// Get the inner RandomState. + pub fn random_state(&self) -> &RandomState { + &self.random_state + } + + /// Get the seed used to create this RandomState. + pub fn seed(&self) -> u64 { + self.seed + } +} /// Physical expression that computes hash values for a set of columns /// @@ -40,11 +74,11 @@ use crate::hash_utils::create_hashes; /// This is used for: /// - Computing routing hashes (with RepartitionExec's 0,0,0,0 seeds) /// - Computing lookup hashes (with HashJoin's 'J','O','I','N' seeds) -pub(super) struct HashExpr { +pub struct HashExpr { /// Columns to hash on_columns: Vec, - /// Random state for hashing - random_state: RandomState, + /// Random state for hashing (with seeds preserved for serialization) + random_state: SeededRandomState, /// Description for display description: String, } @@ -54,11 +88,11 @@ impl HashExpr { /// /// # Arguments /// * `on_columns` - Columns to hash - /// * `random_state` - RandomState for hashing + /// * `random_state` - SeededRandomState for hashing /// * `description` - Description for debugging (e.g., "hash_repartition", "hash_join") - pub(super) fn new( + pub fn new( on_columns: Vec, - random_state: RandomState, + random_state: SeededRandomState, description: String, ) -> Self { Self { @@ -67,6 +101,21 @@ impl HashExpr { description, } } + + /// Get the columns being hashed. + pub fn on_columns(&self) -> &[PhysicalExprRef] { + &self.on_columns + } + + /// Get the seed used for hashing. + pub fn seed(&self) -> u64 { + self.random_state.seed() + } + + /// Get the description. + pub fn description(&self) -> &str { + &self.description + } } impl std::fmt::Debug for HashExpr { @@ -77,7 +126,8 @@ impl std::fmt::Debug for HashExpr { .map(|e| e.to_string()) .collect::>() .join(", "); - write!(f, "{}({})", self.description, cols) + let seed = self.seed(); + write!(f, "{}({cols}, [{seed}])", self.description) } } @@ -85,12 +135,15 @@ impl Hash for HashExpr { fn hash(&self, state: &mut H) { self.on_columns.dyn_hash(state); self.description.hash(state); + self.seed().hash(state); } } impl PartialEq for HashExpr { fn eq(&self, other: &Self) -> bool { - self.on_columns == other.on_columns && self.description == other.description + self.on_columns == other.on_columns + && self.description == other.description + && self.seed() == other.seed() } } @@ -103,10 +156,6 @@ impl Display for HashExpr { } impl PhysicalExpr for HashExpr { - fn as_any(&self) -> &dyn Any { - self - } - fn children(&self) -> Vec<&Arc> { self.on_columns.iter().collect() } @@ -130,22 +179,19 @@ impl PhysicalExpr for HashExpr { Ok(false) } - fn evaluate( - &self, - batch: &arrow::record_batch::RecordBatch, - ) -> Result { + fn evaluate(&self, batch: &RecordBatch) -> Result { let num_rows = batch.num_rows(); // Evaluate columns - let keys_values = self - .on_columns - .iter() - .map(|c| c.evaluate(batch)?.into_array(num_rows)) - .collect::>>()?; + let keys_values = evaluate_columns(&self.on_columns, batch)?; // Compute hashes let mut hashes_buffer = vec![0; num_rows]; - create_hashes(&keys_values, &self.random_state, &mut hashes_buffer)?; + create_hashes( + &keys_values, + self.random_state.random_state(), + &mut hashes_buffer, + )?; Ok(ColumnarValue::Array(Arc::new(UInt64Array::from( hashes_buffer, @@ -155,4 +201,640 @@ impl PhysicalExpr for HashExpr { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.description) } + + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + let on_columns = ctx.encode_children_expressions(&self.on_columns)?; + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::HashExpr( + protobuf::PhysicalHashExprNode { + on_columns, + seed0: self.seed(), + description: self.description.clone(), + }, + )), + })) + } +} + +#[cfg(feature = "proto")] +impl HashExpr { + /// Reconstruct a [`HashExpr`] from its protobuf representation. + /// + /// Takes the whole [`PhysicalExprNode`], the exact inverse of what + /// [`PhysicalExpr::try_to_proto`] produces, so every expression's + /// `try_from_proto` shares one signature. Child sub-expressions are + /// decoded recursively via [`PhysicalExprDecodeCtx::decode`]. + /// + /// [`PhysicalExprNode`]: datafusion_proto_models::protobuf::PhysicalExprNode + /// [`PhysicalExpr::try_to_proto`]: datafusion_physical_expr_common::physical_expr::PhysicalExpr::try_to_proto + /// [`PhysicalExprDecodeCtx::decode`]: datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx::decode + pub fn try_from_proto( + node: &datafusion_proto_models::protobuf::PhysicalExprNode, + ctx: &datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + let hash_expr = match &node.expr_type { + Some(protobuf::physical_expr_node::ExprType::HashExpr(h)) => h, + _ => return internal_err!("PhysicalExprNode is not a HashExpr"), + }; + let on_columns = ctx.decode_children_expressions(&hash_expr.on_columns)?; + Ok(Arc::new(HashExpr::new( + on_columns, + SeededRandomState::with_seed(hash_expr.seed0), + hash_expr.description.clone(), + ))) + } +} + +/// Physical expression that checks join keys in a [`Map`] (hash table or array map). +/// +/// Returns a [`BooleanArray`](arrow::array::BooleanArray) indicating if join keys (from `on_columns`) exist in the map. +// TODO: rename to MapLookupExpr +pub struct HashTableLookupExpr { + /// Columns in the ON clause used to compute the join key for lookups + on_columns: Vec, + /// Random state for hashing (with seeds preserved for serialization) + random_state: SeededRandomState, + /// Map to check against (hash table or array map) + map: Arc, + /// Description for display + description: String, +} +impl HashTableLookupExpr { + /// Create a new HashTableLookupExpr + /// + /// # Arguments + /// * `on_columns` - Columns in the ON clause used to compute the join key + /// * `random_state` - SeededRandomState for hashing + /// * `map` - Map to check membership (hash table or array map) + /// * `description` - Description for debugging + /// # Note + /// This is public for internal testing purposes only and is not + /// guaranteed to be stable across versions. + pub fn new( + on_columns: Vec, + random_state: SeededRandomState, + map: Arc, + description: String, + ) -> Self { + Self { + on_columns, + random_state, + map, + description, + } + } +} +impl std::fmt::Debug for HashTableLookupExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let cols = self + .on_columns + .iter() + .map(|e| e.to_string()) + .collect::>() + .join(", "); + let seed = self.random_state.seed(); + write!(f, "{}({cols}, [{seed}])", self.description) + } +} + +impl Hash for HashTableLookupExpr { + fn hash(&self, state: &mut H) { + self.on_columns.dyn_hash(state); + self.description.hash(state); + self.random_state.seed().hash(state); + // Note that we compare hash_map by pointer equality. + // Actually comparing the contents of the hash maps would be expensive. + // The way these hash maps are used in actuality is that HashJoinExec creates + // one per partition per query execution, thus it is never possible for two different + // hash maps to have the same content in practice. + // Theoretically this is a public API and users could create identical hash maps, + // but that seems unlikely and not worth paying the cost of deep comparison all the time. + Arc::as_ptr(&self.map).hash(state); + } +} + +impl PartialEq for HashTableLookupExpr { + fn eq(&self, other: &Self) -> bool { + // Note that we compare hash_map by pointer equality. + // Actually comparing the contents of the hash maps would be expensive. + // The way these hash maps are used in actuality is that HashJoinExec creates + // one per partition per query execution, thus it is never possible for two different + // hash maps to have the same content in practice. + // Theoretically this is a public API and users could create identical hash maps, + // but that seems unlikely and not worth paying the cost of deep comparison all the time. + self.on_columns == other.on_columns + && self.description == other.description + && self.random_state.seed() == other.random_state.seed() + && Arc::ptr_eq(&self.map, &other.map) + } +} + +impl Eq for HashTableLookupExpr {} + +impl Display for HashTableLookupExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.description) + } +} + +impl PhysicalExpr for HashTableLookupExpr { + fn children(&self) -> Vec<&Arc> { + self.on_columns.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(HashTableLookupExpr::new( + children, + self.random_state.clone(), + Arc::clone(&self.map), + self.description.clone(), + ))) + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Boolean) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + // Evaluate columns + let join_keys = evaluate_columns(&self.on_columns, batch)?; + + match self.map.as_ref() { + Map::HashMap(map) => { + with_hashes(&join_keys, self.random_state.random_state(), |hashes| { + let array = map.contain_hashes(hashes); + Ok(ColumnarValue::Array(Arc::new(array))) + }) + } + Map::ArrayMap(map) => { + let array = map.contain_keys(&join_keys)?; + Ok(ColumnarValue::Array(Arc::new(array))) + } + } + } + #[cfg(feature = "proto")] + fn try_to_proto( + &self, + _ctx: &datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx<'_>, + ) -> Result> { + use datafusion_proto_models::protobuf; + use datafusion_proto_models::protobuf::physical_expr_node::ExprType; + + // HashTableLookupExpr holds a runtime Arc (the build-side hash + // table) that cannot be serialized, so it is replaced with lit(true). + // + // Dynamic filtering is a performance optimisation only — replacing the + // lookup with lit(true) preserves correctness by allowing all rows + // through. + // + // If a plan is serialized before execution, HashTableLookupExpr is not + // yet present in the dynamic filter expression. + // + // If a plan is serialized after execution, any runtime-created + // HashTableLookupExpr is replaced during serialization. Re-executing + // the plan requires reset_state(), after which HashJoinExec rebuilds + // fresh dynamic filters at runtime. + let value = datafusion_proto_common::ScalarValue { + value: Some(datafusion_proto_common::scalar_value::Value::BoolValue( + true, + )), + }; + Ok(Some(protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(ExprType::Literal(value)), + })) + } + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.description) + } +} + +fn evaluate_columns( + columns: &[PhysicalExprRef], + batch: &RecordBatch, +) -> Result> { + let num_rows = batch.num_rows(); + columns + .iter() + .map(|c| c.evaluate(batch)?.into_array(num_rows)) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::joins::join_hash_map::JoinHashMapU32; + use datafusion_physical_expr::expressions::Column; + use std::collections::hash_map::DefaultHasher; + use std::hash::Hasher; + + fn compute_hash(value: &T) -> u64 { + let mut hasher = DefaultHasher::new(); + value.hash(&mut hasher); + hasher.finish() + } + + #[test] + fn test_hash_expr_eq_same() { + let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); + let col_b: PhysicalExprRef = Arc::new(Column::new("b", 1)); + + let expr1 = HashExpr::new( + vec![Arc::clone(&col_a), Arc::clone(&col_b)], + SeededRandomState::with_seed(1), + "test_hash".to_string(), + ); + + let expr2 = HashExpr::new( + vec![Arc::clone(&col_a), Arc::clone(&col_b)], + SeededRandomState::with_seed(1), + "test_hash".to_string(), + ); + + assert_eq!(expr1, expr2); + } + + #[test] + fn test_hash_expr_eq_different_columns() { + let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); + let col_b: PhysicalExprRef = Arc::new(Column::new("b", 1)); + let col_c: PhysicalExprRef = Arc::new(Column::new("c", 2)); + + let expr1 = HashExpr::new( + vec![Arc::clone(&col_a), Arc::clone(&col_b)], + SeededRandomState::with_seed(1), + "test_hash".to_string(), + ); + + let expr2 = HashExpr::new( + vec![Arc::clone(&col_a), Arc::clone(&col_c)], + SeededRandomState::with_seed(1), + "test_hash".to_string(), + ); + + assert_ne!(expr1, expr2); + } + + #[test] + fn test_hash_expr_eq_different_description() { + let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); + + let expr1 = HashExpr::new( + vec![Arc::clone(&col_a)], + SeededRandomState::with_seed(1), + "hash_one".to_string(), + ); + + let expr2 = HashExpr::new( + vec![Arc::clone(&col_a)], + SeededRandomState::with_seed(1), + "hash_two".to_string(), + ); + + assert_ne!(expr1, expr2); + } + + #[test] + fn test_hash_expr_eq_different_seeds() { + let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); + + let expr1 = HashExpr::new( + vec![Arc::clone(&col_a)], + SeededRandomState::with_seed(1), + "test_hash".to_string(), + ); + + let expr2 = HashExpr::new( + vec![Arc::clone(&col_a)], + SeededRandomState::with_seed(5), + "test_hash".to_string(), + ); + + assert_ne!(expr1, expr2); + } + + #[test] + fn test_hash_expr_hash_consistency() { + let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); + let col_b: PhysicalExprRef = Arc::new(Column::new("b", 1)); + + let expr1 = HashExpr::new( + vec![Arc::clone(&col_a), Arc::clone(&col_b)], + SeededRandomState::with_seed(1), + "test_hash".to_string(), + ); + + let expr2 = HashExpr::new( + vec![Arc::clone(&col_a), Arc::clone(&col_b)], + SeededRandomState::with_seed(1), + "test_hash".to_string(), + ); + + // Equal expressions should have equal hashes + assert_eq!(expr1, expr2); + assert_eq!(compute_hash(&expr1), compute_hash(&expr2)); + } + + #[cfg(feature = "proto")] + mod proto_tests { + use super::*; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::internal_datafusion_err; + use datafusion_physical_expr_common::physical_expr::proto_decode::{ + PhysicalExprDecode, PhysicalExprDecodeCtx, + }; + use datafusion_physical_expr_common::physical_expr::proto_encode::{ + PhysicalExprEncode, PhysicalExprEncodeCtx, + }; + use datafusion_proto_models::protobuf; + + struct TestEncoder; + + impl PhysicalExprEncode for TestEncoder { + fn encode( + &self, + expr: &Arc, + ) -> Result { + let ctx = PhysicalExprEncodeCtx::new(self); + expr.try_to_proto(&ctx)?.ok_or_else(|| { + internal_datafusion_err!("test encoder cannot encode {expr:?}") + }) + } + } + + struct TestDecoder; + + impl PhysicalExprDecode for TestDecoder { + fn decode( + &self, + node: &protobuf::PhysicalExprNode, + schema: &Schema, + ) -> Result> { + let ctx = PhysicalExprDecodeCtx::new(schema, self); + match &node.expr_type { + Some(protobuf::physical_expr_node::ExprType::Column(_)) => { + Column::try_from_proto(node, &ctx) + } + _ => internal_err!("test decoder cannot decode {node:?}"), + } + } + } + + fn test_decode_ctx<'a>( + schema: &'a Schema, + decoder: &'a TestDecoder, + ) -> PhysicalExprDecodeCtx<'a> { + PhysicalExprDecodeCtx::new(schema, decoder) + } + + #[test] + fn hash_expr_try_to_proto() { + let expr = HashExpr::new( + vec![Arc::new(Column::new("a", 0)), Arc::new(Column::new("b", 1))], + SeededRandomState::with_seed(42), + "hash_join".to_string(), + ); + let encoder = TestEncoder; + let ctx = PhysicalExprEncodeCtx::new(&encoder); + + let proto = expr.try_to_proto(&ctx).unwrap().unwrap(); + + assert_eq!(proto.expr_id, None); + let hash_expr = match proto.expr_type.unwrap() { + protobuf::physical_expr_node::ExprType::HashExpr(hash_expr) => hash_expr, + other => panic!("expected HashExpr, got {other:?}"), + }; + assert_eq!(hash_expr.seed0, 42); + assert_eq!(hash_expr.description, "hash_join"); + assert_eq!(hash_expr.on_columns.len(), 2); + assert!( + hash_expr + .on_columns + .iter() + .all(|expr| expr.expr_id.is_none()) + ); + } + + #[test] + fn hash_expr_try_from_proto() { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ]); + let decoder = TestDecoder; + let ctx = test_decode_ctx(&schema, &decoder); + let proto = protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::HashExpr( + protobuf::PhysicalHashExprNode { + on_columns: vec![ + protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some( + protobuf::physical_expr_node::ExprType::Column( + protobuf::PhysicalColumn { + name: "a".to_string(), + index: 0, + }, + ), + ), + }, + protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some( + protobuf::physical_expr_node::ExprType::Column( + protobuf::PhysicalColumn { + name: "b".to_string(), + index: 1, + }, + ), + ), + }, + ], + seed0: 42, + description: "hash_join".to_string(), + }, + )), + }; + + let expr = HashExpr::try_from_proto(&proto, &ctx).unwrap(); + let expr = expr.downcast_ref::().unwrap(); + + assert_eq!(expr.seed(), 42); + assert_eq!(expr.description(), "hash_join"); + assert_eq!(expr.on_columns().len(), 2); + assert_eq!( + expr.on_columns()[0] + .downcast_ref::() + .map(|col| (col.name(), col.index())), + Some(("a", 0)) + ); + assert_eq!( + expr.on_columns()[1] + .downcast_ref::() + .map(|col| (col.name(), col.index())), + Some(("b", 1)) + ); + } + + #[test] + fn hash_expr_try_from_proto_rejects_wrong_node_type() { + let schema = Schema::empty(); + let decoder = TestDecoder; + let ctx = test_decode_ctx(&schema, &decoder); + let proto = protobuf::PhysicalExprNode { + expr_id: None, + expr_type: Some(protobuf::physical_expr_node::ExprType::Column( + protobuf::PhysicalColumn { + name: "a".to_string(), + index: 0, + }, + )), + }; + + let err = HashExpr::try_from_proto(&proto, &ctx).unwrap_err(); + assert!( + err.to_string() + .contains("PhysicalExprNode is not a HashExpr"), + "{err}" + ); + } + } + + #[test] + fn test_hash_table_lookup_expr_eq_same() { + let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); + let hash_map = + Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(10)))); + + let expr1 = HashTableLookupExpr::new( + vec![Arc::clone(&col_a)], + SeededRandomState::with_seed(1), + Arc::clone(&hash_map), + "lookup".to_string(), + ); + + let expr2 = HashTableLookupExpr::new( + vec![Arc::clone(&col_a)], + SeededRandomState::with_seed(1), + Arc::clone(&hash_map), + "lookup".to_string(), + ); + + assert_eq!(expr1, expr2); + } + + #[test] + fn test_hash_table_lookup_expr_eq_different_columns() { + let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); + let col_b: PhysicalExprRef = Arc::new(Column::new("b", 1)); + + let hash_map = + Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(10)))); + + let expr1 = HashTableLookupExpr::new( + vec![Arc::clone(&col_a)], + SeededRandomState::with_seed(1), + Arc::clone(&hash_map), + "lookup".to_string(), + ); + + let expr2 = HashTableLookupExpr::new( + vec![Arc::clone(&col_b)], + SeededRandomState::with_seed(1), + Arc::clone(&hash_map), + "lookup".to_string(), + ); + + assert_ne!(expr1, expr2); + } + + #[test] + fn test_hash_table_lookup_expr_eq_different_description() { + let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); + let hash_map = + Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(10)))); + + let expr1 = HashTableLookupExpr::new( + vec![Arc::clone(&col_a)], + SeededRandomState::with_seed(1), + Arc::clone(&hash_map), + "lookup_one".to_string(), + ); + + let expr2 = HashTableLookupExpr::new( + vec![Arc::clone(&col_a)], + SeededRandomState::with_seed(1), + Arc::clone(&hash_map), + "lookup_two".to_string(), + ); + + assert_ne!(expr1, expr2); + } + + #[test] + fn test_hash_table_lookup_expr_eq_different_hash_map() { + let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); + + // Two different Arc pointers (even with same content) should not be equal + let hash_map1 = + Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(10)))); + let hash_map2 = + Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(10)))); + let expr1 = HashTableLookupExpr::new( + vec![Arc::clone(&col_a)], + SeededRandomState::with_seed(1), + hash_map1, + "lookup".to_string(), + ); + + let expr2 = HashTableLookupExpr::new( + vec![Arc::clone(&col_a)], + SeededRandomState::with_seed(1), + hash_map2, + "lookup".to_string(), + ); + + // Different Arc pointers means not equal (uses Arc::ptr_eq) + assert_ne!(expr1, expr2); + } + + #[test] + fn test_hash_table_lookup_expr_hash_consistency() { + let col_a: PhysicalExprRef = Arc::new(Column::new("a", 0)); + let hash_map = + Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(10)))); + + let expr1 = HashTableLookupExpr::new( + vec![Arc::clone(&col_a)], + SeededRandomState::with_seed(1), + Arc::clone(&hash_map), + "lookup".to_string(), + ); + + let expr2 = HashTableLookupExpr::new( + vec![Arc::clone(&col_a)], + SeededRandomState::with_seed(1), + Arc::clone(&hash_map), + "lookup".to_string(), + ); + + // Equal expressions should have equal hashes + assert_eq!(expr1, expr2); + assert_eq!(compute_hash(&expr1), compute_hash(&expr2)); + } } diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs index cb727f40a20a2..0af4015ff7239 100644 --- a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -21,30 +21,37 @@ use std::fmt; use std::sync::Arc; -use crate::joins::hash_join::partitioned_hash_eval::HashExpr; -use crate::joins::PartitionMode; use crate::ExecutionPlan; use crate::ExecutionPlanProperties; - -use ahash::RandomState; -use datafusion_common::{Result, ScalarValue}; +use crate::joins::Map; +use crate::joins::PartitionMode; +use crate::joins::hash_join::exec::HASH_JOIN_SEED; +use crate::joins::hash_join::inlist_builder::build_struct_fields; +use crate::joins::hash_join::partitioned_hash_eval::{ + HashExpr, HashTableLookupExpr, SeededRandomState, +}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{DataFusionError, Result, ScalarValue, SharedResult}; use datafusion_expr::Operator; +use datafusion_functions::core::r#struct as struct_func; use datafusion_physical_expr::expressions::{ - lit, BinaryExpr, CaseExpr, DynamicFilterPhysicalExpr, + BinaryExpr, CaseExpr, DynamicFilterPhysicalExpr, InListExpr, lit, }; -use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef, ScalarFunctionExpr}; use parking_lot::Mutex; -use tokio::sync::Barrier; +use tokio::sync::Notify; /// Represents the minimum and maximum values for a specific column. /// Used in dynamic filter pushdown to establish value boundaries. #[derive(Debug, Clone, PartialEq)] pub(crate) struct ColumnBounds { /// The minimum value observed for this column - min: ScalarValue, + pub(crate) min: ScalarValue, /// The maximum value observed for this column - max: ScalarValue, + pub(crate) max: ScalarValue, } impl ColumnBounds { @@ -72,11 +79,71 @@ impl PartitionBounds { } } -/// Creates a bounds predicate from partition bounds. +/// Creates a membership predicate for filter pushdown. /// -/// Returns a bound predicate (col >= min AND col <= max) for all key columns in the ON expression that have computed bounds from the build phase. +/// If `inlist_values` is provided (for small build sides), creates an InList expression. +/// Otherwise, creates a HashTableLookup expression (for large build sides). +/// +/// Supports both single-column and multi-column joins using struct expressions. +fn create_membership_predicate( + on_right: &[PhysicalExprRef], + pushdown: PushdownStrategy, + random_state: &SeededRandomState, + schema: &Schema, +) -> Result>> { + match pushdown { + // Use InList expression for small build sides + PushdownStrategy::InList(in_list_array) => { + // Build the expression to compare against + let expr = if on_right.len() == 1 { + // Single column: col IN (val1, val2, ...) + Arc::clone(&on_right[0]) + } else { + let fields = build_struct_fields( + on_right + .iter() + .map(|r| r.data_type(schema)) + .collect::>>()? + .as_ref(), + )?; + + // The return field name and the function field name don't really matter here. + let return_field = + Arc::new(Field::new("struct", DataType::Struct(fields), true)); + + Arc::new(ScalarFunctionExpr::new( + "struct", + struct_func(), + on_right.to_vec(), + return_field, + Arc::new(ConfigOptions::default()), + )) as Arc + }; + + // Use InListExpr::try_new_from_array() to build an InList with static_filter optimization (hash-based lookup) + Ok(Some(Arc::new(InListExpr::try_new_from_array( + expr, + in_list_array, + false, + schema, + )?))) + } + // Use hash table lookup for large build sides + PushdownStrategy::Map(hash_map) => Ok(Some(Arc::new(HashTableLookupExpr::new( + on_right.to_vec(), + random_state.clone(), + hash_map, + "hash_lookup".to_string(), + )) as Arc)), + // Empty partition - should not create a filter for this + PushdownStrategy::Empty => Ok(None), + } +} + +/// Creates a bounds predicate from partition bounds. /// /// Returns `None` if no column bounds are available. +/// Returns a combined predicate (col >= min AND col <= max) for all columns with bounds. fn create_bounds_predicate( on_right: &[PhysicalExprRef], bounds: &PartitionBounds, @@ -117,6 +184,25 @@ fn create_bounds_predicate( } } +/// Combines a membership predicate and a bounds predicate with logical AND. +/// +/// Returns `None` when neither is available; callers decide the fallback (e.g. +/// skip updating the filter vs. emit a `lit(true)` branch inside a CASE). +fn combine_membership_and_bounds( + membership_expr: Option>, + bounds_expr: Option>, +) -> Option> { + match (membership_expr, bounds_expr) { + (Some(membership), Some(bounds)) => { + Some(Arc::new(BinaryExpr::new(bounds, Operator::And, membership)) + as Arc) + } + (Some(membership), None) => Some(membership), + (None, Some(bounds)) => Some(bounds), + (None, None) => None, + } +} + /// Coordinates build-side information collection across multiple partitions /// /// This structure collects information from the build side (hash tables and/or bounds) and @@ -127,9 +213,12 @@ fn create_bounds_predicate( /// ## Synchronization Strategy /// /// 1. Each partition computes information from its build-side data (hash maps and/or bounds) -/// 2. Information is stored in the shared state -/// 3. A barrier tracks how many partitions have reported -/// 4. When the last partition reports, information is merged and the filter is updated exactly once +/// 2. Information is stored in the shared state, which tracks how many partitions have reported +/// 3. When the last partition reports, one waiter is elected as the finalizer; it merges the +/// collected information, updates the dynamic filter exactly once, and publishes the +/// terminal result by transitioning [`CompletionState`] to `Ready` +/// 4. A [`tokio::sync::Notify`] wakes any other partitions parked in `wait_for_completion`, +/// which then observe the `Ready` state under the mutex and return immediately /// /// ## Hash Map vs Bounds /// @@ -149,53 +238,93 @@ fn create_bounds_predicate( /// partition executions. pub(crate) struct SharedBuildAccumulator { /// Build-side data protected by a single mutex to avoid ordering concerns - inner: Mutex, - barrier: Barrier, + inner: Mutex, + /// Wakes every partition that is parked in [`Self::wait_for_completion`] + /// once [`AccumulatorState::completion`] transitions to + /// [`CompletionState::Ready`]. Notifications are fired once per + /// accumulator lifetime (the elected finalizer publishes the terminal + /// result, then broadcasts), so late subscribers simply re-check the + /// state under the mutex and return immediately. + completion_notify: Notify, /// Dynamic filter for pushdown to probe side dynamic_filter: Arc, /// Right side join expressions needed for creating filter expressions on_right: Vec, /// Random state for partitioning (RepartitionExec's hash function with 0,0,0,0 seeds) /// Used for PartitionedHashLookupPhysicalExpr - repartition_random_state: RandomState, + repartition_random_state: SeededRandomState, + /// Schema of the probe (right) side for evaluating filter expressions + probe_schema: Arc, } +/// Strategy for filter pushdown (decided at collection time) #[derive(Clone)] -pub(crate) enum PartitionBuildDataReport { +pub(crate) enum PushdownStrategy { + /// Use InList for small build sides (< 128MB) + InList(ArrayRef), + /// Use map lookup for large build sides + Map(Arc), + /// There was no data in this partition, do not build a dynamic filter for it + Empty, +} + +/// Build-side data reported by a single partition +pub(crate) enum PartitionBuildData { Partitioned { partition_id: usize, - /// Bounds computed from this partition's build side. - /// If the partition is empty (no rows) this will be None. - bounds: Option, + pushdown: PushdownStrategy, + bounds: PartitionBounds, }, CollectLeft { - /// Bounds computed from the collected build side. - /// If the build side is empty (no rows) this will be None. - bounds: Option, + pushdown: PushdownStrategy, + bounds: PartitionBounds, }, } +/// Per-partition accumulated data (Partitioned mode) #[derive(Clone)] -struct PartitionedBuildData { - partition_id: usize, - bounds: PartitionBounds, -} - -#[derive(Clone)] -struct CollectLeftBuildData { +struct PartitionData { bounds: PartitionBounds, + pushdown: PushdownStrategy, } /// Build-side data organized by partition mode enum AccumulatedBuildData { Partitioned { - partitions: Vec>, + partitions: Vec, + completed_partitions: usize, }, CollectLeft { - data: Option, + data: PartitionStatus, + reported_count: usize, + expected_reports: usize, }, } +enum CompletionState { + Pending, + Finalizing, + Ready(SharedResult<()>), +} + +struct AccumulatorState { + data: AccumulatedBuildData, + completion: CompletionState, +} + +#[derive(Clone)] +enum PartitionStatus { + Pending, + Reported(PartitionData), + CanceledUnknown, +} + +#[derive(Clone)] +enum FinalizeInput { + Partitioned(Vec), + CollectLeft(PartitionStatus), +} + impl SharedBuildAccumulator { /// Creates a new SharedBuildAccumulator configured for the given partition mode /// @@ -228,7 +357,7 @@ impl SharedBuildAccumulator { right_child: &dyn ExecutionPlan, dynamic_filter: Arc, on_right: Vec, - repartition_random_state: RandomState, + repartition_random_state: SeededRandomState, ) -> Self { // Troubleshooting: If partition counts are incorrect, verify this logic matches // the actual execution pattern in collect_build_side() @@ -242,25 +371,39 @@ impl SharedBuildAccumulator { left_child.output_partitioning().partition_count() } // Default value, will be resolved during optimization (does not exist once `execute()` is called; will be replaced by one of the other two) - PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), + PartitionMode::Auto => unreachable!( + "PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!" + ), }; let mode_data = match partition_mode { PartitionMode::Partitioned => AccumulatedBuildData::Partitioned { - partitions: vec![None; left_child.output_partitioning().partition_count()], + partitions: vec![ + PartitionStatus::Pending; + left_child.output_partitioning().partition_count() + ], + completed_partitions: 0, }, PartitionMode::CollectLeft => AccumulatedBuildData::CollectLeft { - data: None, + data: PartitionStatus::Pending, + reported_count: 0, + expected_reports: expected_calls, }, - PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), + PartitionMode::Auto => unreachable!( + "PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!" + ), }; Self { - inner: Mutex::new(mode_data), - barrier: Barrier::new(expected_calls), + inner: Mutex::new(AccumulatorState { + data: mode_data, + completion: CompletionState::Pending, + }), + completion_notify: Notify::new(), dynamic_filter, on_right, repartition_random_state, + probe_schema: right_child.schema(), } } @@ -276,141 +419,274 @@ impl SharedBuildAccumulator { /// /// # Returns /// * `Result<()>` - Ok if successful, Err if filter update failed or mode mismatch - pub(crate) async fn report_build_data( + pub(crate) async fn report_build_data(&self, data: PartitionBuildData) -> Result<()> { + let finalize_input = { + let mut guard = self.inner.lock(); + self.store_build_data(&mut guard, data)?; + self.take_finalize_input_if_ready(&mut guard) + }; + + if let Some(finalize_input) = finalize_input { + self.finish(finalize_input); + } + + self.wait_for_completion().await + } + + pub(crate) fn report_canceled_partition(&self, partition_id: usize) { + let finalize_input = { + let mut guard = self.inner.lock(); + self.store_canceled_partition(&mut guard, partition_id); + self.take_finalize_input_if_ready(&mut guard) + }; + + if let Some(finalize_input) = finalize_input { + self.finish(finalize_input); + } + } + + fn store_build_data( &self, - data: PartitionBuildDataReport, + guard: &mut AccumulatorState, + data: PartitionBuildData, ) -> Result<()> { - // Store data in the accumulator + match (data, &mut guard.data) { + ( + PartitionBuildData::Partitioned { + partition_id, + pushdown, + bounds, + }, + AccumulatedBuildData::Partitioned { + partitions, + completed_partitions, + }, + ) => { + if matches!(partitions[partition_id], PartitionStatus::Pending) { + *completed_partitions += 1; + } + partitions[partition_id] = + PartitionStatus::Reported(PartitionData { pushdown, bounds }); + } + ( + PartitionBuildData::CollectLeft { pushdown, bounds }, + AccumulatedBuildData::CollectLeft { + data, + reported_count, + .. + }, + ) => { + if matches!(data, PartitionStatus::Pending) { + *data = PartitionStatus::Reported(PartitionData { pushdown, bounds }); + } + *reported_count += 1; + } + _ => { + return datafusion_common::internal_err!( + "Build data mode mismatch in report_build_data" + ); + } + } + Ok(()) + } + + fn store_canceled_partition( + &self, + guard: &mut AccumulatorState, + partition_id: usize, + ) { + if let AccumulatedBuildData::Partitioned { + partitions, + completed_partitions, + } = &mut guard.data + && matches!(partitions[partition_id], PartitionStatus::Pending) { - let mut guard = self.inner.lock(); + partitions[partition_id] = PartitionStatus::CanceledUnknown; + *completed_partitions += 1; + } + } + + fn take_finalize_input_if_ready( + &self, + guard: &mut AccumulatorState, + ) -> Option { + if !matches!(guard.completion, CompletionState::Pending) { + return None; + } - match (data, &mut *guard) { - // Partitioned mode - ( - PartitionBuildDataReport::Partitioned { - partition_id, - bounds, - }, - AccumulatedBuildData::Partitioned { partitions }, - ) => { - if let Some(bounds) = bounds { - partitions[partition_id] = Some(PartitionedBuildData { - partition_id, - bounds, - }); + let finalize_input = match &guard.data { + AccumulatedBuildData::Partitioned { + partitions, + completed_partitions, + } if *completed_partitions == partitions.len() => { + Some(FinalizeInput::Partitioned(partitions.clone())) + } + AccumulatedBuildData::CollectLeft { + data, + reported_count, + expected_reports, + } if *reported_count == *expected_reports => { + Some(FinalizeInput::CollectLeft(data.clone())) + } + _ => None, + }?; + + guard.completion = CompletionState::Finalizing; + Some(finalize_input) + } + + fn finish(&self, finalize_input: FinalizeInput) { + let result = self.build_filter(finalize_input).map_err(Arc::new); + self.dynamic_filter.mark_complete(); + + let mut guard = self.inner.lock(); + guard.completion = CompletionState::Ready(result); + drop(guard); + self.completion_notify.notify_waiters(); + } + + async fn wait_for_completion(&self) -> Result<()> { + loop { + let notified = { + let guard = self.inner.lock(); + match &guard.completion { + CompletionState::Ready(Ok(())) => return Ok(()), + CompletionState::Ready(Err(err)) => { + return Err(DataFusionError::Shared(Arc::clone(err))); } - } - // CollectLeft mode (store once, deduplicate across partitions) - ( - PartitionBuildDataReport::CollectLeft { bounds }, - AccumulatedBuildData::CollectLeft { data }, - ) => { - match (bounds, data) { - (None, _) | (_, Some(_)) => { - // No bounds reported or already reported; do nothing - } - (Some(new_bounds), data) => { - // First report, store the bounds - *data = Some(CollectLeftBuildData { bounds: new_bounds }); - } + CompletionState::Pending | CompletionState::Finalizing => { + self.completion_notify.notified() } } - // Mismatched modes - should never happen - _ => { - return datafusion_common::internal_err!( - "Build data mode mismatch in report_build_data" - ); - } - } + }; + notified.await; } + } - // Wait for all partitions to report - if self.barrier.wait().await.is_leader() { - // All partitions have reported, so we can create and update the filter - let inner = self.inner.lock(); - - match &*inner { - // CollectLeft: Simple conjunction of bounds and membership check - AccumulatedBuildData::CollectLeft { data } => { - if let Some(partition_data) = data { - // Create bounds check expression (if bounds available) - let Some(filter_expr) = create_bounds_predicate( - &self.on_right, - &partition_data.bounds, - ) else { - // No bounds available, nothing to update - return Ok(()); - }; - + fn build_filter(&self, finalize_input: FinalizeInput) -> Result<()> { + match finalize_input { + FinalizeInput::CollectLeft(partition) => match partition { + PartitionStatus::Reported(partition_data) => { + let membership_expr = create_membership_predicate( + &self.on_right, + partition_data.pushdown.clone(), + &HASH_JOIN_SEED, + self.probe_schema.as_ref(), + )?; + let bounds_expr = + create_bounds_predicate(&self.on_right, &partition_data.bounds); + + if let Some(filter_expr) = + combine_membership_and_bounds(membership_expr, bounds_expr) + { self.dynamic_filter.update(filter_expr)?; } } - // Partitioned: CASE expression routing to per-partition filters - AccumulatedBuildData::Partitioned { partitions } => { - // Collect all partition data, skipping empty partitions - let partition_data: Vec<_> = - partitions.iter().filter_map(|p| p.as_ref()).collect(); - - if partition_data.is_empty() { - // All partitions are empty: no rows can match, skip the probe side entirely - self.dynamic_filter.update(lit(false))?; - return Ok(()); + PartitionStatus::Pending => { + return datafusion_common::internal_err!( + "attempted to finalize collect-left dynamic filter without reported build data" + ); + } + PartitionStatus::CanceledUnknown => { + return datafusion_common::internal_err!( + "collect-left dynamic filter cannot finalize with canceled build data" + ); + } + }, + FinalizeInput::Partitioned(partitions) => { + let num_partitions = partitions.len(); + let routing_hash_expr = Arc::new(HashExpr::new( + self.on_right.clone(), + self.repartition_random_state.clone(), + "hash_repartition".to_string(), + )) as Arc; + + let modulo_expr = Arc::new(BinaryExpr::new( + routing_hash_expr, + Operator::Modulo, + lit(ScalarValue::UInt64(Some(num_partitions as u64))), + )) as Arc; + + let mut real_branches = Vec::new(); + let mut empty_partition_ids = Vec::new(); + let mut has_canceled_unknown = false; + + for (partition_id, partition) in partitions.iter().enumerate() { + match partition { + PartitionStatus::Reported(partition) + if matches!(partition.pushdown, PushdownStrategy::Empty) => + { + empty_partition_ids.push(partition_id); + } + PartitionStatus::Reported(partition) => { + let membership_expr = create_membership_predicate( + &self.on_right, + partition.pushdown.clone(), + &HASH_JOIN_SEED, + self.probe_schema.as_ref(), + )?; + let bounds_expr = create_bounds_predicate( + &self.on_right, + &partition.bounds, + ); + let then_expr = combine_membership_and_bounds( + membership_expr, + bounds_expr, + ) + .unwrap_or_else(|| lit(true)); + real_branches.push(( + lit(ScalarValue::UInt64(Some(partition_id as u64))), + then_expr, + )); + } + PartitionStatus::CanceledUnknown => { + has_canceled_unknown = true; + } + PartitionStatus::Pending => { + return datafusion_common::internal_err!( + "attempted to finalize dynamic filter with pending partition" + ); + } } + } - // Build a CASE expression that combines range checks AND membership checks - // CASE (hash_repartition(join_keys) % num_partitions) - // WHEN 0 THEN (col >= min_0 AND col <= max_0 AND ...) - // WHEN 1 THEN (col >= min_1 AND col <= max_1 AND ...) - // ... - // ELSE false - // END - - let num_partitions = partitions.len(); - - // Create base expression: hash_repartition(join_keys) % num_partitions - let routing_hash_expr = Arc::new(HashExpr::new( - self.on_right.clone(), - self.repartition_random_state.clone(), - "hash_repartition".to_string(), - )) - as Arc; - - let modulo_expr = Arc::new(BinaryExpr::new( - routing_hash_expr, - Operator::Modulo, - lit(ScalarValue::UInt64(Some(num_partitions as u64))), - )) as Arc; - - // Create WHEN branches for each partition - let when_then_branches: Vec<( - Arc, - Arc, - )> = partition_data + let filter_expr = if has_canceled_unknown { + let mut when_then_branches = empty_partition_ids .into_iter() - .map(|pdata| -> Result<_> { - // WHEN partition_id - let when_expr = - lit(ScalarValue::UInt64(Some(pdata.partition_id as u64))); - - // Create bounds check expression for this partition (if bounds available) - let bounds_expr = - create_bounds_predicate(&self.on_right, &pdata.bounds) - .unwrap_or_else(|| lit(true)); // No bounds means all rows pass - - Ok((when_expr, bounds_expr)) + .map(|partition_id| { + ( + lit(ScalarValue::UInt64(Some(partition_id as u64))), + lit(false), + ) }) - .collect::>>()?; - - let case_expr = Arc::new(CaseExpr::try_new( + .collect::>(); + when_then_branches.extend(real_branches); + + if when_then_branches.is_empty() { + lit(true) + } else { + Arc::new(CaseExpr::try_new( + Some(modulo_expr), + when_then_branches, + Some(lit(true)), + )?) as Arc + } + } else if real_branches.is_empty() { + lit(false) + } else if real_branches.len() == 1 + && empty_partition_ids.len() + 1 == num_partitions + { + Arc::clone(&real_branches[0].1) + } else { + Arc::new(CaseExpr::try_new( Some(modulo_expr), - when_then_branches, - Some(lit(false)), // ELSE false - )?) as Arc; + real_branches, + Some(lit(false)), + )?) as Arc + }; - self.dynamic_filter.update(case_expr)?; - } + self.dynamic_filter.update(filter_expr)?; } - self.dynamic_filter.mark_complete(); } Ok(()) @@ -422,3 +698,116 @@ impl fmt::Debug for SharedBuildAccumulator { write!(f, "SharedBuildAccumulator") } } + +#[cfg(test)] +pub(super) fn make_partitioned_accumulator_for_test( + num_partitions: usize, +) -> SharedBuildAccumulator { + let probe_schema = Arc::new(Schema::new(vec![Field::new( + "probe_key", + DataType::Int32, + false, + )])); + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new(vec![], lit(true))); + SharedBuildAccumulator { + inner: Mutex::new(AccumulatorState { + data: AccumulatedBuildData::Partitioned { + partitions: vec![PartitionStatus::Pending; num_partitions], + completed_partitions: 0, + }, + completion: CompletionState::Pending, + }), + completion_notify: Notify::new(), + dynamic_filter, + on_right: vec![], + repartition_random_state: SeededRandomState::with_seed(1), + probe_schema, + } +} + +#[cfg(test)] +pub(super) fn completed_partitions_for_test(acc: &SharedBuildAccumulator) -> usize { + let guard = acc.inner.lock(); + let AccumulatedBuildData::Partitioned { + completed_partitions, + .. + } = &guard.data + else { + panic!("expected partitioned accumulator"); + }; + *completed_partitions +} + +#[cfg(test)] +mod tests { + use super::*; + + fn partitioned_state(acc: &SharedBuildAccumulator) -> (Vec, usize) { + let guard = acc.inner.lock(); + let AccumulatedBuildData::Partitioned { + partitions, + completed_partitions, + } = &guard.data + else { + panic!("expected partitioned accumulator"); + }; + (partitions.clone(), *completed_partitions) + } + + // Regression guard for the build-report lifecycle fix: on `Drop`, a stream + // in `BuildReportState::ReportScheduled` still calls `report_canceled_partition` + // because it cannot tell whether the coordinator has already observed the + // report (first poll of the `OnceFut` runs `store_build_data` synchronously + // before the future's first `.await`, but the stream doesn't learn that + // until `get_shared` returns `Ok`). Correctness therefore relies on + // `store_canceled_partition` being a no-op when the partition is already + // `Reported`. This test pins that invariant. + #[test] + fn report_canceled_partition_is_noop_after_report() { + let acc = make_partitioned_accumulator_for_test(2); + + { + let mut guard = acc.inner.lock(); + acc.store_build_data( + &mut guard, + PartitionBuildData::Partitioned { + partition_id: 0, + pushdown: PushdownStrategy::Empty, + bounds: PartitionBounds::new(vec![]), + }, + ) + .unwrap(); + } + let (partitions, completed) = partitioned_state(&acc); + assert!(matches!(partitions[0], PartitionStatus::Reported(_))); + assert_eq!(completed, 1); + + acc.report_canceled_partition(0); + let (partitions, completed) = partitioned_state(&acc); + assert!( + matches!(partitions[0], PartitionStatus::Reported(_)), + "late cancel must not overwrite a prior Reported status" + ); + assert_eq!(completed, 1, "late cancel must not double-count completion"); + } + + // Drop from the `NotReported` (or first-poll-never-ran) state must + // transition `Pending` -> `CanceledUnknown` and bump `completed_partitions`, + // which is what unblocks sibling partitions waiting on the coordinator. + #[test] + fn report_canceled_partition_marks_pending_partition_canceled() { + let acc = make_partitioned_accumulator_for_test(2); + + acc.report_canceled_partition(0); + let (partitions, completed) = partitioned_state(&acc); + assert!(matches!(partitions[0], PartitionStatus::CanceledUnknown)); + assert_eq!(completed, 1); + + // Idempotent: a second cancel (e.g. a stray double-drop) must not + // double-count completion. + acc.report_canceled_partition(0); + let (partitions, completed) = partitioned_state(&acc); + assert!(matches!(partitions[0], PartitionStatus::CanceledUnknown)); + assert_eq!(completed, 1); + } +} diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index a50a6551db4d1..d403fa43cda4b 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -21,40 +21,43 @@ //! [`super::HashJoinExec`]. See comments in [`HashJoinStream`] for more details. use std::sync::Arc; +use std::sync::atomic::Ordering; use std::task::Poll; +use crate::coalesce::{LimitedBatchCoalescer, PushBatchStatus}; +use crate::joins::Map; +use crate::joins::MapOffset; +use crate::joins::PartitionMode; use crate::joins::hash_join::exec::JoinLeftData; use crate::joins::hash_join::shared_bounds::{ - PartitionBuildDataReport, SharedBuildAccumulator, + PartitionBounds, PartitionBuildData, SharedBuildAccumulator, }; use crate::joins::utils::{ - equal_rows_arr, get_final_indices_from_shared_bitmap, OnceFut, + OnceFut, equal_rows_arr, get_final_indices_from_shared_bitmap, }; -use crate::joins::PartitionMode; +use crate::stream::EmptyRecordBatchStream; use crate::{ - handle_state, + RecordBatchStream, SendableRecordBatchStream, handle_state, hash_utils::create_hashes, - joins::join_hash_map::JoinHashMapOffset, joins::utils::{ - adjust_indices_by_join_type, apply_join_filter_to_indices, + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinHashMapType, + StatefulStreamResult, adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_empty_build_side, build_batch_from_indices, - need_produce_result_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, - JoinHashMapType, StatefulStreamResult, + need_produce_result_in_final, }, - RecordBatchStream, SendableRecordBatchStream, }; use arrow::array::{Array, ArrayRef, UInt32Array, UInt64Array}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::{ - internal_datafusion_err, internal_err, JoinSide, JoinType, NullEquality, Result, + JoinSide, JoinType, NullEquality, Result, internal_datafusion_err, internal_err, }; use datafusion_physical_expr::PhysicalExprRef; -use ahash::RandomState; +use datafusion_common::hash_utils::RandomState; use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays; -use futures::{ready, Stream, StreamExt}; +use futures::{Stream, StreamExt, ready}; /// Represents build-side of hash join. pub(super) enum BuildSide { @@ -154,13 +157,13 @@ pub(super) struct ProcessProbeBatchState { /// Probe-side on expressions values values: Vec, /// Starting offset for JoinHashMap lookups - offset: JoinHashMapOffset, + offset: MapOffset, /// Max joined probe-side index from current batch joined_probe_idx: Option, } impl ProcessProbeBatchState { - fn advance(&mut self, offset: JoinHashMapOffset, joined_probe_idx: Option) { + fn advance(&mut self, offset: MapOffset, joined_probe_idx: Option) { self.offset = offset; if joined_probe_idx.is_some() { self.joined_probe_idx = joined_probe_idx; @@ -168,6 +171,113 @@ impl ProcessProbeBatchState { } } +/// Lifecycle of this partition's build-data report to the shared coordinator. +/// +/// `Scheduled` means the reporting `OnceFut` has been constructed but is lazy: +/// the coordinator has not necessarily observed the report. Only `Delivered` +/// guarantees the coordinator saw it, so `Drop` must still cancel a `Scheduled` +/// partition — otherwise sibling partitions can wait forever for a report that +/// never runs. +#[derive(Debug, PartialEq, Eq)] +enum BuildReportState { + NotReported, + Scheduled, + Delivered, + Canceled, + Finalized, +} + +/// Owns the stream-side lifecycle for one partition's build-data report. +struct BuildReportHandle { + partition: usize, + mode: PartitionMode, + build_accumulator: Option>, + waiter: Option>, + state: BuildReportState, +} + +impl BuildReportHandle { + fn new( + partition: usize, + mode: PartitionMode, + build_accumulator: Option>, + ) -> Self { + Self { + partition, + mode, + build_accumulator, + waiter: None, + state: BuildReportState::NotReported, + } + } + + fn has_accumulator(&self) -> bool { + self.build_accumulator.is_some() + } + + fn schedule(&mut self, build_data: PartitionBuildData) { + let Some(build_accumulator) = &self.build_accumulator else { + // Defensive no-op terminal state; current callers avoid scheduling + // unless an accumulator is present. + self.finalize(); + return; + }; + + debug_assert!(matches!(self.state, BuildReportState::NotReported)); + let acc = Arc::clone(build_accumulator); + self.waiter = Some(OnceFut::new(async move { + acc.report_build_data(build_data).await + })); + self.state = BuildReportState::Scheduled; + } + + fn poll_delivery(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + if let Some(ref mut fut) = self.waiter { + ready!(fut.get_shared(cx))?; + if !matches!(self.state, BuildReportState::Delivered) { + debug_assert!(matches!(self.state, BuildReportState::Scheduled)); + self.state = BuildReportState::Delivered; + } + } + Poll::Ready(Ok(())) + } + + fn cancel_pending(&mut self) { + if matches!( + self.state, + BuildReportState::Delivered + | BuildReportState::Canceled + | BuildReportState::Finalized + ) { + return; + } + + if self.mode == PartitionMode::Partitioned + && let Some(build_accumulator) = &self.build_accumulator + { + build_accumulator.report_canceled_partition(self.partition); + self.state = BuildReportState::Canceled; + } else { + self.finalize(); + } + } + + fn finalize(&mut self) { + self.state = BuildReportState::Finalized; + } + + #[cfg(test)] + fn state(&self) -> &BuildReportState { + &self.state + } +} + +impl Drop for BuildReportHandle { + fn drop(&mut self) { + self.cancel_pending(); + } +} + /// [`Stream`] for [`super::HashJoinExec`] that does the actual join. /// /// This stream: @@ -206,16 +316,21 @@ pub(super) struct HashJoinStream { batch_size: usize, /// Scratch space for computing hashes hashes_buffer: Vec, + /// Scratch space for probe indices during hash lookup + probe_indices_buffer: Vec, + /// Scratch space for build indices during hash lookup + build_indices_buffer: Vec, /// Specifies whether the right side has an ordering to potentially preserve right_side_ordered: bool, - /// Shared build accumulator for coordinating dynamic filter updates (collects hash maps and/or bounds, optional) - build_accumulator: Option>, - /// Optional future to signal when build information has been reported by all partitions - /// and the dynamic filter has been updated - build_waiter: Option>, - + /// Owns this partition's build-data report lifecycle. + build_report: BuildReportHandle, /// Partitioning mode to use mode: PartitionMode, + /// Output buffer for coalescing small batches into larger ones with optional fetch limit. + /// Uses `LimitedBatchCoalescer` to efficiently combine batches and absorb limit with 'fetch' + output_buffer: LimitedBatchCoalescer, + /// Whether this is a null-aware anti join + null_aware: bool, } impl RecordBatchStream for HashJoinStream { @@ -272,7 +387,7 @@ impl RecordBatchStream for HashJoinStream { /// Build indices: 4, 5, 6, 6 /// Probe indices: 3, 3, 4, 5 /// ``` -#[allow(clippy::too_many_arguments)] +#[expect(clippy::too_many_arguments)] pub(super) fn lookup_join_hashmap( build_hashmap: &dyn JoinHashMapType, build_side_values: &[ArrayRef], @@ -280,22 +395,37 @@ pub(super) fn lookup_join_hashmap( null_equality: NullEquality, hashes_buffer: &[u64], limit: usize, - offset: JoinHashMapOffset, -) -> Result<(UInt64Array, UInt32Array, Option)> { - let (probe_indices, build_indices, next_offset) = - build_hashmap.get_matched_indices_with_limit_offset(hashes_buffer, limit, offset); - - let build_indices: UInt64Array = build_indices.into(); - let probe_indices: UInt32Array = probe_indices.into(); - + offset: MapOffset, + probe_indices_buffer: &mut Vec, + build_indices_buffer: &mut Vec, +) -> Result<(UInt64Array, UInt32Array, Option)> { + let next_offset = build_hashmap.get_matched_indices_with_limit_offset( + hashes_buffer, + limit, + offset, + probe_indices_buffer, + build_indices_buffer, + ); + + let build_indices_unfiltered: UInt64Array = + std::mem::take(build_indices_buffer).into(); + let probe_indices_unfiltered: UInt32Array = + std::mem::take(probe_indices_buffer).into(); + + // TODO: optimize equal_rows_arr to avoid allocation of intermediate arrays + // https://github.com/apache/datafusion/issues/12131 let (build_indices, probe_indices) = equal_rows_arr( - &build_indices, - &probe_indices, + &build_indices_unfiltered, + &probe_indices_unfiltered, build_side_values, probe_side_values, null_equality, )?; + // Reclaim buffers + *build_indices_buffer = build_indices_unfiltered.into_parts().1.into(); + *probe_indices_buffer = probe_indices_unfiltered.into_parts().1.into(); + Ok((build_indices, probe_indices, next_offset)) } @@ -329,7 +459,7 @@ fn count_distinct_sorted_indices(indices: &UInt32Array) -> usize { } impl HashJoinStream { - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] pub(super) fn new( partition: usize, schema: Arc, @@ -348,7 +478,13 @@ impl HashJoinStream { right_side_ordered: bool, build_accumulator: Option>, mode: PartitionMode, + null_aware: bool, + fetch: Option, ) -> Self { + // Create output buffer with coalescing and optional fetch limit. + let output_buffer = + LimitedBatchCoalescer::new(Arc::clone(&schema), batch_size, fetch); + Self { partition, schema, @@ -364,11 +500,68 @@ impl HashJoinStream { build_side, batch_size, hashes_buffer, + probe_indices_buffer: Vec::with_capacity(batch_size), + build_indices_buffer: Vec::with_capacity(batch_size), right_side_ordered, - build_accumulator, - build_waiter: None, + build_report: BuildReportHandle::new(partition, mode, build_accumulator), mode, + output_buffer, + null_aware, + } + } + + /// Returns the next state after the build side has been fully collected + /// and any required build-side coordination has completed. + fn state_after_build_ready( + join_type: JoinType, + left_data: &JoinLeftData, + ) -> HashJoinStreamState { + if left_data.map().is_empty() + && join_type.empty_build_side_produces_empty_result() + { + HashJoinStreamState::Completed + } else { + HashJoinStreamState::FetchProbeBatch + } + } + + /// Transitions state after build-side data has been collected, automatically + /// reporting build data to the accumulator when one is present. + /// + /// If a `build_accumulator` is configured, this method constructs the + /// appropriate [`PartitionBuildData`], schedules the reporting future, and + /// returns [`HashJoinStreamState::WaitPartitionBoundsReport`]. Otherwise it + /// delegates to [`Self::state_after_build_ready`]. + fn transition_after_build_collected( + &mut self, + left_data: &Arc, + ) -> HashJoinStreamState { + if !self.build_report.has_accumulator() { + return Self::state_after_build_ready(self.join_type, left_data.as_ref()); } + + let pushdown = left_data.membership().clone(); + let bounds = left_data + .bounds + .clone() + .unwrap_or_else(|| PartitionBounds::new(vec![])); + + let build_data = match self.mode { + PartitionMode::Partitioned => PartitionBuildData::Partitioned { + partition_id: self.partition, + pushdown, + bounds, + }, + PartitionMode::CollectLeft => { + PartitionBuildData::CollectLeft { pushdown, bounds } + } + PartitionMode::Auto => unreachable!( + "PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!" + ), + }; + + self.build_report.schedule(build_data); + HashJoinStreamState::WaitPartitionBoundsReport } /// Separate implementation function that unpins the [`HashJoinStream`] so @@ -378,6 +571,19 @@ impl HashJoinStream { cx: &mut std::task::Context<'_>, ) -> Poll>> { loop { + // First, check if we have any completed batches ready to emit + if let Some(batch) = self.output_buffer.next_completed_batch() { + return self + .join_metrics + .baseline + .record_poll(Poll::Ready(Some(Ok(batch)))); + } + + // Check if the coalescer has finished (limit reached and flushed) + if self.output_buffer.is_finished() { + return Poll::Ready(None); + } + return match self.state { HashJoinStreamState::WaitBuildSide => { handle_state!(ready!(self.collect_build_side(cx))) @@ -389,12 +595,16 @@ impl HashJoinStream { handle_state!(ready!(self.fetch_probe_batch(cx))) } HashJoinStreamState::ProcessProbeBatch(_) => { - let poll = handle_state!(self.process_probe_batch()); - self.join_metrics.baseline.record_poll(poll) + handle_state!(self.process_probe_batch()) } HashJoinStreamState::ExhaustedProbeSide => { - let poll = handle_state!(self.process_unmatched_build_batch()); - self.join_metrics.baseline.record_poll(poll) + handle_state!(self.process_unmatched_build_batch()) + } + HashJoinStreamState::Completed if !self.output_buffer.is_empty() => { + // Flush any remaining buffered data + self.output_buffer.finish()?; + // Continue loop to emit the flushed batch + continue; } HashJoinStreamState::Completed => Poll::Ready(None), }; @@ -414,10 +624,10 @@ impl HashJoinStream { &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>>> { - if let Some(ref mut fut) = self.build_waiter { - ready!(fut.get_shared(cx))?; - } - self.state = HashJoinStreamState::FetchProbeBatch; + ready!(self.build_report.poll_delivery(cx))?; + let build_side = self.build_side.try_as_ready()?; + self.state = + Self::state_after_build_ready(self.join_type, build_side.left_data.as_ref()); Poll::Ready(Ok(StatefulStreamResult::Continue)) } @@ -430,46 +640,19 @@ impl HashJoinStream { ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); // build hash table from left (build) side, if not yet done - let left_data = ready!(self - .build_side - .try_as_initial_mut()? - .left_fut - .get_shared(cx))?; + let left_data = ready!( + self.build_side + .try_as_initial_mut()? + .left_fut + .get_shared(cx) + )?; build_timer.done(); - // Handle dynamic filter build-side information accumulation - // - // Dynamic filter coordination between partitions: - // Report hash maps (Partitioned mode) or bounds (CollectLeft mode) to the accumulator - // which will handle synchronization and filter updates - if let Some(ref build_accumulator) = self.build_accumulator { - let build_accumulator = Arc::clone(build_accumulator); - - let left_side_partition_id = match self.mode { - PartitionMode::Partitioned => self.partition, - PartitionMode::CollectLeft => 0, - PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), - }; + // Note: For null-aware anti join, we need to check the probe side (right) for NULLs, + // not the build side (left). The probe-side NULL check happens during process_probe_batch. + // The probe_side_has_null flag will be set there if any probe batch contains NULL. - let build_data = match self.mode { - PartitionMode::Partitioned => PartitionBuildDataReport::Partitioned { - partition_id: left_side_partition_id, - bounds: left_data.bounds.clone(), - }, - PartitionMode::CollectLeft => PartitionBuildDataReport::CollectLeft { - bounds: left_data.bounds.clone(), - }, - PartitionMode::Auto => unreachable!( - "PartitionMode::Auto should not be present at execution time" - ), - }; - self.build_waiter = Some(OnceFut::new(async move { - build_accumulator.report_build_data(build_data).await - })); - self.state = HashJoinStreamState::WaitPartitionBoundsReport; - } else { - self.state = HashJoinStreamState::FetchProbeBatch; - } + self.state = self.transition_after_build_collected(&left_data); self.build_side = BuildSide::Ready(BuildSideReadyState { left_data }); Poll::Ready(Ok(StatefulStreamResult::Continue)) @@ -485,15 +668,26 @@ impl HashJoinStream { ) -> Poll>>> { match ready!(self.right.poll_next_unpin(cx)) { None => { + // Release the probe-side input pipeline's resources. The schema + // is preserved so callers that still query `self.right.schema()` + // (e.g. for unmatched-build emission) keep working. + let right_schema = self.right.schema(); + self.right = Box::pin(EmptyRecordBatchStream::new(right_schema)); self.state = HashJoinStreamState::ExhaustedProbeSide; } Some(Ok(batch)) => { // Precalculate hash values for fetched batch let keys_values = evaluate_expressions_to_arrays(&self.on_right, &batch)?; - self.hashes_buffer.clear(); - self.hashes_buffer.resize(batch.num_rows(), 0); - create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?; + if let Map::HashMap(_) = self.build_side.try_as_ready()?.left_data.map() { + self.hashes_buffer.clear(); + self.hashes_buffer.resize(batch.num_rows(), 0); + create_hashes( + &keys_values, + &self.random_state, + &mut self.hashes_buffer, + )?; + } self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); @@ -527,8 +721,52 @@ impl HashJoinStream { let timer = self.join_metrics.join_time.timer(); - // if the left side is empty, we can skip the (potentially expensive) join operation - if build_side.left_data.hash_map.is_empty() && self.filter.is_none() { + // Null-aware anti join semantics: + // For LeftAnti: output LEFT (build) rows where LEFT.key NOT IN RIGHT.key + // 1. If RIGHT (probe) contains NULL in any batch, no LEFT rows should be output + // 2. LEFT rows with NULL keys should not be output (handled in final stage) + if self.null_aware { + // Mark that we've seen a probe batch with actual rows (probe side is non-empty) + // Only set this if batch has rows - empty batches don't count + // Use shared atomic state so all partitions can see this global information + if state.batch.num_rows() > 0 { + build_side + .left_data + .probe_side_non_empty + .store(true, Ordering::Relaxed); + } + + // Check if probe side (RIGHT) contains NULL + // Since null_aware validation ensures single column join, we only check the first column + let probe_key_column = &state.values[0]; + if probe_key_column.null_count() > 0 { + // Found NULL in probe side - set shared flag to prevent any output + build_side + .left_data + .probe_side_has_null + .store(true, Ordering::Relaxed); + } + + // If probe side has NULL (detected in this or any other partition), return empty result + if build_side + .left_data + .probe_side_has_null + .load(Ordering::Relaxed) + { + timer.done(); + self.state = HashJoinStreamState::FetchProbeBatch; + return Ok(StatefulStreamResult::Continue); + } + } + + // If the build side is empty, this stream only reaches ProcessProbeBatch for + // join types whose output still depends on probe rows. + let is_empty = build_side.left_data.map().is_empty(); + + if is_empty { + // Invariant: state_after_build_ready should have already completed + // join types whose result is fixed to empty when the build side is empty. + debug_assert!(!self.join_type.empty_build_side_produces_empty_result()); let result = build_batch_empty_build_side( &self.schema, build_side.left_data.batch(), @@ -537,22 +775,41 @@ impl HashJoinStream { self.join_type, )?; timer.done(); - + self.output_buffer.push_batch(result)?; self.state = HashJoinStreamState::FetchProbeBatch; - return Ok(StatefulStreamResult::Ready(Some(result))); + return Ok(StatefulStreamResult::Continue); } // get the matched by join keys indices - let (left_indices, right_indices, next_offset) = lookup_join_hashmap( - build_side.left_data.hash_map(), - build_side.left_data.values(), - &state.values, - self.null_equality, - &self.hashes_buffer, - self.batch_size, - state.offset, - )?; + let (left_indices, right_indices, next_offset) = match build_side.left_data.map() + { + Map::HashMap(map) => lookup_join_hashmap( + map.as_ref(), + build_side.left_data.values(), + &state.values, + self.null_equality, + &self.hashes_buffer, + self.batch_size, + state.offset, + &mut self.probe_indices_buffer, + &mut self.build_indices_buffer, + )?, + Map::ArrayMap(array_map) => { + let next_offset = array_map.get_matched_indices_with_limit_offset( + &state.values, + self.batch_size, + state.offset, + &mut self.probe_indices_buffer, + &mut self.build_indices_buffer, + )?; + ( + UInt64Array::from(self.build_indices_buffer.clone()), + UInt32Array::from(self.probe_indices_buffer.clone()), + next_offset, + ) + } + }; let distinct_right_indices_count = count_distinct_sorted_indices(&right_indices); @@ -576,6 +833,7 @@ impl HashJoinStream { filter, JoinSide::Left, None, + self.join_type, )? } else { (left_indices, right_indices) @@ -628,30 +886,36 @@ impl HashJoinStream { self.right_side_ordered, )?; - let result = if self.join_type == JoinType::RightMark { - build_batch_from_indices( - &self.schema, - &state.batch, - build_side.left_data.batch(), - &left_indices, - &right_indices, - &self.column_indices, - JoinSide::Right, - )? - } else { - build_batch_from_indices( - &self.schema, - build_side.left_data.batch(), - &state.batch, - &left_indices, - &right_indices, - &self.column_indices, - JoinSide::Left, - )? - }; + // Build output batch and push to coalescer + let (build_batch, probe_batch, join_side) = + if self.join_type == JoinType::RightMark { + (&state.batch, build_side.left_data.batch(), JoinSide::Right) + } else { + (build_side.left_data.batch(), &state.batch, JoinSide::Left) + }; + + let batch = build_batch_from_indices( + &self.schema, + build_batch, + probe_batch, + &left_indices, + &right_indices, + &self.column_indices, + join_side, + self.join_type, + )?; + + let push_status = self.output_buffer.push_batch(batch)?; timer.done(); + // If limit reached, finish and move to Completed state + if push_status == PushBatchStatus::LimitReached { + self.output_buffer.finish()?; + self.state = HashJoinStreamState::Completed; + return Ok(StatefulStreamResult::Continue); + } + if next_offset.is_none() { self.state = HashJoinStreamState::FetchProbeBatch; } else { @@ -662,7 +926,7 @@ impl HashJoinStream { ) }; - Ok(StatefulStreamResult::Ready(Some(result))) + Ok(StatefulStreamResult::Continue) } /// Processes unmatched build-side rows for certain join types and produces output batch @@ -679,38 +943,95 @@ impl HashJoinStream { } let build_side = self.build_side.try_as_ready()?; + + // For null-aware anti join, if probe side had NULL, no rows should be output + // Check shared atomic state to get global knowledge across all partitions + if self.null_aware + && build_side + .left_data + .probe_side_has_null + .load(Ordering::Relaxed) + { + timer.done(); + self.state = HashJoinStreamState::Completed; + return Ok(StatefulStreamResult::Continue); + } if !build_side.left_data.report_probe_completed() { self.state = HashJoinStreamState::Completed; return Ok(StatefulStreamResult::Continue); } // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = get_final_indices_from_shared_bitmap( + let (mut left_side, mut right_side) = get_final_indices_from_shared_bitmap( build_side.left_data.visited_indices_bitmap(), self.join_type, true, ); - let empty_right_batch = RecordBatch::new_empty(self.right.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - build_side.left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - if let Ok(ref batch) = result { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); + // For null-aware anti join, filter out LEFT rows with NULL in join keys + // BUT only if the probe side (RIGHT) was non-empty. If probe side is empty, + // NULL NOT IN (empty) = TRUE, so NULL rows should be returned. + // Use shared atomic state to get global knowledge across all partitions + if self.null_aware + && self.join_type == JoinType::LeftAnti + && build_side + .left_data + .probe_side_non_empty + .load(Ordering::Relaxed) + { + // Since null_aware validation ensures single column join, we only check the first column + let build_key_column = &build_side.left_data.values()[0]; + + // Filter out indices where the key is NULL + let filtered_indices: Vec = left_side + .iter() + .filter_map(|idx| { + let idx_usize = idx.unwrap() as usize; + if build_key_column.is_null(idx_usize) { + None // Skip rows with NULL keys + } else { + Some(idx.unwrap()) + } + }) + .collect(); + + left_side = UInt64Array::from(filtered_indices); + + // Update right_side to match the new length + let mut builder = arrow::array::UInt32Builder::with_capacity(left_side.len()); + builder.append_nulls(left_side.len()); + right_side = builder.finish(); } + + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(left_side.len()); + timer.done(); self.state = HashJoinStreamState::Completed; - Ok(StatefulStreamResult::Ready(Some(result?))) + // Push final unmatched indices to output buffer + if !left_side.is_empty() { + let empty_right_batch = RecordBatch::new_empty(self.right.schema()); + let batch = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + self.join_type, + )?; + let push_status = self.output_buffer.push_batch(batch)?; + + // If limit reached, finish the coalescer + if push_status == PushBatchStatus::LimitReached { + self.output_buffer.finish()?; + } + } + + Ok(StatefulStreamResult::Continue) } } @@ -724,3 +1045,75 @@ impl Stream for HashJoinStream { self.poll_next_impl(cx) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::joins::hash_join::shared_bounds::{ + PushdownStrategy, completed_partitions_for_test, + make_partitioned_accumulator_for_test, + }; + + fn empty_build_data(partition_id: usize) -> PartitionBuildData { + PartitionBuildData::Partitioned { + partition_id, + pushdown: PushdownStrategy::Empty, + bounds: PartitionBounds::new(vec![]), + } + } + + fn partitioned_handle(acc: &Arc) -> BuildReportHandle { + BuildReportHandle::new(0, PartitionMode::Partitioned, Some(Arc::clone(acc))) + } + + #[test] + fn build_report_handle_cancels_scheduled_partition_on_drop() { + let acc = Arc::new(make_partitioned_accumulator_for_test(2)); + + { + let mut handle = partitioned_handle(&acc); + handle.schedule(empty_build_data(0)); + assert_eq!(handle.state(), &BuildReportState::Scheduled); + } + + assert_eq!(completed_partitions_for_test(&acc), 1); + } + + #[test] + fn build_report_handle_does_not_cancel_delivered_partition_on_drop() { + let acc = Arc::new(make_partitioned_accumulator_for_test(1)); + + { + let mut handle = partitioned_handle(&acc); + handle.schedule(empty_build_data(0)); + let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref()); + assert!(matches!(handle.poll_delivery(&mut cx), Poll::Ready(Ok(())))); + assert_eq!(handle.state(), &BuildReportState::Delivered); + } + + assert_eq!(completed_partitions_for_test(&acc), 1); + } + + #[test] + fn build_report_handle_cancel_pending_is_idempotent() { + let acc = Arc::new(make_partitioned_accumulator_for_test(2)); + let mut handle = partitioned_handle(&acc); + handle.schedule(empty_build_data(0)); + + handle.cancel_pending(); + handle.cancel_pending(); + + assert_eq!(handle.state(), &BuildReportState::Canceled); + assert_eq!(completed_partitions_for_test(&acc), 1); + } + + #[test] + fn build_report_handle_no_accumulator_finalizes() { + let mut handle = BuildReportHandle::new(0, PartitionMode::Partitioned, None); + + handle.schedule(empty_build_data(0)); + handle.cancel_pending(); + + assert_eq!(handle.state(), &BuildReportState::Finalized); + } +} diff --git a/datafusion/physical-plan/src/joins/join_hash_map.rs b/datafusion/physical-plan/src/joins/join_hash_map.rs index bdd4bfeeb0fbe..8f0fb66b64fbf 100644 --- a/datafusion/physical-plan/src/joins/join_hash_map.rs +++ b/datafusion/physical-plan/src/joins/join_hash_map.rs @@ -22,8 +22,11 @@ use std::fmt::{self, Debug}; use std::ops::Sub; -use hashbrown::hash_table::Entry::{Occupied, Vacant}; +use arrow::array::BooleanArray; +use arrow::buffer::BooleanBuffer; +use arrow::datatypes::ArrowNativeType; use hashbrown::HashTable; +use hashbrown::hash_table::Entry::{Occupied, Vacant}; /// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. /// @@ -93,6 +96,12 @@ use hashbrown::HashTable; /// /// At runtime we choose between using `JoinHashMapU32` and `JoinHashMapU64` which oth implement /// `JoinHashMapType`. +/// +/// ## Note on use of this trait as a public API +/// This is currently a public trait but is mainly intended for internal use within DataFusion. +/// For example, we may compare references to `JoinHashMapType` implementations by pointer equality +/// rather than deep equality of contents, as deep equality would be expensive and in our usage +/// patterns it is impossible for two different hash maps to have identical contents in a practical sense. pub trait JoinHashMapType: Send + Sync { fn extend_zero(&mut self, len: usize); @@ -112,11 +121,19 @@ pub trait JoinHashMapType: Send + Sync { &self, hash_values: &[u64], limit: usize, - offset: JoinHashMapOffset, - ) -> (Vec, Vec, Option); + offset: MapOffset, + input_indices: &mut Vec, + match_indices: &mut Vec, + ) -> Option; + + /// Returns a BooleanArray indicating which of the provided hashes exist in the map. + fn contain_hashes(&self, hash_values: &[u64]) -> BooleanArray; /// Returns `true` if the join hash map contains no entries. fn is_empty(&self) -> bool; + + /// Returns the number of entries in the join hash map. + fn len(&self) -> usize; } pub struct JoinHashMapU32 { @@ -169,20 +186,32 @@ impl JoinHashMapType for JoinHashMapU32 { &self, hash_values: &[u64], limit: usize, - offset: JoinHashMapOffset, - ) -> (Vec, Vec, Option) { + offset: MapOffset, + input_indices: &mut Vec, + match_indices: &mut Vec, + ) -> Option { get_matched_indices_with_limit_offset::( &self.map, &self.next, hash_values, limit, offset, + input_indices, + match_indices, ) } + fn contain_hashes(&self, hash_values: &[u64]) -> BooleanArray { + contain_hashes(&self.map, hash_values) + } + fn is_empty(&self) -> bool { self.map.is_empty() } + + fn len(&self) -> usize { + self.map.len() + } } pub struct JoinHashMapU64 { @@ -235,60 +264,37 @@ impl JoinHashMapType for JoinHashMapU64 { &self, hash_values: &[u64], limit: usize, - offset: JoinHashMapOffset, - ) -> (Vec, Vec, Option) { + offset: MapOffset, + input_indices: &mut Vec, + match_indices: &mut Vec, + ) -> Option { get_matched_indices_with_limit_offset::( &self.map, &self.next, hash_values, limit, offset, + input_indices, + match_indices, ) } + fn contain_hashes(&self, hash_values: &[u64]) -> BooleanArray { + contain_hashes(&self.map, hash_values) + } + fn is_empty(&self) -> bool { self.map.is_empty() } -} -// Type of offsets for obtaining indices from JoinHashMap. -pub(crate) type JoinHashMapOffset = (usize, Option); - -// Macro for traversing chained values with limit. -// Early returns in case of reaching output tuples limit. -macro_rules! chain_traverse { - ( - $input_indices:ident, $match_indices:ident, - $hash_values:ident, $next_chain:ident, - $input_idx:ident, $chain_idx:ident, $remaining_output:ident, $one:ident, $zero:ident - ) => {{ - // now `one` and `zero` are in scope from the outer function - let mut match_row_idx = $chain_idx - $one; - loop { - $match_indices.push(match_row_idx.into()); - $input_indices.push($input_idx as u32); - $remaining_output -= 1; - - let next = $next_chain[match_row_idx.into() as usize]; - - if $remaining_output == 0 { - // we compare against `zero` (of type T) here too - let next_offset = if $input_idx == $hash_values.len() - 1 && next == $zero - { - None - } else { - Some(($input_idx, Some(next.into()))) - }; - return ($input_indices, $match_indices, next_offset); - } - if next == $zero { - break; - } - match_row_idx = next - $one; - } - }}; + fn len(&self) -> usize { + self.map.len() + } } +use crate::joins::MapOffset; +use crate::joins::chain::traverse_chain; + pub fn update_from_iter<'a, T>( map: &mut HashTable<(u64, T)>, next: &mut [T], @@ -375,15 +381,18 @@ pub fn get_matched_indices_with_limit_offset( next_chain: &[T], hash_values: &[u64], limit: usize, - offset: JoinHashMapOffset, -) -> (Vec, Vec, Option) + offset: MapOffset, + input_indices: &mut Vec, + match_indices: &mut Vec, +) -> Option where T: Copy + TryFrom + PartialOrd + Into + Sub, >::Error: Debug, + T: ArrowNativeType, { - let mut input_indices = Vec::with_capacity(limit); - let mut match_indices = Vec::with_capacity(limit); - let zero = T::try_from(0).unwrap(); + // Clear the buffer before producing new results + input_indices.clear(); + match_indices.clear(); let one = T::try_from(1).unwrap(); // Check if hashmap consists of unique values @@ -397,19 +406,18 @@ where match_indices.push((*idx - one).into()); } } - let next_off = if end == hash_values.len() { + return if end == hash_values.len() { None } else { Some((end, None)) }; - return (input_indices, match_indices, next_off); } let mut remaining_output = limit; // Calculate initial `hash_values` index before iterating let to_skip = match offset { - // None `initial_next_idx` indicates that `initial_idx` processing has'n been started + // None `initial_next_idx` indicates that `initial_idx` processing hasn't been started (idx, None) => idx, // Zero `initial_next_idx` indicates that `initial_idx` has been processed during // previous iteration, and it should be skipped @@ -417,39 +425,73 @@ where // Otherwise, process remaining `initial_idx` matches by traversing `next_chain`, // to start with the next index (idx, Some(next_idx)) => { - let next_idx: T = T::try_from(next_idx as usize).unwrap(); - chain_traverse!( - input_indices, - match_indices, - hash_values, + let next_idx: T = T::usize_as(next_idx as usize); + let is_last = idx == hash_values.len() - 1; + if let Some(next_offset) = traverse_chain( next_chain, idx, next_idx, - remaining_output, - one, - zero - ); + &mut remaining_output, + input_indices, + match_indices, + is_last, + ) { + return Some(next_offset); + } idx + 1 } }; - let mut row_idx = to_skip; - for &hash in &hash_values[to_skip..] { + let hash_values_len = hash_values.len(); + for (i, &hash) in hash_values[to_skip..].iter().enumerate() { + let row_idx = to_skip + i; if let Some((_, idx)) = map.find(hash, |(h, _)| hash == *h) { let idx: T = *idx; - chain_traverse!( - input_indices, - match_indices, - hash_values, + let is_last = row_idx == hash_values_len - 1; + if let Some(next_offset) = traverse_chain( next_chain, row_idx, idx, - remaining_output, - one, - zero - ); + &mut remaining_output, + input_indices, + match_indices, + is_last, + ) { + return Some(next_offset); + } + } + } + None +} + +pub fn contain_hashes(map: &HashTable<(u64, T)>, hash_values: &[u64]) -> BooleanArray { + let buffer = BooleanBuffer::collect_bool(hash_values.len(), |i| { + let hash = hash_values[i]; + map.find(hash, |(h, _)| hash == *h).is_some() + }); + BooleanArray::new(buffer, None) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_contain_hashes() { + let mut hash_map = JoinHashMapU32::with_capacity(10); + hash_map.update_from_iter(Box::new([10u64, 20u64, 30u64].iter().enumerate()), 0); + + let probe_hashes = vec![10, 11, 20, 21, 30, 31]; + let array = hash_map.contain_hashes(&probe_hashes); + + assert_eq!(array.len(), probe_hashes.len()); + + for (i, &hash) in probe_hashes.iter().enumerate() { + if matches!(hash, 10 | 20 | 30) { + assert!(array.value(i), "Hash {hash} should exist in the map"); + } else { + assert!(!array.value(i), "Hash {hash} should NOT exist in the map"); + } } - row_idx += 1; } - (input_indices, match_indices, None) } diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index b0c28cf994f71..2cdfa1e6ac020 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -20,13 +20,16 @@ use arrow::array::BooleanBufferBuilder; pub use cross_join::CrossJoinExec; use datafusion_physical_expr::PhysicalExprRef; -pub use hash_join::HashJoinExec; -pub use nested_loop_join::NestedLoopJoinExec; +pub use hash_join::{ + HashExpr, HashJoinExec, HashJoinExecBuilder, HashTableLookupExpr, SeededRandomState, +}; +pub use nested_loop_join::{NestedLoopJoinExec, NestedLoopJoinExecBuilder}; use parking_lot::Mutex; // Note: SortMergeJoin is not used in plans yet pub use piecewise_merge_join::PiecewiseMergeJoinExec; pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; +pub mod chain; mod cross_join; mod hash_join; mod nested_loop_join; @@ -36,8 +39,38 @@ mod stream_join_utils; mod symmetric_hash_join; pub mod utils; +mod array_map; mod join_filter; -mod join_hash_map; +/// Hash map implementations for join operations. +/// +/// Note: This module is public for internal testing purposes only +/// and is not guaranteed to be stable across versions. +pub mod join_hash_map; + +use array_map::ArrayMap; +use utils::JoinHashMapType; + +pub enum Map { + HashMap(Box), + ArrayMap(ArrayMap), +} + +impl Map { + /// Returns the number of elements in the map. + pub fn num_of_distinct_key(&self) -> usize { + match self { + Map::HashMap(map) => map.len(), + Map::ArrayMap(array_map) => array_map.num_of_distinct_key(), + } + } + + /// Returns `true` if the map contains no elements. + pub fn is_empty(&self) -> bool { + self.num_of_distinct_key() == 0 + } +} + +pub(crate) type MapOffset = (usize, Option); #[cfg(test)] pub mod test_utils; diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index f16e2220dfbee..0bd053a9db12c 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -17,11 +17,10 @@ //! [`NestedLoopJoinExec`]: joins without equijoin (equality predicates). -use std::any::Any; use std::fmt::Formatter; use std::ops::{BitOr, ControlFlow}; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::task::Poll; use super::utils::{ @@ -29,54 +28,61 @@ use super::utils::{ reorder_output_after_swap, swap_join_projection, }; use crate::common::can_project; -use crate::execution_plan::{boundedness_from_children, EmissionType}; +use crate::execution_plan::{EmissionType, boundedness_from_children}; +use crate::joins::SharedBitmapBuilder; use crate::joins::utils::{ + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceAsync, OnceFut, build_join_schema, check_join_is_valid, estimate_join_statistics, - need_produce_right_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, - OnceAsync, OnceFut, + need_produce_right_in_final, }; -use crate::joins::SharedBitmapBuilder; use crate::metrics::{ Count, ExecutionPlanMetricsSet, MetricBuilder, MetricType, MetricsSet, RatioMetrics, }; use crate::projection::{ - try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, - ProjectionExec, + EmbeddedProjection, JoinData, ProjectionExec, try_embed_projection, + try_pushdown_through_join, }; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, + check_if_same_properties, }; use arrow::array::{ - new_null_array, Array, BooleanArray, BooleanBufferBuilder, RecordBatchOptions, - UInt32Array, UInt64Array, + Array, BooleanArray, BooleanBufferBuilder, RecordBatchOptions, UInt32Array, + UInt64Array, new_null_array, }; use arrow::buffer::BooleanBuffer; use arrow::compute::{ - concat_batches, filter, filter_record_batch, not, take, BatchCoalescer, + BatchCoalescer, concat_batches, filter, filter_record_batch, not, take, }; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow_schema::DataType; use datafusion_common::cast::as_boolean_array; use datafusion_common::{ - arrow_err, assert_eq_or_internal_err, internal_datafusion_err, internal_err, - project_schema, unwrap_or_internal_err, DataFusionError, JoinSide, Result, - ScalarValue, Statistics, + JoinSide, NullEquality, Result, ScalarValue, Statistics, arrow_err, + assert_eq_or_internal_err, internal_datafusion_err, internal_err, project_schema, + unwrap_or_internal_err, }; -use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_expr::JoinType; use datafusion_physical_expr::equivalence::{ - join_equivalence_properties, ProjectionMapping, + ProjectionMapping, join_equivalence_properties, }; +use datafusion_physical_expr::projection::{ProjectionRef, combine_projections}; use futures::{Stream, StreamExt, TryStreamExt}; use log::debug; use parking_lot::Mutex; -#[allow(rustdoc::private_intra_doc_links)] +use crate::metrics::SpillMetrics; +use crate::spill::replayable_spill_input::ReplayableStreamSource; +use crate::spill::spill_manager::SpillManager; + +#[expect(rustdoc::private_intra_doc_links)] /// NestedLoopJoinExec is a build-probe join operator designed for joins that /// do not have equijoin keys in their `ON` clause. /// @@ -159,10 +165,23 @@ use parking_lot::Mutex; /// - The design try to minimize the intermediate data size to approximately /// 1 batch, for better cache locality and memory efficiency. /// -/// # TODO: Memory-limited Execution -/// If the memory budget is exceeded during left-side buffering, fallback -/// strategies such as streaming left batches and re-scanning the right side -/// may be implemented in the future. +/// # Memory-limited Execution +/// When the memory budget is exceeded during left-side buffering, the operator +/// falls back to a multi-pass strategy: +/// 1. Buffer as many left rows as fit in memory (one "chunk") +/// 2. On the first pass, the right side is both processed and spilled to disk +/// 3. For each subsequent left chunk, the right side is re-read from the spill file +/// +/// The fallback is triggered automatically when the initial in-memory load +/// fails with `ResourcesExhausted` and disk spilling is available. Each +/// output partition independently re-executes the left child and manages +/// its own spill state. +/// +/// All join types are supported. For RIGHT/FULL/RIGHT SEMI/RIGHT ANTI/ +/// RIGHT MARK joins, a global right-side bitmap (indexed by right batch +/// sequence number) accumulates matches across all left chunks. After the +/// last left chunk is processed, the right side is replayed one more time +/// to emit unmatched right rows using the accumulated bitmap. /// /// Tracking issue: /// @@ -190,53 +209,131 @@ pub struct NestedLoopJoinExec { /// Each output stream waits on the `OnceAsync` to signal the completion of /// the build(left) side data, and buffer them all for later joining. build_side_data: OnceAsync, + /// Shared left-side spill data for OOM fallback. + /// + /// When `build_side_data` fails with OOM, the first partition to + /// initiate fallback spills the entire left side to disk. Other + /// partitions share the same spill file via this `OnceAsync`, + /// avoiding redundant re-execution of the left child. + left_spill_data: Arc>, /// Information of index and left / right placement of columns column_indices: Vec, /// Projection to apply to the output of the join - projection: Option>, + projection: Option, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } -impl NestedLoopJoinExec { - /// Try to create a new [`NestedLoopJoinExec`] - pub fn try_new( +/// Helps to build [`NestedLoopJoinExec`]. +pub struct NestedLoopJoinExecBuilder { + left: Arc, + right: Arc, + join_type: JoinType, + filter: Option, + projection: Option, +} + +impl NestedLoopJoinExecBuilder { + /// Make a new [`NestedLoopJoinExecBuilder`]. + pub fn new( left: Arc, right: Arc, - filter: Option, - join_type: &JoinType, - projection: Option>, - ) -> Result { + join_type: JoinType, + ) -> Self { + Self { + left, + right, + join_type, + filter: None, + projection: None, + } + } + + /// Set projection from the vector. + pub fn with_projection(self, projection: Option>) -> Self { + self.with_projection_ref(projection.map(Into::into)) + } + + /// Set projection from the shared reference. + pub fn with_projection_ref(mut self, projection: Option) -> Self { + self.projection = projection; + self + } + + /// Set optional filter. + pub fn with_filter(mut self, filter: Option) -> Self { + self.filter = filter; + self + } + + /// Build resulting execution plan. + pub fn build(self) -> Result { + let Self { + left, + right, + join_type, + filter, + projection, + } = self; + let left_schema = left.schema(); let right_schema = right.schema(); check_join_is_valid(&left_schema, &right_schema, &[])?; let (join_schema, column_indices) = - build_join_schema(&left_schema, &right_schema, join_type); + build_join_schema(&left_schema, &right_schema, &join_type); let join_schema = Arc::new(join_schema); - let cache = Self::compute_properties( + let cache = NestedLoopJoinExec::compute_properties( &left, &right, &join_schema, - *join_type, - projection.as_ref(), + join_type, + projection.as_deref(), )?; - Ok(NestedLoopJoinExec { left, right, filter, - join_type: *join_type, + join_type, join_schema, build_side_data: Default::default(), + left_spill_data: Arc::new(OnceAsync::default()), column_indices, projection, metrics: Default::default(), - cache, + cache: Arc::new(cache), }) } +} + +impl From<&NestedLoopJoinExec> for NestedLoopJoinExecBuilder { + fn from(exec: &NestedLoopJoinExec) -> Self { + Self { + left: Arc::clone(exec.left()), + right: Arc::clone(exec.right()), + join_type: exec.join_type, + filter: exec.filter.clone(), + projection: exec.projection.clone(), + } + } +} + +impl NestedLoopJoinExec { + /// Try to create a new [`NestedLoopJoinExec`] + pub fn try_new( + left: Arc, + right: Arc, + filter: Option, + join_type: &JoinType, + projection: Option>, + ) -> Result { + NestedLoopJoinExecBuilder::new(left, right, *join_type) + .with_projection(projection) + .with_filter(filter) + .build() + } /// left side pub fn left(&self) -> &Arc { @@ -258,8 +355,8 @@ impl NestedLoopJoinExec { &self.join_type } - pub fn projection(&self) -> Option<&Vec> { - self.projection.as_ref() + pub fn projection(&self) -> &Option { + &self.projection } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -268,7 +365,7 @@ impl NestedLoopJoinExec { right: &Arc, schema: &SchemaRef, join_type: JoinType, - projection: Option<&Vec>, + projection: Option<&[usize]>, ) -> Result { // Calculate equivalence properties: let mut eq_properties = join_equivalence_properties( @@ -311,7 +408,7 @@ impl NestedLoopJoinExec { if let Some(projection) = projection { // construct a map from the input expressions to the output expression of the Projection let projection_mapping = ProjectionMapping::from_indices(projection, schema)?; - let out_schema = project_schema(schema, Some(projection))?; + let out_schema = project_schema(schema, Some(&projection))?; output_partitioning = output_partitioning.project(&projection_mapping, &eq_properties); eq_properties = eq_properties.project(&projection_mapping, out_schema); @@ -335,22 +432,14 @@ impl NestedLoopJoinExec { } pub fn with_projection(&self, projection: Option>) -> Result { + let projection = projection.map(Into::into); // check if the projection is valid - can_project(&self.schema(), projection.as_ref())?; - let projection = match projection { - Some(projection) => match &self.projection { - Some(p) => Some(projection.iter().map(|i| p[*i]).collect()), - None => Some(projection), - }, - None => None, - }; - Self::try_new( - Arc::clone(&self.left), - Arc::clone(&self.right), - self.filter.clone(), - &self.join_type, - projection, - ) + can_project(&self.schema(), projection.as_deref())?; + let projection = + combine_projections(projection.as_ref(), self.projection.as_ref())?; + NestedLoopJoinExecBuilder::from(self) + .with_projection_ref(projection) + .build() } /// Returns a new `ExecutionPlan` that runs NestedLoopsJoins with the left @@ -372,7 +461,7 @@ impl NestedLoopJoinExec { swap_join_projection( left.schema().fields().len(), right.schema().fields().len(), - self.projection.as_ref(), + self.projection.as_deref(), self.join_type(), ), )?; @@ -400,6 +489,28 @@ impl NestedLoopJoinExec { Ok(plan) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + let left = children.swap_remove(0); + let right = children.swap_remove(0); + + Self { + left, + right, + metrics: ExecutionPlanMetricsSet::new(), + build_side_data: Default::default(), + left_spill_data: Arc::new(OnceAsync::default()), + cache: Arc::clone(&self.cache), + filter: self.filter.clone(), + join_type: self.join_type, + join_schema: Arc::clone(&self.join_schema), + column_indices: self.column_indices.clone(), + projection: self.projection.clone(), + } + } } impl DisplayAs for NestedLoopJoinExec { @@ -450,11 +561,7 @@ impl ExecutionPlan for NestedLoopJoinExec { "NestedLoopJoinExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -477,13 +584,17 @@ impl ExecutionPlan for NestedLoopJoinExec { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(NestedLoopJoinExec::try_new( - Arc::clone(&children[0]), - Arc::clone(&children[1]), - self.filter.clone(), - &self.join_type, - self.projection.clone(), - )?)) + check_if_same_properties!(self, children); + Ok(Arc::new( + NestedLoopJoinExecBuilder::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.join_type, + ) + .with_filter(self.filter.clone()) + .with_projection_ref(self.projection.clone()) + .build()?, + )) } fn execute( @@ -499,8 +610,22 @@ impl ExecutionPlan for NestedLoopJoinExec { ); let metrics = NestedLoopJoinMetrics::new(&self.metrics, partition); + let batch_size = context.session_config().batch_size(); - // Initialization reservation for load of inner table + // update column indices to reflect the projection + let column_indices_after_projection = match self.projection.as_ref() { + Some(projection) => projection + .iter() + .map(|i| self.column_indices[*i].clone()) + .collect(), + None => self.column_indices.clone(), + }; + + let right_partition_count = self.right().output_partitioning().partition_count(); + + // Always try to buffer all left data in memory via OnceFut. + // If that fails with OOM, the stream will fallback to memory-limited + // mode (if conditions allow). let load_reservation = MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]")) .register(context.memory_pool()); @@ -513,21 +638,39 @@ impl ExecutionPlan for NestedLoopJoinExec { metrics.join_metrics.clone(), load_reservation, need_produce_result_in_final(self.join_type), - self.right().output_partitioning().partition_count(), + right_partition_count, )) })?; - let batch_size = context.session_config().batch_size(); - - let probe_side_data = self.right.execute(partition, context)?; - - // update column indices to reflect the projection - let column_indices_after_projection = match &self.projection { - Some(projection) => projection - .iter() - .map(|i| self.column_indices[*i].clone()) - .collect(), - None => self.column_indices.clone(), + let probe_side_data = self.right.execute(partition, Arc::clone(&context))?; + + // Determine if OOM fallback to memory-limited mode is possible. + // Conditions: + // 1. Disk manager supports temp files (needed for spilling). + // 2. FULL join with multiple right partitions is not yet supported + // in the fallback path. FULL join needs to track BOTH left-side + // matches (for unmatched left rows) AND right-side matches (for + // unmatched right rows). The fallback path builds a per-partition + // `JoinLeftData` with `probe_threads_counter == 1`, so each + // partition emits unmatched left rows based only on its own + // right-side matches, producing incorrect duplicate output for + // left rows that match in another partition. Other join types + // that need only one-sided final emission (LEFT, LEFT SEMI, + // LEFT ANTI, LEFT MARK) have a similar latent issue in the + // fallback path which predates this change; tracking is out of + // scope for this PR. + let full_join_multi_partition = + matches!(self.join_type, JoinType::Full) && right_partition_count > 1; + let spill_state = if context.runtime_env().disk_manager.tmp_files_enabled() + && !full_join_multi_partition + { + SpillState::Pending { + left_plan: Arc::clone(&self.left), + task_context: Arc::clone(&context), + left_spill_data: Arc::clone(&self.left_spill_data), + } + } else { + SpillState::Disabled }; Ok(Box::pin(NestedLoopJoinStream::new( @@ -539,6 +682,7 @@ impl ExecutionPlan for NestedLoopJoinExec { column_indices_after_projection, metrics, batch_size, + spill_state, ))) } @@ -546,22 +690,36 @@ impl ExecutionPlan for NestedLoopJoinExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { - if partition.is_some() { - return Ok(Statistics::new_unknown(&self.schema())); - } + fn partition_statistics(&self, partition: Option) -> Result> { + // NestedLoopJoinExec is designed for joins without equijoin keys in the + // ON clause (e.g., `t1 JOIN t2 ON (t1.v1 + t2.v1) % 2 = 0`). Any join + // predicates are stored in `self.filter`, but `estimate_join_statistics` + // currently doesn't support selectivity estimation for such arbitrary + // filter expressions. We pass an empty join column list, which means + // the cardinality estimation cannot use column statistics and returns + // unknown row counts. let join_columns = Vec::new(); - estimate_join_statistics( - self.left.partition_statistics(None)?, - self.right.partition_statistics(None)?, + + // Left side is always a single partition (Distribution::SinglePartition), + // so we always request overall stats with `None`. Right side can have + // multiple partitions, so we forward the partition parameter to get + // partition-specific statistics when requested. + let left_stats = Arc::unwrap_or_clone(self.left.partition_statistics(None)?); + let right_stats = Arc::unwrap_or_clone(match partition { + Some(partition) => self.right.partition_statistics(Some(partition))?, + None => self.right.partition_statistics(None)?, + }); + + let stats = estimate_join_statistics( + left_stats, + right_stats, &join_columns, + NullEquality::NullEqualsNothing, &self.join_type, - &self.schema(), - ) + &self.join_schema, + )?; + + Ok(Arc::new(stats.project(self.projection.as_ref()))) } /// Tries to push `projection` down through `nested_loop_join`. If possible, performs the @@ -666,10 +824,10 @@ async fn collect_left_input( let schema = stream.schema(); // Load all batches and count the rows - let (batches, metrics, mut reservation) = stream + let (batches, metrics, reservation) = stream .try_fold( (Vec::new(), join_metrics, reservation), - |(mut batches, metrics, mut reservation), batch| async { + |(mut batches, metrics, reservation), batch| async { let batch_size = batch.get_array_memory_size(); // Reserve memory for incoming batch reservation.try_grow(batch_size)?; @@ -717,8 +875,116 @@ enum NLJState { ProbeRight, EmitRightUnmatched, EmitLeftUnmatched, + /// Emit unmatched right rows using the global bitmap accumulated across + /// all left chunks. Only used in memory-limited mode for join types that + /// require tracking right-side matches in the final output (RIGHT, FULL, + /// RIGHT SEMI, RIGHT ANTI, RIGHT MARK). + EmitGlobalRightUnmatched, Done, } +/// Shared data for the left-side spill fallback. +/// +/// When the in-memory `OnceFut` path fails with OOM, the first partition +/// spills the entire left side to disk. This struct holds the spill file +/// reference so other partitions can read from the same file. +pub(crate) struct LeftSpillData { + /// SpillManager used to read the spill file (has the left schema) + spill_manager: SpillManager, + /// The spill file containing all left-side batches + spill_file: RefCountedTempFile, + /// Left-side schema + schema: SchemaRef, +} + +/// Tracks the state of the memory-limited spill fallback for NLJ. +/// +/// The NLJ always starts with the standard OnceFut path. If the in-memory +/// load fails with OOM and conditions allow, the operator falls back to a +/// multi-pass strategy where left data is loaded in chunks and the right +/// side is spilled to disk. +pub(crate) enum SpillState { + /// Fallback is not possible (e.g., join type requires global right bitmap, + /// or disk manager is disabled). OOM errors will propagate as-is. + Disabled, + + /// Fallback is possible but not yet triggered. The operator is still + /// attempting the standard OnceFut path. Holds the context needed to + /// initiate fallback if OOM occurs. + Pending { + /// Left child plan for re-execution + left_plan: Arc, + /// TaskContext for re-execution and SpillManager creation + task_context: Arc, + /// Shared OnceAsync for left-side spill data. The first partition + /// to initiate fallback spills the left side; others share the file. + left_spill_data: Arc>, + }, + + /// Fallback has been triggered. Left data is being loaded in chunks + /// and the right side is spilled to disk for re-scanning. + Active(Box), +} + +/// State for active memory-limited spill execution. +/// Boxed inside [`SpillState::Active`] to reduce enum size. +pub(crate) struct SpillStateActive { + /// Shared future for left-side spill data. All partitions wait on + /// the same future — the first to poll triggers the actual spill. + left_spill_fut: OnceFut, + /// Left input stream for incremental chunk reading (from spill file). + /// None until `left_spill_fut` resolves. + left_stream: Option, + /// Left-side schema (set once `left_spill_fut` resolves) + left_schema: Option, + /// Memory reservation for left-side buffering + reservation: MemoryReservation, + /// Accumulated left batches for the current chunk + pending_batches: Vec, + /// Right input that spills on the first pass and replays from spill later. + right_input: ReplayableStreamSource, + /// Per-batch accumulated right bitmaps across all left chunks. + /// Index = right batch sequence number (0-based, non-empty batches only). + /// Only populated when `should_track_unmatched_right` is true. + global_right_bitmaps: Vec, + /// Separate reservation for `global_right_bitmaps`. These buffers live + /// for the full operator lifetime (not per-chunk), so they must be + /// tracked separately from `reservation`, which gets `resize(0)`-ed + /// between chunks. + global_right_bitmaps_reservation: MemoryReservation, + /// Current right batch sequence index within the current pass. + right_batch_index: usize, +} + +impl SpillStateActive { + /// Merge a per-pass right bitmap into the global accumulator at the + /// given batch index, growing the dedicated reservation when seeing + /// a batch index for the first time. + /// + /// On first encounter of `idx`, the bitmap is stored as-is and its + /// size is reserved. On subsequent encounters (later left chunk + /// passes over the same right batch), the existing entry is OR-merged + /// with `values`. Because `bitor` produces a buffer of the same bit + /// length, the reservation does not need to be adjusted on merge. + fn merge_current_right_bitmap(&mut self, idx: usize, values: BooleanBuffer) { + if idx >= self.global_right_bitmaps.len() { + // First encounter of this right batch — account memory and store. + // The bitmap has one bit per right row, so for very large right + // inputs the accumulated size can be non-negligible (e.g., + // 1M rows ≈ 125 KB per batch). + // Use infallible `grow` because we must accept the bitmap to + // preserve correctness — the fallback path has no other recourse. + let bytes = values.len().div_ceil(8); + self.global_right_bitmaps_reservation.grow(bytes); + self.global_right_bitmaps.push(values); + } else { + // Subsequent left chunk pass — OR merge. Same bit length, so + // no reservation adjustment is needed. + self.global_right_bitmaps[idx] = + self.global_right_bitmaps[idx].bitor(&values); + } + } +} + pub(crate) struct NestedLoopJoinStream { // ======================================================================== // PROPERTIES: @@ -736,7 +1002,8 @@ pub(crate) struct NestedLoopJoinStream { /// type of the join pub(crate) join_type: JoinType, /// the probe-side(right) table data of the nested loop join - pub(crate) right_data: SendableRecordBatchStream, + /// `Option` is used because memory-limited path requires resetting it. + pub(crate) right_data: Option, /// the build-side table data of the nested loop join pub(crate) left_data: OnceFut, /// Projection to construct the output schema from the left and right tables. @@ -784,9 +1051,7 @@ pub(crate) struct NestedLoopJoinStream { /// Should we go back to `BufferingLeft` state again after `EmitLeftUnmatched` /// state is over. left_exhausted: bool, - /// If we can buffer all left data in one pass - /// TODO(now): this is for the (unimplemented) memory-limited execution - #[allow(dead_code)] + /// If we can buffer all left data in one pass (false means memory-limited multi-pass) left_buffered_in_one_pass: bool, // Probe(right) side @@ -796,6 +1061,20 @@ pub(crate) struct NestedLoopJoinStream { // For right join, keep track of matched rows in `current_right_batch` // Constructed when fetching each new incoming right batch in `FetchingRight` state. current_right_batch_matched: Option, + + /// Memory-limited spill fallback state. See [`SpillState`] for details. + spill_state: SpillState, + + /// Whether this stream has already reported probe completion for the current + /// left chunk via [`JoinLeftData::report_probe_completed`]. The shared + /// probe-threads counter must be decremented exactly once per probe stream; + /// without this guard a stream that yields a ready batch while finishing the + /// `EmitLeftUnmatched` state (and is then re-polled with `left_emit_idx` + /// still 0) would decrement the counter twice, driving it to zero + /// prematurely and causing a sibling partition to emit unmatched-left rows + /// before all partitions finished probing (spurious NULL-padded rows). + /// Reset to `false` when starting a new left chunk in memory-limited mode. + probe_completed_reported: bool, } pub(crate) struct NestedLoopJoinMetrics { @@ -803,6 +1082,8 @@ pub(crate) struct NestedLoopJoinMetrics { pub(crate) join_metrics: BuildProbeJoinMetrics, /// Selectivity of the join: output_rows / (left_rows * right_rows) pub(crate) selectivity: RatioMetrics, + /// Spill metrics for memory-limited execution + pub(crate) spill_metrics: SpillMetrics, } impl NestedLoopJoinMetrics { @@ -810,8 +1091,9 @@ impl NestedLoopJoinMetrics { Self { join_metrics: BuildProbeJoinMetrics::new(partition, metrics), selectivity: MetricBuilder::new(metrics) - .with_type(MetricType::SUMMARY) + .with_type(MetricType::Summary) .ratio_metrics("selectivity", partition), + spill_metrics: SpillMetrics::new(metrics, partition), } } } @@ -933,7 +1215,7 @@ impl Stream for NestedLoopJoinStream { match self.handle_probe_right() { ControlFlow::Continue(()) => continue, ControlFlow::Break(poll) => { - return self.metrics.join_metrics.baseline.record_poll(poll) + return self.metrics.join_metrics.baseline.record_poll(poll); } } } @@ -954,7 +1236,7 @@ impl Stream for NestedLoopJoinStream { match self.handle_emit_right_unmatched() { ControlFlow::Continue(()) => continue, ControlFlow::Break(poll) => { - return self.metrics.join_metrics.baseline.record_poll(poll) + return self.metrics.join_metrics.baseline.record_poll(poll); } } } @@ -971,9 +1253,9 @@ impl Stream for NestedLoopJoinStream { // 3. --> Done // It has processed all data, go to the final state and ready // to exit. - // - // TODO: For memory-limited case, go back to `BufferingLeft` - // state again. + // 4. --> BufferingLeft (memory-limited mode only) + // When left data was loaded in chunks and more chunks remain, + // go back to BufferingLeft to load the next chunk. NLJState::EmitLeftUnmatched => { debug!("[NLJState] Entering: {:?}", self.state); @@ -984,7 +1266,26 @@ impl Stream for NestedLoopJoinStream { match self.handle_emit_left_unmatched() { ControlFlow::Continue(()) => continue, ControlFlow::Break(poll) => { - return self.metrics.join_metrics.baseline.record_poll(poll) + return self.metrics.join_metrics.baseline.record_poll(poll); + } + } + } + + // Replay all right batches from spill and emit unmatched + // right rows using the global bitmap accumulated across all + // left chunks. Only entered in memory-limited mode for join + // types where `should_track_unmatched_right` is true + // (RIGHT, FULL, RIGHT SEMI, RIGHT ANTI, RIGHT MARK). + NLJState::EmitGlobalRightUnmatched => { + debug!("[NLJState] Entering: {:?}", self.state); + + let join_metric = self.metrics.join_metrics.join_time.clone(); + let _join_timer = join_metric.timer(); + + match self.handle_emit_global_right_unmatched(cx) { + ControlFlow::Continue(()) => continue, + ControlFlow::Break(poll) => { + return self.metrics.join_metrics.baseline.record_poll(poll); } } } @@ -1014,7 +1315,7 @@ impl RecordBatchStream for NestedLoopJoinStream { } impl NestedLoopJoinStream { - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] pub(crate) fn new( schema: Arc, filter: Option, @@ -1024,12 +1325,13 @@ impl NestedLoopJoinStream { column_indices: Vec, metrics: NestedLoopJoinMetrics, batch_size: usize, + spill_state: SpillState, ) -> Self { Self { output_schema: Arc::clone(&schema), join_filter: filter, join_type, - right_data, + right_data: Some(right_data), column_indices, left_data, metrics, @@ -1045,45 +1347,365 @@ impl NestedLoopJoinStream { left_buffered_in_one_pass: true, handled_empty_output: false, should_track_unmatched_right: need_produce_right_in_final(join_type), + spill_state, + probe_completed_reported: false, } } + /// Returns true if this stream is operating in memory-limited mode + fn is_memory_limited(&self) -> bool { + matches!(self.spill_state, SpillState::Active(_)) + } + + /// Check if we can fall back to memory-limited mode on this error. + fn can_fallback_to_spill(&self, error: &datafusion_common::DataFusionError) -> bool { + matches!(self.spill_state, SpillState::Pending { .. }) + && matches!( + error.find_root(), + datafusion_common::DataFusionError::ResourcesExhausted(_) + ) + } + + /// Switch from the standard OnceFut path to memory-limited mode. + /// + /// Uses the shared `left_spill_data` OnceAsync so that only the first + /// partition to reach this point re-executes the left child and spills + /// it to disk. Other partitions share the same spill file. + fn initiate_fallback(&mut self) -> Result<()> { + // Take ownership of Pending state + let (left_plan, context, left_spill_data) = + match std::mem::replace(&mut self.spill_state, SpillState::Disabled) { + SpillState::Pending { + left_plan, + task_context, + left_spill_data, + } => (left_plan, task_context, left_spill_data), + _ => { + return internal_err!( + "initiate_fallback called in non-Pending spill state" + ); + } + }; + + // Use OnceAsync to ensure only the first partition spills the left + // side. Other partitions will get the same OnceFut that resolves + // to the shared spill file. + let left_spill_fut = left_spill_data.try_once(|| { + let plan = Arc::clone(&left_plan); + let ctx = Arc::clone(&context); + let spill_metrics = self.metrics.spill_metrics.clone(); + Ok(async move { + let mut stream = plan.execute(0, Arc::clone(&ctx))?; + let schema = stream.schema(); + let left_spill_manager = SpillManager::new( + ctx.runtime_env(), + spill_metrics, + Arc::clone(&schema), + ) + .with_compression_type(ctx.session_config().spill_compression()); + + let result = left_spill_manager + .spill_record_batch_stream_and_return_max_batch_memory( + &mut stream, + "NestedLoopJoin left spill", + ) + .await?; + + match result { + Some((file, _max_batch_memory)) => Ok(LeftSpillData { + spill_manager: left_spill_manager, + spill_file: file, + schema, + }), + None => { + internal_err!("Left side produced no data to spill") + } + } + }) + })?; + + // Create reservation with can_spill for fair memory allocation + let reservation = MemoryConsumer::new("NestedLoopJoinLoad[fallback]".to_string()) + .with_can_spill(true) + .register(context.memory_pool()); + + // Separate reservation for the global right bitmaps. These buffers + // persist across all left chunks, whereas `reservation` is reset + // between chunks via `resize(0)`. + let global_right_bitmaps_reservation = + MemoryConsumer::new("NestedLoopJoinGlobalRightBitmaps".to_string()) + .register(context.memory_pool()); + + // Create SpillManager for right-side spilling + let right_schema = self + .right_data + .as_ref() + .expect("right_data must be present before fallback") + .schema(); + let right_data = self + .right_data + .take() + .expect("right_data must be present before fallback"); + let right_spill_manager = SpillManager::new( + context.runtime_env(), + self.metrics.spill_metrics.clone(), + right_schema, + ) + .with_compression_type(context.session_config().spill_compression()); + + self.spill_state = SpillState::Active(Box::new(SpillStateActive { + left_spill_fut, + left_stream: None, + left_schema: None, + reservation, + pending_batches: Vec::new(), + right_input: ReplayableStreamSource::new( + right_data, + right_spill_manager, + "NestedLoopJoin right spill", + ), + global_right_bitmaps: Vec::new(), + global_right_bitmaps_reservation, + right_batch_index: 0, + })); + + // State stays BufferingLeft — next poll will enter + // handle_buffering_left_memory_limited via is_memory_limited() check + self.state = NLJState::BufferingLeft; + + Ok(()) + } + // ==== State handler functions ==== - /// Handle BufferingLeft state - prepare left side batches + /// Handle BufferingLeft state - prepare left side batches. + /// + /// In standard mode, uses OnceFut to load all left data at once. + /// In memory-limited mode, incrementally buffers left batches until the + /// memory budget is reached or the left stream is exhausted. fn handle_buffering_left( &mut self, cx: &mut std::task::Context<'_>, ) -> ControlFlow>>> { - match self.left_data.get_shared(cx) { - Poll::Ready(Ok(left_data)) => { - self.buffered_left_data = Some(left_data); - // TODO: implement memory-limited case - self.left_exhausted = true; - self.state = NLJState::FetchingRight; - // Continue to next state immediately - ControlFlow::Continue(()) + if self.is_memory_limited() { + self.handle_buffering_left_memory_limited(cx) + } else { + // Standard path: use OnceFut + match self.left_data.get_shared(cx) { + Poll::Ready(Ok(left_data)) => { + self.buffered_left_data = Some(left_data); + self.left_exhausted = true; + self.state = NLJState::FetchingRight; + ControlFlow::Continue(()) + } + Poll::Ready(Err(e)) => { + if self.can_fallback_to_spill(&e) { + debug!( + "NestedLoopJoin: OnceFut failed with OOM, \ + falling back to memory-limited mode" + ); + match self.initiate_fallback() { + Ok(()) => ControlFlow::Continue(()), + Err(fallback_err) => { + ControlFlow::Break(Poll::Ready(Some(Err(fallback_err)))) + } + } + } else { + ControlFlow::Break(Poll::Ready(Some(Err(e)))) + } + } + Poll::Pending => ControlFlow::Break(Poll::Pending), } - Poll::Ready(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))), - Poll::Pending => ControlFlow::Break(Poll::Pending), } } - /// Handle FetchingRight state - fetch next right batch and prepare for processing + /// Memory-limited path for handle_buffering_left. + /// + /// Incrementally polls the left stream and accumulates batches until: + /// - Memory reservation fails (chunk is full, more data remains) + /// - Left stream is exhausted (this is the last/only chunk) + fn handle_buffering_left_memory_limited( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> ControlFlow>>> { + let SpillState::Active(active) = &mut self.spill_state else { + unreachable!( + "handle_buffering_left_memory_limited called without Active spill state" + ); + }; + + // On first entry (or after re-entry for a new chunk pass when + // left_stream was consumed), wait for the shared left spill + // future to resolve and then open a stream from the spill file. + if active.left_stream.is_none() { + match active.left_spill_fut.get_shared(cx) { + Poll::Ready(Ok(spill_data)) => { + match spill_data + .spill_manager + .read_spill_as_stream(spill_data.spill_file.clone(), None) + { + Ok(stream) => { + active.left_schema = Some(Arc::clone(&spill_data.schema)); + active.left_stream = Some(stream); + } + Err(e) => { + return ControlFlow::Break(Poll::Ready(Some(Err(e)))); + } + } + } + Poll::Ready(Err(e)) => { + return ControlFlow::Break(Poll::Ready(Some(Err(e)))); + } + Poll::Pending => { + return ControlFlow::Break(Poll::Pending); + } + } + } + + let left_stream = active + .left_stream + .as_mut() + .expect("left_stream must be set after spill future resolves"); + + // Poll left stream for more batches. + // Note: pending_batches may already contain a batch from the + // previous chunk iteration (the batch that triggered the memory limit). + loop { + match left_stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + if batch.num_rows() == 0 { + continue; + } + let batch_rows = batch.num_rows(); + let batch_size = batch.get_array_memory_size(); + let can_grow = active.reservation.try_grow(batch_size).is_ok(); + + if !can_grow && !active.pending_batches.is_empty() { + // Memory limit reached and we already have data. + // Push this batch into pending (it's already in memory) + // and stop buffering for this chunk. + active.pending_batches.push(batch); + self.left_exhausted = false; + self.left_buffered_in_one_pass = false; + break; + } else if !can_grow { + // No pending batches yet — we must accept this batch + // to make progress, even if it exceeds the budget. + active.reservation.grow(batch_size); + } + + self.metrics.join_metrics.build_mem_used.add(batch_size); + self.metrics.join_metrics.build_input_batches.add(1); + self.metrics.join_metrics.build_input_rows.add(batch_rows); + active.pending_batches.push(batch); + } + Poll::Ready(Some(Err(e))) => { + return ControlFlow::Break(Poll::Ready(Some(Err(e)))); + } + Poll::Ready(None) => { + // Left stream exhausted + self.left_exhausted = true; + break; + } + Poll::Pending => { + return ControlFlow::Break(Poll::Pending); + } + } + } + + // If the left stream is fully exhausted, release its resources so the + // upstream pipeline can be torn down before we move on to probing. + if self.left_exhausted { + active.left_stream = None; + } + + if active.pending_batches.is_empty() { + // No data at all — go directly to Done + self.left_exhausted = true; + self.state = NLJState::Done; + return ControlFlow::Continue(()); + } + + let merged_batch = match concat_batches( + active + .left_schema + .as_ref() + .expect("left_schema must be set"), + &active.pending_batches, + ) { + Ok(batch) => batch, + Err(e) => { + return ControlFlow::Break(Poll::Ready(Some(Err(e.into())))); + } + }; + active.pending_batches.clear(); + + // Build visited bitmap if needed for this join type + let with_visited = need_produce_result_in_final(self.join_type); + let n_rows = merged_batch.num_rows(); + let visited_left_side = if with_visited { + let buffer_size = n_rows.div_ceil(8); + // Use infallible grow for bitmap — it's small + active.reservation.grow(buffer_size); + self.metrics.join_metrics.build_mem_used.add(buffer_size); + let mut buffer = BooleanBufferBuilder::new(n_rows); + buffer.append_n(n_rows, false); + buffer + } else { + BooleanBufferBuilder::new(0) + }; + + // Create an empty reservation for JoinLeftData's RAII field. + // The actual memory tracking is managed by the Active state's reservation. + let dummy_reservation = active.reservation.new_empty(); + + let left_data = JoinLeftData::new( + merged_batch, + Mutex::new(visited_left_side), + // In memory-limited mode, only 1 probe thread per chunk + AtomicUsize::new(1), + dummy_reservation, + ); + + self.buffered_left_data = Some(Arc::new(left_data)); + + active.right_batch_index = 0; + match active.right_input.open_pass() { + Ok(stream) => { + self.right_data = Some(stream); + } + Err(e) => { + return ControlFlow::Break(Poll::Ready(Some(Err(e)))); + } + } + + self.state = NLJState::FetchingRight; + ControlFlow::Continue(()) + } + + /// Handle FetchingRight state - fetch next right batch and prepare for processing. + /// + /// In memory-limited mode during the first pass, each right batch is also + /// written to a spill file so it can be re-read on subsequent passes. fn handle_fetching_right( &mut self, cx: &mut std::task::Context<'_>, ) -> ControlFlow>>> { - match self.right_data.poll_next_unpin(cx) { + match self + .right_data + .as_mut() + .expect("right_data must be present while fetching right") + .poll_next_unpin(cx) + { Poll::Ready(result) => match result { Some(Ok(right_batch)) => { // Update metrics - let right_batch_size = right_batch.num_rows(); - self.metrics.join_metrics.input_rows.add(right_batch_size); + let right_batch_rows = right_batch.num_rows(); + self.metrics.join_metrics.input_rows.add(right_batch_rows); self.metrics.join_metrics.input_batches.add(1); // Skip the empty batch - if right_batch_size == 0 { + if right_batch_rows == 0 { return ControlFlow::Continue(()); } @@ -1091,7 +1713,7 @@ impl NestedLoopJoinStream { // Prepare right bitmap if self.should_track_unmatched_right { - let zeroed_buf = BooleanBuffer::new_unset(right_batch_size); + let zeroed_buf = BooleanBuffer::new_unset(right_batch_rows); self.current_right_batch_matched = Some(BooleanArray::new(zeroed_buf, None)); } @@ -1102,7 +1724,6 @@ impl NestedLoopJoinStream { } Some(Err(e)) => ControlFlow::Break(Poll::Ready(Some(Err(e)))), None => { - // Right stream exhausted self.state = NLJState::EmitLeftUnmatched; ControlFlow::Continue(()) } @@ -1157,10 +1778,37 @@ impl NestedLoopJoinStream { } } - /// Handle EmitRightUnmatched state - emit unmatched right rows + /// Handle EmitRightUnmatched state - emit unmatched right rows. + /// + /// In memory-limited mode, instead of emitting unmatched right rows + /// per-batch (which would be incorrect since more left chunks may + /// match those rows), we merge the bitmap into the global accumulator + /// and defer emission to `EmitGlobalRightUnmatched`. fn handle_emit_right_unmatched( &mut self, ) -> ControlFlow>>> { + // In memory-limited mode, merge bitmap into global and move on + if self.is_memory_limited() { + debug_assert!( + self.current_right_batch_matched.is_some(), + "right bitmap must be present" + ); + let bitmap = std::mem::take(&mut self.current_right_batch_matched) + .expect("right bitmap should be available"); + let (values, _nulls) = bitmap.into_parts(); + + if let SpillState::Active(ref mut active) = self.spill_state { + let idx = active.right_batch_index; + active.merge_current_right_bitmap(idx, values); + active.right_batch_index += 1; + } + + self.current_right_batch = None; + self.state = NLJState::FetchingRight; + return ControlFlow::Continue(()); + } + + // Standard (single-pass) mode: emit unmatched right rows immediately // Return any completed batches first if let Some(poll) = self.maybe_flush_ready_batch() { return ControlFlow::Break(poll); @@ -1171,23 +1819,16 @@ impl NestedLoopJoinStream { && self.current_right_batch.is_some(), "This state is yielding output for unmatched rows in the current right batch, so both the right batch and the bitmap must be present" ); - // Construct the result batch for unmatched right rows using a utility function match self.process_right_unmatched() { - Ok(Some(batch)) => { - match self.output_buffer.push_batch(batch) { - Ok(()) => { - // Processed all in one pass - // cleared inside `process_right_unmatched` - debug_assert!(self.current_right_batch.is_none()); - self.state = NLJState::FetchingRight; - ControlFlow::Continue(()) - } - Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))), + Ok(Some(batch)) => match self.output_buffer.push_batch(batch) { + Ok(()) => { + debug_assert!(self.current_right_batch.is_none()); + self.state = NLJState::FetchingRight; + ControlFlow::Continue(()) } - } + Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))), + }, Ok(None) => { - // Processed all in one pass - // cleared inside `process_right_unmatched` debug_assert!(self.current_right_batch.is_none()); self.state = NLJState::FetchingRight; ControlFlow::Continue(()) @@ -1196,7 +1837,11 @@ impl NestedLoopJoinStream { } } - /// Handle EmitLeftUnmatched state - emit unmatched left rows + /// Handle EmitLeftUnmatched state - emit unmatched left rows. + /// + /// In memory-limited mode, after processing all unmatched rows for the + /// current left chunk, transitions back to `BufferingLeft` to load the + /// next chunk (if the left stream is not yet exhausted). fn handle_emit_left_unmatched( &mut self, ) -> ControlFlow>>> { @@ -1210,11 +1855,45 @@ impl NestedLoopJoinStream { // State unchanged (EmitLeftUnmatched) // Continue processing until we have processed all unmatched rows Ok(true) => ControlFlow::Continue(()), - // To Done state - // We have finished processing all unmatched rows + // We have finished processing all unmatched rows for this chunk Ok(false) => match self.output_buffer.finish_buffered_batch() { Ok(()) => { - self.state = NLJState::Done; + // Flush any completed batch before transitioning. + // This is critical for the memory-limited path: the + // ProbeRight results must be emitted before we discard + // the current chunk and load the next one. + if let Some(poll) = self.maybe_flush_ready_batch() { + return ControlFlow::Break(poll); + } + + if !self.left_exhausted && self.is_memory_limited() { + // More left data to process — free current chunk and + // go back to BufferingLeft for the next chunk + if let SpillState::Active(ref active) = self.spill_state { + active.reservation.resize(0); + } + self.buffered_left_data = None; + self.left_probe_idx = 0; + self.left_emit_idx = 0; + // Each memory-limited chunk gets a fresh per-chunk + // `JoinLeftData`/counter, so allow this stream to report + // completion again for the next chunk. + self.probe_completed_reported = false; + self.state = NLJState::BufferingLeft; + } else if self.is_memory_limited() + && self.should_track_unmatched_right + { + // All left chunks done — emit global right unmatched. + // Drop the exhausted right stream so that + // EmitGlobalRightUnmatched opens a fresh replay pass + // from the spill file. (process_left_unmatched_range + // already ran with right_data still set, so its + // schema access is not affected.) + self.right_data = None; + self.state = NLJState::EmitGlobalRightUnmatched; + } else { + self.state = NLJState::Done; + } ControlFlow::Continue(()) } Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))), @@ -1223,6 +1902,103 @@ impl NestedLoopJoinStream { } } + /// Handle EmitGlobalRightUnmatched state. + /// + /// Replays all right batches from the spill file and emits unmatched + /// right rows using the global bitmap accumulated across all left chunks. + fn handle_emit_global_right_unmatched( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> ControlFlow>>> { + // Flush any completed batches first + if let Some(poll) = self.maybe_flush_ready_batch() { + return ControlFlow::Break(poll); + } + + // On first entry, open a new replay pass on the right input + if self.right_data.is_none() { + let SpillState::Active(ref mut active) = self.spill_state else { + unreachable!("EmitGlobalRightUnmatched without Active spill state"); + }; + active.right_batch_index = 0; + match active.right_input.open_pass() { + Ok(stream) => { + self.right_data = Some(stream); + } + Err(e) => { + return ControlFlow::Break(Poll::Ready(Some(Err(e)))); + } + } + } + + // Poll the replay stream for the next right batch + match self + .right_data + .as_mut() + .expect("right_data must be present") + .poll_next_unpin(cx) + { + Poll::Ready(Some(Ok(right_batch))) => { + if right_batch.num_rows() == 0 { + return ControlFlow::Continue(()); + } + + let SpillState::Active(ref mut active) = self.spill_state else { + unreachable!(); + }; + let idx = active.right_batch_index; + active.right_batch_index += 1; + + // Build BooleanArray from the global bitmap + let bitmap = if idx < active.global_right_bitmaps.len() { + BooleanArray::new(active.global_right_bitmaps[idx].clone(), None) + } else { + // Batch never seen — treat all rows as unmatched + BooleanArray::new( + BooleanBuffer::new_unset(right_batch.num_rows()), + None, + ) + }; + + let left_schema = Arc::clone( + active + .left_schema + .as_ref() + .expect("left_schema must be set"), + ); + + match build_unmatched_batch( + &self.output_schema, + &right_batch, + bitmap, + &left_schema, + &self.column_indices, + self.join_type, + JoinSide::Right, + ) { + Ok(Some(batch)) => match self.output_buffer.push_batch(batch) { + Ok(()) => ControlFlow::Continue(()), + Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))), + }, + Ok(None) => ControlFlow::Continue(()), + Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))), + } + } + Poll::Ready(Some(Err(e))) => ControlFlow::Break(Poll::Ready(Some(Err(e)))), + Poll::Ready(None) => { + // All right batches replayed + match self.output_buffer.finish_buffered_batch() { + Ok(()) => { + self.state = NLJState::Done; + ControlFlow::Continue(()) + } + Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))), + } + } + Poll::Pending => ControlFlow::Break(Poll::Pending), + } + } + /// Handle Done state - final state processing fn handle_done(&mut self) -> Poll>> { // Return any remaining completed batches before final termination @@ -1294,7 +2070,10 @@ impl NestedLoopJoinStream { left_data.batch().num_rows() - self.left_probe_idx, ); - debug_assert!(l_row_count != 0, "This function should only be entered when there are remaining left rows to process"); + debug_assert!( + l_row_count != 0, + "This function should only be entered when there are remaining left rows to process" + ); let joined_batch = self.process_left_range_join( &left_data, &right_batch, @@ -1434,17 +2213,17 @@ impl NestedLoopJoinStream { let l_index = l_start_index + i / right_rows; let r_index = i % right_rows; - if let Some(bitmap) = left_bitmap.as_mut() { - if is_matched { - // Map local index back to absolute left index within the batch - bitmap.set_bit(l_index, true); - } + if let Some(bitmap) = left_bitmap.as_mut() + && is_matched + { + // Map local index back to absolute left index within the batch + bitmap.set_bit(l_index, true); } - if let Some(bitmap) = local_right_bitmap.as_mut() { - if is_matched { - bitmap.set_bit(r_index, true); - } + if let Some(bitmap) = local_right_bitmap.as_mut() + && is_matched + { + bitmap.set_bit(r_index, true); } } @@ -1556,7 +2335,7 @@ impl NestedLoopJoinStream { return Ok(None); } - if cur_right_bitmap.true_count() == 0 { + if !cur_right_bitmap.has_true() { // If none of the pairs has passed the join predicate/filter Ok(None) } else { @@ -1578,7 +2357,9 @@ impl NestedLoopJoinStream { /// true -> continue in the same EmitLeftUnmatched state /// false -> next state (Done) fn process_left_unmatched(&mut self) -> Result { - let left_data = self.get_left_data()?; + // Clone the shared `Arc` so the immutable borrow of `self` + // ends here and we can update `self.probe_completed_reported` below. + let left_data = Arc::clone(self.get_left_data()?); let left_batch = left_data.batch(); // ======== @@ -1587,9 +2368,25 @@ impl NestedLoopJoinStream { // Early return if join type can't have unmatched rows let join_type_no_produce_left = !need_produce_result_in_final(self.join_type); - // Early return if another thread is already processing unmatched rows - let handled_by_other_partition = - self.left_emit_idx == 0 && !left_data.report_probe_completed(); + // Early return if another thread is already processing unmatched rows. + // + // The shared probe-threads counter must be decremented exactly once per + // probe stream. This function can be re-entered with `left_emit_idx` + // still 0 (e.g. when a ready batch was flushed via an early return in + // `handle_emit_left_unmatched` before the state advanced), so guard the + // decrement with `probe_completed_reported` instead of relying solely on + // `left_emit_idx == 0`. Decrementing twice would drive the counter to + // zero prematurely and let a partition emit unmatched-left rows before + // all partitions finished probing, producing spurious NULL-padded rows. + let handled_by_other_partition = if self.probe_completed_reported { + // Already counted this stream's completion; if we're the designated + // emitter we have `left_emit_idx > 0` (or are mid-emit) and continue, + // otherwise another partition is handling emission. + self.left_emit_idx == 0 + } else { + self.probe_completed_reported = true; + self.left_emit_idx == 0 && !left_data.report_probe_completed() + }; // Stop processing unmatched rows, the caller will go to the next state let finished = self.left_emit_idx >= left_batch.num_rows(); @@ -1605,7 +2402,7 @@ impl NestedLoopJoinStream { let end_idx = std::cmp::min(start_idx + self.batch_size, left_batch.num_rows()); if let Some(batch) = - self.process_left_unmatched_range(left_data, start_idx, end_idx)? + self.process_left_unmatched_range(&left_data, start_idx, end_idx)? { self.output_buffer.push_batch(batch)?; } @@ -1657,7 +2454,11 @@ impl NestedLoopJoinStream { } let bitmap_sliced = BooleanArray::new(bitmap_sliced.finish(), None); - let right_schema = self.right_data.schema(); + let right_schema = self + .right_data + .as_ref() + .expect("right_data must be present when building unmatched batch") + .schema(); build_unmatched_batch( &self.output_schema, &left_batch_sliced, @@ -1712,14 +2513,14 @@ impl NestedLoopJoinStream { /// Flush the `output_buffer` if there are batches ready to output /// None if no result batch ready. fn maybe_flush_ready_batch(&mut self) -> Option>>> { - if self.output_buffer.has_completed_batch() { - if let Some(batch) = self.output_buffer.next_completed_batch() { - // Update output rows for selectivity metric - let output_rows = batch.num_rows(); - self.metrics.selectivity.add_part(output_rows); + if self.output_buffer.has_completed_batch() + && let Some(batch) = self.output_buffer.next_completed_batch() + { + // Update output rows for selectivity metric + let output_rows = batch.num_rows(); + self.metrics.selectivity.add_part(output_rows); - return Some(Poll::Ready(Some(Ok(batch)))); - } + return Some(Poll::Ready(Some(Ok(batch)))); } None @@ -1747,16 +2548,13 @@ impl NestedLoopJoinStream { ) -> Result<()> { let left_data = self.get_left_data()?; - // number of successfully joined pairs from (l_index x cur_right_batch) - let joined_len = r_matched_bitmap.true_count(); - // 1. Maybe update the left bitmap - if need_produce_result_in_final(self.join_type) && (joined_len > 0) { + if need_produce_result_in_final(self.join_type) && r_matched_bitmap.has_true() { let mut bitmap = left_data.bitmap().lock(); bitmap.set_bit(l_index, true); } - // 2. Maybe updateh the right bitmap + // 2. Maybe update the right bitmap if self.should_track_unmatched_right { debug_assert!(self.current_right_batch_matched.is_some()); // after bit-wise or, it will be put back @@ -1930,9 +2728,10 @@ fn build_row_join_batch( // Broadcast the single build-side row to match the filtered // probe-side batch length let original_left_array = build_side_batch.column(column_index.index); - // Avoid using `ScalarValue::to_array_of_size()` for `List(Utf8View)` to avoid - // deep copies for buffers inside `Utf8View` array. See below for details. - // https://github.com/apache/datafusion/issues/18159 + + // Use `arrow::compute::take` directly for `List(Utf8View)` rather + // than going through `ScalarValue::to_array_of_size()`, which + // avoids some intermediate allocations. // // In other cases, `to_array_of_size()` is faster. match original_left_array.data_type() { @@ -2086,7 +2885,7 @@ fn build_unmatched_batch( // 2. Fill left side with nulls let flipped_bitmap = not(&batch_bitmap)?; - // create a recordbatch, with left_schema, of only one row of all nulls + // create a record batch, with left_schema, of only one row of all nulls let left_null_columns: Vec> = another_side_schema .fields() .iter() @@ -2100,9 +2899,7 @@ fn build_unmatched_batch( another_side_schema .fields() .iter() - .map(|field| { - (**field).clone().with_nullable(true) - }) + .map(|field| (**field).clone().with_nullable(true)) .collect::>(), )); let left_null_batch = if nullable_left_schema.fields.is_empty() { @@ -2116,10 +2913,20 @@ fn build_unmatched_batch( debug_assert_ne!(batch_side, JoinSide::None); let opposite_side = batch_side.negate(); - build_row_join_batch(output_schema, &left_null_batch, 0, batch, Some(flipped_bitmap), col_indices, opposite_side) - - }, - JoinType::RightSemi | JoinType::RightAnti | JoinType::LeftSemi | JoinType::LeftAnti => { + build_row_join_batch( + output_schema, + &left_null_batch, + 0, + batch, + Some(flipped_bitmap), + col_indices, + opposite_side, + ) + } + JoinType::RightSemi + | JoinType::RightAnti + | JoinType::LeftSemi + | JoinType::LeftAnti => { if matches!(join_type, JoinType::RightSemi | JoinType::RightAnti) { debug_assert_eq!(batch_side, JoinSide::Right); } @@ -2127,13 +2934,14 @@ fn build_unmatched_batch( debug_assert_eq!(batch_side, JoinSide::Left); } - let bitmap = if matches!(join_type, JoinType::LeftSemi | JoinType::RightSemi) { + let bitmap = if matches!(join_type, JoinType::LeftSemi | JoinType::RightSemi) + { batch_bitmap.clone() } else { not(&batch_bitmap)? }; - if bitmap.true_count() == 0 { + if !bitmap.has_true() { return Ok(None); } @@ -2149,8 +2957,11 @@ fn build_unmatched_batch( columns.push(filtered_col); } - Ok(Some(RecordBatch::try_new(Arc::clone(output_schema), columns)?)) - }, + Ok(Some(RecordBatch::try_new( + Arc::clone(output_schema), + columns, + )?)) + } JoinType::RightMark | JoinType::LeftMark => { if join_type == JoinType::RightMark { debug_assert_eq!(batch_side, JoinSide::Right); @@ -2173,32 +2984,41 @@ fn build_unmatched_batch( } else if column_index.side == JoinSide::None { let right_batch_bitmap = std::mem::take(&mut right_batch_bitmap_opt); match right_batch_bitmap { - Some(right_batch_bitmap) => {columns.push(Arc::new(right_batch_bitmap))}, + Some(right_batch_bitmap) => { + columns.push(Arc::new(right_batch_bitmap)) + } None => unreachable!("Should only be one mark column"), } } else { - return internal_err!("Not possible to have this join side for RightMark join"); + return internal_err!( + "Not possible to have this join side for RightMark join" + ); } } - Ok(Some(RecordBatch::try_new(Arc::clone(output_schema), columns)?)) + Ok(Some(RecordBatch::try_new( + Arc::clone(output_schema), + columns, + )?)) } - _ => internal_err!("If batch is at right side, this function must be handling Full/Right/RightSemi/RightAnti/RightMark joins"), + _ => internal_err!( + "If batch is at right side, this function must be handling Full/Right/RightSemi/RightAnti/RightMark joins" + ), } } #[cfg(test)] pub(crate) mod tests { use super::*; - use crate::test::{assert_join_metrics, TestMemoryExec}; + use crate::test::{TestMemoryExec, assert_join_metrics}; use crate::{ common, expressions::Column, repartition::RepartitionExec, test::build_table_i32, }; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field}; + use datafusion_common::assert_contains; use datafusion_common::test_util::batches_to_sort_string; - use datafusion_common::{assert_contains, ScalarValue}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; @@ -2382,13 +3202,13 @@ pub(crate) mod tests { .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+----+----+----+----+ - | 5 | 5 | 50 | 2 | 2 | 80 | - +----+----+----+----+----+----+ - "#)); + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+----+ + | 5 | 5 | 50 | 2 | 2 | 80 | + +----+----+----+----+----+----+ + ")); assert_join_metrics!(metrics, 1); @@ -2412,15 +3232,15 @@ pub(crate) mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" - +----+----+-----+----+----+----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+-----+----+----+----+ - | 11 | 8 | 110 | | | | - | 5 | 5 | 50 | 2 | 2 | 80 | - | 9 | 8 | 90 | | | | - +----+----+-----+----+----+----+ - "#)); + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+-----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+-----+----+----+----+ + | 11 | 8 | 110 | | | | + | 5 | 5 | 50 | 2 | 2 | 80 | + | 9 | 8 | 90 | | | | + +----+----+-----+----+----+----+ + ")); assert_join_metrics!(metrics, 3); @@ -2444,15 +3264,15 @@ pub(crate) mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" - +----+----+----+----+----+-----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+----+----+----+-----+ - | | | | 10 | 10 | 100 | - | | | | 12 | 10 | 40 | - | 5 | 5 | 50 | 2 | 2 | 80 | - +----+----+----+----+----+-----+ - "#)); + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+----+----+----+-----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+-----+ + | | | | 10 | 10 | 100 | + | | | | 12 | 10 | 40 | + | 5 | 5 | 50 | 2 | 2 | 80 | + +----+----+----+----+----+-----+ + ")); assert_join_metrics!(metrics, 3); @@ -2476,17 +3296,17 @@ pub(crate) mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" - +----+----+-----+----+----+-----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+-----+----+----+-----+ - | | | | 10 | 10 | 100 | - | | | | 12 | 10 | 40 | - | 11 | 8 | 110 | | | | - | 5 | 5 | 50 | 2 | 2 | 80 | - | 9 | 8 | 90 | | | | - +----+----+-----+----+----+-----+ - "#)); + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+-----+----+----+-----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+-----+----+----+-----+ + | | | | 10 | 10 | 100 | + | | | | 12 | 10 | 40 | + | 11 | 8 | 110 | | | | + | 5 | 5 | 50 | 2 | 2 | 80 | + | 9 | 8 | 90 | | | | + +----+----+-----+----+----+-----+ + ")); assert_join_metrics!(metrics, 5); @@ -2512,13 +3332,13 @@ pub(crate) mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1"]); - allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c1 | - +----+----+----+ - | 5 | 5 | 50 | - +----+----+----+ - "#)); + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 5 | 5 | 50 | + +----+----+----+ + ")); assert_join_metrics!(metrics, 1); @@ -2544,14 +3364,14 @@ pub(crate) mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1"]); - allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" - +----+----+-----+ - | a1 | b1 | c1 | - +----+----+-----+ - | 11 | 8 | 110 | - | 9 | 8 | 90 | - +----+----+-----+ - "#)); + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+-----+ + | a1 | b1 | c1 | + +----+----+-----+ + | 11 | 8 | 110 | + | 9 | 8 | 90 | + +----+----+-----+ + ")); assert_join_metrics!(metrics, 2); @@ -2597,13 +3417,13 @@ pub(crate) mod tests { ) .await?; assert_eq!(columns, vec!["a2", "b2", "c2"]); - allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" - +----+----+----+ - | a2 | b2 | c2 | - +----+----+----+ - | 2 | 2 | 80 | - +----+----+----+ - "#)); + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+----+ + | a2 | b2 | c2 | + +----+----+----+ + | 2 | 2 | 80 | + +----+----+----+ + ")); assert_join_metrics!(metrics, 1); @@ -2629,14 +3449,14 @@ pub(crate) mod tests { ) .await?; assert_eq!(columns, vec!["a2", "b2", "c2"]); - allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" - +----+----+-----+ - | a2 | b2 | c2 | - +----+----+-----+ - | 10 | 10 | 100 | - | 12 | 10 | 40 | - +----+----+-----+ - "#)); + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+-----+ + | a2 | b2 | c2 | + +----+----+-----+ + | 10 | 10 | 100 | + | 12 | 10 | 40 | + +----+----+-----+ + ")); assert_join_metrics!(metrics, 2); @@ -2662,15 +3482,15 @@ pub(crate) mod tests { ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]); - allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" - +----+----+-----+-------+ - | a1 | b1 | c1 | mark | - +----+----+-----+-------+ - | 11 | 8 | 110 | false | - | 5 | 5 | 50 | true | - | 9 | 8 | 90 | false | - +----+----+-----+-------+ - "#)); + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+-----+-------+ + | a1 | b1 | c1 | mark | + +----+----+-----+-------+ + | 11 | 8 | 110 | false | + | 5 | 5 | 50 | true | + | 9 | 8 | 90 | false | + +----+----+-----+-------+ + ")); assert_join_metrics!(metrics, 3); @@ -2697,15 +3517,15 @@ pub(crate) mod tests { .await?; assert_eq!(columns, vec!["a2", "b2", "c2", "mark"]); - allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r#" - +----+----+-----+-------+ - | a2 | b2 | c2 | mark | - +----+----+-----+-------+ - | 10 | 10 | 100 | false | - | 12 | 10 | 40 | false | - | 2 | 2 | 80 | true | - +----+----+-----+-------+ - "#)); + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+-----+-------+ + | a2 | b2 | c2 | mark | + +----+----+-----+-------+ + | 10 | 10 | 100 | false | + | 12 | 10 | 40 | false | + | 2 | 2 | 80 | true | + +----+----+-----+-------+ + ")); assert_join_metrics!(metrics, 3); @@ -2730,42 +3550,57 @@ pub(crate) mod tests { ); let filter = prepare_join_filter(); - let join_types = vec![ + // Join types that support memory-limited fallback should succeed + // even under tight memory limits (they spill to disk instead of OOM). + let fallback_join_types = vec![ JoinType::Inner, JoinType::Left, - JoinType::Right, - JoinType::Full, JoinType::LeftSemi, JoinType::LeftAnti, JoinType::LeftMark, + JoinType::Right, JoinType::RightSemi, JoinType::RightAnti, JoinType::RightMark, ]; - for join_type in join_types { + for join_type in &fallback_join_types { let runtime = RuntimeEnvBuilder::new() .with_memory_limit(100, 1.0) .build_arc()?; let task_ctx = TaskContext::default().with_runtime(runtime); let task_ctx = Arc::new(task_ctx); - let err = multi_partitioned_join_collect( + // Should succeed via spill fallback, not OOM + let _result = multi_partitioned_join_collect( Arc::clone(&left), Arc::clone(&right), - &join_type, + join_type, Some(filter.clone()), task_ctx, ) - .await - .unwrap_err(); - - assert_contains!( - err.to_string(), - "Resources exhausted: Additional allocation failed for NestedLoopJoinLoad[0] with top memory consumers (across reservations) as:\n NestedLoopJoinLoad[0]" - ); + .await?; } + // FULL JOIN with multiple right partitions is intentionally not + // supported in the fallback path yet (cross-partition left-bitmap + // coordination is missing). It should still OOM under tight memory. + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(100, 1.0) + .build_arc()?; + let task_ctx = TaskContext::default().with_runtime(runtime); + let task_ctx = Arc::new(task_ctx); + let err = multi_partitioned_join_collect( + Arc::clone(&left), + Arc::clone(&right), + &JoinType::Full, + Some(filter.clone()), + task_ctx, + ) + .await + .unwrap_err(); + assert_contains!(err.to_string(), "Resources exhausted"); + Ok(()) } @@ -2773,4 +3608,361 @@ pub(crate) mod tests { fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() } + + // ======================================================================== + // Memory-limited execution tests + // ======================================================================== + + /// Helper to run a NLJ using partition 0 and collect results + metrics. + async fn join_collect( + left: Arc, + right: Arc, + join_type: &JoinType, + join_filter: Option, + context: Arc, + ) -> Result<(Vec, Vec, MetricsSet)> { + let nested_loop_join = + NestedLoopJoinExec::try_new(left, right, join_filter, join_type, None)?; + let columns = columns(&nested_loop_join.schema()); + let stream = nested_loop_join.execute(0, context)?; + let batches: Vec = common::collect(stream) + .await? + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect(); + let metrics = nested_loop_join.metrics().unwrap(); + Ok((columns, batches, metrics)) + } + + /// Create a TaskContext with tight memory limit and disk spilling enabled. + fn task_ctx_with_memory_limit( + memory_limit: usize, + batch_size: usize, + ) -> Result> { + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .build_arc()?; + let cfg = TaskContext::default() + .session_config() + .clone() + .with_batch_size(batch_size); + let task_ctx = TaskContext::default() + .with_runtime(runtime) + .with_session_config(cfg); + Ok(Arc::new(task_ctx)) + } + + #[tokio::test] + async fn test_nlj_memory_limited_inner_join() -> Result<()> { + // Use a very small memory limit to force OOM → fallback to spill. + let task_ctx = task_ctx_with_memory_limit(50, 16)?; + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (columns, batches, metrics) = + join_collect(left, right, &JoinType::Inner, Some(filter), task_ctx).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + // Verify spill actually occurred (memory-limited path was taken) + assert!( + metrics.spill_count().unwrap_or(0) > 0, + "Expected spilling to occur under tight memory limit" + ); + + // Result should be identical to the non-memory-limited case + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+----+ + | 5 | 5 | 50 | 2 | 2 | 80 | + +----+----+----+----+----+----+ + ")); + Ok(()) + } + + #[tokio::test] + async fn test_nlj_memory_limited_left_join() -> Result<()> { + let task_ctx = task_ctx_with_memory_limit(50, 16)?; + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (columns, batches, metrics) = + join_collect(left, right, &JoinType::Left, Some(filter), task_ctx).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + // Verify spill actually occurred + assert!( + metrics.spill_count().unwrap_or(0) > 0, + "Expected spilling to occur under tight memory limit" + ); + + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+-----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+-----+----+----+----+ + | 11 | 8 | 110 | | | | + | 5 | 5 | 50 | 2 | 2 | 80 | + | 9 | 8 | 90 | | | | + +----+----+-----+----+----+----+ + ")); + Ok(()) + } + + #[tokio::test] + async fn test_nlj_fits_in_memory_no_spill() -> Result<()> { + // Use a large memory limit — everything fits, no spilling needed. + let task_ctx = task_ctx_with_memory_limit(10_000_000, 16)?; + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (columns, batches, metrics) = + join_collect(left, right, &JoinType::Inner, Some(filter), task_ctx).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + // Verify no spilling occurred (standard OnceFut path was used) + assert_eq!( + metrics.spill_count().unwrap_or(0), + 0, + "Expected no spilling with generous memory limit" + ); + + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+----+ + | 5 | 5 | 50 | 2 | 2 | 80 | + +----+----+----+----+----+----+ + ")); + Ok(()) + } + + #[tokio::test] + async fn test_nlj_memory_limited_empty_inputs() -> Result<()> { + let task_ctx = task_ctx_with_memory_limit(50, 16)?; + + // Empty left table + let empty_left = build_table( + ("a1", &vec![]), + ("b1", &vec![]), + ("c1", &vec![]), + None, + Vec::new(), + ); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (_columns, batches, _metrics) = + join_collect(empty_left, right, &JoinType::Inner, Some(filter), task_ctx) + .await?; + assert!(batches.is_empty() || batches.iter().all(|b| b.num_rows() == 0)); + + // Empty right table + let task_ctx2 = task_ctx_with_memory_limit(50, 16)?; + let left = build_left_table(); + let empty_right = build_table( + ("a2", &vec![]), + ("b2", &vec![]), + ("c2", &vec![]), + None, + Vec::new(), + ); + let filter2 = prepare_join_filter(); + + let (_columns, batches, _metrics) = join_collect( + left, + empty_right, + &JoinType::Inner, + Some(filter2), + task_ctx2, + ) + .await?; + assert!(batches.is_empty() || batches.iter().all(|b| b.num_rows() == 0)); + + Ok(()) + } + + #[tokio::test] + async fn test_nlj_memory_limited_no_disk_falls_back_to_oom() -> Result<()> { + // When disk is disabled, fallback is not possible and OOM should occur. + use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(100, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::Disabled), + ) + .build_arc()?; + let task_ctx = Arc::new(TaskContext::default().with_runtime(runtime)); + + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let err = join_collect(left, right, &JoinType::Inner, Some(filter), task_ctx) + .await + .unwrap_err(); + + assert_contains!(err.to_string(), "Resources exhausted"); + Ok(()) + } + + #[tokio::test] + async fn test_nlj_memory_limited_right_join() -> Result<()> { + let task_ctx = task_ctx_with_memory_limit(50, 16)?; + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (columns, batches, metrics) = + join_collect(left, right, &JoinType::Right, Some(filter), task_ctx).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + // Verify spill actually occurred + assert!( + metrics.spill_count().unwrap_or(0) > 0, + "Expected spilling to occur under tight memory limit" + ); + + // Right join: all right rows appear. Unmatched right rows get NULLs on left. + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+----+----+----+-----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+-----+ + | | | | 10 | 10 | 100 | + | | | | 12 | 10 | 40 | + | 5 | 5 | 50 | 2 | 2 | 80 | + +----+----+----+----+----+-----+ + ")); + Ok(()) + } + + #[tokio::test] + async fn test_nlj_memory_limited_full_join() -> Result<()> { + let task_ctx = task_ctx_with_memory_limit(50, 16)?; + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (columns, batches, metrics) = + join_collect(left, right, &JoinType::Full, Some(filter), task_ctx).await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); + + // Verify spill actually occurred + assert!( + metrics.spill_count().unwrap_or(0) > 0, + "Expected spilling to occur under tight memory limit" + ); + + // Full join: unmatched from both sides appear with NULL padding. + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+-----+----+----+-----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+-----+----+----+-----+ + | | | | 10 | 10 | 100 | + | | | | 12 | 10 | 40 | + | 11 | 8 | 110 | | | | + | 5 | 5 | 50 | 2 | 2 | 80 | + | 9 | 8 | 90 | | | | + +----+----+-----+----+----+-----+ + ")); + Ok(()) + } + + #[tokio::test] + async fn test_nlj_memory_limited_right_semi_join() -> Result<()> { + let task_ctx = task_ctx_with_memory_limit(50, 16)?; + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (columns, batches, metrics) = + join_collect(left, right, &JoinType::RightSemi, Some(filter), task_ctx) + .await?; + + assert_eq!(columns, vec!["a2", "b2", "c2"]); + + assert!( + metrics.spill_count().unwrap_or(0) > 0, + "Expected spilling to occur under tight memory limit" + ); + + // Right semi: only right rows that matched at least one left row. + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+----+ + | a2 | b2 | c2 | + +----+----+----+ + | 2 | 2 | 80 | + +----+----+----+ + ")); + Ok(()) + } + + #[tokio::test] + async fn test_nlj_memory_limited_right_anti_join() -> Result<()> { + let task_ctx = task_ctx_with_memory_limit(50, 16)?; + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (columns, batches, metrics) = + join_collect(left, right, &JoinType::RightAnti, Some(filter), task_ctx) + .await?; + + assert_eq!(columns, vec!["a2", "b2", "c2"]); + + assert!( + metrics.spill_count().unwrap_or(0) > 0, + "Expected spilling to occur under tight memory limit" + ); + + // Right anti: right rows that did NOT match any left row. + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+-----+ + | a2 | b2 | c2 | + +----+----+-----+ + | 10 | 10 | 100 | + | 12 | 10 | 40 | + +----+----+-----+ + ")); + Ok(()) + } + + #[tokio::test] + async fn test_nlj_memory_limited_right_mark_join() -> Result<()> { + let task_ctx = task_ctx_with_memory_limit(50, 16)?; + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (columns, batches, metrics) = + join_collect(left, right, &JoinType::RightMark, Some(filter), task_ctx) + .await?; + + assert_eq!(columns, vec!["a2", "b2", "c2", "mark"]); + + assert!( + metrics.spill_count().unwrap_or(0) > 0, + "Expected spilling to occur under tight memory limit" + ); + + // Right mark: all right rows with a bool column indicating match. + allow_duplicates!(assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+-----+-------+ + | a2 | b2 | c2 | mark | + +----+----+-----+-------+ + | 10 | 10 | 100 | false | + | 12 | 10 | 40 | false | + | 2 | 2 | 80 | true | + +----+----+-----+-------+ + ")); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs index 695e73f109f3b..36a043cc7d16b 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -17,8 +17,8 @@ //! Stream Implementation for PiecewiseMergeJoin's Classic Join (Left, Right, Full, Inner) -use arrow::array::{new_null_array, Array, PrimitiveBuilder}; -use arrow::compute::{take, BatchCoalescer}; +use arrow::array::{Array, PrimitiveBuilder, new_null_array}; +use arrow::compute::{BatchCoalescer, take}; use arrow::datatypes::UInt32Type; use arrow::{ array::{ArrayRef, RecordBatch, UInt32Array}, @@ -26,7 +26,7 @@ use arrow::{ }; use arrow_schema::{Schema, SchemaRef, SortOptions}; use datafusion_common::NullEquality; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{Result, internal_err}; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::PhysicalExprRef; @@ -37,8 +37,9 @@ use std::{sync::Arc, task::Poll}; use crate::handle_state; use crate::joins::piecewise_merge_join::exec::{BufferedSide, BufferedSideReadyState}; use crate::joins::piecewise_merge_join::utils::need_produce_result_in_final; -use crate::joins::utils::{compare_join_arrays, get_final_indices_from_shared_bitmap}; use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult}; +use crate::joins::utils::{JoinKeyComparator, get_final_indices_from_shared_bitmap}; +use crate::stream::EmptyRecordBatchStream; pub(super) enum PiecewiseMergeJoinStreamState { WaitBufferedSide, @@ -70,7 +71,6 @@ pub(super) struct SortedStreamBatch { } impl SortedStreamBatch { - #[allow(dead_code)] fn new(batch: RecordBatch, compare_key_values: Vec) -> Self { Self { batch, @@ -132,7 +132,7 @@ impl RecordBatchStream for ClassicPWMJStream { // `Completed` however for Full and Right we will need to process the unmatched buffered rows. impl ClassicPWMJStream { // Creates a new `PiecewiseMergeJoinStream` instance - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] pub fn try_new( schema: Arc, on_streamed: PhysicalExprRef, @@ -189,11 +189,12 @@ impl ClassicPWMJStream { cx: &mut std::task::Context<'_>, ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); - let buffered_data = ready!(self - .buffered_side - .try_as_initial_mut()? - .buffered_fut - .get_shared(cx))?; + let buffered_data = ready!( + self.buffered_side + .try_as_initial_mut()? + .buffered_fut + .get_shared(cx) + )?; build_timer.done(); // We will start fetching stream batches for classic joins @@ -212,6 +213,9 @@ impl ClassicPWMJStream { ) -> Poll>>> { match ready!(self.streamed.poll_next_unpin(cx)) { None => { + // Release the streamed input pipeline's resources. + let streamed_schema = self.streamed.schema(); + self.streamed = Box::pin(EmptyRecordBatchStream::new(streamed_schema)); if self .buffered_side .try_as_ready_mut()? @@ -248,10 +252,7 @@ impl ClassicPWMJStream { // Reset BatchProcessState before processing a new stream batch self.batch_process_state.reset(); self.state = PiecewiseMergeJoinStreamState::ProcessStreamBatch( - SortedStreamBatch { - batch: stream_batch, - compare_key_values: vec![stream_values], - }, + SortedStreamBatch::new(stream_batch, vec![stream_values]), ); } Some(Err(err)) => return Poll::Ready(Err(err)), @@ -451,7 +452,6 @@ impl Stream for ClassicPWMJStream { } // For Left, Right, Full, and Inner joins, incoming stream batches will already be sorted. -#[allow(clippy::too_many_arguments)] fn resolve_classic_join( buffered_side: &mut BufferedSideReadyState, stream_batch: &SortedStreamBatch, @@ -464,6 +464,14 @@ fn resolve_classic_join( let buffered_len = buffered_side.buffered_data.values().len(); let stream_values = stream_batch.compare_key_values(); + // Build comparator once for the batch pair + let cmp = JoinKeyComparator::new( + &[Arc::clone(&stream_values[0])], + &[Arc::clone(buffered_side.buffered_data.values())], + &[sort_options], + NullEquality::NullEqualsNothing, + )?; + let mut buffer_idx = batch_process_state.start_buffer_idx; let mut stream_idx = batch_process_state.start_stream_idx; @@ -479,22 +487,12 @@ fn resolve_classic_join( // in the previous stream row. for row_idx in stream_idx..stream_batch.batch.num_rows() { while buffer_idx < buffered_len { - let compare = { - let buffered_values = buffered_side.buffered_data.values(); - compare_join_arrays( - &[Arc::clone(&stream_values[0])], - row_idx, - &[Arc::clone(buffered_values)], - buffer_idx, - &[sort_options], - NullEquality::NullEqualsNothing, - )? - }; + let compare = cmp.compare(row_idx, buffer_idx); // If we find a match we append all indices and move to the next stream row index match operator { Operator::Gt | Operator::Lt => { - if matches!(compare, Ordering::Less) { + if compare == Ordering::Less { batch_process_state.found = true; let count = buffered_len - buffer_idx; @@ -553,7 +551,7 @@ fn resolve_classic_join( return internal_err!( "PiecewiseMergeJoin should not contain operator, {}", operator - ) + ); } }; @@ -658,17 +656,15 @@ fn create_unmatched_batch( mod tests { use super::*; use crate::{ - common, + ExecutionPlan, common, joins::PiecewiseMergeJoinExec, - test::{build_table_i32, TestMemoryExec}, - ExecutionPlan, + test::{TestMemoryExec, build_table_i32}, }; use arrow::array::{Date32Array, Date64Array}; use arrow_schema::{DataType, Field}; use datafusion_common::test_util::batches_to_string; use datafusion_execution::TaskContext; - use datafusion_expr::JoinType; - use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; + use datafusion_physical_expr::{PhysicalExpr, expressions::Column}; use insta::assert_snapshot; use std::sync::Arc; @@ -808,7 +804,7 @@ mod tests { let (_, batches) = join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ @@ -819,7 +815,7 @@ mod tests { | 3 | 1 | 9 | 20 | 3 | 80 | | 3 | 1 | 9 | 10 | 2 | 70 | +----+----+----+----+----+----+ - "#); + "); Ok(()) } @@ -859,18 +855,18 @@ mod tests { let (_, batches) = join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+----+----+----+----+ - | 1 | 3 | 7 | 30 | 4 | 90 | - | 2 | 2 | 8 | 30 | 4 | 90 | - | 3 | 1 | 9 | 30 | 4 | 90 | - | 2 | 2 | 8 | 10 | 3 | 70 | - | 3 | 1 | 9 | 10 | 3 | 70 | - | 3 | 1 | 9 | 20 | 2 | 80 | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 3 | 7 | 30 | 4 | 90 | + | 2 | 2 | 8 | 30 | 4 | 90 | + | 3 | 1 | 9 | 30 | 4 | 90 | + | 2 | 2 | 8 | 10 | 3 | 70 | + | 3 | 1 | 9 | 10 | 3 | 70 | + | 3 | 1 | 9 | 20 | 2 | 80 | + +----+----+----+----+----+----+ + "); Ok(()) } @@ -910,7 +906,7 @@ mod tests { let (_, batches) = join_collect(left, right, on, Operator::GtEq, JoinType::Inner).await?; - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ @@ -923,7 +919,7 @@ mod tests { | 2 | 3 | 8 | 10 | 3 | 70 | | 3 | 4 | 9 | 10 | 3 | 70 | +----+----+----+----+----+----+ - "#); + "); Ok(()) } @@ -958,12 +954,12 @@ mod tests { ); let (_, batches) = join_collect(left, right, on, Operator::LtEq, JoinType::Inner).await?; - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ +----+----+----+----+----+----+ - "#); + "); Ok(()) } @@ -1001,7 +997,7 @@ mod tests { let (_, batches) = join_collect(left, right, on, Operator::GtEq, JoinType::Full).await?; - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+-----+----+----+-----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+-----+----+----+-----+ @@ -1009,7 +1005,7 @@ mod tests { | | | | 10 | 3 | 300 | | 1 | 1 | 100 | | | | +----+----+-----+----+----+-----+ - "#); + "); Ok(()) } @@ -1050,7 +1046,7 @@ mod tests { let (_, batches) = join_collect(left, right, on, Operator::Gt, JoinType::Left).await?; - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ @@ -1061,7 +1057,7 @@ mod tests { | 3 | 4 | 9 | 10 | 3 | 70 | | 1 | 1 | 7 | | | | +----+----+----+----+----+----+ - "#); + "); Ok(()) } @@ -1101,7 +1097,7 @@ mod tests { let (_, batches) = join_collect(left, right, on, Operator::Gt, JoinType::Right).await?; - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ @@ -1110,7 +1106,7 @@ mod tests { | 3 | 4 | 9 | 20 | 3 | 80 | | | | | 10 | 5 | 70 | +----+----+----+----+----+----+ - "#); + "); Ok(()) } @@ -1150,7 +1146,7 @@ mod tests { let (_, batches) = join_collect(left, right, on, Operator::Lt, JoinType::Right).await?; - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ @@ -1160,7 +1156,7 @@ mod tests { | 3 | 1 | 9 | 20 | 3 | 80 | | 3 | 1 | 9 | 10 | 2 | 70 | +----+----+----+----+----+----+ - "#); + "); Ok(()) } @@ -1201,7 +1197,7 @@ mod tests { join_collect(left, right, on, Operator::LtEq, JoinType::Inner).await?; // Expected grouping follows right.b1 descending (4, 3, 2) - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ @@ -1211,7 +1207,7 @@ mod tests { | 3 | 2 | 9 | 20 | 3 | 80 | | 3 | 2 | 9 | 30 | 2 | 90 | +----+----+----+----+----+----+ - "#); + "); Ok(()) } @@ -1252,7 +1248,7 @@ mod tests { join_collect(left, right, on, Operator::Gt, JoinType::Inner).await?; // Grouped by right in ascending evaluation for > (1,2,3) - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ @@ -1261,7 +1257,7 @@ mod tests { | 3 | 4 | 9 | 30 | 2 | 90 | | 3 | 4 | 9 | 10 | 3 | 70 | +----+----+----+----+----+----+ - "#); + "); Ok(()) } @@ -1295,7 +1291,7 @@ mod tests { let (_, batches) = join_collect(left, right, on, Operator::LtEq, JoinType::Left).await?; - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ @@ -1303,7 +1299,7 @@ mod tests { | 1 | 5 | 7 | | | | | 2 | 4 | 8 | | | | +----+----+----+----+----+----+ - "#); + "); Ok(()) } @@ -1341,14 +1337,14 @@ mod tests { let (_, batches) = join_collect(left, right, on, Operator::GtEq, JoinType::Right).await?; - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ | | | | 10 | 3 | 70 | | | | | 20 | 5 | 80 | +----+----+----+----+----+----+ - "#); + "); Ok(()) } @@ -1370,13 +1366,13 @@ mod tests { let (_, batches) = join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+-----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+-----+----+----+----+ | 42 | 5 | 999 | 30 | 7 | 90 | +----+----+-----+----+----+----+ - "#); + "); Ok(()) } @@ -1402,12 +1398,12 @@ mod tests { let (_, batches) = join_collect(left, right, on, Operator::Gt, JoinType::Inner).await?; - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +----+----+----+----+----+----+ | a1 | b1 | c1 | a2 | b1 | c2 | +----+----+----+----+----+----+ +----+----+----+----+----+----+ - "#); + "); Ok(()) } @@ -1447,13 +1443,13 @@ mod tests { let (_, batches) = join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +------------+------------+------------+------------+------------+------------+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +------------+------------+------------+------------+------------+------------+ - | 1970-01-04 | 2022-04-23 | 1970-01-10 | 1970-01-31 | 2022-04-25 | 1970-04-01 | - +------------+------------+------------+------------+------------+------------+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +------------+------------+------------+------------+------------+------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +------------+------------+------------+------------+------------+------------+ + | 1970-01-04 | 2022-04-23 | 1970-01-10 | 1970-01-31 | 2022-04-25 | 1970-04-01 | + +------------+------------+------------+------------+------------+------------+ + "); Ok(()) } @@ -1493,13 +1489,13 @@ mod tests { let (_, batches) = join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; - assert_snapshot!(batches_to_string(&batches), @r#" + assert_snapshot!(batches_to_string(&batches), @r" +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ | a1 | b1 | c1 | a2 | b1 | c2 | +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ | 1970-01-01T00:00:00.003 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ - "#); + "); Ok(()) } @@ -1537,14 +1533,14 @@ mod tests { let (_, batches) = join_collect(left, right, on, Operator::Lt, JoinType::Right).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ - | 1970-01-01T00:00:00.002 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.020 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | - | | | | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.080 | - +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ -"#); + assert_snapshot!(batches_to_string(&batches), @r" + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | 1970-01-01T00:00:00.002 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.020 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + | | | | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.080 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + "); Ok(()) } } diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs index a9ea92f2d92da..50e9252a21131 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/exec.rs @@ -23,10 +23,10 @@ use arrow::{ }; use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::not_impl_err; -use datafusion_common::{internal_err, JoinSide, Result}; +use datafusion_common::{JoinSide, Result, internal_err}; use datafusion_execution::{ - memory_pool::{MemoryConsumer, MemoryReservation}, SendableRecordBatchStream, + memory_pool::{MemoryConsumer, MemoryReservation}, }; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::equivalence::join_equivalence_properties; @@ -38,10 +38,10 @@ use datafusion_physical_expr_common::physical_expr::fmt_sql; use futures::TryStreamExt; use parking_lot::Mutex; use std::fmt::Formatter; -use std::sync::atomic::AtomicUsize; use std::sync::Arc; +use std::sync::atomic::AtomicUsize; -use crate::execution_plan::{boundedness_from_children, EmissionType}; +use crate::execution_plan::{EmissionType, boundedness_from_children}; use crate::joins::piecewise_merge_join::classic_join::{ ClassicPWMJStream, PiecewiseMergeJoinStreamState, @@ -50,16 +50,19 @@ use crate::joins::piecewise_merge_join::utils::{ build_visited_indices_map, is_existence_join, is_right_existence_join, }; use crate::joins::utils::asymmetric_join_output_partitioning; +use crate::metrics::MetricsSet; use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlanProperties, check_if_same_properties, +}; +use crate::{ + ExecutionPlan, PlanProperties, joins::{ - utils::{build_join_schema, BuildProbeJoinMetrics, OnceAsync, OnceFut}, SharedBitmapBuilder, + utils::{BuildProbeJoinMetrics, OnceAsync, OnceFut, build_join_schema}, }, metrics::ExecutionPlanMetricsSet, spill::get_record_batch_memory_size, - ExecutionPlan, PlanProperties, }; -use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; /// `PiecewiseMergeJoinExec` is a join execution plan that only evaluates single range filter and show much /// better performance for these workloads than `NestedLoopJoin` @@ -85,7 +88,7 @@ use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; /// Both sides are sorted so that we can iterate from index 0 to the end on each side. This ordering ensures /// that when we find the first matching pair of rows, we can emit the current stream row joined with all remaining /// probe rows from the match position onward, without rescanning earlier probe rows. -/// +/// /// For `<` and `<=` operators, both inputs are sorted in **descending** order, while for `>` and `>=` operators /// they are sorted in **ascending** order. This choice ensures that the pointer on the buffered side can advance /// monotonically as we stream new batches from the stream side. @@ -128,34 +131,34 @@ use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; /// /// Processing Row 1: /// -/// Sorted Buffered Side Sorted Streamed Side -/// ┌──────────────────┐ ┌──────────────────┐ -/// 1 │ 100 │ 1 │ 100 │ -/// ├──────────────────┤ ├──────────────────┤ -/// 2 │ 200 │ ─┐ 2 │ 200 │ -/// ├──────────────────┤ │ For row 1 on streamed side with ├──────────────────┤ -/// 3 │ 200 │ │ value 100, we emit rows 2 - 5. 3 │ 500 │ +/// Sorted Buffered Side Sorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ ─┐ 2 │ 200 │ +/// ├──────────────────┤ │ For row 1 on streamed side with ├──────────────────┤ +/// 3 │ 200 │ │ value 100, we emit rows 2 - 5. 3 │ 500 │ /// ├──────────────────┤ │ as matches when the operator is └──────────────────┘ /// 4 │ 300 │ │ `Operator::Lt` (<) Emitting all /// ├──────────────────┤ │ rows after the first match (row /// 5 │ 400 │ ─┘ 2 buffered side; 100 < 200) -/// └──────────────────┘ +/// └──────────────────┘ /// /// Processing Row 2: /// By sorting the streamed side we know /// -/// Sorted Buffered Side Sorted Streamed Side -/// ┌──────────────────┐ ┌──────────────────┐ -/// 1 │ 100 │ 1 │ 100 │ -/// ├──────────────────┤ ├──────────────────┤ -/// 2 │ 200 │ <- Start here when probing for the 2 │ 200 │ -/// ├──────────────────┤ streamed side row 2. ├──────────────────┤ -/// 3 │ 200 │ 3 │ 500 │ +/// Sorted Buffered Side Sorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ <- Start here when probing for the 2 │ 200 │ +/// ├──────────────────┤ streamed side row 2. ├──────────────────┤ +/// 3 │ 200 │ 3 │ 500 │ /// ├──────────────────┤ └──────────────────┘ -/// 4 │ 300 │ -/// ├──────────────────┤ +/// 4 │ 300 │ +/// ├──────────────────┤ /// 5 │ 400 │ -/// └──────────────────┘ +/// └──────────────────┘ /// ``` /// /// ## Existence Joins (Semi, Anti, Mark) @@ -201,10 +204,10 @@ use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; /// 1 │ 100 │ 1 │ 500 │ /// ├──────────────────┤ ├──────────────────┤ /// 2 │ 200 │ 2 │ 200 │ -/// ├──────────────────┤ ├──────────────────┤ +/// ├──────────────────┤ ├──────────────────┤ /// 3 │ 200 │ 3 │ 300 │ /// ├──────────────────┤ └──────────────────┘ -/// 4 │ 300 │ ─┐ +/// 4 │ 300 │ ─┐ /// ├──────────────────┤ | We emit matches for row 4 - 5 /// 5 │ 400 │ ─┘ on the buffered side. /// └──────────────────┘ @@ -235,11 +238,11 @@ use crate::{DisplayAs, DisplayFormatType, ExecutionPlanProperties}; /// /// # Mark Join: /// Sorts the probe side, then computes the min/max range of the probe keys and scans the buffered side only -/// within that range. +/// within that range. /// Complexity: `O(|S| + scan(R[range]))`. /// /// ## Nested Loop Join -/// Compares every row from `S` with every row from `R`. +/// Compares every row from `S` with every row from `R`. /// Complexity: `O(|S| * |R|)`. /// /// ## Nested Loop Join @@ -272,13 +275,12 @@ pub struct PiecewiseMergeJoinExec { left_child_plan_required_order: LexOrdering, /// The right sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations /// Unsorted for mark joins - #[allow(unused)] right_batch_required_orders: LexOrdering, /// This determines the sort order of all join columns used in sorting the stream and buffered execution plans. sort_options: SortOptions, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, /// Number of partitions to process num_partitions: usize, } @@ -321,7 +323,7 @@ impl PiecewiseMergeJoinExec { _ => { return internal_err!( "Cannot contain non-range operator in PiecewiseMergeJoinExec" - ) + ); } }; @@ -372,7 +374,7 @@ impl PiecewiseMergeJoinExec { left_child_plan_required_order, right_batch_required_orders, sort_options, - cache, + cache: Arc::new(cache), num_partitions, }) } @@ -465,6 +467,31 @@ impl PiecewiseMergeJoinExec { pub fn swap_inputs(&self) -> Result> { todo!() } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + let buffered = children.swap_remove(0); + let streamed = children.swap_remove(0); + Self { + buffered, + streamed, + on: self.on.clone(), + operator: self.operator, + join_type: self.join_type, + schema: Arc::clone(&self.schema), + left_child_plan_required_order: self.left_child_plan_required_order.clone(), + right_batch_required_orders: self.right_batch_required_orders.clone(), + sort_options: self.sort_options, + cache: Arc::clone(&self.cache), + num_partitions: self.num_partitions, + + // Re-set state. + metrics: ExecutionPlanMetricsSet::new(), + buffered_fut: Default::default(), + } + } } impl ExecutionPlan for PiecewiseMergeJoinExec { @@ -472,11 +499,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { "PiecewiseMergeJoinExec" } - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -510,6 +533,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); match &children[..] { [left, right] => Ok(Arc::new(PiecewiseMergeJoinExec::try_new( Arc::clone(left), @@ -526,6 +550,13 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { } } + fn reset_state(self: Arc) -> Result> { + Ok(Arc::new(self.with_new_children_and_same_properties(vec![ + Arc::clone(&self.buffered), + Arc::clone(&self.streamed), + ]))) + } + fn execute( &self, partition: usize, @@ -572,6 +603,10 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { ))) } } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } } impl DisplayAs for PiecewiseMergeJoinExec { @@ -615,7 +650,7 @@ async fn build_buffered_data( // Combine batches and record number of rows let initial = (Vec::new(), 0, metrics, reservation); - let (batches, num_rows, metrics, mut reservation) = buffered + let (batches, num_rows, metrics, reservation) = buffered .try_fold(initial, |mut acc, batch| async { let batch_size = get_record_batch_memory_size(&batch); acc.3.try_grow(batch_size)?; diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/bitwise_stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/bitwise_stream.rs new file mode 100644 index 0000000000000..ad7312426bd18 --- /dev/null +++ b/datafusion/physical-plan/src/joins/sort_merge_join/bitwise_stream.rs @@ -0,0 +1,1344 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sort-merge join stream specialized for semi/anti/mark joins. +//! +//! Instantiated by [`SortMergeJoinExec`](crate::joins::sort_merge_join::SortMergeJoinExec) +//! when the join type is `LeftSemi`, `LeftAnti`, `RightSemi`, `RightAnti`, +//! `LeftMark`, or `RightMark`. +//! +//! # Motivation +//! +//! The general-purpose `MaterializingSortMergeJoinStream` +//! handles semi/anti joins by materializing `(outer, inner)` row pairs, +//! applying a filter, then using a "corrected filter mask" to deduplicate. +//! Semi/anti joins only need a boolean per outer row (does a match exist?), +//! not pairs. The pair-based approach incurs unnecessary memory allocation +//! and intermediate batches. +//! +//! This stream instead tracks matches with a per-outer-batch bitset, +//! avoiding all pair materialization. +//! +//! # "Outer Side" vs "Inner Side" +//! +//! For `Left*` join types, left is outer and right is inner. +//! For `Right*` join types, right is outer and left is inner. +//! The output schema always equals the outer side's schema (for semi/anti) +//! or the outer side's schema plus a boolean mark column (for mark joins). +//! +//! # Algorithm +//! +//! Both inputs must be sorted by the join keys. The stream performs a merge +//! scan across the two sorted inputs: +//! +//! ```text +//! outer cursor ──► [1, 2, 2, 3, 5, 5, 7] +//! inner cursor ──► [2, 2, 4, 5, 6, 7, 7] +//! ▲ +//! compare keys at cursors +//! ``` +//! +//! At each step, the keys at the outer and inner cursors are compared: +//! +//! - **outer < inner**: Skip the outer key group (no match exists). +//! - **outer > inner**: Skip the inner key group. +//! - **outer == inner**: Process the match (see below). +//! +//! Key groups are contiguous runs of equal keys within one side. The scan +//! advances past entire groups at each step. +//! +//! ## Processing a key match +//! +//! **Without filter**: All outer rows in the key group are marked as matched. +//! +//! **With filter**: The inner key group is buffered (may span multiple inner +//! batches). For each buffered inner row, the filter is evaluated against the +//! outer key group as a batch. Results are OR'd into the matched bitset. A +//! short-circuit exits early when all outer rows in the group are matched. +//! +//! ```text +//! matched bitset: [0, 0, 1, 0, 0, ...] +//! ▲── one bit per outer row ──▲ +//! +//! On emit: +//! Semi → filter_record_batch(outer_batch, &matched) +//! Anti → filter_record_batch(outer_batch, &NOT(matched)) +//! Mark → outer_batch + matched as boolean column +//! ``` +//! +//! ## Batch boundaries +//! +//! Key groups can span batch boundaries on either side. The stream handles +//! this by detecting when a group extends to the end of a batch, loading the +//! next batch, and continuing if the key matches. The [`PendingBoundary`] enum +//! preserves loop context across async `Poll::Pending` re-entries. +//! +//! # Memory +//! +//! Memory usage is bounded and independent of total input size: +//! - One outer batch at a time (not tracked by reservation — single batch, +//! cannot be spilled since it's needed for filter evaluation) +//! - One inner batch at a time (streaming) +//! - `matched` bitset: one bit per outer row, re-allocated per batch +//! - Inner key group buffer: only for filtered joins, one key group at a time. +//! Tracked via `MemoryReservation`; spilled to disk when the memory pool +//! limit is exceeded. +//! - `BatchCoalescer`: output buffering to target batch size +//! +//! # Degenerate cases +//! +//! **Highly skewed key (filtered joins only):** When a filter is present, +//! the inner key group is buffered so each inner row can be evaluated +//! against the outer group. If one join key has N inner rows, all N rows +//! are held in memory simultaneously (or spilled to disk if the memory +//! pool limit is reached). With uniform key distribution this is small +//! (inner_rows / num_distinct_keys), but a single hot key can buffer +//! arbitrarily many rows. The no-filter path does not buffer inner +//! rows — it only advances the cursor — so it is unaffected. +//! +//! **Scalar broadcast during filter evaluation:** Each inner row is +//! broadcast to match the outer group length for filter evaluation, +//! allocating one array per inner row × filter column. This is inherent +//! to the `PhysicalExpr::evaluate(RecordBatch)` API, which does not +//! support scalar inputs directly. The total work is +//! O(inner_group × outer_group) per key, but with much lower constant +//! factor than the pair-materialization approach. + +use std::cmp::Ordering; +use std::fs::File; +use std::io::BufReader; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::joins::utils::{JoinFilter, JoinKeyComparator, compare_join_arrays}; +use crate::metrics::{ + BaselineMetrics, Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, +}; +use crate::spill::spill_manager::SpillManager; +use crate::{EmptyRecordBatchStream, RecordBatchStream}; +use arrow::array::{Array, ArrayRef, BooleanArray, BooleanBufferBuilder, RecordBatch}; +use arrow::compute::{BatchCoalescer, SortOptions, filter_record_batch, not}; +use arrow::datatypes::SchemaRef; +use arrow::ipc::reader::StreamReader; +use arrow::util::bit_chunk_iterator::UnalignedBitChunk; +use arrow::util::bit_util::apply_bitwise_binary_op; +use datafusion_common::{ + JoinSide, JoinType, NullEquality, Result, ScalarValue, internal_err, +}; +use datafusion_execution::SendableRecordBatchStream; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_physical_expr_common::physical_expr::PhysicalExprRef; + +use futures::{Stream, StreamExt, ready}; + +/// Evaluates join key expressions against a batch, returning one array per key. +fn evaluate_join_keys( + batch: &RecordBatch, + on: &[PhysicalExprRef], +) -> Result> { + on.iter() + .map(|expr| { + let num_rows = batch.num_rows(); + let val = expr.evaluate(batch)?; + val.into_array(num_rows) + }) + .collect() +} + +/// Find the first index in `key_arrays` starting from `from` where the key +/// differs from the key at `from`. Uses a pre-built `JoinKeyComparator` for +/// zero-alloc ordinal comparison without per-row type dispatch. +/// +/// Optimized for join workloads: checks adjacent and boundary keys before +/// falling back to binary search, since most key groups are small (often 1). +fn find_key_group_end(cmp: &JoinKeyComparator, from: usize, len: usize) -> usize { + let next = from + 1; + if next >= len { + return len; + } + + // Fast path: single-row group (common with unique keys). + if cmp.compare(from, next) != Ordering::Equal { + return next; + } + + // Check if the entire remaining batch shares this key. + let last = len - 1; + if cmp.compare(from, last) == Ordering::Equal { + return len; + } + + // Binary search the interior: key at `next` matches, key at `last` doesn't. + let mut lo = next + 1; + let mut hi = last; + while lo < hi { + let mid = lo + (hi - lo) / 2; + if cmp.compare(from, mid) == Ordering::Equal { + lo = mid + 1; + } else { + hi = mid; + } + } + lo +} + +/// When an outer key group spans a batch boundary, the boundary loop emits +/// the current batch, then polls for the next. If that poll returns Pending, +/// `ready!` exits `poll_join` and we re-enter from the top on the next call. +/// Without this state, the new batch would be processed fresh by the +/// merge-scan — but inner already advanced past this key, so the matching +/// outer rows would be skipped via `Ordering::Less` and never marked. +/// +/// This enum carries the last key (as single-row sliced arrays) from the +/// previous batch so we can check whether the next batch continues the same +/// key group. Stored as `Option`: `None` means normal +/// processing. +#[derive(Debug)] +enum PendingBoundary { + /// Resuming a no-filter boundary loop. + NoFilter { saved_keys: Vec }, + /// Resuming a filtered boundary loop. Inner key data remains in the + /// buffer (or spill file) for the resumed loop. + Filtered { saved_keys: Vec }, +} + +/// Sort-Merge join stream for Semi/Anti/Mark joins. +/// +/// Named "bitwise" because it tracks outer-row matches via a per-batch +/// boolean bitset (`BooleanBufferBuilder`) rather than materializing +/// `(outer, inner)` row pairs. Filter results are OR'd into the bitset +/// in `u64` chunks, and emitting applies the bitset directly. +pub(crate) struct BitwiseSortMergeJoinStream { + join_type: JoinType, + + // Input streams — in the nested-loop model that sort-merge join + // implements, "outer" is the driving loop and "inner" is probed for + // matches. The existing MaterializingSortMergeJoinStream calls these "streamed" + // and "buffered" respectively. For Left* joins, outer=left; for + // Right* joins, outer=right. Output schema equals the outer side. + outer: SendableRecordBatchStream, + inner: SendableRecordBatchStream, + + // Current batches and cursor positions within them + outer_batch: Option, + /// Row index into `outer_batch` — the next unprocessed outer row. + outer_offset: usize, + outer_key_arrays: Vec, + inner_batch: Option, + /// Row index into `inner_batch` — the next unprocessed inner row. + inner_offset: usize, + inner_key_arrays: Vec, + + // Per-outer-batch match tracking, reused across batches. + // Bit-packed (not Vec) so that: + // - emit: finish() yields a BooleanBuffer directly (no packing iteration) + // - OR: apply_bitwise_binary_op ORs filter results in u64 chunks + // - count: UnalignedBitChunk::count_ones uses popcnt + matched: BooleanBufferBuilder, + + // Inner key group buffer: all inner rows sharing the current join key. + // Only populated when a filter is present. Unbounded — a single key + // with many inner rows will buffer them all. See "Degenerate cases" + // in exec.rs. Spilled to disk when memory reservation fails. + inner_key_buffer: Vec, + inner_key_spill: Option, + + // True when buffer_inner_key_group returned Pending after partially + // filling inner_key_buffer. On re-entry, buffer_inner_key_group + // must skip clear() and resume from poll_next_inner_batch (the + // current inner_batch was already sliced and pushed before Pending). + buffering_inner_pending: bool, + + // Boundary re-entry state — see PendingBoundary doc comment. + pending_boundary: Option, + + // Join ON expressions, evaluated against each new batch to produce + // the key arrays used for sorted key comparisons. + on_outer: Vec, + on_inner: Vec, + filter: Option, + sort_options: Vec, + null_equality: NullEquality, + // Decomposed from JoinType: when RightSemi/RightAnti, outer=right, + // inner=left, so we swap sides when building the filter batch. + outer_is_left: bool, + + // Output + coalescer: BatchCoalescer, + schema: SchemaRef, + + // Metrics + join_time: crate::metrics::Time, + input_batches: Count, + input_rows: Count, + baseline_metrics: BaselineMetrics, + peak_mem_used: Gauge, + + // Memory / spill — only the inner key buffer is tracked via reservation, + // matching existing SMJ (which tracks only the buffered side). The outer + // batch is a single batch at a time and cannot be spilled. + reservation: MemoryReservation, + spill_manager: SpillManager, + runtime_env: Arc, + inner_buffer_size: usize, + + // Cached comparators — pre-built to avoid per-row type dispatch. + /// Comparator for outer vs inner key comparison + outer_inner_cmp: Option, + /// Comparator for outer self-comparison (find_key_group_end on outer) + outer_self_cmp: Option, + /// Comparator for inner self-comparison (find_key_group_end on inner) + inner_self_cmp: Option, + + // True once the current outer batch has been emitted. The Equal + // branch's inner loops call emit then `ready!(poll_next_outer_batch)`. + // If that poll returns Pending, poll_join re-enters from the top + // on the next poll — with outer_batch still Some and outer_offset + // past the end. The main loop's step 3 would re-emit without this + // guard. Cleared when poll_next_outer_batch loads a new batch. + batch_emitted: bool, +} + +impl BitwiseSortMergeJoinStream { + #[expect(clippy::too_many_arguments)] + pub fn try_new( + schema: SchemaRef, + sort_options: Vec, + null_equality: NullEquality, + outer: SendableRecordBatchStream, + inner: SendableRecordBatchStream, + on_outer: Vec, + on_inner: Vec, + filter: Option, + join_type: JoinType, + batch_size: usize, + partition: usize, + metrics: &ExecutionPlanMetricsSet, + reservation: MemoryReservation, + spill_manager: SpillManager, + runtime_env: Arc, + ) -> Result { + debug_assert!( + matches!( + join_type, + JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark + ), + "BitwiseSortMergeJoinStream does not handle {join_type:?}" + ); + let outer_is_left = matches!( + join_type, + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark + ); + + let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition); + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let baseline_metrics = BaselineMetrics::new(metrics, partition); + let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition); + + Ok(Self { + join_type, + outer, + inner, + outer_batch: None, + outer_offset: 0, + outer_key_arrays: vec![], + inner_batch: None, + inner_offset: 0, + inner_key_arrays: vec![], + matched: BooleanBufferBuilder::new(0), + inner_key_buffer: vec![], + inner_key_spill: None, + buffering_inner_pending: false, + pending_boundary: None, + on_outer, + on_inner, + filter, + sort_options, + null_equality, + outer_is_left, + coalescer: BatchCoalescer::new(Arc::clone(&schema), batch_size) + .with_biggest_coalesce_batch_size(Some(batch_size / 2)), + schema, + join_time, + input_batches, + input_rows, + baseline_metrics, + peak_mem_used, + reservation, + spill_manager, + runtime_env, + inner_buffer_size: 0, + outer_inner_cmp: None, + outer_self_cmp: None, + inner_self_cmp: None, + batch_emitted: false, + }) + } + + /// Resize the memory reservation to match current tracked usage. + fn try_resize_reservation(&mut self) -> Result<()> { + let needed = self.inner_buffer_size; + self.reservation.try_resize(needed)?; + self.peak_mem_used.set_max(self.reservation.size()); + Ok(()) + } + + /// Get or build the outer vs inner key comparator. + fn get_outer_inner_cmp(&mut self) -> Result<&JoinKeyComparator> { + if self.outer_inner_cmp.is_none() { + self.outer_inner_cmp = Some(JoinKeyComparator::new( + &self.outer_key_arrays, + &self.inner_key_arrays, + &self.sort_options, + self.null_equality, + )?); + } + Ok(self.outer_inner_cmp.as_ref().unwrap()) + } + + /// Get or build the outer self-comparison comparator. + fn get_outer_self_cmp(&mut self) -> Result<&JoinKeyComparator> { + if self.outer_self_cmp.is_none() { + self.outer_self_cmp = Some(JoinKeyComparator::new( + &self.outer_key_arrays, + &self.outer_key_arrays, + &self.sort_options, + self.null_equality, + )?); + } + Ok(self.outer_self_cmp.as_ref().unwrap()) + } + + /// Get or build the inner self-comparison comparator. + fn get_inner_self_cmp(&mut self) -> Result<&JoinKeyComparator> { + if self.inner_self_cmp.is_none() { + self.inner_self_cmp = Some(JoinKeyComparator::new( + &self.inner_key_arrays, + &self.inner_key_arrays, + &self.sort_options, + self.null_equality, + )?); + } + Ok(self.inner_self_cmp.as_ref().unwrap()) + } + + /// Spill the in-memory inner key buffer to disk and clear it. + fn spill_inner_key_buffer(&mut self) -> Result<()> { + let spill_file = self + .spill_manager + .spill_record_batch_and_finish( + &self.inner_key_buffer, + "semi_anti_smj_inner_key_spill", + )? + .expect("inner_key_buffer is non-empty when spilling"); + self.inner_key_buffer.clear(); + self.inner_buffer_size = 0; + self.inner_key_spill = Some(spill_file); + // Should succeed now — inner buffer has been spilled. + self.try_resize_reservation() + } + + /// Clear inner key group state after processing. Does not resize the + /// reservation — the next key group will resize when buffering, or + /// the stream's Drop will free it. This avoids unnecessary memory + /// pool interactions (see apache/datafusion#20729). + fn clear_inner_key_group(&mut self) { + self.inner_key_buffer.clear(); + self.inner_key_spill = None; + self.inner_buffer_size = 0; + } + + /// Poll for the next outer batch. Returns true if a batch was loaded. + fn poll_next_outer_batch(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match ready!(self.outer.poll_next_unpin(cx)) { + None => { + // Release the outer input pipeline's resources. + let outer_schema = self.outer.schema(); + self.outer = Box::pin(EmptyRecordBatchStream::new(outer_schema)); + return Poll::Ready(Ok(false)); + } + Some(Err(e)) => return Poll::Ready(Err(e)), + Some(Ok(batch)) => { + let batch_num_rows = batch.num_rows(); + self.input_batches.add(1); + self.input_rows.add(batch_num_rows); + if batch_num_rows == 0 { + continue; + } + let keys = evaluate_join_keys(&batch, &self.on_outer)?; + self.outer_batch = Some(batch); + self.outer_offset = 0; + self.outer_key_arrays = keys; + self.outer_inner_cmp = None; + self.outer_self_cmp = None; + self.batch_emitted = false; + self.matched = BooleanBufferBuilder::new(batch_num_rows); + self.matched.append_n(batch_num_rows, false); + return Poll::Ready(Ok(true)); + } + } + } + } + + /// Poll for the next inner batch. Returns true if a batch was loaded. + fn poll_next_inner_batch(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match ready!(self.inner.poll_next_unpin(cx)) { + None => { + // Release the inner input pipeline's resources. + let inner_schema = self.inner.schema(); + self.inner = Box::pin(EmptyRecordBatchStream::new(inner_schema)); + return Poll::Ready(Ok(false)); + } + Some(Err(e)) => return Poll::Ready(Err(e)), + Some(Ok(batch)) => { + let batch_num_rows = batch.num_rows(); + self.input_batches.add(1); + self.input_rows.add(batch_num_rows); + if batch_num_rows == 0 { + continue; + } + let keys = evaluate_join_keys(&batch, &self.on_inner)?; + self.inner_batch = Some(batch); + self.inner_offset = 0; + self.inner_key_arrays = keys; + self.outer_inner_cmp = None; + self.inner_self_cmp = None; + return Poll::Ready(Ok(true)); + } + } + } + } + + /// Emit the current outer batch through the coalescer, applying the + /// matched bitset as a selection mask. No-op if already emitted + /// (see `batch_emitted` field). + fn emit_outer_batch(&mut self) -> Result<()> { + if self.batch_emitted { + return Ok(()); + } + self.batch_emitted = true; + + let batch = self.outer_batch.as_ref().unwrap(); + + // finish() converts the bit-packed builder directly to a + // BooleanBuffer — no iteration or repacking needed. + let matched_buf = self.matched.finish(); + + match self.join_type { + JoinType::LeftMark | JoinType::RightMark => { + // Mark joins emit ALL outer rows with a boolean match column appended. + debug_assert_eq!( + self.schema.fields().len(), + batch.num_columns() + 1, + "Mark join output schema should be outer schema + 1 mark column" + ); + let mark_col = Arc::new(BooleanArray::new(matched_buf, None)) as ArrayRef; + let mut columns = Vec::with_capacity(batch.num_columns() + 1); + columns.extend_from_slice(batch.columns()); + columns.push(mark_col); + let output = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + self.coalescer.push_batch(output)?; + } + JoinType::LeftSemi | JoinType::RightSemi => { + let selection = BooleanArray::new(matched_buf, None); + let filtered = filter_record_batch(batch, &selection)?; + if filtered.num_rows() > 0 { + self.coalescer.push_batch(filtered)?; + } + } + JoinType::LeftAnti | JoinType::RightAnti => { + let selection = not(&BooleanArray::new(matched_buf, None))?; + let filtered = filter_record_batch(batch, &selection)?; + if filtered.num_rows() > 0 { + self.coalescer.push_batch(filtered)?; + } + } + _ => unreachable!(), + } + Ok(()) + } + + /// Process a key match between outer and inner sides (no filter). + /// Sets matched bits for all outer rows sharing the current key. + fn process_key_match_no_filter(&mut self) -> Result<()> { + let outer_batch = self.outer_batch.as_ref().unwrap(); + let num_outer = outer_batch.num_rows(); + + self.get_outer_self_cmp()?; + let outer_group_end = find_key_group_end( + self.outer_self_cmp.as_ref().unwrap(), + self.outer_offset, + num_outer, + ); + + for i in self.outer_offset..outer_group_end { + self.matched.set_bit(i, true); + } + + self.outer_offset = outer_group_end; + Ok(()) + } + + /// Advance inner past the current key group. Returns Ok(true) if inner + /// is exhausted. + fn advance_inner_past_key_group( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + let inner_batch = match &self.inner_batch { + Some(b) => b, + None => return Poll::Ready(Ok(true)), + }; + let num_inner = inner_batch.num_rows(); + + self.get_inner_self_cmp()?; + let group_end = find_key_group_end( + self.inner_self_cmp.as_ref().unwrap(), + self.inner_offset, + num_inner, + ); + + if group_end < num_inner { + self.inner_offset = group_end; + return Poll::Ready(Ok(false)); + } + + // Key group extends to end of batch — need to check next batch + let saved_inner_keys = slice_keys(&self.inner_key_arrays, num_inner - 1); + + match ready!(self.poll_next_inner_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => { + return Poll::Ready(Ok(true)); + } + Ok(true) => { + if keys_match( + &saved_inner_keys, + &self.inner_key_arrays, + &self.sort_options, + self.null_equality, + )? { + continue; + } else { + return Poll::Ready(Ok(false)); + } + } + } + } + } + + /// Buffer inner key group for filter evaluation. Collects all inner rows + /// with the current key across batch boundaries. + /// + /// If poll_next_inner_batch returns Pending, we save progress via + /// buffering_inner_pending. On re-entry (from the Equal branch in + /// poll_join), we skip clear() and the slice+push for the current + /// batch (which was already buffered before Pending), and go directly + /// to polling for the next inner batch. + fn buffer_inner_key_group(&mut self, cx: &mut Context<'_>) -> Poll> { + // On re-entry after Pending: don't clear the partially-filled + // buffer. The current inner_batch was already sliced and pushed + // before Pending, so jump to polling for the next batch. + let mut resume_from_poll = false; + if self.buffering_inner_pending { + self.buffering_inner_pending = false; + resume_from_poll = true; + } else { + self.clear_inner_key_group(); + } + + loop { + if self.inner_batch.is_none() { + return Poll::Ready(Ok(true)); + } + let num_inner = self.inner_batch.as_ref().unwrap().num_rows(); + self.get_inner_self_cmp()?; + let group_end = find_key_group_end( + self.inner_self_cmp.as_ref().unwrap(), + self.inner_offset, + num_inner, + ); + + if !resume_from_poll { + let inner_batch = self.inner_batch.as_ref().unwrap(); + let slice = + inner_batch.slice(self.inner_offset, group_end - self.inner_offset); + self.inner_buffer_size += slice.get_array_memory_size(); + self.inner_key_buffer.push(slice); + + // Reserve memory for the newly buffered slice. If the pool + // is exhausted, spill the entire buffer to disk. + if self.try_resize_reservation().is_err() { + if self.runtime_env.disk_manager.tmp_files_enabled() { + self.spill_inner_key_buffer()?; + } else { + // Re-attempt to get the error message + self.try_resize_reservation().map_err(|e| { + datafusion_common::DataFusionError::Execution(format!( + "{e}. Disk spilling disabled." + )) + })?; + } + } + + if group_end < num_inner { + self.inner_offset = group_end; + return Poll::Ready(Ok(false)); + } + } + resume_from_poll = false; + + // Key group extends to end of batch — check next + let saved_inner_keys = slice_keys(&self.inner_key_arrays, num_inner - 1); + + // If poll returns Pending, the current batch is already + // in inner_key_buffer. + self.buffering_inner_pending = true; + match ready!(self.poll_next_inner_batch(cx)) { + Err(e) => { + self.buffering_inner_pending = false; + return Poll::Ready(Err(e)); + } + Ok(false) => { + self.buffering_inner_pending = false; + return Poll::Ready(Ok(true)); + } + Ok(true) => { + self.buffering_inner_pending = false; + if keys_match( + &saved_inner_keys, + &self.inner_key_arrays, + &self.sort_options, + self.null_equality, + )? { + continue; + } else { + return Poll::Ready(Ok(false)); + } + } + } + } + } + + /// Process a key match with a filter. For each inner row in the buffered + /// key group, evaluates the filter against the outer key group and ORs + /// the results into the matched bitset using u64-chunked bitwise ops. + fn process_key_match_with_filter(&mut self) -> Result<()> { + self.get_outer_self_cmp()?; + let filter = self.filter.as_ref().unwrap(); + let outer_batch = self.outer_batch.as_ref().unwrap(); + let num_outer = outer_batch.num_rows(); + + // buffer_inner_key_group must be called before this function + debug_assert!( + !self.inner_key_buffer.is_empty() || self.inner_key_spill.is_some(), + "process_key_match_with_filter called with no inner key data" + ); + debug_assert!( + self.outer_offset < num_outer, + "outer_offset must be within the current batch" + ); + debug_assert!( + self.matched.len() == num_outer, + "matched vector must be sized for the current outer batch" + ); + + let outer_group_end = find_key_group_end( + self.outer_self_cmp.as_ref().unwrap(), + self.outer_offset, + num_outer, + ); + let outer_group_len = outer_group_end - self.outer_offset; + let outer_slice = outer_batch.slice(self.outer_offset, outer_group_len); + + // Count already-matched bits using popcnt on u64 chunks (zero-copy). + let mut matched_count = UnalignedBitChunk::new( + self.matched.as_slice(), + self.outer_offset, + outer_group_len, + ) + .count_ones(); + + // Process spilled inner batches first (read back from disk). + if let Some(spill_file) = &self.inner_key_spill { + let file = BufReader::new(File::open(spill_file.path())?); + let reader = StreamReader::try_new(file, None)?; + for batch_result in reader { + let inner_slice = batch_result?; + matched_count = eval_filter_for_inner_slice( + self.outer_is_left, + filter, + &outer_slice, + &inner_slice, + &mut self.matched, + self.outer_offset, + outer_group_len, + matched_count, + )?; + if matched_count == outer_group_len { + break; + } + } + } + + // Then process in-memory inner batches. + // evaluate_filter_for_inner_row is a free function (not &self method) + // so that Rust can split the struct borrow: &mut self.matched coexists + // with &self.inner_key_buffer and &self.filter inside this loop. + if matched_count < outer_group_len { + 'outer: for inner_slice in &self.inner_key_buffer { + matched_count = eval_filter_for_inner_slice( + self.outer_is_left, + filter, + &outer_slice, + inner_slice, + &mut self.matched, + self.outer_offset, + outer_group_len, + matched_count, + )?; + if matched_count == outer_group_len { + break 'outer; + } + } + } + + self.outer_offset = outer_group_end; + Ok(()) + } + + /// Continue processing an outer key group that spans multiple outer + /// batches. Returns `true` if this outer batch was fully consumed + /// by the key group and the caller should load another. + fn resume_boundary(&mut self) -> Result { + debug_assert!( + self.outer_batch.is_some(), + "caller must load outer_batch first" + ); + match self.pending_boundary.take() { + Some(PendingBoundary::NoFilter { saved_keys }) => { + let same_key = keys_match( + &saved_keys, + &self.outer_key_arrays, + &self.sort_options, + self.null_equality, + )?; + if same_key { + self.process_key_match_no_filter()?; + let num_outer = self.outer_batch.as_ref().unwrap().num_rows(); + if self.outer_offset >= num_outer { + self.pending_boundary = Some(PendingBoundary::NoFilter { + saved_keys: slice_keys(&self.outer_key_arrays, num_outer - 1), + }); + self.emit_outer_batch()?; + self.outer_batch = None; + return Ok(true); + } + } + } + Some(PendingBoundary::Filtered { saved_keys }) => { + debug_assert!( + !self.inner_key_buffer.is_empty() || self.inner_key_spill.is_some(), + "Filtered pending boundary entered but no inner key data exists" + ); + let same_key = keys_match( + &saved_keys, + &self.outer_key_arrays, + &self.sort_options, + self.null_equality, + )?; + if same_key { + self.process_key_match_with_filter()?; + let num_outer = self.outer_batch.as_ref().unwrap().num_rows(); + if self.outer_offset >= num_outer { + self.pending_boundary = Some(PendingBoundary::Filtered { + saved_keys: slice_keys(&self.outer_key_arrays, num_outer - 1), + }); + self.emit_outer_batch()?; + self.outer_batch = None; + return Ok(true); + } + } + self.clear_inner_key_group(); + } + None => {} + } + Ok(false) + } + + /// Main loop: drive the merge-scan to produce output batches. + fn poll_join(&mut self, cx: &mut Context<'_>) -> Poll>> { + let join_time = self.join_time.clone(); + let _timer = join_time.timer(); + + loop { + // 1. Ensure we have an outer batch + if self.outer_batch.is_none() { + match ready!(self.poll_next_outer_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => { + // Outer exhausted — flush coalescer + self.pending_boundary = None; + self.coalescer.finish_buffered_batch()?; + if let Some(batch) = self.coalescer.next_completed_batch() { + return Poll::Ready(Ok(Some(batch))); + } + return Poll::Ready(Ok(None)); + } + Ok(true) => { + if self.resume_boundary()? { + continue; + } + } + } + } + + // 2. Ensure we have an inner batch (unless inner is exhausted). + // Skip this when resuming a pending boundary — inner was already + // advanced past the key group before the boundary loop started. + if self.inner_batch.is_none() && self.pending_boundary.is_none() { + match ready!(self.poll_next_inner_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => { + // Inner exhausted — emit remaining outer batches. + // For semi: no more matches possible. + // For anti: all remaining outer rows are unmatched. + self.emit_outer_batch()?; + self.outer_batch = None; + + loop { + match ready!(self.poll_next_outer_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => break, + Ok(true) => { + self.emit_outer_batch()?; + self.outer_batch = None; + } + } + } + + self.coalescer.finish_buffered_batch()?; + if let Some(batch) = self.coalescer.next_completed_batch() { + return Poll::Ready(Ok(Some(batch))); + } + return Poll::Ready(Ok(None)); + } + Ok(true) => {} + } + } + + // 3. Main merge-scan loop + let outer_batch = self.outer_batch.as_ref().unwrap(); + let num_outer = outer_batch.num_rows(); + + if self.outer_offset >= num_outer { + self.emit_outer_batch()?; + self.outer_batch = None; + + if let Some(batch) = self.coalescer.next_completed_batch() { + return Poll::Ready(Ok(Some(batch))); + } + continue; + } + + let inner_batch = match &self.inner_batch { + Some(b) => b, + None => { + self.emit_outer_batch()?; + self.outer_batch = None; + continue; + } + }; + let num_inner = inner_batch.num_rows(); + + if self.inner_offset >= num_inner { + match ready!(self.poll_next_inner_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => { + self.inner_batch = None; + continue; + } + Ok(true) => continue, + } + } + + // 4. Compare keys at current positions + self.get_outer_inner_cmp()?; + let cmp = self + .outer_inner_cmp + .as_ref() + .unwrap() + .compare(self.outer_offset, self.inner_offset); + + match cmp { + Ordering::Less => { + self.get_outer_self_cmp()?; + let group_end = find_key_group_end( + self.outer_self_cmp.as_ref().unwrap(), + self.outer_offset, + num_outer, + ); + self.outer_offset = group_end; + } + Ordering::Greater => { + self.get_inner_self_cmp()?; + let group_end = find_key_group_end( + self.inner_self_cmp.as_ref().unwrap(), + self.inner_offset, + num_inner, + ); + if group_end >= num_inner { + let saved_keys = + slice_keys(&self.inner_key_arrays, num_inner - 1); + match ready!(self.poll_next_inner_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => { + self.inner_batch = None; + continue; + } + Ok(true) => { + if keys_match( + &saved_keys, + &self.inner_key_arrays, + &self.sort_options, + self.null_equality, + )? { + match ready!(self.advance_inner_past_key_group(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(_) => continue, + } + } + continue; + } + } + } else { + self.inner_offset = group_end; + } + } + Ordering::Equal => { + if self.filter.is_some() { + // Buffer inner key group (may span batches) + match ready!(self.buffer_inner_key_group(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(_inner_exhausted) => {} + } + + // Process outer rows against buffered inner group + // (may need to handle outer batch boundary) + loop { + self.process_key_match_with_filter()?; + + let outer_batch = self.outer_batch.as_ref().unwrap(); + if self.outer_offset >= outer_batch.num_rows() { + let saved_keys = slice_keys( + &self.outer_key_arrays, + outer_batch.num_rows() - 1, + ); + + self.emit_outer_batch()?; + debug_assert!( + !self.inner_key_buffer.is_empty() + || self.inner_key_spill.is_some(), + "Filtered pending boundary requires inner key data in buffer or spill" + ); + self.pending_boundary = + Some(PendingBoundary::Filtered { saved_keys }); + + match ready!(self.poll_next_outer_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => { + self.pending_boundary = None; + self.outer_batch = None; + break; + } + Ok(true) => { + let Some(PendingBoundary::Filtered { + saved_keys, + }) = self.pending_boundary.take() + else { + unreachable!() + }; + let same = keys_match( + &saved_keys, + &self.outer_key_arrays, + &self.sort_options, + self.null_equality, + )?; + if same { + continue; + } + break; + } + } + } else { + break; + } + } + + self.clear_inner_key_group(); + } else { + // No filter: advance inner past key group, then + // mark all outer rows with this key as matched. + match ready!(self.advance_inner_past_key_group(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(_inner_exhausted) => {} + } + + loop { + self.process_key_match_no_filter()?; + + let num_outer = self.outer_batch.as_ref().unwrap().num_rows(); + if self.outer_offset >= num_outer { + let saved_keys = + slice_keys(&self.outer_key_arrays, num_outer - 1); + + self.emit_outer_batch()?; + self.pending_boundary = + Some(PendingBoundary::NoFilter { saved_keys }); + + match ready!(self.poll_next_outer_batch(cx)) { + Err(e) => return Poll::Ready(Err(e)), + Ok(false) => { + self.pending_boundary = None; + self.outer_batch = None; + break; + } + Ok(true) => { + let Some(PendingBoundary::NoFilter { + saved_keys, + }) = self.pending_boundary.take() + else { + unreachable!() + }; + let same_key = keys_match( + &saved_keys, + &self.outer_key_arrays, + &self.sort_options, + self.null_equality, + )?; + if same_key { + continue; + } + break; + } + } + } else { + break; + } + } + } + } + } + + // Check for completed coalescer batch + if let Some(batch) = self.coalescer.next_completed_batch() { + return Poll::Ready(Ok(Some(batch))); + } + } + } +} + +/// Evaluate the filter for all rows in an inner slice against the outer group, +/// OR-ing results into the matched bitset. Returns the updated matched count. +/// Extracted as a free function so Rust can split borrows on the stream struct. +#[expect(clippy::too_many_arguments)] +fn eval_filter_for_inner_slice( + outer_is_left: bool, + filter: &JoinFilter, + outer_slice: &RecordBatch, + inner_slice: &RecordBatch, + matched: &mut BooleanBufferBuilder, + outer_offset: usize, + outer_group_len: usize, + // Passed in to avoid recounting bits we just counted at the call site. + mut matched_count: usize, +) -> Result { + debug_assert_eq!( + matched_count, + UnalignedBitChunk::new(matched.as_slice(), outer_offset, outer_group_len) + .count_ones() + ); + for inner_row in 0..inner_slice.num_rows() { + if matched_count == outer_group_len { + break; + } + + let filter_result = evaluate_filter_for_inner_row( + outer_is_left, + filter, + outer_slice, + inner_slice, + inner_row, + )?; + + // OR filter results into the matched bitset. Both sides are + // bit-packed [u8] buffers, so apply_bitwise_binary_op + // processes 64 bits per loop iteration (not 1 bit at a time). + // + // The offsets handle alignment: outer_offset is the bit + // position within matched where this key group starts, + // and filter_buf.offset() is the BooleanBuffer's internal + // bit offset (usually 0, but not guaranteed by Arrow). + let filter_buf = filter_result.values(); + apply_bitwise_binary_op( + matched.as_slice_mut(), + outer_offset, + filter_buf.inner().as_slice(), + filter_buf.offset(), + outer_group_len, + |a, b| a | b, + ); + + // Recount matched bits after the OR. UnalignedBitChunk is + // zero-copy — it reads the bytes in place and uses popcnt. + matched_count = + UnalignedBitChunk::new(matched.as_slice(), outer_offset, outer_group_len) + .count_ones(); + } + Ok(matched_count) +} + +/// Slice each key array to a single row at `idx`. +fn slice_keys(keys: &[ArrayRef], idx: usize) -> Vec { + keys.iter().map(|a| a.slice(idx, 1)).collect() +} + +/// Compare the first row of two key arrays using sort options to determine +/// equality. The left side is expected to be single-row slices (from +/// `slice_keys`); the right side can be any length (row 0 is compared). +fn keys_match( + left_arrays: &[ArrayRef], + right_arrays: &[ArrayRef], + sort_options: &[SortOptions], + null_equality: NullEquality, +) -> Result { + debug_assert!(left_arrays.iter().all(|a| a.len() == 1)); + let cmp = compare_join_arrays( + left_arrays, + 0, + right_arrays, + 0, + sort_options, + null_equality, + )?; + Ok(cmp == Ordering::Equal) +} + +/// Evaluate the join filter for one inner row against a slice of outer rows. +/// +/// Free function (not a method on BitwiseSortMergeJoinStream) so that Rust +/// can split the struct borrow in process_key_match_with_filter: the caller +/// holds &mut self.matched and &self.inner_key_buffer simultaneously, which +/// is impossible if this borrows all of &self. +fn evaluate_filter_for_inner_row( + outer_is_left: bool, + filter: &JoinFilter, + outer_slice: &RecordBatch, + inner_batch: &RecordBatch, + inner_idx: usize, +) -> Result { + let num_outer_rows = outer_slice.num_rows(); + + // Build filter input columns in the order the filter expects + let mut columns: Vec = Vec::with_capacity(filter.column_indices().len()); + for col_idx in filter.column_indices() { + let (side_batch, side_idx) = if outer_is_left { + match col_idx.side { + JoinSide::Left => (outer_slice, None), + JoinSide::Right => (inner_batch, Some(inner_idx)), + JoinSide::None => { + return internal_err!("Unexpected JoinSide::None in filter"); + } + } + } else { + match col_idx.side { + JoinSide::Left => (inner_batch, Some(inner_idx)), + JoinSide::Right => (outer_slice, None), + JoinSide::None => { + return internal_err!("Unexpected JoinSide::None in filter"); + } + } + }; + + match side_idx { + None => { + columns.push(Arc::clone(side_batch.column(col_idx.index))); + } + Some(idx) => { + // Broadcasts inner scalar to N-element array. Arrow's + // BinaryExpr handles Scalar×Array natively via the Datum + // trait, but Column::evaluate always returns Array, so + // we'd need a custom expr to avoid this broadcast. + let scalar = ScalarValue::try_from_array( + side_batch.column(col_idx.index).as_ref(), + idx, + )?; + columns.push(scalar.to_array_of_size(num_outer_rows)?); + } + } + } + + let filter_batch = RecordBatch::try_new(Arc::clone(filter.schema()), columns)?; + let result = filter + .expression() + .evaluate(&filter_batch)? + .into_array(num_outer_rows)?; + let bool_arr = result + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion_common::DataFusionError::Internal( + "Filter expression did not return BooleanArray".to_string(), + ) + })?; + // Treat nulls as false + if bool_arr.null_count() > 0 { + Ok(arrow::compute::prep_null_mask_filter(bool_arr)) + } else { + Ok(bool_arr.clone()) + } +} + +impl Stream for BitwiseSortMergeJoinStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let poll = self.poll_join(cx).map(|result| result.transpose()); + self.baseline_metrics.record_poll(poll) + } +} + +impl RecordBatchStream for BitwiseSortMergeJoinStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs index b5b4325798f9d..a86cb647e4bff 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs @@ -19,39 +19,40 @@ //! A Sort-Merge join plan consumes two sorted children plans and produces //! joined output by given join type and other options. -use std::any::Any; use std::fmt::Formatter; use std::sync::Arc; -use crate::execution_plan::{boundedness_from_children, EmissionType}; +use super::bitwise_stream::BitwiseSortMergeJoinStream; +use super::materializing_stream::MaterializingSortMergeJoinStream; +use super::metrics::SortMergeJoinMetrics; +use crate::execution_plan::{EmissionType, boundedness_from_children}; use crate::expressions::PhysicalSortExpr; -use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics; -use crate::joins::sort_merge_join::stream::SortMergeJoinStream; use crate::joins::utils::{ - build_join_schema, check_join_is_valid, estimate_join_statistics, - reorder_output_after_swap, symmetric_join_output_partitioning, JoinFilter, JoinOn, - JoinOnRef, + JoinFilter, JoinOn, JoinOnRef, build_join_schema, check_join_is_valid, + estimate_join_statistics, reorder_output_after_swap, + symmetric_join_output_partitioning, }; -use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet, SpillMetrics}; use crate::projection::{ - join_allows_pushdown, join_table_borders, new_join_children, - physical_to_column_exprs, update_join_on, ProjectionExec, + ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children, + physical_to_column_exprs, update_join_on, }; +use crate::spill::spill_manager::SpillManager; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, - PlanProperties, SendableRecordBatchStream, Statistics, + PlanProperties, SendableRecordBatchStream, Statistics, check_if_same_properties, }; use arrow::compute::SortOptions; use arrow::datatypes::SchemaRef; use datafusion_common::{ - assert_eq_or_internal_err, internal_err, plan_err, JoinSide, JoinType, NullEquality, - Result, + JoinSide, JoinType, NullEquality, Result, assert_eq_or_internal_err, internal_err, + plan_err, }; -use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_physical_expr::equivalence::join_equivalence_properties; -use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; +use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql}; use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; /// Join execution plan that executes equi-join predicates on multiple partitions using Sort-Merge @@ -127,7 +128,7 @@ pub struct SortMergeJoinExec { /// Defines the null equality for the join. pub null_equality: NullEquality, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl SortMergeJoinExec { @@ -198,7 +199,7 @@ impl SortMergeJoinExec { right_sort_exprs, sort_options, null_equality, - cache, + cache: Arc::new(cache), }) } @@ -334,12 +335,28 @@ impl SortMergeJoinExec { | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark ) { Ok(Arc::new(new_join)) } else { reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema()) } } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + let left = children.swap_remove(0); + let right = children.swap_remove(0); + Self { + left, + right, + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for SortMergeJoinExec { @@ -353,14 +370,15 @@ impl DisplayAs for SortMergeJoinExec { .collect::>() .join(", "); let display_null_equality = - if matches!(self.null_equality(), NullEquality::NullEqualsNull) { + if self.null_equality() == NullEquality::NullEqualsNull { ", NullsEqual: true" } else { "" }; write!( f, - "SortMergeJoin: join_type={:?}, on=[{}]{}{}", + "{}: join_type={:?}, on=[{}]{}{}", + Self::static_name(), self.join_type, on, self.filter.as_ref().map_or_else( @@ -385,7 +403,7 @@ impl DisplayAs for SortMergeJoinExec { } writeln!(f, "on={on}")?; - if matches!(self.null_equality(), NullEquality::NullEqualsNull) { + if self.null_equality() == NullEquality::NullEqualsNull { writeln!(f, "NullsEqual: true")?; } @@ -400,11 +418,7 @@ impl ExecutionPlan for SortMergeJoinExec { "SortMergeJoinExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -439,6 +453,7 @@ impl ExecutionPlan for SortMergeJoinExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); match &children[..] { [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new( Arc::clone(left), @@ -488,54 +503,89 @@ impl ExecutionPlan for SortMergeJoinExec { let streamed = streamed.execute(partition, Arc::clone(&context))?; let buffered = buffered.execute(partition, Arc::clone(&context))?; - // create output buffer let batch_size = context.session_config().batch_size(); - - // create memory reservation let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]")) .register(context.memory_pool()); + let spill_manager = SpillManager::new( + context.runtime_env(), + SpillMetrics::new(&self.metrics, partition), + buffered.schema(), + ) + .with_compression_type(context.session_config().spill_compression()); - // create join stream - Ok(Box::pin(SortMergeJoinStream::try_new( - context.session_config().spill_compression(), - Arc::clone(&self.schema), - self.sort_options.clone(), - self.null_equality, - streamed, - buffered, - on_streamed, - on_buffered, - self.filter.clone(), + if matches!( self.join_type, - batch_size, - SortMergeJoinMetrics::new(partition, &self.metrics), - reservation, - context.runtime_env(), - )?)) + JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark + ) { + Ok(Box::pin(BitwiseSortMergeJoinStream::try_new( + Arc::clone(&self.schema), + self.sort_options.clone(), + self.null_equality, + streamed, + buffered, + on_streamed, + on_buffered, + self.filter.clone(), + self.join_type, + batch_size, + partition, + &self.metrics, + reservation, + spill_manager, + context.runtime_env(), + )?)) + } else { + Ok(Box::pin(MaterializingSortMergeJoinStream::try_new( + Arc::clone(&self.schema), + self.sort_options.clone(), + self.null_equality, + streamed, + buffered, + on_streamed, + on_buffered, + self.filter.clone(), + self.join_type, + batch_size, + SortMergeJoinMetrics::new(partition, &self.metrics), + reservation, + spill_manager, + context.runtime_env(), + )?)) + } } fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { - if partition.is_some() { - return Ok(Statistics::new_unknown(&self.schema())); - } + fn partition_statistics(&self, partition: Option) -> Result> { + // SortMergeJoinExec uses symmetric hash partitioning where both left and right + // inputs are hash-partitioned on the join keys. This means partition `i` of the + // left input is joined with partition `i` of the right input. + // + // Therefore, partition-specific statistics can be computed by getting the + // partition-specific statistics from both children and combining them via + // `estimate_join_statistics`. + // // TODO stats: it is not possible in general to know the output size of joins // There are some special cases though, for example: // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` - estimate_join_statistics( - self.left.partition_statistics(None)?, - self.right.partition_statistics(None)?, + let left_stats = Arc::unwrap_or_clone(self.left.partition_statistics(partition)?); + let right_stats = + Arc::unwrap_or_clone(self.right.partition_statistics(partition)?); + Ok(Arc::new(estimate_join_statistics( + left_stats, + right_stats, &self.on, + self.null_equality, &self.join_type, &self.schema, - ) + )?)) } /// Tries to swap the projection with its input [`SortMergeJoinExec`]. If it can be done, diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/filter.rs b/datafusion/physical-plan/src/joins/sort_merge_join/filter.rs new file mode 100644 index 0000000000000..4fc6cccaa8838 --- /dev/null +++ b/datafusion/physical-plan/src/joins/sort_merge_join/filter.rs @@ -0,0 +1,388 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Filter handling for Sort-Merge Join +//! +//! This module encapsulates the complexity of join filter evaluation, including: +//! - Immediate filtering for INNER joins +//! - Deferred filtering for outer joins +//! - Metadata tracking for grouping output rows by input row +//! - Correcting filter masks to handle multiple matches per input row + +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayBuilder, ArrayRef, BooleanArray, BooleanBuilder, RecordBatch, + RecordBatchOptions, UInt64Array, UInt64Builder, new_null_array, +}; +use arrow::compute::kernels::zip::zip; +use arrow::compute::{self, filter_record_batch}; +use arrow::datatypes::SchemaRef; +use datafusion_common::{JoinSide, JoinType, Result}; + +use crate::joins::utils::JoinFilter; + +/// Metadata for tracking filter results during deferred filtering +/// +/// When a join filter is present and we need to ensure each input row produces +/// at least one output (outer joins), we can't filter immediately. Instead, +/// we accumulate all joined rows with metadata, then post-process to determine +/// which rows to output. +#[derive(Debug)] +pub struct FilterMetadata { + /// Did each output row pass the join filter? + /// Used to detect if an input row found ANY match + pub filter_mask: BooleanBuilder, + + /// Which input row (within batch) produced each output row? + /// Used for grouping output rows by input row + pub row_indices: UInt64Builder, + + /// Which input batch did each output row come from? + /// Used to disambiguate row_indices across multiple batches + pub batch_ids: Vec, +} + +impl FilterMetadata { + /// Create new empty filter metadata + pub fn new() -> Self { + Self { + filter_mask: BooleanBuilder::new(), + row_indices: UInt64Builder::new(), + batch_ids: vec![], + } + } + + /// Returns (row_indices, filter_mask, batch_ids_ref) and clears builders + pub fn finish_metadata(&mut self) -> (UInt64Array, BooleanArray, &[usize]) { + let row_indices = self.row_indices.finish(); + let filter_mask = self.filter_mask.finish(); + (row_indices, filter_mask, &self.batch_ids) + } + + /// Add metadata for null-joined rows (no filter applied) + pub fn append_nulls(&mut self, num_rows: usize) { + self.filter_mask.append_nulls(num_rows); + self.row_indices.append_nulls(num_rows); + self.batch_ids.resize( + self.batch_ids.len() + num_rows, + 0, // batch_id = 0 for null-joined rows + ); + } + + /// Add metadata for filtered rows + pub fn append_filter_metadata( + &mut self, + row_indices: &UInt64Array, + filter_mask: &BooleanArray, + batch_id: usize, + ) { + debug_assert_eq!( + row_indices.len(), + filter_mask.len(), + "row_indices and filter_mask must have same length" + ); + + self.filter_mask.extend(filter_mask); + self.row_indices.extend(row_indices); + self.batch_ids + .resize(self.batch_ids.len() + row_indices.len(), batch_id); + } + + /// Verify that metadata arrays are aligned (same length) + pub fn debug_assert_metadata_aligned(&self) { + if self.filter_mask.len() > 0 { + debug_assert_eq!( + self.filter_mask.len(), + self.row_indices.len(), + "filter_mask and row_indices must have same length when metadata is used" + ); + debug_assert_eq!( + self.filter_mask.len(), + self.batch_ids.len(), + "filter_mask and batch_ids must have same length when metadata is used" + ); + } else { + debug_assert_eq!( + self.filter_mask.len(), + 0, + "filter_mask should be empty when batches is empty" + ); + } + } +} + +impl Default for FilterMetadata { + fn default() -> Self { + Self::new() + } +} + +/// Determines if a join type needs deferred filtering +/// +/// Deferred filtering is required when: +/// - A filter exists AND +/// - The join type requires ensuring each input row produces at least one output +pub fn needs_deferred_filtering( + filter: &Option, + join_type: JoinType, +) -> bool { + filter.is_some() + && matches!(join_type, JoinType::Left | JoinType::Right | JoinType::Full) +} + +/// Gets the arrays which join filters are applied on +/// +/// Extracts the columns needed for filter evaluation from left and right batch columns +pub fn get_filter_columns( + join_filter: &Option, + left_columns: &[ArrayRef], + right_columns: &[ArrayRef], +) -> Vec { + let mut filter_columns = vec![]; + + if let Some(f) = join_filter { + let left_columns: Vec = f + .column_indices() + .iter() + .filter(|col_index| col_index.side == JoinSide::Left) + .map(|i| Arc::clone(&left_columns[i.index])) + .collect(); + let right_columns: Vec = f + .column_indices() + .iter() + .filter(|col_index| col_index.side == JoinSide::Right) + .map(|i| Arc::clone(&right_columns[i.index])) + .collect(); + + filter_columns.extend(left_columns); + filter_columns.extend(right_columns); + } + + filter_columns +} + +/// Determines if current index is the last occurrence of a row +/// +/// Used during filter mask correction to detect row boundaries when grouping +/// output rows by input row. +fn last_index_for_row( + row_index: usize, + indices: &UInt64Array, + batch_ids: &[usize], + indices_len: usize, +) -> bool { + debug_assert_eq!( + indices.len(), + indices_len, + "indices.len() should match indices_len parameter" + ); + debug_assert_eq!( + batch_ids.len(), + indices_len, + "batch_ids.len() should match indices_len" + ); + debug_assert!( + row_index < indices_len, + "row_index {row_index} should be < indices_len {indices_len}", + ); + + // If this is the last index overall, it's definitely the last for this row + if row_index == indices_len - 1 { + return true; + } + + // Check if next row has different (batch_id, index) pair + let current_batch_id = batch_ids[row_index]; + let next_batch_id = batch_ids[row_index + 1]; + + if current_batch_id != next_batch_id { + return true; + } + + // Same batch_id, check if row index is different + // Both current and next should be non-null (already joined rows) + if indices.is_null(row_index) || indices.is_null(row_index + 1) { + return true; + } + + indices.value(row_index) != indices.value(row_index + 1) +} + +/// Corrects the filter mask for joins with deferred filtering +/// +/// When an input row joins with multiple buffered rows, we get multiple output rows. +/// This function groups them by input row and applies join-type-specific logic: +/// +/// - **Outer joins**: Keep first matching row, convert rest to nulls, add null-joined for unmatched +/// +/// # Arguments +/// * `join_type` - The type of join being performed +/// * `row_indices` - Which input row produced each output row +/// * `batch_ids` - Which batch each output row came from +/// * `filter_mask` - Whether each output row passed the filter +/// * `expected_size` - Total number of input rows (for adding unmatched) +/// +/// # Returns +/// Corrected mask indicating which rows to include in final output: +/// - `true`: Include this row +/// - `false`: Convert to null-joined row (outer joins) +/// - `null`: Discard this row +pub fn get_corrected_filter_mask( + join_type: JoinType, + row_indices: &UInt64Array, + batch_ids: &[usize], + filter_mask: &BooleanArray, + expected_size: usize, +) -> Option { + let row_indices_length = row_indices.len(); + let mut corrected_mask: BooleanBuilder = + BooleanBuilder::with_capacity(row_indices_length); + let mut seen_true = false; + + match join_type { + JoinType::Left | JoinType::Right | JoinType::Full => { + // For each input row group: keep first filter-passing row, + // discard (null) remaining matches, null-join if none passed. + // Null metadata entries are already-null-joined rows that + // flow through unchanged to preserve output ordering. + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + if filter_mask.is_null(i) { + corrected_mask.append_value(true); + } else if filter_mask.value(i) { + seen_true = true; + corrected_mask.append_value(true); + } else if seen_true || !filter_mask.value(i) && !last_index { + corrected_mask.append_null(); + } else { + corrected_mask.append_value(false); + } + + if last_index { + seen_true = false; + } + } + + corrected_mask.append_n(expected_size - corrected_mask.len(), false); + Some(corrected_mask.finish()) + } + JoinType::LeftMark + | JoinType::RightMark + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti => { + unreachable!("Semi/anti/mark joins are handled by BitwiseSortMergeJoinStream") + } + JoinType::Inner => None, + } +} + +/// Applies corrected filter mask to record batch based on join type +/// +/// The corrected mask has three possible values per row: +/// - `true`: Keep the row as-is (matched and passed filter) +/// - `false`: Convert to null-joined row (all filter matches failed for this input row) +/// - `null`: Discard the row entirely (duplicate match for an already-output input row) +/// +/// This function preserves input row ordering by processing each row in place +/// rather than separating matched/unmatched rows. +pub fn filter_record_batch_by_join_type( + record_batch: &RecordBatch, + corrected_mask: &BooleanArray, + join_type: JoinType, + schema: &SchemaRef, + buffered_schema: &SchemaRef, +) -> Result { + match join_type { + JoinType::Left | JoinType::Right | JoinType::Full => { + if record_batch.num_rows() == 0 { + return Ok(record_batch.clone()); + } + + // Discard null-masked rows (keep true + false only) + let keep_mask = compute::is_not_null(corrected_mask)?; + let kept_batch = filter_record_batch(record_batch, &keep_mask)?; + + if kept_batch.num_rows() == 0 { + return Ok(kept_batch); + } + + let kept_corrected = compute::filter(corrected_mask, &keep_mask)?; + let kept_corrected = kept_corrected + .as_any() + .downcast_ref::() + .unwrap(); + + // All rows passed the filter — no null-joining needed + if !kept_corrected.has_false() { + return Ok(kept_batch); + } + + // For false entries: replace the non-preserved side with nulls. + // This preserves row ordering unlike filter+concat. + let (null_side_start, null_side_len) = match join_type { + JoinType::Left => { + // Left join: null out right (buffered) columns + let left_cols = + schema.fields().len() - buffered_schema.fields().len(); + (left_cols, buffered_schema.fields().len()) + } + JoinType::Right => { + // Right join: null out left (buffered) columns + (0, buffered_schema.fields().len()) + } + JoinType::Full => { + // Full join: null out buffered columns for streamed rows + // that matched but failed the filter. Unmatched buffered + // rows are null-joined on the streamed side separately + // when the buffered batch is drained. + let left_cols = + schema.fields().len() - buffered_schema.fields().len(); + (left_cols, buffered_schema.fields().len()) + } + _ => unreachable!(), + }; + + let num_rows = kept_batch.num_rows(); + let mut columns: Vec = kept_batch.columns().to_vec(); + + for col in columns.iter_mut().skip(null_side_start).take(null_side_len) { + let null_array = new_null_array(col.data_type(), num_rows); + *col = zip(kept_corrected, &*col, &null_array)?; + } + + let options = RecordBatchOptions::new().with_row_count(Some(num_rows)); + Ok(RecordBatch::try_new_with_options( + Arc::clone(schema), + columns, + &options, + )?) + } + JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark => unreachable!( + "Semi/anti/mark joins are handled by SemiAntiMarkSortMergeJoinStream" + ), + JoinType::Inner => Ok(filter_record_batch(record_batch, corrected_mask)?), + } +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs new file mode 100644 index 0000000000000..9bcc749c23dce --- /dev/null +++ b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs @@ -0,0 +1,1948 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sort-Merge Join execution +//! +//! This module implements the runtime state machine for the Sort-Merge Join +//! operator. It drives two sorted input streams (the *streamed* side and the +//! *buffered* side), compares join keys, and produces joined `RecordBatch`es. + +use std::cmp::Ordering; +use std::collections::{HashMap, VecDeque}; +use std::fs::File; +use std::io::BufReader; +use std::mem::size_of; +use std::ops::Range; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering::Relaxed; +use std::task::{Context, Poll}; + +use crate::joins::sort_merge_join::filter::{ + FilterMetadata, filter_record_batch_by_join_type, get_corrected_filter_mask, + get_filter_columns, needs_deferred_filtering, +}; +use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics; +use crate::joins::utils::{JoinFilter, JoinKeyComparator}; +use crate::metrics::RecordOutput; +use crate::spill::spill_manager::SpillManager; +use crate::stream::EmptyRecordBatchStream; +use crate::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream}; + +use arrow::array::{types::UInt64Type, *}; +use arrow::compute::{ + self, BatchCoalescer, SortOptions, concat_batches, filter_record_batch, interleave, + take, take_arrays, +}; +use arrow::datatypes::SchemaRef; +use arrow::ipc::reader::StreamReader; +use datafusion_common::cast::as_uint64_array; +use datafusion_common::{JoinType, NullEquality, Result, exec_err, internal_err}; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_physical_expr_common::physical_expr::PhysicalExprRef; + +use futures::{Stream, StreamExt}; + +/// State of SMJ stream +#[derive(Debug, PartialEq, Eq)] +pub(super) enum SortMergeJoinState { + /// Init joining with a new streamed row or a new buffered batches + Init, + /// Polling one streamed row or one buffered batch, or both + Polling, + /// Joining polled data and making output + JoinOutput, + /// Emit ready data if have any and then go back to [`Self::Init`] state + EmitReadyThenInit, + /// No more output + Exhausted, +} + +/// State of streamed data stream +#[derive(Debug, PartialEq, Eq)] +pub(super) enum StreamedState { + /// Init polling + Init, + /// Polling one streamed row + Polling, + /// Ready to produce one streamed row + Ready, + /// No more streamed row + Exhausted, +} + +/// State of buffered data stream +#[derive(Debug, PartialEq, Eq)] +pub(super) enum BufferedState { + /// Init polling + Init, + /// Polling first row in the next batch + PollingFirst, + /// Polling rest rows in the next batch + PollingRest, + /// Ready to produce one batch + Ready, + /// No more buffered batches + Exhausted, +} + +/// Represents a chunk of joined data from streamed and buffered side +pub(super) struct StreamedJoinedChunk { + /// Index of batch in buffered_data + buffered_batch_idx: Option, + /// Array builder for streamed indices + streamed_indices: UInt64Builder, + /// Array builder for buffered indices + /// This could contain nulls if the join is null-joined + buffered_indices: UInt64Builder, +} + +/// Represents a record batch from streamed input. +/// +/// Also stores information of matching rows from buffered batches. +pub(super) struct StreamedBatch { + /// The streamed record batch + pub batch: RecordBatch, + /// The index of row in the streamed batch to compare with buffered batches + pub idx: usize, + /// The join key arrays of streamed batch which are used to compare with buffered batches + /// and to produce output. They are produced by evaluating `on` expressions. + pub join_arrays: Vec, + /// Chunks of indices from buffered side (may be nulls) joined to streamed + pub output_indices: Vec, + /// Total number of output rows across all chunks in `output_indices` + pub num_output_rows: usize, + /// Index of currently scanned batch from buffered data + pub buffered_batch_idx: Option, +} + +impl StreamedBatch { + fn new(batch: RecordBatch, on_column: &[Arc]) -> Self { + let join_arrays = join_arrays(&batch, on_column); + StreamedBatch { + batch, + idx: 0, + join_arrays, + output_indices: vec![], + num_output_rows: 0, + buffered_batch_idx: None, + } + } + + fn new_empty(schema: SchemaRef) -> Self { + StreamedBatch { + batch: RecordBatch::new_empty(schema), + idx: 0, + join_arrays: vec![], + output_indices: vec![], + num_output_rows: 0, + buffered_batch_idx: None, + } + } + + /// Number of unfrozen output pairs in this streamed batch + fn num_output_rows(&self) -> usize { + self.num_output_rows + } + + /// Appends new pair consisting of current streamed index and `buffered_idx` + /// index of buffered batch with `buffered_batch_idx` index. + fn append_output_pair( + &mut self, + buffered_batch_idx: Option, + buffered_idx: Option, + batch_size: usize, + ) { + // If no current chunk exists or current chunk is not for current buffered batch, + // create a new chunk + if self.output_indices.is_empty() || self.buffered_batch_idx != buffered_batch_idx + { + // Compute capacity only when creating a new chunk (infrequent operation). + // The capacity is the remaining space to reach batch_size. + // This should always be >= 1 since we only call this when num_output_rows < batch_size. + debug_assert!( + batch_size > self.num_output_rows, + "batch_size ({batch_size}) must be > num_output_rows ({})", + self.num_output_rows + ); + let capacity = batch_size - self.num_output_rows; + self.output_indices.push(StreamedJoinedChunk { + buffered_batch_idx, + streamed_indices: UInt64Builder::with_capacity(capacity), + buffered_indices: UInt64Builder::with_capacity(capacity), + }); + self.buffered_batch_idx = buffered_batch_idx; + }; + let current_chunk = self.output_indices.last_mut().unwrap(); + + // Append index of streamed batch and index of buffered batch into current chunk + current_chunk.streamed_indices.append_value(self.idx as u64); + if let Some(idx) = buffered_idx { + current_chunk.buffered_indices.append_value(idx as u64); + } else { + current_chunk.buffered_indices.append_null(); + } + self.num_output_rows += 1; + } +} + +/// Per-row filter outcome tracking for full outer joins. +/// +/// In a full outer join with a filter, buffered rows that match on join +/// keys but fail every filter evaluation must be emitted with NULLs on +/// the streamed side. Three states are needed because a simple boolean +/// cannot distinguish "never matched" (handled by [`BufferedBatch::null_joined`]) +/// from "matched but all filters failed" (must be emitted as null-joined). +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum FilterState { + /// Row never appeared in a matched pair. + Unvisited = 0, + /// Row matched streamed rows, but all filter evaluations failed. + AllFailed = 1, + /// Row matched and at least one filter evaluation passed. + SomePassed = 2, +} + +/// A buffered batch that contains contiguous rows with same join key +/// +/// `BufferedBatch` can exist as either an in-memory `RecordBatch` or a `RefCountedTempFile` on disk. +#[derive(Debug)] +pub(super) struct BufferedBatch { + /// Represents in memory or spilled record batch + pub batch: BufferedBatchState, + /// The range in which the rows share the same join key + pub range: Range, + /// Array refs of the join key + pub join_arrays: Vec, + /// Buffered joined index (null joining buffered) + pub null_joined: Vec, + /// Size estimation used for reserving / releasing memory + pub size_estimation: usize, + /// Memory footprint of `join_arrays` cached at construction time. + /// Used during spill to track the residual memory that remains after + /// the main batch is written to disk. + pub join_arrays_mem: usize, + /// Actual amount tracked in the memory reservation for this batch. + /// + /// - `InMemory`: equals `size_estimation` (full batch + join_arrays + metadata) + /// - `Spilled`: equals `join_arrays_mem` (join key arrays stay in memory) + /// + /// Invariant: `free_reservation()` shrinks by exactly this amount, so we never + /// shrink by more than we grew. + pub reserved_amount: usize, + /// Tracks filter outcomes for buffered rows in full outer joins. + /// Indexed by absolute row position within the batch. See [`FilterState`]. + pub join_filter_status: Vec, + /// Current buffered batch number of rows. Equal to batch.num_rows() + /// but if batch is spilled to disk this property is preferable + /// and less expensive + pub num_rows: usize, +} + +impl BufferedBatch { + fn new( + batch: RecordBatch, + range: Range, + on_column: &[PhysicalExprRef], + ) -> Self { + let join_arrays = join_arrays(&batch, on_column); + + // Estimation is calculated as + // inner batch size + // + join keys size + // + worst case null_joined (as vector capacity * element size) + // + Range size + // + size of this estimation + let join_arrays_mem: usize = join_arrays + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum(); + + let size_estimation = batch.get_array_memory_size() + + join_arrays_mem + + batch.num_rows().next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); + + let num_rows = batch.num_rows(); + BufferedBatch { + batch: BufferedBatchState::InMemory(batch), + range, + join_arrays, + null_joined: vec![], + size_estimation, + join_arrays_mem, + reserved_amount: 0, + join_filter_status: vec![FilterState::Unvisited; num_rows], + num_rows, + } + } +} + +// TODO: Spill join arrays (https://github.com/apache/datafusion/pull/17429) +// Used to represent whether the buffered data is currently in memory or written to disk +#[derive(Debug)] +pub(super) enum BufferedBatchState { + // In memory record batch + InMemory(RecordBatch), + // Spilled temp file + Spilled(RefCountedTempFile), +} + +/// Sort-Merge join stream for Inner/Left/Right/Full joins. +/// +/// Named "materializing" because it builds explicit `(streamed, buffered)` row +/// pairs in [`JoinedRecordBatches`] to produce output columns from both sides +/// of the join. +pub(super) struct MaterializingSortMergeJoinStream { + // ======================================================================== + // PROPERTIES: + // These fields are initialized at the start and remain constant throughout + // the execution. + // ======================================================================== + /// Output schema + pub schema: SchemaRef, + /// Defines the null equality for the join. + pub null_equality: NullEquality, + /// Sort options of join columns used to sort streamed and buffered data stream + pub sort_options: Vec, + /// optional join filter + pub filter: Option, + /// How the join is performed + pub join_type: JoinType, + /// Target output batch size + pub batch_size: usize, + + // ======================================================================== + // STREAMED FIELDS: + // These fields manage the properties and state of the streamed input. + // ======================================================================== + /// Input schema of streamed + pub streamed_schema: SchemaRef, + /// Streamed data stream + pub streamed: SendableRecordBatchStream, + /// Current processing record batch of streamed + pub streamed_batch: StreamedBatch, + /// (used in outer join) Is current streamed row joined at least once? + pub streamed_joined: bool, + /// State of streamed + pub streamed_state: StreamedState, + /// Join key columns of streamed + pub on_streamed: Vec, + + // ======================================================================== + // BUFFERED FIELDS: + // These fields manage the properties and state of the buffered input. + // ======================================================================== + /// Input schema of buffered + pub buffered_schema: SchemaRef, + /// Buffered data stream + pub buffered: SendableRecordBatchStream, + /// Current buffered data + pub buffered_data: BufferedData, + /// (used in outer join) Is current buffered batches joined at least once? + pub buffered_joined: bool, + /// State of buffered + pub buffered_state: BufferedState, + /// Join key columns of buffered + pub on_buffered: Vec, + + // ======================================================================== + // MERGE JOIN STATES: + // These fields track the execution state of merge join and are updated + // during the execution. + // ======================================================================== + /// Current state of the stream + pub state: SortMergeJoinState, + /// Staging output array builders + pub joined_record_batches: JoinedRecordBatches, + /// Output buffer. Currently used by filtering as it requires double buffering + /// to avoid small/empty batches. Non-filtered join outputs directly from `staging_output_record_batches.batches` + pub output: BatchCoalescer, + /// The comparison result of current streamed row and buffered batches + pub current_ordering: Ordering, + /// Manages the process of spilling and reading back intermediate data + pub spill_manager: SpillManager, + + // ======================================================================== + // CACHED COMPARATORS: + // Pre-built comparators to avoid per-row type dispatch in hot loops. + // ======================================================================== + /// Comparator for streamed vs buffered head batch key comparison + pub streamed_buffered_cmp: Option, + /// Comparator for buffered head vs tail batch equality check + pub buffered_equality_cmp: Option, + + // ======================================================================== + // EXECUTION RESOURCES: + // Fields related to managing execution resources and monitoring performance. + // ======================================================================== + /// Metrics + pub join_metrics: SortMergeJoinMetrics, + /// Memory reservation + pub reservation: MemoryReservation, + /// Runtime env + pub runtime_env: Arc, + /// A unique number for each batch + pub streamed_batch_counter: AtomicUsize, +} + +/// Staging area for joined data before output +/// +/// Accumulates joined rows until either: +/// - Target batch size reached (for efficiency) +/// - Stream exhausted (flush remaining data) +pub(super) struct JoinedRecordBatches { + /// Joined batches. Each batch is already joined columns from left and right sources + pub(super) joined_batches: BatchCoalescer, + /// Filter metadata for deferred filtering + pub(super) filter_metadata: FilterMetadata, +} + +impl JoinedRecordBatches { + /// Concatenates all accumulated batches into a single RecordBatch + /// + /// Must drain ALL batches from BatchCoalescer for filtered joins to ensure + /// metadata alignment when applying get_corrected_filter_mask(). + pub(super) fn concat_batches(&mut self, schema: &SchemaRef) -> Result { + self.joined_batches.finish_buffered_batch()?; + + let mut all_batches = vec![]; + while let Some(batch) = self.joined_batches.next_completed_batch() { + all_batches.push(batch); + } + + match all_batches.as_slice() { + [] => unreachable!("concat_batches called with empty BatchCoalescer"), + [single_batch] => Ok(single_batch.clone()), + multiple_batches => Ok(concat_batches(schema, multiple_batches)?), + } + } + + /// Clears batches without touching metadata (for early return when no filtering needed) + fn clear_batches(&mut self, schema: &SchemaRef, batch_size: usize) { + self.joined_batches = BatchCoalescer::new(Arc::clone(schema), batch_size) + .with_biggest_coalesce_batch_size(Option::from(batch_size / 2)); + } + + /// Asserts that if batches is empty, metadata is also empty + #[inline] + fn debug_assert_empty_consistency(&self) { + if self.joined_batches.is_empty() { + debug_assert_eq!( + self.filter_metadata.filter_mask.len(), + 0, + "filter_mask should be empty when batches is empty" + ); + debug_assert_eq!( + self.filter_metadata.row_indices.len(), + 0, + "row_indices should be empty when batches is empty" + ); + debug_assert_eq!( + self.filter_metadata.batch_ids.len(), + 0, + "batch_ids should be empty when batches is empty" + ); + } + } + + /// Pushes a batch with null metadata (rows that need no filter correction) + /// + /// Used for: (1) Full join buffered rows with no streamed match, and + /// (2) outer join streamed rows with no buffered match. These rows are + /// already in final form but must flow through the deferred filtering + /// pipeline to preserve output ordering. Null metadata causes + /// get_corrected_filter_mask() to pass them through unchanged. + /// + /// Maintains invariant: N rows → N metadata entries (nulls) + fn push_batch_with_null_metadata(&mut self, batch: RecordBatch, join_type: JoinType) { + debug_assert!( + matches!(join_type, JoinType::Left | JoinType::Right | JoinType::Full), + "push_batch_with_null_metadata should only be called for deferred-filtered joins" + ); + + let num_rows = batch.num_rows(); + + self.filter_metadata.append_nulls(num_rows); + + self.filter_metadata.debug_assert_metadata_aligned(); + self.joined_batches + .push_batch(batch) + .expect("Failed to push batch to BatchCoalescer"); + } + + /// Pushes a batch with filter metadata (filtered outer joins) + /// + /// Deferred filtering: An input row may join with multiple buffered rows, but we + /// don't know yet if all matches failed the filter. We track metadata so + /// `get_corrected_filter_mask()` can later group by input row and decide: + /// - If any match passed: emit passing rows + /// - If all matches failed: emit null-joined row + /// + /// Maintains invariant: N rows → N metadata entries + fn push_batch_with_filter_metadata( + &mut self, + batch: RecordBatch, + row_indices: &UInt64Array, + filter_mask: &BooleanArray, + streamed_batch_id: usize, + join_type: JoinType, + ) { + debug_assert!( + matches!(join_type, JoinType::Left | JoinType::Right | JoinType::Full), + "push_batch_with_filter_metadata should only be called for outer joins that need deferred filtering" + ); + + debug_assert_eq!( + row_indices.len(), + filter_mask.len(), + "row_indices and filter_mask must have same length" + ); + + self.filter_metadata.append_filter_metadata( + row_indices, + filter_mask, + streamed_batch_id, + ); + + self.filter_metadata.debug_assert_metadata_aligned(); + self.joined_batches + .push_batch(batch) + .expect("Failed to push batch to BatchCoalescer"); + } + + /// Pushes a batch without metadata (non-filtered joins) + /// + /// No deferred filtering needed. Either every join match is output (Inner), + /// or null-joined rows are handled separately. No need to track which input + /// row produced which output row. + fn push_batch_without_metadata(&mut self, batch: RecordBatch) { + self.joined_batches + .push_batch(batch) + .expect("Failed to push batch to BatchCoalescer"); + } + + fn clear(&mut self, schema: &SchemaRef, batch_size: usize) { + self.joined_batches = BatchCoalescer::new(Arc::clone(schema), batch_size) + .with_biggest_coalesce_batch_size(Option::from(batch_size / 2)); + self.filter_metadata = FilterMetadata::new(); + self.debug_assert_empty_consistency(); + } +} +impl RecordBatchStream for MaterializingSortMergeJoinStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +impl Stream for MaterializingSortMergeJoinStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let join_time = self.join_metrics.join_time().clone(); + let _timer = join_time.timer(); + loop { + match &self.state { + SortMergeJoinState::Init => { + let streamed_exhausted = + self.streamed_state == StreamedState::Exhausted; + let buffered_exhausted = + self.buffered_state == BufferedState::Exhausted; + self.state = if streamed_exhausted && buffered_exhausted { + SortMergeJoinState::Exhausted + } else { + match self.current_ordering { + Ordering::Less | Ordering::Equal => { + if !streamed_exhausted { + // Batch deferred filtering: process_filtered_batches() + // only when >= batch_size rows have accumulated. + // Without this gate, unique keys cause per-row pipeline + // execution (concat + correct_mask + filter_by_type), + // which dominates runtime. + // + // Accumulated rows are bounded to ~2*batch_size: + // one batch_size worth from freeze_dequeuing_buffered() + // (when an input batch is fully consumed), plus up to + // batch_size pairs accumulating toward the next freeze. + // This does not reintroduce the unbounded buffering + // fixed by PR #20482. Exhausted state flushes remainder. + if needs_deferred_filtering( + &self.filter, + self.join_type, + ) { + let accumulated = self.num_unfrozen_pairs() + + self + .joined_record_batches + .filter_metadata + .filter_mask + .len(); + if accumulated >= self.batch_size { + match self.process_filtered_batches()? { + Poll::Ready(Some(batch)) => { + return Poll::Ready(Some(Ok(batch))); + } + Poll::Ready(None) | Poll::Pending => {} + } + } + } + + self.streamed_joined = false; + self.streamed_state = StreamedState::Init; + } + } + Ordering::Greater => { + if !buffered_exhausted { + self.buffered_joined = false; + self.buffered_state = BufferedState::Init; + } + } + } + SortMergeJoinState::Polling + }; + } + SortMergeJoinState::Polling => { + if ![StreamedState::Exhausted, StreamedState::Ready] + .contains(&self.streamed_state) + { + match self.poll_streamed_row(cx)? { + Poll::Ready(_) => {} + Poll::Pending => return Poll::Pending, + } + } + + if ![BufferedState::Exhausted, BufferedState::Ready] + .contains(&self.buffered_state) + { + match self.poll_buffered_batches(cx)? { + Poll::Ready(_) => {} + Poll::Pending => return Poll::Pending, + } + } + let streamed_exhausted = + self.streamed_state == StreamedState::Exhausted; + let buffered_exhausted = + self.buffered_state == BufferedState::Exhausted; + if streamed_exhausted && buffered_exhausted { + self.state = SortMergeJoinState::Exhausted; + continue; + } + self.current_ordering = self.compare_streamed_buffered()?; + self.state = SortMergeJoinState::JoinOutput; + } + SortMergeJoinState::EmitReadyThenInit => { + // If have data to emit, emit it and if no more, change to next + + // Verify metadata alignment before checking if we have batches to output + self.joined_record_batches + .filter_metadata + .debug_assert_metadata_aligned(); + + // For filtered joins, skip output and let Init state handle it + if needs_deferred_filtering(&self.filter, self.join_type) { + self.state = SortMergeJoinState::Init; + continue; + } + + // For non-filtered joins, only output if we have a completed batch + // (opportunistic output when target batch size is reached) + if self + .joined_record_batches + .joined_batches + .has_completed_batch() + { + let record_batch = self + .joined_record_batches + .joined_batches + .next_completed_batch() + .expect("has_completed_batch was true"); + (&record_batch) + .record_output(&self.join_metrics.baseline_metrics()); + return Poll::Ready(Some(Ok(record_batch))); + } + self.state = SortMergeJoinState::Init; + } + SortMergeJoinState::JoinOutput => { + self.join_partial()?; + + if self.num_unfrozen_pairs() < self.batch_size { + if self.buffered_data.scanning_finished() { + self.buffered_data.scanning_reset(); + self.state = SortMergeJoinState::EmitReadyThenInit; + } + } else { + self.freeze_all()?; + + // Verify metadata alignment before checking if we have batches to output + self.joined_record_batches + .filter_metadata + .debug_assert_metadata_aligned(); + + // For filtered joins, skip output and let Init state handle it + if needs_deferred_filtering(&self.filter, self.join_type) { + continue; + } + + // For non-filtered joins, only output if we have a completed batch + // (opportunistic output when target batch size is reached) + if self + .joined_record_batches + .joined_batches + .has_completed_batch() + { + let record_batch = self + .joined_record_batches + .joined_batches + .next_completed_batch() + .expect("has_completed_batch was true"); + (&record_batch) + .record_output(&self.join_metrics.baseline_metrics()); + return Poll::Ready(Some(Ok(record_batch))); + } + // Otherwise keep buffering (don't output yet) + } + } + SortMergeJoinState::Exhausted => { + self.freeze_all()?; + + // Verify metadata alignment before final output + self.joined_record_batches + .filter_metadata + .debug_assert_metadata_aligned(); + + // For filtered joins, must concat and filter ALL data at once + if needs_deferred_filtering(&self.filter, self.join_type) + && !self.joined_record_batches.joined_batches.is_empty() + { + let record_batch = self.filter_joined_batch()?; + (&record_batch) + .record_output(&self.join_metrics.baseline_metrics()); + return Poll::Ready(Some(Ok(record_batch))); + } + + // For non-filtered joins, finish buffered data first + if !self.joined_record_batches.joined_batches.is_empty() { + self.joined_record_batches + .joined_batches + .finish_buffered_batch()?; + } + + // Output one completed batch at a time (stay in Exhausted until empty) + if self + .joined_record_batches + .joined_batches + .has_completed_batch() + { + let record_batch = self + .joined_record_batches + .joined_batches + .next_completed_batch() + .expect("has_completed_batch was true"); + (&record_batch) + .record_output(&self.join_metrics.baseline_metrics()); + return Poll::Ready(Some(Ok(record_batch))); + } + + // Finally check self.output BatchCoalescer (used by filtered joins) + return if !self.output.is_empty() { + self.output.finish_buffered_batch()?; + let record_batch = self + .output + .next_completed_batch() + .expect("Failed to get last batch"); + (&record_batch) + .record_output(&self.join_metrics.baseline_metrics()); + Poll::Ready(Some(Ok(record_batch))) + } else { + Poll::Ready(None) + }; + } + } + } + } +} + +impl MaterializingSortMergeJoinStream { + #[expect(clippy::too_many_arguments)] + pub fn try_new( + schema: SchemaRef, + sort_options: Vec, + null_equality: NullEquality, + streamed: SendableRecordBatchStream, + buffered: SendableRecordBatchStream, + on_streamed: Vec>, + on_buffered: Vec>, + filter: Option, + join_type: JoinType, + batch_size: usize, + join_metrics: SortMergeJoinMetrics, + reservation: MemoryReservation, + spill_manager: SpillManager, + runtime_env: Arc, + ) -> Result { + let streamed_schema = streamed.schema(); + let buffered_schema = buffered.schema(); + debug_assert!( + matches!( + join_type, + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full + ), + "MaterializingSortMergeJoinStream does not handle {join_type:?}; \ + semi/anti/mark joins use BitwiseSortMergeJoinStream" + ); + Ok(Self { + state: SortMergeJoinState::Init, + sort_options, + null_equality, + schema: Arc::clone(&schema), + streamed_schema: Arc::clone(&streamed_schema), + buffered_schema, + streamed, + buffered, + streamed_batch: StreamedBatch::new_empty(streamed_schema), + buffered_data: BufferedData::default(), + streamed_joined: false, + buffered_joined: false, + streamed_state: StreamedState::Init, + buffered_state: BufferedState::Init, + current_ordering: Ordering::Equal, + on_streamed, + on_buffered, + filter, + joined_record_batches: JoinedRecordBatches { + joined_batches: BatchCoalescer::new(Arc::clone(&schema), batch_size) + .with_biggest_coalesce_batch_size(Option::from(batch_size / 2)), + filter_metadata: FilterMetadata::new(), + }, + output: BatchCoalescer::new(schema, batch_size) + .with_biggest_coalesce_batch_size(Option::from(batch_size / 2)), + batch_size, + join_type, + join_metrics, + reservation, + runtime_env, + spill_manager, + streamed_buffered_cmp: None, + buffered_equality_cmp: None, + streamed_batch_counter: AtomicUsize::new(0), + }) + } + + /// Build a comparator for streamed vs buffered head batch keys. + fn rebuild_streamed_buffered_cmp(&mut self) -> Result<()> { + if self.streamed_batch.join_arrays.is_empty() + || !self.buffered_data.has_buffered_rows() + { + self.streamed_buffered_cmp = None; + return Ok(()); + } + self.streamed_buffered_cmp = Some(JoinKeyComparator::new( + &self.streamed_batch.join_arrays, + &self.buffered_data.head_batch().join_arrays, + &self.sort_options, + self.null_equality, + )?); + Ok(()) + } + + /// Build a comparator for buffered head vs tail batch equality. + fn rebuild_buffered_equality_cmp(&mut self) -> Result<()> { + if self.buffered_data.batches.is_empty() { + self.buffered_equality_cmp = None; + return Ok(()); + } + self.buffered_equality_cmp = Some(JoinKeyComparator::new( + &self.buffered_data.head_batch().join_arrays, + &self.buffered_data.tail_batch().join_arrays, + &self.sort_options, + // is_join_arrays_equal treats both-null as equal + NullEquality::NullEqualsNull, + )?); + Ok(()) + } + + /// Number of unfrozen output pairs (used to decide when to freeze + output) + fn num_unfrozen_pairs(&self) -> usize { + self.streamed_batch.num_output_rows() + } + + /// Process accumulated batches for filtered joins + /// + /// Freezes unfrozen pairs, applies deferred filtering, and outputs if ready. + /// Returns Poll::Ready with a batch if one is available, otherwise Poll::Pending. + fn process_filtered_batches(&mut self) -> Poll>> { + self.freeze_all()?; + + self.joined_record_batches + .filter_metadata + .debug_assert_metadata_aligned(); + + if !self.joined_record_batches.joined_batches.is_empty() { + let out_filtered_batch = self.filter_joined_batch()?; + self.output + .push_batch(out_filtered_batch) + .expect("Failed to push output batch"); + + if self.output.has_completed_batch() { + let record_batch = self + .output + .next_completed_batch() + .expect("Failed to get output batch"); + (&record_batch).record_output(&self.join_metrics.baseline_metrics()); + return Poll::Ready(Some(Ok(record_batch))); + } + } + + Poll::Pending + } + + /// Poll next streamed row + fn poll_streamed_row(&mut self, cx: &mut Context) -> Poll>> { + loop { + match &self.streamed_state { + StreamedState::Init => { + if self.streamed_batch.idx + 1 < self.streamed_batch.batch.num_rows() + { + self.streamed_batch.idx += 1; + self.streamed_state = StreamedState::Ready; + return Poll::Ready(Some(Ok(()))); + } else { + self.streamed_state = StreamedState::Polling; + } + } + StreamedState::Polling => match self.streamed.poll_next_unpin(cx)? { + Poll::Pending => { + return Poll::Pending; + } + Poll::Ready(None) => { + // Release the streamed input pipeline's resources. + let streamed_schema = self.streamed.schema(); + self.streamed = + Box::pin(EmptyRecordBatchStream::new(streamed_schema)); + self.streamed_state = StreamedState::Exhausted; + } + Poll::Ready(Some(batch)) => { + if batch.num_rows() > 0 { + self.freeze_streamed()?; + self.join_metrics.input_batches().add(1); + self.join_metrics.input_rows().add(batch.num_rows()); + self.streamed_batch = + StreamedBatch::new(batch, &self.on_streamed); + self.rebuild_streamed_buffered_cmp()?; + // Every incoming streaming batch should have its unique id + // Check `JoinedRecordBatches.self.streamed_batch_counter` documentation + self.streamed_batch_counter + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + self.streamed_state = StreamedState::Ready; + } + } + }, + StreamedState::Ready => { + return Poll::Ready(Some(Ok(()))); + } + StreamedState::Exhausted => { + return Poll::Ready(None); + } + } + } + } + + fn free_reservation(&mut self, buffered_batch: &BufferedBatch) { + if buffered_batch.reserved_amount > 0 { + self.reservation.shrink(buffered_batch.reserved_amount); + } + } + + fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> { + match self.reservation.try_grow(buffered_batch.size_estimation) { + Ok(_) => { + buffered_batch.reserved_amount = buffered_batch.size_estimation; + self.join_metrics + .peak_mem_used() + .set_max(self.reservation.size()); + Ok(()) + } + Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => { + // Spill buffered batch to disk + + match buffered_batch.batch { + BufferedBatchState::InMemory(batch) => { + let spill_file = self + .spill_manager + .spill_record_batch_and_finish( + &[batch], + "sort_merge_join_buffered_spill", + )? + .unwrap(); // Operation only return None if no batches are spilled, here we ensure that at least one batch is spilled + + buffered_batch.batch = BufferedBatchState::Spilled(spill_file); + + // Join key arrays remain in memory after the batch is + // spilled — the comparator needs them for key boundary + // detection. Force-grow the reservation so the pool + // reflects actual memory usage even if this pushes + // pool.reserved() above the configured limit. This is + // safe because the memory is physically consumed and + // not tracking it would let other operators over-allocate + // against a stale pool view. + let join_arrays_mem = buffered_batch.join_arrays_mem; + self.reservation.grow(join_arrays_mem); + buffered_batch.reserved_amount = join_arrays_mem; + self.join_metrics + .peak_mem_used() + .set_max(self.reservation.size()); + + Ok(()) + } + _ => internal_err!("Buffered batch has empty body"), + } + } + Err(e) => exec_err!("{}. Disk spilling disabled.", e.message()), + }?; + + self.buffered_data.batches.push_back(buffered_batch); + Ok(()) + } + + /// Poll next buffered batches + fn poll_buffered_batches(&mut self, cx: &mut Context) -> Poll>> { + loop { + match &self.buffered_state { + BufferedState::Init => { + // pop previous buffered batches + let mut head_changed = false; + while !self.buffered_data.batches.is_empty() { + let head_batch = self.buffered_data.head_batch(); + // If the head batch is fully processed, dequeue it and produce output of it. + if head_batch.range.end == head_batch.num_rows { + self.freeze_dequeuing_buffered()?; + if let Some(mut buffered_batch) = + self.buffered_data.batches.pop_front() + { + self.produce_buffered_not_matched(&mut buffered_batch)?; + self.free_reservation(&buffered_batch); + head_changed = true; + } + } else { + // If the head batch is not fully processed, break the loop. + // Streamed batch will be joined with the head batch in the next step. + break; + } + } + if head_changed { + self.streamed_buffered_cmp = None; + self.buffered_equality_cmp = None; + } + if self.buffered_data.batches.is_empty() { + self.buffered_state = BufferedState::PollingFirst; + } else { + let tail_batch = self.buffered_data.tail_batch_mut(); + tail_batch.range.start = tail_batch.range.end; + tail_batch.range.end += 1; + self.buffered_state = BufferedState::PollingRest; + } + } + BufferedState::PollingFirst => match self.buffered.poll_next_unpin(cx)? { + Poll::Pending => { + return Poll::Pending; + } + Poll::Ready(None) => { + // Release the buffered input pipeline's resources. + let buffered_schema = self.buffered.schema(); + self.buffered = + Box::pin(EmptyRecordBatchStream::new(buffered_schema)); + self.buffered_state = BufferedState::Exhausted; + return Poll::Ready(None); + } + Poll::Ready(Some(batch)) => { + self.join_metrics.input_batches().add(1); + self.join_metrics.input_rows().add(batch.num_rows()); + + if batch.num_rows() > 0 { + let buffered_batch = + BufferedBatch::new(batch, 0..1, &self.on_buffered); + + self.allocate_reservation(buffered_batch)?; + self.streamed_buffered_cmp = None; + self.buffered_state = BufferedState::PollingRest; + } + } + }, + BufferedState::PollingRest => { + if self.buffered_data.tail_batch().range.end + < self.buffered_data.tail_batch().num_rows + { + if self.buffered_equality_cmp.is_none() { + self.rebuild_buffered_equality_cmp()?; + } + while self.buffered_data.tail_batch().range.end + < self.buffered_data.tail_batch().num_rows + { + if self.buffered_equality_cmp.as_ref().unwrap().is_equal( + self.buffered_data.head_batch().range.start, + self.buffered_data.tail_batch().range.end, + ) { + self.buffered_data.tail_batch_mut().range.end += 1; + } else { + self.buffered_state = BufferedState::Ready; + return Poll::Ready(Some(Ok(()))); + } + } + } else { + match self.buffered.poll_next_unpin(cx)? { + Poll::Pending => { + return Poll::Pending; + } + Poll::Ready(None) => { + // Release the buffered input pipeline's resources. + let buffered_schema = self.buffered.schema(); + self.buffered = Box::pin(EmptyRecordBatchStream::new( + buffered_schema, + )); + self.buffered_state = BufferedState::Ready; + } + Poll::Ready(Some(batch)) => { + // Polling batches coming concurrently as multiple partitions + self.join_metrics.input_batches().add(1); + self.join_metrics.input_rows().add(batch.num_rows()); + if batch.num_rows() > 0 { + let buffered_batch = BufferedBatch::new( + batch, + 0..0, + &self.on_buffered, + ); + self.allocate_reservation(buffered_batch)?; + self.buffered_equality_cmp = None; + } + } + } + } + } + BufferedState::Ready => { + return Poll::Ready(Some(Ok(()))); + } + BufferedState::Exhausted => { + return Poll::Ready(None); + } + } + } + } + + /// Get comparison result of streamed row and buffered batches + fn compare_streamed_buffered(&mut self) -> Result { + if self.streamed_state == StreamedState::Exhausted { + return Ok(Ordering::Greater); + } + if !self.buffered_data.has_buffered_rows() { + return Ok(Ordering::Less); + } + + if self.streamed_buffered_cmp.is_none() { + self.rebuild_streamed_buffered_cmp()?; + } + Ok(self.streamed_buffered_cmp.as_ref().unwrap().compare( + self.streamed_batch.idx, + self.buffered_data.head_batch().range.start, + )) + } + + /// Produce join and fill output buffer until reaching target batch size + /// or the join is finished + fn join_partial(&mut self) -> Result<()> { + // Whether to join streamed rows + let mut join_streamed = false; + // Whether to join buffered rows + let mut join_buffered = false; + + // determine whether we need to join streamed/buffered rows + match self.current_ordering { + Ordering::Less => { + if matches!( + self.join_type, + JoinType::Left | JoinType::Right | JoinType::Full + ) { + join_streamed = !self.streamed_joined; + } + } + Ordering::Equal => { + join_streamed = true; + join_buffered = true; + } + Ordering::Greater => { + if self.join_type == JoinType::Full { + join_buffered = !self.buffered_joined; + }; + } + } + if !join_streamed && !join_buffered { + // no joined data + self.buffered_data.scanning_finish(); + return Ok(()); + } + + if join_buffered { + // joining streamed/nulls and buffered + while !self.buffered_data.scanning_finished() + && self.num_unfrozen_pairs() < self.batch_size + { + let scanning_idx = self.buffered_data.scanning_idx(); + if join_streamed { + // Join streamed row and buffered row + self.streamed_batch.append_output_pair( + Some(self.buffered_data.scanning_batch_idx), + Some(scanning_idx), + self.batch_size, + ); + } else { + // Join nulls and buffered row for FULL join + self.buffered_data + .scanning_batch_mut() + .null_joined + .push(scanning_idx); + } + self.buffered_data.scanning_advance(); + + if self.buffered_data.scanning_finished() { + self.streamed_joined = join_streamed; + self.buffered_joined = true; + } + } + } else { + // joining streamed and nulls + let scanning_batch_idx = if self.buffered_data.scanning_finished() { + None + } else { + Some(self.buffered_data.scanning_batch_idx) + }; + self.streamed_batch.append_output_pair( + scanning_batch_idx, + None, + self.batch_size, + ); + self.buffered_data.scanning_finish(); + self.streamed_joined = true; + } + Ok(()) + } + + fn freeze_all(&mut self) -> Result<()> { + self.freeze_buffered(self.buffered_data.batches.len())?; + self.freeze_streamed()?; + + // After freezing, metadata should be aligned + self.joined_record_batches + .filter_metadata + .debug_assert_metadata_aligned(); + + Ok(()) + } + + // Produces and stages record batches to ensure dequeued buffered batch + // no longer needed: + // 1. freezes all indices joined to streamed side + // 2. freezes NULLs joined to dequeued buffered batch to "release" it + fn freeze_dequeuing_buffered(&mut self) -> Result<()> { + self.freeze_streamed()?; + // Only freeze and produce the first batch in buffered_data as the batch is fully processed + self.freeze_buffered(1)?; + + // After freezing, metadata should be aligned + self.joined_record_batches + .filter_metadata + .debug_assert_metadata_aligned(); + + Ok(()) + } + + // Produces and stages record batch from buffered indices with corresponding + // NULLs on streamed side. + // + // Applicable only in case of Full join. + // + fn freeze_buffered(&mut self, batch_count: usize) -> Result<()> { + if self.join_type != JoinType::Full { + return Ok(()); + } + for buffered_batch in self.buffered_data.batches.range_mut(..batch_count) { + let buffered_indices = UInt64Array::from_iter_values( + buffered_batch.null_joined.iter().map(|&index| index as u64), + ); + if let Some(record_batch) = produce_buffered_null_batch( + &self.schema, + &self.streamed_schema, + &buffered_indices, + buffered_batch, + )? { + self.joined_record_batches + .push_batch_with_null_metadata(record_batch, self.join_type); + } + buffered_batch.null_joined.clear(); + } + Ok(()) + } + + fn produce_buffered_not_matched( + &mut self, + buffered_batch: &mut BufferedBatch, + ) -> Result<()> { + if self.join_type != JoinType::Full { + return Ok(()); + } + + // Collect buffered rows that matched on join keys but had every + // filter evaluation fail — these must be emitted with NULLs on + // the streamed side to satisfy full outer join semantics. + let not_matched_buffered_indices = buffered_batch + .join_filter_status + .iter() + .enumerate() + .filter_map(|(i, state)| { + matches!(state, FilterState::AllFailed).then_some(i as u64) + }) + .collect::>(); + + let buffered_indices = + UInt64Array::from_iter_values(not_matched_buffered_indices.iter().copied()); + + if let Some(record_batch) = produce_buffered_null_batch( + &self.schema, + &self.streamed_schema, + &buffered_indices, + buffered_batch, + )? { + self.joined_record_batches + .push_batch_with_null_metadata(record_batch, self.join_type); + } + buffered_batch + .join_filter_status + .fill(FilterState::Unvisited); + + Ok(()) + } + + // Produces and stages record batch for all output indices found + // for current streamed batch and clears staged output indices. + // + // Null-joined chunks (no buffered match) are pushed immediately. + // Matched chunks are collected and processed together in + // freeze_streamed_matched() to amortize filter evaluation overhead. + fn freeze_streamed(&mut self) -> Result<()> { + let mut matched_chunks: Vec<(usize, UInt64Array, UInt64Array)> = Vec::new(); + let mut total_matched_rows: usize = 0; + + for chunk in self.streamed_batch.output_indices.iter_mut() { + let left_indices = chunk.streamed_indices.finish(); + if left_indices.is_empty() { + continue; + } + let right_indices: UInt64Array = chunk.buffered_indices.finish(); + + if chunk.buffered_batch_idx.is_none() { + let left_columns = + materialize_left_columns(&self.streamed_batch.batch, &left_indices)?; + let right_columns = + create_unmatched_columns(&self.buffered_schema, left_indices.len()); + + let columns = if self.join_type != JoinType::Right { + [left_columns, right_columns].concat() + } else { + [right_columns, left_columns].concat() + }; + let batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + + // Null-joined rows (no buffered match) need no filter correction, + // but must flow through the same pipeline as matched rows to + // preserve output ordering. Use null metadata as a sentinel so + // get_corrected_filter_mask() passes them through unchanged. + if needs_deferred_filtering(&self.filter, self.join_type) { + self.joined_record_batches + .push_batch_with_null_metadata(batch, self.join_type); + } else { + self.joined_record_batches + .push_batch_without_metadata(batch); + } + continue; + } + + total_matched_rows += left_indices.len(); + matched_chunks.push(( + chunk.buffered_batch_idx.unwrap(), + left_indices, + right_indices, + )); + } + + if !matched_chunks.is_empty() { + self.freeze_streamed_matched(&matched_chunks, total_matched_rows)?; + } + + self.streamed_batch.output_indices.clear(); + self.streamed_batch.num_output_rows = 0; + Ok(()) + } + + /// Materializes columns, evaluates the join filter, and pushes output + /// for all matched chunks in a single batch. This avoids per-chunk + /// RecordBatch construction and filter evaluation, which dominates + /// cost when keys are near-unique (1 row per chunk). + fn freeze_streamed_matched( + &mut self, + matched_chunks: &[(usize, UInt64Array, UInt64Array)], + total_matched_rows: usize, + ) -> Result<()> { + debug_assert!( + !matched_chunks.is_empty(), + "caller guards this with an is_empty check before calling" + ); + debug_assert!( + matched_chunks.iter().all(|(idx, left, right)| { + left.len() == right.len() && *idx < self.buffered_data.batches.len() + }), + "left/right indices are built in pairs from the same streamed×buffered cross, \ + and batch_idx comes from iterating buffered_data.batches" + ); + debug_assert_eq!( + matched_chunks + .iter() + .map(|(_, l, _)| l.len()) + .sum::(), + total_matched_rows, + "total_matched_rows is accumulated from the same chunks in freeze_streamed" + ); + + let combined_left_indices = if matched_chunks.len() == 1 { + matched_chunks[0].1.clone() + } else { + let refs: Vec<&dyn Array> = + matched_chunks.iter().map(|c| &c.1 as &dyn Array).collect(); + as_uint64_array(&compute::concat(&refs)?)?.clone() + }; + + let left_columns = + materialize_left_columns(&self.streamed_batch.batch, &combined_left_indices)?; + + let right_columns = + self.materialize_right_columns(matched_chunks, total_matched_rows)?; + + let filter_columns = if self.join_type == JoinType::Right { + get_filter_columns(&self.filter, &right_columns, &left_columns) + } else { + get_filter_columns(&self.filter, &left_columns, &right_columns) + }; + + let columns = if self.join_type != JoinType::Right { + [left_columns, right_columns].concat() + } else { + [right_columns, left_columns].concat() + }; + let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + + if !filter_columns.is_empty() { + if let Some(f) = &self.filter { + let filter_batch = + RecordBatch::try_new(Arc::clone(f.schema()), filter_columns)?; + let filter_result = f + .expression() + .evaluate(&filter_batch)? + .into_array(filter_batch.num_rows())?; + + let filter_result_mask = + datafusion_common::cast::as_boolean_array(&filter_result)?; + + // Convert NULL filter results to false — NULL means "not satisfied" + // per SQL semantics, same as Left/Right outer joins. + let mask = if filter_result_mask.null_count() > 0 { + compute::prep_null_mask_filter(filter_result_mask) + } else { + filter_result_mask.clone() + }; + + if needs_deferred_filtering(&self.filter, self.join_type) { + self.joined_record_batches.push_batch_with_filter_metadata( + output_batch, + &combined_left_indices, + &mask, + self.streamed_batch_counter.load(Relaxed), + self.join_type, + ); + } else { + let filtered_batch = filter_record_batch(&output_batch, &mask)?; + self.joined_record_batches + .push_batch_without_metadata(filtered_batch); + } + + // Track which buffered rows had all filter matches fail, + // so full join can emit them as null-joined later. + if self.join_type == JoinType::Full { + let mut offset = 0usize; + for (batch_idx, _left, right) in matched_chunks { + let chunk_len = right.len(); + let buffered_batch = &mut self.buffered_data.batches[*batch_idx]; + + for i in 0..chunk_len { + if right.is_null(i) { + continue; + } + let idx = right.value(i) as usize; + match buffered_batch.join_filter_status[idx] { + FilterState::SomePassed => {} + _ if mask.value(offset + i) => { + buffered_batch.join_filter_status[idx] = + FilterState::SomePassed; + } + _ => { + buffered_batch.join_filter_status[idx] = + FilterState::AllFailed; + } + } + } + offset += chunk_len; + } + debug_assert_eq!( + offset, total_matched_rows, + "offset must advance through every chunk exactly once" + ); + } + } + } else { + self.joined_record_batches + .push_batch_without_metadata(output_batch); + } + + Ok(()) + } + + /// Materializes right-side columns across all matched chunks. + /// + /// When chunks reference a single buffered batch, indices are concatenated + /// for a single fetch. When multiple batches are involved, `interleave` + /// gathers columns across sources. A null-row sentinel at source index 0 + /// handles null right indices (unmatched streamed rows). + fn materialize_right_columns( + &mut self, + matched_chunks: &[(usize, UInt64Array, UInt64Array)], + total_matched_rows: usize, + ) -> Result> { + let first_batch_idx = matched_chunks[0].0; + let single_source = matched_chunks.iter().all(|c| c.0 == first_batch_idx); + + if single_source { + let combined_right_indices = if matched_chunks.len() == 1 { + matched_chunks[0].2.clone() + } else { + let refs: Vec<&dyn Array> = + matched_chunks.iter().map(|c| &c.2 as &dyn Array).collect(); + as_uint64_array(&compute::concat(&refs)?)?.clone() + }; + + let spill_reservation = self.reservation.new_empty(); + if matches!( + &self.buffered_data.batches[first_batch_idx].batch, + BufferedBatchState::Spilled(_) + ) { + spill_reservation + .grow(self.buffered_data.batches[first_batch_idx].size_estimation); + self.join_metrics + .peak_mem_used() + .set_max(self.reservation.size() + spill_reservation.size()); + } + + return fetch_right_columns_by_idxs( + &self.buffered_data, + first_batch_idx, + &combined_right_indices, + ); + } + + // Multiple source batches: map each buffered_batch_idx to a + // contiguous source index, reserving source 0 for a null sentinel. + let mut batch_idx_to_source: HashMap = HashMap::new(); + let mut source_batches: Vec = Vec::new(); + for (batch_idx, _, _) in matched_chunks { + batch_idx_to_source.entry(*batch_idx).or_insert_with(|| { + let idx = source_batches.len() + 1; + source_batches.push(*batch_idx); + idx + }); + } + + let mut interleave_indices: Vec<(usize, usize)> = + Vec::with_capacity(total_matched_rows); + for (batch_idx, _, right) in matched_chunks { + let source = batch_idx_to_source[batch_idx]; + for i in 0..right.len() { + if right.is_null(i) { + interleave_indices.push((0, 0)); + } else { + interleave_indices.push((source, right.value(i) as usize)); + } + } + } + + let num_right_cols = self.buffered_schema.fields().len(); + + // Read each source batch once (spilled batches require disk I/O). + // Track memory for each spilled batch at the point of deserialization + // so the pool reflects actual usage as it grows. + let spill_reservation = self.reservation.new_empty(); + let mut source_data: Vec> = + Vec::with_capacity(source_batches.len()); + for &idx in &source_batches { + let bb = &self.buffered_data.batches[idx]; + match &bb.batch { + BufferedBatchState::InMemory(batch) => { + source_data.push(Some(batch.clone())); + } + BufferedBatchState::Spilled(spill_file) => { + spill_reservation.grow(bb.size_estimation); + self.join_metrics + .peak_mem_used() + .set_max(self.reservation.size() + spill_reservation.size()); + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = StreamReader::try_new(file, None)?; + source_data.push(reader.into_iter().next().transpose()?); + } + } + } + + let mut right_columns = Vec::with_capacity(num_right_cols); + for col_idx in 0..num_right_cols { + let dtype = self.buffered_schema.field(col_idx).data_type(); + let null_array = new_null_array(dtype, 1); + + let mut source_arrays: Vec<&dyn Array> = + Vec::with_capacity(source_batches.len() + 1); + source_arrays.push(null_array.as_ref()); + + for data in &source_data { + match data { + Some(batch) => source_arrays.push(batch.column(col_idx).as_ref()), + None => { + return internal_err!( + "Failed to read spilled buffered batch during interleave" + ); + } + } + } + right_columns.push(interleave(&source_arrays, &interleave_indices)?); + } + + Ok(right_columns) + } + + fn filter_joined_batch(&mut self) -> Result { + // Metadata should be aligned before processing + self.joined_record_batches + .filter_metadata + .debug_assert_metadata_aligned(); + + let record_batch = self.joined_record_batches.concat_batches(&self.schema)?; + let (mut out_indices, mut out_mask, mut batch_ids) = + self.joined_record_batches.filter_metadata.finish_metadata(); + let default_batch_ids = vec![0; record_batch.num_rows()]; + + // If only nulls come in and indices sizes doesn't match with expected record batch count + // generate missing indices + // Happens for null joined batches for Full Join + if out_indices.null_count() == out_indices.len() + && out_indices.len() != record_batch.num_rows() + { + out_mask = BooleanArray::from(vec![None; record_batch.num_rows()]); + out_indices = UInt64Array::from(vec![None; record_batch.num_rows()]); + batch_ids = &default_batch_ids; + } + + // After potential reconstruction, metadata should align with batch row count + debug_assert_eq!( + out_indices.len(), + record_batch.num_rows(), + "out_indices length should match record_batch row count" + ); + debug_assert_eq!( + out_mask.len(), + record_batch.num_rows(), + "out_mask length should match record_batch row count (unless empty)" + ); + debug_assert_eq!( + batch_ids.len(), + record_batch.num_rows(), + "batch_ids length should match record_batch row count" + ); + + if out_mask.is_empty() { + self.joined_record_batches + .clear_batches(&self.schema, self.batch_size); + return Ok(record_batch); + } + + // Validate inputs to get_corrected_filter_mask + debug_assert_eq!( + out_indices.len(), + out_mask.len(), + "out_indices and out_mask must have same length for get_corrected_filter_mask" + ); + debug_assert_eq!( + batch_ids.len(), + out_mask.len(), + "batch_ids and out_mask must have same length for get_corrected_filter_mask" + ); + + let maybe_corrected_mask = get_corrected_filter_mask( + self.join_type, + &out_indices, + batch_ids, + &out_mask, + record_batch.num_rows(), + ); + + let corrected_mask = if let Some(ref filtered_join_mask) = maybe_corrected_mask { + filtered_join_mask + } else { + &out_mask + }; + + self.filter_record_batch_by_join_type(&record_batch, corrected_mask) + } + + fn filter_record_batch_by_join_type( + &mut self, + record_batch: &RecordBatch, + corrected_mask: &BooleanArray, + ) -> Result { + let filtered_record_batch = filter_record_batch_by_join_type( + record_batch, + corrected_mask, + self.join_type, + &self.schema, + &self.buffered_schema, + )?; + + self.joined_record_batches + .clear(&self.schema, self.batch_size); + + Ok(filtered_record_batch) + } +} + +/// Materialize left (streamed) columns using slice or take. +fn materialize_left_columns( + batch: &RecordBatch, + indices: &UInt64Array, +) -> Result> { + if let Some(range) = is_contiguous_range(indices) { + Ok(batch.slice(range.start, range.len()).columns().to_vec()) + } else { + Ok(take_arrays(batch.columns(), indices, None)?) + } +} + +fn create_unmatched_columns(schema: &SchemaRef, size: usize) -> Vec { + schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), size)) + .collect::>() +} + +fn produce_buffered_null_batch( + schema: &SchemaRef, + streamed_schema: &SchemaRef, + buffered_indices: &PrimitiveArray, + buffered_batch: &BufferedBatch, +) -> Result> { + if buffered_indices.is_empty() { + return Ok(None); + } + + // Take buffered (right) columns + let right_columns = + fetch_right_columns_from_batch_by_idxs(buffered_batch, buffered_indices)?; + + // Create null streamed (left) columns + let mut left_columns = streamed_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), buffered_indices.len())) + .collect::>(); + + left_columns.extend(right_columns); + + Ok(Some(RecordBatch::try_new( + Arc::clone(schema), + left_columns, + )?)) +} + +/// Checks if a `UInt64Array` contains a contiguous ascending range (e.g. \[3,4,5,6\]). +/// Returns `Some(start..start+len)` if so, `None` otherwise. +/// This allows replacing an O(n) `take` with an O(1) `slice`. +#[inline] +fn is_contiguous_range(indices: &UInt64Array) -> Option> { + if indices.is_empty() || indices.null_count() > 0 { + return None; + } + let values = indices.values(); + let start = values[0]; + let len = values.len() as u64; + // Quick rejection: if last element doesn't match expected, not contiguous + if values[values.len() - 1] != start + len - 1 { + return None; + } + // Verify every element is sequential (handles duplicates and gaps) + for i in 1..values.len() { + if values[i] != start + i as u64 { + return None; + } + } + Some(start as usize..(start + len) as usize) +} + +/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` by specific column indices +#[inline(always)] +fn fetch_right_columns_by_idxs( + buffered_data: &BufferedData, + buffered_batch_idx: usize, + buffered_indices: &UInt64Array, +) -> Result> { + fetch_right_columns_from_batch_by_idxs( + &buffered_data.batches[buffered_batch_idx], + buffered_indices, + ) +} + +#[inline(always)] +fn fetch_right_columns_from_batch_by_idxs( + buffered_batch: &BufferedBatch, + buffered_indices: &UInt64Array, +) -> Result> { + match &buffered_batch.batch { + // In memory batch + // In memory batch + BufferedBatchState::InMemory(batch) => { + // When indices form a contiguous range (common in SMJ since the + // buffered side is scanned sequentially), use zero-copy slice. + if let Some(range) = is_contiguous_range(buffered_indices) { + Ok(batch.slice(range.start, range.len()).columns().to_vec()) + } else { + Ok(take_arrays(batch.columns(), buffered_indices, None)?) + } + } + // If the batch was spilled to disk, less likely + BufferedBatchState::Spilled(spill_file) => { + let mut buffered_cols: Vec = + Vec::with_capacity(buffered_indices.len()); + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = StreamReader::try_new(file, None)?; + + for batch in reader { + batch?.columns().iter().for_each(|column| { + buffered_cols.extend(take(column, &buffered_indices, None)) + }); + } + + Ok(buffered_cols) + } + } +} + +/// Buffered data contains all buffered batches with one unique join key +#[derive(Debug, Default)] +pub(super) struct BufferedData { + /// Buffered batches with the same key + pub batches: VecDeque, + /// current scanning batch index used in join_partial() + pub scanning_batch_idx: usize, + /// current scanning offset used in join_partial() + pub scanning_offset: usize, +} + +impl BufferedData { + pub fn head_batch(&self) -> &BufferedBatch { + self.batches.front().unwrap() + } + + pub fn tail_batch(&self) -> &BufferedBatch { + self.batches.back().unwrap() + } + + pub fn tail_batch_mut(&mut self) -> &mut BufferedBatch { + self.batches.back_mut().unwrap() + } + + pub fn has_buffered_rows(&self) -> bool { + self.batches.iter().any(|batch| !batch.range.is_empty()) + } + + pub fn scanning_reset(&mut self) { + self.scanning_batch_idx = 0; + self.scanning_offset = 0; + } + + pub fn scanning_advance(&mut self) { + self.scanning_offset += 1; + while !self.scanning_finished() && self.scanning_batch_finished() { + self.scanning_batch_idx += 1; + self.scanning_offset = 0; + } + } + + pub fn scanning_batch(&self) -> &BufferedBatch { + &self.batches[self.scanning_batch_idx] + } + + pub fn scanning_batch_mut(&mut self) -> &mut BufferedBatch { + &mut self.batches[self.scanning_batch_idx] + } + + pub fn scanning_idx(&self) -> usize { + self.scanning_batch().range.start + self.scanning_offset + } + + pub fn scanning_batch_finished(&self) -> bool { + self.scanning_offset == self.scanning_batch().range.len() + } + + pub fn scanning_finished(&self) -> bool { + self.scanning_batch_idx == self.batches.len() + } + + pub fn scanning_finish(&mut self) { + self.scanning_batch_idx = self.batches.len(); + self.scanning_offset = 0; + } +} + +/// Get join array refs of given batch and join columns +fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec { + on_column + .iter() + .map(|c| { + let num_rows = batch.num_rows(); + let c = c.evaluate(batch).unwrap(); + c.into_array(num_rows).unwrap() + }) + .collect() +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/metrics.rs b/datafusion/physical-plan/src/joins/sort_merge_join/metrics.rs index ac476853d5d75..62efb77f877ab 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/metrics.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/metrics.rs @@ -18,12 +18,11 @@ //! Module for tracking Sort Merge Join metrics use crate::metrics::{ - BaselineMetrics, Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, SpillMetrics, - Time, + BaselineMetrics, Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, + MetricCategory, Time, }; /// Metrics for SortMergeJoinExec -#[allow(dead_code)] pub(super) struct SortMergeJoinMetrics { /// Total time for joining probe-side batches to the build-side batches join_time: Time, @@ -36,19 +35,20 @@ pub(super) struct SortMergeJoinMetrics { /// Peak memory used for buffered data. /// Calculated as sum of peak memory values across partitions peak_mem_used: Gauge, - /// Metrics related to spilling - spill_metrics: SpillMetrics, } impl SortMergeJoinMetrics { - #[allow(dead_code)] pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition); - let input_batches = - MetricBuilder::new(metrics).counter("input_batches", partition); - let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); - let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition); - let spill_metrics = SpillMetrics::new(metrics, partition); + let input_batches = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("input_rows", partition); + let peak_mem_used = MetricBuilder::new(metrics) + .with_category(MetricCategory::Bytes) + .gauge("peak_mem_used", partition); let baseline_metrics = BaselineMetrics::new(metrics, partition); @@ -58,7 +58,6 @@ impl SortMergeJoinMetrics { input_rows, baseline_metrics, peak_mem_used, - spill_metrics, } } @@ -81,8 +80,4 @@ impl SortMergeJoinMetrics { pub fn peak_mem_used(&self) -> Gauge { self.peak_mem_used.clone() } - - pub fn spill_metrics(&self) -> SpillMetrics { - self.spill_metrics.clone() - } } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs b/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs index 82f18e7414095..2fdb0924e723d 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs @@ -19,9 +19,11 @@ pub use exec::SortMergeJoinExec; +pub(crate) mod bitwise_stream; mod exec; +mod filter; +pub(crate) mod materializing_stream; mod metrics; -mod stream; #[cfg(test)] mod tests; diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index f91bffbed78fc..c4377b3189ff7 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -24,42 +24,52 @@ //! //! Add relevant tests under the specified sections. +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; +use super::bitwise_stream::BitwiseSortMergeJoinStream; +use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn}; +use crate::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; +use crate::metrics::{ExecutionPlanMetricsSet, SpillMetrics}; +use crate::spill::spill_manager::SpillManager; +use crate::test::TestMemoryExec; +use crate::test::exec::BarrierExec; +use crate::test::{build_table_i32, build_table_i32_two_cols}; +use crate::{ExecutionPlan, RecordBatchStream, common}; +use crate::{ + expressions::Column, joins::sort_merge_join::filter::get_corrected_filter_mask, + joins::sort_merge_join::materializing_stream::JoinedRecordBatches, +}; use arrow::array::{ - builder::{BooleanBuilder, UInt64Builder}, BinaryArray, BooleanArray, Date32Array, Date64Array, FixedSizeBinaryArray, Int32Array, RecordBatch, UInt64Array, }; -use arrow::compute::{concat_batches, filter_record_batch, SortOptions}; +use arrow::compute::{BatchCoalescer, SortOptions, filter_record_batch}; use arrow::datatypes::{DataType, Field, Schema}; - +use arrow_ord::sort::SortColumn; +use arrow_schema::SchemaRef; use datafusion_common::JoinType::*; use datafusion_common::{ - assert_batches_eq, assert_contains, JoinType, NullEquality, Result, + JoinSide, internal_err, + test_util::{batches_to_sort_string, batches_to_string}, }; use datafusion_common::{ - test_util::{batches_to_sort_string, batches_to_string}, - JoinSide, + JoinType, NullEquality, Result, ScalarValue, assert_batches_eq, assert_contains, }; +use datafusion_common_runtime::JoinSet; use datafusion_execution::config::SessionConfig; use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode}; +use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::runtime_env::RuntimeEnvBuilder; -use datafusion_execution::TaskContext; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::BinaryExpr; -use insta::{allow_duplicates, assert_snapshot}; - -use crate::{ - expressions::Column, - joins::sort_merge_join::stream::{get_corrected_filter_mask, JoinedRecordBatches}, -}; - -use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn}; -use crate::joins::SortMergeJoinExec; -use crate::test::TestMemoryExec; -use crate::test::{build_table_i32, build_table_i32_two_cols}; -use crate::{common, ExecutionPlan}; +use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr_common::physical_expr::PhysicalExprRef; +use futures::{Stream, StreamExt}; +use insta::assert_snapshot; +use itertools::Itertools; fn build_table( a: (&str, &Vec), @@ -365,15 +375,15 @@ async fn join_inner_one() -> Result<()> { let (_, batches) = join_collect(left, right, on, Inner).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+----+----+----+----+ - | 1 | 4 | 7 | 10 | 4 | 70 | - | 2 | 5 | 8 | 20 | 5 | 80 | - | 3 | 5 | 9 | 20 | 5 | 80 | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 10 | 4 | 70 | + | 2 | 5 | 8 | 20 | 5 | 80 | + | 3 | 5 | 9 | 20 | 5 | 80 | + +----+----+----+----+----+----+ + "); Ok(()) } @@ -403,15 +413,15 @@ async fn join_inner_two() -> Result<()> { let (_columns, batches) = join_collect(left, right, on, Inner).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b2 | c1 | a1 | b2 | c2 | - +----+----+----+----+----+----+ - | 1 | 1 | 7 | 1 | 1 | 70 | - | 2 | 2 | 8 | 2 | 2 | 80 | - | 2 | 2 | 9 | 2 | 2 | 80 | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b2 | c1 | a1 | b2 | c2 | + +----+----+----+----+----+----+ + | 1 | 1 | 7 | 1 | 1 | 70 | + | 2 | 2 | 8 | 2 | 2 | 80 | + | 2 | 2 | 9 | 2 | 2 | 80 | + +----+----+----+----+----+----+ + "); Ok(()) } @@ -441,16 +451,16 @@ async fn join_inner_two_two() -> Result<()> { let (_columns, batches) = join_collect(left, right, on, Inner).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b2 | c1 | a1 | b2 | c2 | - +----+----+----+----+----+----+ - | 1 | 1 | 7 | 1 | 1 | 70 | - | 1 | 1 | 7 | 1 | 1 | 80 | - | 1 | 1 | 8 | 1 | 1 | 70 | - | 1 | 1 | 8 | 1 | 1 | 80 | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b2 | c1 | a1 | b2 | c2 | + +----+----+----+----+----+----+ + | 1 | 1 | 7 | 1 | 1 | 70 | + | 1 | 1 | 7 | 1 | 1 | 80 | + | 1 | 1 | 8 | 1 | 1 | 70 | + | 1 | 1 | 8 | 1 | 1 | 80 | + +----+----+----+----+----+----+ + "); Ok(()) } @@ -479,15 +489,15 @@ async fn join_inner_with_nulls() -> Result<()> { let (_, batches) = join_collect(left, right, on, Inner).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b2 | c1 | a1 | b2 | c2 | - +----+----+----+----+----+----+ - | 1 | 1 | | 1 | 1 | 70 | - | 2 | 2 | 8 | 2 | 2 | 80 | - | 2 | 2 | 9 | 2 | 2 | 80 | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b2 | c1 | a1 | b2 | c2 | + +----+----+----+----+----+----+ + | 1 | 1 | | 1 | 1 | 70 | + | 2 | 2 | 8 | 2 | 2 | 80 | + | 2 | 2 | 9 | 2 | 2 | 80 | + +----+----+----+----+----+----+ + "); Ok(()) } @@ -529,16 +539,16 @@ async fn join_inner_with_nulls_with_options() -> Result<()> { ) .await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b2 | c1 | a1 | b2 | c2 | - +----+----+----+----+----+----+ - | 2 | 2 | 9 | 2 | 2 | 80 | - | 2 | 2 | 8 | 2 | 2 | 80 | - | 1 | 1 | | 1 | 1 | 70 | - | 1 | | 1 | 1 | | 10 | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b2 | c1 | a1 | b2 | c2 | + +----+----+----+----+----+----+ + | 2 | 2 | 9 | 2 | 2 | 80 | + | 2 | 2 | 8 | 2 | 2 | 80 | + | 1 | 1 | | 1 | 1 | 70 | + | 1 | | 1 | 1 | | 10 | + +----+----+----+----+----+----+ + "); Ok(()) } @@ -570,15 +580,15 @@ async fn join_inner_output_two_batches() -> Result<()> { assert_eq!(batches[0].num_rows(), 2); assert_eq!(batches[1].num_rows(), 1); // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b2 | c1 | a1 | b2 | c2 | - +----+----+----+----+----+----+ - | 1 | 1 | 7 | 1 | 1 | 70 | - | 2 | 2 | 8 | 2 | 2 | 80 | - | 2 | 2 | 9 | 2 | 2 | 80 | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b2 | c1 | a1 | b2 | c2 | + +----+----+----+----+----+----+ + | 1 | 1 | 7 | 1 | 1 | 70 | + | 2 | 2 | 8 | 2 | 2 | 80 | + | 2 | 2 | 9 | 2 | 2 | 80 | + +----+----+----+----+----+----+ + "); Ok(()) } @@ -601,15 +611,15 @@ async fn join_left_one() -> Result<()> { let (_, batches) = join_collect(left, right, on, Left).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+----+----+----+----+ - | 1 | 4 | 7 | 10 | 4 | 70 | - | 2 | 5 | 8 | 20 | 5 | 80 | - | 3 | 7 | 9 | | | | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 10 | 4 | 70 | + | 2 | 5 | 8 | 20 | 5 | 80 | + | 3 | 7 | 9 | | | | + +----+----+----+----+----+----+ + "); Ok(()) } @@ -632,15 +642,15 @@ async fn join_right_one() -> Result<()> { let (_, batches) = join_collect(left, right, on, Right).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+----+----+----+----+ - | 1 | 4 | 7 | 10 | 4 | 70 | - | 2 | 5 | 8 | 20 | 5 | 80 | - | | | | 30 | 6 | 90 | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 10 | 4 | 70 | + | 2 | 5 | 8 | 20 | 5 | 80 | + | | | | 30 | 6 | 90 | + +----+----+----+----+----+----+ + "); Ok(()) } @@ -690,15 +700,15 @@ async fn join_right_different_columns_count_with_filter() -> Result<()> { let (_, batches) = join_collect_with_filter(left, right, on, filter, Right).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | - +----+----+----+----+----+ - | | | | 10 | 4 | - | 21 | 5 | 8 | 20 | 5 | - | | | | 30 | 6 | - +----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | + +----+----+----+----+----+ + | | | | 10 | 4 | + | 21 | 5 | 8 | 20 | 5 | + | | | | 30 | 6 | + +----+----+----+----+----+ + "); Ok(()) } @@ -748,15 +758,15 @@ async fn join_left_different_columns_count_with_filter() -> Result<()> { let (_, batches) = join_collect_with_filter(left, right, on, filter, Left).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+ - | a2 | b1 | a1 | b1 | c1 | - +----+----+----+----+----+ - | 10 | 4 | 1 | 4 | 7 | - | 20 | 5 | | | | - | 30 | 6 | | | | - +----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+----+----+ + | a2 | b1 | a1 | b1 | c1 | + +----+----+----+----+----+ + | 10 | 4 | 1 | 4 | 7 | + | 20 | 5 | | | | + | 30 | 6 | | | | + +----+----+----+----+----+ + "); Ok(()) } @@ -807,15 +817,15 @@ async fn join_left_mark_different_columns_count_with_filter() -> Result<()> { let (_, batches) = join_collect_with_filter(left, right, on, filter, LeftMark).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+-------+ - | a2 | b1 | mark | - +----+----+-------+ - | 10 | 4 | true | - | 20 | 5 | false | - | 30 | 6 | false | - +----+----+-------+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+-------+ + | a2 | b1 | mark | + +----+----+-------+ + | 10 | 4 | true | + | 20 | 5 | false | + | 30 | 6 | false | + +----+----+-------+ + "); Ok(()) } @@ -866,15 +876,15 @@ async fn join_right_mark_different_columns_count_with_filter() -> Result<()> { let (_, batches) = join_collect_with_filter(left, right, on, filter, RightMark).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+-------+ - | a2 | b1 | mark | - +----+----+-------+ - | 10 | 4 | false | - | 20 | 5 | true | - | 30 | 6 | false | - +----+----+-------+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+-------+ + | a2 | b1 | mark | + +----+----+-------+ + | 10 | 4 | false | + | 20 | 5 | true | + | 30 | 6 | false | + +----+----+-------+ + "); Ok(()) } @@ -897,16 +907,16 @@ async fn join_full_one() -> Result<()> { let (_, batches) = join_collect(left, right, on, Full).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_sort_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+----+----+----+----+ - | | | | 30 | 6 | 90 | - | 1 | 4 | 7 | 10 | 4 | 70 | - | 2 | 5 | 8 | 20 | 5 | 80 | - | 3 | 7 | 9 | | | | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+----+ + | | | | 30 | 6 | 90 | + | 1 | 4 | 7 | 10 | 4 | 70 | + | 2 | 5 | 8 | 20 | 5 | 80 | + | 3 | 7 | 9 | | | | + +----+----+----+----+----+----+ + "); Ok(()) } @@ -930,14 +940,14 @@ async fn join_left_anti() -> Result<()> { let (_, batches) = join_collect(left, right, on, LeftAnti).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c1 | - +----+----+----+ - | 3 | 7 | 9 | - | 5 | 7 | 11 | - +----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 3 | 7 | 9 | + | 5 | 7 | 11 | + +----+----+----+ + "); Ok(()) } @@ -956,13 +966,13 @@ async fn join_right_anti_one_one() -> Result<()> { let (_, batches) = join_collect(left, right, on, RightAnti).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+ - | a2 | b1 | - +----+----+ - | 30 | 6 | - +----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+ + | a2 | b1 | + +----+----+ + | 30 | 6 | + +----+----+ + "); let left2 = build_table( ("a1", &vec![1, 2, 2]), @@ -982,13 +992,13 @@ async fn join_right_anti_one_one() -> Result<()> { let (_, batches2) = join_collect(left2, right2, on, RightAnti).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches2), @r#" - +----+----+----+ - | a2 | b1 | c2 | - +----+----+----+ - | 30 | 6 | 90 | - +----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches2), @r" + +----+----+----+ + | a2 | b1 | c2 | + +----+----+----+ + | 30 | 6 | 90 | + +----+----+----+ + "); Ok(()) } @@ -1014,15 +1024,15 @@ async fn join_right_anti_two_two() -> Result<()> { let (_, batches) = join_collect(left, right, on, RightAnti).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+ - | a2 | b1 | - +----+----+ - | 10 | 4 | - | 20 | 5 | - | 30 | 6 | - +----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+ + | a2 | b1 | + +----+----+ + | 10 | 4 | + | 20 | 5 | + | 30 | 6 | + +----+----+ + "); let left = build_table( ("a1", &vec![1, 2, 2]), @@ -1099,13 +1109,68 @@ async fn join_right_anti_two_with_filter() -> Result<()> { ); let (_, batches) = join_collect_with_filter(left, right, on, filter, RightAnti).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c2 | - +----+----+----+ - | 1 | 10 | 20 | - +----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+ + | a1 | b1 | c2 | + +----+----+----+ + | 1 | 10 | 20 | + +----+----+----+ + "); + Ok(()) +} + +#[tokio::test] +async fn join_right_anti_filtered_with_mismatched_columns() -> Result<()> { + let left = build_table_two_cols(("a1", &vec![31, 31]), ("b1", &vec![32, 33])); + let right = build_table( + ("a2", &vec![31, 31]), + ("b2", &vec![32, 35]), + ("c2", &vec![108, 109]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + ), + ]; + + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b1", 0)), + Operator::LtEq, + Arc::new(Column::new("c2", 1)), + )), + vec![ + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ], + Arc::new(Schema::new(vec![ + Field::new("b1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])), + ); + + let (_, batches) = + join_collect_with_filter(left, right, on, filter, RightAnti).await?; + + let expected = [ + "+----+----+-----+", + "| a2 | b2 | c2 |", + "+----+----+-----+", + "| 31 | 35 | 109 |", + "+----+----+-----+", + ]; + assert_batches_eq!(expected, &batches); Ok(()) } @@ -1134,13 +1199,13 @@ async fn join_right_anti_with_nulls() -> Result<()> { let (_, batches) = join_collect(left, right, on, RightAnti).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c2 | - +----+----+----+ - | 2 | | 8 | - +----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+ + | a1 | b1 | c2 | + +----+----+----+ + | 2 | | 8 | + +----+----+----+ + "); Ok(()) } @@ -1184,15 +1249,15 @@ async fn join_right_anti_with_nulls_with_options() -> Result<()> { .await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c2 | - +----+----+----+ - | 3 | | 9 | - | 2 | 5 | | - | 2 | 5 | 8 | - +----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+ + | a1 | b1 | c2 | + +----+----+----+ + | 3 | | 9 | + | 2 | 5 | | + | 2 | 5 | 8 | + +----+----+----+ + "); Ok(()) } @@ -1221,18 +1286,19 @@ async fn join_right_anti_output_two_batches() -> Result<()> { let (_, batches) = join_collect_batch_size_equals_two(left, right, on, LeftAnti).await?; - assert_eq!(batches.len(), 2); - assert_eq!(batches[0].num_rows(), 2); - assert_eq!(batches[1].num_rows(), 1); - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c1 | - +----+----+----+ - | 1 | 4 | 7 | - | 2 | 5 | 8 | - | 2 | 5 | 8 | - +----+----+----+ - "#); + // BitwiseSortMergeJoinStream uses a coalescer, so batch boundaries differ + // from the old stream. Only assert data correctness. + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 3); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | 4 | 7 | + | 2 | 5 | 8 | + | 2 | 5 | 8 | + +----+----+----+ + "); Ok(()) } @@ -1255,15 +1321,15 @@ async fn join_left_semi() -> Result<()> { let (_, batches) = join_collect(left, right, on, LeftSemi).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c1 | - +----+----+----+ - | 1 | 4 | 7 | - | 2 | 5 | 8 | - | 2 | 5 | 8 | - +----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | 4 | 7 | + | 2 | 5 | 8 | + | 2 | 5 | 8 | + +----+----+----+ + "); Ok(()) } @@ -1509,9 +1575,10 @@ async fn join_right_semi_output_two_batches() -> Result<()> { "| 2 | 5 | 8 |", "+----+----+----+", ]; - assert_eq!(batches.len(), 2); - assert_eq!(batches[0].num_rows(), 2); - assert_eq!(batches[1].num_rows(), 1); + // BitwiseSortMergeJoinStream uses a coalescer, so batch boundaries differ + // from the old stream. Only assert data correctness. + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 3); assert_batches_eq!(expected, &batches); Ok(()) } @@ -1535,16 +1602,16 @@ async fn join_left_mark() -> Result<()> { let (_, batches) = join_collect(left, right, on, LeftMark).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+-------+ - | a1 | b1 | c1 | mark | - +----+----+----+-------+ - | 1 | 4 | 7 | true | - | 2 | 5 | 8 | true | - | 2 | 5 | 8 | true | - | 3 | 7 | 9 | false | - +----+----+----+-------+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+-------+ + | a1 | b1 | c1 | mark | + +----+----+----+-------+ + | 1 | 4 | 7 | true | + | 2 | 5 | 8 | true | + | 2 | 5 | 8 | true | + | 3 | 7 | 9 | false | + +----+----+----+-------+ + "); Ok(()) } @@ -1567,16 +1634,16 @@ async fn join_right_mark() -> Result<()> { let (_, batches) = join_collect(left, right, on, RightMark).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+-------+ - | a2 | b1 | c2 | mark | - +----+----+----+-------+ - | 10 | 4 | 60 | true | - | 20 | 4 | 70 | true | - | 30 | 5 | 80 | true | - | 40 | 6 | 90 | false | - +----+----+----+-------+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+-------+ + | a2 | b1 | c2 | mark | + +----+----+----+-------+ + | 10 | 4 | 60 | true | + | 20 | 4 | 70 | true | + | 30 | 5 | 80 | true | + | 40 | 6 | 90 | false | + +----+----+----+-------+ + "); Ok(()) } @@ -1600,14 +1667,14 @@ async fn join_with_duplicated_column_names() -> Result<()> { let (_, batches) = join_collect(left, right, on, Inner).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +---+---+---+----+---+----+ - | a | b | c | a | b | c | - +---+---+---+----+---+----+ - | 1 | 4 | 7 | 10 | 1 | 70 | - | 2 | 5 | 8 | 20 | 2 | 80 | - +---+---+---+----+---+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +---+---+---+----+---+----+ + | a | b | c | a | b | c | + +---+---+---+----+---+----+ + | 1 | 4 | 7 | 10 | 1 | 70 | + | 2 | 5 | 8 | 20 | 2 | 80 | + +---+---+---+----+---+----+ + "); Ok(()) } @@ -1632,15 +1699,15 @@ async fn join_date32() -> Result<()> { let (_, batches) = join_collect(left, right, on, Inner).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +------------+------------+------------+------------+------------+------------+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +------------+------------+------------+------------+------------+------------+ - | 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 | - | 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 | - | 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 | - +------------+------------+------------+------------+------------+------------+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +------------+------------+------------+------------+------------+------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +------------+------------+------------+------------+------------+------------+ + | 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 | + | 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 | + | 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 | + +------------+------------+------------+------------+------------+------------+ + "); Ok(()) } @@ -1665,15 +1732,15 @@ async fn join_date64() -> Result<()> { let (_, batches) = join_collect(left, right, on, Inner).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ - | 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 | - | 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | - | 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | - +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 | + | 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + | 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + "); Ok(()) } @@ -1712,15 +1779,15 @@ async fn join_binary() -> Result<()> { let (_, batches) = join_collect(left, right, on, Inner).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +--------+----+----+--------+-----+----+ - | a1 | b1 | c1 | a1 | b2 | c2 | - +--------+----+----+--------+-----+----+ - | c0ffee | 5 | 7 | c0ffee | 105 | 70 | - | decade | 10 | 8 | decade | 110 | 80 | - | facade | 15 | 9 | facade | 115 | 90 | - +--------+----+----+--------+-----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +--------+----+----+--------+-----+----+ + | a1 | b1 | c1 | a1 | b2 | c2 | + +--------+----+----+--------+-----+----+ + | c0ffee | 5 | 7 | c0ffee | 105 | 70 | + | decade | 10 | 8 | decade | 110 | 80 | + | facade | 15 | 9 | facade | 115 | 90 | + +--------+----+----+--------+-----+----+ + "); Ok(()) } @@ -1759,15 +1826,15 @@ async fn join_fixed_size_binary() -> Result<()> { let (_, batches) = join_collect(left, right, on, Inner).await?; // The output order is important as SMJ preserves sortedness - assert_snapshot!(batches_to_string(&batches), @r#" - +--------+----+----+--------+-----+----+ - | a1 | b1 | c1 | a1 | b2 | c2 | - +--------+----+----+--------+-----+----+ - | c0ffee | 5 | 7 | c0ffee | 105 | 70 | - | decade | 10 | 8 | decade | 110 | 80 | - | facade | 15 | 9 | facade | 115 | 90 | - +--------+----+----+--------+-----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +--------+----+----+--------+-----+----+ + | a1 | b1 | c1 | a1 | b2 | c2 | + +--------+----+----+--------+-----+----+ + | c0ffee | 5 | 7 | c0ffee | 105 | 70 | + | decade | 10 | 8 | decade | 110 | 80 | + | facade | 15 | 9 | facade | 115 | 90 | + +--------+----+----+--------+-----+----+ + "); Ok(()) } @@ -1789,20 +1856,20 @@ async fn join_left_sort_order() -> Result<()> { )]; let (_, batches) = join_collect(left, right, on, Left).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+----+----+----+----+ - | 0 | 3 | 4 | | | | - | 1 | 4 | 5 | 10 | 4 | 60 | - | 2 | 5 | 6 | | | | - | 3 | 6 | 7 | 20 | 6 | 70 | - | 3 | 6 | 7 | 30 | 6 | 80 | - | 4 | 6 | 8 | 20 | 6 | 70 | - | 4 | 6 | 8 | 30 | 6 | 80 | - | 5 | 7 | 9 | | | | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+----+ + | 0 | 3 | 4 | | | | + | 1 | 4 | 5 | 10 | 4 | 60 | + | 2 | 5 | 6 | | | | + | 3 | 6 | 7 | 20 | 6 | 70 | + | 3 | 6 | 7 | 30 | 6 | 80 | + | 4 | 6 | 8 | 20 | 6 | 70 | + | 4 | 6 | 8 | 30 | 6 | 80 | + | 5 | 7 | 9 | | | | + +----+----+----+----+----+----+ + "); Ok(()) } @@ -1824,16 +1891,16 @@ async fn join_right_sort_order() -> Result<()> { )]; let (_, batches) = join_collect(left, right, on, Right).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+----+----+----+----+ - | | | | 0 | 2 | 60 | - | 1 | 4 | 7 | 10 | 4 | 70 | - | 2 | 5 | 8 | 20 | 5 | 80 | - | | | | 30 | 6 | 90 | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+----+ + | | | | 0 | 2 | 60 | + | 1 | 4 | 7 | 10 | 4 | 70 | + | 2 | 5 | 8 | 20 | 5 | 80 | + | | | | 30 | 6 | 90 | + +----+----+----+----+----+----+ + "); Ok(()) } @@ -1867,21 +1934,21 @@ async fn join_left_multiple_batches() -> Result<()> { )]; let (_, batches) = join_collect(left, right, on, Left).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+----+----+----+----+ - | 0 | 3 | 4 | | | | - | 1 | 4 | 5 | 10 | 4 | 60 | - | 2 | 5 | 6 | | | | - | 3 | 6 | 7 | 20 | 6 | 70 | - | 3 | 6 | 7 | 30 | 6 | 80 | - | 4 | 6 | 8 | 20 | 6 | 70 | - | 4 | 6 | 8 | 30 | 6 | 80 | - | 5 | 7 | 9 | | | | - | 6 | 9 | 9 | | | | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+----+ + | 0 | 3 | 4 | | | | + | 1 | 4 | 5 | 10 | 4 | 60 | + | 2 | 5 | 6 | | | | + | 3 | 6 | 7 | 20 | 6 | 70 | + | 3 | 6 | 7 | 30 | 6 | 80 | + | 4 | 6 | 8 | 20 | 6 | 70 | + | 4 | 6 | 8 | 30 | 6 | 80 | + | 5 | 7 | 9 | | | | + | 6 | 9 | 9 | | | | + +----+----+----+----+----+----+ + "); Ok(()) } @@ -1915,21 +1982,21 @@ async fn join_right_multiple_batches() -> Result<()> { )]; let (_, batches) = join_collect(left, right, on, Right).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+----+----+----+----+ - | | | | 0 | 3 | 4 | - | 10 | 4 | 60 | 1 | 4 | 5 | - | | | | 2 | 5 | 6 | - | 20 | 6 | 70 | 3 | 6 | 7 | - | 30 | 6 | 80 | 3 | 6 | 7 | - | 20 | 6 | 70 | 4 | 6 | 8 | - | 30 | 6 | 80 | 4 | 6 | 8 | - | | | | 5 | 7 | 9 | - | | | | 6 | 9 | 9 | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+----+ + | | | | 0 | 3 | 4 | + | 10 | 4 | 60 | 1 | 4 | 5 | + | | | | 2 | 5 | 6 | + | 20 | 6 | 70 | 3 | 6 | 7 | + | 30 | 6 | 80 | 3 | 6 | 7 | + | 20 | 6 | 70 | 4 | 6 | 8 | + | 30 | 6 | 80 | 4 | 6 | 8 | + | | | | 5 | 7 | 9 | + | | | | 6 | 9 | 9 | + +----+----+----+----+----+----+ + "); Ok(()) } @@ -1963,23 +2030,125 @@ async fn join_full_multiple_batches() -> Result<()> { )]; let (_, batches) = join_collect(left, right, on, Full).await?; - assert_snapshot!(batches_to_sort_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b2 | c2 | - +----+----+----+----+----+----+ - | | | | 0 | 2 | 50 | - | | | | 40 | 8 | 90 | - | 0 | 3 | 4 | | | | - | 1 | 4 | 5 | 10 | 4 | 60 | - | 2 | 5 | 6 | | | | - | 3 | 6 | 7 | 20 | 6 | 70 | - | 3 | 6 | 7 | 30 | 6 | 80 | - | 4 | 6 | 8 | 20 | 6 | 70 | - | 4 | 6 | 8 | 30 | 6 | 80 | - | 5 | 7 | 9 | | | | - | 6 | 9 | 9 | | | | - +----+----+----+----+----+----+ - "#); + assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b2 | c2 | + +----+----+----+----+----+----+ + | | | | 0 | 2 | 50 | + | | | | 40 | 8 | 90 | + | 0 | 3 | 4 | | | | + | 1 | 4 | 5 | 10 | 4 | 60 | + | 2 | 5 | 6 | | | | + | 3 | 6 | 7 | 20 | 6 | 70 | + | 3 | 6 | 7 | 30 | 6 | 80 | + | 4 | 6 | 8 | 20 | 6 | 70 | + | 4 | 6 | 8 | 30 | 6 | 80 | + | 5 | 7 | 9 | | | | + | 6 | 9 | 9 | | | | + +----+----+----+----+----+----+ + "); + Ok(()) +} + +/// Full outer join where the filter evaluates to NULL due to a nullable column. +/// NULL filter results must be treated as unmatched, not matched. +/// Reproducer for SPARK-43113. +#[tokio::test] +async fn join_full_null_filter_result() -> Result<()> { + // Left: (a, b) all non-null, sorted on a + let left = build_table_two_cols( + ("a1", &vec![1, 1, 2, 2, 3, 3]), + ("b1", &vec![1, 2, 1, 2, 1, 2]), + ); + + // Right: (a, b) with b nullable, sorted on a + let right_schema = Arc::new(Schema::new(vec![ + Field::new("a2", DataType::Int32, false), + Field::new("b2", DataType::Int32, true), + ])); + let right_batch = RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(Int32Array::from(vec![None, Some(2)])), + ], + )?; + let right = + TestMemoryExec::try_new_exec(&[vec![right_batch]], right_schema, None).unwrap(); + + let on = vec![( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, + )]; + + // Filter: b1 < (b2 + 1) AND b1 < (a2 + 1) + // When b2 is NULL, (b2 + 1) is NULL, so b1 < NULL is NULL → unmatched. + let lit_1: PhysicalExprRef = Arc::new(Literal::new(ScalarValue::Int32(Some(1)))); + let b1_lt_b2_plus_1: PhysicalExprRef = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b1", 0)), + Operator::Lt, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b2", 1)), + Operator::Plus, + Arc::clone(&lit_1), + )), + )); + let b1_lt_a2_plus_1: PhysicalExprRef = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b1", 0)), + Operator::Lt, + Arc::new(BinaryExpr::new( + Arc::new(Column::new("a2", 2)), + Operator::Plus, + Arc::clone(&lit_1), + )), + )); + let filter_expr: PhysicalExprRef = Arc::new(BinaryExpr::new( + b1_lt_b2_plus_1, + Operator::And, + b1_lt_a2_plus_1, + )); + + let filter = JoinFilter::new( + filter_expr, + vec![ + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 1, + side: JoinSide::Right, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ], + Arc::new(Schema::new(vec![ + Field::new("b1", DataType::Int32, true), + Field::new("b2", DataType::Int32, true), + Field::new("a2", DataType::Int32, true), + ])), + ); + + let (_, batches) = join_collect_with_filter(left, right, on, filter, Full).await?; + + // r=(1,NULL): b2 is NULL → b1 < (NULL+1) is NULL → all a=1 rows unmatched + // r=(2,2): b1 < 3 AND b1 < 3 → both l=(2,1) and l=(2,2) match + // l=(3,*): no right row with a=3 → unmatched + assert_snapshot!(batches_to_sort_string(&batches), @r" + +----+----+----+----+ + | a1 | b1 | a2 | b2 | + +----+----+----+----+ + | | | 1 | | + | 1 | 1 | | | + | 1 | 2 | | | + | 2 | 1 | 2 | 2 | + | 2 | 2 | 2 | 2 | + | 3 | 1 | | | + | 3 | 2 | | | + +----+----+----+----+ + "); Ok(()) } @@ -2002,7 +2171,9 @@ async fn overallocation_single_batch_no_spill() -> Result<()> { let sort_options = vec![SortOptions::default(); on.len()]; let join_types = vec![ - Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, RightMark, + // Semi/anti/mark joins use BitwiseSortMergeJoinStream which only tracks + // inner key buffer memory; tested in bitwise_sort_merge_join/tests.rs. + Inner, Left, Right, Full, ]; // Disable DiskManager to prevent spilling @@ -2083,7 +2254,9 @@ async fn overallocation_multi_batch_no_spill() -> Result<()> { let sort_options = vec![SortOptions::default(); on.len()]; let join_types = vec![ - Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, RightMark, + // Semi/anti/mark joins use BitwiseSortMergeJoinStream which only tracks + // inner key buffer memory; tested in bitwise_sort_merge_join/tests.rs. + Inner, Left, Right, Full, ]; // Disable DiskManager to prevent spilling @@ -2143,7 +2316,9 @@ async fn overallocation_single_batch_spill() -> Result<()> { let sort_options = vec![SortOptions::default(); on.len()]; let join_types = [ - Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, RightMark, + // Semi/anti/mark joins use BitwiseSortMergeJoinStream which only tracks + // inner key buffer memory; tested in bitwise_sort_merge_join/tests.rs. + Inner, Left, Right, Full, ]; // Enable DiskManager to allow spilling @@ -2247,7 +2422,9 @@ async fn overallocation_multi_batch_spill() -> Result<()> { let sort_options = vec![SortOptions::default(); on.len()]; let join_types = [ - Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, RightMark, + // Semi/anti/mark joins use BitwiseSortMergeJoinStream which only tracks + // inner key buffer memory; tested in bitwise_sort_merge_join/tests.rs. + Inner, Left, Right, Full, ]; // Enable DiskManager to allow spilling @@ -2310,672 +2487,2463 @@ async fn overallocation_multi_batch_spill() -> Result<()> { Ok(()) } -fn build_joined_record_batches() -> Result { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - Field::new("x", DataType::Int32, true), - Field::new("y", DataType::Int32, true), - ])); - - let mut batches = JoinedRecordBatches { - batches: vec![], - filter_mask: BooleanBuilder::new(), - row_indices: UInt64Builder::new(), - batch_ids: vec![], - }; +/// Verifies that `peak_mem_used` reflects join_arrays memory on the spill path. +/// +/// Uses a memory limit smaller than a single batch's `size_estimation` so that +/// every batch spills — the `Ok` arm of `allocate_reservation` is never hit. +/// Before the fix, `peak_mem_used` would stay 0 because `set_max` was only +/// called in the `Ok` arm. After the fix, the spill path calls +/// `grow(join_arrays_mem)` + `set_max`, so `peak_mem_used > 0`. +#[tokio::test] +async fn spill_join_arrays_memory_accounting() -> Result<()> { + use arrow::array::Array; - // Insert already prejoined non-filtered rows - batches.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![10, 10])), - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![11, 9])), - ], - )?); + let left_batch = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let size_estimation = left_batch.get_array_memory_size() + + Int32Array::from(vec![1, 1]).get_array_memory_size() + + 2usize.next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); + let join_arrays_mem = Int32Array::from(vec![1, 1]).get_array_memory_size(); + + // Memory limit: too small for a full batch, large enough for join_arrays. + // Every batch hits the Err arm → spills → grow(join_arrays_mem). + let memory_limit = (size_estimation + join_arrays_mem) / 2; + assert!( + memory_limit < size_estimation && memory_limit > join_arrays_mem, + "limit {memory_limit} must be between join_arrays_mem {join_arrays_mem} \ + and size_estimation {size_estimation}" + ); - batches.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![11])), - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![12])), - ], - )?); + let left_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a1", &vec![i * 2, i * 2 + 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![100 + i, 101 + i]), + ) + }) + .collect(); + let left = build_table_from_batches(left_batches); + + let right_batches: Vec = (0..2) + .map(|i| { + build_table_i32( + ("a2", &vec![i * 2, i * 2 + 1]), + ("b2", &vec![1, 1]), + ("c2", &vec![200 + i, 201 + i]), + ) + }) + .collect(); + let right = build_table_from_batches(right_batches); - batches.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![12, 12])), - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![11, 13])), - ], - )?); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; - batches.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![13])), - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![12])), - ], - )?); + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; - batches.batches.push(RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![14, 14])), - Arc::new(Int32Array::from(vec![1, 1])), - Arc::new(Int32Array::from(vec![12, 11])), - ], - )?); + let session_config = SessionConfig::default().with_batch_size(50); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(Arc::clone(&runtime)), + ); - let streamed_indices = vec![0, 0]; - batches.batch_ids.extend(vec![0; streamed_indices.len()]); - batches - .row_indices - .extend(&UInt64Array::from(streamed_indices)); + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Inner, + sort_options, + NullEquality::NullEqualsNothing, + )?; - let streamed_indices = vec![1]; - batches.batch_ids.extend(vec![0; streamed_indices.len()]); - batches - .row_indices - .extend(&UInt64Array::from(streamed_indices)); + let stream = join.execute(0, task_ctx)?; + let result = common::collect(stream).await.unwrap(); - let streamed_indices = vec![0, 0]; - batches.batch_ids.extend(vec![1; streamed_indices.len()]); - batches - .row_indices - .extend(&UInt64Array::from(streamed_indices)); + assert!(!result.is_empty(), "Expected non-empty join result"); - let streamed_indices = vec![0]; - batches.batch_ids.extend(vec![2; streamed_indices.len()]); - batches - .row_indices - .extend(&UInt64Array::from(streamed_indices)); + let metrics = join.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to occur" + ); - let streamed_indices = vec![0, 0]; - batches.batch_ids.extend(vec![3; streamed_indices.len()]); - batches - .row_indices - .extend(&UInt64Array::from(streamed_indices)); + // Before the fix, peak_mem_used was 0 here because set_max was only + // called in the Ok arm of allocate_reservation, which is never reached + // when every batch spills. After the fix, the spill path calls + // grow(join_arrays_mem) + set_max unconditionally. + let peak_mem = metrics + .sum_by_name("peak_mem_used") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert!( + peak_mem >= join_arrays_mem, + "peak_mem_used ({peak_mem}) should be >= join_arrays_mem ({join_arrays_mem})" + ); - batches - .filter_mask - .extend(&BooleanArray::from(vec![true, false])); - batches.filter_mask.extend(&BooleanArray::from(vec![true])); - batches - .filter_mask - .extend(&BooleanArray::from(vec![false, true])); - batches.filter_mask.extend(&BooleanArray::from(vec![false])); - batches - .filter_mask - .extend(&BooleanArray::from(vec![false, false])); + // All memory must be released (grow/shrink balanced, no underflow) + assert_eq!( + runtime.memory_pool.reserved(), + 0, + "All memory should be released after join completes" + ); - Ok(batches) + Ok(()) } +/// Test the no-headroom scenario: pool is so tight that even +/// join_arrays_mem exceeds the pool limit. With force-grow, the +/// reservation still tracks the join_arrays unconditionally so the +/// pool reflects actual memory usage. #[tokio::test] -async fn test_left_outer_join_filtered_mask() -> Result<()> { - let mut joined_batches = build_joined_record_batches()?; - let schema = joined_batches.batches.first().unwrap().schema(); +async fn spill_join_arrays_no_headroom() -> Result<()> { + use arrow::array::Array; - let output = concat_batches(&schema, &joined_batches.batches)?; - let out_mask = joined_batches.filter_mask.finish(); - let out_indices = joined_batches.row_indices.finish(); + let join_arrays_mem = Int32Array::from(vec![1, 1]).get_array_memory_size(); - assert_eq!( - get_corrected_filter_mask( - Left, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![true, false, false, false, false, false, false, false]) + // Pool smaller than join_arrays_mem: try_grow(size_estimation) fails → spill. + // Force-grow(join_arrays_mem) succeeds unconditionally → reserved_amount > 0. + let memory_limit = join_arrays_mem / 2; + assert!( + memory_limit < join_arrays_mem, + "limit {memory_limit} must be smaller than join_arrays_mem {join_arrays_mem}" ); - assert_eq!( - get_corrected_filter_mask( - Left, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![false]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![false, false, false, false, false, false, false, false]) - ); + let left_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a1", &vec![i * 2, i * 2 + 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![100 + i, 101 + i]), + ) + }) + .collect(); + let left = build_table_from_batches(left_batches); + + let right_batches: Vec = (0..2) + .map(|i| { + build_table_i32( + ("a2", &vec![i * 2, i * 2 + 1]), + ("b2", &vec![1, 1]), + ("c2", &vec![200 + i, 201 + i]), + ) + }) + .collect(); + let right = build_table_from_batches(right_batches); - assert_eq!( - get_corrected_filter_mask( - Left, - &UInt64Array::from(vec![0, 0]), - &[0usize; 2], - &BooleanArray::from(vec![true, true]), - output.num_rows() + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), ) - .unwrap(), - BooleanArray::from(vec![true, true, false, false, false, false, false, false]) + .build_arc()?; + + let session_config = SessionConfig::default().with_batch_size(50); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(Arc::clone(&runtime)), ); - assert_eq!( - get_corrected_filter_mask( - Left, + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Inner, + sort_options, + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let result = common::collect(stream).await.unwrap(); + + assert!(!result.is_empty(), "Expected non-empty join result"); + + let metrics = join.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to occur" + ); + + // Force-grow means peak_mem_used is always tracked, even when pool is tight. + let peak_mem = metrics + .sum_by_name("peak_mem_used") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert!( + peak_mem >= join_arrays_mem, + "peak_mem_used ({peak_mem}) should be >= join_arrays_mem ({join_arrays_mem})" + ); + + // Pool should be fully released (grow/shrink balanced) + assert_eq!( + runtime.memory_pool.reserved(), + 0, + "All memory should be released after join completes" + ); + + Ok(()) +} + +/// Build a c1 < c2 filter on the third column of each side. +fn build_c1_lt_c2_filter(left_schema: &Schema, right_schema: &Schema) -> JoinFilter { + JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c1", 0)), + Operator::Lt, + Arc::new(Column::new("c2", 1)), + )), + vec![ + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ], + Arc::new(Schema::new(vec![ + left_schema + .field_with_name("c1") + .unwrap() + .clone() + .with_nullable(true), + right_schema + .field_with_name("c2") + .unwrap() + .clone() + .with_nullable(true), + ])), + ) +} + +#[tokio::test] +async fn spill_with_filter_deferred() -> Result<()> { + let left = build_table( + ("a1", &vec![0, 1, 2, 3, 4, 5]), + ("b1", &vec![1, 2, 3, 4, 5, 6]), + ("c1", &vec![4, 5, 6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30, 40]), + ("b2", &vec![1, 3, 4, 6, 8]), + ("c2", &vec![50, 60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let filter = build_c1_lt_c2_filter(&left.schema(), &right.schema()); + + // Deferred filtering join types handled by the main MaterializingSortMergeJoinStream + let join_types = [Left, Right, Full]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(100, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + for batch_size in [1, 50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); + + for join_type in &join_types { + // Run with spilling + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)), + ); + let join = join_with_filter( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + filter.clone(), + *join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + let stream = join.execute(0, task_ctx)?; + let spilled_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert!( + join.metrics().unwrap().spill_count().unwrap() > 0, + "Expected spilling for {join_type:?} batch_size={batch_size}" + ); + + // Run without spilling + let task_ctx_no_spill = Arc::new( + TaskContext::default().with_session_config(session_config.clone()), + ); + let join_no_spill = join_with_filter( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + filter.clone(), + *join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + let stream = join_no_spill.execute(0, task_ctx_no_spill)?; + let no_spill_result = common::collect(stream).await.unwrap(); + + let spilled_str = batches_to_sort_string(&spilled_result); + let no_spill_str = batches_to_sort_string(&no_spill_result); + assert_eq!( + spilled_str, no_spill_str, + "Spill vs no-spill mismatch for {join_type:?} batch_size={batch_size}" + ); + } + } + + Ok(()) +} + +#[tokio::test] +async fn spill_with_filter_multi_batch() -> Result<()> { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![2, 3]), + ("b1", &vec![1, 1]), + ("c1", &vec![6, 7]), + ); + let left_batch_3 = build_table_i32( + ("a1", &vec![4, 5]), + ("b1", &vec![1, 1]), + ("c1", &vec![8, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10]), + ("b2", &vec![1, 1]), + ("c2", &vec![50, 60]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![20, 30]), + ("b2", &vec![1, 1]), + ("c2", &vec![70, 80]), + ); + let right_batch_3 = + build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90])); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]); + let right = + build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let filter = build_c1_lt_c2_filter(&left.schema(), &right.schema()); + + let join_types = [Left, Right, Full]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(500, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + for batch_size in [1, 50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); + + for join_type in &join_types { + // Run with spilling + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)), + ); + let join = join_with_filter( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + filter.clone(), + *join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + let stream = join.execute(0, task_ctx)?; + let spilled_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert!( + join.metrics().unwrap().spill_count().unwrap() > 0, + "Expected spilling for {join_type:?} batch_size={batch_size}" + ); + + // Run without spilling + let task_ctx_no_spill = Arc::new( + TaskContext::default().with_session_config(session_config.clone()), + ); + let join_no_spill = join_with_filter( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + filter.clone(), + *join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + let stream = join_no_spill.execute(0, task_ctx_no_spill)?; + let no_spill_result = common::collect(stream).await.unwrap(); + + let spilled_str = batches_to_sort_string(&spilled_result); + let no_spill_str = batches_to_sort_string(&no_spill_result); + assert_eq!( + spilled_str, no_spill_str, + "Spill vs no-spill mismatch for {join_type:?} batch_size={batch_size}" + ); + } + } + + Ok(()) +} + +/// FULL join where all buffered rows match on key but fail the filter. +/// Verifies produce_buffered_not_matched emits null-joined rows under spill. +#[tokio::test] +async fn spill_full_join_filter_not_matched() -> Result<()> { + // c1 values (100..105) are always > c2 values (1..5), so c1 < c2 always fails + let left = build_table( + ("a1", &vec![0, 1, 2, 3, 4]), + ("b1", &vec![1, 1, 1, 1, 1]), + ("c1", &vec![100, 101, 102, 103, 104]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40, 50]), + ("b2", &vec![1, 1, 1, 1, 1]), + ("c2", &vec![1, 2, 3, 4, 5]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let filter = build_c1_lt_c2_filter(&left.schema(), &right.schema()); + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(100, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + for batch_size in [1, 50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); + + // Run with spilling + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)), + ); + let join = join_with_filter( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + filter.clone(), + Full, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + let stream = join.execute(0, task_ctx)?; + let spilled_result = common::collect(stream).await.unwrap(); + + assert!( + join.metrics().unwrap().spill_count().unwrap() > 0, + "Expected spilling for FULL batch_size={batch_size}" + ); + + // Run without spilling + let task_ctx_no_spill = + Arc::new(TaskContext::default().with_session_config(session_config.clone())); + let join_no_spill = join_with_filter( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + filter.clone(), + Full, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + let stream = join_no_spill.execute(0, task_ctx_no_spill)?; + let no_spill_result = common::collect(stream).await.unwrap(); + + // All filter evaluations fail, so FULL join should produce: + // - 5 rows with left columns + null right columns (unmatched left) + // - 5 rows with null left columns + right columns (unmatched right) + let total_rows: usize = no_spill_result.iter().map(|b| b.num_rows()).sum(); + assert_eq!( + total_rows, 10, + "FULL join with all-failing filter should produce 10 rows, got {total_rows}" + ); + + let spilled_str = batches_to_sort_string(&spilled_result); + let no_spill_str = batches_to_sort_string(&no_spill_result); + assert_eq!( + spilled_str, no_spill_str, + "Spill vs no-spill mismatch for FULL join batch_size={batch_size}" + ); + } + + Ok(()) +} + +fn build_joined_record_batches() -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + ])); + + let mut batches = JoinedRecordBatches { + joined_batches: BatchCoalescer::new(Arc::clone(&schema), 8192), + filter_metadata: crate::joins::sort_merge_join::filter::FilterMetadata::new(), + }; + + // Insert already prejoined non-filtered rows + batches.joined_batches.push_batch(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![10, 10])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 9])), + ], + )?)?; + + batches.joined_batches.push_batch(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![11])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?)?; + + batches.joined_batches.push_batch(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 12])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 13])), + ], + )?)?; + + batches.joined_batches.push_batch(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![13])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?)?; + + batches.joined_batches.push_batch(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![14, 14])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 11])), + ], + )?)?; + + let streamed_indices = vec![0, 0]; + batches + .filter_metadata + .batch_ids + .extend(vec![0; streamed_indices.len()]); + batches + .filter_metadata + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![1]; + batches + .filter_metadata + .batch_ids + .extend(vec![0; streamed_indices.len()]); + batches + .filter_metadata + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + batches + .filter_metadata + .batch_ids + .extend(vec![1; streamed_indices.len()]); + batches + .filter_metadata + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0]; + batches + .filter_metadata + .batch_ids + .extend(vec![2; streamed_indices.len()]); + batches + .filter_metadata + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + batches + .filter_metadata + .batch_ids + .extend(vec![3; streamed_indices.len()]); + batches + .filter_metadata + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + batches + .filter_metadata + .filter_mask + .extend(&BooleanArray::from(vec![true, false])); + batches + .filter_metadata + .filter_mask + .extend(&BooleanArray::from(vec![true])); + batches + .filter_metadata + .filter_mask + .extend(&BooleanArray::from(vec![false, true])); + batches + .filter_metadata + .filter_mask + .extend(&BooleanArray::from(vec![false])); + batches + .filter_metadata + .filter_mask + .extend(&BooleanArray::from(vec![false, false])); + + Ok(batches) +} + +#[tokio::test] +async fn test_left_outer_join_filtered_mask() -> Result<()> { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.joined_batches.schema(); + + let output = joined_batches.concat_batches(&schema)?; + let out_mask = joined_batches.filter_metadata.filter_mask.finish(); + let out_indices = joined_batches.filter_metadata.row_indices.finish(); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true, false, false, false, false, false, false, false]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![false, false, false, false, false, false, false, false]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], + &BooleanArray::from(vec![true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true, true, false, false, false, false, false, false]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true, true, true, false, false, false, false, false]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, &UInt64Array::from(vec![0, 0, 0]), &[0usize; 3], - &BooleanArray::from(vec![true, true, true]), + &BooleanArray::from(vec![false, false, true]), output.num_rows() ) .unwrap(), - BooleanArray::from(vec![true, true, true, false, false, false, false, false]) + BooleanArray::from(vec![ + None, + None, + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + Some(true), + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + None, + Some(false), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + let corrected_mask = get_corrected_filter_mask( + Left, + &out_indices, + &joined_batches.filter_metadata.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + None, + Some(true), + Some(false), + None, + Some(false) + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_snapshot!(batches_to_string(&[filtered_rb]), @r" + +---+----+---+----+ + | a | b | x | y | + +---+----+---+----+ + | 1 | 10 | 1 | 11 | + | 1 | 11 | 1 | 12 | + | 1 | 12 | 1 | 13 | + +---+----+---+----+ + "); + + // output null rows + + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + Some(false), + None, + Some(false), + None, + Some(false), + Some(true), + None, + Some(true) + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_snapshot!(batches_to_string(&[null_joined_batch]), @r" + +---+----+---+----+ + | a | b | x | y | + +---+----+---+----+ + | 1 | 13 | 1 | 12 | + | 1 | 14 | 1 | 11 | + +---+----+---+----+ + "); + Ok(()) +} + +#[test] +fn test_partition_statistics() -> Result<()> { + use crate::ExecutionPlan; + use datafusion_common::stats::Precision; + + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + // Test different join types to ensure partition_statistics works correctly for all + let join_types = vec![ + (Inner, 6), // left cols + right cols + (Left, 6), // left cols + right cols + (Right, 6), // left cols + right cols + (Full, 6), // left cols + right cols + (LeftSemi, 3), // only left cols + (LeftAnti, 3), // only left cols + (RightSemi, 3), // only right cols + (RightAnti, 3), // only right cols + ]; + + for (join_type, expected_cols) in join_types { + let join_exec = + join(Arc::clone(&left), Arc::clone(&right), on.clone(), join_type)?; + + // Test aggregate statistics (partition = None) + // Should return meaningful statistics computed from both inputs + let stats = join_exec.partition_statistics(None)?; + assert_eq!( + stats.column_statistics.len(), + expected_cols, + "Aggregate stats column count failed for {join_type:?}" + ); + // Verify that aggregate statistics have a meaningful num_rows (not Absent) + assert!( + stats.num_rows != Precision::Absent, + "Aggregate stats should have meaningful num_rows for {join_type:?}, got {:?}", + stats.num_rows + ); + + // Test partition-specific statistics (partition = Some(0)) + // The implementation correctly passes `partition` to children. + // Since the child TestMemoryExec returns unknown stats for specific partitions, + // the join output will also have Absent num_rows. This is expected behavior + // as the statistics depend on what the children can provide. + let partition_stats = join_exec.partition_statistics(Some(0))?; + assert_eq!( + partition_stats.column_statistics.len(), + expected_cols, + "Partition stats column count failed for {join_type:?}" + ); + // When children return unknown stats, the join's partition stats will be Absent + assert!( + partition_stats.num_rows == Precision::Absent, + "Partition stats should have Absent num_rows when children return unknown for {join_type:?}, got {:?}", + partition_stats.num_rows + ); + } + + Ok(()) +} + +fn build_batches( + a: (&str, &[Vec]), + b: (&str, &[Vec]), + c: (&str, &[Vec]), +) -> (Vec, SchemaRef) { + assert_eq!(a.1.len(), b.1.len()); + let mut batches = vec![]; + + let schema = Arc::new(Schema::new(vec![ + Field::new(a.0, DataType::Boolean, false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ])); + + for i in 0..a.1.len() { + batches.push( + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(BooleanArray::from(a.1[i].clone())), + Arc::new(Int32Array::from(b.1[i].clone())), + Arc::new(Int32Array::from(c.1[i].clone())), + ], + ) + .unwrap(), + ); + } + let schema = batches[0].schema(); + (batches, schema) +} + +fn build_batched_finish_barrier_table( + a: (&str, &[Vec]), + b: (&str, &[Vec]), + c: (&str, &[Vec]), +) -> (Arc, Arc) { + let (batches, schema) = build_batches(a, b, c); + + let memory_exec = TestMemoryExec::try_new_exec( + std::slice::from_ref(&batches), + Arc::clone(&schema), + None, + ) + .unwrap(); + + let barrier_exec = Arc::new( + BarrierExec::new(vec![batches], schema) + .with_log(false) + .without_start_barrier() + .with_finish_barrier(), + ); + + (barrier_exec, memory_exec) +} + +/// Concat and sort batches by all the columns to make sure we can compare them with different join +fn prepare_record_batches_for_cmp(output: Vec) -> RecordBatch { + let output_batch = arrow::compute::concat_batches(output[0].schema_ref(), &output) + .expect("failed to concat batches"); + + // Sort on all columns to make sure we have a deterministic order for the assertion + let sort_columns = output_batch + .columns() + .iter() + .map(|c| SortColumn { + values: Arc::clone(c), + options: None, + }) + .collect::>(); + + let sorted_columns = + arrow::compute::lexsort(&sort_columns, None).expect("failed to sort"); + + RecordBatch::try_new(output_batch.schema(), sorted_columns) + .expect("failed to create batch") +} + +#[expect(clippy::too_many_arguments)] +async fn join_get_stream_and_get_expected( + left: Arc, + right: Arc, + oracle_left: Arc, + oracle_right: Arc, + on: JoinOn, + join_type: JoinType, + filter: Option, + batch_size: usize, +) -> Result<(SendableRecordBatchStream, RecordBatch)> { + let sort_options = vec![SortOptions::default(); on.len()]; + let null_equality = NullEquality::NullEqualsNothing; + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(SessionConfig::default().with_batch_size(batch_size)), + ); + + let expected_output = { + let oracle = HashJoinExec::try_new( + oracle_left, + oracle_right, + on.clone(), + filter.clone(), + &join_type, + None, + PartitionMode::Partitioned, + null_equality, + false, + )?; + + let stream = oracle.execute(0, Arc::clone(&task_ctx))?; + + let batches = common::collect(stream).await?; + + prepare_record_batches_for_cmp(batches) + }; + + let join = SortMergeJoinExec::try_new( + left, + right, + on, + filter, + join_type, + sort_options, + null_equality, + )?; + + let stream = join.execute(0, task_ctx)?; + + Ok((stream, expected_output)) +} + +fn generate_data_for_emit_early_test( + batch_size: usize, + number_of_batches: usize, + join_type: JoinType, +) -> ( + Arc, + Arc, + Arc, + Arc, +) { + let number_of_rows_per_batch = number_of_batches * batch_size; + // Prepare data + let left_a1 = (0..number_of_rows_per_batch as i32) + .chunks(batch_size) + .into_iter() + .map(|chunk| chunk.collect::>()) + .collect::>(); + let left_b1 = (0..1000000) + .filter(|item| { + match join_type { + LeftAnti | RightAnti => { + let remainder = item % (batch_size as i32); + + // Make sure to have one that match and one that don't + remainder == 0 || remainder == 1 + } + // Have at least 1 that is not matching + _ => item % batch_size as i32 != 0, + } + }) + .take(number_of_rows_per_batch) + .chunks(batch_size) + .into_iter() + .map(|chunk| chunk.collect::>()) + .collect::>(); + + let left_bool_col1 = left_a1 + .clone() + .into_iter() + .map(|b| { + b.into_iter() + // Mostly true but have some false that not overlap with the right column + .map(|a| a % (batch_size as i32) != (batch_size as i32) - 2) + .collect::>() + }) + .collect::>(); + + let (left, left_memory) = build_batched_finish_barrier_table( + ("bool_col1", left_bool_col1.as_slice()), + ("b1", left_b1.as_slice()), + ("a1", left_a1.as_slice()), + ); + + let right_a2 = (0..number_of_rows_per_batch as i32) + .map(|item| item * 11) + .chunks(batch_size) + .into_iter() + .map(|chunk| chunk.collect::>()) + .collect::>(); + let right_b1 = (0..1000000) + .filter(|item| { + match join_type { + LeftAnti | RightAnti => { + let remainder = item % (batch_size as i32); + + // Make sure to have one that match and one that don't + remainder == 1 || remainder == 2 + } + // Have at least 1 that is not matching + _ => item % batch_size as i32 != 1, + } + }) + .take(number_of_rows_per_batch) + .chunks(batch_size) + .into_iter() + .map(|chunk| chunk.collect::>()) + .collect::>(); + let right_bool_col2 = right_a2 + .clone() + .into_iter() + .map(|b| { + b.into_iter() + // Mostly true but have some false that not overlap with the left column + .map(|a| a % (batch_size as i32) != (batch_size as i32) - 1) + .collect::>() + }) + .collect::>(); + + let (right, right_memory) = build_batched_finish_barrier_table( + ("bool_col2", right_bool_col2.as_slice()), + ("b1", right_b1.as_slice()), + ("a2", right_a2.as_slice()), + ); + + (left, right, left_memory, right_memory) +} + +#[tokio::test] +async fn test_should_emit_early_when_have_enough_data_to_emit() -> Result<()> { + for with_filtering in [false, true] { + let join_types = vec![ + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, RightMark, + ]; + const BATCH_SIZE: usize = 10; + for join_type in join_types { + for output_batch_size in [ + BATCH_SIZE / 3, + BATCH_SIZE / 2, + BATCH_SIZE, + BATCH_SIZE * 2, + BATCH_SIZE * 3, + ] { + // Make sure the number of batches is enough for all join type to emit some output + let number_of_batches = if output_batch_size <= BATCH_SIZE { + 100 + } else { + // Have enough batches + (output_batch_size * 100) / BATCH_SIZE + }; + + let (left, right, left_memory, right_memory) = + generate_data_for_emit_early_test( + BATCH_SIZE, + number_of_batches, + join_type, + ); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let join_filter = if with_filtering { + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("bool_col1", 0)), + Operator::And, + Arc::new(Column::new("bool_col2", 1)), + )), + vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ], + Arc::new(Schema::new(vec![ + Field::new("bool_col1", DataType::Boolean, true), + Field::new("bool_col2", DataType::Boolean, true), + ])), + ); + Some(filter) + } else { + None + }; + + // select * + // from t1 + // right join t2 on t1.b1 = t2.b1 and t1.bool_col1 AND t2.bool_col2 + let (mut output_stream, expected) = join_get_stream_and_get_expected( + Arc::clone(&left) as Arc, + Arc::clone(&right) as Arc, + left_memory as Arc, + right_memory as Arc, + on, + join_type, + join_filter, + output_batch_size, + ) + .await?; + + let (output_batched, output_batches_after_finish) = + consume_stream_until_finish_barrier_reached(left, right, &mut output_stream).await.unwrap_or_else(|e| panic!("Failed to consume stream for join type: '{join_type}' and with filtering '{with_filtering}': {e:?}")); + + // It should emit more than that, but we are being generous + // and to make sure the test pass for all + const MINIMUM_OUTPUT_BATCHES: usize = 5; + assert!( + MINIMUM_OUTPUT_BATCHES <= number_of_batches / 5, + "Make sure that the minimum output batches is realistic" + ); + // Test to make sure that we are not waiting for input to be fully consumed to emit some output + assert!( + output_batched.len() >= MINIMUM_OUTPUT_BATCHES, + "[Sort Merge Join {join_type}] Stream must have at least emit {} batches, but only got {} batches", + MINIMUM_OUTPUT_BATCHES, + output_batched.len() + ); + + // Just sanity test to make sure we are still producing valid output + { + let output = [output_batched, output_batches_after_finish].concat(); + let actual_prepared = prepare_record_batches_for_cmp(output); + + assert_eq!(actual_prepared.columns(), expected.columns()); + } + } + } + } + Ok(()) +} + +/// Polls the stream until both barriers are reached, +/// collecting the emitted batches along the way. +/// +/// If the stream is pending for too long (5s) without emitting any batches, +/// it panics to avoid hanging the test indefinitely. +/// +/// Note: The left and right BarrierExec might be the input of the output stream +async fn consume_stream_until_finish_barrier_reached( + left: Arc, + right: Arc, + output_stream: &mut SendableRecordBatchStream, +) -> Result<(Vec, Vec)> { + let mut switch_to_finish_barrier = false; + let mut output_batched = vec![]; + let mut after_finish_barrier_reached = vec![]; + let mut background_task = JoinSet::new(); + + let mut start_time_since_last_ready = datafusion_common::instant::Instant::now(); + loop { + let next_item = output_stream.next(); + + // Manual polling + let poll_output = futures::poll!(next_item); + + // Wake up the stream to make sure it makes progress + tokio::task::yield_now().await; + + match poll_output { + Poll::Ready(Some(Ok(batch))) => { + if batch.num_rows() == 0 { + return internal_err!("join stream should not emit empty batch"); + } + if switch_to_finish_barrier { + after_finish_barrier_reached.push(batch); + } else { + output_batched.push(batch); + } + start_time_since_last_ready = datafusion_common::instant::Instant::now(); + } + Poll::Ready(Some(Err(e))) => return Err(e), + Poll::Ready(None) if !switch_to_finish_barrier => { + unreachable!("Stream should not end before manually finishing it") + } + Poll::Ready(None) => { + break; + } + Poll::Pending => { + if right.is_finish_barrier_reached() + && left.is_finish_barrier_reached() + && !switch_to_finish_barrier + { + switch_to_finish_barrier = true; + + let right = Arc::clone(&right); + background_task.spawn(async move { + right.wait_finish().await; + }); + let left = Arc::clone(&left); + background_task.spawn(async move { + left.wait_finish().await; + }); + } + + // Make sure the test doesn't run forever + if start_time_since_last_ready.elapsed() + > std::time::Duration::from_secs(5) + { + return internal_err!( + "Stream should have emitted data by now, but it's still pending. Output batches so far: {}", + output_batched.len() + ); + } + } + } + } + + Ok((output_batched, after_finish_barrier_reached)) +} + +/// Exercises the multi-source interleave path in `materialize_right_columns`. +/// +/// When the right (buffered) side is split into many small batches with unique +/// keys, a single `freeze_streamed()` call references multiple `BufferedBatch`es. +/// This forces the `interleave` kernel instead of the single-source `take` path. +/// Without this test, the interleave path has zero coverage from unit tests +/// (fuzz tests use ~100 unique keys across 1000 rows, so all keys fit in one +/// buffered batch). +#[tokio::test] +async fn join_filtered_with_multiple_buffered_batches() -> Result<()> { + let left_schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int32, false), + Field::new("val_l", DataType::Int32, false), + ])); + let right_schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int32, false), + Field::new("val_r", DataType::Int32, false), + ])); + + // Left: single batch, keys 1..=6 + let left_batch = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])), + Arc::new(Int32Array::from(vec![10, 20, 30, 40, 50, 60])), + ], + )?; + let left = build_table_from_batches(vec![left_batch]); + + // Right: one row per batch so each key lives in a separate BufferedBatch + let right_batches: Vec = (1..=6) + .map(|k| { + RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(vec![k])), + Arc::new(Int32Array::from(vec![k * 100])), + ], + ) + .unwrap() + }) + .collect(); + let right = build_table_from_batches(right_batches); + + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("key", &left.schema())?) as _, + Arc::new(Column::new_with_schema("key", &right.schema())?) as _, + )]; + + // Filter: val_l + val_r < 350 — passes for keys 1-3, fails for 4-6 + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("val_l", 0)), + Operator::Plus, + Arc::new(Column::new("val_r", 1)), + )), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(350)))), + )), + vec![ + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 1, + side: JoinSide::Right, + }, + ], + Arc::new(Schema::new(vec![ + Field::new("val_l", DataType::Int32, true), + Field::new("val_r", DataType::Int32, true), + ])), + ); + + // Inner: only rows passing the filter + let (_, batches) = join_collect_with_filter( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + filter.clone(), + Inner, + ) + .await?; + let result = batches_to_sort_string(&batches); + assert_snapshot!(result, @r" + +-----+-------+-----+-------+ + | key | val_l | key | val_r | + +-----+-------+-----+-------+ + | 1 | 10 | 1 | 100 | + | 2 | 20 | 2 | 200 | + | 3 | 30 | 3 | 300 | + +-----+-------+-----+-------+ + "); + + // Left: unmatched left rows get null right columns + let (_, batches) = join_collect_with_filter( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + filter.clone(), + Left, + ) + .await?; + let result = batches_to_sort_string(&batches); + assert_snapshot!(result, @r" + +-----+-------+-----+-------+ + | key | val_l | key | val_r | + +-----+-------+-----+-------+ + | 1 | 10 | 1 | 100 | + | 2 | 20 | 2 | 200 | + | 3 | 30 | 3 | 300 | + | 4 | 40 | | | + | 5 | 50 | | | + | 6 | 60 | | | + +-----+-------+-----+-------+ + "); + + // Full: unmatched rows on both sides get null columns + let (_, batches) = join_collect_with_filter( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + filter.clone(), + Full, + ) + .await?; + let result = batches_to_sort_string(&batches); + assert_snapshot!(result, @r" + +-----+-------+-----+-------+ + | key | val_l | key | val_r | + +-----+-------+-----+-------+ + | | | 4 | 400 | + | | | 5 | 500 | + | | | 6 | 600 | + | 1 | 10 | 1 | 100 | + | 2 | 20 | 2 | 200 | + | 3 | 30 | 3 | 300 | + | 4 | 40 | | | + | 5 | 50 | | | + | 6 | 60 | | | + +-----+-------+-----+-------+ + "); + + Ok(()) +} + +/// Returns the column names on the schema +fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() +} + +// ==================== BitwiseSortMergeJoinStream direct tests ==================== +// +// These tests construct a BitwiseSortMergeJoinStream directly (bypassing exec) +// to exercise async re-entry and spill edge cases using PendingStream. + +/// Create test memory/spill resources for stream-level tests. +fn test_stream_resources( + inner_schema: SchemaRef, + metrics: &ExecutionPlanMetricsSet, +) -> ( + datafusion_execution::memory_pool::MemoryReservation, + SpillManager, + Arc, +) { + let ctx = TaskContext::default(); + let runtime_env = ctx.runtime_env(); + let reservation = MemoryConsumer::new("test").register(ctx.memory_pool()); + let spill_manager = SpillManager::new( + Arc::clone(&runtime_env), + SpillMetrics::new(metrics, 0), + inner_schema, ); + (reservation, spill_manager, runtime_env) +} - assert_eq!( - get_corrected_filter_mask( - Left, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, false, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![ - Some(true), - None, - Some(true), - Some(false), - Some(false), - Some(false), - Some(false), - Some(false) - ]) +/// A RecordBatch stream that yields Poll::Pending once before delivering +/// each batch at a specified index. This simulates the behavior of +/// repartitioned tokio::sync::mpsc channels where data isn't immediately +/// available. +struct PendingStream { + batches: Vec, + index: usize, + /// If pending_before[i] is true, yield Pending once before delivering + /// the batch at index i. + pending_before: Vec, + /// True if we've already yielded Pending for the current index. + yielded_pending: bool, + schema: SchemaRef, +} + +impl PendingStream { + fn new(batches: Vec, pending_before: Vec) -> Self { + assert_eq!(batches.len(), pending_before.len()); + let schema = batches[0].schema(); + Self { + batches, + index: 0, + pending_before, + yielded_pending: false, + schema, + } + } +} + +impl Stream for PendingStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + if self.index >= self.batches.len() { + return Poll::Ready(None); + } + if self.pending_before[self.index] && !self.yielded_pending { + self.yielded_pending = true; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + self.yielded_pending = false; + let batch = self.batches[self.index].clone(); + self.index += 1; + Poll::Ready(Some(Ok(batch))) + } +} + +impl RecordBatchStream for PendingStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +/// Helper: collect all output from a BitwiseSortMergeJoinStream. +async fn collect_stream(stream: BitwiseSortMergeJoinStream) -> Result> { + common::collect(Box::pin(stream)).await +} + +/// Reproduces the buffer_inner_key_group re-entry bug: +/// +/// When buffer_inner_key_group buffers inner rows across batch boundaries +/// and poll_next_inner_batch returns Pending mid-way, the ready! macro +/// exits poll_join. On re-entry, the merge-scan reaches Equal again and +/// calls buffer_inner_key_group a second time -- which starts with +/// clear(), destroying the partially collected inner rows. Previously +/// consumed batches are gone, so re-buffering misses them. +/// +/// Setup: +/// - Inner: 3 single-row batches, all with key=1, filter values c2=[10, 20, 30] +/// - Outer: 1 row, key=1, filter value c1=10 +/// - Filter: c1 == c2 (only first inner row c2=10 matches) +/// - Pending injected before 3rd inner batch +/// +/// Without the bug: outer row emitted (match via c2=10) +/// With the bug: outer row missing (c2=10 batch lost on re-entry) +#[tokio::test] +async fn filter_buffer_pending_loses_inner_rows() -> Result<()> { + let left_schema = Arc::new(Schema::new(vec![ + Field::new("a1", DataType::Int32, false), + Field::new("b1", DataType::Int32, false), + Field::new("c1", DataType::Int32, false), + ])); + let right_schema = Arc::new(Schema::new(vec![ + Field::new("a2", DataType::Int32, false), + Field::new("b1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])); + + // Outer: 1 row, key=1, c1=10 + let outer_batch = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![1])), // join key + Arc::new(Int32Array::from(vec![10])), // filter value + ], + )?; + + // Inner: 3 single-row batches, key=1, c2=[10, 20, 30] + let inner_batch1 = RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(vec![100])), + Arc::new(Int32Array::from(vec![1])), // join key + Arc::new(Int32Array::from(vec![10])), // matches filter + ], + )?; + let inner_batch2 = RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(vec![200])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![20])), // doesn't match + ], + )?; + let inner_batch3 = RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(vec![300])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![30])), // doesn't match + ], + )?; + + let outer: SendableRecordBatchStream = Box::pin(PendingStream::new( + vec![outer_batch], + vec![false], // outer delivers immediately + )); + let inner: SendableRecordBatchStream = Box::pin(PendingStream::new( + vec![inner_batch1, inner_batch2, inner_batch3], + vec![false, false, true], // Pending before 3rd batch + )); + + // Filter: c1 == c2 + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c1", 0)), + Operator::Eq, + Arc::new(Column::new("c2", 1)), + )), + vec![ + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ], + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])), ); + let on_outer: Vec = vec![Arc::new(Column::new("b1", 1))]; + let on_inner: Vec = vec![Arc::new(Column::new("b1", 1))]; + + let metrics = ExecutionPlanMetricsSet::new(); + let inner_schema = inner.schema(); + let (reservation, spill_manager, runtime_env) = + test_stream_resources(inner_schema, &metrics); + let stream = BitwiseSortMergeJoinStream::try_new( + left_schema, // output schema = outer schema for semi + vec![SortOptions::default()], + NullEquality::NullEqualsNothing, + outer, + inner, + on_outer, + on_inner, + Some(filter), + LeftSemi, + 8192, + 0, + &metrics, + reservation, + spill_manager, + runtime_env, + )?; + + let batches = collect_stream(stream).await?; + let total: usize = batches.iter().map(|b| b.num_rows()).sum(); assert_eq!( - get_corrected_filter_mask( - Left, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![ - None, - None, - Some(true), - Some(false), - Some(false), - Some(false), - Some(false), - Some(false) - ]) + total, 1, + "LeftSemi with filter: outer row should be emitted because \ + inner row c2=10 matches filter c1==c2. Got {total} rows." ); + Ok(()) +} + +/// Reproduces the no-filter boundary Pending re-entry bug: +/// +/// When an outer key group spans a batch boundary, the no-filter path +/// emits the current batch, then polls for the next outer batch. If +/// poll returns Pending, poll_join exits. On re-entry, without the +/// PendingBoundary fix, the new batch is processed fresh by the +/// merge-scan. Since inner already advanced past this key, the outer +/// rows with the matching key are skipped via Ordering::Less. +/// +/// Setup: +/// - Outer: 2 single-row batches, both with key=1 (key group spans boundary) +/// - Inner: 1 row with key=1 +/// - Pending injected on outer before 2nd batch +/// +/// Without fix: only first outer row emitted (second lost on re-entry) +/// With fix: both outer rows emitted +#[tokio::test] +async fn no_filter_boundary_pending_loses_outer_rows() -> Result<()> { + let left_schema = Arc::new(Schema::new(vec![ + Field::new("a1", DataType::Int32, false), + Field::new("b1", DataType::Int32, false), + Field::new("c1", DataType::Int32, false), + ])); + let right_schema = Arc::new(Schema::new(vec![ + Field::new("a2", DataType::Int32, false), + Field::new("b1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])); + + // Outer: 2 single-row batches, both key=1 + let outer_batch1 = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![10])), + ], + )?; + let outer_batch2 = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![ + Arc::new(Int32Array::from(vec![2])), + Arc::new(Int32Array::from(vec![1])), // same key + Arc::new(Int32Array::from(vec![20])), + ], + )?; + + // Inner: 1 row, key=1 + let inner_batch = RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(vec![100])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![50])), + ], + )?; + let outer: SendableRecordBatchStream = Box::pin(PendingStream::new( + vec![outer_batch1, outer_batch2], + vec![false, true], // Pending before 2nd outer batch + )); + let inner: SendableRecordBatchStream = + Box::pin(PendingStream::new(vec![inner_batch], vec![false])); + + let on_outer: Vec = vec![Arc::new(Column::new("b1", 1))]; + let on_inner: Vec = vec![Arc::new(Column::new("b1", 1))]; + + let metrics = ExecutionPlanMetricsSet::new(); + let inner_schema = inner.schema(); + let (reservation, spill_manager, runtime_env) = + test_stream_resources(inner_schema, &metrics); + let stream = BitwiseSortMergeJoinStream::try_new( + left_schema, + vec![SortOptions::default()], + NullEquality::NullEqualsNothing, + outer, + inner, + on_outer, + on_inner, + None, // no filter + LeftSemi, + 8192, + 0, + &metrics, + reservation, + spill_manager, + runtime_env, + )?; + + let batches = collect_stream(stream).await?; + let total: usize = batches.iter().map(|b| b.num_rows()).sum(); assert_eq!( - get_corrected_filter_mask( - Left, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![ - None, - Some(true), - Some(true), - Some(false), - Some(false), - Some(false), - Some(false), - Some(false) - ]) + total, 2, + "LeftSemi no filter: both outer rows (key=1) should be emitted \ + because inner has key=1. Got {total} rows." ); + Ok(()) +} - assert_eq!( - get_corrected_filter_mask( - Left, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, false]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![ - None, - None, - Some(false), - Some(false), - Some(false), - Some(false), - Some(false), - Some(false) - ]) +/// Tests the filtered boundary Pending re-entry: outer key group spans +/// batches with a filter, and poll_next_outer_batch returns Pending. +/// +/// Setup: +/// - Outer: 2 single-row batches, both key=1, c1=[10, 20] +/// - Inner: 1 row, key=1, c2=10 +/// - Filter: c1 == c2 (first outer row matches, second doesn't) +/// - Pending before 2nd outer batch +/// +/// Expected: 1 row (only the first outer row c1=10 passes the filter) +#[tokio::test] +async fn filtered_boundary_pending_outer_rows() -> Result<()> { + let left_schema = Arc::new(Schema::new(vec![ + Field::new("a1", DataType::Int32, false), + Field::new("b1", DataType::Int32, false), + Field::new("c1", DataType::Int32, false), + ])); + let right_schema = Arc::new(Schema::new(vec![ + Field::new("a2", DataType::Int32, false), + Field::new("b1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])); + + let outer_batch1 = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![10])), // matches filter + ], + )?; + let outer_batch2 = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![ + Arc::new(Int32Array::from(vec![2])), + Arc::new(Int32Array::from(vec![1])), // same key + Arc::new(Int32Array::from(vec![20])), // doesn't match + ], + )?; + + let inner_batch = RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(vec![100])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![10])), + ], + )?; + + let outer: SendableRecordBatchStream = Box::pin(PendingStream::new( + vec![outer_batch1, outer_batch2], + vec![false, true], // Pending before 2nd outer batch + )); + let inner: SendableRecordBatchStream = + Box::pin(PendingStream::new(vec![inner_batch], vec![false])); + + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c1", 0)), + Operator::Eq, + Arc::new(Column::new("c2", 1)), + )), + vec![ + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ], + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])), ); - let corrected_mask = get_corrected_filter_mask( - Left, - &out_indices, - &joined_batches.batch_ids, - &out_mask, - output.num_rows(), - ) - .unwrap(); + let on_outer: Vec = vec![Arc::new(Column::new("b1", 1))]; + let on_inner: Vec = vec![Arc::new(Column::new("b1", 1))]; + + let metrics = ExecutionPlanMetricsSet::new(); + let inner_schema = inner.schema(); + let (reservation, spill_manager, runtime_env) = + test_stream_resources(inner_schema, &metrics); + let stream = BitwiseSortMergeJoinStream::try_new( + left_schema, + vec![SortOptions::default()], + NullEquality::NullEqualsNothing, + outer, + inner, + on_outer, + on_inner, + Some(filter), + LeftSemi, + 8192, + 0, + &metrics, + reservation, + spill_manager, + runtime_env, + )?; + let batches = collect_stream(stream).await?; + let total: usize = batches.iter().map(|b| b.num_rows()).sum(); assert_eq!( - corrected_mask, - BooleanArray::from(vec![ - Some(true), - None, - Some(true), - None, - Some(true), - Some(false), - None, - Some(false) - ]) + total, 1, + "LeftSemi filtered boundary: only first outer row (c1=10) matches \ + filter c1==c2. Got {total} rows." + ); + Ok(()) +} + +// ── Bitwise stream spill tests ───────────────────────────────────────────── + +/// Exercises inner key group spilling under memory pressure. +/// +/// Uses a tiny memory limit (100 bytes) with disk spilling enabled. Since our +/// operator only buffers inner rows when a filter is present, this test includes +/// a filter (c1 < c2, always true). Verifies: +/// 1. Spill metrics are recorded (spill_count, spilled_bytes, spilled_rows > 0) +/// 2. Results match a non-spilled run +#[tokio::test] +async fn bitwise_spill_with_filter() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3, 4, 5, 6]), + ("b1", &vec![1, 2, 3, 4, 5, 6]), + ("c1", &vec![4, 5, 6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40, 50]), + ("b1", &vec![1, 3, 4, 6, 8]), + ("c2", &vec![50, 60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + // c1 < c2 is always true for matching keys + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c1", 0)), + Operator::Lt, + Arc::new(Column::new("c2", 1)), + )), + vec![ + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ], + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])), ); - let filtered_rb = filter_record_batch(&output, &corrected_mask)?; - - assert_snapshot!(batches_to_string(&[filtered_rb]), @r#" - +---+----+---+----+ - | a | b | x | y | - +---+----+---+----+ - | 1 | 10 | 1 | 11 | - | 1 | 11 | 1 | 12 | - | 1 | 12 | 1 | 13 | - +---+----+---+----+ - "#); + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(100, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; - // output null rows + for batch_size in [1, 50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); - let null_mask = arrow::compute::not(&corrected_mask)?; - assert_eq!( - null_mask, - BooleanArray::from(vec![ - Some(false), - None, - Some(false), - None, - Some(false), - Some(true), - None, - Some(true) - ]) - ); + for join_type in [LeftSemi, LeftAnti, RightSemi, RightAnti] { + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)), + ); - let null_joined_batch = filter_record_batch(&output, &null_mask)?; + let join = SortMergeJoinExec::try_new( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Some(filter.clone()), + join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + let stream = join.execute(0, task_ctx)?; + let spilled_result = common::collect(stream).await.unwrap(); + + assert!( + join.metrics().is_some(), + "metrics missing for {join_type:?}" + ); + let metrics = join.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "expected spill_count > 0 for {join_type:?}, batch_size={batch_size}" + ); + assert!( + metrics.spilled_bytes().unwrap() > 0, + "expected spilled_bytes > 0 for {join_type:?}, batch_size={batch_size}" + ); + assert!( + metrics.spilled_rows().unwrap() > 0, + "expected spilled_rows > 0 for {join_type:?}, batch_size={batch_size}" + ); + + // Run without spilling and compare results + let task_ctx_no_spill = Arc::new( + TaskContext::default().with_session_config(session_config.clone()), + ); + let join_no_spill = SortMergeJoinExec::try_new( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Some(filter.clone()), + join_type, + sort_options.clone(), + NullEquality::NullEqualsNothing, + )?; + let stream = join_no_spill.execute(0, task_ctx_no_spill)?; + let no_spill_result = common::collect(stream).await.unwrap(); + + let no_spill_metrics = join_no_spill.metrics().unwrap(); + assert_eq!( + no_spill_metrics.spill_count(), + Some(0), + "unexpected spill for {join_type:?} without memory limit" + ); + + assert_eq!( + spilled_result, no_spill_result, + "spilled vs non-spilled results differ for {join_type:?}, batch_size={batch_size}" + ); + } + } - assert_snapshot!(batches_to_string(&[null_joined_batch]), @r#" - +---+----+---+----+ - | a | b | x | y | - +---+----+---+----+ - | 1 | 13 | 1 | 12 | - | 1 | 14 | 1 | 11 | - +---+----+---+----+ - "#); Ok(()) } +/// Reproduces a bug where `resume_boundary` for the Filtered pending case +/// only checks `inner_key_buffer.is_empty()` but ignores `inner_key_spill`. +/// After spilling, the in-memory buffer is cleared while the spill file +/// holds the data. If the outer key group spans a batch boundary, the +/// second outer batch's rows are never evaluated against the inner group. +/// +/// Setup: +/// - Outer: 2 single-row batches, both key=1, c1=[10, 10] +/// - Inner: 1 batch with many rows all key=1 (enough to trigger spill) +/// - Filter: c1 == c2 (matches when c2=10) +/// - Memory limit: tiny (100 bytes) to force spilling +/// - Pending before 2nd outer batch to trigger boundary re-entry +/// +/// Expected: both outer rows match (semi=2 rows, anti=0 rows) +/// Bug: second outer row is skipped because resume_boundary sees empty +/// inner_key_buffer and skips re-evaluation. #[tokio::test] -async fn test_semi_join_filtered_mask() -> Result<()> { - for join_type in [LeftSemi, RightSemi] { - let mut joined_batches = build_joined_record_batches()?; - let schema = joined_batches.batches.first().unwrap().schema(); +async fn spill_filtered_boundary_loses_outer_rows() -> Result<()> { + let left_schema = Arc::new(Schema::new(vec![ + Field::new("a1", DataType::Int32, false), + Field::new("b1", DataType::Int32, false), + Field::new("c1", DataType::Int32, false), + ])); + let right_schema = Arc::new(Schema::new(vec![ + Field::new("a2", DataType::Int32, false), + Field::new("b1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])); - let output = concat_batches(&schema, &joined_batches.batches)?; - let out_mask = joined_batches.filter_mask.finish(); - let out_indices = joined_batches.row_indices.finish(); + // Two single-row outer batches with the same key -- key group spans boundary + let outer_batch1 = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![1])), // key=1 + Arc::new(Int32Array::from(vec![10])), // matches filter + ], + )?; + let outer_batch2 = RecordBatch::try_new( + Arc::clone(&left_schema), + vec![ + Arc::new(Int32Array::from(vec![2])), + Arc::new(Int32Array::from(vec![1])), // same key=1 + Arc::new(Int32Array::from(vec![10])), // also matches filter + ], + )?; - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![true]) - ); + // Inner: many rows with key=1 to force spilling, followed by key=2. + // c2=10 so the filter c1==c2 passes for both outer rows. + // The key=2 row ensures the inner cursor advances past the key group + // (buffer_inner_key_group returns Ok(false) instead of Ok(true)). + let n_inner = 200; + let mut inner_a = vec![100; n_inner]; + inner_a.push(101); + let mut inner_b = vec![1; n_inner]; + inner_b.push(2); // different key -- forces inner cursor past key=1 + let mut inner_c = vec![10; n_inner]; + inner_c.push(10); + let inner_batch = RecordBatch::try_new( + Arc::clone(&right_schema), + vec![ + Arc::new(Int32Array::from(inner_a)), + Arc::new(Int32Array::from(inner_b)), + Arc::new(Int32Array::from(inner_c)), + ], + )?; - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![false]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None]) - ); + // Filter: c1 == c2 + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("c1", 0)), + Operator::Eq, + Arc::new(Column::new("c2", 1)), + )), + vec![ + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ], + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])), + ); - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0]), - &[0usize; 2], - &BooleanArray::from(vec![true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![Some(true), None]) - ); + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(100, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![Some(true), None, None]) + let on_outer: Vec = vec![Arc::new(Column::new("b1", 1))]; + let on_inner: Vec = vec![Arc::new(Column::new("b1", 1))]; + + for join_type in [LeftSemi, LeftAnti] { + let outer: SendableRecordBatchStream = Box::pin(PendingStream::new( + vec![outer_batch1.clone(), outer_batch2.clone()], + vec![false, true], // Pending before 2nd outer batch + )); + let inner: SendableRecordBatchStream = + Box::pin(PendingStream::new(vec![inner_batch.clone()], vec![false])); + + let metrics = ExecutionPlanMetricsSet::new(); + let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); + let spill_manager = SpillManager::new( + Arc::clone(&runtime), + SpillMetrics::new(&metrics, 0), + Arc::clone(&right_schema), ); - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, false, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![Some(true), None, None]) - ); + let stream = BitwiseSortMergeJoinStream::try_new( + Arc::clone(&left_schema), + vec![SortOptions::default()], + NullEquality::NullEqualsNothing, + outer, + inner, + on_outer.clone(), + on_inner.clone(), + Some(filter.clone()), + join_type, + 8192, + 0, + &metrics, + reservation, + spill_manager, + Arc::clone(&runtime), + )?; - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None, None, Some(true),]) - ); + let batches = collect_stream(stream).await?; + let total: usize = batches.iter().map(|b| b.num_rows()).sum(); + + match join_type { + LeftSemi => { + assert_eq!( + total, 2, + "LeftSemi spill+boundary: both outer rows match filter, \ + expected 2 rows, got {total}" + ); + } + LeftAnti => { + assert_eq!( + total, 0, + "LeftAnti spill+boundary: both outer rows match filter, \ + expected 0 rows, got {total}" + ); + } + _ => unreachable!(), + } + } - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, true, true]), - output.num_rows() - ) - .unwrap(), - BooleanArray::from(vec![None, Some(true), None]) - ); + Ok(()) +} - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, false]), - output.num_rows() +/// Verifies that `peak_mem_used` reflects spill read-back memory during +/// output materialization (multi-source path). +/// +/// When spilled buffered batches are read back from disk to produce join +/// output, a scoped `MemoryReservation` (via `new_empty()`) tracks the +/// transient memory. Its `Drop` guarantees the pool is balanced on every +/// exit path — normal return or early `?` error. +#[tokio::test] +async fn spill_read_back_memory_accounting() -> Result<()> { + use arrow::array::Array; + + let left_batch = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let size_estimation = left_batch.get_array_memory_size() + + Int32Array::from(vec![1, 1]).get_array_memory_size() + + 2usize.next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); + + // Memory limit too small for a full batch — forces spilling. + let memory_limit = size_estimation / 2; + + // All rows share the same join key (b=1) to force multiple buffered + // batches in the same key group — triggering spill read-back during + // output materialization. + let left_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a1", &vec![i * 2, i * 2 + 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![100 + i, 101 + i]), ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); + }) + .collect(); + let left = build_table_from_batches(left_batches); + + let right_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a2", &vec![i * 2, i * 2 + 1]), + ("b2", &vec![1, 1]), + ("c2", &vec![200 + i, 201 + i]), + ) + }) + .collect(); + let right = build_table_from_batches(right_batches); - let corrected_mask = get_corrected_filter_mask( - join_type, - &out_indices, - &joined_batches.batch_ids, - &out_mask, - output.num_rows(), - ) - .unwrap(); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; - assert_eq!( - corrected_mask, - BooleanArray::from(vec![ - Some(true), - None, - Some(true), - None, - Some(true), - None, - None, - None - ]) - ); + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; - let filtered_rb = filter_record_batch(&output, &corrected_mask)?; - - assert_batches_eq!( - &[ - "+---+----+---+----+", - "| a | b | x | y |", - "+---+----+---+----+", - "| 1 | 10 | 1 | 11 |", - "| 1 | 11 | 1 | 12 |", - "| 1 | 12 | 1 | 13 |", - "+---+----+---+----+", - ], - &[filtered_rb] - ); + let session_config = SessionConfig::default().with_batch_size(50); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(Arc::clone(&runtime)), + ); - // output null rows - let null_mask = arrow::compute::not(&corrected_mask)?; - assert_eq!( - null_mask, - BooleanArray::from(vec![ - Some(false), - None, - Some(false), - None, - Some(false), - None, - None, - None - ]) - ); + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Inner, + sort_options, + NullEquality::NullEqualsNothing, + )?; - let null_joined_batch = filter_record_batch(&output, &null_mask)?; + let stream = join.execute(0, task_ctx)?; + let result = common::collect(stream).await.unwrap(); - assert_batches_eq!( - &[ - "+---+---+---+---+", - "| a | b | x | y |", - "+---+---+---+---+", - "+---+---+---+---+", - ], - &[null_joined_batch] - ); - } - Ok(()) -} + assert!(!result.is_empty(), "Expected non-empty join result"); -#[tokio::test] -async fn test_anti_join_filtered_mask() -> Result<()> { - for join_type in [LeftAnti, RightAnti] { - let mut joined_batches = build_joined_record_batches()?; - let schema = joined_batches.batches.first().unwrap().schema(); + let metrics = join.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to occur" + ); - let output = concat_batches(&schema, &joined_batches.batches)?; - let out_mask = joined_batches.filter_mask.finish(); - let out_indices = joined_batches.row_indices.finish(); + // peak_mem_used should reflect the spill read-back: when buffered + // batches are read from disk during output materialization, grow() + // temporarily reserves size_estimation. This pushes peak above what + // join_arrays_mem alone would show. + let peak_mem = metrics + .sum_by_name("peak_mem_used") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert!( + peak_mem >= size_estimation, + "peak_mem_used ({peak_mem}) should be >= size_estimation ({size_estimation}) \ + because spill read-back temporarily loads full batch into memory" + ); - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![true]), - 1 - ) - .unwrap(), - BooleanArray::from(vec![None]) - ); + // All memory must be released (grow/shrink balanced) + assert_eq!( + runtime.memory_pool.reserved(), + 0, + "All memory should be released after join completes" + ); - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0]), - &[0usize], - &BooleanArray::from(vec![false]), - 1 - ) - .unwrap(), - BooleanArray::from(vec![Some(true)]) - ); + Ok(()) +} - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0]), - &[0usize; 2], - &BooleanArray::from(vec![true, true]), - 2 - ) - .unwrap(), - BooleanArray::from(vec![None, None]) - ); +/// Verifies spill read-back memory tracking for the single-source path. +/// +/// When only ONE buffered batch exists for a key group and it's spilled, +/// `fetch_right_columns_by_idxs` reads it back. A scoped `MemoryReservation` +/// (via `new_empty()`) tracks the transient memory and releases it on drop. +#[tokio::test] +async fn spill_read_back_single_source() -> Result<()> { + use arrow::array::Array; - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, true, true]), - 3 + let left_batch = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let size_estimation = left_batch.get_array_memory_size() + + Int32Array::from(vec![1, 1]).get_array_memory_size() + + 2usize.next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); + + // Memory limit too small for a full batch — forces spilling. + let memory_limit = size_estimation / 2; + + // Multiple distinct keys so each key group has exactly ONE buffered batch. + // This ensures the single-source path is exercised. + let left_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a1", &vec![i * 2, i * 2 + 1]), + ("b1", &vec![i, i]), + ("c1", &vec![100 + i, 101 + i]), ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); - - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![true, false, true]), - 3 + }) + .collect(); + let left = build_table_from_batches(left_batches); + + // One batch per key — each key group has single source + let right_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a2", &vec![i * 2, i * 2 + 1]), + ("b2", &vec![i, i]), + ("c2", &vec![200 + i, 201 + i]), ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); + }) + .collect(); + let right = build_table_from_batches(right_batches); - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, true]), - 3 - ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, true, true]), - 3 - ) - .unwrap(), - BooleanArray::from(vec![None, None, None]) - ); + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; - assert_eq!( - get_corrected_filter_mask( - join_type, - &UInt64Array::from(vec![0, 0, 0]), - &[0usize; 3], - &BooleanArray::from(vec![false, false, false]), - 3 - ) - .unwrap(), - BooleanArray::from(vec![None, None, Some(true)]) - ); + let session_config = SessionConfig::default().with_batch_size(50); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(Arc::clone(&runtime)), + ); - let corrected_mask = get_corrected_filter_mask( - join_type, - &out_indices, - &joined_batches.batch_ids, - &out_mask, - output.num_rows(), - ) - .unwrap(); + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Inner, + sort_options, + NullEquality::NullEqualsNothing, + )?; - assert_eq!( - corrected_mask, - BooleanArray::from(vec![ - None, - None, - None, - None, - None, - Some(true), - None, - Some(true) - ]) - ); + let stream = join.execute(0, task_ctx)?; + let result = common::collect(stream).await.unwrap(); - let filtered_rb = filter_record_batch(&output, &corrected_mask)?; - - allow_duplicates! { - assert_snapshot!(batches_to_string(&[filtered_rb]), @r#" - +---+----+---+----+ - | a | b | x | y | - +---+----+---+----+ - | 1 | 13 | 1 | 12 | - | 1 | 14 | 1 | 11 | - +---+----+---+----+ - "#); - } + assert!(!result.is_empty(), "Expected non-empty join result"); - // output null rows - let null_mask = arrow::compute::not(&corrected_mask)?; - assert_eq!( - null_mask, - BooleanArray::from(vec![ - None, - None, - None, - None, - None, - Some(false), - None, - Some(false), - ]) - ); + let metrics = join.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to occur" + ); - let null_joined_batch = filter_record_batch(&output, &null_mask)?; + // peak_mem_used should reflect the single-batch read-back + let peak_mem = metrics + .sum_by_name("peak_mem_used") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert!( + peak_mem >= size_estimation, + "peak_mem_used ({peak_mem}) should be >= size_estimation ({size_estimation}) \ + because single-source spill read-back loads full batch" + ); - allow_duplicates! { - assert_snapshot!(batches_to_string(&[null_joined_batch]), @r#" - +---+---+---+---+ - | a | b | x | y | - +---+---+---+---+ - +---+---+---+---+ - "#); - } - } + // All memory must be released + assert_eq!( + runtime.memory_pool.reserved(), + 0, + "All memory should be released after join completes" + ); Ok(()) } - -/// Returns the column names on the schema -fn columns(schema: &Schema) -> Vec { - schema.fields().iter().map(|f| f.name().clone()).collect() -} diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index f4a3cd92f16da..571c199abb448 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -22,24 +22,26 @@ use std::collections::{HashMap, VecDeque}; use std::mem::size_of; use std::sync::Arc; +use crate::joins::MapOffset; use crate::joins::join_hash_map::{ - get_matched_indices, get_matched_indices_with_limit_offset, update_from_iter, - JoinHashMapOffset, + contain_hashes, get_matched_indices, get_matched_indices_with_limit_offset, + update_from_iter, }; use crate::joins::utils::{JoinFilter, JoinHashMapType}; -use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder}; -use crate::{metrics, ExecutionPlan}; +use crate::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricCategory, +}; +use crate::{ExecutionPlan, metrics}; use arrow::array::{ - ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, + ArrowPrimitiveType, BooleanArray, BooleanBufferBuilder, NativeAdapter, + PrimitiveArray, RecordBatch, }; use arrow::compute::concat_batches; use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::utils::memory::estimate_memory_size; -use datafusion_common::{ - arrow_datafusion_err, DataFusionError, HashSet, JoinSide, Result, ScalarValue, -}; +use datafusion_common::{HashSet, JoinSide, Result, ScalarValue, arrow_datafusion_err}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; @@ -79,8 +81,10 @@ impl JoinHashMapType for PruningJoinHashMap { &self, hash_values: &[u64], limit: usize, - offset: JoinHashMapOffset, - ) -> (Vec, Vec, Option) { + offset: MapOffset, + input_indices: &mut Vec, + match_indices: &mut Vec, + ) -> Option { // Flatten the deque let next: Vec = self.next.iter().copied().collect(); get_matched_indices_with_limit_offset::( @@ -89,12 +93,22 @@ impl JoinHashMapType for PruningJoinHashMap { hash_values, limit, offset, + input_indices, + match_indices, ) } + fn contain_hashes(&self, hash_values: &[u64]) -> BooleanArray { + contain_hashes(&self.map, hash_values) + } + fn is_empty(&self) -> bool { self.map.is_empty() } + + fn len(&self) -> usize { + self.map.len() + } } /// The `PruningJoinHashMap` is similar to a regular `JoinHashMap`, but with @@ -363,7 +377,7 @@ fn convert_filter_columns( column_map: &HashMap, ) -> Result>> { // Attempt to downcast the input expression to a Column type. - Ok(if let Some(col) = input.as_any().downcast_ref::() { + Ok(if let Some(col) = input.downcast_ref::() { // If the downcast is successful, retrieve the corresponding filter column. column_map.get(col).map(|c| Arc::new(c.clone()) as _) } else { @@ -688,24 +702,31 @@ pub struct StreamJoinMetrics { impl StreamJoinMetrics { pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { - let input_batches = - MetricBuilder::new(metrics).counter("input_batches", partition); - let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let input_batches = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("left_input_batches", partition); + let input_rows = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("left_input_rows", partition); let left = StreamJoinSideMetrics { input_batches, input_rows, }; - let input_batches = - MetricBuilder::new(metrics).counter("input_batches", partition); - let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let input_batches = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("right_input_batches", partition); + let input_rows = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("right_input_rows", partition); let right = StreamJoinSideMetrics { input_batches, input_rows, }; - let stream_memory_usage = - MetricBuilder::new(metrics).gauge("stream_memory_usage", partition); + let stream_memory_usage = MetricBuilder::new(metrics) + .with_category(MetricCategory::Bytes) + .gauge("stream_memory_usage", partition); Self { left, @@ -1014,46 +1035,54 @@ pub mod tests { let left_schema = Arc::new(left_schema); let right_schema = Arc::new(right_schema); - assert!(build_filter_input_order( - JoinSide::Left, - &filter, - &left_schema, - &PhysicalSortExpr { - expr: col("la1", left_schema.as_ref())?, - options: SortOptions::default(), - } - )? - .is_some()); - assert!(build_filter_input_order( - JoinSide::Left, - &filter, - &left_schema, - &PhysicalSortExpr { - expr: col("lt1", left_schema.as_ref())?, - options: SortOptions::default(), - } - )? - .is_none()); - assert!(build_filter_input_order( - JoinSide::Right, - &filter, - &right_schema, - &PhysicalSortExpr { - expr: col("ra1", right_schema.as_ref())?, - options: SortOptions::default(), - } - )? - .is_some()); - assert!(build_filter_input_order( - JoinSide::Right, - &filter, - &right_schema, - &PhysicalSortExpr { - expr: col("rb1", right_schema.as_ref())?, - options: SortOptions::default(), - } - )? - .is_none()); + assert!( + build_filter_input_order( + JoinSide::Left, + &filter, + &left_schema, + &PhysicalSortExpr { + expr: col("la1", left_schema.as_ref())?, + options: SortOptions::default(), + } + )? + .is_some() + ); + assert!( + build_filter_input_order( + JoinSide::Left, + &filter, + &left_schema, + &PhysicalSortExpr { + expr: col("lt1", left_schema.as_ref())?, + options: SortOptions::default(), + } + )? + .is_none() + ); + assert!( + build_filter_input_order( + JoinSide::Right, + &filter, + &right_schema, + &PhysicalSortExpr { + expr: col("ra1", right_schema.as_ref())?, + options: SortOptions::default(), + } + )? + .is_some() + ); + assert!( + build_filter_input_order( + JoinSide::Right, + &filter, + &right_schema, + &PhysicalSortExpr { + expr: col("rb1", right_schema.as_ref())?, + options: SortOptions::default(), + } + )? + .is_none() + ); Ok(()) } diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 9c778ad131846..ef92964fadf84 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -25,36 +25,37 @@ //! This plan uses the [`OneSideHashJoiner`] object to facilitate join calculations //! for both its children. -use std::any::Any; use std::fmt::{self, Debug}; use std::mem::{size_of, size_of_val}; use std::sync::Arc; use std::task::{Context, Poll}; use std::vec; +use crate::check_if_same_properties; use crate::common::SharedMemoryReservation; use crate::execution_plan::{boundedness_from_children, emission_type_from_children}; use crate::joins::stream_join_utils::{ + PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics, calculate_filter_expr_intervals, combine_two_batches, convert_sort_expr_with_filter_schema, get_pruning_anti_indices, get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices, - PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics, }; use crate::joins::utils::{ - apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, equal_rows_arr, symmetric_join_output_partitioning, update_hash, BatchSplitter, BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn, - JoinOnRef, NoopBatchTransformer, StatefulStreamResult, + JoinOnRef, NoopBatchTransformer, StatefulStreamResult, apply_join_filter_to_indices, + build_batch_from_indices, build_join_schema, check_join_is_valid, equal_rows_arr, + symmetric_join_output_partitioning, update_hash, }; use crate::projection::{ - join_allows_pushdown, join_table_borders, new_join_children, - physical_to_column_exprs, update_join_filter, update_join_on, ProjectionExec, + ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children, + physical_to_column_exprs, update_join_filter, update_join_on, }; +use crate::stream::EmptyRecordBatchStream; use crate::{ + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, + PlanProperties, RecordBatchStream, SendableRecordBatchStream, joins::StreamJoinPartitionMode, metrics::{ExecutionPlanMetricsSet, MetricsSet}, - DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, - PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; use arrow::array::{ @@ -67,21 +68,20 @@ use arrow::record_batch::RecordBatch; use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::bisect; use datafusion_common::{ - assert_eq_or_internal_err, plan_err, HashSet, JoinSide, JoinType, NullEquality, - Result, + HashSet, JoinSide, JoinType, NullEquality, Result, assert_eq_or_internal_err, + plan_err, }; -use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; -use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; +use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql}; use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; -use ahash::RandomState; +use datafusion_common::hash_utils::RandomState; use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays; -use futures::{ready, Stream, StreamExt}; -use parking_lot::Mutex; +use futures::{Stream, StreamExt, ready}; const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4; @@ -197,7 +197,7 @@ pub struct SymmetricHashJoinExec { /// Partition Mode mode: StreamJoinPartitionMode, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl SymmetricHashJoinExec { @@ -207,7 +207,7 @@ impl SymmetricHashJoinExec { /// - It is not possible to join the left and right sides on keys `on`, or /// - It fails to construct `SortedFilterExpr`s, or /// - It fails to create the [ExprIntervalGraph]. - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] pub fn try_new( left: Arc, right: Arc, @@ -237,7 +237,7 @@ impl SymmetricHashJoinExec { build_join_schema(&left_schema, &right_schema, join_type); // Initialize the random state for the join operation: - let random_state = RandomState::with_seeds(0, 0, 0, 0); + let random_state = RandomState::with_seed(0); let schema = Arc::new(schema); let cache = Self::compute_properties(&left, &right, schema, *join_type, &on)?; Ok(SymmetricHashJoinExec { @@ -253,7 +253,7 @@ impl SymmetricHashJoinExec { left_sort_exprs, right_sort_exprs, mode, - cache, + cache: Arc::new(cache), }) } @@ -360,6 +360,20 @@ impl SymmetricHashJoinExec { } Ok(false) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + let left = children.swap_remove(0); + let right = children.swap_remove(0); + Self { + left, + right, + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for SymmetricHashJoinExec { @@ -407,11 +421,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { "SymmetricHashJoinExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -453,6 +463,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); Ok(Arc::new(SymmetricHashJoinExec::try_new( Arc::clone(&children[0]), Arc::clone(&children[1]), @@ -470,11 +481,6 @@ impl ExecutionPlan for SymmetricHashJoinExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - // TODO stats: it is not possible in general to know the output size of joins - Ok(Statistics::new_unknown(&self.schema())) - } - fn execute( &self, partition: usize, @@ -525,12 +531,12 @@ impl ExecutionPlan for SymmetricHashJoinExec { let enforce_batch_size_in_joins = context.session_config().enforce_batch_size_in_joins(); - let reservation = Arc::new(Mutex::new( + let reservation = Arc::new( MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]")) .register(context.memory_pool()), - )); + ); if let Some(g) = graph.as_ref() { - reservation.lock().try_grow(g.size())?; + reservation.try_grow(g.size())?; } if enforce_batch_size_in_joins { @@ -930,6 +936,7 @@ pub(crate) fn build_side_determined_results( &probe_indices, column_indices, build_hash_joiner.build_side, + join_type, ) .map(|batch| (batch.num_rows() > 0).then_some(batch)) } else { @@ -957,7 +964,7 @@ pub(crate) fn build_side_determined_results( /// /// A [Result] containing an optional record batch if the join type is not one of `LeftAnti`, `RightAnti`, `LeftSemi` or `RightSemi`. /// If the join type is one of the above four, the function will return [None]. -#[allow(clippy::too_many_arguments)] +#[expect(clippy::too_many_arguments)] pub(crate) fn join_with_probe_batch( build_hash_joiner: &mut OneSideHashJoiner, probe_hash_joiner: &mut OneSideHashJoiner, @@ -993,6 +1000,7 @@ pub(crate) fn join_with_probe_batch( filter, build_hash_joiner.build_side, None, + join_type, )? } else { (build_indices, probe_indices) @@ -1031,6 +1039,7 @@ pub(crate) fn join_with_probe_batch( &probe_indices, column_indices, build_hash_joiner.build_side, + join_type, ) .map(|batch| (batch.num_rows() > 0).then_some(batch)) } @@ -1055,7 +1064,7 @@ pub(crate) fn join_with_probe_batch( /// /// A [Result] containing a tuple with two equal length arrays, representing indices of rows from build and probe side, /// matched by join key columns. -#[allow(clippy::too_many_arguments)] +#[expect(clippy::too_many_arguments)] fn lookup_join_hashmap( build_hashmap: &PruningJoinHashMap, build_batch: &RecordBatch, @@ -1380,6 +1389,19 @@ impl SymmetricHashJoinStream { } } } + + /// Release the right input pipeline's resources. + fn cleanup_depleted_right_stream(&mut self) { + let right_schema = self.right_stream.schema(); + self.right_stream = Box::pin(EmptyRecordBatchStream::new(right_schema)); + } + + /// Release the left input pipeline's resources. + fn cleanup_depleted_left_stream(&mut self) { + let left_schema = self.left_stream.schema(); + self.left_stream = Box::pin(EmptyRecordBatchStream::new(left_schema)); + } + /// Asynchronously pulls the next batch from the right stream. /// /// This default implementation checks for the next value in the right stream. @@ -1403,6 +1425,7 @@ impl SymmetricHashJoinStream { } Some(Err(e)) => Poll::Ready(Err(e)), None => { + self.cleanup_depleted_right_stream(); self.set_state(SHJStreamState::RightExhausted); Poll::Ready(Ok(StatefulStreamResult::Continue)) } @@ -1432,6 +1455,7 @@ impl SymmetricHashJoinStream { } Some(Err(e)) => Poll::Ready(Err(e)), None => { + self.cleanup_depleted_left_stream(); self.set_state(SHJStreamState::LeftExhausted); Poll::Ready(Ok(StatefulStreamResult::Continue)) } @@ -1461,6 +1485,7 @@ impl SymmetricHashJoinStream { } Some(Err(e)) => Poll::Ready(Err(e)), None => { + self.cleanup_depleted_left_stream(); self.set_state(SHJStreamState::BothExhausted { final_result: false, }); @@ -1492,6 +1517,7 @@ impl SymmetricHashJoinStream { } Some(Err(e)) => Poll::Ready(Err(e)), None => { + self.cleanup_depleted_right_stream(); self.set_state(SHJStreamState::BothExhausted { final_result: false, }); @@ -1718,7 +1744,7 @@ impl SymmetricHashJoinStream { let result = combine_two_batches(&self.schema, equal_result, anti_result)?; let capacity = self.size(); self.metrics.stream_memory_usage.set(capacity); - self.reservation.lock().try_resize(capacity)?; + self.reservation.try_resize(capacity)?; Ok(result) } } @@ -1769,7 +1795,7 @@ mod tests { use datafusion_common::ScalarValue; use datafusion_execution::config::SessionConfig; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{binary, col, lit, Column}; + use datafusion_physical_expr::expressions::{Column, binary, col, lit}; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use rstest::*; diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index 58338bd860214..0455fb2a1eb6e 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -25,11 +25,11 @@ use crate::joins::{ }; use crate::repartition::RepartitionExec; use crate::test::TestMemoryExec; -use crate::{common, ExecutionPlan, ExecutionPlanProperties, Partitioning}; +use crate::{ExecutionPlan, ExecutionPlanProperties, Partitioning, common}; use arrow::array::{ - types::IntervalDayTime, ArrayRef, Float64Array, Int32Array, IntervalDayTimeArray, - RecordBatch, TimestampMillisecondArray, + ArrayRef, Float64Array, Int32Array, IntervalDayTimeArray, RecordBatch, + TimestampMillisecondArray, types::IntervalDayTime, }; use arrow::datatypes::{DataType, Schema}; use arrow::util::pretty::pretty_format_batches; @@ -152,6 +152,7 @@ pub async fn partitioned_hash_join_with_filter( None, PartitionMode::Partitioned, null_equality, + false, // null_aware )?); let mut batches = vec![]; diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 9087ac415f4b1..5687be04ad867 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -17,7 +17,7 @@ //! Join related functionality used both on logical and physical plans -use std::cmp::{min, Ordering}; +use std::cmp::{Ordering, min}; use std::collections::HashSet; use std::fmt::{self, Debug}; use std::future::Future; @@ -28,7 +28,8 @@ use std::task::{Context, Poll}; use crate::joins::SharedBitmapBuilder; use crate::metrics::{ - self, BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricType, + self, BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricCategory, + MetricType, }; use crate::projection::{ProjectionExec, ProjectionExpr}; use crate::{ @@ -39,47 +40,49 @@ pub use super::join_filter::JoinFilter; pub use super::join_hash_map::JoinHashMapType; pub use crate::joins::{JoinOn, JoinOnRef}; -use ahash::RandomState; use arrow::array::{ - builder::UInt64Builder, downcast_array, new_null_array, Array, ArrowPrimitiveType, - BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, RecordBatchOptions, - UInt32Array, UInt32Builder, UInt64Array, + Array, ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray, + RecordBatch, RecordBatchOptions, UInt32Array, UInt32Builder, UInt64Array, + builder::UInt64Builder, downcast_array, make_array, new_null_array, }; use arrow::array::{ ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, - Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, StringArray, + Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int8Array, + Int16Array, Int32Array, Int64Array, LargeBinaryArray, LargeStringArray, StringArray, StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt8Array, + TimestampNanosecondArray, TimestampSecondArray, UInt8Array, UInt16Array, }; use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::kernels::cmp::eq; -use arrow::compute::{self, and, take, FilterBuilder}; +use arrow::compute::{self, FilterBuilder, and, take}; use arrow::datatypes::{ ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type, }; use arrow_ord::cmp::not_distinct; +use arrow_ord::ord::{DynComparator, make_comparator}; use arrow_schema::{ArrowError, DataType, SortOptions, TimeUnit}; use datafusion_common::cast::as_boolean_array; +use datafusion_common::hash_utils::RandomState; use datafusion_common::hash_utils::create_hashes; use datafusion_common::stats::Precision; +use datafusion_common::utils::normalize_float_zero; use datafusion_common::{ - not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, NullEquality, Result, - SharedResult, + DataFusionError, JoinSide, JoinType, NullEquality, Result, SharedResult, + not_impl_err, plan_err, }; -use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::Operator; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ - add_offset_to_expr, add_offset_to_physical_sort_exprs, LexOrdering, PhysicalExpr, - PhysicalExprRef, + LexOrdering, PhysicalExpr, PhysicalExprRef, add_offset_to_expr, + add_offset_to_physical_sort_exprs, }; use datafusion_physical_expr_common::datum::compare_op_for_nested; use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays; use futures::future::{BoxFuture, Shared}; -use futures::{ready, FutureExt}; +use futures::{FutureExt, ready}; use parking_lot::Mutex; /// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. @@ -142,6 +145,13 @@ pub fn adjust_right_output_partitioning( .collect::>()?; Partitioning::Hash(new_exprs, *size) } + Partitioning::Range(_) => { + // Range partitioning optimizer propagation is tracked in + // https://github.com/apache/datafusion/issues/22395 + return not_impl_err!( + "Join output partitioning with range partitioning is not implemented" + ); + } result => result.clone(), }; Ok(result) @@ -159,20 +169,21 @@ pub fn calculate_join_output_ordering( match maintains_input_order { [true, false] => { // Special case, we can prefix ordering of right side with the ordering of left side. - if join_type == JoinType::Inner && probe_side == Some(JoinSide::Left) { - if let Some(right_ordering) = right_ordering.cloned() { - let right_offset = add_offset_to_physical_sort_exprs( - right_ordering, - left_columns_len as _, - )?; - return if let Some(left_ordering) = left_ordering { - let mut result = left_ordering.clone(); - result.extend(right_offset); - Ok(Some(result)) - } else { - Ok(LexOrdering::new(right_offset)) - }; - } + if join_type == JoinType::Inner + && probe_side == Some(JoinSide::Left) + && let Some(right_ordering) = right_ordering.cloned() + { + let right_offset = add_offset_to_physical_sort_exprs( + right_ordering, + left_columns_len as _, + )?; + return if let Some(left_ordering) = left_ordering { + let mut result = left_ordering.clone(); + result.extend(right_offset); + Ok(Some(result)) + } else { + Ok(LexOrdering::new(right_offset)) + }; } Ok(left_ordering.cloned()) } @@ -407,25 +418,66 @@ impl Clone for OnceFut { #[derive(Clone, Debug, Default)] struct PartialJoinStatistics { pub num_rows: usize, + pub total_byte_size: Precision, pub column_statistics: Vec, } -/// Estimate the statistics for the given join's output. +/// Estimates the output statistics for a join operation based on input statistics. +/// +/// # Statistics Propagation +/// +/// This function estimates join output statistics using the following approach: +/// - **Row count estimation**: Uses the `on` parameter (equijoin keys) to estimate +/// output cardinality via [`estimate_join_cardinality`]. The estimation is based on +/// column-level statistics (distinct counts, min/max values) of the join keys. +/// - **Column statistics**: Combines column statistics from both inputs. For join types +/// that preserve all columns (Inner, Left, Right, Full), statistics from both sides +/// are concatenated. For semi/anti joins, the preserved side's statistics are +/// normalized as subset estimates. +/// - **Byte size**: For semi/anti joins, sums normalized column byte-size estimates +/// when every output column has one. Other join types return `Precision::Absent` +/// because join output size is difficult to estimate without knowing the actual data. +/// +/// # The `on` Parameter +/// +/// The `on` parameter represents equijoin keys (e.g., `t1.id = t2.id`). When `on` is +/// empty (as in NestedLoopJoinExec which handles non-equijoin predicates), the +/// cardinality estimation cannot compute selectivity from join keys, and this function +/// returns unknown statistics (`num_rows: Precision::Absent`). +/// +/// # Limitations +/// +/// - Does not account for selectivity of arbitrary join filter expressions +/// (e.g., `(t1.v1 + t2.v1) % 2 = 0`). Such filters, common in NestedLoopJoinExec, +/// are not factored into the cardinality estimation. +/// - Column statistics for inner/outer joins are simply combined from inputs +/// without adjusting for join selectivity (acknowledged in the code as +/// needing "filter selectivity analysis"). pub(crate) fn estimate_join_statistics( left_stats: Statistics, right_stats: Statistics, on: &JoinOn, + null_equality: NullEquality, join_type: &JoinType, schema: &Schema, ) -> Result { - let join_stats = estimate_join_cardinality(join_type, left_stats, right_stats, on); - let (num_rows, column_statistics) = match join_stats { - Some(stats) => (Precision::Inexact(stats.num_rows), stats.column_statistics), - None => (Precision::Absent, Statistics::unknown_column(schema)), + let join_stats = + estimate_join_cardinality(join_type, left_stats, right_stats, on, null_equality); + let (num_rows, total_byte_size, column_statistics) = match join_stats { + Some(stats) => ( + Precision::Inexact(stats.num_rows), + stats.total_byte_size, + stats.column_statistics, + ), + None => ( + Precision::Absent, + Precision::Absent, + Statistics::unknown_column(schema), + ), }; Ok(Statistics { num_rows, - total_byte_size: Precision::Absent, + total_byte_size, column_statistics, }) } @@ -436,23 +488,24 @@ fn estimate_join_cardinality( left_stats: Statistics, right_stats: Statistics, on: &JoinOn, + null_equality: NullEquality, ) -> Option { - let (left_col_stats, right_col_stats) = on + let on_column_indices = on .iter() - .map(|(left, right)| { - match ( - left.as_any().downcast_ref::(), - right.as_any().downcast_ref::(), - ) { - (Some(left), Some(right)) => ( - left_stats.column_statistics[left.index()].clone(), - right_stats.column_statistics[right.index()].clone(), - ), - _ => ( - ColumnStatistics::new_unknown(), - ColumnStatistics::new_unknown(), - ), - } + .map(|(left, right)| equijoin_column_indices(left, right)) + .collect::>(); + + let (left_key_stats, right_key_stats) = on_column_indices + .iter() + .map(|indices| match indices { + Some((left_index, right_index)) => ( + left_stats.column_statistics[*left_index].clone(), + right_stats.column_statistics[*right_index].clone(), + ), + None => ( + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ), }) .unzip::<_, _, Vec<_>, Vec<_>>(); @@ -462,12 +515,12 @@ fn estimate_join_cardinality( Statistics { num_rows: left_stats.num_rows, total_byte_size: Precision::Absent, - column_statistics: left_col_stats, + column_statistics: left_key_stats, }, Statistics { num_rows: right_stats.num_rows, total_byte_size: Precision::Absent, - column_statistics: right_col_stats, + column_statistics: right_key_stats, }, )?; @@ -488,6 +541,7 @@ fn estimate_join_cardinality( Some(PartialJoinStatistics { num_rows: *cardinality.get_value()?, + total_byte_size: Precision::Absent, // We don't do anything specific here, just combine the existing // statistics which might yield subpar results (although it is // true, esp regarding min/max). For a better estimation, we need @@ -500,36 +554,88 @@ fn estimate_join_cardinality( }) } - // For SemiJoins estimation result is either zero, in cases when inputs - // are non-overlapping according to statistics, or equal to number of rows - // for outer input - JoinType::LeftSemi | JoinType::RightSemi => { - let (outer_stats, inner_stats) = match join_type { - JoinType::LeftSemi => (left_stats, right_stats), - _ => (right_stats, left_stats), + JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti => { + let is_left = matches!(join_type, JoinType::LeftSemi | JoinType::LeftAnti); + let is_anti = matches!(join_type, JoinType::LeftAnti | JoinType::RightAnti); + + let (outer_stats, inner_stats, outer_key_stats, inner_key_stats) = if is_left + { + (left_stats, right_stats, left_key_stats, right_key_stats) + } else { + (right_stats, left_stats, right_key_stats, left_key_stats) + }; + + let outer_rows = *outer_stats.num_rows.get_value()?; + + let outer_join_key_stats = Statistics { + num_rows: outer_stats.num_rows, + total_byte_size: Precision::Absent, + column_statistics: outer_key_stats.clone(), }; - let cardinality = match estimate_disjoint_inputs(&outer_stats, &inner_stats) { - Some(estimation) => *estimation.get_value()?, - None => *outer_stats.num_rows.get_value()?, + let inner_join_key_stats = Statistics { + num_rows: inner_stats.num_rows, + total_byte_size: Precision::Absent, + column_statistics: inner_key_stats.clone(), }; - Some(PartialJoinStatistics { - num_rows: cardinality, - column_statistics: outer_stats.column_statistics, - }) - } + let semi_cardinality = + if estimate_disjoint_inputs(&outer_join_key_stats, &inner_join_key_stats) + .is_some() + { + // If join keys are disjoint, no rows will match + Some(0) + } else { + estimate_semi_join_cardinality( + &outer_stats.num_rows, + &inner_stats.num_rows, + &outer_key_stats, + &inner_key_stats, + null_equality, + ) + }; - // For AntiJoins estimation always equals to outer statistics, as - // non-overlapping inputs won't affect estimation - JoinType::LeftAnti | JoinType::RightAnti => { - let outer_stats = match join_type { - JoinType::LeftAnti => left_stats, - _ => right_stats, + // Semi joins keep the matching rows; anti joins keep the rest. When no + // estimate is available, conservatively assume all outer rows pass. + let cardinality = match (semi_cardinality, is_anti) { + (Some(semi), true) => outer_rows.saturating_sub(semi), + (Some(semi), false) => semi, + (None, _) => outer_rows, }; + // The outer side is the one whose columns a semi/anti join emits, so + // its statistics are the ones to normalize into the subset estimate. + let Statistics { + num_rows: preserved_num_rows, + column_statistics: preserved_column_statistics, + .. + } = outer_stats; + let preserved_join_key_indices = on_column_indices + .iter() + .filter_map(|&indices| { + indices.map( + |(left_index, right_index)| { + if is_left { left_index } else { right_index } + }, + ) + }) + .collect::>(); + let column_statistics = normalize_semi_anti_join_column_statistics( + preserved_column_statistics, + &preserved_num_rows, + cardinality, + &preserved_join_key_indices, + is_anti, + null_equality, + ); + let total_byte_size = + total_byte_size_from_column_statistics(&column_statistics); Some(PartialJoinStatistics { - num_rows: *outer_stats.num_rows.get_value()?, - column_statistics: outer_stats.column_statistics, + num_rows: cardinality, + total_byte_size, + column_statistics, }) } @@ -539,6 +645,7 @@ fn estimate_join_cardinality( column_statistics.push(ColumnStatistics::new_unknown()); Some(PartialJoinStatistics { num_rows, + total_byte_size: Precision::Absent, column_statistics, }) } @@ -548,12 +655,132 @@ fn estimate_join_cardinality( column_statistics.push(ColumnStatistics::new_unknown()); Some(PartialJoinStatistics { num_rows, + total_byte_size: Precision::Absent, column_statistics, }) } } } +fn equijoin_column_indices( + left: &PhysicalExprRef, + right: &PhysicalExprRef, +) -> Option<(usize, usize)> { + Some(( + left.downcast_ref::()?.index(), + right.downcast_ref::()?.index(), + )) +} + +/// Adjusts the preserved input's column statistics to describe the subset of +/// rows a semi or anti join emits. Most values become estimates (marked +/// inexact) bounded by the smaller output row count: +/// +/// - `null_count` and `byte_size` are scaled by the output/input row ratio. +/// - `distinct_count` is capped at the number of non-null output rows. +/// - `sum_value` is dropped, since the input sum does not apply to the subset. +/// +/// Join-key columns are the exception for `null_count`: under regular SQL +/// equality, null keys never match, so a semi join keeps none of those rows and +/// an anti join keeps all of them. Under null-equal joins, null keys can match +/// and are treated like the rest of the subset. +fn normalize_semi_anti_join_column_statistics( + column_statistics: Vec, + input_num_rows: &Precision, + output_num_rows: usize, + join_key_indices: &[usize], + is_anti: bool, + null_equality: NullEquality, +) -> Vec { + let input_num_rows = input_num_rows.get_value().copied().unwrap_or(0); + + column_statistics + .into_iter() + .enumerate() + .map(|(idx, stats)| { + let mut stats = stats.to_inexact(); + stats.null_count = if join_key_indices.contains(&idx) { + normalize_semi_anti_join_key_null_count( + stats.null_count, + input_num_rows, + output_num_rows, + is_anti, + null_equality, + ) + } else { + scale_subset_count(stats.null_count, input_num_rows, output_num_rows) + .min(&Precision::Inexact(output_num_rows)) + }; + let max_distinct_count = stats + .null_count + .get_value() + .map(|null_count| output_num_rows.saturating_sub(*null_count)) + .unwrap_or(output_num_rows); + stats.distinct_count = stats + .distinct_count + .min(&Precision::Inexact(max_distinct_count)); + stats.byte_size = + scale_subset_count(stats.byte_size, input_num_rows, output_num_rows); + stats.sum_value = Precision::Absent; + stats + }) + .collect() +} + +fn normalize_semi_anti_join_key_null_count( + null_count: Precision, + input_num_rows: usize, + output_num_rows: usize, + is_anti: bool, + null_equality: NullEquality, +) -> Precision { + match (is_anti, null_equality) { + (false, NullEquality::NullEqualsNothing) => Precision::Exact(0), + (true, NullEquality::NullEqualsNothing) => null_count + .to_inexact() + .min(&Precision::Inexact(output_num_rows)), + (_, NullEquality::NullEqualsNull) => { + scale_subset_count(null_count, input_num_rows, output_num_rows) + .min(&Precision::Inexact(output_num_rows)) + } + } +} + +// Scale a column-level count to an estimated row subset. Rounding up keeps a +// small non-zero count from disappearing solely because the subset is small. +fn scale_subset_count( + count: Precision, + input_num_rows: usize, + output_num_rows: usize, +) -> Precision { + let scaled = match count { + Precision::Exact(count) | Precision::Inexact(count) => { + if input_num_rows == 0 { + 0 + } else { + (count as u128 * output_num_rows as u128).div_ceil(input_num_rows as u128) + as usize + } + } + Precision::Absent => return Precision::Absent, + }; + + Precision::Inexact(scaled) +} + +fn total_byte_size_from_column_statistics( + column_statistics: &[ColumnStatistics], +) -> Precision { + column_statistics + .iter() + .map(|stats| stats.byte_size.get_value().copied()) + .try_fold(0usize, |acc, byte_size| { + byte_size.map(|byte_size| acc.saturating_add(byte_size)) + }) + .map(Precision::Inexact) + .unwrap_or(Precision::Absent) +} + /// Estimate the inner join cardinality by using the basic building blocks of /// column-level statistics and the total row count. This is a very naive and /// a very conservative implementation that can quickly give up if there is not @@ -578,8 +805,15 @@ fn estimate_inner_join_cardinality( .. } = right_stats; - // The algorithm here is partly based on the non-histogram selectivity estimation - // from Spark's Catalyst optimizer. + if left_num_rows == Precision::Exact(0) || right_num_rows == Precision::Exact(0) { + return Some(Precision::Exact(0)); + } + if left_num_rows == Precision::Inexact(0) || right_num_rows == Precision::Inexact(0) { + return Some(Precision::Inexact(0)); + } + + // Follow Spark Catalyst's conservative NDV join estimate: for multi-key + // joins, use the most selective key instead of multiplying all key denominators. let mut join_selectivity = Precision::Absent; for (left_stat, right_stat) in left_column_statistics .iter() @@ -592,7 +826,11 @@ fn estimate_inner_join_cardinality( // Seems like there are a few implementations of this algorithm that implement // exponential decay for the selectivity (like Hive's Optiq Optimizer). Needs // further exploration. - join_selectivity = max_distinct; + join_selectivity = if join_selectivity.get_value().is_some() { + join_selectivity.max(&max_distinct) + } else { + max_distinct + }; } } @@ -668,6 +906,106 @@ fn estimate_disjoint_inputs( None } +/// Estimates the number of outer rows that have at least one matching +/// key on the inner side (i.e. semi join cardinality) using NDV +/// (Number of Distinct Values) statistics. +/// +/// Assuming the smaller domain is contained in the larger, the number +/// of overlapping distinct values is `min(outer_ndv, inner_ndv)`. +/// Under the uniformity assumption (each distinct value contributes +/// equally to row counts), the surviving fraction of outer rows is: +/// +/// Under regular SQL equality, null rows cannot match, so each column's +/// selectivity is further reduced by the outer null fraction: +/// +/// ```text +/// null_frac_i = outer_null_count_i / outer_rows +/// selectivity_i = min(outer_ndv_i, inner_ndv_i) / outer_ndv_i * (1 - null_frac_i) +/// ``` +/// +/// For multi-column join keys the overall selectivity is the product +/// of per-column factors: +/// +/// ```text +/// semi_cardinality = outer_rows * product_i(selectivity_i) +/// ``` +/// +/// Anti join cardinality is derived as the complement: +/// `outer_rows - semi_cardinality`. +/// +/// With `NullEqualsNothing`, boundary cases are: +/// * `inner_ndv >= outer_ndv` → selectivity = `1.0 - null_frac` +/// * `null_frac = 1.0` → selectivity = 0.0 (no non-null rows can match) +/// * Missing NDV statistics → returns `None` (fallback to `outer_rows`) +/// +/// PostgreSQL uses a similar approach in `eqjoinsel_semi` +/// (`src/backend/utils/adt/selfuncs.c`). When NDV statistics are +/// available on both sides it computes selectivity as `nd2 / nd1`, +/// which is equivalent to `min(outer_ndv, inner_ndv) / outer_ndv`. +/// If either side lacks statistics it falls back to a default. +fn estimate_semi_join_cardinality( + outer_num_rows: &Precision, + inner_num_rows: &Precision, + outer_key_stats: &[ColumnStatistics], + inner_key_stats: &[ColumnStatistics], + null_equality: NullEquality, +) -> Option { + let outer_rows = *outer_num_rows.get_value()?; + if outer_rows == 0 { + return Some(0); + } + let inner_rows = *inner_num_rows.get_value()?; + if inner_rows == 0 { + return Some(0); + } + + let mut selectivity = 1.0_f64; + let mut has_selectivity_estimate = false; + + for (outer_stat, inner_stat) in outer_key_stats.iter().zip(inner_key_stats.iter()) { + let outer_has_stats = outer_stat.distinct_count.get_value().is_some() + || (outer_stat.min_value.get_value().is_some() + && outer_stat.max_value.get_value().is_some()); + let inner_has_stats = inner_stat.distinct_count.get_value().is_some() + || (inner_stat.min_value.get_value().is_some() + && inner_stat.max_value.get_value().is_some()); + if !outer_has_stats || !inner_has_stats { + continue; + } + + let outer_ndv = max_distinct_count(outer_num_rows, outer_stat); + let inner_ndv = max_distinct_count(inner_num_rows, inner_stat); + + if let (Some(&o), Some(&i)) = (outer_ndv.get_value(), inner_ndv.get_value()) + && o > 0 + { + let null_frac = if null_equality == NullEquality::NullEqualsNothing { + outer_stat + .null_count + .get_value() + .map(|&nc| { + if nc > outer_rows { + 0.0 + } else { + nc as f64 / outer_rows as f64 + } + }) + .unwrap_or(0.0) + } else { + 0.0 + }; + selectivity *= (o.min(i) as f64) / (o as f64) * (1.0 - null_frac); + has_selectivity_estimate = true; + } + } + + if has_selectivity_estimate { + Some((outer_rows as f64 * selectivity).ceil() as usize) + } else { + None + } +} + /// Estimate the number of maximum distinct values that can be present in the /// given column from its statistics. If distinct_count is available, uses it /// directly. Otherwise, if the column is numeric and has min/max values, it @@ -678,7 +1016,19 @@ fn max_distinct_count( stats: &ColumnStatistics, ) -> Precision { match &stats.distinct_count { - &dc @ (Precision::Exact(_) | Precision::Inexact(_)) => dc, + &dc @ (Precision::Exact(_) | Precision::Inexact(_)) => { + // NDV can never exceed the number of rows + match num_rows { + Precision::Absent => dc, + _ => { + if dc.get_value() <= num_rows.get_value() { + dc + } else { + num_rows.to_inexact() + } + } + } + } _ => { // The number can never be greater than the number of rows we have // minus the nulls (since they don't count as distinct values). @@ -693,38 +1043,37 @@ fn max_distinct_count( } } Precision::Exact(count) => { - let count = count - stats.null_count.get_value().unwrap_or(&0); + let null_count = *stats.null_count.get_value().unwrap_or(&0); + let non_null_count = count.checked_sub(null_count).unwrap_or(0); if stats.null_count.is_exact().unwrap_or(false) { - Precision::Exact(count) + Precision::Exact(non_null_count) } else { - Precision::Inexact(count) + Precision::Inexact(non_null_count) } } }; // Cap the estimate using the number of possible values: if let (Some(min), Some(max)) = (stats.min_value.get_value(), stats.max_value.get_value()) - { - if let Some(range_dc) = Interval::try_new(min.clone(), max.clone()) + && let Some(range_dc) = Interval::try_new(min.clone(), max.clone()) .ok() .and_then(|e| e.cardinality()) + { + let range_dc = range_dc as usize; + // Note that the `unwrap` calls in the below statement are safe. + return if result == Precision::Absent + || &range_dc < result.get_value().unwrap() { - let range_dc = range_dc as usize; - // Note that the `unwrap` calls in the below statement are safe. - return if matches!(result, Precision::Absent) - || &range_dc < result.get_value().unwrap() + if stats.min_value.is_exact().unwrap() + && stats.max_value.is_exact().unwrap() { - if stats.min_value.is_exact().unwrap() - && stats.max_value.is_exact().unwrap() - { - Precision::Exact(range_dc) - } else { - Precision::Inexact(range_dc) - } + Precision::Exact(range_dc) } else { - result - }; - } + Precision::Inexact(range_dc) + } + } else { + result + }; } result @@ -883,6 +1232,7 @@ pub(crate) fn get_final_indices_from_bit_map( (left_indices, right_indices) } +#[expect(clippy::too_many_arguments)] pub(crate) fn apply_join_filter_to_indices( build_input_buffer: &RecordBatch, probe_batch: &RecordBatch, @@ -891,6 +1241,7 @@ pub(crate) fn apply_join_filter_to_indices( filter: &JoinFilter, build_side: JoinSide, max_intermediate_size: Option, + join_type: JoinType, ) -> Result<(UInt64Array, UInt32Array)> { if build_indices.is_empty() && probe_indices.is_empty() { return Ok((build_indices, probe_indices)); @@ -911,6 +1262,7 @@ pub(crate) fn apply_join_filter_to_indices( &probe_indices.slice(i, len), filter.column_indices(), build_side, + join_type, )?; let filter_result = filter .expression() @@ -932,6 +1284,7 @@ pub(crate) fn apply_join_filter_to_indices( &probe_indices, filter.column_indices(), build_side, + join_type, )?; filter @@ -950,8 +1303,20 @@ pub(crate) fn apply_join_filter_to_indices( )) } +/// Creates a [RecordBatch] with zero columns but the given row count. +/// Used when a join has an empty projection (e.g. `SELECT count(1) ...`). +fn new_empty_schema_batch(schema: &Schema, row_count: usize) -> Result { + let options = RecordBatchOptions::new().with_row_count(Some(row_count)); + Ok(RecordBatch::try_new_with_options( + Arc::new(schema.clone()), + vec![], + &options, + )?) +} + /// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`. /// The resulting batch has [Schema] `schema`. +#[expect(clippy::too_many_arguments)] pub(crate) fn build_batch_from_indices( schema: &Schema, build_input_buffer: &RecordBatch, @@ -960,17 +1325,17 @@ pub(crate) fn build_batch_from_indices( probe_indices: &UInt32Array, column_indices: &[ColumnIndex], build_side: JoinSide, + join_type: JoinType, ) -> Result { if schema.fields().is_empty() { - let options = RecordBatchOptions::new() - .with_match_field_names(true) - .with_row_count(Some(build_indices.len())); - - return Ok(RecordBatch::try_new_with_options( - Arc::new(schema.clone()), - vec![], - &options, - )?); + // For RightAnti and RightSemi joins, after `adjust_indices_by_join_type` + // the build_indices were untouched so only probe_indices hold the actual + // row count. + let row_count = match join_type { + JoinType::RightAnti | JoinType::RightSemi => probe_indices.len(), + _ => build_indices.len(), + }; + return new_empty_schema_batch(schema, row_count); } // build the columns of the new [RecordBatch]: @@ -1017,44 +1382,35 @@ pub(crate) fn build_batch_empty_build_side( column_indices: &[ColumnIndex], join_type: JoinType, ) -> Result { - match join_type { - // these join types only return data if the left side is not empty, so we return an - // empty RecordBatch - JoinType::Inner - | JoinType::Left - | JoinType::LeftSemi - | JoinType::RightSemi - | JoinType::LeftAnti - | JoinType::LeftMark => Ok(RecordBatch::new_empty(Arc::new(schema.clone()))), - - // the remaining joins will return data for the right columns and null for the left ones - JoinType::Right | JoinType::Full | JoinType::RightAnti | JoinType::RightMark => { - let num_rows = probe_batch.num_rows(); - let mut columns: Vec> = - Vec::with_capacity(schema.fields().len()); - - for column_index in column_indices { - let array = match column_index.side { - // left -> null array - JoinSide::Left => new_null_array( - build_batch.column(column_index.index).data_type(), - num_rows, - ), - // right -> respective right array - JoinSide::Right => Arc::clone(probe_batch.column(column_index.index)), - // right mark -> unset boolean array as there are no matches on the left side - JoinSide::None => Arc::new(BooleanArray::new( - BooleanBuffer::new_unset(num_rows), - None, - )), - }; + if join_type.empty_build_side_produces_empty_result() { + // These join types only return data if the left side is not empty. + return Ok(RecordBatch::new_empty(Arc::new(schema.clone()))); + } + + // The remaining joins return right-side rows and nulls for the left side. + let num_rows = probe_batch.num_rows(); + if schema.fields().is_empty() { + return new_empty_schema_batch(schema, num_rows); + } - columns.push(array); + let columns = column_indices + .iter() + .map(|column_index| match column_index.side { + // left -> null array + JoinSide::Left => new_null_array( + build_batch.column(column_index.index).data_type(), + num_rows, + ), + // right -> respective right array + JoinSide::Right => Arc::clone(probe_batch.column(column_index.index)), + // right mark -> unset boolean array as there are no matches on the left side + JoinSide::None => { + Arc::new(BooleanArray::new(BooleanBuffer::new_unset(num_rows), None)) } + }) + .collect(); - Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?) - } - } + Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?) } /// The input is the matched indices for left and right and @@ -1200,7 +1556,9 @@ pub(crate) fn append_right_indices( } } -/// Returns `range` indices which are not present in `input_indices` +/// Returns `range` indices which are not present in `input_indices`. +/// +/// `input_indices` must be sorted ascending and contain no nulls. pub(crate) fn get_anti_indices( range: Range, input_indices: &PrimitiveArray, @@ -1208,18 +1566,51 @@ pub(crate) fn get_anti_indices( where NativeAdapter: From<::Native>, { - let bitmap = build_range_bitmap(&range, input_indices); - let offset = range.start; + debug_assert_eq!( + input_indices.null_count(), + 0, + "get_anti_indices requires non-null input_indices" + ); + debug_assert!( + input_indices + .values() + .windows(2) + .all(|w| w[0].as_usize() <= w[1].as_usize()), + "get_anti_indices requires ascending input_indices" + ); - // get the anti index - (range) - .filter_map(|idx| { - (!bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx)) - }) - .collect() + let mut next_unmatched_idx = range.start; + let mut output: Vec = Vec::with_capacity(range.len()); + + for &v in input_indices.values() { + let idx = v.as_usize(); + + if idx < range.start { + continue; + } + if idx >= range.end { + break; + } + + if next_unmatched_idx < idx { + output.extend((next_unmatched_idx..idx).map(|idx| { + T::Native::from_usize(idx).expect("join index exceeds output index type") + })); + } + next_unmatched_idx = idx + 1; + } + + if next_unmatched_idx < range.end { + output.extend((next_unmatched_idx..range.end).map(|idx| { + T::Native::from_usize(idx).expect("join index exceeds output index type") + })); + } + PrimitiveArray::::new(output.into(), None) } -/// Returns intersection of `range` and `input_indices` omitting duplicates +/// Returns the intersection of `range` and `input_indices`, omitting duplicates. +/// +/// `input_indices` must be sorted ascending and contain no nulls. pub(crate) fn get_semi_indices( range: Range, input_indices: &PrimitiveArray, @@ -1227,14 +1618,38 @@ pub(crate) fn get_semi_indices( where NativeAdapter: From<::Native>, { - let bitmap = build_range_bitmap(&range, input_indices); - let offset = range.start; - // get the semi index - (range) - .filter_map(|idx| { - (bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx)) - }) - .collect() + debug_assert_eq!( + input_indices.null_count(), + 0, + "get_semi_indices requires non-null input_indices" + ); + debug_assert!( + input_indices + .values() + .windows(2) + .all(|w| w[0].as_usize() <= w[1].as_usize()), + "get_semi_indices requires ascending input_indices" + ); + + let mut prev_idx: Option = None; + let mut output = Vec::with_capacity(input_indices.len().min(range.len())); + + for &v in input_indices.values() { + let idx = v.as_usize(); + + if idx < range.start { + continue; + } + if idx >= range.end { + break; + } + + if prev_idx.replace(idx) != Some(idx) { + output.push(v); + } + } + + PrimitiveArray::::new(output.into(), None) } pub(crate) fn get_mark_indices( @@ -1300,7 +1715,7 @@ fn append_probe_indices_in_order( for (build_index, probe_index) in build_indices .values() .into_iter() - .zip(probe_indices.values().into_iter()) + .zip(probe_indices.values()) { // Append values between previous and current probe index with null build index: for value in prev_index..*probe_index { @@ -1373,26 +1788,32 @@ impl BuildProbeJoinMetrics { let build_time = MetricBuilder::new(metrics).subset_time("build_time", partition); - let build_input_batches = - MetricBuilder::new(metrics).counter("build_input_batches", partition); + let build_input_batches = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("build_input_batches", partition); - let build_input_rows = - MetricBuilder::new(metrics).counter("build_input_rows", partition); + let build_input_rows = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("build_input_rows", partition); - let build_mem_used = - MetricBuilder::new(metrics).gauge("build_mem_used", partition); + let build_mem_used = MetricBuilder::new(metrics) + .with_category(MetricCategory::Bytes) + .gauge("build_mem_used", partition); - let input_batches = - MetricBuilder::new(metrics).counter("input_batches", partition); + let input_batches = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("input_batches", partition); - let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let input_rows = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("input_rows", partition); let probe_hit_rate = MetricBuilder::new(metrics) - .with_type(MetricType::SUMMARY) + .with_type(MetricType::Summary) .ratio_metrics("probe_hit_rate", partition); let avg_fanout = MetricBuilder::new(metrics) - .with_type(MetricType::SUMMARY) + .with_type(MetricType::Summary) .ratio_metrics("avg_fanout", partition); Self { @@ -1647,7 +2068,7 @@ fn swap_reverting_projection( pub fn swap_join_projection( left_schema_len: usize, right_schema_len: usize, - projection: Option<&Vec>, + projection: Option<&[usize]>, join_type: &JoinType, ) -> Option> { match join_type { @@ -1658,7 +2079,7 @@ pub fn swap_join_projection( | JoinType::RightAnti | JoinType::RightSemi | JoinType::LeftMark - | JoinType::RightMark => projection.cloned(), + | JoinType::RightMark => projection.map(|p| p.to_vec()), _ => projection.map(|p| { p.iter() .map(|i| { @@ -1683,7 +2104,7 @@ pub fn swap_join_projection( /// `fifo_hashmap` sets the order of iteration over `batch` rows while updating hashmap, /// which allows to keep either first (if set to true) or last (if set to false) row index /// as a chain head for rows with equal hash values. -#[allow(clippy::too_many_arguments)] +#[expect(clippy::too_many_arguments)] pub fn update_hash( on: &[PhysicalExprRef], batch: &RecordBatch, @@ -1774,12 +2195,144 @@ fn eq_dyn_null( }; return Ok(compare_op_for_nested(op, &left, &right)?); } + // Arrow's `eq` / `not_distinct` use IEEE 754 totalOrder semantics for + // floats, so `-0.0` and `+0.0` would compare unequal. Normalize float + // operands first; non-float types dispatch directly to avoid the + // `make_array(to_data())` round-trip. + if !matches!( + left.data_type(), + DataType::Float16 | DataType::Float32 | DataType::Float64 + ) { + return match null_equality { + NullEquality::NullEqualsNothing => eq(&left, &right), + NullEquality::NullEqualsNull => not_distinct(&left, &right), + }; + } + let left_arr: ArrayRef = make_array(left.to_data()); + let right_arr: ArrayRef = make_array(right.to_data()); + let left_norm = normalize_float_zero(&left_arr); + let right_norm = normalize_float_zero(&right_arr); + let left = left_norm.as_ref(); + let right = right_norm.as_ref(); match null_equality { NullEquality::NullEqualsNothing => eq(&left, &right), NullEquality::NullEqualsNull => not_distinct(&left, &right), } } +/// Pre-built comparator for join key columns that eliminates per-row type +/// dispatch. Wraps `arrow_ord::ord::DynComparator` closures built once per +/// batch pair, used for all row comparisons within those batches. +/// +/// The first key column is stored separately so that single-column joins +/// (the common case) avoid Vec iteration entirely, and multi-column joins +/// short-circuit without entering the loop when the first column is +/// selective. +/// +/// Null handling is baked into the closures at construction time: +/// - `NullEqualsNull`: `make_comparator` returns `Equal` for both-null, which +/// is the desired behavior. Closures are used as-is. +/// - `NullEqualsNothing`: columns where both sides contain nulls get a wrapper +/// that returns `Less` for both-null. Columns where one side has no nulls +/// skip the wrapper since both-null is impossible. +/// +/// Because `NullEqualsNothing` wraps comparators to return `Less` for +/// both-null, `is_equal` will return `false` for both-null rows when that +/// mode is active. Callers needing both-null == equal semantics (e.g., +/// buffered head/tail equality in SMJ) should construct with +/// `NullEqualsNull`. +pub struct JoinKeyComparator { + first: DynComparator, + rest: Vec, +} + +impl JoinKeyComparator { + /// Build comparators for each join key column pair. + pub fn new( + left_arrays: &[ArrayRef], + right_arrays: &[ArrayRef], + sort_options: &[SortOptions], + null_equality: NullEquality, + ) -> Result { + debug_assert_eq!(left_arrays.len(), right_arrays.len()); + debug_assert_eq!(left_arrays.len(), sort_options.len()); + + let mut iter = left_arrays + .iter() + .zip(right_arrays.iter()) + .zip(sort_options.iter()) + .map(|((l, r), opts)| { + // `make_comparator` uses IEEE 754 totalOrder for floats and + // treats `-0.0` / `+0.0` as distinct. Normalize float arrays + // so SMJ / piecewise-merge equi-keys honor SQL equality; + // no-op (Arc::clone) for non-floats and for float arrays + // that contain no `-0.0`. `normalize_float_zero` preserves + // null positions, so the original null masks below remain + // valid. + let l_norm = normalize_float_zero(l); + let r_norm = normalize_float_zero(r); + let inner = make_comparator(l_norm.as_ref(), r_norm.as_ref(), *opts)?; + if null_equality == NullEquality::NullEqualsNothing { + let ln = l.logical_nulls().filter(|n| n.null_count() > 0); + let rn = r.logical_nulls().filter(|n| n.null_count() > 0); + match (ln, rn) { + // Both sides have nulls — wrap to override both-null. + (Some(ln), Some(rn)) => Ok(Box::new(move |i, j| { + if ln.is_null(i) && rn.is_null(j) { + Ordering::Less + } else { + inner(i, j) + } + }) + as DynComparator), + // One side has no nulls — both-null impossible, no wrap. + _ => Ok(inner), + } + } else { + Ok(inner) + } + }); + + let first = iter.next().expect("join must have at least one key")?; + let rest = iter.collect::>>()?; + Ok(Self { first, rest }) + } + + /// Compare row `left` (in the left arrays) with row `right` (in the right + /// arrays). Returns the lexicographic ordering across all key columns. + #[inline] + pub fn compare(&self, left: usize, right: usize) -> Ordering { + let ord = (self.first)(left, right); + if ord != Ordering::Equal || self.rest.is_empty() { + return ord; + } + for cmp_fn in &self.rest { + let ord = cmp_fn(left, right); + if ord != Ordering::Equal { + return ord; + } + } + Ordering::Equal + } + + /// Check equality of row `left` (in the left arrays) with row `right` + /// (in the right arrays). Both-null is treated as equal when constructed + /// with `NullEqualsNull`. With `NullEqualsNothing`, both-null returns + /// `false` because the override is baked into the comparators. + #[inline] + pub fn is_equal(&self, left: usize, right: usize) -> bool { + if (self.first)(left, right) != Ordering::Equal { + return false; + } + for cmp_fn in &self.rest { + if cmp_fn(left, right) != Ordering::Equal { + return false; + } + } + true + } +} + /// Get comparison result of two rows of join arrays pub fn compare_join_arrays( left_arrays: &[ArrayRef], @@ -1880,61 +2433,136 @@ mod tests { use super::*; - use arrow::array::Int32Array; use arrow::datatypes::{DataType, Fields}; use arrow::error::{ArrowError, Result as ArrowResult}; use datafusion_common::stats::Precision::{Absent, Exact, Inexact}; - use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; + use datafusion_common::{ScalarValue, arrow_datafusion_err, arrow_err}; use datafusion_physical_expr::PhysicalSortExpr; use rstest::rstest; - fn check( - left: &[Column], - right: &[Column], - on: &[(PhysicalExprRef, PhysicalExprRef)], - ) -> Result<()> { - let left = left - .iter() - .map(|x| x.to_owned()) - .collect::>(); - let right = right - .iter() - .map(|x| x.to_owned()) - .collect::>(); - check_join_set_is_valid(&left, &right, on) + fn assert_u32_values(array: &UInt32Array, expected: &[u32]) { + assert_eq!(array.values().as_ref(), expected); } #[test] - fn check_valid() -> Result<()> { - let left = vec![Column::new("a", 0), Column::new("b1", 1)]; - let right = vec![Column::new("a", 0), Column::new("b2", 1)]; - let on = &[( - Arc::new(Column::new("a", 0)) as _, - Arc::new(Column::new("a", 0)) as _, - )]; + fn get_anti_indices_returns_unmatched_range_indices() { + let input = UInt32Array::from(vec![3, 5, 5]); - check(&left, &right, on)?; - Ok(()) + let result = get_anti_indices(2..8, &input); + + assert_u32_values(&result, &[2, 4, 6, 7]); } #[test] - fn check_not_in_right() { - let left = vec![Column::new("a", 0), Column::new("b", 1)]; - let right = vec![Column::new("b", 0)]; - let on = &[( - Arc::new(Column::new("a", 0)) as _, - Arc::new(Column::new("a", 0)) as _, - )]; + fn get_anti_indices_ignores_out_of_range_indices() { + let input = UInt32Array::from(vec![0, 1, 3, 5, 8, 12]); - assert!(check(&left, &right, on).is_err()); + let result = get_anti_indices(2..8, &input); + + assert_u32_values(&result, &[2, 4, 6, 7]); } - #[tokio::test] - async fn check_error_nesting() { - let once_fut = OnceFut::<()>::new(async { - arrow_err!(ArrowError::CsvError("some error".to_string())) - }); + #[test] + fn get_anti_indices_handles_dense_matches() { + let input = UInt32Array::from(vec![2, 3, 4, 5]); + + let result = get_anti_indices(2..6, &input); + + assert!(result.is_empty()); + } + + #[test] + fn get_anti_indices_handles_sparse_matches() { + let input = UInt32Array::from(vec![0, 8]); + + let result = get_anti_indices(2..6, &input); + + assert_u32_values(&result, &[2, 3, 4, 5]); + } + + #[test] + fn get_semi_indices_returns_distinct_matches_in_range() { + let input = UInt32Array::from(vec![1, 3, 3, 3, 5, 8]); + + let result = get_semi_indices(2..7, &input); + + assert_u32_values(&result, &[3, 5]); + } + + #[test] + fn get_semi_indices_ignores_out_of_range_indices() { + let input = UInt32Array::from(vec![0, 1, 3, 5, 8, 12]); + + let result = get_semi_indices(2..8, &input); + + assert_u32_values(&result, &[3, 5]); + } + + #[test] + fn get_semi_indices_handles_dense_matches() { + let input = UInt32Array::from(vec![2, 3, 4, 5]); + + let result = get_semi_indices(2..6, &input); + + assert_u32_values(&result, &[2, 3, 4, 5]); + } + + #[test] + fn get_semi_indices_handles_empty_input() { + let input = UInt32Array::from(Vec::::new()); + + let result = get_semi_indices(2..6, &input); + + assert!(result.is_empty()); + } + + fn check( + left: &[Column], + right: &[Column], + on: &[(PhysicalExprRef, PhysicalExprRef)], + ) -> Result<()> { + let left = left + .iter() + .map(|x| x.to_owned()) + .collect::>(); + let right = right + .iter() + .map(|x| x.to_owned()) + .collect::>(); + check_join_set_is_valid(&left, &right, on) + } + + #[test] + fn check_valid() -> Result<()> { + let left = vec![Column::new("a", 0), Column::new("b1", 1)]; + let right = vec![Column::new("a", 0), Column::new("b2", 1)]; + let on = &[( + Arc::new(Column::new("a", 0)) as _, + Arc::new(Column::new("a", 0)) as _, + )]; + + check(&left, &right, on)?; + Ok(()) + } + + #[test] + fn check_not_in_right() { + let left = vec![Column::new("a", 0), Column::new("b", 1)]; + let right = vec![Column::new("b", 0)]; + let on = &[( + Arc::new(Column::new("a", 0)) as _, + Arc::new(Column::new("a", 0)) as _, + )]; + + assert!(check(&left, &right, on).is_err()); + } + + #[tokio::test] + async fn check_error_nesting() { + let once_fut = OnceFut::<()>::new(async { + arrow_err!(ArrowError::CsvError("some error".to_string())) + }); struct TestFut(OnceFut<()>); impl Future for TestFut { @@ -2084,6 +2712,7 @@ mod tests { max_value: max.map(ScalarValue::from), sum_value: Absent, null_count, + byte_size: Absent, } } @@ -2245,6 +2874,22 @@ mod tests { (10, Inexact(1), Inexact(10), Absent, Absent), Some(Inexact(0)), ), + // NDV > num_rows: distinct count should be capped at row count + ( + (5, Inexact(1), Inexact(100), Inexact(50), Absent), + (10, Inexact(1), Inexact(100), Inexact(50), Absent), + // max_distinct_count caps: left NDV=min(50,5)=5, right NDV=min(50,10)=10 + // cardinality = (5 * 10) / max(5, 10) = 50 / 10 = 5 + Some(Inexact(5)), + ), + // NDV > num_rows on one side only + ( + (3, Inexact(1), Inexact(100), Inexact(100), Absent), + (10, Inexact(1), Inexact(100), Inexact(5), Absent), + // max_distinct_count caps: left NDV=min(100,3)=3, right NDV=min(5,10)=5 + // cardinality = (3 * 10) / max(3, 5) = 30 / 5 = 6 + Some(Inexact(6)), + ), ]; for (left_info, right_info, expected_cardinality) in cases { @@ -2291,6 +2936,7 @@ mod tests { create_stats(Some(left_num_rows), left_col_stats.clone(), false), create_stats(Some(right_num_rows), right_col_stats.clone(), false), &join_on, + NullEquality::NullEqualsNothing, ); assert_eq!( @@ -2384,11 +3030,14 @@ mod tests { // y: min=0, max=100, distinct=None // // Join on a=c, b=d (ignore x/y) + // Right column d has NDV=2500 but only 2000 rows, so NDV is capped + // to 2000. join_selectivity = max(500, 2000) = 2000. + // Inner cardinality = (1000 * 2000) / 2000 = 1000 let cases = vec![ - (JoinType::Inner, 800), + (JoinType::Inner, 1000), (JoinType::Left, 1000), (JoinType::Right, 2000), - (JoinType::Full, 2200), + (JoinType::Full, 2000), ]; let left_col_stats = vec![ @@ -2420,6 +3069,7 @@ mod tests { create_stats(Some(1000), left_col_stats.clone(), false), create_stats(Some(2000), right_col_stats.clone(), false), &join_on, + NullEquality::NullEqualsNothing, ) .unwrap(); assert_eq!(partial_join_stats.num_rows, expected_num_rows); @@ -2432,6 +3082,70 @@ mod tests { Ok(()) } + #[test] + fn test_join_cardinality_key_order() -> Result<()> { + // Reversing join key order should not change estimated cardinality + let left_col_stats = vec![ + create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent), + create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent), + create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent), + ]; + + let right_col_stats = vec![ + create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent), + create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent), + create_column_stats(Inexact(0), Inexact(100), Absent, Absent), + ]; + + let join_on_ab = vec![ + ( + Arc::new(Column::new("a", 0)) as _, + Arc::new(Column::new("c", 0)) as _, + ), + ( + Arc::new(Column::new("b", 1)) as _, + Arc::new(Column::new("d", 1)) as _, + ), + ]; + let join_on_ba = vec![ + ( + Arc::new(Column::new("b", 1)) as _, + Arc::new(Column::new("d", 1)) as _, + ), + ( + Arc::new(Column::new("a", 0)) as _, + Arc::new(Column::new("c", 0)) as _, + ), + ]; + + let stats_ab = estimate_join_cardinality( + &JoinType::Inner, + create_stats(Some(1000), left_col_stats.clone(), false), + create_stats(Some(2000), right_col_stats.clone(), false), + &join_on_ab, + NullEquality::NullEqualsNothing, + ) + .unwrap(); + let stats_ba = estimate_join_cardinality( + &JoinType::Inner, + create_stats(Some(1000), left_col_stats.clone(), false), + create_stats(Some(2000), right_col_stats.clone(), false), + &join_on_ba, + NullEquality::NullEqualsNothing, + ) + .unwrap(); + + assert_eq!(stats_ab.num_rows, 1000); + assert_eq!(stats_ba.num_rows, stats_ab.num_rows); + assert_eq!(stats_ba.column_statistics, stats_ab.column_statistics); + assert_eq!( + stats_ab.column_statistics, + [left_col_stats, right_col_stats].concat() + ); + + Ok(()) + } + #[test] fn test_join_cardinality_when_one_column_is_disjoint() -> Result<()> { // Left table (rows=1000) @@ -2491,6 +3205,7 @@ mod tests { create_stats(Some(1000), left_col_stats.clone(), true), create_stats(Some(2000), right_col_stats.clone(), true), &join_on, + NullEquality::NullEqualsNothing, ) .unwrap(); assert_eq!(partial_join_stats.num_rows, expected_num_rows); @@ -2519,7 +3234,7 @@ mod tests { JoinType::LeftSemi, (50, Inexact(10), Inexact(20), Absent, Absent), (10, Inexact(15), Inexact(25), Absent, Absent), - Some(50), + Some(46), ), ( JoinType::RightSemi, @@ -2555,13 +3270,13 @@ mod tests { JoinType::LeftAnti, (50, Inexact(10), Inexact(20), Absent, Absent), (10, Inexact(15), Inexact(25), Absent, Absent), - Some(50), + Some(4), ), ( JoinType::RightAnti, (50, Inexact(10), Inexact(20), Absent, Absent), (10, Inexact(15), Inexact(25), Absent, Absent), - Some(10), + Some(0), ), ( JoinType::LeftAnti, @@ -2587,6 +3302,108 @@ mod tests { (10, Inexact(30), Absent, Absent, Absent), Some(50), ), + // NDV-based semi join: outer_ndv=20, inner_ndv=10 + // selectivity = 10/20 = 0.5, cardinality = ceil(50 * 0.5) = 25 + ( + JoinType::LeftSemi, + (50, Inexact(1), Inexact(100), Inexact(20), Absent), + (10, Inexact(1), Inexact(100), Inexact(10), Absent), + Some(25), + ), + // inner_ndv(30) >= outer_ndv(20) -> selectivity 1.0, no reduction + ( + JoinType::LeftSemi, + (50, Inexact(1), Inexact(100), Inexact(20), Absent), + (100, Inexact(1), Inexact(100), Inexact(30), Absent), + Some(50), + ), + // NDV-based anti join: semi=25, anti = 50 - 25 = 25 + ( + JoinType::LeftAnti, + (50, Inexact(1), Inexact(100), Inexact(20), Absent), + (10, Inexact(1), Inexact(100), Inexact(10), Absent), + Some(25), + ), + // inner covers all outer: semi=50, anti = 0 + ( + JoinType::LeftAnti, + (50, Inexact(1), Inexact(100), Inexact(20), Absent), + (100, Inexact(1), Inexact(100), Inexact(30), Absent), + Some(0), + ), + // RightSemi with explicit NDV (NDV within row count, used as-is): + // For RightSemi, sides are swapped: outer = right (20 rows, ndv=10), + // inner = left (50 rows, ndv=5). selectivity = min(10,5)/10 = 0.5, + // cardinality = ceil(20 * 0.5) = 10. + ( + JoinType::RightSemi, + (50, Inexact(1), Inexact(100), Inexact(5), Absent), + (20, Inexact(1), Inexact(100), Inexact(10), Absent), + Some(10), + ), + // RightAnti with explicit NDV: anti = outer_rows - semi = 20 - 10 = 10. + ( + JoinType::RightAnti, + (50, Inexact(1), Inexact(100), Inexact(5), Absent), + (20, Inexact(1), Inexact(100), Inexact(10), Absent), + Some(10), + ), + // RightSemi where right-side NDV (20) exceeds right-side row count (10): + // NDV is clamped to 10, so outer_ndv=10, inner_ndv=10, + // selectivity = min(10,10)/10 = 1.0, cardinality = ceil(10 * 1.0) = 10. + ( + JoinType::RightSemi, + (50, Inexact(1), Inexact(100), Inexact(10), Absent), + (10, Inexact(1), Inexact(100), Inexact(20), Absent), + Some(10), + ), + // RightAnti with NDV clamped by row count: anti = 10 - 10 = 0. + ( + JoinType::RightAnti, + (50, Inexact(1), Inexact(100), Inexact(10), Absent), + (10, Inexact(1), Inexact(100), Inexact(20), Absent), + Some(0), + ), + // Empty inner table: no match possible, semi → 0 + ( + JoinType::LeftSemi, + (100, Absent, Absent, Absent, Absent), + (0, Absent, Absent, Absent, Absent), + Some(0), + ), + // NDV-based semi with nulls on outer side: + // outer_ndv=20, inner_ndv=10, null_frac=10/100=0.1 + // selectivity = 10/20 * (1-0.1) = 0.5 * 0.9 = 0.45 + // semi = ceil(100 * 0.45) = 45 + ( + JoinType::LeftSemi, + (100, Absent, Absent, Inexact(20), Inexact(10)), + (200, Absent, Absent, Inexact(10), Absent), + Some(45), + ), + // Anti-join with nulls on outer side: + // semi=45, anti = 100 - 45 = 55 + ( + JoinType::LeftAnti, + (100, Absent, Absent, Inexact(20), Inexact(10)), + (200, Absent, Absent, Inexact(10), Absent), + Some(55), + ), + // All outer rows are null: null_frac=1.0 + // selectivity = 10/20 * (1-1.0) = 0.0, semi = 0 + ( + JoinType::LeftSemi, + (100, Absent, Absent, Inexact(20), Inexact(100)), + (200, Absent, Absent, Inexact(10), Absent), + Some(0), + ), + // All outer rows are null (anti): anti = 100 - 0 = 100 + ( + JoinType::LeftAnti, + (100, Absent, Absent, Inexact(20), Inexact(100)), + (200, Absent, Absent, Inexact(10), Absent), + Some(100), + ), ]; let join_on = vec![( @@ -2624,6 +3441,7 @@ mod tests { column_statistics: inner_col_stats, }, &join_on, + NullEquality::NullEqualsNothing, ) .map(|cardinality| cardinality.num_rows); @@ -2658,6 +3476,7 @@ mod tests { column_statistics: dummy_column_stats.clone(), }, &join_on, + NullEquality::NullEqualsNothing, ); assert!( absent_outer_estimation.is_none(), @@ -2668,18 +3487,22 @@ mod tests { &JoinType::LeftSemi, Statistics { num_rows: Inexact(500), - total_byte_size: Absent, + total_byte_size: Absent, column_statistics: dummy_column_stats.clone(), }, Statistics { num_rows: Absent, - total_byte_size: Absent, + total_byte_size: Absent, column_statistics: dummy_column_stats.clone(), }, &join_on, + NullEquality::NullEqualsNothing, ).expect("Expected non-empty PartialJoinStatistics for SemiJoin with absent inner num_rows"); - assert_eq!(absent_inner_estimation.num_rows, 500, "Expected outer.num_rows estimated SemiJoin cardinality for absent inner num_rows"); + assert_eq!( + absent_inner_estimation.num_rows, 500, + "Expected outer.num_rows estimated SemiJoin cardinality for absent inner num_rows" + ); let absent_inner_estimation = estimate_join_cardinality( &JoinType::LeftSemi, @@ -2694,12 +3517,511 @@ mod tests { column_statistics: dummy_column_stats, }, &join_on, + NullEquality::NullEqualsNothing, + ); + assert!( + absent_inner_estimation.is_none(), + "Expected \"None\" estimated SemiJoin cardinality for absent outer and inner num_rows" + ); + + Ok(()) + } + + #[test] + fn test_semi_join_multi_column_and_mixed_stats() -> Result<()> { + let join_on = vec![ + ( + Arc::new(Column::new("l_col0", 0)) as _, + Arc::new(Column::new("r_col0", 0)) as _, + ), + ( + Arc::new(Column::new("l_col1", 1)) as _, + Arc::new(Column::new("r_col1", 1)) as _, + ), + ]; + + // Multi-column: both columns have NDV on both sides. + // col0: outer_ndv=20, inner_ndv=10 → selectivity = 10/20 = 0.5 + // col1: outer_ndv=40, inner_ndv=10 → selectivity = 10/40 = 0.25 + // total selectivity = 0.5 * 0.25 = 0.125 + // semi = ceil(100 * 0.125) = 13 + let result = estimate_join_cardinality( + &JoinType::LeftSemi, + Statistics { + num_rows: Inexact(100), + total_byte_size: Absent, + column_statistics: vec![ + create_column_stats(Absent, Absent, Inexact(20), Absent), + create_column_stats(Absent, Absent, Inexact(40), Absent), + ], + }, + Statistics { + num_rows: Inexact(200), + total_byte_size: Absent, + column_statistics: vec![ + create_column_stats(Absent, Absent, Inexact(10), Absent), + create_column_stats(Absent, Absent, Inexact(10), Absent), + ], + }, + &join_on, + NullEquality::NullEqualsNothing, + ) + .map(|c| c.num_rows); + assert_eq!(result, Some(13), "multi-column semi join"); + + // Multi-column anti: anti = 100 - 13 = 87 + let result = estimate_join_cardinality( + &JoinType::LeftAnti, + Statistics { + num_rows: Inexact(100), + total_byte_size: Absent, + column_statistics: vec![ + create_column_stats(Absent, Absent, Inexact(20), Absent), + create_column_stats(Absent, Absent, Inexact(40), Absent), + ], + }, + Statistics { + num_rows: Inexact(200), + total_byte_size: Absent, + column_statistics: vec![ + create_column_stats(Absent, Absent, Inexact(10), Absent), + create_column_stats(Absent, Absent, Inexact(10), Absent), + ], + }, + &join_on, + NullEquality::NullEqualsNothing, + ) + .map(|c| c.num_rows); + assert_eq!(result, Some(87), "multi-column anti join"); + + // Mixed stats: col0 has NDV on both sides, col1 has NDV only on outer. + // col1 is skipped (either side missing), so selectivity comes from col0 only. + // col0: outer_ndv=20, inner_ndv=10 → selectivity = 0.5 + // semi = ceil(100 * 0.5) = 50 + let result = estimate_join_cardinality( + &JoinType::LeftSemi, + Statistics { + num_rows: Inexact(100), + total_byte_size: Absent, + column_statistics: vec![ + create_column_stats(Absent, Absent, Inexact(20), Absent), + create_column_stats(Absent, Absent, Inexact(40), Absent), + ], + }, + Statistics { + num_rows: Inexact(200), + total_byte_size: Absent, + column_statistics: vec![ + create_column_stats(Absent, Absent, Inexact(10), Absent), + create_column_stats(Absent, Absent, Absent, Absent), + ], + }, + &join_on, + NullEquality::NullEqualsNothing, + ) + .map(|c| c.num_rows); + assert_eq!(result, Some(50), "mixed stats: col1 skipped"); + + // Mixed stats: neither column has stats on both sides → fallback to outer_rows + let result = estimate_join_cardinality( + &JoinType::LeftSemi, + Statistics { + num_rows: Inexact(100), + total_byte_size: Absent, + column_statistics: vec![ + create_column_stats(Absent, Absent, Inexact(20), Absent), + create_column_stats(Absent, Absent, Absent, Absent), + ], + }, + Statistics { + num_rows: Inexact(200), + total_byte_size: Absent, + column_statistics: vec![ + create_column_stats(Absent, Absent, Absent, Absent), + create_column_stats(Absent, Absent, Inexact(10), Absent), + ], + }, + &join_on, + NullEquality::NullEqualsNothing, + ) + .map(|c| c.num_rows); + assert_eq!(result, Some(100), "no column has stats on both sides"); + + // Multi-column with nulls on one column: + // col0: outer_ndv=20, inner_ndv=10, null_frac=0.0 → 10/20 * 1.0 = 0.5 + // col1: outer_ndv=40, inner_ndv=10, null_frac=20/100=0.2 → 10/40 * 0.8 = 0.2 + // total selectivity = 0.5 * 0.2 = 0.1 + // semi = ceil(100 * 0.1) = 10 + let result = estimate_join_cardinality( + &JoinType::LeftSemi, + Statistics { + num_rows: Inexact(100), + total_byte_size: Absent, + column_statistics: vec![ + create_column_stats(Absent, Absent, Inexact(20), Absent), + create_column_stats(Absent, Absent, Inexact(40), Inexact(20)), + ], + }, + Statistics { + num_rows: Inexact(200), + total_byte_size: Absent, + column_statistics: vec![ + create_column_stats(Absent, Absent, Inexact(10), Absent), + create_column_stats(Absent, Absent, Inexact(10), Absent), + ], + }, + &join_on, + NullEquality::NullEqualsNothing, + ) + .map(|c| c.num_rows); + assert_eq!( + result, + Some(10), + "multi-column semi join with nulls on one column" ); - assert!(absent_inner_estimation.is_none(), "Expected \"None\" estimated SemiJoin cardinality for absent outer and inner num_rows"); Ok(()) } + #[test] + fn test_semi_anti_join_disjoint_check_uses_only_join_keys() { + let join_on = vec![( + Arc::new(Column::new("l_key", 0)) as _, + Arc::new(Column::new("r_key", 0)) as _, + )]; + + // Ranges for the join key overlap; ranges for the other column are disjoint + let left_stats = Statistics { + num_rows: Inexact(50), + total_byte_size: Absent, + column_statistics: vec![ + create_column_stats(Inexact(1), Inexact(10), Absent, Absent), + create_column_stats(Inexact(100), Inexact(200), Absent, Absent), + ], + }; + let right_stats = Statistics { + num_rows: Inexact(10), + total_byte_size: Absent, + column_statistics: vec![ + create_column_stats(Inexact(1), Inexact(10), Absent, Absent), + create_column_stats(Inexact(1000), Inexact(2000), Absent, Absent), + ], + }; + + let left_semi = estimate_join_cardinality( + &JoinType::LeftSemi, + left_stats.clone(), + right_stats.clone(), + &join_on, + NullEquality::NullEqualsNothing, + ) + .map(|c| c.num_rows); + assert_eq!(left_semi, Some(50)); + + let left_anti = estimate_join_cardinality( + &JoinType::LeftAnti, + left_stats, + right_stats, + &join_on, + NullEquality::NullEqualsNothing, + ) + .map(|c| c.num_rows); + assert_eq!(left_anti, Some(0)); + } + + #[test] + fn test_semi_join_scales_preserved_column_statistics() { + let join_on = vec![( + Arc::new(Column::new("l_key", 0)) as _, + Arc::new(Column::new("r_key", 0)) as _, + )]; + + let result = estimate_join_cardinality( + &JoinType::LeftSemi, + Statistics { + num_rows: Inexact(432_187), + total_byte_size: Absent, + column_statistics: vec![ + ColumnStatistics { + null_count: Exact(7_196), + min_value: Exact(ScalarValue::from(1_i64)), + max_value: Exact(ScalarValue::from(432_187_i64)), + sum_value: Absent, + distinct_count: Absent, + byte_size: Exact(3_457_496), + }, + ColumnStatistics { + null_count: Exact(7_196), + min_value: Exact(ScalarValue::from(1_i64)), + max_value: Exact(ScalarValue::from(432_187_i64)), + sum_value: Exact(ScalarValue::from(1_000_000_i64)), + distinct_count: Exact(500_000), + byte_size: Exact(3_457_496), + }, + ], + }, + Statistics { + num_rows: Inexact(32), + total_byte_size: Absent, + column_statistics: vec![create_column_stats( + Inexact(1), + Inexact(32), + Absent, + Absent, + )], + }, + &join_on, + NullEquality::NullEqualsNothing, + ) + .expect("semi join cardinality should be estimated"); + + assert_eq!(result.num_rows, 32); + assert_eq!(result.total_byte_size, Inexact(512)); + assert_eq!(result.column_statistics[0].null_count, Exact(0)); + assert_eq!(result.column_statistics[0].distinct_count, Absent); + assert_eq!( + result.column_statistics[0].min_value, + Inexact(ScalarValue::from(1_i64)) + ); + assert_eq!( + result.column_statistics[0].max_value, + Inexact(ScalarValue::from(432_187_i64)) + ); + assert_eq!(result.column_statistics[0].byte_size, Inexact(256)); + assert_eq!(result.column_statistics[1].null_count, Inexact(1)); + // distinct_count is capped at the non-null output rows (32 - 1). + assert_eq!(result.column_statistics[1].distinct_count, Inexact(31)); + assert_eq!(result.column_statistics[1].sum_value, Absent); + assert_eq!(result.column_statistics[1].byte_size, Inexact(256)); + } + + #[test] + fn test_semi_join_null_equals_null_scales_join_key_nulls() { + let join_on = vec![( + Arc::new(Column::new("l_key", 0)) as _, + Arc::new(Column::new("r_key", 0)) as _, + )]; + + let result = estimate_join_cardinality( + &JoinType::LeftSemi, + Statistics { + num_rows: Inexact(100), + total_byte_size: Absent, + column_statistics: vec![create_column_stats( + Absent, + Absent, + Inexact(100), + Exact(20), + )], + }, + Statistics { + num_rows: Inexact(10), + total_byte_size: Absent, + column_statistics: vec![create_column_stats( + Absent, + Absent, + Inexact(10), + Absent, + )], + }, + &join_on, + NullEquality::NullEqualsNull, + ) + .expect("semi join cardinality should be estimated"); + + assert_eq!(result.num_rows, 10); + assert_eq!(result.column_statistics[0].null_count, Inexact(2)); + assert_eq!(result.column_statistics[0].distinct_count, Inexact(8)); + } + + #[test] + fn test_semi_join_total_byte_size_absent_if_any_column_byte_size_absent() { + let join_on = vec![( + Arc::new(Column::new("l_key", 0)) as _, + Arc::new(Column::new("r_key", 0)) as _, + )]; + + let result = estimate_join_cardinality( + &JoinType::LeftSemi, + Statistics { + num_rows: Inexact(100), + total_byte_size: Absent, + column_statistics: vec![ + ColumnStatistics { + null_count: Exact(0), + min_value: Exact(ScalarValue::from(1_i64)), + max_value: Exact(ScalarValue::from(100_i64)), + sum_value: Absent, + distinct_count: Absent, + byte_size: Exact(800), + }, + ColumnStatistics { + null_count: Exact(0), + min_value: Absent, + max_value: Absent, + sum_value: Absent, + distinct_count: Absent, + byte_size: Absent, + }, + ], + }, + Statistics { + num_rows: Inexact(10), + total_byte_size: Absent, + column_statistics: vec![create_column_stats( + Inexact(1), + Inexact(10), + Absent, + Absent, + )], + }, + &join_on, + NullEquality::NullEqualsNothing, + ) + .expect("semi join cardinality should be estimated"); + + assert_eq!(result.num_rows, 10); + assert_eq!(result.total_byte_size, Absent); + } + + #[test] + fn test_anti_join_preserves_join_key_nulls() { + let join_on = vec![( + Arc::new(Column::new("l_key", 0)) as _, + Arc::new(Column::new("r_key", 0)) as _, + )]; + + let result = estimate_join_cardinality( + &JoinType::LeftAnti, + Statistics { + num_rows: Inexact(1_000_000), + total_byte_size: Absent, + column_statistics: vec![create_column_stats( + Absent, + Absent, + Inexact(900_000), + Exact(100_000), + )], + }, + Statistics { + num_rows: Inexact(900_000), + total_byte_size: Absent, + column_statistics: vec![create_column_stats( + Absent, + Absent, + Inexact(900_000), + Absent, + )], + }, + &join_on, + NullEquality::NullEqualsNothing, + ) + .expect("anti join cardinality should be estimated"); + + assert_eq!(result.num_rows, 100_000); + assert_eq!(result.column_statistics[0].null_count, Inexact(100_000)); + assert_eq!(result.column_statistics[0].distinct_count, Inexact(0)); + } + + #[test] + fn test_anti_join_null_equals_null_scales_join_key_nulls() { + let join_on = vec![( + Arc::new(Column::new("l_key", 0)) as _, + Arc::new(Column::new("r_key", 0)) as _, + )]; + + let result = estimate_join_cardinality( + &JoinType::LeftAnti, + Statistics { + num_rows: Inexact(100), + total_byte_size: Absent, + column_statistics: vec![create_column_stats( + Absent, + Absent, + Inexact(100), + Exact(20), + )], + }, + Statistics { + num_rows: Inexact(10), + total_byte_size: Absent, + column_statistics: vec![create_column_stats( + Absent, + Absent, + Inexact(10), + Absent, + )], + }, + &join_on, + NullEquality::NullEqualsNull, + ) + .expect("anti join cardinality should be estimated"); + + assert_eq!(result.num_rows, 90); + assert_eq!(result.column_statistics[0].null_count, Inexact(18)); + assert_eq!(result.column_statistics[0].distinct_count, Inexact(72)); + } + + #[test] + fn test_right_semi_join_scales_preserved_column_statistics() { + let join_on = vec![( + Arc::new(Column::new("l_key", 0)) as _, + Arc::new(Column::new("r_key", 0)) as _, + )]; + + // For a right semi join the right input is preserved, so its column + // statistics (and right join-key index) are the ones normalized. + let result = estimate_join_cardinality( + &JoinType::RightSemi, + Statistics { + num_rows: Inexact(32), + total_byte_size: Absent, + column_statistics: vec![create_column_stats( + Inexact(1), + Inexact(32), + Absent, + Absent, + )], + }, + Statistics { + num_rows: Inexact(432_187), + total_byte_size: Absent, + column_statistics: vec![ + ColumnStatistics { + null_count: Exact(7_196), + min_value: Exact(ScalarValue::from(1_i64)), + max_value: Exact(ScalarValue::from(432_187_i64)), + sum_value: Absent, + distinct_count: Absent, + byte_size: Exact(3_457_496), + }, + ColumnStatistics { + null_count: Exact(7_196), + min_value: Exact(ScalarValue::from(1_i64)), + max_value: Exact(ScalarValue::from(432_187_i64)), + sum_value: Exact(ScalarValue::from(1_000_000_i64)), + distinct_count: Exact(500_000), + byte_size: Exact(3_457_496), + }, + ], + }, + &join_on, + NullEquality::NullEqualsNothing, + ) + .expect("right semi join cardinality should be estimated"); + + assert_eq!(result.num_rows, 32); + // Join-key column: null counts collapse to exact zero (null keys never match). + assert_eq!(result.column_statistics[0].null_count, Exact(0)); + assert_eq!(result.column_statistics[0].byte_size, Inexact(256)); + // Non-key column: counts scaled to the subset, sum dropped, distinct + // capped at the non-null output rows (32 - 1). + assert_eq!(result.column_statistics[1].null_count, Inexact(1)); + assert_eq!(result.column_statistics[1].distinct_count, Inexact(31)); + assert_eq!(result.column_statistics[1].sum_value, Absent); + assert_eq!(result.column_statistics[1].byte_size, Inexact(256)); + } + #[test] fn test_calculate_join_output_ordering() -> Result<()> { let left_ordering = LexOrdering::new(vec![ @@ -2825,7 +4147,6 @@ mod tests { fn assert_col_expr(expr: &Arc, name: &str, index: usize) { let col = expr - .as_any() .downcast_ref::() .expect("Projection items should be Column expression"); assert_eq!(col.name(), name); @@ -2855,4 +4176,238 @@ mod tests { Ok(()) } + + #[test] + fn test_build_batch_empty_build_side_empty_schema() -> Result<()> { + // When the output schema has no fields (empty projection pushed into + // the join), build_batch_empty_build_side should return a RecordBatch + // with the correct row count but no columns. + let empty_schema = Schema::empty(); + + let build_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; + + let probe_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])), + vec![Arc::new(Int32Array::from(vec![4, 5, 6, 7]))], + )?; + + let result = build_batch_empty_build_side( + &empty_schema, + &build_batch, + &probe_batch, + &[], // no column indices with empty projection + JoinType::Right, + )?; + + assert_eq!(result.num_rows(), 4); + assert_eq!(result.num_columns(), 0); + + Ok(()) + } + + #[test] + fn test_max_distinct_count_no_overflow_when_null_count_exceeds_num_rows() { + let num_rows = Exact(2); + let stats = ColumnStatistics { + distinct_count: Absent, + null_count: Exact(5), + min_value: Absent, + max_value: Absent, + sum_value: Absent, + byte_size: Absent, + }; + let result = max_distinct_count(&num_rows, &stats); + assert_eq!(result, Exact(0)); + } + + #[test] + fn test_join_key_comparator_multi_column() { + let left_a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 2, 3])); + let left_b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d"])); + let right_a: ArrayRef = Arc::new(Int32Array::from(vec![2, 2, 3, 4])); + let right_b: ArrayRef = Arc::new(StringArray::from(vec!["b", "d", "a", "a"])); + + let opts = vec![SortOptions::default(), SortOptions::default()]; + let cmp = JoinKeyComparator::new( + &[left_a, left_b], + &[right_a, right_b], + &opts, + NullEquality::NullEqualsNull, + ) + .unwrap(); + + // left[0]=(1,"a") vs right[0]=(2,"b") -> Less (first column) + assert_eq!(cmp.compare(0, 0), Ordering::Less); + // left[1]=(2,"b") vs right[0]=(2,"b") -> Equal + assert_eq!(cmp.compare(1, 0), Ordering::Equal); + assert!(cmp.is_equal(1, 0)); + // left[2]=(2,"c") vs right[1]=(2,"d") -> Less (second column) + assert_eq!(cmp.compare(2, 1), Ordering::Less); + // left[3]=(3,"d") vs right[0]=(2,"b") -> Greater + assert_eq!(cmp.compare(3, 0), Ordering::Greater); + } + + #[test] + fn test_join_key_comparator_null_equals_null() { + let left: ArrayRef = + Arc::new(Int32Array::from(vec![Some(1), None, None, Some(2)])); + let right: ArrayRef = + Arc::new(Int32Array::from(vec![None, None, Some(1), Some(2)])); + + let opts = vec![SortOptions { + descending: false, + nulls_first: true, + }]; + let cmp = JoinKeyComparator::new( + &[left], + &[right], + &opts, + NullEquality::NullEqualsNull, + ) + .unwrap(); + + // left[1]=NULL vs right[1]=NULL -> Equal (NullEqualsNull) + assert_eq!(cmp.compare(1, 1), Ordering::Equal); + assert!(cmp.is_equal(1, 1)); + // left[0]=1 vs right[0]=NULL -> Greater (nulls_first, non-null > null) + assert_eq!(cmp.compare(0, 0), Ordering::Greater); + // left[3]=2 vs right[3]=2 -> Equal + assert_eq!(cmp.compare(3, 3), Ordering::Equal); + } + + #[test] + fn test_join_key_comparator_null_equals_nothing() { + let left: ArrayRef = + Arc::new(Int32Array::from(vec![Some(1), None, None, Some(2)])); + let right: ArrayRef = + Arc::new(Int32Array::from(vec![None, None, Some(1), Some(2)])); + + let opts = vec![SortOptions { + descending: false, + nulls_first: true, + }]; + let cmp = JoinKeyComparator::new( + &[left], + &[right], + &opts, + NullEquality::NullEqualsNothing, + ) + .unwrap(); + + // left[1]=NULL vs right[1]=NULL -> Less (NullEqualsNothing) + assert_eq!(cmp.compare(1, 1), Ordering::Less); + // left[0]=1 vs right[0]=NULL -> Greater (nulls_first) + assert_eq!(cmp.compare(0, 0), Ordering::Greater); + // left[3]=2 vs right[3]=2 -> Equal + assert_eq!(cmp.compare(3, 3), Ordering::Equal); + } + + #[test] + fn test_join_key_comparator_nulls_first_ordering() { + let left: ArrayRef = Arc::new(Int32Array::from(vec![None, Some(1)])); + let right: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None])); + + // nulls_first = true: null < non-null + let cmp_nf = JoinKeyComparator::new( + &[Arc::clone(&left)], + &[Arc::clone(&right)], + &[SortOptions { + descending: false, + nulls_first: true, + }], + NullEquality::NullEqualsNull, + ) + .unwrap(); + assert_eq!(cmp_nf.compare(0, 0), Ordering::Less); + assert_eq!(cmp_nf.compare(1, 1), Ordering::Greater); + + // nulls_first = false: null > non-null + let cmp_nl = JoinKeyComparator::new( + &[left], + &[right], + &[SortOptions { + descending: false, + nulls_first: false, + }], + NullEquality::NullEqualsNull, + ) + .unwrap(); + assert_eq!(cmp_nl.compare(0, 0), Ordering::Greater); + assert_eq!(cmp_nl.compare(1, 1), Ordering::Less); + } + + #[test] + fn test_max_distinct_count_preserves_precision_when_not_capped() { + assert_eq!( + max_distinct_count( + &Exact(10), + &ColumnStatistics { + distinct_count: Exact(5), + ..Default::default() + } + ), + Exact(5) + ); + assert_eq!( + max_distinct_count( + &Exact(10), + &ColumnStatistics { + distinct_count: Inexact(5), + ..Default::default() + } + ), + Inexact(5) + ); + // Inexact num_rows does not affect an exact NDV that is within bounds + assert_eq!( + max_distinct_count( + &Inexact(10), + &ColumnStatistics { + distinct_count: Exact(5), + ..Default::default() + } + ), + Exact(5) + ); + } + + #[test] + fn test_max_distinct_count_demotes_to_inexact_when_capped() { + // Exact NDV > Exact num_rows is an illegal state (NDV <= num_rows is a + // mathematical invariant), but the code handles it defensively by + // capping and demoting to inexact + assert_eq!( + max_distinct_count( + &Exact(10), + &ColumnStatistics { + distinct_count: Exact(15), + ..Default::default() + } + ), + Inexact(10) + ); + assert_eq!( + max_distinct_count( + &Inexact(10), + &ColumnStatistics { + distinct_count: Exact(15), + ..Default::default() + } + ), + Inexact(10) + ); + assert_eq!( + max_distinct_count( + &Exact(10), + &ColumnStatistics { + distinct_count: Inexact(15), + ..Default::default() + } + ), + Inexact(10) + ); + } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index b74baf2d0672c..c7b1d4729e21d 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -23,8 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! Traits for physical query plan, supporting parallel execution for partitioned relations. @@ -33,26 +31,27 @@ pub use datafusion_common::hash_utils; pub use datafusion_common::utils::project_schema; -pub use datafusion_common::{internal_err, ColumnStatistics, Statistics}; +pub use datafusion_common::{ColumnStatistics, Statistics, internal_err}; pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; pub use datafusion_expr::{Accumulator, ColumnarValue}; -pub use datafusion_physical_expr::window::WindowExpr; use datafusion_physical_expr::PhysicalSortExpr; +pub use datafusion_physical_expr::window::WindowExpr; pub use datafusion_physical_expr::{ - expressions, Distribution, Partitioning, PhysicalExpr, + Distribution, Partitioning, PhysicalExpr, RangePartitioning, SplitPoint, expressions, }; pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; pub use crate::execution_plan::{ - collect, collect_partitioned, displayable, execute_input_stream, execute_stream, - execute_stream_partitioned, get_plan_string, with_new_children_if_necessary, - ExecutionPlan, ExecutionPlanProperties, PlanProperties, + ExecutionPlan, ExecutionPlanProperties, PlanProperties, collect, collect_partitioned, + displayable, execute_input_stream, execute_stream, execute_stream_partitioned, + get_plan_string, with_new_children_if_necessary, }; pub use crate::metrics::Metric; pub use crate::ordering::InputOrderMode; +pub use crate::sort_pushdown::SortOrderPushdownResult; pub use crate::stream::EmptyRecordBatchStream; pub use crate::topk::TopK; -pub use crate::visitor::{accept, visit_execution_plan, ExecutionPlanVisitor}; +pub use crate::visitor::{ExecutionPlanVisitor, accept, visit_execution_plan}; pub use crate::work_table::WorkTable; pub use spill::spill_manager::SpillManager; @@ -64,9 +63,11 @@ mod visitor; pub mod aggregates; pub mod analyze; pub mod async_func; +pub mod buffer; pub mod coalesce; pub mod coalesce_batches; pub mod coalesce_partitions; +pub mod column_rewriter; pub mod common; pub mod coop; pub mod display; @@ -79,10 +80,13 @@ pub mod joins; pub mod limit; pub mod memory; pub mod metrics; +pub mod operator_statistics; pub mod placeholder_row; pub mod projection; pub mod recursive_query; pub mod repartition; +pub mod scalar_subquery; +pub mod sort_pushdown; pub mod sorts; pub mod spill; pub mod stream; diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 4646e8ebc3132..7f42c33a79ca0 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -17,7 +17,6 @@ //! Defines the LIMIT plan -use std::any::Any; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -28,13 +27,17 @@ use super::{ SendableRecordBatchStream, Statistics, }; use crate::execution_plan::{Boundedness, CardinalityEffect}; -use crate::{DisplayFormatType, Distribution, ExecutionPlan, Partitioning}; +use crate::{ + DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + check_if_same_properties, +}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::{assert_eq_or_internal_err, internal_err, Result}; +use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; use datafusion_execution::TaskContext; +use datafusion_physical_expr::LexOrdering; use futures::stream::{Stream, StreamExt}; use log::trace; @@ -50,7 +53,10 @@ pub struct GlobalLimitExec { fetch: Option, /// Execution metrics metrics: ExecutionPlanMetricsSet, - cache: PlanProperties, + /// Does the limit have to preserve the order of its input, and if so what is it? + /// Some optimizations may reorder the input if no particular sort is required + required_ordering: Option, + cache: Arc, } impl GlobalLimitExec { @@ -62,7 +68,8 @@ impl GlobalLimitExec { skip, fetch, metrics: ExecutionPlanMetricsSet::new(), - cache, + required_ordering: None, + cache: Arc::new(cache), } } @@ -91,6 +98,27 @@ impl GlobalLimitExec { Boundedness::Bounded, ) } + + /// Get the required ordering from limit + pub fn required_ordering(&self) -> &Option { + &self.required_ordering + } + + /// Set the required ordering for limit + pub fn set_required_ordering(&mut self, required_ordering: Option) { + self.required_ordering = required_ordering; + } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for GlobalLimitExec { @@ -125,11 +153,7 @@ impl ExecutionPlan for GlobalLimitExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -151,10 +175,11 @@ impl ExecutionPlan for GlobalLimitExec { fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); Ok(Arc::new(GlobalLimitExec::new( - Arc::clone(&children[0]), + children.swap_remove(0), self.skip, self.fetch, ))) @@ -194,14 +219,9 @@ impl ExecutionPlan for GlobalLimitExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { - self.input - .partition_statistics(partition)? - .with_fetch(self.fetch, self.skip, 1) + fn partition_statistics(&self, partition: Option) -> Result> { + let stats = Arc::unwrap_or_clone(self.input.partition_statistics(partition)?); + Ok(Arc::new(stats.with_fetch(self.fetch, self.skip, 1)?)) } fn fetch(&self) -> Option { @@ -214,7 +234,7 @@ impl ExecutionPlan for GlobalLimitExec { } /// LocalLimitExec applies a limit to a single partition -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct LocalLimitExec { /// Input execution plan input: Arc, @@ -222,7 +242,10 @@ pub struct LocalLimitExec { fetch: usize, /// Execution metrics metrics: ExecutionPlanMetricsSet, - cache: PlanProperties, + /// If the child plan is a sort node, after the sort node is removed during + /// physical optimization, we should add the required ordering to the limit node + required_ordering: Option, + cache: Arc, } impl LocalLimitExec { @@ -233,7 +256,8 @@ impl LocalLimitExec { input, fetch, metrics: ExecutionPlanMetricsSet::new(), - cache, + required_ordering: None, + cache: Arc::new(cache), } } @@ -257,6 +281,27 @@ impl LocalLimitExec { Boundedness::Bounded, ) } + + /// Get the required ordering from limit + pub fn required_ordering(&self) -> &Option { + &self.required_ordering + } + + /// Set the required ordering for limit + pub fn set_required_ordering(&mut self, required_ordering: Option) { + self.required_ordering = required_ordering; + } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for LocalLimitExec { @@ -282,11 +327,7 @@ impl ExecutionPlan for LocalLimitExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -306,6 +347,7 @@ impl ExecutionPlan for LocalLimitExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); match children.len() { 1 => Ok(Arc::new(LocalLimitExec::new( Arc::clone(&children[0]), @@ -320,7 +362,12 @@ impl ExecutionPlan for LocalLimitExec { partition: usize, context: Arc, ) -> Result { - trace!("Start LocalLimitExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + trace!( + "Start LocalLimitExec::execute for partition {} of context session_id {} and task_id {:?}", + partition, + context.session_id(), + context.task_id() + ); let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); let stream = self.input.execute(partition, context)?; Ok(Box::pin(LimitStream::new( @@ -335,14 +382,9 @@ impl ExecutionPlan for LocalLimitExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { - self.input - .partition_statistics(partition)? - .with_fetch(Some(self.fetch), 0, 1) + fn partition_statistics(&self, partition: Option) -> Result> { + let stats = Arc::unwrap_or_clone(self.input.partition_statistics(partition)?); + Ok(Arc::new(stats.with_fetch(Some(self.fetch), 0, 1)?)) } fn fetch(&self) -> Option { @@ -494,8 +536,8 @@ mod tests { use arrow::array::RecordBatchOptions; use arrow::datatypes::Schema; use datafusion_common::stats::Precision; - use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_expr::expressions::col; #[tokio::test] async fn limit() -> Result<()> { @@ -721,9 +763,12 @@ mod tests { row_number_inexact_statistics_for_global_limit(5, Some(10)).await?; assert_eq!(row_count, Precision::Inexact(10)); + // Input was Inexact, so an `nr <= skip` outcome must remain Inexact: + // the inexact estimate could be wrong, so we cannot promote 0 to + // Exact. let row_count = row_number_inexact_statistics_for_global_limit(400, Some(10)).await?; - assert_eq!(row_count, Precision::Exact(0)); + assert_eq!(row_count, Precision::Inexact(0)); let row_count = row_number_inexact_statistics_for_global_limit(398, Some(10)).await?; diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 92e789ebc5965..ad54905f474aa 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -27,14 +27,14 @@ use crate::execution_plan::{Boundedness, EmissionType, SchedulingType}; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, - RecordBatchStream, SendableRecordBatchStream, Statistics, + RecordBatchStream, SendableRecordBatchStream, }; use arrow::array::RecordBatch; use arrow::datatypes::SchemaRef; -use datafusion_common::{assert_eq_or_internal_err, assert_or_internal_err, Result}; -use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_common::{Result, assert_eq_or_internal_err, assert_or_internal_err}; use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::MemoryReservation; use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; @@ -144,6 +144,9 @@ pub trait LazyBatchGenerator: Send + Sync + fmt::Debug + fmt::Display { /// Generate the next batch, return `None` when no more batches are available fn generate_next_batch(&mut self) -> Result>; + + /// Returns a new instance with the state reset. + fn reset_state(&self) -> Arc>; } /// Execution plan for lazy in-memory batches of data @@ -158,7 +161,7 @@ pub struct LazyMemoryExec { /// Functions to generate batches for each partition batch_generators: Vec>>, /// Plan properties cache storing equivalence properties, partitioning, and execution mode - cache: PlanProperties, + cache: Arc, /// Execution metrics metrics: ExecutionPlanMetricsSet, } @@ -197,7 +200,8 @@ impl LazyMemoryExec { EmissionType::Incremental, boundedness, ) - .with_scheduling_type(SchedulingType::Cooperative); + .with_scheduling_type(SchedulingType::Cooperative) + .into(); Ok(Self { schema, @@ -212,9 +216,9 @@ impl LazyMemoryExec { match projection.as_ref() { Some(columns) => { let projected = Arc::new(self.schema.project(columns).unwrap()); - self.cache = self.cache.with_eq_properties(EquivalenceProperties::new( - Arc::clone(&projected), - )); + Arc::make_mut(&mut self.cache).set_eq_properties( + EquivalenceProperties::new(Arc::clone(&projected)), + ); self.schema = projected; self.projection = projection; self @@ -233,12 +237,12 @@ impl LazyMemoryExec { partition_count, generator_count ); - self.cache.partitioning = partitioning; + Arc::make_mut(&mut self.cache).partitioning = partitioning; Ok(()) } pub fn add_ordering(&mut self, ordering: impl IntoIterator) { - self.cache + Arc::make_mut(&mut self.cache) .eq_properties .add_orderings(std::iter::once(ordering)); } @@ -295,15 +299,11 @@ impl ExecutionPlan for LazyMemoryExec { "LazyMemoryExec" } - fn as_any(&self) -> &dyn Any { - self - } - fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -336,10 +336,14 @@ impl ExecutionPlan for LazyMemoryExec { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + // Create a fresh generator via reset_state() so that each execute() + // call produces an independent stream starting from the beginning. + let generator = self.batch_generators[partition].read().reset_state(); + let stream = LazyMemoryStream { schema: Arc::clone(&self.schema), projection: self.projection.clone(), - generator: Arc::clone(&self.batch_generators[partition]), + generator, baseline_metrics, }; Ok(Box::pin(cooperative(stream))) @@ -349,8 +353,19 @@ impl ExecutionPlan for LazyMemoryExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema)) + fn reset_state(self: Arc) -> Result> { + let generators = self + .generators() + .iter() + .map(|g| g.read().reset_state()) + .collect::>(); + Ok(Arc::new(LazyMemoryExec { + schema: Arc::clone(&self.schema), + batch_generators: generators, + cache: Arc::clone(&self.cache), + metrics: ExecutionPlanMetricsSet::new(), + projection: self.projection.clone(), + })) } } @@ -450,6 +465,15 @@ mod lazy_memory_tests { vec![Arc::new(array)], )?)) } + + fn reset_state(&self) -> Arc> { + Arc::new(RwLock::new(TestGenerator { + counter: 0, + max_batches: self.max_batches, + batch_size: self.batch_size, + schema: Arc::clone(&self.schema), + })) + } } #[tokio::test] @@ -503,6 +527,41 @@ mod lazy_memory_tests { Ok(()) } + /// Verify that calling execute(0) twice on the same LazyMemoryExec + /// produces independent streams with the same data. + #[tokio::test] + async fn test_lazy_memory_exec_multiple_executions_are_independent() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let generator = TestGenerator { + counter: 0, + max_batches: 3, + batch_size: 2, + schema: Arc::clone(&schema), + }; + + let exec = + LazyMemoryExec::try_new(schema, vec![Arc::new(RwLock::new(generator))])?; + let task_ctx = Arc::new(TaskContext::default()); + + // First execution — consume all batches + let batches_1 = collect(exec.execute(0, Arc::clone(&task_ctx))?).await?; + let total_rows_1: usize = batches_1.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows_1, 6); + + // Second execution — should produce the same data, not continue + // from where the first execution left off + let batches_2 = collect(exec.execute(0, Arc::clone(&task_ctx))?).await?; + let total_rows_2: usize = batches_2.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows_2, 6); + + // Verify contents are identical + for (b1, b2) in batches_1.iter().zip(batches_2.iter()) { + assert_eq!(b1, b2); + } + + Ok(()) + } + #[tokio::test] async fn test_lazy_memory_exec_invalid_partition() -> Result<()> { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); @@ -568,4 +627,31 @@ mod lazy_memory_tests { Ok(()) } + + #[tokio::test] + async fn test_lazy_memory_exec_reset_state() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let generator = TestGenerator { + counter: 0, + max_batches: 3, + batch_size: 2, + schema: Arc::clone(&schema), + }; + + let exec = Arc::new(LazyMemoryExec::try_new( + schema, + vec![Arc::new(RwLock::new(generator))], + )?); + let stream = exec.execute(0, Arc::new(TaskContext::default()))?; + let batches = collect(stream).await?; + + let exec_reset = exec.reset_state()?; + let stream = exec_reset.execute(0, Arc::new(TaskContext::default()))?; + let batches_reset = collect(stream).await?; + + // if the reset_state is not correct, the batches_reset will be empty + assert_eq!(batches, batches_reset); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/metrics.rs b/datafusion/physical-plan/src/metrics.rs new file mode 100644 index 0000000000000..fe17cbdd4a2c2 --- /dev/null +++ b/datafusion/physical-plan/src/metrics.rs @@ -0,0 +1,21 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Metrics live in `datafusion-physical-expr-common`; this module re-exports +//! them to keep the public APIs stable. + +pub use datafusion_physical_expr_common::metrics::*; diff --git a/datafusion/physical-plan/src/operator_statistics/mod.rs b/datafusion/physical-plan/src/operator_statistics/mod.rs new file mode 100644 index 0000000000000..041ef4666658d --- /dev/null +++ b/datafusion/physical-plan/src/operator_statistics/mod.rs @@ -0,0 +1,2297 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Pluggable statistics propagation for physical plans. +//! +//! This module provides an extensible mechanism for computing statistics +//! on [`ExecutionPlan`] nodes, following the chain of responsibility pattern +//! similar to `RelationPlanner` for SQL parsing. +//! +//! # Overview +//! +//! The default implementation delegates to each operator's built-in +//! `partition_statistics`. Users can register custom [`StatisticsProvider`] +//! implementations to: +//! +//! 1. Provide statistics for custom [`ExecutionPlan`] implementations +//! 2. Override default estimation with advanced approaches (e.g., histograms) +//! 3. Plug in domain-specific knowledge for better cardinality estimation +//! +//! # Architecture +//! +//! - [`StatisticsProvider`]: Chain element that computes statistics for specific operators +//! - [`StatisticsRegistry`]: Chains providers, lives in SessionState +//! - [`ExtendedStatistics`]: Statistics with type-safe custom extensions +//! +//! # Built-in Providers +//! +//! The following providers are included and can be registered in this order: +//! +//! 1. [`FilterStatisticsProvider`] - selectivity-based filter estimation +//! 2. [`ProjectionStatisticsProvider`] - column mapping through projections +//! 3. [`PassthroughStatisticsProvider`] - passthrough for cardinality-preserving operators +//! 4. [`AggregateStatisticsProvider`] - NDV-based GROUP BY cardinality estimation +//! 5. [`JoinStatisticsProvider`] - NDV-based join output estimation (hash, sort-merge, cross) +//! 6. [`LimitStatisticsProvider`] - caps output at the fetch limit (local and global) +//! 7. [`UnionStatisticsProvider`] - sums input row counts +//! 8. [`DefaultStatisticsProvider`] - fallback to `partition_statistics(None)` +//! +//! # Relationship to [#20184](https://github.com/apache/datafusion/issues/20184) +//! +//! This module performs its own bottom-up tree walk in [`StatisticsRegistry::compute`], +//! separate from the walk optimizer rules do via `transform_up`. This means existing +//! rules that call `partition_statistics` directly bypass the registry. +//! +//! [#20184](https://github.com/apache/datafusion/issues/20184) adds a `child_stats` +//! parameter to `partition_statistics`. Once it lands, the registry can feed enriched +//! **base** [`Statistics`] into operators' built-in `partition_statistics` calls, +//! removing redundancy for the base-stats path (row counts, column stats). However, +//! the separate registry walk is still required for [`ExtendedStatistics`] extension +//! propagation: `partition_statistics` returns `Arc`, so extensions +//! (histograms, sketches, etc.) are stripped at that boundary and can only flow +//! through the registry walk. +//! +//! If [`Statistics`] itself were extended to carry a type-erased extension map +//! (similar to [`ExtendedStatistics`]), the registry walk could be dropped entirely: +//! extensions would flow naturally through `partition_statistics(child_stats)` and +//! the registry would become a pure chain-of-responsibility on top of the existing +//! traversal with no separate walk needed. +//! +//! # Example +//! +//! ```ignore +//! use datafusion_physical_plan::operator_statistics::*; +//! +//! // Create registry with default provider +//! let mut registry = StatisticsRegistry::new(); +//! +//! // Register custom provider (higher priority) +//! registry.register(Arc::new(MyHistogramProvider)); +//! +//! // Compute statistics through the chain +//! let stats = registry.compute(plan.as_ref())?; +//! ``` + +use std::fmt::{self, Debug}; +use std::sync::Arc; + +use datafusion_common::extensions::Extensions; +use datafusion_common::stats::Precision; +use datafusion_common::{Result, Statistics}; + +use crate::ExecutionPlan; + +// ============================================================================ +// ExtendedStatistics: Statistics with type-safe extensions +// ============================================================================ + +/// Statistics with support for custom extensions. +/// +/// Wraps the standard [`Statistics`] and adds a type-erased extension map +/// for custom statistics like histograms, sketches, or domain-specific metadata. +/// +/// # Example +/// +/// ```ignore +/// // Define a custom statistics extension +/// #[derive(Debug, Clone)] +/// struct HistogramStats { +/// buckets: Vec<(i64, i64, usize)>, // (min, max, count) +/// } +/// +/// // Set extension in a planner +/// let mut stats = ExtendedStatistics::from(base_stats); +/// stats.set_extension(HistogramStats { buckets: vec![] }); +/// +/// // Retrieve in a consumer +/// if let Some(hist) = stats.get_extension::() { +/// // Use histogram for better estimation +/// } +/// ``` +#[derive(Debug, Clone, Default)] +pub struct ExtendedStatistics { + /// Standard statistics (num_rows, byte_size, column stats) + base: Arc, + /// Type-erased extensions for custom statistics + extensions: Extensions, +} + +impl ExtendedStatistics { + /// Create new ExtendedStatistics wrapping owned statistics. + pub fn new(base: Statistics) -> Self { + Self { + base: Arc::new(base), + extensions: Extensions::new(), + } + } + + /// Create new ExtendedStatistics from an [`Arc`]. + pub fn new_arc(base: Arc) -> Self { + Self { + base, + extensions: Extensions::new(), + } + } + + /// Returns a reference to the base [`Statistics`]. + pub fn base(&self) -> &Statistics { + &self.base + } + + /// Returns a reference to the underlying [`Arc`]. + pub fn base_arc(&self) -> &Arc { + &self.base + } + + /// Get a reference to a custom statistics extension by type. + pub fn get_extension(&self) -> Option<&T> { + self.extensions.get::() + } + + /// Set a custom statistics extension. + pub fn set_extension(&mut self, value: T) { + self.extensions.insert(value); + } + + /// Check if an extension of the given type exists. + pub fn has_extension(&self) -> bool { + self.extensions.contains::() + } + + /// Merge extensions from another ExtendedStatistics (other's extensions take precedence). + pub fn merge_extensions(&mut self, other: &ExtendedStatistics) { + self.extensions.merge(&other.extensions); + } +} + +impl From for ExtendedStatistics { + fn from(base: Statistics) -> Self { + Self::new(base) + } +} + +impl From> for ExtendedStatistics { + fn from(base: Arc) -> Self { + Self::new_arc(base) + } +} + +impl From for Statistics { + fn from(extended: ExtendedStatistics) -> Self { + Arc::unwrap_or_clone(extended.base) + } +} + +// ============================================================================ +// StatisticsProvider trait and registry +// ============================================================================ + +/// Result of attempting to compute statistics with a [`StatisticsProvider`]. +#[derive(Debug)] +pub enum StatisticsResult { + /// Statistics were computed by this provider + Computed(ExtendedStatistics), + /// This provider doesn't handle this operator; delegate to next in chain + Delegate, +} + +/// Customize statistics computation for [`ExecutionPlan`] nodes. +/// +/// Implementations can handle specific operator types or override default +/// estimation logic. The chain of providers is traversed until one returns +/// [`StatisticsResult::Computed`]. +/// +/// # Implementing a Custom Provider +/// +/// ```ignore +/// #[derive(Debug)] +/// struct MyStatisticsProvider; +/// +/// impl StatisticsProvider for MyStatisticsProvider { +/// fn compute_statistics( +/// &self, +/// plan: &dyn ExecutionPlan, +/// child_stats: &[ExtendedStatistics], +/// ) -> Result { +/// if let Some(my_exec) = plan.downcast_ref::() { +/// // Custom logic for MyCustomExec +/// Ok(StatisticsResult::Computed(/* ... */)) +/// } else { +/// // Let next provider handle it +/// Ok(StatisticsResult::Delegate) +/// } +/// } +/// } +/// ``` +pub trait StatisticsProvider: Debug + Send + Sync { + /// Compute statistics for an [`ExecutionPlan`] node. + /// + /// # Arguments + /// * `plan` - The execution plan node to compute statistics for + /// * `child_stats` - Extended statistics already computed for child nodes, + /// in the same order as `plan.children()`. Empty for leaf nodes. + /// + /// # Returns + /// * `StatisticsResult::Computed(stats)` - Short-circuits the chain + /// * `StatisticsResult::Delegate` - Passes to next provider in chain + fn compute_statistics( + &self, + plan: &dyn ExecutionPlan, + child_stats: &[ExtendedStatistics], + ) -> Result; +} + +/// Default statistics provider that delegates to each operator's built-in +/// `partition_statistics` implementation. +#[derive(Debug, Default)] +pub struct DefaultStatisticsProvider; + +impl StatisticsProvider for DefaultStatisticsProvider { + fn compute_statistics( + &self, + plan: &dyn ExecutionPlan, + _child_stats: &[ExtendedStatistics], + ) -> Result { + let base = plan.partition_statistics(None)?; + Ok(StatisticsResult::Computed(ExtendedStatistics::new_arc( + base, + ))) + } +} + +/// Registry that chains [`StatisticsProvider`] implementations. +/// +/// The registry is a stateless provider chain: it holds no mutable state +/// and is cheaply `Clone`able / `Send` / `Sync`. +#[derive(Clone)] +pub struct StatisticsRegistry { + providers: Vec>, +} + +impl Debug for StatisticsRegistry { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "StatisticsRegistry({} providers)", self.providers.len()) + } +} + +impl Default for StatisticsRegistry { + fn default() -> Self { + Self::new() + } +} + +impl StatisticsRegistry { + /// Create a new empty registry. + /// + /// With no providers, `compute()` falls back to each plan node's + /// built-in `partition_statistics()`. Register providers to enhance + /// statistics (e.g., inject NDV, use histograms). + pub fn new() -> Self { + Self { + providers: Vec::new(), + } + } + + /// Create a registry with the given provider chain. + pub fn with_providers(providers: Vec>) -> Self { + Self { providers } + } + + /// Create a registry pre-loaded with the standard built-in providers. + /// + /// Provider order (first match wins): + /// 1. [`FilterStatisticsProvider`] + /// 2. [`ProjectionStatisticsProvider`] + /// 3. [`PassthroughStatisticsProvider`] + /// 4. [`AggregateStatisticsProvider`] + /// 5. [`JoinStatisticsProvider`] + /// 6. [`LimitStatisticsProvider`] + /// 7. [`UnionStatisticsProvider`] + /// 8. [`DefaultStatisticsProvider`] + pub fn default_with_builtin_providers() -> Self { + Self::with_providers(vec![ + Arc::new(FilterStatisticsProvider), + Arc::new(ProjectionStatisticsProvider), + Arc::new(PassthroughStatisticsProvider), + Arc::new(AggregateStatisticsProvider), + Arc::new(JoinStatisticsProvider), + Arc::new(LimitStatisticsProvider), + Arc::new(UnionStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]) + } + + /// Register a provider at the front of the chain (higher priority). + pub fn register(&mut self, provider: Arc) { + self.providers.insert(0, provider); + } + + /// Returns the current provider chain. + pub fn providers(&self) -> &[Arc] { + &self.providers + } + + /// Compute extended statistics for a plan through the provider chain. + /// + /// Performs a bottom-up tree walk: child statistics are computed recursively + /// and passed to providers, mirroring how `partition_statistics` composes + /// operators. Once [#20184](https://github.com/apache/datafusion/issues/20184) + /// lands, the registry can feed enriched base stats directly into + /// `partition_statistics(child_stats)`, removing the need for a separate walk. + /// + /// If no providers are registered, falls back to the plan's built-in + /// `partition_statistics(None)` with no overhead. + pub fn compute(&self, plan: &dyn ExecutionPlan) -> Result { + // Fast path: no providers registered, skip the walk entirely + if self.providers.is_empty() { + let base = plan.partition_statistics(None)?; + return Ok(ExtendedStatistics::new_arc(base)); + } + + let children = plan.children(); + + // For leaf nodes, try providers with empty child stats. + // For non-leaf nodes, recursively compute enhanced child stats first. + let child_stats: Vec = if children.is_empty() { + Vec::new() + } else { + children + .iter() + .map(|child| self.compute(child.as_ref())) + .collect::>>()? + }; + + for provider in &self.providers { + match provider.compute_statistics(plan, &child_stats)? { + StatisticsResult::Computed(stats) => return Ok(stats), + StatisticsResult::Delegate => continue, + } + } + // Fallback: use plan's built-in stats + let base = plan.partition_statistics(None)?; + Ok(ExtendedStatistics::new_arc(base)) + } + + /// Compute statistics and return only the base Statistics (no extensions). + /// + /// Convenience method for callers that don't need extensions. + pub fn compute_base(&self, plan: &dyn ExecutionPlan) -> Result { + Ok(self.compute(plan)?.base().clone()) + } +} + +// ============================================================================ +// Statistics Utility Functions +// ============================================================================ + +/// Estimate the number of distinct values when sampling from a population. +/// +/// Given a domain with `domain_size` distinct values and `num_selected` rows +/// sampled/filtered from it, estimates how many distinct values will appear +/// in the sample. +/// +/// Uses the formula: `Expected distinct = N * [1 - (1 - 1/N)^n]` +/// +/// # References +/// +/// Based on Calcite's `RelMdUtil.numDistinctVals()`: +/// +pub fn num_distinct_vals(domain_size: usize, num_selected: usize) -> usize { + if domain_size == 0 || num_selected == 0 { + return 0; + } + + if num_selected >= domain_size { + return domain_size; + } + + let n = domain_size as f64; + let k = num_selected as f64; + + // For large n, (1-1/n).powf(k) loses precision because the base is near + // 1.0; use the equivalent exp(-k/n) form which is numerically stable. + // Threshold matches Calcite's RelMdUtil.numDistinctVals(). + let expected = if domain_size > 1000 { + n * (1.0 - (-k / n).exp()) + } else { + n * (1.0 - (1.0 - 1.0 / n).powf(k)) + }; + + let result = expected.round() as usize; + result.clamp(1, domain_size) +} + +/// Estimate NDV after applying a selectivity factor (filtering). +/// +/// When filtering rows, each distinct value has multiple rows. If a value +/// appears `k` times, the probability it survives the filter is `1 - (1-s)^k` +/// where `s` is the selectivity. +/// +/// Assuming uniform distribution (each value appears `rows/ndv` times): +/// ```text +/// NDV_after ~ NDV_before * [1 - (1 - selectivity)^(rows/NDV)] +/// ``` +pub fn ndv_after_selectivity( + original_ndv: usize, + original_rows: usize, + selectivity: f64, +) -> usize { + if selectivity <= 0.0 || original_ndv == 0 || original_rows == 0 { + return 0; + } + if selectivity >= 1.0 { + return original_ndv; + } + + let ndv = original_ndv as f64; + let rows = original_rows as f64; + + let rows_per_value = rows / ndv; + let survival_prob = 1.0 - (1.0 - selectivity).powf(rows_per_value); + let expected_ndv = ndv * survival_prob; + + (expected_ndv.round() as usize).clamp(1, original_ndv) +} + +/// Rescale `total_byte_size` proportionally after overriding `num_rows`. +/// +/// When a provider replaces `num_rows` but keeps the rest of the stats from +/// `partition_statistics`, the original `total_byte_size` becomes inconsistent. +/// This function adjusts it by the ratio `new_rows / old_rows`, preserving the +/// average bytes-per-row from the original estimate. +fn rescale_byte_size(stats: &mut Statistics, new_num_rows: Precision) { + let old_rows = stats.num_rows; + stats.num_rows = new_num_rows; + stats.total_byte_size = match (old_rows, new_num_rows, stats.total_byte_size) { + (Precision::Exact(old), Precision::Exact(new), Precision::Exact(bytes)) + if old > 0 => + { + Precision::Exact((bytes as f64 * new as f64 / old as f64).round() as usize) + } + _ => match ( + old_rows.get_value(), + new_num_rows.get_value(), + stats.total_byte_size.get_value(), + ) { + (Some(&old), Some(&new), Some(&bytes)) if old > 0 => Precision::Inexact( + (bytes as f64 * new as f64 / old as f64).round() as usize, + ), + _ => stats.total_byte_size, + }, + }; +} + +/// Fetches base statistics from the operator's built-in `partition_statistics`, +/// overrides `num_rows` with the registry-computed estimate, and rescales +/// `total_byte_size` proportionally. +/// +/// Used by providers that compute a better row count but cannot yet propagate +/// column-level stats (NDV, min/max) through the operator — pending #20184. +fn computed_with_row_count( + plan: &dyn ExecutionPlan, + num_rows: Precision, +) -> Result { + let mut base = Arc::unwrap_or_clone(plan.partition_statistics(None)?); + rescale_byte_size(&mut base, num_rows); + Ok(StatisticsResult::Computed(ExtendedStatistics::new(base))) +} + +/// Statistics provider for [`FilterExec`](crate::filter::FilterExec) that uses +/// pre-computed enhanced child statistics from the registry walk. +/// +/// Unlike the default provider (which calls `partition_statistics` and gets raw +/// child stats), this provider receives enhanced child stats that may include +/// NDV overrides injected at the scan level. It applies the same selectivity +/// estimation logic as `FilterExec::statistics_helper`, then additionally +/// adjusts each column's `distinct_count` using [`ndv_after_selectivity`] based +/// on the computed selectivity ratio. +#[derive(Debug, Default)] +pub struct FilterStatisticsProvider; + +impl StatisticsProvider for FilterStatisticsProvider { + fn compute_statistics( + &self, + plan: &dyn ExecutionPlan, + child_stats: &[ExtendedStatistics], + ) -> Result { + use crate::filter::FilterExec; + + let Some(filter) = plan.downcast_ref::() else { + return Ok(StatisticsResult::Delegate); + }; + if child_stats.is_empty() { + return Ok(StatisticsResult::Delegate); + } + + let input_stats = (*child_stats[0].base).clone(); + let input_rows = input_stats.num_rows; + let mut stats = FilterExec::statistics_helper( + &filter.input().schema(), + input_stats, + filter.predicate(), + filter.default_selectivity(), + // TODO: pass filter.expression_analyzer_registry() once #21122 lands + )?; + + // Adjust distinct_count for each column using the selectivity ratio + // via the probabilistic survival model from + // ndv_after_selectivity to account for rows removed by the filter. + if let (Some(&orig_rows), Some(&filtered_rows)) = + (input_rows.get_value(), stats.num_rows.get_value()) + && orig_rows > 0 + && filtered_rows < orig_rows + { + let selectivity = filtered_rows as f64 / orig_rows as f64; + for col_stat in &mut stats.column_statistics { + if let Some(&ndv) = col_stat.distinct_count.get_value() { + let adjusted = ndv_after_selectivity(ndv, orig_rows, selectivity); + col_stat.distinct_count = Precision::Inexact(adjusted); + } + } + } + + let stats = stats.project(filter.projection().as_ref()); + Ok(StatisticsResult::Computed(ExtendedStatistics::new(stats))) + } +} + +/// Statistics provider for [`ProjectionExec`](crate::projection::ProjectionExec) +/// that uses pre-computed enhanced child statistics from the registry walk. +/// +/// Maps enhanced child column statistics to output columns based on the +/// projection expressions, preserving NDV and other statistics through +/// column references. +#[derive(Debug, Default)] +pub struct ProjectionStatisticsProvider; + +impl StatisticsProvider for ProjectionStatisticsProvider { + fn compute_statistics( + &self, + plan: &dyn ExecutionPlan, + child_stats: &[ExtendedStatistics], + ) -> Result { + use crate::projection::ProjectionExec; + + let Some(proj) = plan.downcast_ref::() else { + return Ok(StatisticsResult::Delegate); + }; + if child_stats.is_empty() { + return Ok(StatisticsResult::Delegate); + } + + let input_stats = (*child_stats[0].base).clone(); + let output_schema = proj.schema(); + // TODO: pass proj.expression_analyzer_registry() once #21122 lands, + // so expression-level NDV/min/max feeds into projected column stats. + let stats = proj + .projection_expr() + .project_statistics(input_stats, &output_schema)?; + Ok(StatisticsResult::Computed(ExtendedStatistics::new(stats))) + } +} + +/// Statistics provider for single-input operators with +/// [`CardinalityEffect::Equal`](crate::execution_plan::CardinalityEffect::Equal). +/// +/// These operators (Sort, Repartition, CoalescePartitions, etc.) don't +/// transform statistics, so we pass through the enhanced child stats directly. +/// This avoids the fallback calling `partition_statistics(None)` which would +/// trigger a redundant internal recursion with raw (non-enhanced) stats. +#[derive(Debug, Default)] +pub struct PassthroughStatisticsProvider; + +impl StatisticsProvider for PassthroughStatisticsProvider { + fn compute_statistics( + &self, + plan: &dyn ExecutionPlan, + child_stats: &[ExtendedStatistics], + ) -> Result { + use crate::execution_plan::CardinalityEffect; + + if child_stats.len() != 1 + || !matches!(plan.cardinality_effect(), CardinalityEffect::Equal) + { + return Ok(StatisticsResult::Delegate); + } + + // Only pass through when the schema is unchanged (same column count). + // Operators like WindowAggExec preserve row count but add columns; + // passing through child stats would produce wrong column_statistics. + let input_cols = child_stats[0].base.column_statistics.len(); + let output_cols = plan.schema().fields().len(); + if input_cols != output_cols { + return Ok(StatisticsResult::Delegate); + } + + Ok(StatisticsResult::Computed(child_stats[0].clone())) + } +} + +/// Statistics provider for [`AggregateExec`](crate::aggregates::AggregateExec) +/// that estimates output cardinality from the NDV of GROUP BY columns. +/// +/// For each GROUP BY column, looks up `distinct_count` from the enhanced +/// child statistics. The estimated output rows is the product of all +/// column NDVs, capped at the input row count. This assumes independence +/// between columns, so correlated columns (e.g., `city` and `state`) will +/// produce overestimates. +/// +/// For GROUPING SETS / CUBE / ROLLUP, delegates to the built-in +/// `partition_statistics`, which handles per-set NDV estimation correctly. +/// +/// Delegates when: +/// - The plan is not an `AggregateExec` +/// - The aggregate is `Partial` (per-partition, not bounded by global NDV) +/// - GROUP BY is empty (scalar aggregate) +/// - Any GROUP BY expression is not a simple column reference +/// - Any GROUP BY column lacks NDV information +#[derive(Debug, Default)] +pub struct AggregateStatisticsProvider; + +impl StatisticsProvider for AggregateStatisticsProvider { + fn compute_statistics( + &self, + plan: &dyn ExecutionPlan, + child_stats: &[ExtendedStatistics], + ) -> Result { + use crate::aggregates::AggregateExec; + use datafusion_physical_expr::expressions::Column; + + use crate::aggregates::AggregateMode; + + let Some(agg) = plan.downcast_ref::() else { + return Ok(StatisticsResult::Delegate); + }; + + // Partial aggregates produce per-partition groups, not bounded by + // global NDV; delegate to the built-in estimate for those. + if matches!(agg.mode(), AggregateMode::Partial) { + return Ok(StatisticsResult::Delegate); + } + + if child_stats.is_empty() || agg.group_expr().expr().is_empty() { + return Ok(StatisticsResult::Delegate); + } + + let input_stats = &child_stats[0].base; + + // Compute NDV product of GROUP BY columns + let mut ndv_product: Option = None; + for (expr, _) in agg.group_expr().expr().iter() { + let Some(col) = expr.downcast_ref::() else { + return Ok(StatisticsResult::Delegate); + }; + let Some(&ndv) = input_stats + .column_statistics + .get(col.index()) + .and_then(|s| s.distinct_count.get_value()) + else { + return Ok(StatisticsResult::Delegate); + }; + if ndv == 0 { + return Ok(StatisticsResult::Delegate); + } + ndv_product = Some(match ndv_product { + Some(prev) => prev.saturating_mul(ndv), + None => ndv, + }); + } + + let Some(product) = ndv_product else { + return Ok(StatisticsResult::Delegate); + }; + + // For CUBE/ROLLUP/GROUPING SETS (multiple grouping sets), delegate to + // the built-in estimate, which handles per-set NDV estimation correctly. + if agg.group_expr().groups().len() > 1 { + return Ok(StatisticsResult::Delegate); + } + + // Cap at input rows + let estimate = match input_stats.num_rows.get_value() { + Some(&rows) => product.min(rows), + None => product, + }; + + let num_rows = Precision::Inexact(estimate); + + computed_with_row_count(plan, num_rows) + } +} + +/// Statistics provider for equi-joins (hash join, sort-merge join) and cross joins. +/// +/// For equi-joins, estimates output cardinality as +/// `left_rows * right_rows / product(max(left_ndv_i, right_ndv_i))` +/// across all join key columns (assuming independence between keys), +/// falling back to the Cartesian product when any key lacks NDV on both sides. +/// For cross joins, uses the exact Cartesian product. +/// +/// The base inner-join estimate is then adjusted for the join type: +/// - Semi joins: capped at the preserved-side row count +/// - Anti joins: preserved-side minus matched rows (clamped to 0) +/// - Left/Right outer: at least as many rows as the preserved side +/// - Full outer: at least `left + right - inner_estimate` +/// - Left mark: exactly `left_rows` (one output row per left row) +/// +/// Delegates when: +/// - The plan is not a supported join type +/// - Either input lacks row count information +#[derive(Debug, Default)] +pub struct JoinStatisticsProvider; + +impl StatisticsProvider for JoinStatisticsProvider { + fn compute_statistics( + &self, + plan: &dyn ExecutionPlan, + child_stats: &[ExtendedStatistics], + ) -> Result { + use crate::joins::{CrossJoinExec, HashJoinExec, SortMergeJoinExec}; + use datafusion_common::JoinType; + use datafusion_physical_expr::expressions::Column; + + if child_stats.len() < 2 { + return Ok(StatisticsResult::Delegate); + } + + let left = &child_stats[0].base; + let right = &child_stats[1].base; + + let (Some(&left_rows), Some(&right_rows)) = + (left.num_rows.get_value(), right.num_rows.get_value()) + else { + return Ok(StatisticsResult::Delegate); + }; + + use crate::joins::JoinOnRef; + + /// Estimate equi-join output using NDV of join key columns: + /// left_rows * right_rows / product(max(left_ndv_i, right_ndv_i)) + /// Falls back to Cartesian product if any key lacks NDV on both sides. + fn equi_join_estimate( + on: JoinOnRef, + left: &Statistics, + right: &Statistics, + left_rows: usize, + right_rows: usize, + ) -> usize { + if on.is_empty() { + return left_rows.saturating_mul(right_rows); + } + let mut ndv_divisor: usize = 1; + for (left_key, right_key) in on { + let left_ndv = left_key + .downcast_ref::() + .and_then(|c| left.column_statistics.get(c.index())) + .and_then(|s| s.distinct_count.get_value().copied()); + let right_ndv = right_key + .downcast_ref::() + .and_then(|c| right.column_statistics.get(c.index())) + .and_then(|s| s.distinct_count.get_value().copied()); + match (left_ndv, right_ndv) { + (Some(l), Some(r)) if l > 0 && r > 0 => { + ndv_divisor = ndv_divisor.saturating_mul(l.max(r)); + } + _ => return left_rows.saturating_mul(right_rows), + } + } + let max_rows = left_rows.saturating_mul(right_rows); + max_rows.checked_div(ndv_divisor).unwrap_or(max_rows) + } + + let (inner_estimate, is_exact_cartesian, join_type) = if let Some(hash_join) = + plan.downcast_ref::() + { + let est = + equi_join_estimate(hash_join.on(), left, right, left_rows, right_rows); + (est, false, *hash_join.join_type()) + } else if let Some(smj) = plan.downcast_ref::() { + let est = equi_join_estimate(smj.on(), left, right, left_rows, right_rows); + (est, false, smj.join_type()) + } else if plan.downcast_ref::().is_some() { + let both_exact = left.num_rows.is_exact().unwrap_or(false) + && right.num_rows.is_exact().unwrap_or(false); + ( + left_rows.saturating_mul(right_rows), + both_exact, + JoinType::Inner, + ) + } else { + return Ok(StatisticsResult::Delegate); + }; + + // Apply join-type-aware cardinality bounds + let estimated = match join_type { + JoinType::Inner => inner_estimate, + JoinType::Left => inner_estimate.max(left_rows), + JoinType::Right => inner_estimate.max(right_rows), + JoinType::Full => { + // At least left + right - matched, but never less than inner + let outer_bound = left_rows + .saturating_add(right_rows) + .saturating_sub(inner_estimate); + inner_estimate.max(outer_bound) + } + JoinType::LeftSemi => inner_estimate.min(left_rows), + JoinType::RightSemi => inner_estimate.min(right_rows), + JoinType::LeftAnti => left_rows.saturating_sub(inner_estimate.min(left_rows)), + JoinType::RightAnti => { + right_rows.saturating_sub(inner_estimate.min(right_rows)) + } + JoinType::LeftMark => left_rows, + JoinType::RightMark => right_rows, + }; + + // NL join inner with exact inputs is an exact Cartesian product; + // NDV-based estimates are inherently inexact. + let num_rows = if is_exact_cartesian && join_type == JoinType::Inner { + Precision::Exact(estimated) + } else { + Precision::Inexact(estimated) + }; + + computed_with_row_count(plan, num_rows) + } +} + +/// Statistics provider for [`LocalLimitExec`](crate::limit::LocalLimitExec) and +/// [`GlobalLimitExec`](crate::limit::GlobalLimitExec). +/// +/// Caps output row count at the limit value, accounting for any leading skip offset +/// in `GlobalLimitExec`. +#[derive(Debug, Default)] +pub struct LimitStatisticsProvider; + +impl StatisticsProvider for LimitStatisticsProvider { + fn compute_statistics( + &self, + plan: &dyn ExecutionPlan, + child_stats: &[ExtendedStatistics], + ) -> Result { + use crate::limit::{GlobalLimitExec, LocalLimitExec}; + + if child_stats.is_empty() { + return Ok(StatisticsResult::Delegate); + } + + let (skip, fetch) = if let Some(limit) = plan.downcast_ref::() { + (0usize, Some(limit.fetch())) + } else if let Some(limit) = plan.downcast_ref::() { + (limit.skip(), limit.fetch()) + } else { + return Ok(StatisticsResult::Delegate); + }; + + let num_rows = match child_stats[0].base.num_rows { + Precision::Exact(rows) => { + let available = rows.saturating_sub(skip); + Precision::Exact(fetch.map_or(available, |f| available.min(f))) + } + Precision::Inexact(rows) => { + let available = rows.saturating_sub(skip); + match fetch { + Some(f) => Precision::Inexact(available.min(f)), + None => Precision::Inexact(available), + } + } + Precision::Absent => match fetch { + Some(f) => Precision::Inexact(f), + None => Precision::Absent, + }, + }; + + computed_with_row_count(plan, num_rows) + } +} + +/// Statistics provider for [`UnionExec`](crate::union::UnionExec). +/// +/// Sums row counts across all inputs. +#[derive(Debug, Default)] +pub struct UnionStatisticsProvider; + +impl StatisticsProvider for UnionStatisticsProvider { + fn compute_statistics( + &self, + plan: &dyn ExecutionPlan, + child_stats: &[ExtendedStatistics], + ) -> Result { + use crate::union::UnionExec; + + if plan.downcast_ref::().is_none() { + return Ok(StatisticsResult::Delegate); + } + + let total = child_stats.iter().try_fold( + Precision::Exact(0usize), + |acc, s| -> Result> { + Ok(match (acc, s.base.num_rows) { + (Precision::Absent, _) | (_, Precision::Absent) => Precision::Absent, + (Precision::Exact(a), Precision::Exact(b)) => { + Precision::Exact(a.saturating_add(b)) + } + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => { + Precision::Inexact(a.saturating_add(b)) + } + }) + }, + )?; + + computed_with_row_count(plan, total) + } +} + +type ProviderFn = dyn Fn(&dyn ExecutionPlan, &[ExtendedStatistics]) -> Result + + Send + + Sync; + +/// A [`StatisticsProvider`] backed by a user-supplied closure. +/// +/// Useful for injecting custom statistics in tests or for cardinality feedback +/// pipelines where real runtime statistics need to override plan estimates. +/// The closure receives the current plan node and its children's enhanced +/// statistics, returning a [`StatisticsResult`]. +/// +/// To distinguish between multiple nodes of the same type (e.g., two +/// `FilterExec` nodes), match on structural properties like the input schema's +/// column names, number of columns, or child row counts. +/// +/// # Example +/// +/// ```rust,ignore (requires crate-internal imports) +/// let provider = ClosureStatisticsProvider::new(|plan, child_stats| { +/// if plan.downcast_ref::().is_some() { +/// Ok(StatisticsResult::Computed(ExtendedStatistics::from(Statistics { +/// num_rows: Precision::Inexact(42), +/// ..Statistics::new_unknown(plan.schema().as_ref()) +/// }))) +/// } else { +/// Ok(StatisticsResult::Delegate) +/// } +/// }); +/// ``` +pub struct ClosureStatisticsProvider { + f: Box, +} + +impl ClosureStatisticsProvider { + /// Create a new provider from a closure. + pub fn new( + f: impl Fn(&dyn ExecutionPlan, &[ExtendedStatistics]) -> Result + + Send + + Sync + + 'static, + ) -> Self { + Self { f: Box::new(f) } + } +} + +impl Debug for ClosureStatisticsProvider { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ClosureStatisticsProvider") + } +} + +impl StatisticsProvider for ClosureStatisticsProvider { + fn compute_statistics( + &self, + plan: &dyn ExecutionPlan, + child_stats: &[ExtendedStatistics], + ) -> Result { + (self.f)(plan, child_stats) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::filter::FilterExec; + use crate::projection::ProjectionExec; + use crate::{DisplayAs, DisplayFormatType, PlanProperties}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::stats::Precision; + use datafusion_common::{ColumnStatistics, ScalarValue}; + use datafusion_expr::Operator; + use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal, col, lit}; + use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; + use std::fmt; + + use crate::execution_plan::{Boundedness, EmissionType}; + + fn make_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])) + } + + #[derive(Debug)] + struct MockSourceExec { + schema: Arc, + stats: Statistics, + cache: Arc, + } + + impl MockSourceExec { + fn new(schema: Arc, num_rows: Precision) -> Self { + let num_cols = schema.fields().len(); + Self::with_column_stats( + schema, + num_rows, + vec![ColumnStatistics::new_unknown(); num_cols], + ) + } + + fn with_column_stats( + schema: Arc, + num_rows: Precision, + column_statistics: Vec, + ) -> Self { + let eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + let cache = Arc::new(PlanProperties::new( + eq_properties, + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + )); + Self { + schema, + stats: Statistics { + num_rows, + total_byte_size: Precision::Absent, + column_statistics, + }, + cache, + } + } + } + + impl DisplayAs for MockSourceExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "MockSourceExec") + } + } + + impl ExecutionPlan for MockSourceExec { + fn name(&self) -> &str { + "MockSourceExec" + } + + fn schema(&self) -> Arc { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!() + } + + fn partition_statistics( + &self, + _partition: Option, + ) -> Result> { + Ok(Arc::new(self.stats.clone())) + } + } + + fn make_source(num_rows: usize) -> Arc { + Arc::new(MockSourceExec::new( + make_schema(), + Precision::Exact(num_rows), + )) + } + + #[test] + fn test_default_provider() -> Result<()> { + let engine = StatisticsRegistry::new(); + let source = make_source(1000); + + let stats = engine.compute(source.as_ref())?; + assert!(matches!(stats.base.num_rows, Precision::Exact(1000))); + Ok(()) + } + + #[test] + fn test_custom_chain_configuration() -> Result<()> { + let source = make_source(1000); + + // Test with_providers: fully custom chain (no default) + let custom_only = + StatisticsRegistry::with_providers(vec![Arc::new(CustomStatisticsProvider)]); + // CustomStatisticsProvider only handles CustomExec, delegates for others + // With no default provider, filter returns fallback statistics + let filter: Arc = + Arc::new(FilterExec::try_new(lit(true), Arc::clone(&source))?); + let stats = custom_only.compute(filter.as_ref())?; + // Falls back to plan.statistics() since no provider handles it + assert!(stats.base.num_rows.get_value().is_some()); + + // Test with_providers: custom provider + built-in fallback + let with_override = + StatisticsRegistry::with_providers(vec![Arc::new(OverrideFilterProvider { + fixed_selectivity: 0.25, + }) + as Arc]); + // OverrideFilterProvider handles filters, built-in fallback handles the rest + let stats = with_override.compute(filter.as_ref())?; + assert!(matches!(stats.base.num_rows, Precision::Inexact(250))); + + // Verify chain inspection + assert_eq!(with_override.providers().len(), 1); + + Ok(()) + } + + #[derive(Debug)] + struct CustomExec { + input: Arc, + } + + impl DisplayAs for CustomExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "CustomExec") + } + } + + impl ExecutionPlan for CustomExec { + fn name(&self) -> &str { + "CustomExec" + } + + fn schema(&self) -> Arc { + self.input.schema() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(CustomExec { + input: Arc::clone(&children[0]), + })) + } + + fn properties(&self) -> &Arc { + self.input.properties() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!() + } + } + + #[derive(Debug)] + struct CustomStatisticsProvider; + + impl StatisticsProvider for CustomStatisticsProvider { + fn compute_statistics( + &self, + plan: &dyn ExecutionPlan, + child_stats: &[ExtendedStatistics], + ) -> Result { + if plan.downcast_ref::().is_some() { + Ok(StatisticsResult::Computed(child_stats[0].clone())) + } else { + Ok(StatisticsResult::Delegate) + } + } + } + + #[test] + fn test_custom_provider_for_custom_exec() -> Result<()> { + let mut engine = StatisticsRegistry::new(); + engine.register(Arc::new(CustomStatisticsProvider)); + + let source = make_source(1000); + let custom: Arc = Arc::new(CustomExec { input: source }); + + let stats = engine.compute(custom.as_ref())?; + assert!(matches!(stats.base.num_rows, Precision::Exact(1000))); + Ok(()) + } + + #[derive(Debug)] + struct OverrideFilterProvider { + fixed_selectivity: f64, + } + + impl StatisticsProvider for OverrideFilterProvider { + fn compute_statistics( + &self, + plan: &dyn ExecutionPlan, + child_stats: &[ExtendedStatistics], + ) -> Result { + if plan.downcast_ref::().is_some() { + if let Some(&input_rows) = child_stats[0].base.num_rows.get_value() { + let estimated = (input_rows as f64 * self.fixed_selectivity) as usize; + Ok(StatisticsResult::Computed(ExtendedStatistics::from( + Statistics { + num_rows: Precision::Inexact(estimated), + total_byte_size: Precision::Absent, + column_statistics: child_stats[0] + .base + .column_statistics + .clone(), + }, + ))) + } else { + Ok(StatisticsResult::Delegate) + } + } else { + Ok(StatisticsResult::Delegate) + } + } + } + + #[test] + fn test_override_builtin_operator() -> Result<()> { + let mut engine = StatisticsRegistry::new(); + engine.register(Arc::new(OverrideFilterProvider { + fixed_selectivity: 0.1, + })); + + let source = make_source(1000); + let filter: Arc = + Arc::new(FilterExec::try_new(lit(true), source)?); + + let stats = engine.compute(filter.as_ref())?; + assert!(matches!(stats.base.num_rows, Precision::Inexact(100))); + Ok(()) + } + + #[test] + fn test_filter_statistics_propagation() -> Result<()> { + let engine = StatisticsRegistry::new(); + let source = make_source(1000); + let predicate = lit(true); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, source)?); + + let stats = engine.compute(filter.as_ref())?; + assert!(stats.base.num_rows.get_value().unwrap_or(&0) <= &1000); + Ok(()) + } + + #[test] + fn test_filter_adjusts_ndv_by_selectivity() -> Result<()> { + use datafusion_common::ScalarValue; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{ + BinaryExpr, Column as PhysColumn, Literal, + }; + + // Source: 1000 rows, NDV(a)=1000 (unique), NDV(b)=800 (near-unique) + // With NDV close to num_rows, each value has ~1.25 rows, so filtering + // visibly reduces the number of surviving distinct values. + let schema = make_schema(); // "a" Int32, "b" Int32 + let col_stats = vec![ + { + let mut cs = ColumnStatistics::new_unknown(); + cs.distinct_count = Precision::Exact(1000); + cs.min_value = Precision::Exact(ScalarValue::Int32(Some(1))); + cs.max_value = Precision::Exact(ScalarValue::Int32(Some(1000))); + cs + }, + { + let mut cs = ColumnStatistics::new_unknown(); + cs.distinct_count = Precision::Exact(800); + cs.min_value = Precision::Exact(ScalarValue::Int32(Some(1))); + cs.max_value = Precision::Exact(ScalarValue::Int32(Some(800))); + cs + }, + ]; + let source: Arc = Arc::new(MockSourceExec::with_column_stats( + schema, + Precision::Exact(1000), + col_stats, + )); + + // Filter: a > 900 (selectivity ~10%, keeps values 901-1000) + let predicate: Arc = Arc::new(BinaryExpr::new( + Arc::new(PhysColumn::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(900)))), + )); + let filter: Arc = + Arc::new(FilterExec::try_new(predicate, source)?); + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(FilterStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(filter.as_ref())?; + + let output_ndv_a = stats.base.column_statistics[0] + .distinct_count + .get_value() + .copied() + .unwrap_or(0); + let output_ndv_b = stats.base.column_statistics[1] + .distinct_count + .get_value() + .copied() + .unwrap_or(0); + + // NDV(a): interval analysis narrows to [901,1000] -> ~100 distinct values + assert!( + output_ndv_a <= 100, + "Expected NDV(a) <= 100 after filter, got {output_ndv_a}" + ); + // NDV(b): not in predicate, but selectivity ~10% with 1.25 rows/value + // means many distinct values are lost. ndv_after_selectivity(800, 1000, 0.1) + // gives ~76. Significantly less than the original 800. + assert!( + output_ndv_b < 200, + "Expected NDV(b) < 200 after filter, got {output_ndv_b}" + ); + Ok(()) + } + + #[test] + fn test_projection_statistics_propagation() -> Result<()> { + let engine = StatisticsRegistry::new(); + let source = make_source(1000); + let schema = make_schema(); + let proj: Arc = Arc::new(ProjectionExec::try_new( + vec![(col("a", &schema)?, "a".to_string())], + source, + )?); + + let stats = engine.compute(proj.as_ref())?; + assert!(matches!(stats.base.num_rows, Precision::Exact(1000))); + Ok(()) + } + + #[test] + fn test_passthrough_statistics_propagation() -> Result<()> { + use crate::coalesce_partitions::CoalescePartitionsExec; + + let engine = StatisticsRegistry::new(); + let source = make_source(1000); + let coalesce: Arc = + Arc::new(CoalescePartitionsExec::new(source)); + + let stats = engine.compute(coalesce.as_ref())?; + // PassthroughStatisticsProvider should propagate child row count unchanged + assert_eq!(stats.base.num_rows, Precision::Exact(1000)); + Ok(()) + } + + #[test] + fn test_chain_priority() -> Result<()> { + let mut engine = StatisticsRegistry::new(); + engine.register(Arc::new(OverrideFilterProvider { + fixed_selectivity: 0.5, + })); + engine.register(Arc::new(CustomStatisticsProvider)); + + let source = make_source(1000); + + // CustomExec handled by CustomStatisticsProvider + let custom: Arc = Arc::new(CustomExec { + input: Arc::clone(&source), + }); + let stats = engine.compute(custom.as_ref())?; + assert!(matches!(stats.base.num_rows, Precision::Exact(1000))); + + // FilterExec: CustomStatisticsProvider delegates, OverrideFilterProvider handles + let filter: Arc = + Arc::new(FilterExec::try_new(lit(true), source)?); + let stats = engine.compute(filter.as_ref())?; + assert!(matches!(stats.base.num_rows, Precision::Inexact(500))); + + Ok(()) + } + + // ========================================================================= + // num_distinct_vals Utility Tests + // ========================================================================= + + #[test] + fn test_num_distinct_vals_basic() { + assert_eq!(num_distinct_vals(0, 100), 0); + assert_eq!(num_distinct_vals(100, 0), 0); + assert_eq!(num_distinct_vals(100, 100), 100); + assert_eq!(num_distinct_vals(100, 200), 100); + + let ndv = num_distinct_vals(1000, 100); + assert!((90..=100).contains(&ndv), "Expected ~95, got {ndv}"); + + let ndv = num_distinct_vals(1000, 500); + assert!((350..=450).contains(&ndv), "Expected ~393, got {ndv}"); + + let ndv = num_distinct_vals(1_000_000, 10_000); + assert!((9900..=10000).contains(&ndv), "Expected ~9950, got {ndv}"); + + let ndv = num_distinct_vals(1_000_000, 100); + assert!((99..=100).contains(&ndv), "Expected ~100, got {ndv}"); + } + + #[test] + fn test_num_distinct_vals_small_domain() { + let ndv = num_distinct_vals(10, 5); + assert!((3..=5).contains(&ndv), "Expected ~4, got {ndv}"); + + assert_eq!(num_distinct_vals(10, 20), 10); + assert_eq!(num_distinct_vals(10, 1), 1); + } + + #[test] + fn test_ndv_after_selectivity() { + let ndv = ndv_after_selectivity(1000, 10000, 0.1); + assert!((600..=700).contains(&ndv), "Expected ~632, got {ndv}"); + + let ndv = ndv_after_selectivity(1000, 10000, 0.01); + assert!((90..=100).contains(&ndv), "Expected ~95, got {ndv}"); + + assert_eq!(ndv_after_selectivity(1000, 10000, 0.0), 0); + assert_eq!(ndv_after_selectivity(1000, 10000, 1.0), 1000); + assert_eq!(ndv_after_selectivity(0, 10000, 0.5), 0); + } + + // ========================================================================= + // AggregateStatisticsProvider tests + // ========================================================================= + + use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; + + fn make_source_with_ndv( + num_rows: usize, + col_ndvs: Vec>, + ) -> Arc { + let fields: Vec = col_ndvs + .iter() + .enumerate() + .map(|(i, _)| Field::new(format!("c{i}"), DataType::Int32, false)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + let col_stats = col_ndvs + .into_iter() + .map(|ndv| { + let mut cs = ColumnStatistics::new_unknown(); + if let Some(n) = ndv { + cs.distinct_count = Precision::Exact(n); + } + cs + }) + .collect(); + Arc::new(MockSourceExec::with_column_stats( + schema, + Precision::Exact(num_rows), + col_stats, + )) + } + + fn make_aggregate( + input: Arc, + group_by: PhysicalGroupBy, + ) -> Result> { + Ok(Arc::new(AggregateExec::try_new( + AggregateMode::Single, + group_by, + vec![], + vec![], + Arc::clone(&input), + input.schema(), + )?)) + } + + #[test] + fn test_aggregate_provider_with_ndv() -> Result<()> { + let source = make_source_with_ndv(100, vec![Some(10)]); + let group_by = PhysicalGroupBy::new_single(vec![( + Arc::new(Column::new("c0", 0)), + "c0".to_string(), + )]); + let agg = make_aggregate(source, group_by)?; + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(AggregateStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(agg.as_ref())?; + assert_eq!(stats.base.num_rows, Precision::Inexact(10)); + Ok(()) + } + + #[test] + fn test_aggregate_provider_multi_column() -> Result<()> { + let source = make_source_with_ndv(1000, vec![Some(10), Some(5)]); + let group_by = PhysicalGroupBy::new_single(vec![ + (Arc::new(Column::new("c0", 0)), "c0".to_string()), + (Arc::new(Column::new("c1", 1)), "c1".to_string()), + ]); + let agg = make_aggregate(source, group_by)?; + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(AggregateStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(agg.as_ref())?; + // 10 * 5 = 50 + assert_eq!(stats.base.num_rows, Precision::Inexact(50)); + Ok(()) + } + + #[test] + fn test_aggregate_provider_caps_at_input_rows() -> Result<()> { + // NDV product (100 * 100 = 10_000) exceeds input rows (500) + let source = make_source_with_ndv(500, vec![Some(100), Some(100)]); + let group_by = PhysicalGroupBy::new_single(vec![ + (Arc::new(Column::new("c0", 0)), "c0".to_string()), + (Arc::new(Column::new("c1", 1)), "c1".to_string()), + ]); + let agg = make_aggregate(source, group_by)?; + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(AggregateStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(agg.as_ref())?; + assert_eq!(stats.base.num_rows, Precision::Inexact(500)); + Ok(()) + } + + #[test] + fn test_aggregate_provider_no_ndv_delegates() -> Result<()> { + // No NDV on the GROUP BY column + let source = make_source_with_ndv(100, vec![None]); + let group_by = PhysicalGroupBy::new_single(vec![( + Arc::new(Column::new("c0", 0)), + "c0".to_string(), + )]); + let agg = make_aggregate(source, group_by)?; + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(AggregateStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(agg.as_ref())?; + // Delegates to DefaultStatisticsProvider, which calls partition_statistics + assert!( + stats.base.num_rows.get_value().is_some() + || matches!(stats.base.num_rows, Precision::Absent) + ); + Ok(()) + } + + #[test] + fn test_aggregate_provider_non_column_expr_delegates() -> Result<()> { + let source = make_source_with_ndv(100, vec![Some(10), Some(5)]); + // GROUP BY an expression (c0 + c1), not a simple column ref + let expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("c0", 0)), + Operator::Plus, + Arc::new(Column::new("c1", 1)), + )); + let group_by = PhysicalGroupBy::new_single(vec![(expr, "sum".to_string())]); + let agg = make_aggregate(source, group_by)?; + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(AggregateStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(agg.as_ref())?; + // Should delegate (expression is not a Column) + assert!( + stats.base.num_rows.get_value().is_some() + || matches!(stats.base.num_rows, Precision::Absent) + ); + Ok(()) + } + + #[test] + fn test_aggregate_provider_grouping_sets() -> Result<()> { + let source = make_source_with_ndv(1000, vec![Some(10), Some(5)]); + // GROUPING SETS: (c0, c1), (c0), (c1) -> 3 groups + let group_by = PhysicalGroupBy::new( + vec![ + (Arc::new(Column::new("c0", 0)), "c0".to_string()), + (Arc::new(Column::new("c1", 1)), "c1".to_string()), + ], + vec![ + ( + Arc::new(Literal::new(ScalarValue::Int32(None))), + "c0".to_string(), + ), + ( + Arc::new(Literal::new(ScalarValue::Int32(None))), + "c1".to_string(), + ), + ], + vec![ + vec![false, true], // (c0, NULL) - group by c0 only + vec![true, false], // (NULL, c1) - group by c1 only + vec![false, false], // (c0, c1) - group by both + ], + true, + ); + let agg = make_aggregate(source, group_by)?; + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(AggregateStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(agg.as_ref())?; + // Multiple grouping sets: provider delegates to DefaultStatisticsProvider, + // which calls the built-in partition_statistics for correct per-set + // NDV estimation. The exact value depends on the built-in implementation. + assert!( + stats.base.num_rows.get_value().is_some() + || matches!(stats.base.num_rows, Precision::Absent) + ); + Ok(()) + } + + #[test] + fn test_aggregate_provider_partial_delegates() -> Result<()> { + // Partial aggregates produce per-partition groups; the provider + // should delegate rather than applying global NDV bounds. + let source = make_source_with_ndv(100, vec![Some(10)]); + let group_by = PhysicalGroupBy::new_single(vec![( + Arc::new(Column::new("c0", 0)), + "c0".to_string(), + )]); + let agg: Arc = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by, + vec![], + vec![], + Arc::clone(&source), + source.schema(), + )?); + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(AggregateStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(agg.as_ref())?; + // Should fall through to DefaultStatisticsProvider (partition_statistics). + // The exact value depends on the built-in implementation. + assert!( + stats.base.num_rows.get_value().is_some() + || matches!(stats.base.num_rows, Precision::Absent) + ); + Ok(()) + } + + // ========================================================================= + // JoinStatisticsProvider tests + // ========================================================================= + + use crate::joins::{HashJoinExec, PartitionMode}; + use datafusion_common::{JoinType, NullEquality}; + + fn make_source_with_ndv_2col( + num_rows: usize, + ndv_a: Option, + ) -> Arc { + let schema = make_schema(); // "a" Int32, "b" Int32 + let col_stats = vec![ + { + let mut cs = ColumnStatistics::new_unknown(); + if let Some(n) = ndv_a { + cs.distinct_count = Precision::Exact(n); + } + cs + }, + ColumnStatistics::new_unknown(), + ]; + Arc::new(MockSourceExec::with_column_stats( + schema, + Precision::Exact(num_rows), + col_stats, + )) + } + + fn make_hash_join( + left: Arc, + right: Arc, + ) -> Result> { + let _schema = make_schema(); + let on: crate::joins::JoinOn = vec![( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(Column::new("a", 0)) as Arc, + )]; + Ok(Arc::new(HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + false, + )?)) + } + + #[test] + fn test_join_provider_with_ndv() -> Result<()> { + // left: 1000 rows, NDV(a)=100; right: 500 rows, NDV(a)=50 + // expected = 1000 * 500 / max(100, 50) = 5000 + let left = make_source_with_ndv_2col(1000, Some(100)); + let right = make_source_with_ndv_2col(500, Some(50)); + let join = make_hash_join(left, right)?; + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(JoinStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(join.as_ref())?; + assert_eq!(stats.base.num_rows, Precision::Inexact(5000)); + Ok(()) + } + + #[test] + fn test_join_provider_uses_actual_key_column_ndv() -> Result<()> { + // Join on column "b" (index 1), NDV only set on "b", not "a". + // Old first()-based code would look up column 0 (a), find no NDV, + // and fall back to Cartesian product. The fix looks up column 1 (b). + // left: 1000 rows, NDV(b)=50; right: 500 rows, NDV(b)=25 + // expected = 1000 * 500 / max(50, 25) = 10000 + let schema = make_schema(); // "a" Int32, "b" Int32 + let make_source_ndv_b = + |num_rows: usize, ndv_b: usize| -> Arc { + let col_stats = vec![ + ColumnStatistics::new_unknown(), // "a": no NDV + { + let mut cs = ColumnStatistics::new_unknown(); + cs.distinct_count = Precision::Exact(ndv_b); + cs + }, + ]; + Arc::new(MockSourceExec::with_column_stats( + Arc::clone(&schema), + Precision::Exact(num_rows), + col_stats, + )) + }; + + let left = make_source_ndv_b(1000, 50); + let right = make_source_ndv_b(500, 25); + + // Join on column "b" (index 1) + let on: crate::joins::JoinOn = vec![( + Arc::new(Column::new("b", 1)) as Arc, + Arc::new(Column::new("b", 1)) as Arc, + )]; + let join: Arc = Arc::new(HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + false, + )?); + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(JoinStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(join.as_ref())?; + assert_eq!(stats.base.num_rows, Precision::Inexact(10_000)); + Ok(()) + } + + #[test] + fn test_join_provider_multi_key_ndv() -> Result<()> { + // Multi-key join: ON a.a = b.a AND a.b = b.b + // left: 1000 rows, NDV(a)=100, NDV(b)=20 + // right: 500 rows, NDV(a)=50, NDV(b)=10 + // expected = 1000 * 500 / (max(100,50) * max(20,10)) = 500000 / 2000 = 250 + let schema = make_schema(); // "a" Int32, "b" Int32 + let make_source_2ndv = + |num_rows: usize, ndv_a: usize, ndv_b: usize| -> Arc { + let col_stats = vec![ + { + let mut cs = ColumnStatistics::new_unknown(); + cs.distinct_count = Precision::Exact(ndv_a); + cs + }, + { + let mut cs = ColumnStatistics::new_unknown(); + cs.distinct_count = Precision::Exact(ndv_b); + cs + }, + ]; + Arc::new(MockSourceExec::with_column_stats( + Arc::clone(&schema), + Precision::Exact(num_rows), + col_stats, + )) + }; + + let left = make_source_2ndv(1000, 100, 20); + let right = make_source_2ndv(500, 50, 10); + + let on: crate::joins::JoinOn = vec![ + ( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(Column::new("a", 0)) as Arc, + ), + ( + Arc::new(Column::new("b", 1)) as Arc, + Arc::new(Column::new("b", 1)) as Arc, + ), + ]; + let join: Arc = Arc::new(HashJoinExec::try_new( + left, + right, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + false, + )?); + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(JoinStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(join.as_ref())?; + assert_eq!(stats.base.num_rows, Precision::Inexact(250)); + Ok(()) + } + + #[test] + fn test_join_provider_fallback_cartesian() -> Result<()> { + // No NDV available -> Cartesian product estimate + let left = make_source_with_ndv_2col(100, None); + let right = make_source_with_ndv_2col(200, None); + let join = make_hash_join(left, right)?; + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(JoinStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(join.as_ref())?; + assert_eq!(stats.base.num_rows, Precision::Inexact(20_000)); + Ok(()) + } + + #[test] + fn test_nl_join_delegates() -> Result<()> { + use crate::joins::NestedLoopJoinExec; + + // NL join delegates to the built-in (NestedLoopJoinExec may have an + // arbitrary JoinFilter, so the provider cannot safely assume Cartesian). + let left = make_source(100); + let right = make_source(200); + let join: Arc = Arc::new(NestedLoopJoinExec::try_new( + left, + right, + None, + &JoinType::Inner, + None, + )?); + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(JoinStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(join.as_ref())?; + // Provider delegates; result comes from built-in partition_statistics. + assert!( + stats.base.num_rows.get_value().is_some() + || matches!(stats.base.num_rows, Precision::Absent) + ); + Ok(()) + } + + fn make_hash_join_typed( + left: Arc, + right: Arc, + join_type: JoinType, + ) -> Result> { + let on: crate::joins::JoinOn = vec![( + Arc::new(Column::new("a", 0)) as Arc, + Arc::new(Column::new("a", 0)) as Arc, + )]; + Ok(Arc::new(HashJoinExec::try_new( + left, + right, + on, + None, + &join_type, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + false, + )?)) + } + + fn compute_join_rows( + left_rows: usize, + left_ndv: Option, + right_rows: usize, + right_ndv: Option, + join_type: JoinType, + ) -> Result> { + let left = make_source_with_ndv_2col(left_rows, left_ndv); + let right = make_source_with_ndv_2col(right_rows, right_ndv); + let join = make_hash_join_typed(left, right, join_type)?; + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(JoinStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + Ok(registry.compute(join.as_ref())?.base.num_rows) + } + + #[test] + fn test_join_provider_left_outer() -> Result<()> { + // left=1000, right=500, NDV(a)=100/50 + // inner estimate = 1000*500/100 = 5000, already >= left_rows + // Left outer: max(5000, 1000) = 5000 + assert_eq!( + compute_join_rows(1000, Some(100), 500, Some(50), JoinType::Left)?, + Precision::Inexact(5000) + ); + // Small inner estimate: left=1000, right=10, NDV=100/100 + // inner = 1000*10/100 = 100, left outer = max(100, 1000) = 1000 + assert_eq!( + compute_join_rows(1000, Some(100), 10, Some(100), JoinType::Left)?, + Precision::Inexact(1000) + ); + Ok(()) + } + + #[test] + fn test_join_provider_right_outer() -> Result<()> { + // inner = 1000*10/100 = 100, right outer = max(100, 10) = 100 + assert_eq!( + compute_join_rows(1000, Some(100), 10, Some(100), JoinType::Right)?, + Precision::Inexact(100) + ); + // inner = 10*1000/100 = 100, right outer = max(100, 1000) = 1000 + assert_eq!( + compute_join_rows(10, Some(100), 1000, Some(100), JoinType::Right)?, + Precision::Inexact(1000) + ); + Ok(()) + } + + #[test] + fn test_join_provider_semi_join() -> Result<()> { + // inner = 5000, left semi = min(5000, 1000) = 1000 + assert_eq!( + compute_join_rows(1000, Some(100), 500, Some(50), JoinType::LeftSemi)?, + Precision::Inexact(1000) + ); + // inner = 5000, right semi = min(5000, 500) = 500 + assert_eq!( + compute_join_rows(1000, Some(100), 500, Some(50), JoinType::RightSemi)?, + Precision::Inexact(500) + ); + // Cartesian fallback (no NDV): inner = 1000*500 = 500000, + // left semi = min(500000, 1000) = 1000 (selectivity = 1.0) + assert_eq!( + compute_join_rows(1000, None, 500, None, JoinType::LeftSemi)?, + Precision::Inexact(1000) + ); + Ok(()) + } + + #[test] + fn test_join_provider_anti_join() -> Result<()> { + // inner = 1000*10/100 = 100, left anti = 1000 - min(100, 1000) = 900 + assert_eq!( + compute_join_rows(1000, Some(100), 10, Some(100), JoinType::LeftAnti)?, + Precision::Inexact(900) + ); + // inner = 5000, right anti = 500 - min(5000, 500) = 0 + assert_eq!( + compute_join_rows(1000, Some(100), 500, Some(50), JoinType::RightAnti)?, + Precision::Inexact(0) + ); + Ok(()) + } + + // ========================================================================= + // CrossJoinExec tests (handled by JoinStatisticsProvider) + // ========================================================================= + + #[test] + fn test_cross_join_provider_exact() -> Result<()> { + use crate::joins::CrossJoinExec; + let left = make_source(100); + let right = make_source(200); + let join: Arc = Arc::new(CrossJoinExec::new(left, right)); + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(JoinStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(join.as_ref())?; + // Both inputs have Exact row counts -> result is also Exact + assert_eq!(stats.base.num_rows, Precision::Exact(20_000)); + Ok(()) + } + + // ========================================================================= + // LimitStatisticsProvider tests + // ========================================================================= + + use crate::limit::{GlobalLimitExec, LocalLimitExec}; + + #[test] + fn test_limit_provider_caps_output() -> Result<()> { + // input > fetch -> capped at fetch + let source = make_source(1000); + let limit: Arc = Arc::new(LocalLimitExec::new(source, 100)); + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(LimitStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(limit.as_ref())?; + assert_eq!(stats.base.num_rows, Precision::Exact(100)); + Ok(()) + } + + #[test] + fn test_limit_provider_input_smaller_than_fetch() -> Result<()> { + // input < fetch -> output = input + let source = make_source(50); + let limit: Arc = Arc::new(LocalLimitExec::new(source, 200)); + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(LimitStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(limit.as_ref())?; + assert_eq!(stats.base.num_rows, Precision::Exact(50)); + Ok(()) + } + + #[test] + fn test_global_limit_provider_skip_and_fetch() -> Result<()> { + // 1000 rows, skip 200, fetch 100 -> exactly 100 + let source = make_source(1000); + let limit: Arc = + Arc::new(GlobalLimitExec::new(source, 200, Some(100))); + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(LimitStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(limit.as_ref())?; + assert_eq!(stats.base.num_rows, Precision::Exact(100)); + Ok(()) + } + + #[test] + fn test_global_limit_provider_skip_exceeds_rows() -> Result<()> { + // 100 rows, skip 200 -> 0 rows (skip > available) + let source = make_source(100); + let limit: Arc = + Arc::new(GlobalLimitExec::new(source, 200, Some(50))); + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(LimitStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(limit.as_ref())?; + assert_eq!(stats.base.num_rows, Precision::Exact(0)); + Ok(()) + } + + #[test] + fn test_limit_provider_inexact_input() -> Result<()> { + // Inexact(1000) with fetch=100: result must stay Inexact, not Exact, + // because the actual row count could be less than 100. + let source = make_source_with_precision(Precision::Inexact(1000)); + let limit: Arc = Arc::new(LocalLimitExec::new(source, 100)); + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(LimitStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(limit.as_ref())?; + assert_eq!(stats.base.num_rows, Precision::Inexact(100)); + Ok(()) + } + + // ========================================================================= + // UnionStatisticsProvider tests + // ========================================================================= + + use crate::union::UnionExec; + + fn make_source_with_precision(num_rows: Precision) -> Arc { + Arc::new(MockSourceExec::new(make_schema(), num_rows)) + } + + #[test] + fn test_union_provider_sums_rows() -> Result<()> { + let union = UnionExec::try_new(vec![make_source(300), make_source(700)])?; + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(UnionStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(union.as_ref())?; + assert_eq!(stats.base.num_rows, Precision::Exact(1000)); + Ok(()) + } + + #[test] + fn test_union_provider_three_inputs() -> Result<()> { + let union = UnionExec::try_new(vec![ + make_source(100), + make_source(200), + make_source(300), + ])?; + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(UnionStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(union.as_ref())?; + assert_eq!(stats.base.num_rows, Precision::Exact(600)); + Ok(()) + } + + #[test] + fn test_union_provider_absent_propagates() -> Result<()> { + // One input with unknown row count -> result must be Absent, not Inexact(300) + let union = UnionExec::try_new(vec![ + make_source(300), + make_source_with_precision(Precision::Absent), + ])?; + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(UnionStatisticsProvider), + Arc::new(DefaultStatisticsProvider), + ]); + let stats = registry.compute(union.as_ref())?; + assert_eq!(stats.base.num_rows, Precision::Absent); + Ok(()) + } + + // ========================================================================= + // ClosureStatisticsProvider tests + // ========================================================================= + + #[test] + fn test_closure_provider_basic() -> Result<()> { + // Override all FilterExec stats with a fixed row count + let provider = ClosureStatisticsProvider::new(|plan, _child_stats| { + if plan.downcast_ref::().is_some() { + Ok(StatisticsResult::Computed(ExtendedStatistics::from( + Statistics { + num_rows: Precision::Inexact(42), + total_byte_size: Precision::Absent, + column_statistics: vec![], + }, + ))) + } else { + Ok(StatisticsResult::Delegate) + } + }); + + let registry = StatisticsRegistry::with_providers(vec![ + Arc::new(provider), + Arc::new(DefaultStatisticsProvider), + ]); + + let source = make_source(1000); + let filter: Arc = + Arc::new(FilterExec::try_new(lit(true), source)?); + let stats = registry.compute(filter.as_ref())?; + assert_eq!(stats.base.num_rows, Precision::Inexact(42)); + Ok(()) + } + + #[test] + fn test_closure_provider_distinguishes_nodes_by_child_stats() -> Result<()> { + // Two FilterExec nodes with different input sizes. + // The closure uses the child row count as a proxy to distinguish them, + // which mirrors the cardinality feedback use case where you match a + // runtime-observed count to the right node in the plan tree. + let provider = ClosureStatisticsProvider::new(|plan, child_stats| { + if plan.downcast_ref::().is_none() { + return Ok(StatisticsResult::Delegate); + } + match child_stats[0].base.num_rows.get_value().copied() { + Some(500) => Ok(StatisticsResult::Computed(ExtendedStatistics::from( + Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Absent, + column_statistics: vec![], + }, + ))), + Some(200) => Ok(StatisticsResult::Computed(ExtendedStatistics::from( + Statistics { + num_rows: Precision::Inexact(50), + total_byte_size: Precision::Absent, + column_statistics: vec![], + }, + ))), + _ => Ok(StatisticsResult::Delegate), + } + }); + + let registry = StatisticsRegistry::with_providers(vec![Arc::new(provider)]); + + let filter_a: Arc = + Arc::new(FilterExec::try_new(lit(true), make_source(500))?); + let filter_b: Arc = + Arc::new(FilterExec::try_new(lit(true), make_source(200))?); + + let stats_a = registry.compute(filter_a.as_ref())?; + let stats_b = registry.compute(filter_b.as_ref())?; + + assert_eq!(stats_a.base.num_rows, Precision::Inexact(100)); + assert_eq!(stats_b.base.num_rows, Precision::Inexact(50)); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index be4c3da509e88..b99f9a93045fb 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -17,20 +17,19 @@ //! EmptyRelation produce_one_row=true execution plan -use std::any::Any; use std::sync::Arc; use crate::coop::cooperative; use crate::execution_plan::{Boundedness, EmissionType, SchedulingType}; use crate::memory::MemoryStream; use crate::{ - common, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, - SendableRecordBatchStream, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, Statistics, common, }; use arrow::array::{ArrayRef, NullArray, RecordBatch, RecordBatchOptions}; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; -use datafusion_common::{assert_or_internal_err, Result}; +use datafusion_common::{Result, assert_or_internal_err}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; @@ -43,7 +42,7 @@ pub struct PlaceholderRowExec { schema: SchemaRef, /// Number of partitions partitions: usize, - cache: PlanProperties, + cache: Arc, } impl PlaceholderRowExec { @@ -54,7 +53,7 @@ impl PlaceholderRowExec { PlaceholderRowExec { schema, partitions, - cache, + cache: Arc::new(cache), } } @@ -63,7 +62,7 @@ impl PlaceholderRowExec { self.partitions = partitions; // Update output partitioning when updating partitions: let output_partitioning = Self::output_partitioning_helper(self.partitions); - self.cache = self.cache.with_partitioning(output_partitioning); + Arc::make_mut(&mut self.cache).partitioning = output_partitioning; self } @@ -128,11 +127,7 @@ impl ExecutionPlan for PlaceholderRowExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -152,7 +147,12 @@ impl ExecutionPlan for PlaceholderRowExec { partition: usize, context: Arc, ) -> Result { - trace!("Start PlaceholderRowExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + trace!( + "Start PlaceholderRowExec::execute for partition {} of context session_id {} and task_id {:?}", + partition, + context.session_id(), + context.task_id() + ); assert_or_internal_err!( partition < self.partitions, @@ -164,11 +164,7 @@ impl ExecutionPlan for PlaceholderRowExec { Ok(Box::pin(cooperative(ms))) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { let batches = self .data() .expect("Create single row placeholder RecordBatch should not fail"); @@ -179,11 +175,11 @@ impl ExecutionPlan for PlaceholderRowExec { None => vec![batches; self.partitions], }; - Ok(common::compute_record_batch_statistics( + Ok(Arc::new(common::compute_record_batch_statistics( &batches, &self.schema, None, - )) + ))) } } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 0b8c4ee5fbec1..ade3a988c7b61 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -20,20 +20,20 @@ //! of a projection on table `t1` where the expressions `a`, `b`, and `a+b` are the //! projection expressions. `SELECT` without `FROM` will only evaluate expressions. -use super::expressions::{Column, Literal}; +use super::expressions::Column; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, - SendableRecordBatchStream, Statistics, + SendableRecordBatchStream, SortOrderPushdownResult, Statistics, }; +use crate::column_rewriter::PhysicalColumnRewriter; use crate::execution_plan::CardinalityEffect; use crate::filter_pushdown::{ - ChildPushdownResult, FilterDescription, FilterPushdownPhase, - FilterPushdownPropagation, + ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, FilterRemapper, PushedDownPredicate, }; use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn, JoinOnRef}; -use crate::{DisplayFormatType, ExecutionPlan, PhysicalExpr}; -use std::any::Any; +use crate::{DisplayFormatType, ExecutionPlan, PhysicalExpr, check_if_same_properties}; use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; @@ -45,17 +45,19 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; -use datafusion_common::{internal_err, JoinSide, Result}; +use datafusion_common::{DataFusionError, JoinSide, Result, internal_err}; use datafusion_execution::TaskContext; +use datafusion_expr::ExpressionPlacement; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::projection::Projector; -use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql}; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, LexRequirement, PhysicalSortExpr, +}; // Re-exported from datafusion-physical-expr for backwards compatibility // We recommend updating your imports to use datafusion-physical-expr directly pub use datafusion_physical_expr::projection::{ - update_expr, ProjectionExpr, ProjectionExprs, + ProjectionExpr, ProjectionExprs, update_expr, }; use futures::stream::{Stream, StreamExt}; @@ -75,7 +77,7 @@ pub struct ProjectionExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl ProjectionExec { @@ -134,13 +136,19 @@ impl ProjectionExec { E: Into, { let input_schema = input.schema(); - // convert argument to Vec - let expr_vec = expr.into_iter().map(Into::into).collect::>(); - let projection = ProjectionExprs::new(expr_vec); + let expr_arc = expr.into_iter().map(Into::into).collect::>(); + let projection = ProjectionExprs::from_expressions(expr_arc); let projector = projection.make_projector(&input_schema)?; + Self::try_from_projector(projector, input) + } + fn try_from_projector( + projector: Projector, + input: Arc, + ) -> Result { // Construct a map from the input expressions to the output expression of the Projection - let projection_mapping = projection.projection_mapping(&input_schema)?; + let projection_mapping = + projector.projection().projection_mapping(&input.schema())?; let cache = Self::compute_properties( &input, &projection_mapping, @@ -150,7 +158,7 @@ impl ProjectionExec { projector, input, metrics: ExecutionPlanMetricsSet::new(), - cache, + cache: Arc::new(cache), }) } @@ -159,6 +167,11 @@ impl ProjectionExec { self.projector.projection().as_ref() } + /// The projection expressions as a [`ProjectionExprs`]. + pub fn projection_expr(&self) -> &ProjectionExprs { + self.projector.projection() + } + /// The input plan pub fn input(&self) -> &Arc { &self.input @@ -185,6 +198,40 @@ impl ProjectionExec { input.boundedness(), )) } + + /// Collect reverse alias mapping from projection expressions. + /// The result hash map is a map from aliased Column in parent to original expr. + fn collect_reverse_alias( + &self, + ) -> Result>> { + let mut alias_map = datafusion_common::HashMap::new(); + for projection in self.projection_expr().iter() { + let (aliased_index, _output_field) = self + .projector + .output_schema() + .column_with_name(&projection.alias) + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Expr {} with alias {} not found in output schema", + projection.expr, projection.alias + )) + })?; + let aliased_col = Column::new(&projection.alias, aliased_index); + alias_map.insert(aliased_col, Arc::clone(&projection.expr)); + } + Ok(alias_map) + } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for ProjectionExec { @@ -234,11 +281,7 @@ impl ExecutionPlan for ProjectionExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -254,10 +297,13 @@ impl ExecutionPlan for ProjectionExec { .as_ref() .iter() .all(|proj_expr| { - proj_expr.expr.as_any().is::() - || proj_expr.expr.as_any().is::() + !matches!( + proj_expr.expr.placement(), + ExpressionPlacement::KeepInPlace + ) }); - // If expressions are all either column_expr or Literal, then all computations in this projection are reorder or rename, + // If expressions are all either column_expr or Literal (or other cheap expressions), + // then all computations in this projection are reorder or rename, // and projection would not benefit from the repartition, benefits_from_input_partitioning will return false. vec![!all_simple_exprs] } @@ -270,8 +316,9 @@ impl ExecutionPlan for ProjectionExec { self: Arc, mut children: Vec>, ) -> Result> { - ProjectionExec::try_new( - self.projector.projection().clone(), + check_if_same_properties!(self, children); + ProjectionExec::try_from_projector( + self.projector.clone(), children.swap_remove(0), ) .map(|p| Arc::new(p) as _) @@ -282,9 +329,16 @@ impl ExecutionPlan for ProjectionExec { partition: usize, context: Arc, ) -> Result { - trace!("Start ProjectionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + trace!( + "Start ProjectionExec::execute for partition {} of context session_id {} and task_id {:?}", + partition, + context.session_id(), + context.task_id() + ); + + let projector = self.projector.with_metrics(&self.metrics, partition); Ok(Box::pin(ProjectionStream::new( - self.projector.clone(), + projector, self.input.execute(partition, context)?, BaselineMetrics::new(&self.metrics, partition), )?)) @@ -294,15 +348,15 @@ impl ExecutionPlan for ProjectionExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { - let input_stats = self.input.partition_statistics(partition)?; - self.projector - .projection() - .project_statistics(input_stats, &self.input.schema()) + fn partition_statistics(&self, partition: Option) -> Result> { + let input_stats = + Arc::unwrap_or_clone(self.input.partition_statistics(partition)?); + let output_schema = self.schema(); + Ok(Arc::new( + self.projector + .projection() + .project_statistics(input_stats, &output_schema)?, + )) } fn supports_limit_pushdown(&self) -> bool { @@ -317,12 +371,9 @@ impl ExecutionPlan for ProjectionExec { &self, projection: &ProjectionExec, ) -> Result>> { - let maybe_unified = try_unifying_projections(projection, self)?; - if let Some(new_plan) = maybe_unified { - // To unify 3 or more sequential projections: - remove_unnecessary_projections(new_plan).data().map(Some) - } else { - Ok(Some(Arc::new(projection.clone()))) + match try_collapse_projection_chain(projection)? { + Some(plan) => Ok(Some(plan)), + None => Ok(Some(Arc::new(projection.clone()))), } } @@ -332,10 +383,28 @@ impl ExecutionPlan for ProjectionExec { parent_filters: Vec>, _config: &ConfigOptions, ) -> Result { - // TODO: In future, we can try to handle inverting aliases here. - // For the time being, we pass through untransformed filters, so filters on aliases are not handled. - // https://github.com/apache/datafusion/issues/17246 - FilterDescription::from_children(parent_filters, &self.children()) + // expand alias column to original expr in parent filters + let invert_alias_map = self.collect_reverse_alias()?; + let output_schema = self.schema(); + let remapper = FilterRemapper::new(output_schema); + let mut child_parent_filters = Vec::with_capacity(parent_filters.len()); + + for filter in parent_filters { + // Check that column exists in child, then reassign column indices to match child schema + if let Some(reassigned) = remapper.try_remap(&filter)? { + // rewrite filter expression using invert alias map + let mut rewriter = PhysicalColumnRewriter::new(&invert_alias_map); + let rewritten = reassigned.rewrite(&mut rewriter)?.data; + child_parent_filters.push(PushedDownPredicate::supported(rewritten)); + } else { + child_parent_filters.push(PushedDownPredicate::unsupported(filter)); + } + } + + Ok(FilterDescription::new().with_child(ChildFilterDescription { + parent_filters: child_parent_filters, + self_filters: vec![], + })) } fn handle_child_pushdown_result( @@ -346,6 +415,83 @@ impl ExecutionPlan for ProjectionExec { ) -> Result>> { Ok(FilterPushdownPropagation::if_all(child_pushdown_result)) } + + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + ) -> Result>> { + let child = self.input(); + let mut child_order = Vec::new(); + + // Check and transform sort expressions + for sort_expr in order { + // Recursively transform the expression + let mut can_pushdown = true; + let transformed = Arc::clone(&sort_expr.expr).transform(|expr| { + if let Some(col) = expr.downcast_ref::() { + // Check if column index is valid. + // This should always be true but fail gracefully if it's not. + if col.index() >= self.expr().len() { + can_pushdown = false; + return Ok(Transformed::no(expr)); + } + + let proj_expr = &self.expr()[col.index()]; + + // Check if projection expression is a simple column + // We cannot push down order by clauses that depend on + // projected computations as they would have nothing to reference. + if let Some(child_col) = proj_expr.expr.downcast_ref::() { + // Replace with the child column + Ok(Transformed::yes(Arc::new(child_col.clone()) as _)) + } else { + // Projection involves computation, cannot push down + can_pushdown = false; + Ok(Transformed::no(expr)) + } + } else { + Ok(Transformed::no(expr)) + } + })?; + + if !can_pushdown { + return Ok(SortOrderPushdownResult::Unsupported); + } + + child_order.push(PhysicalSortExpr { + expr: transformed.data, + options: sort_expr.options, + }); + } + + // Recursively push down to child node + match child.try_pushdown_sort(&child_order)? { + SortOrderPushdownResult::Exact { inner } => { + let new_exec = Arc::new(self.clone()).with_new_children(vec![inner])?; + Ok(SortOrderPushdownResult::Exact { inner: new_exec }) + } + SortOrderPushdownResult::Inexact { inner } => { + let new_exec = Arc::new(self.clone()).with_new_children(vec![inner])?; + Ok(SortOrderPushdownResult::Inexact { inner: new_exec }) + } + SortOrderPushdownResult::Unsupported => { + Ok(SortOrderPushdownResult::Unsupported) + } + } + } + + fn with_preserve_order( + &self, + preserve_order: bool, + ) -> Option> { + self.input + .with_preserve_order(preserve_order) + .and_then(|new_input| { + Arc::new(self.clone()) + .with_new_children(vec![new_input]) + .ok() + }) + } } impl ProjectionStream { @@ -404,6 +550,15 @@ impl RecordBatchStream for ProjectionStream { } } +/// Trait for execution plans that can embed a projection, avoiding a separate +/// [`ProjectionExec`] wrapper. +/// +/// # Empty projections +/// +/// `Some(vec![])` is a valid projection that produces zero output columns while +/// preserving the correct row count. Implementors must ensure that runtime batch +/// construction still returns batches with the right number of rows even when no +/// columns are selected (e.g. for `SELECT count(1) … JOIN …`). pub trait EmbeddedProjection: ExecutionPlan + Sized { fn with_projection(&self, projection: Option>) -> Result; } @@ -414,6 +569,15 @@ pub fn try_embed_projection( projection: &ProjectionExec, execution_plan: &Exec, ) -> Result>> { + // If the projection has no expressions at all (e.g., ProjectionExec: expr=[]), + // embed an empty projection into the execution plan so it outputs zero columns. + // This avoids allocating throwaway null arrays for build-side columns + // when no output columns are actually needed (e.g., count(1) over a right join). + if projection.expr().is_empty() { + let new_execution_plan = Arc::new(execution_plan.with_projection(Some(vec![]))?); + return Ok(Some(new_execution_plan)); + } + // Collect all column indices from the given projection expressions. let projection_index = collect_column_indices(projection.expr()); @@ -421,13 +585,7 @@ pub fn try_embed_projection( return Ok(None); }; - // If the projection indices is the same as the input columns, we don't need to embed the projection to hash join. - // Check the projection_index is 0..n-1 and the length of projection_index is the same as the length of execution_plan schema fields. - if projection_index.len() == projection_index.last().unwrap() + 1 - && projection_index.len() == execution_plan.schema().fields().len() - { - return Ok(None); - } + let columns_reduced = projection_index.len() < execution_plan.schema().fields().len(); let new_execution_plan = Arc::new(execution_plan.with_projection(Some(projection_index.to_vec()))?); @@ -462,9 +620,16 @@ pub fn try_embed_projection( Arc::clone(&new_execution_plan) as _, )?); if is_projection_removable(&new_projection) { + // Residual is identity — embedding fully absorbed the projection. Ok(Some(new_execution_plan)) - } else { + } else if columns_reduced { + // Embedding reduced columns even though a residual is still needed + // for renames or expressions — worth keeping. Ok(Some(new_projection)) + } else { + // No columns eliminated and residual still needed — embedding just + // adds an unnecessary column reorder inside the operator. + Ok(None) } } @@ -546,20 +711,19 @@ pub fn try_pushdown_through_join( pub fn remove_unnecessary_projections( plan: Arc, ) -> Result>> { - let maybe_modified = - if let Some(projection) = plan.as_any().downcast_ref::() { - // If the projection does not cause any change on the input, we can - // safely remove it: - if is_projection_removable(projection) { - return Ok(Transformed::yes(Arc::clone(projection.input()))); - } - // If it does, check if we can push it under its child(ren): - projection - .input() - .try_swapping_with_projection(projection)? - } else { - return Ok(Transformed::no(plan)); - }; + let maybe_modified = if let Some(projection) = plan.downcast_ref::() { + // If the projection does not cause any change on the input, we can + // safely remove it: + if is_projection_removable(projection) { + return Ok(Transformed::yes(Arc::clone(projection.input()))); + } + // If it does, check if we can push it under its child(ren): + projection + .input() + .try_swapping_with_projection(projection)? + } else { + return Ok(Transformed::no(plan)); + }; Ok(maybe_modified.map_or_else(|| Transformed::no(plan), Transformed::yes)) } @@ -570,7 +734,7 @@ pub fn remove_unnecessary_projections( fn is_projection_removable(projection: &ProjectionExec) -> bool { let exprs = projection.expr(); exprs.iter().enumerate().all(|(idx, proj_expr)| { - let Some(col) = proj_expr.expr.as_any().downcast_ref::() else { + let Some(col) = proj_expr.expr.downcast_ref::() else { return false; }; col.name() == proj_expr.alias && col.index() == idx @@ -583,7 +747,6 @@ pub fn all_alias_free_columns(exprs: &[ProjectionExpr]) -> bool { exprs.iter().all(|proj_expr| { proj_expr .expr - .as_any() .downcast_ref::() .map(|column| column.name() == proj_expr.alias) .unwrap_or(false) @@ -602,7 +765,6 @@ pub fn new_projections_for_columns( .filter_map(|proj_expr| { proj_expr .expr - .as_any() .downcast_ref::() .map(|expr| source[expr.index()]) }) @@ -621,9 +783,7 @@ pub fn make_with_child( /// Returns `true` if all the expressions in the argument are `Column`s. pub fn all_columns(exprs: &[ProjectionExpr]) -> bool { - exprs - .iter() - .all(|proj_expr| proj_expr.expr.as_any().is::()) + exprs.iter().all(|proj_expr| proj_expr.expr.is::()) } /// Updates the given lexicographic ordering according to given projected @@ -672,7 +832,6 @@ pub fn physical_to_column_exprs( .map(|proj_expr| { proj_expr .expr - .as_any() .downcast_ref::() .map(|col| (col.clone(), proj_expr.alias.clone())) }) @@ -780,10 +939,6 @@ pub fn update_join_on( hash_join_on: &[(PhysicalExprRef, PhysicalExprRef)], left_field_size: usize, ) -> Option> { - // TODO: Clippy wants the "map" call removed, but doing so generates - // a compilation error. Remove the clippy directive once this - // issue is fixed. - #[allow(clippy::map_identity)] let (left_idx, right_idx): (Vec<_>, Vec<_>) = hash_join_on .iter() .map(|(left, right)| (left, right)) @@ -845,64 +1000,104 @@ pub fn update_join_filter( }) } -/// Unifies `projection` with its input (which is also a [`ProjectionExec`]). -fn try_unifying_projections( - projection: &ProjectionExec, - child: &ProjectionExec, +/// Collapse a chain of consecutive [`ProjectionExec`]s into one. Returns +/// `None` if nothing could be merged. +fn try_collapse_projection_chain( + outer: &ProjectionExec, ) -> Result>> { - let mut projected_exprs = vec![]; + let mut current_exprs: Vec = outer.expr().to_vec(); + let mut current_input: Arc = Arc::clone(outer.input()); let mut column_ref_map: HashMap = HashMap::new(); + let mut collapsed_any = false; + + 'outer: while let Some(inner_proj) = current_input.downcast_ref::() { + // Collect the column references usage in the outer projection. + column_ref_map.clear(); + for proj_expr in ¤t_exprs { + proj_expr.expr.apply(|expr| { + if let Some(column) = expr.downcast_ref::() { + *column_ref_map.entry(column.clone()).or_default() += 1; + } + Ok(TreeNodeRecursion::Continue) + })?; + } + let inner_exprs = inner_proj.expr(); + // Merging these projections is not beneficial, e.g + // If an expression is not trivial (KeepInPlace) and it is referred more than 1, unifies projections will be + // beneficial as caching mechanism for non-trivial computations. + // See discussion in: https://github.com/apache/datafusion/issues/8296 + let blocked = column_ref_map.iter().any(|(column, count)| { + *count > 1 + && !inner_exprs[column.index()] + .expr + .placement() + .should_push_to_leaves() + }); + if blocked { + break; + } - // Collect the column references usage in the outer projection. - projection.expr().iter().for_each(|proj_expr| { - proj_expr - .expr - .apply(|expr| { - Ok({ - if let Some(column) = expr.as_any().downcast_ref::() { - *column_ref_map.entry(column.clone()).or_default() += 1; - } - TreeNodeRecursion::Continue - }) - }) - .unwrap(); - }); - // Merging these projections is not beneficial, e.g - // If an expression is not trivial and it is referred more than 1, unifies projections will be - // beneficial as caching mechanism for non-trivial computations. - // See discussion in: https://github.com/apache/datafusion/issues/8296 - if column_ref_map.iter().any(|(column, count)| { - *count > 1 && !is_expr_trivial(&Arc::clone(&child.expr()[column.index()].expr)) - }) { - return Ok(None); + let mut new_phys: Vec> = + Vec::with_capacity(current_exprs.len()); + for proj_expr in ¤t_exprs { + // If there is no match in the input projection, we cannot unify these + // projections. This case will arise if the projection expression contains + // a `PhysicalExpr` variant `update_expr` doesn't support. + let Some(expr) = update_expr(&proj_expr.expr, inner_exprs, true)? else { + break 'outer; + }; + new_phys.push(expr); + } + for (proj_expr, expr) in current_exprs.iter_mut().zip(new_phys) { + proj_expr.expr = expr; + } + current_input = Arc::clone(inner_proj.input()); + collapsed_any = true; } - for proj_expr in projection.expr() { - // If there is no match in the input projection, we cannot unify these - // projections. This case will arise if the projection expression contains - // a `PhysicalExpr` variant `update_expr` doesn't support. - let Some(expr) = update_expr(&proj_expr.expr, child.expr(), true)? else { - return Ok(None); - }; - projected_exprs.push(ProjectionExpr { - expr, - alias: proj_expr.alias.clone(), - }); + + if !collapsed_any { + return Ok(None); } - ProjectionExec::try_new(projected_exprs, Arc::clone(child.input())) - .map(|e| Some(Arc::new(e) as _)) + + // To unify 3 or more sequential projections: + let unified: Arc = + Arc::new(ProjectionExec::try_new(current_exprs, current_input)?); + remove_unnecessary_projections(unified).data().map(Some) } /// Collect all column indices from the given projection expressions. fn collect_column_indices(exprs: &[ProjectionExpr]) -> Vec { - // Collect indices and remove duplicates. - let mut indices = exprs - .iter() - .flat_map(|proj_expr| collect_columns(&proj_expr.expr)) - .map(|x| x.index()) - .collect::>() - .into_iter() - .collect::>(); - indices.sort(); + // Collect column indices in a deterministic order that preserves the + // projection's column ordering. For simple Column expressions, we use + // the column index directly. For complex expressions, we walk the + // expression tree to collect column references in traversal order. + // This allows the embedded projection to match the desired output + // column order, avoiding a residual ProjectionExec. + let mut seen = std::collections::HashSet::new(); + let mut indices = Vec::new(); + for proj_expr in exprs { + if let Some(col) = proj_expr.expr.downcast_ref::() { + // Simple column reference: preserve projection order. + if seen.insert(col.index()) { + indices.push(col.index()); + } + } else { + // Complex expression: collect all referenced columns in + // expression tree traversal order (deterministic) to preserve + // the natural ordering of column references. + proj_expr + .expr + .apply(|expr| { + if let Some(col) = expr.downcast_ref::() + && seen.insert(col.index()) + { + indices.push(col.index()); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("closure always returns OK"); + } + } indices } @@ -949,7 +1144,7 @@ fn new_columns_for_join_on( // Rewrite all columns in `on` Arc::clone(*on) .transform(|expr| { - if let Some(column) = expr.as_any().downcast_ref::() { + if let Some(column) = expr.downcast_ref::() { // Find the column in the projection expressions let new_column = projection_exprs .iter() @@ -982,28 +1177,24 @@ fn new_columns_for_join_on( (new_columns.len() == hash_join_on.len()).then_some(new_columns) } -/// Checks if the given expression is trivial. -/// An expression is considered trivial if it is either a `Column` or a `Literal`. -fn is_expr_trivial(expr: &Arc) -> bool { - expr.as_any().downcast_ref::().is_some() - || expr.as_any().downcast_ref::().is_some() -} - #[cfg(test)] mod tests { use super::*; - use std::sync::Arc; use crate::common::collect; + + use crate::filter_pushdown::PushedDown; use crate::test; use crate::test::exec::StatisticsExec; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::stats::{ColumnStatistics, Precision, Statistics}; use datafusion_common::ScalarValue; + use datafusion_common::stats::{ColumnStatistics, Precision, Statistics}; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{col, BinaryExpr, Column, Literal}; + use datafusion_physical_expr::expressions::{ + BinaryExpr, Column, DynamicFilterPhysicalExpr, Literal, binary, col, lit, + }; #[test] fn test_collect_column_indices() -> Result<()> { @@ -1020,7 +1211,8 @@ mod tests { expr, alias: "b-(1+a)".to_string(), }]); - assert_eq!(column_indices, vec![1, 7]); + // Tree traversal order: b@7 is visited before a@1 + assert_eq!(column_indices, vec![7, 1]); Ok(()) } @@ -1192,4 +1384,431 @@ mod tests { ); assert!(stats.total_byte_size.is_exact().unwrap_or(false)); } + + #[test] + fn test_filter_pushdown_with_alias() -> Result<()> { + let input_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics::new_unknown(&input_schema), + input_schema.clone(), + )); + + // project "a" as "b" + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "b".to_string(), + }], + input, + )?; + + // filter "b > 5" + let filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter], + &ConfigOptions::default(), + )?; + + // Should be converted to "a > 5" + // "a" is index 0 in input + let expected_filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + assert_eq!(description.self_filters(), vec![vec![]]); + let pushed_filters = &description.parent_filters()[0]; + assert_eq!( + format!("{}", pushed_filters[0].predicate), + format!("{}", expected_filter) + ); + // Verify the predicate was actually pushed down + assert!(matches!(pushed_filters[0].discriminant, PushedDown::Yes)); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_multiple_aliases() -> Result<()> { + let input_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a" as "x", "b" as "y" + let projection = ProjectionExec::try_new( + vec![ + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "x".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "y".to_string(), + }, + ], + input, + )?; + + // filter "x > 5" + let filter1 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + // filter "y < 10" + let filter2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("y", 1)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter1, filter2], + &ConfigOptions::default(), + )?; + + // Should be converted to "a > 5" and "b < 10" + let expected_filter1 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + let expected_filter2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let pushed_filters = &description.parent_filters()[0]; + assert_eq!(pushed_filters.len(), 2); + // Note: The order of filters is preserved + assert_eq!( + format!("{}", pushed_filters[0].predicate), + format!("{}", expected_filter1) + ); + assert_eq!( + format!("{}", pushed_filters[1].predicate), + format!("{}", expected_filter2) + ); + // Verify the predicates were actually pushed down + assert!(matches!(pushed_filters[0].discriminant, PushedDown::Yes)); + assert!(matches!(pushed_filters[1].discriminant, PushedDown::Yes)); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_swapped_aliases() -> Result<()> { + let input_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a" as "b", "b" as "a" + let projection = ProjectionExec::try_new( + vec![ + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "b".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "a".to_string(), + }, + ], + input, + )?; + + // filter "b > 5" (output column 0, which is "a" in input) + let filter1 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + // filter "a < 10" (output column 1, which is "b" in input) + let filter2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 1)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter1, filter2], + &ConfigOptions::default(), + )?; + + let pushed_filters = &description.parent_filters()[0]; + assert_eq!(pushed_filters.len(), 2); + + // "b" (output index 0) -> "a" (input index 0) + let expected_filter1 = "a@0 > 5"; + // "a" (output index 1) -> "b" (input index 1) + let expected_filter2 = "b@1 < 10"; + + assert_eq!(format!("{}", pushed_filters[0].predicate), expected_filter1); + assert_eq!(format!("{}", pushed_filters[1].predicate), expected_filter2); + // Verify the predicates were actually pushed down + assert!(matches!(pushed_filters[0].discriminant, PushedDown::Yes)); + assert!(matches!(pushed_filters[1].discriminant, PushedDown::Yes)); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_mixed_columns() -> Result<()> { + let input_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a" as "x", "b" as "b" (pass through) + let projection = ProjectionExec::try_new( + vec![ + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "x".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "b".to_string(), + }, + ], + input, + )?; + + // filter "x > 5" + let filter1 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + // filter "b < 10" (using output index 1 which corresponds to 'b') + let filter2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter1, filter2], + &ConfigOptions::default(), + )?; + + let pushed_filters = &description.parent_filters()[0]; + assert_eq!(pushed_filters.len(), 2); + // "x" -> "a" (index 0) + let expected_filter1 = "a@0 > 5"; + // "b" -> "b" (index 1) + let expected_filter2 = "b@1 < 10"; + + assert_eq!(format!("{}", pushed_filters[0].predicate), expected_filter1); + assert_eq!(format!("{}", pushed_filters[1].predicate), expected_filter2); + // Verify the predicates were actually pushed down + assert!(matches!(pushed_filters[0].discriminant, PushedDown::Yes)); + assert!(matches!(pushed_filters[1].discriminant, PushedDown::Yes)); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_complex_expression() -> Result<()> { + let input_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a + 1" as "z" + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )), + alias: "z".to_string(), + }], + input, + )?; + + // filter "z > 10" + let filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("z", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter], + &ConfigOptions::default(), + )?; + + // expand to `a + 1 > 10` + let pushed_filters = &description.parent_filters()[0]; + assert!(matches!(pushed_filters[0].discriminant, PushedDown::Yes)); + assert_eq!(format!("{}", pushed_filters[0].predicate), "a@0 + 1 > 10"); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_unknown_column() -> Result<()> { + let input_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a" as "a" + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "a".to_string(), + }], + input, + )?; + + // filter "unknown_col > 5" - using a column name that doesn't exist in projection output + // Column constructor: name, index. Index 1 doesn't exist. + let filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("unknown_col", 1)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter], + &ConfigOptions::default(), + )?; + + let pushed_filters = &description.parent_filters()[0]; + assert!(matches!(pushed_filters[0].discriminant, PushedDown::No)); + // The column shouldn't be found in the alias map, so it remains unchanged with its index + assert_eq!( + format!("{}", pushed_filters[0].predicate), + "unknown_col@1 > 5" + ); + + Ok(()) + } + + /// Basic test for `DynamicFilterPhysicalExpr` can correctly update its child expression + /// i.e. starting with lit(true) and after update it becomes `a > 5` + /// with projection [b - 1 as a], the pushed down filter should be `b - 1 > 5` + #[test] + fn test_basic_dyn_filter_projection_pushdown_update_child() -> Result<()> { + let input_schema = + Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, false)])); + + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.as_ref().clone(), + )); + + // project "b" - 1 as "a" + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: binary( + Arc::new(Column::new("b", 0)), + Operator::Minus, + lit(1), + &input_schema, + ) + .unwrap(), + alias: "a".to_string(), + }], + input, + )?; + + // simulate projection's parent create a dynamic filter on "a" + let projected_schema = projection.schema(); + let col_a = col("a", &projected_schema)?; + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&col_a)], + lit(true), + )); + // Initial state should be lit(true) + let current = dynamic_filter.current()?; + assert_eq!(format!("{current}"), "true"); + + let dyn_phy_expr: Arc = Arc::clone(&dynamic_filter) as _; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![dyn_phy_expr], + &ConfigOptions::default(), + )?; + + let pushed_filters = &description.parent_filters()[0][0]; + + // Check currently pushed_filters is lit(true) + assert_eq!( + format!("{}", pushed_filters.predicate), + "DynamicFilter [ empty ]" + ); + + // Update to a > 5 (after projection, b is now called a) + let new_expr = + Arc::new(BinaryExpr::new(Arc::clone(&col_a), Operator::Gt, lit(5i32))); + dynamic_filter.update(new_expr)?; + + // Now it should be a > 5 + let current = dynamic_filter.current()?; + assert_eq!(format!("{current}"), "a@0 > 5"); + + // Check currently pushed_filters is b - 1 > 5 (because b - 1 is projected as a) + assert_eq!( + format!("{}", pushed_filters.predicate), + "DynamicFilter [ b@0 - 1 > 5 ]" + ); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 7b966ed23dbde..7289ac43e510c 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -22,22 +22,30 @@ use std::sync::Arc; use std::task::{Context, Poll}; use super::work_table::{ReservedBatches, WorkTable}; -use crate::execution_plan::{Boundedness, EmissionType}; +use crate::aggregates::group_values::{GroupValues, new_group_values}; +use crate::aggregates::order::GroupOrdering; +use crate::common::project_plan_to_schema; +use crate::execution_plan::{Boundedness, EmissionType, reset_plan_states}; +use crate::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, +}; use crate::{ - metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, - PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, }; -use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; - +use arrow::array::{BooleanArray, BooleanBuilder}; +use arrow::compute::filter_record_batch; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{internal_datafusion_err, not_impl_err, Result}; -use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_common::{ + Result, exec_datafusion_err, internal_datafusion_err, not_impl_err, +}; use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; -use futures::{ready, Stream, StreamExt}; +use futures::{Stream, StreamExt, ready}; /// Recursive query execution plan. /// @@ -69,22 +77,25 @@ pub struct RecursiveQueryExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl RecursiveQueryExec { /// Create a new RecursiveQueryExec pub fn try_new( name: String, + output_schema: SchemaRef, static_term: Arc, recursive_term: Arc, is_distinct: bool, ) -> Result { // Each recursive query needs its own work table - let work_table = Arc::new(WorkTable::new()); + let work_table = Arc::new(WorkTable::new(name.clone())); // Use the same work table for both the WorkTableExec and the recursive term + let static_term = project_plan_to_schema(static_term, &output_schema)?; let recursive_term = assign_work_table(recursive_term, &work_table)?; - let cache = Self::compute_properties(static_term.schema()); + let recursive_term = project_plan_to_schema(recursive_term, &output_schema)?; + let cache = Self::compute_properties(output_schema); Ok(RecursiveQueryExec { name, static_term, @@ -92,7 +103,7 @@ impl RecursiveQueryExec { is_distinct, work_table, metrics: ExecutionPlanMetricsSet::new(), - cache, + cache: Arc::new(cache), }) } @@ -134,11 +145,7 @@ impl ExecutionPlan for RecursiveQueryExec { "RecursiveQueryExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -169,6 +176,7 @@ impl ExecutionPlan for RecursiveQueryExec { ) -> Result> { RecursiveQueryExec::try_new( self.name.clone(), + self.schema(), Arc::clone(&children[0]), Arc::clone(&children[1]), self.is_distinct, @@ -195,17 +203,14 @@ impl ExecutionPlan for RecursiveQueryExec { Arc::clone(&self.work_table), Arc::clone(&self.recursive_term), static_stream, + self.is_distinct, baseline_metrics, - ))) + )?)) } fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } - - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema())) - } } impl DisplayAs for RecursiveQueryExec { @@ -267,8 +272,10 @@ struct RecursiveQueryStream { buffer: Vec, /// Tracks the memory used by the buffer reservation: MemoryReservation, - // /// Metrics. - _baseline_metrics: BaselineMetrics, + /// If the distinct flag is set, then we use this hash table to remove duplicates from result and work tables + distinct_deduplicator: Option, + /// Metrics. + baseline_metrics: BaselineMetrics, } impl RecursiveQueryStream { @@ -278,12 +285,16 @@ impl RecursiveQueryStream { work_table: Arc, recursive_term: Arc, static_stream: SendableRecordBatchStream, + is_distinct: bool, baseline_metrics: BaselineMetrics, - ) -> Self { + ) -> Result { let schema = static_stream.schema(); let reservation = MemoryConsumer::new("RecursiveQuery").register(task_context.memory_pool()); - Self { + let distinct_deduplicator = is_distinct + .then(|| DistinctDeduplicator::new(Arc::clone(&schema), &task_context)) + .transpose()?; + Ok(Self { task_context, work_table, recursive_term, @@ -292,21 +303,29 @@ impl RecursiveQueryStream { schema, buffer: vec![], reservation, - _baseline_metrics: baseline_metrics, - } + distinct_deduplicator, + baseline_metrics, + }) } /// Push a clone of the given batch to the in memory buffer, and then return /// a poll with it. fn push_batch( mut self: std::pin::Pin<&mut Self>, - batch: RecordBatch, + mut batch: RecordBatch, ) -> Poll>> { + let baseline_metrics = self.baseline_metrics.clone(); + + if let Some(deduplicator) = &mut self.distinct_deduplicator { + let _timer_guard = baseline_metrics.elapsed_compute().timer(); + batch = deduplicator.deduplicate(&batch)?; + } + if let Err(e) = self.reservation.try_grow(batch.get_array_memory_size()) { return Poll::Ready(Some(Err(e))); } - self.buffer.push(batch.clone()); + (&batch).record_output(&baseline_metrics); Poll::Ready(Some(Ok(batch))) } @@ -361,8 +380,6 @@ fn assign_work_table( work_table_refs += 1; Ok(Transformed::yes(new_plan)) } - } else if plan.as_any().is::() { - not_impl_err!("Recursive queries cannot be nested") } else { Ok(Transformed::no(plan)) } @@ -370,20 +387,6 @@ fn assign_work_table( .data() } -/// Some plans will change their internal states after execution, making them unable to be executed again. -/// This function uses [`ExecutionPlan::reset_state`] to reset any internal state within the plan. -/// -/// An example is `CrossJoinExec`, which loads the left table into memory and stores it in the plan. -/// However, if the data of the left table is derived from the work table, it will become outdated -/// as the work table changes. When the next iteration executes this plan again, we must clear the left table. -fn reset_plan_states(plan: Arc) -> Result> { - plan.transform_up(|plan| { - let new_plan = Arc::clone(&plan).reset_state()?; - Ok(Transformed::yes(new_plan)) - }) - .data() -} - impl Stream for RecursiveQueryStream { type Item = Result; @@ -391,7 +394,6 @@ impl Stream for RecursiveQueryStream { mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - // TODO: we should use this poll to record some metrics! if let Some(static_stream) = &mut self.static_stream { // While the static term's stream is available, we'll be forwarding the batches from it (also // saving them for the initial iteration of the recursive term). @@ -428,5 +430,128 @@ impl RecordBatchStream for RecursiveQueryStream { } } +/// Deduplicator based on a hash table. +struct DistinctDeduplicator { + /// Grouped rows used for distinct + group_values: Box, + reservation: MemoryReservation, + intern_output_buffer: Vec, +} + +impl DistinctDeduplicator { + fn new(schema: SchemaRef, task_context: &TaskContext) -> Result { + let group_values = new_group_values(schema, &GroupOrdering::None)?; + let reservation = MemoryConsumer::new("RecursiveQueryHashTable") + .register(task_context.memory_pool()); + Ok(Self { + group_values, + reservation, + intern_output_buffer: Vec::new(), + }) + } + + /// Remove duplicated rows from the given batch, keeping a state between batches. + /// + /// We use a hash table to allocate new group ids for the new rows. + /// [`GroupValues`] allocate increasing group ids. + /// Hence, if groups (i.e., rows) are new, then they have ids >= length before interning, we keep them. + /// We also detect duplicates by enforcing that group ids are increasing. + fn deduplicate(&mut self, batch: &RecordBatch) -> Result { + let size_before = self.group_values.len(); + let additional = batch.num_rows(); + self.intern_output_buffer + .try_reserve(additional) + .map_err(|e| { + exec_datafusion_err!( + "failed to reserve {additional} recursive query group ids: {e}" + ) + })?; + self.group_values + .intern(batch.columns(), &mut self.intern_output_buffer)?; + let mask = new_groups_mask(&self.intern_output_buffer, size_before); + self.intern_output_buffer.clear(); + // We update the reservation to reflect the new size of the hash table. + self.reservation.try_resize(self.group_values.size())?; + Ok(filter_record_batch(batch, &mask)?) + } +} + +/// Return a mask, each element being true if, and only if, the element is greater than all previous elements and greater or equal than the provided max_already_seen_group_id +fn new_groups_mask( + values: &[usize], + mut max_already_seen_group_id: usize, +) -> BooleanArray { + let mut output = BooleanBuilder::with_capacity(values.len()); + for value in values { + if *value >= max_already_seen_group_id { + output.append_value(true); + max_already_seen_group_id = *value + 1; // We want to be increasing + } else { + output.append_value(false); + } + } + output.finish() +} + #[cfg(test)] -mod tests {} +mod tests { + use super::*; + use crate::empty::EmptyExec; + use crate::projection::ProjectionExec; + + use arrow::datatypes::{DataType, Field, Schema}; + + fn empty_exec(fields: Vec) -> Arc { + Arc::new(EmptyExec::new(Arc::new(Schema::new(fields)))) + } + + #[test] + fn recursive_query_exec_projects_recursive_term_to_reconciled_schema() -> Result<()> { + let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]); + let recursive_term = + empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, false)]); + + let exec = RecursiveQueryExec::try_new( + "numbers".to_string(), + static_term.schema(), + Arc::clone(&static_term), + Arc::clone(&recursive_term), + false, + )?; + + assert_eq!(exec.schema(), static_term.schema()); + let projection = exec + .recursive_term() + .downcast_ref::() + .expect("recursive term should be aligned with ProjectionExec"); + assert!(Arc::ptr_eq(projection.input(), &recursive_term)); + assert!(!projection.schema().field(0).is_nullable()); + assert_eq!(projection.expr()[0].alias, "value"); + Ok(()) + } + + #[test] + fn recursive_query_exec_reconciles_nullability() -> Result<()> { + let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]); + let recursive_term = + empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, true)]); + let output_schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int32, + true, + )])); + + let exec = RecursiveQueryExec::try_new( + "numbers".to_string(), + Arc::clone(&output_schema), + static_term, + recursive_term, + false, + )?; + + assert!(exec.schema().field(0).is_nullable()); + assert!(exec.static_term().schema().field(0).is_nullable()); + assert!(exec.recursive_term().schema().field(0).is_nullable()); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/render_tree.rs b/datafusion/physical-plan/src/render_tree.rs index f86e4c55e7b0e..40e2763698093 100644 --- a/datafusion/physical-plan/src/render_tree.rs +++ b/datafusion/physical-plan/src/render_tree.rs @@ -31,11 +31,12 @@ use crate::{DisplayFormatType, ExecutionPlan}; // TODO: It's never used. /// Represents a 2D coordinate in the rendered tree. /// Used to track positions of nodes and their connections. -#[allow(dead_code)] pub struct Coordinate { /// Horizontal position in the tree + #[expect(dead_code)] pub x: usize, /// Vertical position in the tree + #[expect(dead_code)] pub y: usize, } diff --git a/datafusion/physical-plan/src/repartition/distributor_channels.rs b/datafusion/physical-plan/src/repartition/distributor_channels.rs index 34294d0f2326d..22872d1e32d49 100644 --- a/datafusion/physical-plan/src/repartition/distributor_channels.rs +++ b/datafusion/physical-plan/src/repartition/distributor_channels.rs @@ -43,8 +43,8 @@ use std::{ ops::DerefMut, pin::Pin, sync::{ - atomic::{AtomicUsize, Ordering}, Arc, + atomic::{AtomicUsize, Ordering}, }, task::{Context, Poll, Waker}, }; @@ -476,7 +476,7 @@ type SharedGate = Arc; mod tests { use std::sync::atomic::AtomicBool; - use futures::{task::ArcWake, FutureExt}; + use futures::{FutureExt, task::ArcWake}; use super::*; diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 843d975c7d769..3d30dd82762b1 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -22,23 +22,28 @@ use std::fmt::{Debug, Formatter}; use std::pin::Pin; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::task::{Context, Poll}; -use std::{any::Any, vec}; +use std::vec; use super::common::SharedMemoryReservation; use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream, }; +use crate::coalesce::LimitedBatchCoalescer; use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; use crate::hash_utils::create_hashes; use crate::metrics::{BaselineMetrics, SpillMetrics}; -use crate::projection::{all_columns, make_with_child, update_expr, ProjectionExec}; +use crate::projection::{ProjectionExec, all_columns, make_with_child, update_expr}; use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::spill::spill_manager::SpillManager; use crate::spill::spill_pool::{self, SpillPoolWriter}; -use crate::stream::RecordBatchStreamAdapter; -use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics}; +use crate::stream::{EmptyRecordBatchStream, RecordBatchStreamAdapter}; +use crate::{ + DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics, + check_if_same_properties, +}; use arrow::array::{PrimitiveArray, RecordBatch, RecordBatchOptions}; use arrow::compute::take_arrays; @@ -47,12 +52,12 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::stats::Precision; use datafusion_common::utils::transpose; use datafusion_common::{ - assert_or_internal_err, internal_err, ColumnStatistics, DataFusionError, HashMap, + ColumnStatistics, DataFusionError, HashMap, assert_or_internal_err, internal_err, }; -use datafusion_common::{not_impl_err, Result}; +use datafusion_common::{Result, not_impl_err}; use datafusion_common_runtime::SpawnedTask; -use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; @@ -60,6 +65,9 @@ use crate::filter_pushdown::{ ChildPushdownResult, FilterDescription, FilterPushdownPhase, FilterPushdownPropagation, }; +use crate::joins::SeededRandomState; +use crate::sort_pushdown::SortOrderPushdownResult; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays; use futures::stream::Stream; use futures::{FutureExt, StreamExt, TryStreamExt}; @@ -67,8 +75,9 @@ use log::trace; use parking_lot::Mutex; mod distributor_channels; +use crate::repartition::distributor_channels::SendError; use distributor_channels::{ - channels, partition_aware_channels, DistributionReceiver, DistributionSender, + DistributionReceiver, DistributionSender, channels, partition_aware_channels, }; /// A batch in the repartition queue - either in memory or spilled to disk. @@ -80,41 +89,49 @@ use distributor_channels::{ /// # Batch Flow with Spilling /// /// ```text -/// Input Stream ──▶ Partition Logic ──▶ try_grow() -/// │ -/// ┌───────────────┴────────────────┐ -/// │ │ -/// ▼ ▼ -/// try_grow() succeeds try_grow() fails -/// (Memory Available) (Memory Pressure) -/// │ │ -/// ▼ ▼ -/// RepartitionBatch::Memory spill_writer.push_batch() -/// (batch held in memory) (batch written to disk) -/// │ │ -/// │ ▼ -/// │ RepartitionBatch::Spilled -/// │ (marker - no batch data) -/// │ │ -/// └────────┬───────────────────────┘ -/// │ -/// ▼ -/// Send to channel -/// │ -/// ▼ -/// Output Stream (poll) -/// │ -/// ┌──────────────┴─────────────┐ -/// │ │ -/// ▼ ▼ -/// RepartitionBatch::Memory RepartitionBatch::Spilled -/// Return batch immediately Poll spill_stream (blocks) -/// │ │ -/// └────────┬───────────────────┘ -/// │ -/// ▼ -/// Return batch -/// (FIFO order preserved) +/// Input Stream ◀──────┐ +/// │ │ +/// ▼ │ +/// Partition Logic │ +/// │ `batch_size` not +/// ▼ reached yet +/// Coalesce Batch │ +/// ┌───────────────┴────────────────┘ +/// ▼ +/// `batch_size` reached +/// │ +/// └───────────────┐ +/// ▼ +/// try_grow() +/// ┌───────────────┴────────────────┐ +/// ▼ ▼ +/// try_grow() succeeds try_grow() fails +/// (Memory Available) (Memory Pressure) +/// │ │ +/// ▼ ▼ +/// RepartitionBatch::Memory spill_writer.push_batch() +/// (batch held in memory) (batch written to disk) +/// │ │ +/// │ ▼ +/// │ RepartitionBatch::Spilled +/// │ (marker - no batch data) +/// └──────────────┬─────────────────┘ +/// │ +/// ▼ +/// Send to channel +/// │ +/// ▼ +/// Output Stream (poll) +/// │ +/// ┌──────────────┴────────────────┐ +/// ▼ ▼ +/// RepartitionBatch::Memory RepartitionBatch::Spilled +/// Return batch immediately Poll spill_stream (blocks) +/// └─────────────┬─────────────────┘ +/// │ +/// ▼ +/// Return batch +/// (FIFO order preserved) /// ``` /// /// See [`RepartitionExec`] for overall architecture and [`StreamState`] for @@ -134,11 +151,122 @@ type MaybeBatch = Option>; type InputPartitionsToCurrentPartitionSender = Vec>; type InputPartitionsToCurrentPartitionReceiver = Vec>; -/// Output channel with its associated memory reservation and spill writer +/// Output channel with its associated memory reservation and spill writer. +/// +/// `coalescer` is `None` for preserve-order mode, where downstream +/// [`StreamingMergeBuilder`] performs the batching; otherwise it's a +/// [`SharedCoalescer`] cloned from the per-partition one held by +/// [`PartitionChannels`]. struct OutputChannel { sender: DistributionSender, reservation: SharedMemoryReservation, spill_writer: SpillPoolWriter, + shared_coalescer: Option, +} + +impl OutputChannel { + fn coalesce(&mut self, batch: RecordBatch) -> Result> { + match &self.shared_coalescer { + Some(shared) => Ok(shared.push_and_drain(batch)?), + None => Ok(vec![batch]), + } + } + + /// Send a single batch through the channel for `partition`, applying + /// the memory reservation / spill-writer fallback. Removes the channel + /// from `self.inner` if the receiver has hung up. + /// + /// Used after [`OutputChannel::coalesce`] for performance purposes. + async fn send(&mut self, batch: RecordBatch) -> Result<(), SendError> { + let size = batch.get_array_memory_size(); + + // Decide the payload outside of any await: never hold a MutexGuard + // across an await point. + let (payload, is_memory_batch) = { + match self.reservation.try_grow(size) { + Ok(_) => (Ok(RepartitionBatch::Memory(batch)), true), + Err(_) => match self.spill_writer.push_batch(&batch) { + Ok(()) => (Ok(RepartitionBatch::Spilled), false), + Err(err) => (Err(err), false), + }, + } + }; + + let result = self.sender.send(Some(payload)).await; + if result.is_err() && is_memory_batch { + self.reservation.shrink(size); + } + result + } + + async fn finalize(mut self) -> Result<()> { + let Some(shared) = self.shared_coalescer.take() else { + return Ok(()); + }; + for batch in shared.finalize()? { + // If this errored, it means that nobody is listening on the other side, which is fine + // and can happen in certain cases, like when a LIMIT drops the stream that listens. + let _ = self.send(batch).await; + } + Ok(()) + } +} + +/// A producer-side coalescer shared across all input tasks targeting a +/// single output partition. +/// +/// Bundles the [`LimitedBatchCoalescer`] (behind a [`Mutex`]) with the +/// active-sender counter that tracks how many input tasks may still push +/// into it. The last task to call [`Self::finalize`] is the one that +/// finalizes the coalescer and ships the residual batch. +/// +/// Cheap to [`Clone`]: both fields are [`Arc`]s. +#[derive(Clone)] +struct SharedCoalescer { + inner: Arc>, + active_senders: Arc, +} + +impl SharedCoalescer { + fn new(schema: SchemaRef, target_batch_size: usize, num_senders: usize) -> Self { + Self { + inner: Arc::new(Mutex::new(LimitedBatchCoalescer::new( + schema, + target_batch_size, + None, + ))), + active_senders: Arc::new(AtomicUsize::new(num_senders)), + } + } + + /// Push `batch` into the coalescer and drain any newly completed + /// batches. The mutex is held only briefly. + fn push_and_drain(&self, batch: RecordBatch) -> Result> { + let mut acc = Vec::new(); + let mut c = self.inner.lock(); + c.push_batch(batch)?; + while let Some(b) = c.next_completed_batch() { + acc.push(b); + } + Ok(acc) + } + + /// Decrement the active-senders counter. If this caller was the last + /// sender, finalize the coalescer and return its residual batches; if + /// other senders are still active, return `Ok(None)`. + fn finalize(&self) -> Result> { + let was_last = self.active_senders.fetch_sub(1, Ordering::AcqRel) == 1; + if !was_last { + return Ok(vec![]); + } + let mut acc = Vec::new(); + let mut c = self.inner.lock(); + c.finish()?; + while let Some(b) = c.next_completed_batch() { + acc.push(b); + } + Ok(acc) + } } /// Channels and resources for a single output partition. @@ -169,6 +297,10 @@ struct PartitionChannels { rx: InputPartitionsToCurrentPartitionReceiver, /// Memory reservation for this output partition reservation: SharedMemoryReservation, + /// Shared coalescer used by all input tasks targeting this output + /// partition. `None` in preserve-order mode (downstream + /// `StreamingMergeBuilder` handles batching). + shared_coalescer: Option, /// Spill writers for writing spilled data. /// SpillPoolWriter is Clone, so multiple writers can share state in non-preserve-order mode. spill_writers: Vec, @@ -275,7 +407,9 @@ impl RepartitionExecState { let RepartitionExecState::InputStreamsInitialized(value) = self else { // This cannot happen, as ensure_input_streams_initialized() was just called, // but the compiler does not know. - return internal_err!("Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized"); + return internal_err!( + "Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized" + ); }; value } @@ -311,11 +445,11 @@ impl RepartitionExecState { let mut channels = HashMap::with_capacity(txs.len()); for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() { - let reservation = Arc::new(Mutex::new( + let reservation = Arc::new( MemoryConsumer::new(format!("{name}[{partition}]")) .with_can_spill(true) .register(context.memory_pool()), - )); + ); // Create spill channels based on mode: // - preserve_order: one spill channel per (input, output) pair for proper FIFO ordering @@ -336,6 +470,18 @@ impl RepartitionExecState { .map(|_| spill_pool::channel(max_file_size, Arc::clone(&spill_manager))) .unzip(); + // Coalesce on the producer side, before the channel's gate, so + // the consumer never sees the per-input-task small batches. + // Skip in preserve-order mode: each input has its own dedicated + // channel and `StreamingMergeBuilder` handles batching. + let shared_coalescer = (!preserve_order).then(|| { + SharedCoalescer::new( + input.schema(), + context.session_config().batch_size(), + num_input_partitions, + ) + }); + channels.insert( partition, PartitionChannels { @@ -344,6 +490,7 @@ impl RepartitionExecState { reservation, spill_readers, spill_writers, + shared_coalescer, }, ); } @@ -366,6 +513,7 @@ impl RepartitionExecState { reservation: Arc::clone(&channels.reservation), spill_writer: channels.spill_writers[spill_writer_idx] .clone(), + shared_coalescer: channels.shared_coalescer.clone(), }, ) }) @@ -413,8 +561,9 @@ pub struct BatchPartitioner { enum BatchPartitionerState { Hash { exprs: Vec>, - num_partitions: usize, + partition_reducer: StrengthReducedU64, hash_buffer: Vec, + indices: Vec>, }, RoundRobin { num_partitions: usize, @@ -424,37 +573,181 @@ enum BatchPartitionerState { /// Fixed RandomState used for hash repartitioning to ensure consistent behavior across /// executions and runs. -pub const REPARTITION_RANDOM_STATE: ahash::RandomState = - ahash::RandomState::with_seeds(0, 0, 0, 0); +pub const REPARTITION_RANDOM_STATE: SeededRandomState = SeededRandomState::with_seed(0); + +/// Computes `value % divisor` without division in the hot loop when `divisor` +/// is fixed for many values. +/// +/// Hash repartitioning computes a remainder for every row. Integer division is +/// relatively expensive, so this precomputes the strength-reduced form of the +/// divisor: powers of two use a bit mask, and other divisors use a reciprocal +/// multiply to recover the quotient and therefore the remainder. This is the +/// same invariant-divisor optimization compilers use for `%` by a constant. +#[derive(Debug, Clone, Copy)] +enum StrengthReducedU64 { + PowerOfTwo { mask: u64 }, + Reciprocal { divisor: u64, reciprocal: u128 }, +} + +impl StrengthReducedU64 { + fn new(divisor: u64) -> Self { + debug_assert!(divisor > 0); + + if divisor.is_power_of_two() { + Self::PowerOfTwo { mask: divisor - 1 } + } else { + Self::Reciprocal { + divisor, + // ceil(2^128 / divisor), computed without representing 2^128 + reciprocal: u128::MAX / u128::from(divisor) + 1, + } + } + } + + fn partition_indices(self, hash_buffer: &[u64], indices: &mut [Vec]) { + match self { + Self::PowerOfTwo { mask } => { + for (index, hash) in hash_buffer.iter().enumerate() { + indices[(*hash & mask) as usize].push(index as u32); + } + } + Self::Reciprocal { + divisor, + reciprocal, + } => { + for (index, hash) in hash_buffer.iter().enumerate() { + let quotient = Self::quotient(*hash, reciprocal); + let partition = *hash - quotient * divisor; + indices[partition as usize].push(index as u32); + } + } + } + } + + #[cfg(test)] + fn remainder(self, value: u64) -> u64 { + match self { + Self::PowerOfTwo { mask } => value & mask, + Self::Reciprocal { + divisor, + reciprocal, + } => value - Self::quotient(value, reciprocal) * divisor, + } + } + + #[inline] + fn quotient(value: u64, reciprocal: u128) -> u64 { + let reciprocal_low = reciprocal as u64; + let reciprocal_high = (reciprocal >> 64) as u64; + let low_product = u128::from(value) * u128::from(reciprocal_low); + let high_product = u128::from(value) * u128::from(reciprocal_high); + let carry = ((high_product & u128::from(u64::MAX)) + (low_product >> 64)) >> 64; + + ((high_product >> 64) + carry) as u64 + } +} impl BatchPartitioner { - /// Create a new [`BatchPartitioner`] with the provided [`Partitioning`] + /// Create a new [`BatchPartitioner`] for hash-based repartitioning. + /// + /// # Parameters + /// - `exprs`: Expressions used to compute the hash for each input row. + /// - `num_partitions`: Total number of output partitions. + /// - `timer`: Metric used to record time spent during repartitioning. + /// + /// The partition count is fixed for the lifetime of the partitioner, so this + /// precomputes a strength-reduced reducer for `hash % num_partitions`. + /// + /// # Errors + /// Returns an error if `num_partitions` is zero. + pub fn new_hash_partitioner( + exprs: Vec>, + num_partitions: usize, + timer: metrics::Time, + ) -> Result { + if num_partitions == 0 { + return internal_err!("Hash repartition requires at least one partition"); + } + + Ok(Self { + state: BatchPartitionerState::Hash { + exprs, + partition_reducer: StrengthReducedU64::new(num_partitions as u64), + hash_buffer: vec![], + indices: vec![vec![]; num_partitions], + }, + timer, + }) + } + + /// Create a new [`BatchPartitioner`] for round-robin repartitioning. + /// + /// # Parameters + /// - `num_partitions`: Total number of output partitions. + /// - `timer`: Metric used to record time spent during repartitioning. + /// - `input_partition`: Index of the current input partition. + /// - `num_input_partitions`: Total number of input partitions. /// - /// The time spent repartitioning will be recorded to `timer` + /// # Notes + /// The starting output partition is derived from the input partition + /// to avoid skew when multiple input partitions are used. + pub fn new_round_robin_partitioner( + num_partitions: usize, + timer: metrics::Time, + input_partition: usize, + num_input_partitions: usize, + ) -> Self { + Self { + state: BatchPartitionerState::RoundRobin { + num_partitions, + next_idx: (input_partition * num_partitions) / num_input_partitions, + }, + timer, + } + } + /// Create a new [`BatchPartitioner`] based on the provided [`Partitioning`] scheme. + /// + /// This is a convenience constructor that delegates to the specialized + /// hash or round-robin constructors depending on the partitioning variant. + /// + /// # Parameters + /// - `partitioning`: Partitioning scheme to apply (hash or round-robin). + /// - `timer`: Metric used to record time spent during repartitioning. + /// - `input_partition`: Index of the current input partition. + /// - `num_input_partitions`: Total number of input partitions. + /// + /// # Errors + /// Returns an error if the provided partitioning scheme is not supported, + /// or if hash partitioning is requested with zero output partitions. pub fn try_new( partitioning: Partitioning, timer: metrics::Time, input_partition: usize, num_input_partitions: usize, ) -> Result { - let state = match partitioning { + match partitioning { + Partitioning::Hash(exprs, num_partitions) => { + Self::new_hash_partitioner(exprs, num_partitions, timer) + } Partitioning::RoundRobinBatch(num_partitions) => { - BatchPartitionerState::RoundRobin { + Ok(Self::new_round_robin_partitioner( num_partitions, - // Distribute starting index evenly based on input partition, number of input partitions and number of partitions - // to avoid they all start at partition 0 and heavily skew on the lower partitions - next_idx: ((input_partition * num_partitions) / num_input_partitions), - } + timer, + input_partition, + num_input_partitions, + )) } - Partitioning::Hash(exprs, num_partitions) => BatchPartitionerState::Hash { - exprs, - num_partitions, - hash_buffer: vec![], - }, - other => return not_impl_err!("Unsupported repartitioning scheme {other:?}"), - }; - - Ok(Self { state, timer }) + Partitioning::Range(_) => { + // Range repartition execution is tracked in + // https://github.com/apache/datafusion/issues/22397 + not_impl_err!( + "Range partitioning execution is not implemented by RepartitionExec" + ) + } + other => { + not_impl_err!("Unsupported repartitioning scheme {other:?}") + } + } } /// Partition the provided [`RecordBatch`] into one or more partitioned [`RecordBatch`] @@ -476,12 +769,21 @@ impl BatchPartitioner { }) } - /// Actual implementation of [`partition`](Self::partition). + /// Returns an iterator of `(partition_index, RecordBatch)` pairs for the given batch. + /// + /// This is useful for async consumers that want to separate CPU-bound partitioning + /// from I/O. For example, you can iterate results on the async side and send them + /// through a channel, while performing file I/O on a blocking task: /// - /// The reason this was pulled out is that we need to have a variant of `partition` that works w/ sync functions, - /// and one that works w/ async. Using an iterator as an intermediate representation was the best way to achieve - /// this (so we don't need to clone the entire implementation). - fn partition_iter( + /// ```ignore + /// for result in partitioner.partition_iter(batch)? { + /// let (partition, batch) = result?; + /// tx.send((partition, batch)).await?; + /// } + /// ``` + /// + /// The sync [`partition`](Self::partition) method is implemented on top of this. + pub fn partition_iter( &mut self, batch: RecordBatch, ) -> Result> + Send + '_> { @@ -497,8 +799,9 @@ impl BatchPartitioner { } BatchPartitionerState::Hash { exprs, - num_partitions: partitions, + partition_reducer, hash_buffer, + indices, } => { // Tracking time required for distributing indexes across output partitions let timer = self.timer.timer(); @@ -509,48 +812,23 @@ impl BatchPartitioner { hash_buffer.clear(); hash_buffer.resize(batch.num_rows(), 0); - create_hashes(&arrays, &REPARTITION_RANDOM_STATE, hash_buffer)?; + create_hashes( + &arrays, + REPARTITION_RANDOM_STATE.random_state(), + hash_buffer, + )?; - let mut indices: Vec<_> = (0..*partitions) - .map(|_| Vec::with_capacity(batch.num_rows())) - .collect(); + indices.iter_mut().for_each(|v| v.clear()); - for (index, hash) in hash_buffer.iter().enumerate() { - indices[(*hash % *partitions as u64) as usize].push(index as u32); - } + partition_reducer.partition_indices(hash_buffer, indices); // Finished building index-arrays for output partitions timer.done(); - // Borrowing partitioner timer to prevent moving `self` to closure - let partitioner_timer = &self.timer; - let it = indices - .into_iter() - .enumerate() - .filter_map(|(partition, indices)| { - let indices: PrimitiveArray = indices.into(); - (!indices.is_empty()).then_some((partition, indices)) - }) - .map(move |(partition, indices)| { - // Tracking time required for repartitioned batches construction - let _timer = partitioner_timer.timer(); - - // Produce batches based on indices - let columns = take_arrays(batch.columns(), &indices, None)?; - - let mut options = RecordBatchOptions::new(); - options = options.with_row_count(Some(indices.len())); - let batch = RecordBatch::try_new_with_options( - batch.schema(), - columns, - &options, - ) - .unwrap(); - - Ok((partition, batch)) - }); - - Box::new(it) + let partitioned_batches = + Self::partition_grouped_take(&batch, indices, &self.timer)?; + + Box::new(partitioned_batches.into_iter()) } }; @@ -559,11 +837,71 @@ impl BatchPartitioner { // return the number of output partitions fn num_partitions(&self) -> usize { - match self.state { - BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions, - BatchPartitionerState::Hash { num_partitions, .. } => num_partitions, + match &self.state { + BatchPartitionerState::RoundRobin { num_partitions, .. } => *num_partitions, + BatchPartitionerState::Hash { indices, .. } => indices.len(), } } + + /// Build repartitioned hash output batches using one `take` per input batch. + /// + /// The hash router first fills one index vector per output partition. This method + /// concatenates those index vectors, performs one grouped `take_arrays`, and + /// then returns each output partition as a slice of the reordered batch. + /// + /// For example, given partition indices: + /// + /// ```text + /// partition 0: [2, 5] + /// partition 1: [] + /// partition 2: [0, 3, 4] + /// ``` + /// + /// this method takes rows in `[2, 5, 0, 3, 4]` order once, then returns + /// `partition 0 = slice(0, 2)` and `partition 2 = slice(2, 3)`. + fn partition_grouped_take( + batch: &RecordBatch, + indices: &mut [Vec], + timer: &metrics::Time, + ) -> Result>> { + let mut partition_ranges = Vec::with_capacity(indices.len()); + let mut reordered_indices = Vec::with_capacity(batch.num_rows()); + + for (partition, p_indices) in indices.iter_mut().enumerate() { + if p_indices.is_empty() { + continue; + } + + let start = reordered_indices.len(); + reordered_indices.extend_from_slice(p_indices); + partition_ranges.push((partition, start, p_indices.len())); + p_indices.clear(); + } + + if reordered_indices.is_empty() { + return Ok(vec![]); + } + + let batches = { + let _timer = timer.timer(); + let indices_array: PrimitiveArray = reordered_indices.into(); + let columns = take_arrays(batch.columns(), &indices_array, None)?; + + let mut options = RecordBatchOptions::new(); + options = options.with_row_count(Some(indices_array.len())); + let reordered_batch = + RecordBatch::try_new_with_options(batch.schema(), columns, &options)?; + + partition_ranges + .into_iter() + .map(|(partition, start, len)| { + Ok((partition, reordered_batch.slice(start, len))) + }) + .collect() + }; + + Ok(batches) + } } /// Maps `N` input partitions to `M` output partitions based on a @@ -585,7 +923,7 @@ impl BatchPartitioner { /// used to get 3 even streams of `RecordBatch`es /// /// -///```text +/// ```text /// ▲ ▲ ▲ /// │ │ │ /// │ │ │ @@ -625,6 +963,34 @@ impl BatchPartitioner { /// arbitrary interleaving (and thus unordered) unless /// [`Self::with_preserve_order`] specifies otherwise. /// +/// # Batch coalescing +/// +/// Repartitioning one [`RecordBatch`] implies creating multiple smaller batches, potentially +/// as many as the number of output partitions. [`RepartitionExec`] makes sure that the returned +/// batches adhere to the configured `datafusion.execution.batch_size` for efficient operations, +/// and for that, it will automatically coalesce batches right after repartitioning. +/// +/// For this, one shared [`LimitedBatchCoalescer`] per output partition is used: +/// +/// ```text +/// ┌───┐ ┌───┐ +/// ┌─▶│ │────────▶.───────────. │ │ ┌──────────────────┐ +/// │ └───┘ ┌───┐ ( Coalescer 0 )──▶ ├───┤ ───▶│ Output 0 │ +/// │┌──────▶│ │──▶`───────────' │ │ └──────────────────┘ +/// ││ └───┘ └───┘ +/// ┌──────────────────┐ ││ ┌──────────────────┐ +/// │BatchPartitioner 0│─┘│ │ Output 1 │ +/// └──────────────────┘ │ └──────────────────┘ +/// │ +/// ┌──────────────────┐ │ ... ┌──────────────────┐ +/// │BatchPartitioner 1│──┘ │ Output 2 │ +/// └──────────────────┘ └──────────────────┘ +/// +/// ┌──────────────────┐ +/// │ Output 3 │ +/// └──────────────────┘ +/// ``` +/// /// # Spilling Architecture /// /// RepartitionExec uses [`SpillPool`](crate::spill::spill_pool) channels to handle @@ -664,6 +1030,10 @@ impl BatchPartitioner { /// system Paper](https://dl.acm.org/doi/pdf/10.1145/93605.98720) /// which uses the term "Exchange" for the concept of repartitioning /// data across threads. +/// +/// For more background, please also see the [Optimizing Repartitions in DataFusion] blog. +/// +/// [Optimizing Repartitions in DataFusion]: https://datafusion.apache.org/blog/2025/12/15/avoid-consecutive-repartitions #[derive(Debug, Clone)] pub struct RepartitionExec { /// Input execution plan @@ -677,7 +1047,7 @@ pub struct RepartitionExec { /// `SortPreservingRepartitionExec`, false means `RepartitionExec`. preserve_order: bool, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } #[derive(Debug, Clone)] @@ -746,6 +1116,18 @@ impl RepartitionExec { pub fn name(&self) -> &str { "RepartitionExec" } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + state: Default::default(), + ..Self::clone(self) + } + } } impl DisplayAs for RepartitionExec { @@ -801,11 +1183,7 @@ impl ExecutionPlan for RepartitionExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -817,6 +1195,7 @@ impl ExecutionPlan for RepartitionExec { self: Arc, mut children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); let mut repartition = RepartitionExec::try_new( children.swap_remove(0), self.partitioning().clone(), @@ -982,15 +1361,11 @@ impl ExecutionPlan for RepartitionExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.input.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { if let Some(partition) = partition { let partition_count = self.partitioning().partition_count(); if partition_count == 0 { - return Ok(Statistics::new_unknown(&self.schema())); + return Ok(Arc::new(Statistics::new_unknown(&self.schema()))); } assert_or_internal_err!( @@ -1000,7 +1375,7 @@ impl ExecutionPlan for RepartitionExec { partition_count ); - let mut stats = self.input.partition_statistics(None)?; + let mut stats = Arc::unwrap_or_clone(self.input.partition_statistics(None)?); // Distribute statistics across partitions stats.num_rows = stats @@ -1021,7 +1396,7 @@ impl ExecutionPlan for RepartitionExec { .map(|_| ColumnStatistics::new_unknown()) .collect(); - Ok(stats) + Ok(Arc::new(stats)) } else { self.input.partition_statistics(None) } @@ -1062,6 +1437,13 @@ impl ExecutionPlan for RepartitionExec { } Partitioning::Hash(new_partitions, *size) } + Partitioning::Range(_) => { + // Range partitioning optimizer propagation is tracked in + // https://github.com/apache/datafusion/issues/22395 + return not_impl_err!( + "Projection pushdown through RepartitionExec with range partitioning is not implemented" + ); + } others => others.clone(), }; @@ -1089,24 +1471,64 @@ impl ExecutionPlan for RepartitionExec { Ok(FilterPushdownPropagation::if_all(child_pushdown_result)) } + fn try_pushdown_sort( + &self, + order: &[PhysicalSortExpr], + ) -> Result>> { + // RepartitionExec only maintains input order if preserve_order is set + // or if there's only one partition + if !self.maintains_input_order()[0] { + return Ok(SortOrderPushdownResult::Unsupported); + } + match self.partitioning() { + Partitioning::Range(_) => { + // Range partitioning optimizer propagation is tracked in + // https://github.com/apache/datafusion/issues/22395 + return not_impl_err!( + "Sort pushdown through RepartitionExec with range partitioning is not implemented" + ); + } + Partitioning::RoundRobinBatch(_) + | Partitioning::Hash(_, _) + | Partitioning::UnknownPartitioning(_) => {} + } + + // Delegate to the child and wrap with a new RepartitionExec + self.input.try_pushdown_sort(order)?.try_map(|new_input| { + let mut new_repartition = + RepartitionExec::try_new(new_input, self.partitioning().clone())?; + if self.preserve_order { + new_repartition = new_repartition.with_preserve_order(); + } + Ok(Arc::new(new_repartition) as Arc) + }) + } + fn repartitioned( &self, target_partitions: usize, _config: &ConfigOptions, ) -> Result>> { use Partitioning::*; - let mut new_properties = self.cache.clone(); + let mut new_properties = PlanProperties::clone(&self.cache); new_properties.partitioning = match new_properties.partitioning { RoundRobinBatch(_) => RoundRobinBatch(target_partitions), Hash(hash, _) => Hash(hash, target_partitions), UnknownPartitioning(_) => UnknownPartitioning(target_partitions), + Range(_) => { + // Range repartition execution is tracked in + // https://github.com/apache/datafusion/issues/22397 + return not_impl_err!( + "Changing RepartitionExec partition counts with range partitioning is not implemented" + ); + } }; Ok(Some(Arc::new(Self { input: Arc::clone(&self.input), state: Arc::clone(&self.state), metrics: self.metrics.clone(), preserve_order: self.preserve_order, - cache: new_properties, + cache: new_properties.into(), }))) } } @@ -1126,7 +1548,7 @@ impl RepartitionExec { state: Default::default(), metrics: ExecutionPlanMetricsSet::new(), preserve_order, - cache, + cache: Arc::new(cache), }) } @@ -1187,7 +1609,7 @@ impl RepartitionExec { // to maintain order self.input.output_partitioning().partition_count() > 1; let eq_properties = Self::eq_properties_helper(&self.input, self.preserve_order); - self.cache = self.cache.with_eq_properties(eq_properties); + Arc::make_mut(&mut self.cache).set_eq_properties(eq_properties); self } @@ -1212,12 +1634,33 @@ impl RepartitionExec { input_partition: usize, num_input_partitions: usize, ) -> Result<()> { - let mut partitioner = BatchPartitioner::try_new( - partitioning, - metrics.repartition_time.clone(), - input_partition, - num_input_partitions, - )?; + let mut partitioner = match &partitioning { + Partitioning::Hash(exprs, num_partitions) => { + BatchPartitioner::new_hash_partitioner( + exprs.clone(), + *num_partitions, + metrics.repartition_time.clone(), + )? + } + Partitioning::RoundRobinBatch(num_partitions) => { + BatchPartitioner::new_round_robin_partitioner( + *num_partitions, + metrics.repartition_time.clone(), + input_partition, + num_input_partitions, + ) + } + Partitioning::Range(_) => { + // Range repartition execution is tracked in + // https://github.com/apache/datafusion/issues/22397 + return not_impl_err!( + "Range partitioning execution is not implemented by RepartitionExec" + ); + } + other => { + return not_impl_err!("Unsupported repartitioning scheme {other:?}"); + } + }; // While there are still outputs to send to, keep pulling inputs let mut batches_until_yield = partitioner.num_partitions(); @@ -1240,33 +1683,17 @@ impl RepartitionExec { for res in partitioner.partition_iter(batch)? { let (partition, batch) = res?; - let size = batch.get_array_memory_size(); let timer = metrics.send_time[partition].timer(); // if there is still a receiver, send to it - if let Some(channel) = output_channels.get_mut(&partition) { - let (batch_to_send, is_memory_batch) = - match channel.reservation.lock().try_grow(size) { - Ok(_) => { - // Memory available - send in-memory batch - (RepartitionBatch::Memory(batch), true) - } - Err(_) => { - // We're memory limited - spill to SpillPool - // SpillPool handles file handle reuse and rotation - channel.spill_writer.push_batch(&batch)?; - // Send marker indicating batch was spilled - (RepartitionBatch::Spilled, false) - } - }; - - if channel.sender.send(Some(Ok(batch_to_send))).await.is_err() { - // If the other end has hung up, it was an early shutdown (e.g. LIMIT) - // Only shrink memory if it was a memory batch - if is_memory_batch { - channel.reservation.lock().shrink(size); + if let Some(output_channel) = output_channels.get_mut(&partition) { + for batch in output_channel.coalesce(batch)? { + if output_channel.send(batch).await.is_err() { + // If the other end has hung up, it was an early shutdown (e.g. LIMIT) + // so ignore this channel from now on. + output_channels.remove(&partition); + break; } - output_channels.remove(&partition); } } timer.done(); @@ -1296,6 +1723,14 @@ impl RepartitionExec { } } + // End of input for this task. For each output partition we still + // have a channel to, decrement the active-senders counter; whoever + // sees the count drop to zero is the last input task and must + // finalize the shared coalescer and ship its residual. + for (_, output_channel) in output_channels.drain() { + output_channel.finalize().await?; + } + // Spill writers will auto-finalize when dropped // No need for explicit flush Ok(()) @@ -1475,9 +1910,7 @@ impl PerPartitionStream { Some(Some(v)) => match v { Ok(RepartitionBatch::Memory(batch)) => { // Release memory and return batch - self.reservation - .lock() - .shrink(batch.get_array_memory_size()); + self.reservation.shrink(batch.get_array_memory_size()); return Poll::Ready(Some(Ok(batch))); } Ok(RepartitionBatch::Spilled) => { @@ -1518,7 +1951,11 @@ impl PerPartitionStream { return Poll::Ready(Some(Err(e))); } Poll::Ready(None) => { - // Spill stream ended, keep draining the memory channel + // Spill stream ended — release its resources before + // we go back to draining the memory channel. + let spill_schema = self.spill_stream.schema(); + self.spill_stream = + Box::pin(EmptyRecordBatchStream::new(spill_schema)); self.state = StreamState::ReadingMemory; } Poll::Pending => { @@ -1562,8 +1999,8 @@ mod tests { test::{ assert_is_pending, exec::{ - assert_strong_count_converges_to_zero, BarrierExec, BlockingExec, - ErrorExec, MockExec, + BarrierExec, BlockingExec, ErrorExec, MockExec, + assert_strong_count_converges_to_zero, }, }, {collect, expressions::col}, @@ -1571,13 +2008,107 @@ mod tests { use arrow::array::{ArrayRef, StringArray, UInt32Array}; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::ScalarValue; use datafusion_common::cast::as_string_array; use datafusion_common::exec_err; use datafusion_common::test_util::batches_to_sort_string; use datafusion_common_runtime::JoinSet; + use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; + use datafusion_physical_expr::{PhysicalSortExpr, RangePartitioning, SplitPoint}; use insta::assert_snapshot; - use itertools::Itertools; + + #[test] + fn strength_reduced_u64_remainder_matches_modulo() { + let divisors = [ + 1, + 2, + 3, + 4, + 5, + 7, + 8, + 10, + 16, + 31, + 32, + 63, + 64, + 65, + 97, + u64::from(u32::MAX), + u64::from(u32::MAX) + 1, + 1_u64 << 32, + (1_u64 << 63) - 1, + 1_u64 << 63, + u64::MAX - 1, + u64::MAX, + ]; + let values = [ + 0, + 1, + 2, + 3, + 4, + 5, + 31, + 32, + 33, + 63, + 64, + 65, + u64::from(u32::MAX) - 1, + u64::from(u32::MAX), + u64::from(u32::MAX) + 1, + (1_u64 << 32) - 1, + 1_u64 << 32, + (1_u64 << 32) + 1, + (1_u64 << 63) - 1, + 1_u64 << 63, + (1_u64 << 63) + 1, + u64::MAX - 1, + u64::MAX, + ]; + + for divisor in divisors { + let reducer = StrengthReducedU64::new(divisor); + for value in values { + assert_eq!( + reducer.remainder(value), + value % divisor, + "value={value} divisor={divisor}" + ); + } + + let mut value = 0x1234_5678_9abc_def0 ^ divisor; + for _ in 0..10_000 { + value = value + .wrapping_mul(6_364_136_223_846_793_005) + .wrapping_add(1_442_695_040_888_963_407); + assert_eq!( + reducer.remainder(value), + value % divisor, + "value={value} divisor={divisor}" + ); + } + } + } + + #[test] + fn hash_partitioner_requires_nonzero_partitions() { + let metrics = ExecutionPlanMetricsSet::new(); + let timer = MetricBuilder::new(&metrics).subset_time("test", 0); + + let err = BatchPartitioner::new_hash_partitioner(vec![], 0, timer) + .err() + .expect("zero hash partitions should fail") + .to_string(); + + assert!( + err.contains("Hash repartition requires at least one partition"), + "actual: {err}" + ); + } #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -1591,10 +2122,13 @@ mod tests { repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?; assert_eq!(4, output_partitions.len()); - assert_eq!(13, output_partitions[0].len()); - assert_eq!(13, output_partitions[1].len()); - assert_eq!(12, output_partitions[2].len()); - assert_eq!(12, output_partitions[3].len()); + for partition in &output_partitions { + assert_eq!(1, partition.len()); + } + assert_eq!(13 * 8, output_partitions[0][0].num_rows()); + assert_eq!(13 * 8, output_partitions[1][0].num_rows()); + assert_eq!(12 * 8, output_partitions[2][0].num_rows()); + assert_eq!(12 * 8, output_partitions[3][0].num_rows()); Ok(()) } @@ -1611,7 +2145,7 @@ mod tests { repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?; assert_eq!(1, output_partitions.len()); - assert_eq!(150, output_partitions[0].len()); + assert_eq!(150 * 8, output_partitions[0][0].num_rows()); Ok(()) } @@ -1627,12 +2161,12 @@ mod tests { let output_partitions = repartition(&schema, partitions, Partitioning::RoundRobinBatch(5)).await?; + let total_rows_per_partition = 8 * 50 * 3 / 5; assert_eq!(5, output_partitions.len()); - assert_eq!(30, output_partitions[0].len()); - assert_eq!(30, output_partitions[1].len()); - assert_eq!(30, output_partitions[2].len()); - assert_eq!(30, output_partitions[3].len()); - assert_eq!(30, output_partitions[4].len()); + for partition in output_partitions { + assert_eq!(1, partition.len()); + assert_eq!(total_rows_per_partition, partition[0].num_rows()); + } Ok(()) } @@ -1662,6 +2196,32 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_repartition_with_coalescing() -> Result<()> { + let schema = test_schema(); + // create 50 batches, each having 8 rows + let partition = create_vec_batches(50); + let partitions = vec![partition.clone(), partition.clone()]; + let partitioning = Partitioning::RoundRobinBatch(1); + + let session_config = SessionConfig::new().with_batch_size(200); + let task_ctx = TaskContext::default().with_session_config(session_config); + let task_ctx = Arc::new(task_ctx); + + // create physical plan + let exec = TestMemoryExec::try_new_exec(&partitions, Arc::clone(&schema), None)?; + let exec = RepartitionExec::try_new(exec, partitioning)?; + + for i in 0..exec.partitioning().partition_count() { + let mut stream = exec.execute(i, Arc::clone(&task_ctx))?; + while let Some(result) = stream.next().await { + let batch = result?; + assert_eq!(200, batch.num_rows()); + } + } + Ok(()) + } + fn test_schema() -> Arc { Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])) } @@ -1707,12 +2267,12 @@ mod tests { let output_partitions = handle.join().await.unwrap().unwrap(); + let total_rows_per_partition = 8 * 50 * 3 / 5; assert_eq!(5, output_partitions.len()); - assert_eq!(30, output_partitions[0].len()); - assert_eq!(30, output_partitions[1].len()); - assert_eq!(30, output_partitions[2].len()); - assert_eq!(30, output_partitions[3].len()); - assert_eq!(30, output_partitions[4].len()); + for partition in output_partitions { + assert_eq!(1, partition.len()); + assert_eq!(total_rows_per_partition, partition[0].num_rows()); + } Ok(()) } @@ -1748,6 +2308,40 @@ mod tests { ); } + #[tokio::test] + async fn unsupported_range_partitioning() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let batch = RecordBatch::try_from_iter(vec![( + "my_awesome_field", + Arc::new(StringArray::from(vec!["foo", "bar"])) as ArrayRef, + )])?; + + let schema = batch.schema(); + let expr = col("my_awesome_field", &schema)?; + let input = MockExec::new(vec![Ok(batch)], Arc::clone(&schema)); + let partitioning = Partitioning::Range(RangePartitioning::new( + [PhysicalSortExpr::new_default(expr)].into(), + vec![SplitPoint::new(vec![ScalarValue::Utf8(Some( + "foo".to_string(), + ))])], + )); + let exec = RepartitionExec::try_new(Arc::new(input), partitioning)?; + let output_stream = exec.execute(0, task_ctx)?; + + let result_string = crate::common::collect(output_stream) + .await + .unwrap_err() + .to_string(); + assert!( + result_string.contains( + "Range partitioning execution is not implemented by RepartitionExec" + ), + "actual: {result_string}" + ); + + Ok(()) + } + #[tokio::test] async fn error_for_input_exec() { // This generates an error on a call to execute. The error @@ -1950,14 +2544,13 @@ mod tests { }); let batches_with_drop = crate::common::collect(output_stream1).await.unwrap(); - fn sort(batch: Vec) -> Vec { - batch - .into_iter() - .sorted_by_key(|b| format!("{b:?}")) - .collect() - } - - assert_eq!(sort(batches_without_drop), sort(batches_with_drop)); + let items_vec_with_drop = str_batches_to_vec(&batches_with_drop); + let items_set_with_drop: HashSet<&str> = + items_vec_with_drop.iter().copied().collect(); + assert_eq!( + items_set_with_drop.symmetric_difference(&items_set).count(), + 0 + ); } fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> { @@ -2124,13 +2717,17 @@ mod tests { let input_partitions = vec![partition]; let partitioning = Partitioning::RoundRobinBatch(4); - // Set up context with moderate memory limit to force partial spilling - // 2KB should allow some batches in memory but force others to spill + // With `batch_size = 1024` and a single UInt32 column, each + // coalesced residual is ~4 KiB. An 8 KiB pool fits one and forces + // the rest to spill. let runtime = RuntimeEnvBuilder::default() - .with_memory_limit(2 * 1024, 1.0) + .with_memory_limit(8 * 1024, 1.0) .build_arc()?; - let task_ctx = TaskContext::default().with_runtime(runtime); + let session_config = SessionConfig::new().with_batch_size(1024); + let task_ctx = TaskContext::default() + .with_runtime(runtime) + .with_session_config(session_config); let task_ctx = Arc::new(task_ctx); // create physical plan @@ -2285,7 +2882,7 @@ mod tests { /// Create vector batches fn create_vec_batches(n: usize) -> Vec { let batch = create_batch(); - (0..n).map(|_| batch.clone()).collect() + std::iter::repeat_n(batch, n).collect() } /// Create batch @@ -2402,7 +2999,6 @@ mod test { use crate::union::UnionExec; use datafusion_physical_expr::expressions::col; - use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; /// Asserts that the plan is as expected /// @@ -2481,7 +3077,6 @@ mod test { #[tokio::test] async fn test_preserve_order_with_spilling() -> Result<()> { use datafusion_execution::runtime_env::RuntimeEnvBuilder; - use datafusion_execution::TaskContext; // Create sorted input data across multiple partitions // Partition1: [1,3], [5,7], [9,11] @@ -2608,7 +3203,6 @@ mod test { #[tokio::test] async fn test_hash_partitioning_with_spilling() -> Result<()> { use datafusion_execution::runtime_env::RuntimeEnvBuilder; - use datafusion_execution::TaskContext; // Create input data similar to the round-robin test let batch1 = record_batch!(("c0", UInt32, [1, 3])).unwrap(); diff --git a/datafusion/physical-plan/src/scalar_subquery.rs b/datafusion/physical-plan/src/scalar_subquery.rs new file mode 100644 index 0000000000000..25f7332f95272 --- /dev/null +++ b/datafusion/physical-plan/src/scalar_subquery.rs @@ -0,0 +1,558 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Execution plan for uncorrelated scalar subqueries. +//! +//! [`ScalarSubqueryExec`] wraps a main input plan and a set of subquery plans. +//! At execution time, it runs each subquery exactly once, extracts the scalar +//! result, and populates a shared [`ScalarSubqueryResults`] container that +//! [`ScalarSubqueryExpr`] instances hold directly and read from by index. +//! +//! [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr + +use std::fmt; +use std::sync::Arc; + +use datafusion_common::{Result, ScalarValue, Statistics, exec_err, internal_err}; +use datafusion_execution::TaskContext; +use datafusion_expr::execution_props::{ScalarSubqueryResults, SubqueryIndex}; + +use crate::execution_plan::{CardinalityEffect, ExecutionPlan, PlanProperties}; +use crate::joins::utils::{OnceAsync, OnceFut}; +use crate::stream::RecordBatchStreamAdapter; +use crate::{DisplayAs, DisplayFormatType, SendableRecordBatchStream}; + +use futures::StreamExt; +use futures::TryStreamExt; + +/// Links a scalar subquery's execution plan to its index in the shared results +/// container. The [`ScalarSubqueryExec`] that owns these links populates +/// `results[index]` at execution time, and [`ScalarSubqueryExpr`] instances +/// with the same index read from it. +/// +/// [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr +#[derive(Debug, Clone)] +pub struct ScalarSubqueryLink { + /// The physical plan for the subquery. + pub plan: Arc, + /// Index into the shared results container. + pub index: SubqueryIndex, +} + +/// Manages execution of uncorrelated scalar subqueries for a single plan +/// level. +/// +/// From a query-results perspective, this node is a pass-through: it yields +/// the same batches as its main input and exists only to populate scalar +/// subquery results as a side effect before those batches are produced. +/// +/// The first child node is the **main input plan**, whose batches are passed +/// through unchanged. The remaining children are **subquery plans**, each of +/// which must produce exactly zero or one row. Before any batches from the main +/// input are yielded, all subquery plans are executed and their scalar results +/// are stored in a shared [`ScalarSubqueryResults`] container owned by this +/// node. [`ScalarSubqueryExpr`] nodes embedded in the main input's expressions +/// hold the same container and read from it by index. +/// +/// All subqueries are evaluated eagerly when the first output partition is +/// requested, before any rows from the main input are produced. +/// +/// TODO: Consider overlapping computation of the subqueries with evaluating the +/// main query. +/// +/// [`ScalarSubqueryExpr`]: datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr +#[derive(Debug)] +pub struct ScalarSubqueryExec { + /// The main input plan whose output is passed through. + input: Arc, + /// Subquery plans and their result indexes. + subqueries: Vec, + /// Shared one-time async computation of subquery results. + subquery_future: Arc>, + /// Shared results container; the corresponding `ScalarSubqueryExpr` + /// nodes in the input plan hold the same underlying container. + results: ScalarSubqueryResults, + /// Cached plan properties (copied from input). + cache: Arc, +} + +impl ScalarSubqueryExec { + pub fn new( + input: Arc, + subqueries: Vec, + results: ScalarSubqueryResults, + ) -> Self { + let cache = Arc::clone(input.properties()); + Self { + input, + subqueries, + subquery_future: Arc::default(), + results, + cache, + } + } + + pub fn input(&self) -> &Arc { + &self.input + } + + pub fn subqueries(&self) -> &[ScalarSubqueryLink] { + &self.subqueries + } + + pub fn results(&self) -> &ScalarSubqueryResults { + &self.results + } + + /// Returns a per-child bool vec that is `true` for the main input + /// (child 0) and `false` for every subquery child. + fn true_for_input_only(&self) -> Vec { + std::iter::once(true) + .chain(std::iter::repeat_n(false, self.subqueries.len())) + .collect() + } +} + +impl DisplayAs for ScalarSubqueryExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "ScalarSubqueryExec: subqueries={}", + self.subqueries.len() + ) + } + DisplayFormatType::TreeRender => { + write!(f, "") + } + } + } +} + +impl ExecutionPlan for ScalarSubqueryExec { + fn name(&self) -> &'static str { + "ScalarSubqueryExec" + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + let mut children = vec![&self.input]; + for sq in &self.subqueries { + children.push(&sq.plan); + } + children + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + // First child is the main input, the rest are subquery plans. + let input = children.remove(0); + let subqueries = self + .subqueries + .iter() + .zip(children) + .map(|(sq, new_plan)| ScalarSubqueryLink { + plan: new_plan, + index: sq.index, + }) + .collect(); + Ok(Arc::new(ScalarSubqueryExec::new( + input, + subqueries, + self.results.clone(), + ))) + } + + fn reset_state(self: Arc) -> Result> { + self.results.clear(); + Ok(Arc::new(ScalarSubqueryExec { + input: Arc::clone(&self.input), + subqueries: self.subqueries.clone(), + subquery_future: Arc::default(), + results: self.results.clone(), + cache: Arc::clone(&self.cache), + })) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let subqueries = self.subqueries.clone(); + let results = self.results.clone(); + let subquery_ctx = Arc::clone(&context); + let mut subquery_future = self.subquery_future.try_once(move || { + Ok(async move { execute_subqueries(subqueries, results, subquery_ctx).await }) + })?; + let input = Arc::clone(&self.input); + let schema = self.schema(); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::stream::once(async move { + // Execute all subqueries exactly once, even when multiple + // partitions call execute() concurrently. + wait_for_subqueries(&mut subquery_future).await?; + + // Now that the subqueries have finished execution, we can + // safely execute the main input + input.execute(partition, context) + }) + .try_flatten(), + ))) + } + + fn maintains_input_order(&self) -> Vec { + // Only the main input (first child); subquery children don't contribute + // to ordering. + self.true_for_input_only() + } + + fn benefits_from_input_partitioning(&self) -> Vec { + // ScalarSubqueryExec is a pass-through coordinator: it does not + // benefit from repartitioning any child directly below it. + vec![false; self.subqueries.len() + 1] + } + + fn partition_statistics(&self, partition: Option) -> Result> { + self.input.partition_statistics(partition) + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } +} + +/// Wait for the subquery execution future to complete. +async fn wait_for_subqueries(fut: &mut OnceFut<()>) -> Result<()> { + std::future::poll_fn(|cx| fut.get_shared(cx)).await?; + Ok(()) +} + +async fn execute_subqueries( + subqueries: Vec, + results: ScalarSubqueryResults, + context: Arc, +) -> Result<()> { + // Evaluate subqueries in parallel; wait for them all to finish evaluation + // before returning. + let futures = subqueries.iter().map(|sq| { + let plan = Arc::clone(&sq.plan); + let ctx = Arc::clone(&context); + let results = results.clone(); + let index = sq.index; + async move { + let value = execute_scalar_subquery(plan, ctx).await?; + results.set(index, value)?; + Ok(()) as Result<()> + } + }); + futures::future::try_join_all(futures).await?; + Ok(()) +} + +/// Execute a single subquery plan and extract the scalar value. +/// Returns NULL for 0 rows, the scalar value for exactly 1 row, +/// or an error for >1 rows. +async fn execute_scalar_subquery( + plan: Arc, + context: Arc, +) -> Result { + let schema = plan.schema(); + if schema.fields().len() != 1 { + // Should be enforced by the physical planner. + return internal_err!( + "Scalar subquery must return exactly one column, got {}", + schema.fields().len() + ); + } + + let mut stream = crate::execute_stream(plan, context)?; + let mut result: Option = None; + + while let Some(batch) = stream.next().await.transpose()? { + if batch.num_rows() == 0 { + continue; + } + if result.is_some() || batch.num_rows() > 1 { + return exec_err!("Scalar subquery returned more than one row"); + } + result = Some(ScalarValue::try_from_array(batch.column(0), 0)?); + } + + // 0 rows → typed NULL per SQL semantics + match result { + Some(v) => Ok(v), + None => ScalarValue::try_from(schema.field(0).data_type()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::{self, TestMemoryExec}; + use crate::{ + execution_plan::reset_plan_states, + projection::{ProjectionExec, ProjectionExpr}, + }; + + use std::sync::atomic::{AtomicUsize, Ordering}; + + use crate::test::exec::ErrorExec; + use arrow::array::{Int32Array, Int64Array}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; + + enum ExpectedSubqueryResult { + Value(ScalarValue), + Error(&'static str), + } + + #[derive(Debug)] + struct CountingExec { + inner: Arc, + execute_calls: Arc, + } + + impl CountingExec { + fn new(inner: Arc, execute_calls: Arc) -> Self { + Self { + inner, + execute_calls, + } + } + } + + impl DisplayAs for CountingExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "CountingExec") + } + DisplayFormatType::TreeRender => write!(f, ""), + } + } + } + + impl ExecutionPlan for CountingExec { + fn name(&self) -> &'static str { + "CountingExec" + } + + fn properties(&self) -> &Arc { + self.inner.properties() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.inner] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::new( + children.remove(0), + Arc::clone(&self.execute_calls), + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + self.execute_calls.fetch_add(1, Ordering::SeqCst); + self.inner.execute(partition, context) + } + } + + fn make_subquery_plan(batches: Vec) -> Arc { + let schema = batches[0].schema(); + TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap() + } + + fn int32_batch(values: Vec) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(values))]).unwrap() + } + + fn empty_int64_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); + RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![] as Vec))]) + .unwrap() + } + + fn placeholder_input() -> Arc { + Arc::new(crate::placeholder_row::PlaceholderRowExec::new( + test::aggr_test_schema(), + )) + } + + fn single_subquery_exec( + input: Arc, + subquery_plan: Arc, + results: ScalarSubqueryResults, + ) -> ScalarSubqueryExec { + ScalarSubqueryExec::new( + input, + vec![ScalarSubqueryLink { + plan: subquery_plan, + index: SubqueryIndex::new(0), + }], + results, + ) + } + + fn scalar_subquery_projection_input( + results: ScalarSubqueryResults, + ) -> Result> { + Ok(Arc::new(ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(ScalarSubqueryExpr::new( + DataType::Int32, + false, + SubqueryIndex::new(0), + results, + )), + alias: "sq".to_string(), + }], + placeholder_input(), + )?)) + } + + fn extract_single_int32_value(batches: &[RecordBatch]) -> i32 { + assert_eq!(batches.len(), 1); + let values = batches[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.len(), 1); + values.value(0) + } + + #[tokio::test] + async fn test_execute_scalar_subquery_row_count_semantics() -> Result<()> { + for (name, plan, expected) in [ + ( + "single_row", + make_subquery_plan(vec![int32_batch(vec![42])]), + ExpectedSubqueryResult::Value(ScalarValue::Int32(Some(42))), + ), + ( + "zero_rows", + make_subquery_plan(vec![empty_int64_batch()]), + ExpectedSubqueryResult::Value(ScalarValue::Int64(None)), + ), + ( + "multiple_rows", + make_subquery_plan(vec![int32_batch(vec![1, 2, 3])]), + ExpectedSubqueryResult::Error("more than one row"), + ), + ] { + let actual = + execute_scalar_subquery(plan, Arc::new(TaskContext::default())).await; + match expected { + ExpectedSubqueryResult::Value(expected) => { + assert_eq!(actual?, expected, "{name}"); + } + ExpectedSubqueryResult::Error(expected) => { + let err = actual.expect_err(name); + assert!( + err.to_string().contains(expected), + "{name}: expected error containing '{expected}', got {err}" + ); + } + } + } + + Ok(()) + } + + #[tokio::test] + async fn test_failed_subquery_is_not_retried() -> Result<()> { + let execute_calls = Arc::new(AtomicUsize::new(0)); + let subquery_plan = Arc::new(CountingExec::new( + Arc::new(ErrorExec::new()), + Arc::clone(&execute_calls), + )); + let exec = single_subquery_exec( + placeholder_input(), + subquery_plan, + ScalarSubqueryResults::new(1), + ); + + let ctx = Arc::new(TaskContext::default()); + let stream = exec.execute(0, Arc::clone(&ctx))?; + assert!(crate::common::collect(stream).await.is_err()); + + let stream = exec.execute(0, ctx)?; + assert!(crate::common::collect(stream).await.is_err()); + + assert_eq!(execute_calls.load(Ordering::SeqCst), 1); + Ok(()) + } + + #[tokio::test] + async fn test_reset_state_clears_results_and_reexecutes_subqueries() -> Result<()> { + let execute_calls = Arc::new(AtomicUsize::new(0)); + let results = ScalarSubqueryResults::new(1); + let subquery_plan = Arc::new(CountingExec::new( + make_subquery_plan(vec![int32_batch(vec![42])]), + Arc::clone(&execute_calls), + )); + let exec: Arc = Arc::new(single_subquery_exec( + scalar_subquery_projection_input(results.clone())?, + subquery_plan, + results.clone(), + )); + + let batches = + crate::common::collect(exec.execute(0, Arc::new(TaskContext::default()))?) + .await?; + assert_eq!(extract_single_int32_value(&batches), 42); + assert_eq!( + results.get(SubqueryIndex::new(0)), + Some(ScalarValue::Int32(Some(42))) + ); + + let reset_exec = reset_plan_states(Arc::clone(&exec))?; + assert_eq!(results.get(SubqueryIndex::new(0)), None); + + let reset_batches = crate::common::collect( + reset_exec.execute(0, Arc::new(TaskContext::default()))?, + ) + .await?; + assert_eq!(extract_single_int32_value(&reset_batches), 42); + assert_eq!( + results.get(SubqueryIndex::new(0)), + Some(ScalarValue::Int32(Some(42))) + ); + assert_eq!(execute_calls.load(Ordering::SeqCst), 2); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/sort_pushdown.rs b/datafusion/physical-plan/src/sort_pushdown.rs new file mode 100644 index 0000000000000..8432fd5dabee7 --- /dev/null +++ b/datafusion/physical-plan/src/sort_pushdown.rs @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sort pushdown types for physical execution plans. +//! +//! This module provides types used for pushing sort ordering requirements +//! down through the execution plan tree to data sources. + +/// Result of attempting to push down sort ordering to a node. +/// +/// Used by [`ExecutionPlan::try_pushdown_sort`] to communicate +/// whether and how sort ordering was successfully pushed down. +/// +/// [`ExecutionPlan::try_pushdown_sort`]: crate::ExecutionPlan::try_pushdown_sort +#[derive(Debug, Clone)] +pub enum SortOrderPushdownResult { + /// The source can guarantee exact ordering (data is perfectly sorted). + /// + /// When this is returned, the optimizer can safely remove the Sort operator + /// entirely since the data source guarantees the requested ordering. + Exact { + /// The optimized node that provides exact ordering + inner: T, + }, + /// The source has optimized for the ordering but cannot guarantee perfect sorting. + /// + /// This indicates the data source has been optimized (e.g., reordered files/row groups + /// based on statistics, enabled reverse scanning) but the data may not be perfectly + /// sorted. The optimizer should keep the Sort operator but benefits from the + /// optimization (e.g., faster TopK queries due to early termination). + Inexact { + /// The optimized node that provides approximate ordering + inner: T, + }, + /// The source cannot optimize for this ordering. + /// + /// The data source does not support the requested sort ordering and no + /// optimization was applied. + Unsupported, +} + +impl SortOrderPushdownResult { + /// Extract the inner value if present + pub fn into_inner(self) -> Option { + match self { + Self::Exact { inner } | Self::Inexact { inner } => Some(inner), + Self::Unsupported => None, + } + } + + /// Map the inner value to a different type while preserving the variant. + pub fn map U>(self, f: F) -> SortOrderPushdownResult { + match self { + Self::Exact { inner } => SortOrderPushdownResult::Exact { inner: f(inner) }, + Self::Inexact { inner } => { + SortOrderPushdownResult::Inexact { inner: f(inner) } + } + Self::Unsupported => SortOrderPushdownResult::Unsupported, + } + } + + /// Try to map the inner value, returning an error if the function fails. + pub fn try_map Result>( + self, + f: F, + ) -> Result, E> { + match self { + Self::Exact { inner } => { + Ok(SortOrderPushdownResult::Exact { inner: f(inner)? }) + } + Self::Inexact { inner } => { + Ok(SortOrderPushdownResult::Inexact { inner: f(inner)? }) + } + Self::Unsupported => Ok(SortOrderPushdownResult::Unsupported), + } + } + + /// Convert this result to `Inexact`, downgrading `Exact` if present. + /// + /// This is useful when an operation (like merging multiple partitions) + /// cannot guarantee exact ordering even if the input provides it. + /// + /// # Examples + /// + /// ``` + /// # use datafusion_physical_plan::SortOrderPushdownResult; + /// let exact = SortOrderPushdownResult::Exact { inner: 42 }; + /// let inexact = exact.into_inexact(); + /// assert!(matches!(inexact, SortOrderPushdownResult::Inexact { inner: 42 })); + /// + /// let already_inexact = SortOrderPushdownResult::Inexact { inner: 42 }; + /// let still_inexact = already_inexact.into_inexact(); + /// assert!(matches!(still_inexact, SortOrderPushdownResult::Inexact { inner: 42 })); + /// + /// let unsupported = SortOrderPushdownResult::::Unsupported; + /// let still_unsupported = unsupported.into_inexact(); + /// assert!(matches!(still_unsupported, SortOrderPushdownResult::Unsupported)); + /// ``` + pub fn into_inexact(self) -> Self { + match self { + Self::Exact { inner } => Self::Inexact { inner }, + Self::Inexact { inner } => Self::Inexact { inner }, + Self::Unsupported => Self::Unsupported, + } + } +} diff --git a/datafusion/physical-plan/src/sorts/builder.rs b/datafusion/physical-plan/src/sorts/builder.rs index 9b2fa968222c4..75eb2ff980325 100644 --- a/datafusion/physical-plan/src/sorts/builder.rs +++ b/datafusion/physical-plan/src/sorts/builder.rs @@ -16,11 +16,14 @@ // under the License. use crate::spill::get_record_batch_memory_size; +use arrow::array::ArrayRef; use arrow::compute::interleave; use arrow::datatypes::SchemaRef; +use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; -use datafusion_common::Result; +use datafusion_common::{DataFusionError, Result}; use datafusion_execution::memory_pool::MemoryReservation; +use log::warn; use std::sync::Arc; #[derive(Debug, Copy, Clone, Default)] @@ -40,9 +43,24 @@ pub struct BatchBuilder { /// Maintain a list of [`RecordBatch`] and their corresponding stream batches: Vec<(usize, RecordBatch)>, - /// Accounts for memory used by buffered batches + /// Accounts for memory used by buffered batches. + /// + /// May include pre-reserved bytes (from `sort_spill_reservation_bytes`) + /// that were transferred via [`MemoryReservation::take()`] to prevent + /// starvation when concurrent sort partitions compete for pool memory. reservation: MemoryReservation, + /// Tracks the actual memory used by buffered batches (not including + /// pre-reserved bytes). This allows [`Self::push_batch`] to skip pool + /// allocation requests when the pre-reserved bytes cover the batch. + batches_mem_used: usize, + + /// The initial reservation size at construction time. When the reservation + /// is pre-loaded with `sort_spill_reservation_bytes` (via `take()`), this + /// records that amount so we never shrink below it, maintaining the + /// anti-starvation guarantee throughout the merge. + initial_reservation: usize, + /// The current [`BatchCursor`] for each stream cursors: Vec, @@ -59,19 +77,26 @@ impl BatchBuilder { batch_size: usize, reservation: MemoryReservation, ) -> Self { + let initial_reservation = reservation.size(); Self { schema, batches: Vec::with_capacity(stream_count * 2), cursors: vec![BatchCursor::default(); stream_count], indices: Vec::with_capacity(batch_size), reservation, + batches_mem_used: 0, + initial_reservation, } } /// Append a new batch in `stream_idx` pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) -> Result<()> { - self.reservation - .try_grow(get_record_batch_memory_size(&batch))?; + let size = get_record_batch_memory_size(&batch); + self.batches_mem_used += size; + // Only request additional memory from the pool when actual batch + // usage exceeds the current reservation (which may include + // pre-reserved bytes from sort_spill_reservation_bytes). + try_grow_reservation_to_at_least(&mut self.reservation, self.batches_mem_used)?; let batch_idx = self.batches.len(); self.batches.push((stream_idx, batch)); self.cursors[stream_idx] = BatchCursor { @@ -104,9 +129,79 @@ impl BatchBuilder { &self.schema } + /// Try to interleave all columns using the given index slice. + fn try_interleave_columns( + &self, + indices: &[(usize, usize)], + ) -> Result> { + (0..self.schema.fields.len()) + .map(|column_idx| { + let arrays: Vec<_> = self + .batches + .iter() + .map(|(_, batch)| batch.column(column_idx).as_ref()) + .collect(); + // Arrow 58.1.0+ returns OffsetOverflowError directly from + // interleave, allowing retry_interleave to shrink the batch. + interleave(&arrays, indices).map_err(Into::into) + }) + .collect::>>() + } + + /// Builds a record batch from the first `rows_to_emit` buffered rows. + fn finish_record_batch( + &mut self, + rows_to_emit: usize, + columns: Vec, + ) -> Result { + // Remove consumed indices, keeping any remaining for the next call. + self.indices.drain(..rows_to_emit); + + // Only clean up fully-consumed batches when all indices are drained, + // because remaining indices may still reference earlier batches. + // In the overflow/partial-emit case this may retain some extra memory + // across a few drain polls, but avoids costly index scanning on the + // hot path. The retention is bounded and short-lived since leftover + // rows are drained over subsequent polls. + if self.indices.is_empty() { + // New cursors are only created once the previous cursor for the stream + // is finished. This means all remaining rows from all but the last batch + // for each stream have been yielded to the newly created record batch + // + // We can therefore drop all but the last batch for each stream + let mut batch_idx = 0; + let mut retained = 0; + self.batches.retain(|(stream_idx, batch)| { + let stream_cursor = &mut self.cursors[*stream_idx]; + let retain = stream_cursor.batch_idx == batch_idx; + batch_idx += 1; + + if retain { + stream_cursor.batch_idx = retained; + retained += 1; + } else { + self.batches_mem_used -= get_record_batch_memory_size(batch); + } + retain + }); + } + + // Release excess memory back to the pool, but never shrink below + // initial_reservation to maintain the anti-starvation guarantee + // for the merge phase. + let target = self.batches_mem_used.max(self.initial_reservation); + if self.reservation.size() > target { + self.reservation.shrink(self.reservation.size() - target); + } + + RecordBatch::try_new(Arc::clone(&self.schema), columns).map_err(Into::into) + } + /// Drains the in_progress row indexes, and builds a new RecordBatch from them /// - /// Will then drop any batches for which all rows have been yielded to the output + /// Will then drop any batches for which all rows have been yielded to the output. + /// If an offset overflow occurs (e.g. string/list offsets exceed i32::MAX), + /// retries with progressively fewer rows until it succeeds. /// /// Returns `None` if no pending rows pub fn build_record_batch(&mut self) -> Result> { @@ -114,43 +209,151 @@ impl BatchBuilder { return Ok(None); } - let columns = (0..self.schema.fields.len()) - .map(|column_idx| { - let arrays: Vec<_> = self - .batches - .iter() - .map(|(_, batch)| batch.column(column_idx).as_ref()) - .collect(); - Ok(interleave(&arrays, &self.indices)?) - }) - .collect::>>()?; - - self.indices.clear(); - - // New cursors are only created once the previous cursor for the stream - // is finished. This means all remaining rows from all but the last batch - // for each stream have been yielded to the newly created record batch - // - // We can therefore drop all but the last batch for each stream - let mut batch_idx = 0; - let mut retained = 0; - self.batches.retain(|(stream_idx, batch)| { - let stream_cursor = &mut self.cursors[*stream_idx]; - let retain = stream_cursor.batch_idx == batch_idx; - batch_idx += 1; - - if retain { - stream_cursor.batch_idx = retained; - retained += 1; - } else { - self.reservation.shrink(get_record_batch_memory_size(batch)); + let (rows_to_emit, columns) = + retry_interleave(self.indices.len(), self.indices.len(), |rows_to_emit| { + self.try_interleave_columns(&self.indices[..rows_to_emit]) + })?; + + Ok(Some(self.finish_record_batch(rows_to_emit, columns)?)) + } +} + +/// Try to grow `reservation` so it covers at least `needed` bytes. +/// +/// When a reservation has been pre-loaded with bytes (e.g. via +/// [`MemoryReservation::take()`]), this avoids redundant pool +/// allocations: if the reservation already covers `needed`, this is +/// a no-op; otherwise only the deficit is requested from the pool. +pub(crate) fn try_grow_reservation_to_at_least( + reservation: &mut MemoryReservation, + needed: usize, +) -> Result<()> { + if needed > reservation.size() { + reservation.try_grow(needed - reservation.size())?; + } + Ok(()) +} + +/// Returns true if the error is an Arrow offset overflow. +fn is_offset_overflow(e: &DataFusionError) -> bool { + matches!( + e, + DataFusionError::ArrowError(boxed, _) + if matches!(boxed.as_ref(), ArrowError::OffsetOverflowError(_)) + ) +} + +#[cfg(test)] +fn offset_overflow_error() -> DataFusionError { + DataFusionError::ArrowError(Box::new(ArrowError::OffsetOverflowError(0)), None) +} + +fn retry_interleave( + mut rows_to_emit: usize, + total_rows: usize, + mut interleave: F, +) -> Result<(usize, T)> +where + F: FnMut(usize) -> Result, +{ + loop { + match interleave(rows_to_emit) { + Ok(value) => return Ok((rows_to_emit, value)), + // Only offset overflow is recoverable by emitting fewer rows. + Err(e) if is_offset_overflow(&e) => { + rows_to_emit /= 2; + if rows_to_emit == 0 { + return Err(e); + } + warn!( + "Interleave offset overflow with {total_rows} rows, retrying with {rows_to_emit}" + ); } - retain + Err(e) => return Err(e), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, ArrayDataBuilder, Int32Array, ListArray}; + use arrow::buffer::Buffer; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_execution::memory_pool::{ + MemoryConsumer, MemoryPool, UnboundedMemoryPool, + }; + + fn overflow_list_batch() -> RecordBatch { + let values_field = Arc::new(Field::new_list_field(DataType::Int32, true)); + // SAFETY: This intentionally constructs an invalid child length so + // Arrow's interleave hits offset overflow before touching child data. + let list = ListArray::from(unsafe { + ArrayDataBuilder::new(DataType::List(Arc::clone(&values_field))) + .len(1) + .add_buffer(Buffer::from_slice_ref([0_i32, i32::MAX])) + .add_child_data(Int32Array::from(Vec::::new()).to_data()) + .build_unchecked() }); + let schema = Arc::new(Schema::new(vec![Field::new( + "list_col", + DataType::List(values_field), + true, + )])); + RecordBatch::try_new(schema, vec![Arc::new(list)]).unwrap() + } + + #[test] + fn test_retry_interleave_halves_rows_until_success() { + let mut attempts = Vec::new(); + + let (rows_to_emit, result) = retry_interleave(4, 4, |rows_to_emit| { + attempts.push(rows_to_emit); + if rows_to_emit > 1 { + Err(offset_overflow_error()) + } else { + Ok("ok") + } + }) + .unwrap(); + + assert_eq!(rows_to_emit, 1); + assert_eq!(result, "ok"); + assert_eq!(attempts, vec![4, 2, 1]); + } + + #[test] + fn test_is_offset_overflow_matches_arrow_error() { + assert!(is_offset_overflow(&offset_overflow_error())); + } + + #[test] + fn test_retry_interleave_does_not_retry_non_offset_errors() { + let mut attempts = Vec::new(); + + let error = retry_interleave(4, 4, |rows_to_emit| { + attempts.push(rows_to_emit); + Err::<(), _>(DataFusionError::Execution("boom".into())) + }) + .unwrap_err(); + + assert_eq!(attempts, vec![4]); + assert!(matches!(error, DataFusionError::Execution(msg) if msg == "boom")); + } + + #[test] + fn test_try_interleave_columns_surfaces_arrow_offset_overflow() { + let batch = overflow_list_batch(); + let schema = batch.schema(); + let pool: Arc = Arc::new(UnboundedMemoryPool::default()); + let reservation = MemoryConsumer::new("test").register(&pool); + let mut builder = BatchBuilder::new(schema, 1, 2, reservation); + builder.push_batch(0, batch).unwrap(); + + let error = builder + .try_interleave_columns(&[(0, 0), (0, 0)]) + .unwrap_err(); - Ok(Some(RecordBatch::try_new( - Arc::clone(&self.schema), - columns, - )?)) + assert!(is_offset_overflow(&error)); } } diff --git a/datafusion/physical-plan/src/sorts/cursor.rs b/datafusion/physical-plan/src/sorts/cursor.rs index 54dc2414e4f08..288ec4cee1594 100644 --- a/datafusion/physical-plan/src/sorts/cursor.rs +++ b/datafusion/physical-plan/src/sorts/cursor.rs @@ -19,8 +19,8 @@ use std::cmp::Ordering; use std::sync::Arc; use arrow::array::{ - types::ByteArrayType, Array, ArrowPrimitiveType, GenericByteArray, - GenericByteViewArray, OffsetSizeTrait, PrimitiveArray, StringViewArray, + Array, ArrowPrimitiveType, GenericByteArray, GenericByteViewArray, OffsetSizeTrait, + PrimitiveArray, StringViewArray, types::ByteArrayType, }; use arrow::buffer::{Buffer, OffsetBuffer, ScalarBuffer}; use arrow::compute::SortOptions; @@ -445,7 +445,6 @@ mod tests { use datafusion_execution::memory_pool::{ GreedyMemoryPool, MemoryConsumer, MemoryPool, }; - use std::sync::Arc; use super::*; diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index 720a3e53e4597..c29933535adc5 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -20,13 +20,13 @@ use std::pin::Pin; use std::sync::Arc; -use std::task::{ready, Context, Poll}; +use std::task::{Context, Poll, ready}; +use crate::RecordBatchStream; use crate::metrics::BaselineMetrics; use crate::sorts::builder::BatchBuilder; use crate::sorts::cursor::{Cursor, CursorValues}; use crate::sorts::stream::PartitionedStream; -use crate::RecordBatchStream; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; @@ -53,6 +53,14 @@ pub(crate) struct SortPreservingMergeStream { /// `fetch` limit. done: bool, + /// Whether buffered rows should be drained after `done` is set. + /// + /// This is enabled when we stop because the `fetch` limit has been + /// reached, allowing partial batches left over after overflow handling to + /// be emitted on subsequent polls. It remains disabled for terminal + /// errors so the stream does not yield data after returning `Err`. + drain_in_progress_on_done: bool, + /// A loser tree that always produces the minimum cursor /// /// Node 0 stores the top winner, Nodes 1..num_streams store @@ -164,6 +172,7 @@ impl SortPreservingMergeStream { streams, metrics, done: false, + drain_in_progress_on_done: false, cursors: (0..stream_count).map(|_| None).collect(), prev_cursors: (0..stream_count).map(|_| None).collect(), round_robin_tie_breaker_mode: false, @@ -203,11 +212,28 @@ impl SortPreservingMergeStream { } } + fn emit_in_progress_batch(&mut self) -> Result> { + let rows_before = self.in_progress.len(); + let result = self.in_progress.build_record_batch(); + self.produced += rows_before - self.in_progress.len(); + result + } + fn poll_next_inner( &mut self, cx: &mut Context<'_>, ) -> Poll>> { if self.done { + // When `build_record_batch()` hits an i32 offset overflow (e.g. + // combined string offsets exceed 2 GB), it emits a partial batch + // and keeps the remaining rows in `self.in_progress.indices`. + // Drain those leftover rows before terminating the stream, + // otherwise they would be silently dropped. + // Repeated overflows are fine — each poll emits another partial + // batch until `in_progress` is fully drained. + if self.drain_in_progress_on_done && !self.in_progress.is_empty() { + return Poll::Ready(self.emit_in_progress_batch().transpose()); + } return Poll::Ready(None); } // Once all partitions have set their corresponding cursors for the loser tree, @@ -283,14 +309,13 @@ impl SortPreservingMergeStream { // stop sorting if fetch has been reached if self.fetch_reached() { self.done = true; + self.drain_in_progress_on_done = true; } else if self.in_progress.len() < self.batch_size { continue; } } - self.produced += self.in_progress.len(); - - return Poll::Ready(self.in_progress.build_record_batch().transpose()); + return Poll::Ready(self.emit_in_progress_batch().transpose()); } } @@ -542,3 +567,95 @@ impl RecordBatchStream for SortPreservingMergeStream Arc::clone(self.in_progress.schema()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::metrics::ExecutionPlanMetricsSet; + use crate::sorts::stream::PartitionedStream; + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_execution::memory_pool::{ + MemoryConsumer, MemoryPool, UnboundedMemoryPool, + }; + use futures::task::noop_waker_ref; + use std::cmp::Ordering; + + #[derive(Debug)] + struct EmptyPartitionedStream; + + impl PartitionedStream for EmptyPartitionedStream { + type Output = Result<(DummyValues, RecordBatch)>; + + fn partitions(&self) -> usize { + 1 + } + + fn poll_next( + &mut self, + _cx: &mut Context<'_>, + _stream_idx: usize, + ) -> Poll> { + Poll::Ready(None) + } + } + + #[derive(Debug)] + struct DummyValues; + + impl CursorValues for DummyValues { + fn len(&self) -> usize { + 0 + } + + fn eq(_l: &Self, _l_idx: usize, _r: &Self, _r_idx: usize) -> bool { + unreachable!("done-path test should not compare cursors") + } + + fn eq_to_previous(_cursor: &Self, _idx: usize) -> bool { + unreachable!("done-path test should not compare cursors") + } + + fn compare(_l: &Self, _l_idx: usize, _r: &Self, _r_idx: usize) -> Ordering { + unreachable!("done-path test should not compare cursors") + } + } + + #[test] + fn test_done_drains_buffered_rows() { + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); + let pool: Arc = Arc::new(UnboundedMemoryPool::default()); + let reservation = MemoryConsumer::new("test").register(&pool); + let metrics = ExecutionPlanMetricsSet::new(); + + let mut stream = SortPreservingMergeStream::::new( + Box::new(EmptyPartitionedStream), + Arc::clone(&schema), + BaselineMetrics::new(&metrics, 0), + 16, + Some(1), + reservation, + true, + ); + + let batch = + RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1]))]) + .unwrap(); + stream.in_progress.push_batch(0, batch).unwrap(); + stream.in_progress.push_row(0); + stream.done = true; + stream.drain_in_progress_on_done = true; + + let waker = noop_waker_ref(); + let mut cx = Context::from_waker(waker); + + match stream.poll_next_inner(&mut cx) { + Poll::Ready(Some(Ok(batch))) => assert_eq!(batch.num_rows(), 1), + other => { + panic!("expected buffered rows to be drained after done, got {other:?}") + } + } + assert!(stream.in_progress.is_empty()); + assert!(matches!(stream.poll_next_inner(&mut cx), Poll::Ready(None))); + } +} diff --git a/datafusion/physical-plan/src/sorts/mod.rs b/datafusion/physical-plan/src/sorts/mod.rs index 9c72e34fe343e..ca8d4a4400c49 100644 --- a/datafusion/physical-plan/src/sorts/mod.rs +++ b/datafusion/physical-plan/src/sorts/mod.rs @@ -22,7 +22,10 @@ mod cursor; mod merge; mod multi_level_merge; pub mod partial_sort; +pub mod partitioned_topk; pub mod sort; pub mod sort_preserving_merge; mod stream; pub mod streaming_merge; + +pub(crate) use stream::IncrementalSortIterator; diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs index 6e7a5e7a72616..8985e1d8c70ee 100644 --- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -30,7 +30,8 @@ use arrow::datatypes::SchemaRef; use datafusion_common::Result; use datafusion_execution::memory_pool::MemoryReservation; -use crate::sorts::sort::get_reserved_byte_for_record_batch_size; +use crate::sorts::builder::try_grow_reservation_to_at_least; +use crate::sorts::sort::get_reserved_bytes_for_record_batch_size; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::stream::RecordBatchStreamAdapter; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; @@ -145,7 +146,7 @@ impl Debug for MultiLevelMergeBuilder { } impl MultiLevelMergeBuilder { - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] pub(crate) fn new( spill_manager: SpillManager, schema: SchemaRef, @@ -253,7 +254,12 @@ impl MultiLevelMergeBuilder { // Need to merge multiple streams (_, _) => { - let mut memory_reservation = self.reservation.new_empty(); + // Transfer any pre-reserved bytes (from sort_spill_reservation_bytes) + // to the merge memory reservation. This prevents starvation when + // concurrent sort partitions compete for pool memory: the pre-reserved + // bytes cover spill file buffer reservations without additional pool + // allocation. + let mut memory_reservation = self.reservation.take(); // Don't account for existing streams memory // as we are not holding the memory for them @@ -269,6 +275,15 @@ impl MultiLevelMergeBuilder { let is_only_merging_memory_streams = sorted_spill_files.is_empty(); + // If no spill files were selected (e.g. all too large for + // available memory but enough in-memory streams exist), + // return the pre-reserved bytes to self.reservation so + // create_new_merge_sort can transfer them to the merge + // stream's BatchBuilder. + if is_only_merging_memory_streams { + mem::swap(&mut self.reservation, &mut memory_reservation); + } + for spill in sorted_spill_files { let stream = self .spill_manager @@ -290,7 +305,11 @@ impl MultiLevelMergeBuilder { // If we're only merging memory streams, we don't need to attach the memory reservation // as it's empty if is_only_merging_memory_streams { - assert_eq!(memory_reservation.size(), 0, "when only merging memory streams, we should not have any memory reservation and let the merge sort handle the memory"); + assert_eq!( + memory_reservation.size(), + 0, + "when only merging memory streams, we should not have any memory reservation and let the merge sort handle the memory" + ); Ok(merge_sort_stream) } else { @@ -333,8 +352,10 @@ impl MultiLevelMergeBuilder { builder = builder.with_bypass_mempool(); } else { // If we are only merging in-memory streams, we need to use the memory reservation - // because we don't know the maximum size of the batches in the streams - builder = builder.with_reservation(self.reservation.new_empty()); + // because we don't know the maximum size of the batches in the streams. + // Use take() to transfer any pre-reserved bytes so the merge can use them + // as its initial budget without additional pool allocation. + builder = builder.with_reservation(self.reservation.take()); } builder.build() @@ -352,13 +373,24 @@ impl MultiLevelMergeBuilder { ) -> Result<(Vec, usize)> { assert_ne!(buffer_len, 0, "Buffer length must be greater than 0"); let mut number_of_spills_to_read_for_current_phase = 0; + // Track total memory needed for spill file buffers. When the + // reservation has pre-reserved bytes (from sort_spill_reservation_bytes), + // those bytes cover the first N spill files without additional pool + // allocation, preventing starvation under memory pressure. + let mut total_needed: usize = 0; for spill in &self.sorted_spill_files { - // For memory pools that are not shared this is good, for other this is not - // and there should be some upper limit to memory reservation so we won't starve the system - match reservation.try_grow(get_reserved_byte_for_record_batch_size( - spill.max_record_batch_memory * buffer_len, - )) { + let per_spill = get_reserved_bytes_for_record_batch_size( + spill.max_record_batch_memory, + // Size will be the same as the sliced size, bc it is a spilled batch. + spill.max_record_batch_memory, + ) * buffer_len; + total_needed += per_spill; + + // For memory pools that are not shared this is good, for other + // this is not and there should be some upper limit to memory + // reservation so we won't starve the system. + match try_grow_reservation_to_at_least(reservation, total_needed) { Ok(_) => { number_of_spills_to_read_for_current_phase += 1; } diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs b/datafusion/physical-plan/src/sorts/partial_sort.rs index 7a623b0c30d32..3bf16af36c62b 100644 --- a/datafusion/physical-plan/src/sorts/partial_sort.rs +++ b/datafusion/physical-plan/src/sorts/partial_sort.rs @@ -51,7 +51,6 @@ //! The plan concats incoming data with such last rows of previous input //! and continues partial sorting of the segments. -use std::any::Any; use std::fmt::Debug; use std::pin::Pin; use std::sync::Arc; @@ -59,20 +58,22 @@ use std::task::{Context, Poll}; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::sorts::sort::sort_batch; +use crate::stream::EmptyRecordBatchStream; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, + check_if_same_properties, }; use arrow::compute::concat_batches; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::utils::evaluate_partition_ranges; use datafusion_common::Result; +use datafusion_common::utils::evaluate_partition_ranges; use datafusion_execution::{RecordBatchStream, TaskContext}; use datafusion_physical_expr::LexOrdering; -use futures::{ready, Stream, StreamExt}; +use futures::{Stream, StreamExt, ready}; use log::trace; /// Partial Sort execution plan. @@ -93,7 +94,7 @@ pub struct PartialSortExec { /// Fetch highest/lowest n results fetch: Option, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl PartialSortExec { @@ -114,7 +115,7 @@ impl PartialSortExec { metrics_set: ExecutionPlanMetricsSet::new(), preserve_partitioning, fetch: None, - cache, + cache: Arc::new(cache), } } @@ -132,12 +133,8 @@ impl PartialSortExec { /// input partitions producing a single, sorted partition. pub fn with_preserve_partitioning(mut self, preserve_partitioning: bool) -> Self { self.preserve_partitioning = preserve_partitioning; - self.cache = self - .cache - .with_partitioning(Self::output_partitioning_helper( - &self.input, - self.preserve_partitioning, - )); + Arc::make_mut(&mut self.cache).partitioning = + Self::output_partitioning_helper(&self.input, self.preserve_partitioning); self } @@ -207,6 +204,17 @@ impl PartialSortExec { input.boundedness(), )) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics_set: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for PartialSortExec { @@ -220,9 +228,17 @@ impl DisplayAs for PartialSortExec { let common_prefix_length = self.common_prefix_length; match self.fetch { Some(fetch) => { - write!(f, "PartialSortExec: TopK(fetch={fetch}), expr=[{}], common_prefix_length=[{common_prefix_length}]", self.expr) + write!( + f, + "PartialSortExec: TopK(fetch={fetch}), expr=[{}], common_prefix_length=[{common_prefix_length}]", + self.expr + ) } - None => write!(f, "PartialSortExec: expr=[{}], common_prefix_length=[{common_prefix_length}]", self.expr), + None => write!( + f, + "PartialSortExec: expr=[{}], common_prefix_length=[{common_prefix_length}]", + self.expr + ), } } DisplayFormatType::TreeRender => match self.fetch { @@ -243,11 +259,7 @@ impl ExecutionPlan for PartialSortExec { "PartialSortExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -275,6 +287,7 @@ impl ExecutionPlan for PartialSortExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); let new_partial_sort = PartialSortExec::new( self.expr.clone(), Arc::clone(&children[0]), @@ -291,7 +304,12 @@ impl ExecutionPlan for PartialSortExec { partition: usize, context: Arc, ) -> Result { - trace!("Start PartialSortExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + trace!( + "Start PartialSortExec::execute for partition {} of context session_id {} and task_id {:?}", + partition, + context.session_id(), + context.task_id() + ); let input = self.input.execute(partition, Arc::clone(&context))?; @@ -316,11 +334,7 @@ impl ExecutionPlan for PartialSortExec { Some(self.metrics_set.clone_inner()) } - fn statistics(&self) -> Result { - self.input.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { self.input.partition_statistics(partition) } } @@ -378,6 +392,9 @@ impl PartialSortStream { // Check if we've already reached the fetch limit if self.fetch == Some(0) { self.is_closed = true; + // Release the input pipeline's resources. + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); return Poll::Ready(None); } @@ -411,6 +428,9 @@ impl PartialSortStream { Some(Err(e)) => return Poll::Ready(Some(Err(e))), None => { self.is_closed = true; + // Release the input pipeline's resources before sorting. + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); // Once input is consumed, sort the rest of the inserted batches let remaining_batch = self.sort_in_mem_batch()?; return if remaining_batch.num_rows() > 0 { @@ -484,13 +504,13 @@ mod tests { use itertools::Itertools; use crate::collect; - use crate::expressions::col; use crate::expressions::PhysicalSortExpr; + use crate::expressions::col; use crate::sorts::sort::SortExec; use crate::test; - use crate::test::assert_is_pending; - use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::test::TestMemoryExec; + use crate::test::assert_is_pending; + use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero}; use super::*; @@ -536,18 +556,18 @@ mod tests { assert_eq!(2, result.len()); allow_duplicates! { - assert_snapshot!(batches_to_string(&result), @r#" - +---+---+---+ - | a | b | c | - +---+---+---+ - | 0 | 1 | 0 | - | 0 | 1 | 1 | - | 0 | 2 | 5 | - | 1 | 2 | 4 | - | 1 | 3 | 2 | - | 1 | 3 | 3 | - +---+---+---+ - "#); + assert_snapshot!(batches_to_string(&result), @r" + +---+---+---+ + | a | b | c | + +---+---+---+ + | 0 | 1 | 0 | + | 0 | 1 | 1 | + | 0 | 2 | 5 | + | 1 | 2 | 4 | + | 1 | 3 | 2 | + | 1 | 3 | 3 | + +---+---+---+ + "); } assert_eq!( task_ctx.runtime_env().memory_pool.reserved(), @@ -604,16 +624,16 @@ mod tests { assert_eq!(2, result.len()); allow_duplicates! { - assert_snapshot!(batches_to_string(&result), @r#" - +---+---+---+ - | a | b | c | - +---+---+---+ - | 0 | 1 | 4 | - | 0 | 2 | 3 | - | 1 | 2 | 2 | - | 1 | 3 | 0 | - +---+---+---+ - "#); + assert_snapshot!(batches_to_string(&result), @r" + +---+---+---+ + | a | b | c | + +---+---+---+ + | 0 | 1 | 4 | + | 0 | 2 | 3 | + | 1 | 2 | 2 | + | 1 | 3 | 0 | + +---+---+---+ + "); } assert_eq!( task_ctx.runtime_env().memory_pool.reserved(), @@ -680,20 +700,20 @@ mod tests { "The sort should have returned all memory used back to the memory manager" ); allow_duplicates! { - assert_snapshot!(batches_to_string(&result), @r#" - +---+---+---+ - | a | b | c | - +---+---+---+ - | 0 | 1 | 6 | - | 0 | 1 | 7 | - | 0 | 3 | 4 | - | 0 | 3 | 5 | - | 1 | 2 | 0 | - | 1 | 2 | 1 | - | 1 | 4 | 2 | - | 1 | 4 | 3 | - +---+---+---+ - "#); + assert_snapshot!(batches_to_string(&result), @r" + +---+---+---+ + | a | b | c | + +---+---+---+ + | 0 | 1 | 6 | + | 0 | 1 | 7 | + | 0 | 3 | 4 | + | 0 | 3 | 5 | + | 1 | 2 | 0 | + | 1 | 2 | 1 | + | 1 | 4 | 2 | + | 1 | 4 | 3 | + +---+---+---+ + "); } } Ok(()) @@ -1038,20 +1058,20 @@ mod tests { task_ctx, ) .await?; - assert_snapshot!(batches_to_string(&result), @r#" - +-----+------+-------+ - | a | b | c | - +-----+------+-------+ - | 1.0 | 20.0 | 20.0 | - | 1.0 | 20.0 | 10.0 | - | 1.0 | 40.0 | 10.0 | - | 2.0 | 40.0 | 100.0 | - | 2.0 | NaN | NaN | - | 3.0 | | | - | 3.0 | | 100.0 | - | 3.0 | NaN | NaN | - +-----+------+-------+ - "#); + assert_snapshot!(batches_to_string(&result), @r" + +-----+------+-------+ + | a | b | c | + +-----+------+-------+ + | 1.0 | 20.0 | 20.0 | + | 1.0 | 20.0 | 10.0 | + | 1.0 | 40.0 | 10.0 | + | 2.0 | 40.0 | 100.0 | + | 2.0 | NaN | NaN | + | 3.0 | | | + | 3.0 | | 100.0 | + | 3.0 | NaN | NaN | + +-----+------+-------+ + "); assert_eq!(result.len(), 2); let metrics = partial_sort_exec.metrics().unwrap(); assert!(metrics.elapsed_compute().unwrap() > 0); @@ -1164,21 +1184,21 @@ mod tests { assert_eq!(result.len(), 3,); allow_duplicates! { - assert_snapshot!(batches_to_string(&result), @r#" - +---+---+---+ - | a | b | c | - +---+---+---+ - | 1 | 1 | 1 | - | 1 | 1 | 2 | - | 1 | 1 | 3 | - | 2 | 2 | 4 | - | 2 | 2 | 4 | - | 2 | 2 | 6 | - | 3 | 3 | 7 | - | 3 | 3 | 8 | - | 3 | 3 | 9 | - +---+---+---+ - "#); + assert_snapshot!(batches_to_string(&result), @r" + +---+---+---+ + | a | b | c | + +---+---+---+ + | 1 | 1 | 1 | + | 1 | 1 | 2 | + | 1 | 1 | 3 | + | 2 | 2 | 4 | + | 2 | 2 | 4 | + | 2 | 2 | 6 | + | 3 | 3 | 7 | + | 3 | 3 | 8 | + | 3 | 3 | 9 | + +---+---+---+ + "); } assert_eq!(task_ctx.runtime_env().memory_pool.reserved(), 0,); diff --git a/datafusion/physical-plan/src/sorts/partitioned_topk.rs b/datafusion/physical-plan/src/sorts/partitioned_topk.rs new file mode 100644 index 0000000000000..fe876eeddf7f2 --- /dev/null +++ b/datafusion/physical-plan/src/sorts/partitioned_topk.rs @@ -0,0 +1,515 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`PartitionedTopKExec`]: Top-K per partition operator +//! +//! For queries like: +//! ```sql +//! SELECT *, ROW_NUMBER() OVER (PARTITION BY pk ORDER BY val) as rn +//! FROM t WHERE rn <= N +//! ``` +//! +//! Instead of sorting the entire dataset, this operator maintains a +//! [`TopK`] heap per partition (reusing the existing TopK implementation) +//! and emits only the top-K rows per partition in sorted order +//! `(partition_keys, order_keys)`. + +use std::fmt::{self, Formatter}; +use std::sync::Arc; + +use arrow::array::{RecordBatch, UInt32Array}; +use arrow::compute::{BatchCoalescer, take_record_batch}; +use arrow::datatypes::SchemaRef; +use arrow::row::{OwnedRow, RowConverter}; +use datafusion_common::{HashMap, Result}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use futures::StreamExt; +use futures::TryStreamExt; +use parking_lot::RwLock; + +use crate::execution_plan::{Boundedness, EmissionType}; +use crate::metrics::ExecutionPlanMetricsSet; +use crate::topk::{TopK, TopKDynamicFilters, build_sort_fields}; +use crate::{ + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, + PlanProperties, SendableRecordBatchStream, stream::RecordBatchStreamAdapter, +}; + +/// Per-partition Top-K operator for window function queries. +/// +/// # Background +/// +/// "Top K per partition" is a common analytics pattern used for queries such as +/// "find the top 3 products by revenue for each store". The (simplified) SQL +/// for such a query might be: +/// +/// ```sql +/// SELECT * FROM ( +/// SELECT *, ROW_NUMBER() OVER (PARTITION BY store ORDER BY revenue DESC) as rn +/// FROM sales +/// ) WHERE rn <= 3; +/// ``` +/// +/// The unoptimized physical plan would be: +/// +/// ```text +/// FilterExec: rn <= 3 +/// BoundedWindowAggExec: ROW_NUMBER() PARTITION BY [store] ORDER BY [revenue DESC] +/// SortExec: expr=[store ASC, revenue DESC] +/// DataSourceExec +/// ``` +/// +/// This plan sorts the **entire** dataset (O(N log N)), computes `ROW_NUMBER` +/// for **all** rows, and then filters to keep only the top K per partition. +/// With 10M rows, 1K partitions, and K=3, it sorts all 10M rows but only +/// keeps 3K. +/// +/// # Optimization +/// +/// `PartitionedTopKExec` replaces the `SortExec` and the `FilterExec` is +/// removed. The optimized plan becomes: +/// +/// ```text +/// BoundedWindowAggExec: ROW_NUMBER() PARTITION BY [store] ORDER BY [revenue DESC] +/// PartitionedTopKExec: fetch=3, partition=[store], order=[revenue DESC] +/// DataSourceExec +/// ``` +/// +/// Instead of sorting the entire dataset, this operator reads unsorted input, +/// maintains a [`TopK`] heap per distinct partition key, and emits only the +/// top-K rows per partition in sorted order `(partition_keys, order_keys)`. +/// +/// Cost: O(N log K) time instead of O(N log N), and O(K × P × row_size) +/// memory where K = fetch, P = number of distinct partitions. +/// ## Why maintaining partition key order in output +/// Window functions do not require partition keys to be globally sorted, and +/// enforcing such ordering in the output can introduce unnecessary overhead. +/// However, the physical optimizer framework currently cannot express an +/// ordering that is only grouped by some keys while ordered by others. For +/// example: +/// +/// +/// # Example +/// +/// For the query above with `fetch=3` and input: +/// +/// ```text +/// store | revenue +/// ------|-------- +/// A | 100 +/// B | 50 +/// A | 200 +/// B | 150 +/// A | 300 +/// A | 400 +/// ``` +/// +/// The operator maintains two heaps: +/// - **store=A**: keeps top-3 by revenue DESC → {400, 300, 200}, evicts 100 +/// - **store=B**: keeps top-3 by revenue DESC → {150, 50} (only 2 rows) +/// +/// Output (sorted by store ASC, revenue DESC): +/// +/// ```text +/// store | revenue +/// ------|-------- +/// A | 400 +/// A | 300 +/// A | 200 +/// B | 150 +/// B | 50 +/// ``` +/// +/// This is then passed to `BoundedWindowAggExec` which assigns +/// `ROW_NUMBER` 1, 2, 3 to each partition — all of which satisfy `rn <= 3`. +/// +/// # Limitations +/// +/// - Only activated when the window function is `ROW_NUMBER` with a +/// `PARTITION BY` clause. Global top-K (no `PARTITION BY`) is already +/// handled efficiently by `SortExec` with `fetch`. +/// - For very high cardinality partition keys (millions of distinct values), +/// both memory usage and runtime overhead can become significant. In such +/// cases, the sort-based plan is more robust. Therefore, this optimization +/// is currently controlled by a configuration flag. +#[derive(Debug, Clone)] +pub struct PartitionedTopKExec { + /// Input execution plan (reads unsorted data) + input: Arc, + /// Full sort expressions: `[partition_keys..., order_keys...]`. + /// + /// For `PARTITION BY store ORDER BY revenue DESC` with sort + /// `[store ASC, revenue DESC]`, the first `partition_prefix_len` + /// expressions are the partition keys (`[store ASC]`) and the + /// remaining are the order-by keys (`[revenue DESC]`). + expr: LexOrdering, + /// Number of leading expressions in `expr` that define the partition + /// key. For example, `PARTITION BY a, b` → `partition_prefix_len = 2`. + partition_prefix_len: usize, + /// Maximum number of rows to keep per partition (the K in "top-K"). + /// Derived from the filter predicate: `rn <= 3` → `fetch = 3`, + /// `rn < 3` → `fetch = 2`. + fetch: usize, + /// Execution metrics + metrics_set: ExecutionPlanMetricsSet, + /// Cached plan properties (output ordering, partitioning, etc.) + cache: Arc, +} + +impl PartitionedTopKExec { + /// Create a new `PartitionedTopKExec`. + /// + /// # Arguments + /// + /// * `input` - The child execution plan providing unsorted input rows. + /// * `expr` - Full sort ordering `[partition_keys..., order_keys...]`. + /// For `PARTITION BY pk ORDER BY val ASC`, this would be `[pk ASC, val ASC]`. + /// * `partition_prefix_len` - Number of leading expressions in `expr` + /// that form the partition key. Must be >= 1. + /// * `fetch` - Maximum rows to retain per partition (the K in "top-K"). + /// + /// # Example + /// + /// ```text + /// // For: ROW_NUMBER() OVER (PARTITION BY store ORDER BY revenue DESC) ... WHERE rn <= 5 + /// PartitionedTopKExec::try_new( + /// data_source, + /// LexOrdering([store ASC, revenue DESC]), + /// 1, // partition_prefix_len: 1 partition column (store) + /// 5, // fetch: keep top 5 per partition + /// ) + /// ``` + pub fn try_new( + input: Arc, + expr: LexOrdering, + partition_prefix_len: usize, + fetch: usize, + ) -> Result { + let cache = Self::compute_properties(&input, expr.clone())?; + Ok(Self { + input, + expr, + partition_prefix_len, + fetch, + metrics_set: ExecutionPlanMetricsSet::new(), + cache: Arc::new(cache), + }) + } + + /// Returns the child execution plan. + pub fn input(&self) -> &Arc { + &self.input + } + + /// Returns the full sort ordering `[partition_keys..., order_keys...]`. + pub fn expr(&self) -> &LexOrdering { + &self.expr + } + + /// Returns the number of leading expressions in [`Self::expr`] that + /// define the partition key. + pub fn partition_prefix_len(&self) -> usize { + self.partition_prefix_len + } + + /// Returns the maximum number of rows retained per partition. + pub fn fetch(&self) -> usize { + self.fetch + } + + /// Compute [`PlanProperties`] for this operator. + /// + /// The output is sorted by `sort_exprs` (partition keys then order keys), + /// uses the same partitioning as the input, emits all output at once + /// (`EmissionType::Final`), and is bounded. + fn compute_properties( + input: &Arc, + sort_exprs: LexOrdering, + ) -> Result { + let mut eq_properties = input.equivalence_properties().clone(); + eq_properties.reorder(sort_exprs)?; + + Ok(PlanProperties::new( + eq_properties, + input.output_partitioning().clone(), + EmissionType::Final, + Boundedness::Bounded, + )) + } +} + +impl DisplayAs for PartitionedTopKExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let partition_exprs: Vec = self.expr[..self.partition_prefix_len] + .iter() + .map(|e| format!("{}", e.expr)) + .collect(); + let order_exprs: Vec = self.expr[self.partition_prefix_len..] + .iter() + .map(|e| format!("{e}")) + .collect(); + write!( + f, + "PartitionedTopKExec: fetch={}, partition=[{}], order=[{}]", + self.fetch, + partition_exprs.join(", "), + order_exprs.join(", "), + ) + } + DisplayFormatType::TreeRender => { + let partition_exprs: Vec = self.expr[..self.partition_prefix_len] + .iter() + .map(|e| format!("{}", e.expr)) + .collect(); + let order_exprs: Vec = self.expr[self.partition_prefix_len..] + .iter() + .map(|e| format!("{e}")) + .collect(); + writeln!(f, "fetch={}", self.fetch)?; + writeln!(f, "partition=[{}]", partition_exprs.join(", "))?; + writeln!(f, "order=[{}]", order_exprs.join(", ")) + } + } + } +} + +impl ExecutionPlan for PartitionedTopKExec { + fn name(&self) -> &'static str { + "PartitionedTopKExec" + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn required_input_distribution(&self) -> Vec { + let partition_exprs: Vec> = self.expr + [..self.partition_prefix_len] + .iter() + .map(|e| Arc::clone(&e.expr)) + .collect(); + vec![Distribution::HashPartitioned(partition_exprs)] + } + + fn maintains_input_order(&self) -> Vec { + vec![false] + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert_eq!(children.len(), 1); + Ok(Arc::new(PartitionedTopKExec::try_new( + Arc::clone(&children[0]), + self.expr.clone(), + self.partition_prefix_len, + self.fetch, + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let input = self.input.execute(partition, Arc::clone(&context))?; + let schema = input.schema(); + + let partition_sort_fields = + build_sort_fields(&self.expr[..self.partition_prefix_len], &schema)?; + + let partition_converter = RowConverter::new(partition_sort_fields)?; + + let partition_exprs: Vec> = self.expr + [..self.partition_prefix_len] + .iter() + .map(|e| Arc::clone(&e.expr)) + .collect(); + let order_expr: LexOrdering = + LexOrdering::new(self.expr[self.partition_prefix_len..].iter().cloned()) + .expect("PartitionedTopKExec requires at least one order-by expression"); + let fetch = self.fetch; + let batch_size = context.session_config().batch_size(); + let runtime = Arc::clone(&context.runtime_env()); + let metrics_set = self.metrics_set.clone(); + + let stream = futures::stream::once(async move { + do_partitioned_topk( + input, + schema, + partition_converter, + partition_exprs, + order_expr, + fetch, + batch_size, + runtime, + metrics_set, + ) + .await + }) + .try_flatten(); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.input.schema(), + stream, + ))) + } +} + +/// Create a no-op [`TopKDynamicFilters`] for a per-partition [`TopK`]. +/// +/// In normal `SortExec` top-K mode, dynamic filters push predicates down to +/// the data source (e.g., telling Parquet to skip rows worse than the current +/// K-th best). For per-partition heaps the data is already in memory and split +/// by partition key, so there is no data source to push filters to. We pass +/// `lit(true)` (accept everything) so the filter never rejects any row. +fn create_noop_dynamic_filter() -> Arc> { + Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new( + DynamicFilterPhysicalExpr::new(vec![], lit(true)), + )))) +} + +/// Read all input, split batches by partition key, feed each sub-batch +/// to a per-partition [`TopK`], then emit results in partition-key order. +/// +/// # Phases +/// +/// 1. **Accumulation** — For each input batch: +/// - Evaluate partition key expressions to get partition column arrays +/// - Convert partition columns to binary [`arrow::row::Row`] format +/// - Group row indices by partition key +/// - Extract sub-batches via [`take_record_batch`] and insert into +/// the partition's [`TopK`] heap +/// +/// 2. **Emission** — After all input is consumed: +/// - Sort partition keys so output is ordered by partition key +/// - For each partition in sorted order, call [`TopK::emit`] to get +/// rows sorted by order-by key +/// - Return all batches as a single stream +/// +/// # Cost +/// +/// - Time: O(N log K) where N = total rows, K = fetch +/// - Memory: O(K × P × row_size) where P = number of distinct partitions +#[expect(clippy::too_many_arguments)] +async fn do_partitioned_topk( + mut input: SendableRecordBatchStream, + schema: SchemaRef, + partition_converter: RowConverter, + partition_exprs: Vec>, + order_expr: LexOrdering, + fetch: usize, + batch_size: usize, + runtime: Arc, + metrics_set: ExecutionPlanMetricsSet, +) -> Result { + let mut partitions: HashMap = HashMap::new(); + let mut partition_counter: usize = 0; + + // Macro-like helper: create a new TopK for a partition + macro_rules! new_topk { + () => {{ + let id = partition_counter; + partition_counter += 1; + TopK::try_new( + id, + Arc::clone(&schema), + vec![], + order_expr.clone(), + fetch, + batch_size, + Arc::clone(&runtime), + &metrics_set, + create_noop_dynamic_filter(), + ) + }}; + } + + // ---------- Accumulation phase ---------- + while let Some(batch) = input.next().await { + let batch = batch?; + let num_rows = batch.num_rows(); + if num_rows == 0 { + continue; + } + + // Evaluate partition key columns + let pk_arrays: Vec<_> = partition_exprs + .iter() + .map(|e| e.evaluate(&batch).and_then(|v| v.into_array(num_rows))) + .collect::>>()?; + + let pk_rows = partition_converter.convert_columns(&pk_arrays)?; + + // Group row indices by partition key + let mut groups: HashMap> = HashMap::new(); + for row_idx in 0..num_rows { + let pk = pk_rows.row(row_idx).owned(); + groups.entry(pk).or_default().push(row_idx as u32); + } + + // For each partition group, create a sub-batch and feed to TopK + for (pk, indices) in groups { + if !partitions.contains_key(&pk) { + partitions.insert(pk.clone(), new_topk!()?); + } + let topk = partitions.get_mut(&pk).unwrap(); + let indices_array = UInt32Array::from(indices); + let sub_batch = take_record_batch(&batch, &indices_array)?; + topk.insert_batch(sub_batch)?; + } + } + // Release the input pipeline now that accumulation is complete. + drop(input); + + // ---------- Emit phase ---------- + // Sort partition keys so output is ordered by (partition_keys, order_keys). + let mut sorted_pks: Vec = partitions.keys().cloned().collect(); + sorted_pks.sort(); + + let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), batch_size); + + for pk in sorted_pks { + if let Some(topk) = partitions.remove(&pk) { + // TopK::emit() returns a stream of sorted batches + let mut stream = topk.emit()?; + while let Some(batch) = stream.next().await { + coalescer.push_batch(batch?)?; + } + } + } + coalescer.finish_buffered_batch()?; + let mut output_batches: Vec = Vec::new(); + while let Some(batch) = coalescer.next_completed_batch() { + output_batches.push(batch); + } + + Ok(Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::stream::iter(output_batches.into_iter().map(Ok)), + ))) +} diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index c4c871da749d9..929ff4f7dfc85 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -19,7 +19,6 @@ //! It will do in-memory sorting if it has enough memory budget //! but spills to disk if needed. -use std::any::Any; use std::fmt; use std::fmt::{Debug, Formatter}; use std::sync::Arc; @@ -27,23 +26,27 @@ use std::sync::Arc; use parking_lot::RwLock; use crate::common::spawn_buffered; -use crate::execution_plan::{Boundedness, CardinalityEffect, EmissionType}; +use crate::execution_plan::{ + Boundedness, CardinalityEffect, EmissionType, has_same_children_properties, +}; use crate::expressions::PhysicalSortExpr; +use crate::filter::FilterExec; use crate::filter_pushdown::{ - ChildFilterDescription, FilterDescription, FilterPushdownPhase, + ChildFilterDescription, ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, PushedDown, }; use crate::limit::LimitStream; use crate::metrics::{ - BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, SpillMetrics, - SplitMetrics, + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, SpillMetrics, }; -use crate::projection::{make_with_child, update_ordering, ProjectionExec}; +use crate::projection::{ProjectionExec, make_with_child, update_ordering}; +use crate::sorts::IncrementalSortIterator; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::spill::get_record_batch_memory_size; use crate::spill::in_progress_spill_file::InProgressSpillFile; use crate::spill::spill_manager::{GetSlicedSize, SpillManager}; -use crate::stream::BatchSplitStream; use crate::stream::RecordBatchStreamAdapter; +use crate::stream::ReservationStream; use crate::topk::TopK; use crate::topk::TopKDynamicFilters; use crate::{ @@ -52,20 +55,20 @@ use crate::{ Statistics, }; -use arrow::array::{Array, RecordBatch, RecordBatchOptions, StringViewArray}; +use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays}; use arrow::datatypes::SchemaRef; use datafusion_common::config::SpillCompression; use datafusion_common::{ - assert_or_internal_err, internal_datafusion_err, unwrap_or_internal_err, - DataFusionError, Result, + DataFusionError, Result, assert_or_internal_err, internal_datafusion_err, + unwrap_or_internal_err, }; +use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_execution::TaskContext; -use datafusion_physical_expr::expressions::{lit, DynamicFilterPhysicalExpr}; use datafusion_physical_expr::LexOrdering; use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit}; use futures::{StreamExt, TryStreamExt}; use log::{debug, trace}; @@ -75,8 +78,6 @@ struct ExternalSorterMetrics { baseline: BaselineMetrics, spill_metrics: SpillMetrics, - - split_metrics: SplitMetrics, } impl ExternalSorterMetrics { @@ -84,7 +85,6 @@ impl ExternalSorterMetrics { Self { baseline: BaselineMetrics::new(metrics, partition), spill_metrics: SpillMetrics::new(metrics, partition), - split_metrics: SplitMetrics::new(metrics, partition), } } } @@ -266,7 +266,7 @@ struct ExternalSorter { impl ExternalSorter { // TODO: make a builder or some other nicer API to avoid the // clippy warning - #[allow(clippy::too_many_arguments)] + #[expect(clippy::too_many_arguments)] pub fn new( partition_id: usize, schema: SchemaRef, @@ -342,11 +342,6 @@ impl ExternalSorter { /// 2. A combined streaming merge incorporating both in-memory /// batches and data from spill files on disk. async fn sort(&mut self) -> Result { - // Release the memory reserved for merge back to the pool so - // there is some left when `in_mem_sort_stream` requests an - // allocation. - self.merge_reservation.free(); - if self.spilled_before() { // Sort `in_mem_batches` and spill it first. If there are many // `in_mem_batches` and the memory limit is almost reached, merging @@ -355,6 +350,13 @@ impl ExternalSorter { self.sort_and_spill_in_mem_batches().await?; } + // Transfer the pre-reserved merge memory to the streaming merge + // using `take()` instead of `new_empty()`. This ensures the merge + // stream starts with `sort_spill_reservation_bytes` already + // allocated, preventing starvation when concurrent sort partitions + // compete for pool memory. `take()` moves the bytes atomically + // without releasing them back to the pool, so other partitions + // cannot race to consume the freed memory. StreamingMergeBuilder::new() .with_sorted_spill_files(std::mem::take(&mut self.finished_spill_files)) .with_spill_manager(self.spill_manager.clone()) @@ -363,9 +365,14 @@ impl ExternalSorter { .with_metrics(self.metrics.baseline.clone()) .with_batch_size(self.batch_size) .with_fetch(None) - .with_reservation(self.merge_reservation.new_empty()) + .with_reservation(self.merge_reservation.take()) .build() } else { + // Release the memory reserved for merge back to the pool so + // there is some left when `in_mem_sort_stream` requests an + // allocation. Only needed for the non-spill path; the spill + // path transfers the reservation to the merge stream instead. + self.merge_reservation.free(); self.in_mem_sort_stream(self.metrics.baseline.clone()) } } @@ -375,6 +382,12 @@ impl ExternalSorter { self.reservation.size() } + /// How much memory is reserved for the merge phase? + #[cfg(test)] + fn merge_reservation_size(&self) -> usize { + self.merge_reservation.size() + } + /// How many bytes have been spilled to disk? fn spilled_bytes(&self) -> usize { self.metrics.spill_metrics.spilled_bytes.value() @@ -406,8 +419,6 @@ impl ExternalSorter { Some((self.spill_manager.create_in_progress_file("Sorting")?, 0)); } - Self::organize_stringview_arrays(globally_sorted_batches)?; - debug!("Spilling sort data of ExternalSorter to disk whilst inserting"); let batches_to_spill = std::mem::take(globally_sorted_batches); @@ -419,10 +430,9 @@ impl ExternalSorter { })?; for batch in batches_to_spill { - in_progress_file.append_batch(&batch)?; + let gc_sliced_size = in_progress_file.append_batch(&batch)?; - *max_record_batch_size = - (*max_record_batch_size).max(batch.get_sliced_size()?); + *max_record_batch_size = (*max_record_batch_size).max(gc_sliced_size); } assert_or_internal_err!( @@ -451,71 +461,6 @@ impl ExternalSorter { Ok(()) } - /// Reconstruct `globally_sorted_batches` to organize the payload buffers of each - /// `StringViewArray` in sequential order by calling `gc()` on them. - /// - /// Note this is a workaround until is - /// available - /// - /// # Rationale - /// After (merge-based) sorting, all batches will be sorted into a single run, - /// but physically this sorted run is chunked into many small batches. For - /// `StringViewArray`s inside each sorted run, their inner buffers are not - /// re-constructed by default, leading to non-sequential payload locations - /// (permutated by `interleave()` Arrow kernel). A single payload buffer might - /// be shared by multiple `RecordBatch`es. - /// When writing each batch to disk, the writer has to write all referenced buffers, - /// because they have to be read back one by one to reduce memory usage. This - /// causes extra disk reads and writes, and potentially execution failure. - /// - /// # Example - /// Before sorting: - /// batch1 -> buffer1 - /// batch2 -> buffer2 - /// - /// sorted_batch1 -> buffer1 - /// -> buffer2 - /// sorted_batch2 -> buffer1 - /// -> buffer2 - /// - /// Then when spilling each batch, the writer has to write all referenced buffers - /// repeatedly. - fn organize_stringview_arrays( - globally_sorted_batches: &mut Vec, - ) -> Result<()> { - let mut organized_batches = Vec::with_capacity(globally_sorted_batches.len()); - - for batch in globally_sorted_batches.drain(..) { - let mut new_columns: Vec> = - Vec::with_capacity(batch.num_columns()); - - let mut arr_mutated = false; - for array in batch.columns() { - if let Some(string_view_array) = - array.as_any().downcast_ref::() - { - let new_array = string_view_array.gc(); - new_columns.push(Arc::new(new_array)); - arr_mutated = true; - } else { - new_columns.push(Arc::clone(array)); - } - } - - let organized_batch = if arr_mutated { - RecordBatch::try_new(batch.schema(), new_columns)? - } else { - batch - }; - - organized_batches.push(organized_batch); - } - - *globally_sorted_batches = organized_batches; - - Ok(()) - } - /// Sorts the in-memory batches and merges them into a single sorted run, then writes /// the result to spill files. async fn sort_and_spill_in_mem_batches(&mut self) -> Result<()> { @@ -545,7 +490,7 @@ impl ExternalSorter { while let Some(batch) = sorted_stream.next().await { let batch = batch?; - let sorted_size = get_reserved_byte_for_record_batch(&batch); + let sorted_size = get_reserved_bytes_for_record_batch(&batch)?; if self.reservation.try_grow(sorted_size).is_err() { // Although the reservation is not enough, the batch is // already in memory, so it's okay to combine it with previously @@ -662,7 +607,7 @@ impl ExternalSorter { if self.in_mem_batches.len() == 1 { let batch = self.in_mem_batches.swap_remove(0); let reservation = self.reservation.take(); - return self.sort_batch_stream(batch, metrics, reservation, true); + return self.sort_batch_stream(batch, &metrics, reservation); } // If less than sort_in_place_threshold_bytes, concatenate and sort in place @@ -671,10 +616,10 @@ impl ExternalSorter { let batch = concat_batches(&self.schema, &self.in_mem_batches)?; self.in_mem_batches.clear(); self.reservation - .try_resize(get_reserved_byte_for_record_batch(&batch)) + .try_resize(get_reserved_bytes_for_record_batch(&batch)?) .map_err(Self::err_with_oom_context)?; let reservation = self.reservation.take(); - return self.sort_batch_stream(batch, metrics, reservation, true); + return self.sort_batch_stream(batch, &metrics, reservation); } let streams = std::mem::take(&mut self.in_mem_batches) @@ -683,15 +628,8 @@ impl ExternalSorter { let metrics = self.metrics.baseline.intermediate(); let reservation = self .reservation - .split(get_reserved_byte_for_record_batch(&batch)); - let input = self.sort_batch_stream( - batch, - metrics, - reservation, - // Passing false as `StreamingMergeBuilder` will split the - // stream into batches of `self.batch_size` rows. - false, - )?; + .split(get_reserved_bytes_for_record_batch(&batch)?); + let input = self.sort_batch_stream(batch, &metrics, reservation)?; Ok(spawn_buffered(input, 1)) }) .collect::>()?; @@ -709,52 +647,67 @@ impl ExternalSorter { /// Sorts a single `RecordBatch` into a single stream. /// - /// `reservation` accounts for the memory used by this batch and - /// is released when the sort is complete - /// - /// passing `split` true will return a [`BatchSplitStream`] where each batch maximum row count - /// will be `self.batch_size`. - /// If `split` is false, the stream will return a single batch + /// This may output multiple batches depending on the size of the + /// sorted data and the target batch size. + /// For single-batch output cases, `reservation` will be freed immediately after sorting, + /// as the batch will be output and is expected to be reserved by the consumer of the stream. + /// For multi-batch output cases, `reservation` will be grown to match the actual + /// size of sorted output, and as each batch is output, its memory will be freed from the reservation. + /// (This leads to the same behaviour, as futures are only evaluated when polled by the consumer.) fn sort_batch_stream( &self, batch: RecordBatch, - metrics: BaselineMetrics, + metrics: &BaselineMetrics, reservation: MemoryReservation, - mut split: bool, ) -> Result { assert_eq!( - get_reserved_byte_for_record_batch(&batch), + get_reserved_bytes_for_record_batch(&batch)?, reservation.size() ); - split = split && batch.num_rows() > self.batch_size; - let schema = batch.schema(); - let expressions = self.expr.clone(); - let stream = futures::stream::once(async move { - let _timer = metrics.elapsed_compute().timer(); + let batch_size = self.batch_size; + let output_row_metrics = metrics.output_rows().clone(); - let sorted = sort_batch(&batch, &expressions, None)?; + let stream = futures::stream::once(async move { + let schema = batch.schema(); - (&sorted).record_output(&metrics); - drop(batch); - drop(reservation); - Ok(sorted) - }); + // Sort the batch immediately and get all output batches + let sorted_batches = sort_batch_chunked(&batch, &expressions, batch_size)?; - let mut output: SendableRecordBatchStream = - Box::pin(RecordBatchStreamAdapter::new(schema, stream)); + // Resize the reservation to match the actual sorted output size. + // Using try_resize avoids a release-then-reacquire cycle, which + // matters for MemoryPool implementations where grow/shrink have + // non-trivial cost (e.g. JNI calls in Comet). + let total_sorted_size: usize = sorted_batches + .iter() + .map(get_record_batch_memory_size) + .sum(); + reservation + .try_resize(total_sorted_size) + .map_err(Self::err_with_oom_context)?; - if split { - output = Box::pin(BatchSplitStream::new( - output, - self.batch_size, - self.metrics.split_metrics.clone(), - )); - } + // Wrap in ReservationStream to hold the reservation + Result::<_, DataFusionError>::Ok(Box::pin(ReservationStream::new( + Arc::clone(&schema), + Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter(sorted_batches.into_iter().map(Ok)), + )), + reservation, + )) as SendableRecordBatchStream) + }) + .try_flatten() + .map(move |batch| match batch { + Ok(batch) => { + output_row_metrics.add(batch.num_rows()); + Ok(batch) + } + Err(e) => Err(e), + }); - Ok(output) + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) } /// If this sort may spill, pre-allocates @@ -780,7 +733,7 @@ impl ExternalSorter { &mut self, input: &RecordBatch, ) -> Result<()> { - let size = get_reserved_byte_for_record_batch(input); + let size = get_reserved_bytes_for_record_batch(input)?; match self.reservation.try_grow(size) { Ok(_) => Ok(()), @@ -804,7 +757,8 @@ impl ExternalSorter { match e { DataFusionError::ResourcesExhausted(_) => e.context( "Not enough memory to continue external sort. \ - Consider increasing the memory limit, or decreasing sort_spill_reservation_bytes" + Consider increasing the memory limit config: 'datafusion.runtime.memory_limit', \ + or decreasing the config: 'datafusion.execution.sort_spill_reservation_bytes'." ), // This is not an OOM error, so just return it as is. _ => e, @@ -819,16 +773,29 @@ impl ExternalSorter { /// in sorting and merging. The sorted copies are in either row format or array format. /// Please refer to cursor.rs and stream.rs for more details. No matter what format the /// sorted copies are, they will use more memory than the original record batch. -pub(crate) fn get_reserved_byte_for_record_batch_size(record_batch_size: usize) -> usize { - // 2x may not be enough for some cases, but it's a good start. +/// +/// This can basically be calculated as the sum of the actual space it takes in +/// memory (which would be larger for a sliced batch), and the size of the actual data. +pub(crate) fn get_reserved_bytes_for_record_batch_size( + record_batch_size: usize, + sliced_size: usize, +) -> usize { + // Even 2x may not be enough for some cases, but it's a good enough estimation as a baseline. // If 2x is not enough, user can set a larger value for `sort_spill_reservation_bytes` // to compensate for the extra memory needed. - record_batch_size * 2 + record_batch_size + sliced_size } /// Estimate how much memory is needed to sort a `RecordBatch`. -fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> usize { - get_reserved_byte_for_record_batch_size(get_record_batch_memory_size(batch)) +/// This will just call `get_reserved_bytes_for_record_batch_size` with the +/// memory size of the record batch and its sliced size. +pub(crate) fn get_reserved_bytes_for_record_batch(batch: &RecordBatch) -> Result { + batch.get_sliced_size().map(|sliced_size| { + get_reserved_bytes_for_record_batch_size( + get_record_batch_memory_size(batch), + sliced_size, + ) + }) } impl Debug for ExternalSorter { @@ -853,15 +820,7 @@ pub fn sort_batch( .collect::>>()?; let indices = lexsort_to_indices(&sort_columns, fetch)?; - let mut columns = take_arrays(batch.columns(), &indices, None)?; - - // The columns may be larger than the unsorted columns in `batch` especially for variable length - // data types due to exponential growth when building the sort columns. We shrink the columns - // to prevent memory reservation failures, as well as excessive memory allocation when running - // merges in `SortPreservingMergeStream`. - columns.iter_mut().for_each(|c| { - c.shrink_to_fit(); - }); + let columns = take_arrays(batch.columns(), &indices, None)?; let options = RecordBatchOptions::new().with_row_count(Some(indices.len())); Ok(RecordBatch::try_new_with_options( @@ -871,6 +830,17 @@ pub fn sort_batch( )?) } +/// Sort a batch and return the result as multiple batches of size `batch_size`. +/// This is useful when you want to avoid creating one large sorted batch in memory, +/// and instead want to process the sorted data in smaller chunks. +pub fn sort_batch_chunked( + batch: &RecordBatch, + expressions: &LexOrdering, + batch_size: usize, +) -> Result> { + IncrementalSortIterator::new(batch.clone(), expressions.clone(), batch_size).collect() +} + /// Sort execution plan. /// /// Support sorting datasets that are larger than the memory allotted @@ -891,7 +861,7 @@ pub struct SortExec { /// Normalized common sort prefix between the input and the sort expressions (only used with fetch) common_sort_prefix: Vec, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, /// Filter matching the state of the sort for dynamic filter pushdown. /// If `fetch` is `Some`, this will also be set and a TopK operator may be used. /// If `fetch` is `None`, this will be `None`. @@ -913,7 +883,7 @@ impl SortExec { preserve_partitioning, fetch: None, common_sort_prefix: sort_prefix, - cache, + cache: Arc::new(cache), filter: None, } } @@ -932,12 +902,8 @@ impl SortExec { /// input partitions producing a single, sorted partition. pub fn with_preserve_partitioning(mut self, preserve_partitioning: bool) -> Self { self.preserve_partitioning = preserve_partitioning; - self.cache = self - .cache - .with_partitioning(Self::output_partitioning_helper( - &self.input, - self.preserve_partitioning, - )); + Arc::make_mut(&mut self.cache).partitioning = + Self::output_partitioning_helper(&self.input, self.preserve_partitioning); self } @@ -961,7 +927,7 @@ impl SortExec { preserve_partitioning: self.preserve_partitioning, common_sort_prefix: self.common_sort_prefix.clone(), fetch: self.fetch, - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), filter: self.filter.clone(), } } @@ -974,12 +940,12 @@ impl SortExec { /// operation since rows that are not going to be included /// can be dropped. pub fn with_fetch(&self, fetch: Option) -> Self { - let mut cache = self.cache.clone(); + let mut cache = PlanProperties::clone(&self.cache); // If the SortExec can emit incrementally (that means the sort requirements // and properties of the input match), the SortExec can generate its result // without scanning the entire input when a fetch value exists. let is_pipeline_friendly = matches!( - self.cache.emission_type, + cache.emission_type, EmissionType::Incremental | EmissionType::Both ); if fetch.is_some() && is_pipeline_friendly { @@ -991,7 +957,7 @@ impl SortExec { }); let mut new_sort = self.cloned(); new_sort.fetch = fetch; - new_sort.cache = cache; + new_sort.cache = cache.into(); new_sort.filter = filter; new_sort } @@ -1011,6 +977,30 @@ impl SortExec { self.fetch } + /// Returns the dynamic filter expression for this sort (TopK), if set. + pub fn dynamic_filter_expr(&self) -> Option> { + self.filter.as_ref().map(|f| f.read().expr()) + } + + /// Replace the dynamic filter expression for this sort. + /// + /// + /// Resets any internal state which may depend on the previous dynamic filter. + /// + /// Validates that the filter's children reference valid columns in + /// the sort's input schema. + pub fn with_dynamic_filter_expr( + mut self, + filter: Arc, + ) -> Result { + let input_schema = self.input.schema(); + for child in filter.children() { + child.data_type(&input_schema)?; + } + self.filter = Some(Arc::new(RwLock::new(TopKDynamicFilters::new(filter)))); + Ok(self) + } + fn output_partitioning_helper( input: &Arc, preserve_partitioning: bool, @@ -1087,13 +1077,16 @@ impl DisplayAs for SortExec { let preserve_partitioning = self.preserve_partitioning; match self.fetch { Some(fetch) => { - write!(f, "SortExec: TopK(fetch={fetch}), expr=[{}], preserve_partitioning=[{preserve_partitioning}]", self.expr)?; - if let Some(filter) = &self.filter { - if let Ok(current) = filter.read().expr().current() { - if !current.eq(&lit(true)) { - write!(f, ", filter=[{current}]")?; - } - } + write!( + f, + "SortExec: TopK(fetch={fetch}), expr=[{}], preserve_partitioning=[{preserve_partitioning}]", + self.expr + )?; + if let Some(filter) = &self.filter + && let Ok(current) = filter.read().expr().current() + && !current.eq(&lit(true)) + { + write!(f, ", filter=[{current}]")?; } if !self.common_sort_prefix.is_empty() { write!(f, ", sort_prefix=[")?; @@ -1111,7 +1104,11 @@ impl DisplayAs for SortExec { Ok(()) } } - None => write!(f, "SortExec: expr=[{}], preserve_partitioning=[{preserve_partitioning}]", self.expr), + None => write!( + f, + "SortExec: expr=[{}], preserve_partitioning=[{preserve_partitioning}]", + self.expr + ), } } DisplayFormatType::TreeRender => match self.fetch { @@ -1135,11 +1132,7 @@ impl ExecutionPlan for SortExec { } } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -1148,7 +1141,8 @@ impl ExecutionPlan for SortExec { vec![Distribution::UnspecifiedDistribution] } else { // global sort - // TODO support RangePartition and OrderedDistribution + // TODO support range partitioning and OrderedDistribution. + // See https://github.com/apache/datafusion/issues/22395 vec![Distribution::SinglePartition] } } @@ -1166,19 +1160,19 @@ impl ExecutionPlan for SortExec { children: Vec>, ) -> Result> { let mut new_sort = self.cloned(); - assert!( - children.len() == 1, - "SortExec should have exactly one child" - ); + assert_eq!(children.len(), 1, "SortExec should have exactly one child"); new_sort.input = Arc::clone(&children[0]); - // Recompute the properties based on the new input since they may have changed - let (cache, sort_prefix) = Self::compute_properties( - &new_sort.input, - new_sort.expr.clone(), - new_sort.preserve_partitioning, - )?; - new_sort.cache = cache; - new_sort.common_sort_prefix = sort_prefix; + + if !has_same_children_properties(self.as_ref(), &children)? { + // Recompute the properties based on the new input since they may have changed + let (cache, sort_prefix) = Self::compute_properties( + &new_sort.input, + new_sort.expr.clone(), + new_sort.preserve_partitioning, + )?; + new_sort.cache = Arc::new(cache); + new_sort.common_sort_prefix = sort_prefix; + } Ok(Arc::new(new_sort)) } @@ -1187,7 +1181,6 @@ impl ExecutionPlan for SortExec { let children = self.children().into_iter().cloned().collect(); let new_sort = self.with_new_children(children)?; let mut new_sort = new_sort - .as_any() .downcast_ref::() .expect("cloned 1 lines above this line, we know the type") .clone(); @@ -1203,7 +1196,12 @@ impl ExecutionPlan for SortExec { partition: usize, context: Arc, ) -> Result { - trace!("Start SortExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + trace!( + "Start SortExec::execute for partition {} of context session_id {} and task_id {:?}", + partition, + context.session_id(), + context.task_id() + ); let mut input = self.input.execute(partition, Arc::clone(&context))?; @@ -1247,6 +1245,7 @@ impl ExecutionPlan for SortExec { break; } } + drop(input); topk.emit() }) .try_flatten(), @@ -1271,6 +1270,7 @@ impl ExecutionPlan for SortExec { let batch = batch?; sorter.insert_batch(batch).await?; } + drop(input); sorter.sort().await }) .try_flatten(), @@ -1283,20 +1283,14 @@ impl ExecutionPlan for SortExec { Some(self.metrics_set.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { - if !self.preserve_partitioning() { - return self - .input - .partition_statistics(None)? - .with_fetch(self.fetch, 0, 1); - } - self.input - .partition_statistics(partition)? - .with_fetch(self.fetch, 0, 1) + fn partition_statistics(&self, partition: Option) -> Result> { + let p = if !self.preserve_partitioning() { + None + } else { + partition + }; + let stats = Arc::unwrap_or_clone(self.input.partition_statistics(p)?); + Ok(Arc::new(stats.with_fetch(self.fetch, 0, 1)?)) } fn with_fetch(&self, limit: Option) -> Option> { @@ -1345,21 +1339,84 @@ impl ExecutionPlan for SortExec { parent_filters: Vec>, config: &datafusion_common::config::ConfigOptions, ) -> Result { - if !matches!(phase, FilterPushdownPhase::Post) { + if phase != FilterPushdownPhase::Post { + if self.fetch.is_some() { + return Ok(FilterDescription::all_unsupported( + &parent_filters, + &self.children(), + )); + } return FilterDescription::from_children(parent_filters, &self.children()); } - let mut child = - ChildFilterDescription::from_child(&parent_filters, self.input())?; + // In Post phase: block parent filters when fetch is set, + // but still push the TopK dynamic filter (self-filter). + let mut child = if self.fetch.is_some() { + ChildFilterDescription::all_unsupported(&parent_filters) + } else { + ChildFilterDescription::from_child(&parent_filters, self.input())? + }; - if let Some(filter) = &self.filter { - if config.optimizer.enable_topk_dynamic_filter_pushdown { - child = child.with_self_filter(filter.read().expr()); - } + if let Some(filter) = &self.filter + && config.optimizer.enable_topk_dynamic_filter_pushdown + { + child = child.with_self_filter(filter.read().expr()); } Ok(FilterDescription::new().with_child(child)) } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &datafusion_common::config::ConfigOptions, + ) -> Result>> { + // For a plain sort (no fetch) we intercept any unsupported filters + // by inserting a FilterExec below this Sort. Moving the filter below + // Sort is safe because Sort preserves all rows. + // + // Why not fetch (TopK)? + // A sort with fetch limits the number of output rows. Inserting a + // FilterExec *below* the TopK would change semantics. A filter *above* + // the TopK is supposed to post-filter its output (e.g. "take the top 10 + // rows, then keep only those with a > 5"). Pushing the filter below + // Sort changes the meaning to "filter first, then take top 10", which + // produces a different result. + if self.fetch.is_some() { + return Ok(FilterPushdownPropagation::if_all(child_pushdown_result)); + } + + // Collect parent filters that were NOT successfully pushed to our child. + let unsupported_filters: Vec> = child_pushdown_result + .parent_filters + .iter() + .filter(|&f| matches!(f.all(), PushedDown::No)) + .map(|f| Arc::clone(&f.filter)) + .collect(); + + if unsupported_filters.is_empty() { + // All filters were pushed — nothing extra to do. + return Ok(FilterPushdownPropagation::if_all(child_pushdown_result)); + } + + // Build a single conjunctive predicate from the unsupported filters + // and insert a FilterExec between this SortExec and its child. + let predicate = datafusion_physical_expr::conjunction(unsupported_filters); + let new_child = + Arc::new(FilterExec::try_new(predicate, Arc::clone(self.input()))?) + as Arc; + let new_sort = Arc::new( + SortExec::new(self.expr.clone(), new_child) + .with_fetch(self.fetch()) + .with_preserve_partitioning(self.preserve_partitioning()), + ) as Arc; + + Ok(FilterPushdownPropagation { + filters: vec![PushedDown::Yes; child_pushdown_result.parent_filters.len()], + updated_node: Some(new_sort), + }) + } } #[cfg(test)] @@ -1371,33 +1428,39 @@ mod tests { use super::*; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::collect; + use crate::empty::EmptyExec; use crate::execution_plan::Boundedness; use crate::expressions::col; + use crate::filter_pushdown::{FilterPushdownPhase, PushedDown}; use crate::test; - use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::test::TestMemoryExec; + use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero}; use crate::test::{assert_is_pending, make_partition}; use arrow::array::*; use arrow::compute::SortOptions; use arrow::datatypes::*; + use datafusion_common::ScalarValue; use datafusion_common::cast::as_primitive_array; + use datafusion_common::config::ConfigOptions; use datafusion_common::test_util::batches_to_string; - use datafusion_common::{DataFusionError, Result, ScalarValue}; + use datafusion_execution::RecordBatchStream; use datafusion_execution::config::SessionConfig; + use datafusion_execution::memory_pool::{ + GreedyMemoryPool, MemoryConsumer, MemoryPool, + }; use datafusion_execution::runtime_env::RuntimeEnvBuilder; - use datafusion_execution::RecordBatchStream; - use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::EquivalenceProperties; + use datafusion_physical_expr::expressions::{Column, Literal}; - use futures::{FutureExt, Stream}; + use futures::{FutureExt, Stream, TryStreamExt}; use insta::assert_snapshot; #[derive(Debug, Clone)] pub struct SortedUnboundedExec { schema: Schema, batch_size: u64, - cache: PlanProperties, + cache: Arc, } impl DisplayAs for SortedUnboundedExec { @@ -1433,11 +1496,7 @@ mod tests { Self::static_name() } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -1611,13 +1670,24 @@ mod tests { #[tokio::test] async fn test_batch_reservation_error() -> Result<()> { // Pick a memory limit and sort_spill_reservation that make the first batch reservation fail. - // These values assume that the ExternalSorter will reserve 800 bytes for the first batch. - let expected_batch_reservation = 800; let merge_reservation: usize = 0; // Set to 0 for simplicity - let memory_limit: usize = expected_batch_reservation + merge_reservation - 1; // Just short of what we need let session_config = SessionConfig::new().with_sort_spill_reservation_bytes(merge_reservation); + + let plan = test::scan_partitioned(1); + + // Read the first record batch to determine the actual memory requirement + let expected_batch_reservation = { + let temp_ctx = Arc::new(TaskContext::default()); + let mut stream = plan.execute(0, Arc::clone(&temp_ctx))?; + let first_batch = stream.next().await.unwrap()?; + get_reserved_bytes_for_record_batch(&first_batch)? + }; + + // Set memory limit just short of what we need + let memory_limit: usize = expected_batch_reservation + merge_reservation - 1; + let runtime = RuntimeEnvBuilder::new() .with_memory_limit(memory_limit, 1.0) .build_arc()?; @@ -1627,14 +1697,11 @@ mod tests { .with_runtime(runtime), ); - let plan = test::scan_partitioned(1); - - // Read the first record batch to assert that our memory limit and sort_spill_reservation - // settings trigger the test scenario. + // Verify that our memory limit is insufficient { let mut stream = plan.execute(0, Arc::clone(&task_ctx))?; let first_batch = stream.next().await.unwrap()?; - let batch_reservation = get_reserved_byte_for_record_batch(&first_batch); + let batch_reservation = get_reserved_bytes_for_record_batch(&first_batch)?; assert_eq!(batch_reservation, expected_batch_reservation); assert!(memory_limit < (merge_reservation + batch_reservation)); @@ -1659,6 +1726,21 @@ mod tests { "Assertion failed: expected a ResourcesExhausted error, but got: {err:?}" ); + // Verify external sorter error message when resource is exhausted + let config_vector = vec![ + "datafusion.runtime.memory_limit", + "datafusion.execution.sort_spill_reservation_bytes", + ]; + let error_message = err.message().to_string(); + for config in config_vector.into_iter() { + assert!( + error_message.as_str().contains(config), + "Config: '{}' should be contained in error message: {}.", + config, + error_message.as_str() + ); + } + Ok(()) } @@ -1679,7 +1761,7 @@ mod tests { // The input has 200 partitions, each partition has a batch containing 100 rows. // Each row has a single Utf8 column, the Utf8 string values are roughly 42 bytes. - // The total size of the input is roughly 8.4 KB. + // The total size of the input is roughly 820 KB. let input = test::scan_partitioned_utf8(200); let schema = input.schema(); @@ -1802,6 +1884,93 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_sort_memory_reduction_per_batch() -> Result<()> { + // This test verifies that memory reservation is reduced for every batch emitted + // during the sort process. This is important to ensure we don't hold onto + // memory longer than necessary. + + // Create a large enough batch that will be split into multiple output batches + let batch_size = 50; // Small batch size to force multiple output batches + let num_rows = 1000; // Create enough data for multiple batches + + let task_ctx = Arc::new( + TaskContext::default().with_session_config( + SessionConfig::new() + .with_batch_size(batch_size) + .with_sort_in_place_threshold_bytes(usize::MAX), // Ensure we don't concat batches + ), + ); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create unsorted data + let mut values: Vec = (0..num_rows).collect(); + values.reverse(); + + let input_batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + + let batches = vec![input_batch]; + + let sort_exec = Arc::new(SortExec::new( + [PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }] + .into(), + TestMemoryExec::try_new_exec( + std::slice::from_ref(&batches), + Arc::clone(&schema), + None, + )?, + )); + + let mut stream = sort_exec.execute(0, Arc::clone(&task_ctx))?; + + let mut previous_reserved = task_ctx.runtime_env().memory_pool.reserved(); + let mut batch_count = 0; + + // Collect batches and verify memory is reduced with each batch + while let Some(result) = stream.next().await { + let batch = result?; + batch_count += 1; + + // Verify we got a non-empty batch + assert!(batch.num_rows() > 0, "Batch should not be empty"); + + let current_reserved = task_ctx.runtime_env().memory_pool.reserved(); + + // After the first batch, memory should be reducing or staying the same + // (it should not increase as we emit batches) + if batch_count > 1 { + assert!( + current_reserved <= previous_reserved, + "Memory reservation should decrease or stay same as batches are emitted. \ + Batch {batch_count}: previous={previous_reserved}, current={current_reserved}" + ); + } + + previous_reserved = current_reserved; + } + + assert!( + batch_count > 1, + "Expected multiple batches to be emitted, got {batch_count}" + ); + + // Verify all memory is returned at the end + assert_eq!( + task_ctx.runtime_env().memory_pool.reserved(), + 0, + "All memory should be returned after consuming all batches" + ); + + Ok(()) + } + #[tokio::test] async fn test_sort_metadata() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); @@ -2095,7 +2264,9 @@ mod tests { let source = SortedUnboundedExec { schema: schema.clone(), batch_size: 2, - cache: SortedUnboundedExec::compute_properties(Arc::new(schema.clone())), + cache: Arc::new(SortedUnboundedExec::compute_properties(Arc::new( + schema.clone(), + ))), }; let mut plan = SortExec::new( [PhysicalSortExpr::new_default(Arc::new(Column::new( @@ -2107,21 +2278,21 @@ mod tests { plan = plan.with_fetch(Some(9)); let batches = collect(Arc::new(plan), task_ctx).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+ - | c1 | - +----+ - | 0 | - | 1 | - | 2 | - | 3 | - | 4 | - | 5 | - | 6 | - | 7 | - | 8 | - +----+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+ + | c1 | + +----+ + | 0 | + | 1 | + | 2 | + | 3 | + | 4 | + | 5 | + | 6 | + | 7 | + | 8 | + +----+ + "); Ok(()) } @@ -2150,8 +2321,8 @@ mod tests { } #[tokio::test] - async fn should_return_stream_with_batches_in_the_requested_size_when_sorting_in_place( - ) -> Result<()> { + async fn should_return_stream_with_batches_in_the_requested_size_when_sorting_in_place() + -> Result<()> { let batch_size = 100; let create_task_ctx = |_: &[RecordBatch]| { @@ -2202,8 +2373,8 @@ mod tests { } #[tokio::test] - async fn should_return_stream_with_batches_in_the_requested_size_when_having_a_single_batch( - ) -> Result<()> { + async fn should_return_stream_with_batches_in_the_requested_size_when_having_a_single_batch() + -> Result<()> { let batch_size = 100; let create_task_ctx = |_: &[RecordBatch]| { @@ -2266,8 +2437,8 @@ mod tests { } #[tokio::test] - async fn should_return_stream_with_batches_in_the_requested_size_when_having_to_spill( - ) -> Result<()> { + async fn should_return_stream_with_batches_in_the_requested_size_when_having_to_spill() + -> Result<()> { let batch_size = 100; let create_task_ctx = |generated_batches: &[RecordBatch]| { @@ -2390,4 +2561,423 @@ mod tests { Ok((sorted_batches, metrics)) } + + #[tokio::test] + async fn test_sort_batch_chunked_basic() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a batch with 1000 rows + let mut values: Vec = (0..1000).collect(); + // Shuffle to make it unsorted + values.reverse(); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + + let expressions: LexOrdering = + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); + + // Sort with batch_size = 250 + let result_batches = sort_batch_chunked(&batch, &expressions, 250)?; + + // Verify 4 batches are returned + assert_eq!(result_batches.len(), 4); + + // Verify each batch has <= 250 rows + let mut total_rows = 0; + for (i, batch) in result_batches.iter().enumerate() { + assert!( + batch.num_rows() <= 250, + "Batch {} has {} rows, expected <= 250", + i, + batch.num_rows() + ); + total_rows += batch.num_rows(); + } + + // Verify total row count matches input + assert_eq!(total_rows, 1000); + + // Verify data is correctly sorted across all chunks + let concatenated = concat_batches(&schema, &result_batches)?; + let array = as_primitive_array::(concatenated.column(0))?; + for i in 0..array.len() - 1 { + assert!( + array.value(i) <= array.value(i + 1), + "Array not sorted at position {}: {} > {}", + i, + array.value(i), + array.value(i + 1) + ); + } + assert_eq!(array.value(0), 0); + assert_eq!(array.value(array.len() - 1), 999); + + Ok(()) + } + + #[tokio::test] + async fn test_sort_batch_chunked_smaller_than_batch_size() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a batch with 50 rows + let values: Vec = (0..50).rev().collect(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + + let expressions: LexOrdering = + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); + + // Sort with batch_size = 100 + let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; + + // Should return exactly 1 batch + assert_eq!(result_batches.len(), 1); + assert_eq!(result_batches[0].num_rows(), 50); + + // Verify it's correctly sorted + let array = as_primitive_array::(result_batches[0].column(0))?; + for i in 0..array.len() - 1 { + assert!(array.value(i) <= array.value(i + 1)); + } + assert_eq!(array.value(0), 0); + assert_eq!(array.value(49), 49); + + Ok(()) + } + + #[tokio::test] + async fn test_sort_batch_chunked_exact_multiple() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a batch with 1000 rows + let values: Vec = (0..1000).rev().collect(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + + let expressions: LexOrdering = + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); + + // Sort with batch_size = 100 + let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; + + // Should return exactly 10 batches of 100 rows each + assert_eq!(result_batches.len(), 10); + for batch in &result_batches { + assert_eq!(batch.num_rows(), 100); + } + + // Verify sorted correctly across all batches + let concatenated = concat_batches(&schema, &result_batches)?; + let array = as_primitive_array::(concatenated.column(0))?; + for i in 0..array.len() - 1 { + assert!(array.value(i) <= array.value(i + 1)); + } + + Ok(()) + } + + #[tokio::test] + async fn test_sort_batch_chunked_empty_batch() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + let batch = RecordBatch::new_empty(Arc::clone(&schema)); + + let expressions: LexOrdering = + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(); + + let result_batches = sort_batch_chunked(&batch, &expressions, 100)?; + + // Empty input produces no output batches (0 chunks) + assert_eq!(result_batches.len(), 0); + + Ok(()) + } + + #[tokio::test] + async fn test_get_reserved_bytes_for_record_batch_with_sliced_batches() -> Result<()> + { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create a larger batch then slice it + let large_array = Int32Array::from((0..1000).collect::>()); + let sliced_array = large_array.slice(100, 50); // Take 50 elements starting at 100 + + let sliced_batch = + RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(sliced_array)])?; + let batch = + RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(large_array)])?; + + let sliced_reserved = get_reserved_bytes_for_record_batch(&sliced_batch)?; + let reserved = get_reserved_bytes_for_record_batch(&batch)?; + + // The reserved memory for the sliced batch should be less than that of the full batch + assert!(reserved > sliced_reserved); + + Ok(()) + } + + #[test] + fn test_with_dynamic_filter() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let child = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let sort = SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }]) + .unwrap(), + child, + ) + .with_fetch(Some(10)); + + // SortExec with fetch creates a dynamic filter automatically. + let original_id = sort + .dynamic_filter_expr() + .expect("should have dynamic filter with fetch") + .expression_id() + .expect("DynamicFilterPhysicalExpr always has an expression_id"); + + // with_dynamic_filter replaces it with a new TopKDynamicFilters. + let new_df = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("a", 0)) as _], + lit(true), + )); + let new_id = new_df + .expression_id() + .expect("DynamicFilterPhysicalExpr always has an expression_id"); + let sort = sort.with_dynamic_filter_expr(Arc::clone(&new_df))?; + let restored_id = sort + .dynamic_filter_expr() + .expect("should still have dynamic filter") + .expression_id() + .expect("DynamicFilterPhysicalExpr always has an expression_id"); + assert_eq!(restored_id, new_id); + assert_ne!(restored_id, original_id); + Ok(()) + } + + #[test] + fn test_with_dynamic_filter_rejects_invalid_columns() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let child = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let sort = SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::new(Column::new("a", 0)), + options: SortOptions::default(), + }]) + .unwrap(), + child, + ) + .with_fetch(Some(10)); + + // Column index 99 is out of bounds for the input schema. + let df = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("bad", 99)) as _], + lit(true), + )); + assert!(sort.with_dynamic_filter_expr(df).is_err()); + Ok(()) + } + + /// Verifies that `ExternalSorter::sort()` transfers the pre-reserved + /// merge bytes to the merge stream via `take()`, rather than leaving + /// them in the sorter (via `new_empty()`). + /// + /// 1. Create a sorter with a tight memory pool and insert enough data + /// to force spilling + /// 2. Verify `merge_reservation` holds the pre-reserved bytes before sort + /// 3. Call `sort()` to get the merge stream + /// 4. Verify `merge_reservation` is now 0 (bytes transferred to merge stream) + /// 5. Simulate contention: a competing consumer grabs all available pool memory + /// 6. Verify the merge stream still works (it uses its pre-reserved bytes + /// as initial budget, not requesting from pool starting at 0) + /// + /// With `new_empty()` (before fix), step 4 fails: `merge_reservation` + /// still holds the bytes, the merge stream starts with 0 budget, and + /// those bytes become unaccounted-for reserved memory that nobody uses. + #[tokio::test] + async fn test_sort_merge_reservation_transferred_not_freed() -> Result<()> { + let sort_spill_reservation_bytes: usize = 10 * 1024; // 10 KB + + // Pool: merge reservation (10KB) + enough room for sort to work. + // The room must accommodate batch data accumulation before spilling. + let sort_working_memory: usize = 40 * 1024; // 40 KB for sort operations + let pool_size = sort_spill_reservation_bytes + sort_working_memory; + let pool: Arc = Arc::new(GreedyMemoryPool::new(pool_size)); + + let runtime = RuntimeEnvBuilder::new() + .with_memory_pool(Arc::clone(&pool)) + .build_arc()?; + + let metrics_set = ExecutionPlanMetricsSet::new(); + let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int32, false)])); + + let mut sorter = ExternalSorter::new( + 0, + Arc::clone(&schema), + [PhysicalSortExpr::new_default(Arc::new(Column::new("x", 0)))].into(), + 128, // batch_size + sort_spill_reservation_bytes, + usize::MAX, // sort_in_place_threshold_bytes (high to avoid concat path) + SpillCompression::Uncompressed, + &metrics_set, + Arc::clone(&runtime), + )?; + + // Insert enough data to force spilling. + let num_batches = 200; + for i in 0..num_batches { + let values: Vec = ((i * 100)..((i + 1) * 100)).rev().collect(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(values))], + )?; + sorter.insert_batch(batch).await?; + } + + assert!( + sorter.spilled_before(), + "Test requires spilling to exercise the merge path" + ); + + // Before sort(), merge_reservation holds sort_spill_reservation_bytes. + assert!( + sorter.merge_reservation_size() >= sort_spill_reservation_bytes, + "merge_reservation should hold the pre-reserved bytes before sort()" + ); + + // Call sort() to get the merge stream. With the fix (take()), + // the pre-reserved merge bytes are transferred to the merge + // stream. Without the fix (free() + new_empty()), the bytes + // are released back to the pool and the merge stream starts + // with 0 bytes. + let merge_stream = sorter.sort().await?; + + // THE KEY ASSERTION: after sort(), merge_reservation must be 0. + // This proves take() transferred the bytes to the merge stream, + // rather than them being freed back to the pool where other + // partitions could steal them. + assert_eq!( + sorter.merge_reservation_size(), + 0, + "After sort(), merge_reservation should be 0 (bytes transferred \ + to merge stream via take()). If non-zero, the bytes are still \ + held by the sorter and will be freed on drop, allowing other \ + partitions to steal them." + ); + + // Drop the sorter to free its reservations back to the pool. + drop(sorter); + + // Simulate contention: another partition grabs ALL available + // pool memory. If the merge stream didn't receive the + // pre-reserved bytes via take(), it will fail when it tries + // to allocate memory for reading spill files. + let contender = MemoryConsumer::new("CompetingPartition").register(&pool); + let available = pool_size.saturating_sub(pool.reserved()); + if available > 0 { + contender.try_grow(available).unwrap(); + } + + // The merge stream must still produce correct results despite + // the pool being fully consumed by the contender. This only + // works if sort() transferred the pre-reserved bytes to the + // merge stream (via take()) rather than freeing them. + let batches: Vec = merge_stream.try_collect().await?; + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!( + total_rows, + (num_batches * 100) as usize, + "Merge stream should produce all rows even under memory contention" + ); + + // Verify data is sorted + let merged = concat_batches(&schema, &batches)?; + let col = merged.column(0).as_primitive::(); + for i in 1..col.len() { + assert!( + col.value(i - 1) <= col.value(i), + "Output should be sorted, but found {} > {} at index {}", + col.value(i - 1), + col.value(i), + i + ); + } + + drop(contender); + Ok(()) + } + + fn make_sort_exec_with_fetch(fetch: Option) -> SortExec { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let input = Arc::new(EmptyExec::new(schema)); + SortExec::new( + [PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(), + input, + ) + .with_fetch(fetch) + } + + #[test] + fn test_sort_with_fetch_blocks_filter_pushdown() -> Result<()> { + let sort = make_sort_exec_with_fetch(Some(10)); + let desc = sort.gather_filters_for_pushdown( + FilterPushdownPhase::Pre, + vec![Arc::new(Column::new("a", 0))], + &ConfigOptions::new(), + )?; + // Sort with fetch (TopK) must not allow filters to be pushed below it. + assert!(matches!( + desc.parent_filters()[0][0].discriminant, + PushedDown::No + )); + Ok(()) + } + + #[test] + fn test_sort_without_fetch_allows_filter_pushdown() -> Result<()> { + let sort = make_sort_exec_with_fetch(None); + let desc = sort.gather_filters_for_pushdown( + FilterPushdownPhase::Pre, + vec![Arc::new(Column::new("a", 0))], + &ConfigOptions::new(), + )?; + // Plain sort (no fetch) is filter-commutative. + assert!(matches!( + desc.parent_filters()[0][0].discriminant, + PushedDown::Yes + )); + Ok(()) + } + + #[test] + fn test_sort_with_fetch_allows_topk_self_filter_in_post_phase() -> Result<()> { + let sort = make_sort_exec_with_fetch(Some(10)); + assert!(sort.filter.is_some(), "TopK filter should be created"); + + let mut config = ConfigOptions::new(); + config.optimizer.enable_topk_dynamic_filter_pushdown = true; + let desc = sort.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![Arc::new(Column::new("a", 0))], + &config, + )?; + // Parent filters are still blocked in the Post phase. + assert!(matches!( + desc.parent_filters()[0][0].discriminant, + PushedDown::No + )); + // But the TopK self-filter should be pushed down. + assert_eq!(desc.self_filters()[0].len(), 1); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 3361a7cdb7185..eb9b5f09aa3ed 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -17,22 +17,22 @@ //! [`SortPreservingMergeExec`] merges multiple sorted streams into one sorted stream. -use std::any::Any; use std::sync::Arc; use crate::common::spawn_buffered; use crate::limit::LimitStream; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use crate::projection::{make_with_child, update_ordering, ProjectionExec}; +use crate::projection::{ProjectionExec, make_with_child, update_ordering}; use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, + check_if_same_properties, }; -use datafusion_common::{assert_eq_or_internal_err, internal_err, Result}; -use datafusion_execution::memory_pool::MemoryConsumer; +use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements}; use crate::execution_plan::{EvaluationType, SchedulingType}; @@ -93,7 +93,7 @@ pub struct SortPreservingMergeExec { /// Optional number of rows to fetch. Stops producing rows after this fetch fetch: Option, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, /// Use round-robin selection of tied winners of loser tree /// /// See [`Self::with_round_robin_repartition`] for more information. @@ -109,7 +109,7 @@ impl SortPreservingMergeExec { expr, metrics: ExecutionPlanMetricsSet::new(), fetch: None, - cache, + cache: Arc::new(cache), enable_round_robin_repartition: true, } } @@ -180,6 +180,17 @@ impl SortPreservingMergeExec { .with_evaluation_type(drive) .with_scheduling_type(scheduling) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for SortPreservingMergeExec { @@ -221,11 +232,7 @@ impl ExecutionPlan for SortPreservingMergeExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -240,11 +247,24 @@ impl ExecutionPlan for SortPreservingMergeExec { expr: self.expr.clone(), metrics: self.metrics.clone(), fetch: limit, - cache: self.cache.clone(), - enable_round_robin_repartition: true, + cache: Arc::clone(&self.cache), + enable_round_robin_repartition: self.enable_round_robin_repartition, })) } + fn with_preserve_order( + &self, + preserve_order: bool, + ) -> Option> { + self.input + .with_preserve_order(preserve_order) + .and_then(|new_input| { + Arc::new(self.clone()) + .with_new_children(vec![new_input]) + .ok() + }) + } + fn required_input_distribution(&self) -> Vec { vec![Distribution::UnspecifiedDistribution] } @@ -267,10 +287,11 @@ impl ExecutionPlan for SortPreservingMergeExec { fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); Ok(Arc::new( - SortPreservingMergeExec::new(self.expr.clone(), Arc::clone(&children[0])) + SortPreservingMergeExec::new(self.expr.clone(), children.swap_remove(0)) .with_fetch(self.fetch), )) } @@ -304,7 +325,9 @@ impl ExecutionPlan for SortPreservingMergeExec { 1 => match self.fetch { Some(fetch) => { let stream = self.input.execute(0, context)?; - debug!("Done getting stream for SortPreservingMergeExec::execute with 1 input with {fetch}"); + debug!( + "Done getting stream for SortPreservingMergeExec::execute with 1 input with {fetch}" + ); Ok(Box::pin(LimitStream::new( stream, 0, @@ -314,7 +337,9 @@ impl ExecutionPlan for SortPreservingMergeExec { } None => { let stream = self.input.execute(0, context); - debug!("Done getting stream for SortPreservingMergeExec::execute with 1 input without fetch"); + debug!( + "Done getting stream for SortPreservingMergeExec::execute with 1 input without fetch" + ); stream } }, @@ -327,7 +352,9 @@ impl ExecutionPlan for SortPreservingMergeExec { }) .collect::>()?; - debug!("Done setting up sender-receiver for SortPreservingMergeExec::execute"); + debug!( + "Done setting up sender-receiver for SortPreservingMergeExec::execute" + ); let result = StreamingMergeBuilder::new() .with_streams(receivers) @@ -340,7 +367,9 @@ impl ExecutionPlan for SortPreservingMergeExec { .with_round_robin_tie_breaker(self.enable_round_robin_repartition) .build()?; - debug!("Got stream result from SortPreservingMergeStream::new_from_receivers"); + debug!( + "Got stream result from SortPreservingMergeStream::new_from_receivers" + ); Ok(result) } @@ -351,11 +380,7 @@ impl ExecutionPlan for SortPreservingMergeExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.input.partition_statistics(None) - } - - fn partition_statistics(&self, _partition: Option) -> Result { + fn partition_statistics(&self, _partition: Option) -> Result> { self.input.partition_statistics(None) } @@ -396,11 +421,10 @@ mod tests { use std::fmt::Formatter; use std::pin::Pin; use std::sync::Mutex; - use std::task::{ready, Context, Poll, Waker}; + use std::task::{Context, Poll, Waker, ready}; use std::time::Duration; use super::*; - use crate::coalesce_batches::CoalesceBatchesExec; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::execution_plan::{Boundedness, EmissionType}; use crate::expressions::col; @@ -408,8 +432,8 @@ mod tests { use crate::repartition::RepartitionExec; use crate::sorts::sort::SortExec; use crate::stream::RecordBatchReceiverStream; - use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::test::TestMemoryExec; + use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero}; use crate::test::{self, assert_is_pending, make_partition}; use crate::{collect, common}; @@ -422,11 +446,11 @@ mod tests { use datafusion_common::test_util::batches_to_string; use datafusion_common::{assert_batches_eq, exec_err}; use datafusion_common_runtime::SpawnedTask; + use datafusion_execution::RecordBatchStream; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; - use datafusion_execution::RecordBatchStream; - use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::EquivalenceProperties; + use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; @@ -436,31 +460,33 @@ mod tests { // The number in the function is highly related to the memory limit we are testing // any change of the constant should be aware of - fn generate_task_ctx_for_round_robin_tie_breaker() -> Result> { + fn generate_task_ctx_for_round_robin_tie_breaker( + target_batch_size: usize, + ) -> Result> { let runtime = RuntimeEnvBuilder::new() .with_memory_limit(20_000_000, 1.0) .build_arc()?; - let config = SessionConfig::new(); + let mut config = SessionConfig::new(); + config.options_mut().execution.batch_size = target_batch_size; let task_ctx = TaskContext::default() .with_runtime(runtime) .with_session_config(config); Ok(Arc::new(task_ctx)) } + // The number in the function is highly related to the memory limit we are testing, // any change of the constant should be aware of fn generate_spm_for_round_robin_tie_breaker( enable_round_robin_repartition: bool, ) -> Result> { - let target_batch_size = 12500; let row_size = 12500; let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size])); let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size])); let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)])?; - - let rbs = (0..1024).map(|_| rb.clone()).collect::>(); - let schema = rb.schema(); + + let rbs = std::iter::repeat_n(rb, 1024).collect::>(); let sort = [ PhysicalSortExpr { expr: col("b", &schema)?, @@ -477,9 +503,7 @@ mod tests { TestMemoryExec::try_new_exec(&[rbs], schema, None)?, Partitioning::RoundRobinBatch(2), )?; - let coalesce_batches_exec = - CoalesceBatchesExec::new(Arc::new(repartition_exec), target_batch_size); - let spm = SortPreservingMergeExec::new(sort, Arc::new(coalesce_batches_exec)) + let spm = SortPreservingMergeExec::new(sort, Arc::new(repartition_exec)) .with_round_robin_repartition(enable_round_robin_repartition); Ok(Arc::new(spm)) } @@ -491,7 +515,8 @@ mod tests { /// based on whether the tie breaker is enabled or disabled. #[tokio::test(flavor = "multi_thread")] async fn test_round_robin_tie_breaker_success() -> Result<()> { - let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?; + let target_batch_size = 12500; + let task_ctx = generate_task_ctx_for_round_robin_tie_breaker(target_batch_size)?; let spm = generate_spm_for_round_robin_tie_breaker(true)?; let _collected = collect(spm, task_ctx).await?; Ok(()) @@ -504,7 +529,7 @@ mod tests { /// based on whether the tie breaker is enabled or disabled. #[tokio::test(flavor = "multi_thread")] async fn test_round_robin_tie_breaker_fail() -> Result<()> { - let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?; + let task_ctx = generate_task_ctx_for_round_robin_tie_breaker(8192)?; let spm = generate_spm_for_round_robin_tie_breaker(false)?; let _err = collect(spm, task_ctx).await.unwrap_err(); Ok(()) @@ -975,22 +1000,22 @@ mod tests { let collected = collect(merge, task_ctx).await.unwrap(); assert_eq!(collected.len(), 1); - assert_snapshot!(batches_to_string(collected.as_slice()), @r#" - +---+---+-------------------------------+ - | a | b | c | - +---+---+-------------------------------+ - | 1 | | 1970-01-01T00:00:00.000000008 | - | 1 | | 1970-01-01T00:00:00.000000008 | - | 2 | a | | - | 7 | b | 1970-01-01T00:00:00.000000006 | - | 2 | b | | - | 9 | d | | - | 3 | e | 1970-01-01T00:00:00.000000004 | - | 3 | g | 1970-01-01T00:00:00.000000005 | - | 4 | h | | - | 5 | i | 1970-01-01T00:00:00.000000004 | - +---+---+-------------------------------+ - "#); + assert_snapshot!(batches_to_string(collected.as_slice()), @r" + +---+---+-------------------------------+ + | a | b | c | + +---+---+-------------------------------+ + | 1 | | 1970-01-01T00:00:00.000000008 | + | 1 | | 1970-01-01T00:00:00.000000008 | + | 2 | a | | + | 7 | b | 1970-01-01T00:00:00.000000006 | + | 2 | b | | + | 9 | d | | + | 3 | e | 1970-01-01T00:00:00.000000004 | + | 3 | g | 1970-01-01T00:00:00.000000005 | + | 4 | h | | + | 5 | i | 1970-01-01T00:00:00.000000004 | + +---+---+-------------------------------+ + "); } #[tokio::test] @@ -1016,14 +1041,14 @@ mod tests { let collected = collect(merge, task_ctx).await.unwrap(); assert_eq!(collected.len(), 1); - assert_snapshot!(batches_to_string(collected.as_slice()), @r#" - +---+---+ - | a | b | - +---+---+ - | 1 | a | - | 2 | b | - +---+---+ - "#); + assert_snapshot!(batches_to_string(collected.as_slice()), @r" + +---+---+ + | a | b | + +---+---+ + | 1 | a | + | 2 | b | + +---+---+ + "); } #[tokio::test] @@ -1048,17 +1073,17 @@ mod tests { let collected = collect(merge, task_ctx).await.unwrap(); assert_eq!(collected.len(), 1); - assert_snapshot!(batches_to_string(collected.as_slice()), @r#" - +---+---+ - | a | b | - +---+---+ - | 1 | a | - | 2 | b | - | 7 | c | - | 9 | d | - | 3 | e | - +---+---+ - "#); + assert_snapshot!(batches_to_string(collected.as_slice()), @r" + +---+---+ + | a | b | + +---+---+ + | 1 | a | + | 2 | b | + | 7 | c | + | 9 | d | + | 3 | e | + +---+---+ + "); } #[tokio::test] @@ -1157,16 +1182,16 @@ mod tests { let collected = collect(Arc::clone(&merge) as Arc, task_ctx) .await .unwrap(); - assert_snapshot!(batches_to_string(collected.as_slice()), @r#" - +----+---+ - | a | b | - +----+---+ - | 1 | a | - | 10 | b | - | 2 | c | - | 20 | d | - +----+---+ - "#); + assert_snapshot!(batches_to_string(collected.as_slice()), @r" + +----+---+ + | a | b | + +----+---+ + | 1 | a | + | 10 | b | + | 2 | c | + | 20 | d | + +----+---+ + "); // Now, validate metrics let metrics = merge.metrics().unwrap(); @@ -1272,32 +1297,32 @@ mod tests { // Expect the data to be sorted first by "batch_number" (because // that was the order it was fed in, even though only "value" // is in the sort key) - assert_snapshot!(batches_to_string(collected.as_slice()), @r#" - +--------------+-------+ - | batch_number | value | - +--------------+-------+ - | 0 | A | - | 1 | A | - | 2 | A | - | 3 | A | - | 4 | A | - | 5 | A | - | 6 | A | - | 7 | A | - | 8 | A | - | 9 | A | - | 0 | B | - | 1 | B | - | 2 | B | - | 3 | B | - | 4 | B | - | 5 | B | - | 6 | B | - | 7 | B | - | 8 | B | - | 9 | B | - +--------------+-------+ - "#); + assert_snapshot!(batches_to_string(collected.as_slice()), @r" + +--------------+-------+ + | batch_number | value | + +--------------+-------+ + | 0 | A | + | 1 | A | + | 2 | A | + | 3 | A | + | 4 | A | + | 5 | A | + | 6 | A | + | 7 | A | + | 8 | A | + | 9 | A | + | 0 | B | + | 1 | B | + | 2 | B | + | 3 | B | + | 4 | B | + | 5 | B | + | 6 | B | + | 7 | B | + | 8 | B | + | 9 | B | + +--------------+-------+ + "); } #[derive(Debug)] @@ -1342,7 +1367,7 @@ mod tests { #[derive(Debug, Clone)] struct CongestedExec { schema: Schema, - cache: PlanProperties, + cache: Arc, congestion: Arc, } @@ -1375,10 +1400,7 @@ mod tests { fn name(&self) -> &'static str { Self::static_name() } - fn as_any(&self) -> &dyn Any { - self - } - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } fn children(&self) -> Vec<&Arc> { @@ -1464,14 +1486,10 @@ mod tests { let task_ctx = Arc::new(TaskContext::default()); let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]); let properties = CongestedExec::compute_properties(Arc::new(schema.clone())); - let &partition_count = match properties.output_partitioning() { - Partitioning::RoundRobinBatch(partitions) => partitions, - Partitioning::Hash(_, partitions) => partitions, - Partitioning::UnknownPartitioning(partitions) => partitions, - }; + let partition_count = properties.output_partitioning().partition_count(); let source = CongestedExec { schema: schema.clone(), - cache: properties, + cache: Arc::new(properties), congestion: Arc::new(Congestion::new(partition_count)), }; let spm = SortPreservingMergeExec::new( @@ -1491,4 +1509,59 @@ mod tests { Err(_) => exec_err!("SortPreservingMerge caused a deadlock"), } } + + #[tokio::test] + async fn test_sort_merge_stops_after_error_with_buffered_rows() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); + let sort: LexOrdering = [PhysicalSortExpr::new_default(Arc::new(Column::new( + "i", 0, + )) + as Arc)] + .into(); + + let mut stream0 = RecordBatchReceiverStream::builder(Arc::clone(&schema), 2); + let tx0 = stream0.tx(); + let schema0 = Arc::clone(&schema); + stream0.spawn(async move { + let batch = + RecordBatch::try_new(schema0, vec![Arc::new(Int32Array::from(vec![1]))])?; + tx0.send(Ok(batch)).await.unwrap(); + tx0.send(exec_err!("stream failure")).await.unwrap(); + Ok(()) + }); + + let mut stream1 = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1); + let tx1 = stream1.tx(); + let schema1 = Arc::clone(&schema); + stream1.spawn(async move { + let batch = + RecordBatch::try_new(schema1, vec![Arc::new(Int32Array::from(vec![2]))])?; + tx1.send(Ok(batch)).await.unwrap(); + Ok(()) + }); + + let metrics = ExecutionPlanMetricsSet::new(); + let reservation = + MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool); + + let mut merge_stream = StreamingMergeBuilder::new() + .with_streams(vec![stream0.build(), stream1.build()]) + .with_schema(Arc::clone(&schema)) + .with_expressions(&sort) + .with_metrics(BaselineMetrics::new(&metrics, 0)) + .with_batch_size(task_ctx.session_config().batch_size()) + .with_fetch(None) + .with_reservation(reservation) + .build()?; + + let first = merge_stream.next().await.unwrap(); + assert!(first.is_err(), "expected merge stream to surface the error"); + assert!( + merge_stream.next().await.is_none(), + "merge stream yielded data after returning an error" + ); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/sorts/stream.rs b/datafusion/physical-plan/src/sorts/stream.rs index b0c631cf9135f..ff7f259dd1347 100644 --- a/datafusion/physical-plan/src/sorts/stream.rs +++ b/datafusion/physical-plan/src/sorts/stream.rs @@ -15,21 +15,25 @@ // specific language governing permissions and limitations // under the License. -use crate::sorts::cursor::{ArrayValues, CursorArray, RowValues}; use crate::SendableRecordBatchStream; +use crate::sorts::cursor::{ArrayValues, CursorArray, RowValues}; use crate::{PhysicalExpr, PhysicalSortExpr}; -use arrow::array::Array; +use arrow::array::{Array, UInt32Array}; +use arrow::compute::take_record_batch; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; -use datafusion_common::{internal_datafusion_err, Result}; +use arrow_ord::sort::lexsort_to_indices; +use datafusion_common::{Result, internal_datafusion_err}; use datafusion_execution::memory_pool::MemoryReservation; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays; use futures::stream::{Fuse, StreamExt}; +use std::iter::FusedIterator; use std::marker::PhantomData; +use std::mem; use std::sync::Arc; -use std::task::{ready, Context, Poll}; +use std::task::{Context, Poll, ready}; /// A [`Stream`](futures::Stream) that has multiple partitions that can /// be polled separately but not concurrently @@ -103,7 +107,7 @@ impl ReusableRows { self.inner[stream_idx][1] = Some(Arc::clone(rows)); // swap the current with the previous one, so that the next poll can reuse the Rows from the previous poll let [a, b] = &mut self.inner[stream_idx]; - std::mem::swap(a, b); + mem::swap(a, b); } } @@ -180,7 +184,7 @@ impl RowCursorStream { self.rows.save(stream_idx, &rows); // track the memory in the newly created Rows. - let mut rows_reservation = self.reservation.new_empty(); + let rows_reservation = self.reservation.new_empty(); rows_reservation.try_grow(rows.size())?; Ok(RowValues::new(rows, rows_reservation)) } @@ -246,7 +250,7 @@ impl FieldCursorStream { let array = value.into_array(batch.num_rows())?; let size_in_mem = array.get_buffer_memory_size(); let array = array.as_any().downcast_ref::().expect("field values"); - let mut array_reservation = self.reservation.new_empty(); + let array_reservation = self.reservation.new_empty(); array_reservation.try_grow(size_in_mem)?; Ok(ArrayValues::new( self.sort.options, @@ -276,3 +280,159 @@ impl PartitionedStream for FieldCursorStream { })) } } + +/// A lazy, memory-efficient sort iterator used as a fallback during aggregate +/// spill when there is not enough memory for an eager sort (which requires ~2x +/// peak memory to hold both the unsorted and sorted copies simultaneously). +/// +/// On the first call to `next()`, a sorted index array (`UInt32Array`) is +/// computed via `lexsort_to_indices`. Subsequent calls yield chunks of +/// `batch_size` rows by `take`-ing from the original batch using slices of +/// this index array. Each `take` copies data for the chunk (not zero-copy), +/// but only one chunk is live at a time since the caller consumes it before +/// requesting the next. Once all rows have been yielded, the original batch +/// and index array are dropped to free memory. +/// +/// The caller must reserve `sizeof(batch) + sizeof(one chunk)` for this iterator, +/// and free the reservation once the iterator is depleted. +pub(crate) struct IncrementalSortIterator { + batch: RecordBatch, + expressions: LexOrdering, + batch_size: usize, + indices: Option, + cursor: usize, +} + +impl IncrementalSortIterator { + pub(crate) fn new( + batch: RecordBatch, + expressions: LexOrdering, + batch_size: usize, + ) -> Self { + Self { + batch, + expressions, + batch_size, + cursor: 0, + indices: None, + } + } +} + +impl Iterator for IncrementalSortIterator { + type Item = Result; + + fn next(&mut self) -> Option { + if self.cursor >= self.batch.num_rows() { + return None; + } + + match self.indices.as_ref() { + None => { + let sort_columns = match self + .expressions + .iter() + .map(|expr| expr.evaluate_to_sort_column(&self.batch)) + .collect::>>() + { + Ok(cols) => cols, + Err(e) => return Some(Err(e)), + }; + + let indices = match lexsort_to_indices(&sort_columns, None) { + Ok(indices) => indices, + Err(e) => return Some(Err(e.into())), + }; + self.indices = Some(indices); + + // Call again, this time it will hit the Some(indices) branch and return the first batch + self.next() + } + Some(indices) => { + let batch_size = self.batch_size.min(self.batch.num_rows() - self.cursor); + + // Perform the take to produce the next batch + let new_batch_indices = indices.slice(self.cursor, batch_size); + let new_batch = match take_record_batch(&self.batch, &new_batch_indices) { + Ok(batch) => batch, + Err(e) => return Some(Err(e.into())), + }; + + self.cursor += batch_size; + + // If this is the last batch, we can release the memory + if self.cursor >= self.batch.num_rows() { + let schema = self.batch.schema(); + let _ = mem::replace(&mut self.batch, RecordBatch::new_empty(schema)); + self.indices = None; + } + + // Return the new batch + Some(Ok(new_batch)) + } + } + } + + fn size_hint(&self) -> (usize, Option) { + let num_rows = self.batch.num_rows(); + let batch_size = self.batch_size; + let num_batches = num_rows.div_ceil(batch_size); + (num_batches, Some(num_batches)) + } +} + +impl FusedIterator for IncrementalSortIterator {} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{AsArray, Int32Array}; + use arrow::datatypes::{DataType, Field, Int32Type}; + use datafusion_common::DataFusionError; + use datafusion_physical_expr::expressions::col; + + /// Verifies that `take_record_batch` in `IncrementalSortIterator` actually + /// copies the data into a new allocation rather than returning a zero-copy + /// slice of the original batch. If the output arrays were slices, their + /// underlying buffer length would match the original array's length; a true + /// copy will have a buffer sized to fit only the chunk. + #[test] + fn incremental_sort_iterator_copies_data() -> Result<()> { + let original_len = 10; + let batch_size = 3; + + // Build a batch with a single Int32 column of descending values + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let col_a: Int32Array = Int32Array::from(vec![0; original_len]); + let batch = RecordBatch::try_new(schema, vec![Arc::new(col_a)])?; + + // Sort ascending on column "a" + let expressions = LexOrdering::new(vec![PhysicalSortExpr::new_default(col( + "a", + &batch.schema(), + )?)]) + .unwrap(); + + let mut total_rows = 0; + IncrementalSortIterator::new(batch.clone(), expressions, batch_size).try_for_each( + |result| { + let chunk = result?; + total_rows += chunk.num_rows(); + + // Every output column must be a fresh allocation whose length + // equals the chunk size, NOT the original array length. + chunk.columns().iter().zip(batch.columns()).for_each(|(arr, original_arr)| { + let (_, scalar_buf, _) = arr.as_primitive::().clone().into_parts(); + let (_, original_scalar_buf, _) = original_arr.as_primitive::().clone().into_parts(); + + assert_ne!(scalar_buf.inner().data_ptr(), original_scalar_buf.inner().data_ptr(), "Expected a copy of the data for each chunk, but got a slice that shares the same buffer as the original array"); + }); + + Result::<_, DataFusionError>::Ok(()) + }, + )?; + + assert_eq!(total_rows, original_len); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs index 047fbd8cbd81d..8129c3d8f695d 100644 --- a/datafusion/physical-plan/src/sorts/streaming_merge.rs +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -27,11 +27,11 @@ use crate::sorts::{ use crate::{SendableRecordBatchStream, SpillManager}; use arrow::array::*; use arrow::datatypes::{DataType, SchemaRef}; -use datafusion_common::{assert_or_internal_err, internal_err, Result}; +use datafusion_common::human_readable_size; +use datafusion_common::{Result, assert_or_internal_err, internal_err}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{ - human_readable_size, MemoryConsumer, MemoryPool, MemoryReservation, - UnboundedMemoryPool, + MemoryConsumer, MemoryPool, MemoryReservation, UnboundedMemoryPool, }; use datafusion_physical_expr_common::sort_expr::LexOrdering; use std::sync::Arc; diff --git a/datafusion/physical-plan/src/spill/in_progress_spill_file.rs b/datafusion/physical-plan/src/spill/in_progress_spill_file.rs index e7f354a73b4cd..e0548bd5bf860 100644 --- a/datafusion/physical-plan/src/spill/in_progress_spill_file.rs +++ b/datafusion/physical-plan/src/spill/in_progress_spill_file.rs @@ -24,7 +24,10 @@ use arrow::array::RecordBatch; use datafusion_common::exec_datafusion_err; use datafusion_execution::disk_manager::RefCountedTempFile; -use super::{spill_manager::SpillManager, IPCStreamWriter}; +use super::{ + IPCStreamWriter, gc_view_arrays, + spill_manager::{GetSlicedSize, SpillManager}, +}; /// Represents an in-progress spill file used for writing `RecordBatch`es to disk, created by `SpillManager`. /// Caller is able to use this struct to incrementally append in-memory batches to @@ -51,19 +54,32 @@ impl InProgressSpillFile { /// Appends a `RecordBatch` to the spill file, initializing the writer if necessary. /// + /// Before writing, performs GC on StringView/BinaryView arrays to compact backing + /// buffers. When a view array is sliced, it still references the original full buffers, + /// causing massive spill files without GC (see issue #19414: 820MB → 33MB after GC). + /// + /// Returns the post-GC sliced memory size of the batch for memory accounting. + /// /// # Errors /// - Returns an error if the file is not active (has been finalized) /// - Returns an error if appending would exceed the disk usage limit configured /// by `max_temp_directory_size` in `DiskManager` - pub fn append_batch(&mut self, batch: &RecordBatch) -> Result<()> { + pub fn append_batch(&mut self, batch: &RecordBatch) -> Result { if self.in_progress_file.is_none() { return Err(exec_datafusion_err!( "Append operation failed: No active in-progress file. The file may have already been finalized." )); } + + let gc_batch = gc_view_arrays(batch)?; + if self.writer.is_none() { - let schema = batch.schema(); - if let Some(ref in_progress_file) = self.in_progress_file { + // Use the SpillManager's declared schema rather than the batch's schema. + // Individual batches may have different schemas (e.g., different nullability) + // when they come from different branches of a UnionExec. The SpillManager's + // schema represents the canonical schema that all batches should conform to. + let schema = self.spill_writer.schema(); + if let Some(in_progress_file) = &mut self.in_progress_file { self.writer = Some(IPCStreamWriter::new( in_progress_file.path(), schema.as_ref(), @@ -72,18 +88,38 @@ impl InProgressSpillFile { // Update metrics self.spill_writer.metrics.spill_file_count.add(1); + + // Update initial size (schema/header) + in_progress_file.update_disk_usage()?; + let initial_size = in_progress_file.current_disk_usage(); + self.spill_writer + .metrics + .spilled_bytes + .add(initial_size as usize); } } if let Some(writer) = &mut self.writer { - let (spilled_rows, _) = writer.write(batch)?; + let (spilled_rows, _) = writer.write(&gc_batch)?; if let Some(in_progress_file) = &mut self.in_progress_file { + let pre_size = in_progress_file.current_disk_usage(); in_progress_file.update_disk_usage()?; + let post_size = in_progress_file.current_disk_usage(); + + self.spill_writer.metrics.spilled_rows.add(spilled_rows); + self.spill_writer + .metrics + .spilled_bytes + .add((post_size - pre_size) as usize); } else { unreachable!() // Already checked inside current function } + } + gc_batch.get_sliced_size() + } - // Update metrics - self.spill_writer.metrics.spilled_rows.add(spilled_rows); + pub fn flush(&mut self) -> Result<()> { + if let Some(writer) = &mut self.writer { + writer.flush()?; } Ok(()) } @@ -106,11 +142,89 @@ impl InProgressSpillFile { // Since spill files are append-only, add the file size to spilled_bytes if let Some(in_progress_file) = &mut self.in_progress_file { // Since writer.finish() writes continuation marker and message length at the end + let pre_size = in_progress_file.current_disk_usage(); in_progress_file.update_disk_usage()?; - let size = in_progress_file.current_disk_usage(); - self.spill_writer.metrics.spilled_bytes.add(size as usize); + let post_size = in_progress_file.current_disk_usage(); + self.spill_writer + .metrics + .spilled_bytes + .add((post_size - pre_size) as usize); } Ok(self.in_progress_file.take()) } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int64Array; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + use datafusion_physical_expr_common::metrics::{ + ExecutionPlanMetricsSet, SpillMetrics, + }; + use futures::TryStreamExt; + + #[tokio::test] + async fn test_spill_file_uses_spill_manager_schema() -> Result<()> { + let nullable_schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int64, false), + Field::new("val", DataType::Int64, true), + ])); + let non_nullable_schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int64, false), + Field::new("val", DataType::Int64, false), + ])); + + let runtime = Arc::new(RuntimeEnvBuilder::new().build()?); + let metrics_set = ExecutionPlanMetricsSet::new(); + let spill_metrics = SpillMetrics::new(&metrics_set, 0); + let spill_manager = Arc::new(SpillManager::new( + runtime, + spill_metrics, + Arc::clone(&nullable_schema), + )); + + let mut in_progress = spill_manager.create_in_progress_file("test")?; + + // First batch: non-nullable val (simulates literal-0 UNION branch) + let non_nullable_batch = RecordBatch::try_new( + Arc::clone(&non_nullable_schema), + vec![ + Arc::new(Int64Array::from(vec![1, 2, 3])), + Arc::new(Int64Array::from(vec![0, 0, 0])), + ], + )?; + in_progress.append_batch(&non_nullable_batch)?; + + // Second batch: nullable val with NULLs (simulates table UNION branch) + let nullable_batch = RecordBatch::try_new( + Arc::clone(&nullable_schema), + vec![ + Arc::new(Int64Array::from(vec![4, 5, 6])), + Arc::new(Int64Array::from(vec![Some(10), None, Some(30)])), + ], + )?; + in_progress.append_batch(&nullable_batch)?; + + let spill_file = in_progress.finish()?.unwrap(); + + let stream = spill_manager.read_spill_as_stream(spill_file, None)?; + + // Stream schema should be nullable + assert_eq!(stream.schema(), nullable_schema); + + let batches = stream.try_collect::>().await?; + assert_eq!(batches.len(), 2); + + // Both batches must have the SpillManager's nullable schema + assert_eq!( + batches[0], + non_nullable_batch.with_schema(Arc::clone(&nullable_schema))? + ); + assert_eq!(batches[1], nullable_batch); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index 6be7edcf32918..3c95a1da5b33c 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -18,9 +18,12 @@ //! Defines the spilling functions pub(crate) mod in_progress_spill_file; +pub(crate) mod replayable_spill_input; pub(crate) mod spill_manager; pub mod spill_pool; +// Moved for refactor, re-export to keep the public API stable +pub use datafusion_common::utils::memory::get_record_batch_memory_size; // Re-export SpillManager for doctests only (hidden from public docs) #[doc(hidden)] pub use spill_manager::SpillManager; @@ -29,26 +32,31 @@ use std::fs::File; use std::io::BufReader; use std::path::{Path, PathBuf}; use std::pin::Pin; -use std::ptr::NonNull; use std::sync::Arc; use std::task::{Context, Poll}; -use arrow::array::{layout, ArrayData, BufferSpec}; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::array::{ + Array, ArrayRef, BinaryViewArray, BufferSpec, GenericByteViewArray, StringViewArray, + layout, make_array, +}; +use arrow::datatypes::DataType; +use arrow::datatypes::{ByteViewType, Schema, SchemaRef}; use arrow::ipc::{ + MetadataVersion, reader::StreamReader, writer::{IpcWriteOptions, StreamWriter}, - MetadataVersion, }; use arrow::record_batch::RecordBatch; +use arrow_data::ArrayDataBuilder; +use arrow_ipc::CompressionType; use datafusion_common::config::SpillCompression; -use datafusion_common::{exec_datafusion_err, DataFusionError, HashSet, Result}; +use datafusion_common::{DataFusionError, Result, exec_datafusion_err, exec_err}; use datafusion_common_runtime::SpawnedTask; -use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::RecordBatchStream; +use datafusion_execution::disk_manager::RefCountedTempFile; use futures::{FutureExt as _, Stream}; -use log::warn; +use log::debug; /// Stream that reads spill files from disk where each batch is read in a spawned blocking task /// It will read one batch at a time and will not do any buffering, to buffer data use [`crate::common::spawn_buffered`] @@ -115,6 +123,7 @@ impl SpillReaderStream { unreachable!() }; + let expected_schema = Arc::clone(&self.schema); let task = SpawnedTask::spawn_blocking(move || { let file = BufReader::new(File::open(spill_file.path())?); // SAFETY: DataFusion's spill writer strictly follows Arrow IPC specifications @@ -124,6 +133,21 @@ impl SpillReaderStream { StreamReader::try_new(file, None)?.with_skip_validation(true) }; + // Validate the schema read from Arrow IPC file is the same as the + // schema of the current `SpillManager` + let actual_schema = reader.schema(); + + if actual_schema != expected_schema { + return exec_err!( + "Spill file schema mismatch: expected {}, got {}. \ + The caller must use the same SpillManager that created the spill file to read it.", + expected_schema, + actual_schema + ); + } + + // TODO: Same-schema reads from a different SpillManager still pass today. + // Add a SpillManager UID to IPC metadata and validate it here as well. let next_batch = reader.next().transpose()?; Ok((reader, next_batch)) @@ -153,12 +177,12 @@ impl SpillReaderStream { > max_record_batch_memory + SPILL_BATCH_MEMORY_MARGIN { - warn!( - "Record batch memory usage ({actual_size} bytes) exceeds the expected limit ({max_record_batch_memory} bytes) \n\ + debug!( + "Record batch memory usage ({actual_size} bytes) exceeds the expected limit ({max_record_batch_memory} bytes) \n\ by more than the allowed tolerance ({SPILL_BATCH_MEMORY_MARGIN} bytes).\n\ This likely indicates a bug in memory accounting during spilling.\n\ Please report this issue in https://github.com/apache/datafusion/issues/17340." - ); + ); } } self.state = SpillReaderStreamState::Waiting(reader); @@ -250,74 +274,6 @@ pub fn spill_record_batch_by_size( Ok(()) } -/// Calculate total used memory of this batch. -/// -/// This function is used to estimate the physical memory usage of the `RecordBatch`. -/// It only counts the memory of large data `Buffer`s, and ignores metadata like -/// types and pointers. -/// The implementation will add up all unique `Buffer`'s memory -/// size, due to: -/// - The data pointer inside `Buffer` are memory regions returned by global memory -/// allocator, those regions can't have overlap. -/// - The actual used range of `ArrayRef`s inside `RecordBatch` can have overlap -/// or reuse the same `Buffer`. For example: taking a slice from `Array`. -/// -/// Example: -/// For a `RecordBatch` with two columns: `col1` and `col2`, two columns are pointing -/// to a sub-region of the same buffer. -/// -/// {xxxxxxxxxxxxxxxxxxx} <--- buffer -/// ^ ^ ^ ^ -/// | | | | -/// col1->{ } | | -/// col2--------->{ } -/// -/// In the above case, `get_record_batch_memory_size` will return the size of -/// the buffer, instead of the sum of `col1` and `col2`'s actual memory size. -/// -/// Note: Current `RecordBatch`.get_array_memory_size()` will double count the -/// buffer memory size if multiple arrays within the batch are sharing the same -/// `Buffer`. This method provides temporary fix until the issue is resolved: -/// -pub fn get_record_batch_memory_size(batch: &RecordBatch) -> usize { - // Store pointers to `Buffer`'s start memory address (instead of actual - // used data region's pointer represented by current `Array`) - let mut counted_buffers: HashSet> = HashSet::new(); - let mut total_size = 0; - - for array in batch.columns() { - let array_data = array.to_data(); - count_array_data_memory_size(&array_data, &mut counted_buffers, &mut total_size); - } - - total_size -} - -/// Count the memory usage of `array_data` and its children recursively. -fn count_array_data_memory_size( - array_data: &ArrayData, - counted_buffers: &mut HashSet>, - total_size: &mut usize, -) { - // Count memory usage for `array_data` - for buffer in array_data.buffers() { - if counted_buffers.insert(buffer.data_ptr()) { - *total_size += buffer.capacity(); - } // Otherwise the buffer's memory is already counted - } - - if let Some(null_buffer) = array_data.nulls() { - if counted_buffers.insert(null_buffer.inner().inner().data_ptr()) { - *total_size += null_buffer.inner().inner().capacity(); - } - } - - // Count all children `ArrayData` recursively - for child in array_data.child_data() { - count_array_data_memory_size(child, counted_buffers, total_size); - } -} - /// Write in Arrow IPC Stream format to a file. /// /// Stream format is used for spill because it supports dictionary replacement, and the random @@ -335,10 +291,21 @@ struct IPCStreamWriter { impl IPCStreamWriter { /// Create new writer + /// + /// # Codec contract + /// + /// `arrow-ipc` must be compiled with the `lz4` and `zstd` features + /// (declared explicitly in `datafusion-physical-plan/Cargo.toml`). If + /// those features are absent, `try_with_compression` will return an + /// error at runtime for [`SpillCompression::Lz4Frame`] and + /// [`SpillCompression::Zstd`] variants. The Cargo dependency keeps this + /// contract local and build-visible during Cargo feature resolution, + /// rather than relying solely on workspace-level feature unification; + /// see #21917. pub fn new( path: &Path, schema: &Schema, - compression_type: SpillCompression, + spill_compression: SpillCompression, ) -> Result { let file = File::create(path).map_err(|e| { exec_datafusion_err!("(Hint: you may increase the file descriptor limit with shell command 'ulimit -n 4096') Failed to create partition file at {path:?}: {e:?}") @@ -353,7 +320,8 @@ impl IPCStreamWriter { let alignment = get_max_alignment_for_schema(schema); let mut write_options = IpcWriteOptions::try_new(alignment, false, metadata_version)?; - write_options = write_options.try_with_compression(compression_type.into())?; + let compression_type = Option::::from(spill_compression); + write_options = write_options.try_with_compression(compression_type)?; let writer = StreamWriter::try_new_with_options(file, schema, write_options)?; Ok(Self { @@ -377,6 +345,11 @@ impl IPCStreamWriter { Ok((delta_num_rows, delta_num_bytes)) } + pub fn flush(&mut self) -> Result<()> { + self.writer.flush()?; + Ok(()) + } + /// Finish the writer pub fn finish(&mut self) -> Result<()> { self.writer.finish().map_err(Into::into) @@ -406,6 +379,174 @@ fn get_max_alignment_for_schema(schema: &Schema) -> usize { max_alignment } +/// Size of a single view structure in StringView/BinaryView arrays (in bytes). +/// Each view is 16 bytes: 4 bytes length + 4 bytes prefix + 8 bytes buffer ID/offset. +const VIEW_SIZE_BYTES: usize = 16; + +/// Performs garbage collection on StringView and BinaryView arrays before spilling to reduce memory usage. +/// +/// # Why GC is needed +/// +/// StringView and BinaryView arrays can accumulate significant memory waste when sliced. +/// When a large array is sliced (e.g., taking first 100 rows of 1000), the view array +/// still references the original data buffers containing all 1000 rows of data. +/// +/// For example, in the ClickBench benchmark (issue #19414), repeated slicing of StringView +/// arrays resulted in 820MB of spill files that could be reduced to just 33MB after GC - +/// a 96% reduction in size. +/// +/// # How it works +/// +/// The GC process: +/// 1. Identifies view arrays (StringView/BinaryView) in the batch +/// 2. Checks if their data buffers exceed a memory threshold +/// 3. If exceeded, calls the Arrow `gc()` method which creates new compact buffers +/// containing only the data referenced by the current views +/// 4. Returns a new batch with GC'd arrays (or original arrays if GC not needed) +/// +/// # When GC is triggered +/// +/// GC is only performed when data buffers exceed a threshold (currently 10KB). +/// This balances memory savings against the CPU overhead of garbage collection. +/// Small arrays are passed through unchanged since the GC overhead would exceed +/// any memory savings. +/// +/// # Performance considerations +/// +/// - If no view arrays need compaction, the original batch is cloned cheaply +/// - GC is skipped for small buffers to avoid unnecessary CPU overhead +/// - Nested container types are traversed recursively so view arrays inside +/// `List`, `Map`, `Union`, `Dictionary`, and other child-bearing arrays are compacted too +/// - The Arrow `gc()` method itself is optimized and only copies referenced data +pub(crate) fn gc_view_arrays(batch: &RecordBatch) -> Result { + let mut mutated = false; + let mut new_columns: Vec> = Vec::with_capacity(batch.num_columns()); + + for array in batch.columns() { + let (gc_array, array_mutated) = gc_array(array)?; + mutated |= array_mutated; + new_columns.push(gc_array); + } + + if mutated { + Ok(RecordBatch::try_new(batch.schema(), new_columns)?) + } else { + Ok(batch.clone()) + } +} + +fn gc_array(array: &ArrayRef) -> Result<(ArrayRef, bool)> { + match array.data_type() { + DataType::Utf8View => { + let string_view = array + .as_any() + .downcast_ref::() + .expect("Utf8View array should downcast to StringViewArray"); + if should_gc_view_array(string_view) { + Ok((Arc::new(string_view.gc()) as ArrayRef, true)) + } else { + Ok((Arc::clone(array), false)) + } + } + DataType::BinaryView => { + let binary_view = array + .as_any() + .downcast_ref::() + .expect("BinaryView array should downcast to BinaryViewArray"); + if should_gc_view_array(binary_view) { + Ok((Arc::new(binary_view.gc()) as ArrayRef, true)) + } else { + Ok((Arc::clone(array), false)) + } + } + _ => gc_array_children(array), + } +} + +fn gc_array_children(array: &ArrayRef) -> Result<(ArrayRef, bool)> { + let data = array.to_data(); + if data.child_data().is_empty() { + return Ok((Arc::clone(array), false)); + } + + let mut mutated = false; + let mut child_data = Vec::with_capacity(data.child_data().len()); + for child in data.child_data() { + let child_array = make_array(child.clone()); + let (gc_child, child_mutated) = gc_array(&child_array)?; + mutated |= child_mutated; + child_data.push(gc_child.to_data()); + } + + if !mutated { + return Ok((Arc::clone(array), false)); + } + + let rebuilt = ArrayDataBuilder::new(data.data_type().clone()) + .len(data.len()) + .offset(data.offset()) + .nulls(data.nulls().cloned()) + .buffers(data.buffers().to_vec()) + .child_data(child_data) + .build()?; + + Ok((make_array(rebuilt), true)) +} + +/// Determines whether a view array should be garbage collected before spilling. +/// +/// Arrow's `gc()` always allocates new compact buffers (it is never a no-op), so we +/// check here to skip the allocation cost when data buffers are small. We subtract +/// the views buffer (16 bytes × n_rows) from `get_buffer_memory_size()` so the +/// threshold tracks non-inline string data rather than row count. +fn should_gc_view_array(array: &GenericByteViewArray) -> bool { + const MIN_BUFFER_SIZE_FOR_GC: usize = 10 * 1024; // 10KB threshold + + if array.data_buffers().is_empty() { + return false; + } + + let data_buffer_size = array + .get_buffer_memory_size() + .saturating_sub(array.len() * VIEW_SIZE_BYTES); + data_buffer_size > MIN_BUFFER_SIZE_FOR_GC +} + +#[cfg(test)] +fn calculate_string_view_waste_ratio(array: &StringViewArray) -> f64 { + use arrow_data::MAX_INLINE_VIEW_LEN; + calculate_view_waste_ratio(array.len(), array.data_buffers(), |i| { + if !array.is_null(i) { + let value = array.value(i); + if value.len() > MAX_INLINE_VIEW_LEN as usize { + return value.len(); + } + } + 0 + }) +} + +#[cfg(test)] +fn calculate_view_waste_ratio( + len: usize, + data_buffers: &[arrow::buffer::Buffer], + get_value_size: F, +) -> f64 +where + F: Fn(usize) -> usize, +{ + let total_buffer_size: usize = data_buffers.iter().map(|b| b.capacity()).sum(); + if total_buffer_size == 0 { + return 0.0; + } + + let mut actual_used_size = (0..len).map(get_value_size).sum::(); + actual_used_size += len * VIEW_SIZE_BYTES; + + let waste = total_buffer_size.saturating_sub(actual_used_size); + waste as f64 / total_buffer_size as f64 +} + #[cfg(test)] mod tests { use super::in_progress_spill_file::InProgressSpillFile; @@ -415,16 +556,12 @@ mod tests { use crate::metrics::SpillMetrics; use crate::spill::spill_manager::SpillManager; use crate::test::build_table_i32; - use arrow::array::{ArrayRef, Float64Array, Int32Array, ListArray, StringArray}; + use arrow::array::{ArrayRef, Int32Array, StringArray}; use arrow::compute::cast; - use arrow::datatypes::{DataType, Field, Int32Type, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion_common::Result; + use arrow::datatypes::{DataType, Field}; use datafusion_execution::runtime_env::RuntimeEnv; use futures::StreamExt as _; - use std::sync::Arc; - #[tokio::test] async fn test_batch_spill_and_read() -> Result<()> { let batch1 = build_table_i32( @@ -539,11 +676,12 @@ mod tests { let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema)); + let row_batches: Vec = + (0..batch1.num_rows()).map(|i| batch1.slice(i, 1)).collect(); let (spill_file, max_batch_mem) = spill_manager - .spill_record_batch_by_size_and_return_max_batch_memory( - &batch1, + .spill_record_batch_iter_and_return_max_batch_memory( + row_batches.iter().map(Ok), "Test Spill", - 1, )? .unwrap(); assert!(spill_file.path().exists()); @@ -665,133 +803,6 @@ mod tests { Ok(()) } - #[test] - fn test_get_record_batch_memory_size() { - // Create a simple record batch with two columns - let schema = Arc::new(Schema::new(vec![ - Field::new("ints", DataType::Int32, true), - Field::new("float64", DataType::Float64, false), - ])); - - let int_array = - Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); - let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); - - let batch = RecordBatch::try_new( - schema, - vec![Arc::new(int_array), Arc::new(float64_array)], - ) - .unwrap(); - - let size = get_record_batch_memory_size(&batch); - assert_eq!(size, 60); - } - - #[test] - fn test_get_record_batch_memory_size_with_null() { - // Create a simple record batch with two columns - let schema = Arc::new(Schema::new(vec![ - Field::new("ints", DataType::Int32, true), - Field::new("float64", DataType::Float64, false), - ])); - - let int_array = Int32Array::from(vec![None, Some(2), Some(3)]); - let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0]); - - let batch = RecordBatch::try_new( - schema, - vec![Arc::new(int_array), Arc::new(float64_array)], - ) - .unwrap(); - - let size = get_record_batch_memory_size(&batch); - assert_eq!(size, 100); - } - - #[test] - fn test_get_record_batch_memory_size_empty() { - // Test with empty record batch - let schema = Arc::new(Schema::new(vec![Field::new( - "ints", - DataType::Int32, - false, - )])); - - let int_array: Int32Array = Int32Array::from(vec![] as Vec); - let batch = RecordBatch::try_new(schema, vec![Arc::new(int_array)]).unwrap(); - - let size = get_record_batch_memory_size(&batch); - assert_eq!(size, 0, "Empty batch should have 0 memory size"); - } - - #[test] - fn test_get_record_batch_memory_size_shared_buffer() { - // Test with slices that share the same underlying buffer - let original = Int32Array::from(vec![1, 2, 3, 4, 5]); - let slice1 = original.slice(0, 3); - let slice2 = original.slice(2, 3); - - // `RecordBatch` with `original` array - // ---- - let schema_origin = Arc::new(Schema::new(vec![Field::new( - "origin_col", - DataType::Int32, - false, - )])); - let batch_origin = - RecordBatch::try_new(schema_origin, vec![Arc::new(original)]).unwrap(); - - // `RecordBatch` with all columns are reference to `original` array - // ---- - let schema = Arc::new(Schema::new(vec![ - Field::new("slice1", DataType::Int32, false), - Field::new("slice2", DataType::Int32, false), - ])); - - let batch_sliced = - RecordBatch::try_new(schema, vec![Arc::new(slice1), Arc::new(slice2)]) - .unwrap(); - - // Two sizes should all be only counting the buffer in `original` array - let size_origin = get_record_batch_memory_size(&batch_origin); - let size_sliced = get_record_batch_memory_size(&batch_sliced); - - assert_eq!(size_origin, size_sliced); - } - - #[test] - fn test_get_record_batch_memory_size_nested_array() { - let schema = Arc::new(Schema::new(vec![ - Field::new( - "nested_int", - DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), - false, - ), - Field::new( - "nested_int2", - DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))), - false, - ), - ])); - - let int_list_array = ListArray::from_iter_primitive::(vec![ - Some(vec![Some(1), Some(2), Some(3)]), - ]); - - let int_list_array2 = ListArray::from_iter_primitive::(vec![ - Some(vec![Some(4), Some(5), Some(6)]), - ]); - - let batch = RecordBatch::try_new( - schema, - vec![Arc::new(int_list_array), Arc::new(int_list_array2)], - ) - .unwrap(); - - let size = get_record_batch_memory_size(&batch); - assert_eq!(size, 8208); - } - // ==== Spill manager tests ==== #[test] @@ -879,13 +890,13 @@ mod tests { Arc::new(StringArray::from(vec!["d", "e", "f"])), ], )?; - // After appending each batch, spilled_rows should increase, while spill_file_count and - // spilled_bytes remain the same (spilled_bytes is updated only after finish() is called) + // After appending each batch, spilled_rows and spilled_bytes should increase incrementally, + // while spill_file_count remains 1 (since we're writing to the same file) in_progress_file.append_batch(&batch1)?; - verify_metrics(&in_progress_file, 1, 0, 3)?; + verify_metrics(&in_progress_file, 1, 440, 3)?; in_progress_file.append_batch(&batch2)?; - verify_metrics(&in_progress_file, 1, 0, 6)?; + verify_metrics(&in_progress_file, 1, 704, 6)?; let completed_file = in_progress_file.finish()?; assert!(completed_file.is_some()); @@ -920,7 +931,7 @@ mod tests { let completed_file = spill_manager.spill_record_batch_and_finish(&[], "Test")?; assert!(completed_file.is_none()); - // Test write empty batch with interface `spill_record_batch_by_size_and_return_max_batch_memory()` + // Test write empty batch with interface `spill_record_batch_iter_and_return_max_batch_memory()` let empty_batch = RecordBatch::try_new( Arc::clone(&schema), vec![ @@ -929,10 +940,9 @@ mod tests { ], )?; let completed_file = spill_manager - .spill_record_batch_by_size_and_return_max_batch_memory( - &empty_batch, + .spill_record_batch_iter_and_return_max_batch_memory( + std::iter::once(Ok(&empty_batch)), "Test", - 1, )?; assert!(completed_file.is_none()); @@ -993,4 +1003,540 @@ mod tests { assert_eq!(alignment, 8); Ok(()) } + #[tokio::test] + async fn test_real_time_spill_metrics() -> Result<()> { + let env = Arc::new(RuntimeEnv::default()); + let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ])); + + let spill_manager = Arc::new(SpillManager::new( + Arc::clone(&env), + metrics.clone(), + Arc::clone(&schema), + )); + let mut in_progress_file = spill_manager.create_in_progress_file("Test")?; + + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + )?; + + // Before any batch, metrics should be 0 + assert_eq!(metrics.spilled_bytes.value(), 0); + assert_eq!(metrics.spill_file_count.value(), 0); + + // Append first batch + in_progress_file.append_batch(&batch1)?; + + // Metrics should be updated immediately (at least schema and first batch) + let bytes_after_batch1 = metrics.spilled_bytes.value(); + assert_eq!(bytes_after_batch1, 440); + assert_eq!(metrics.spill_file_count.value(), 1); + + // Check global progress + let progress = env.spilling_progress(); + assert_eq!(progress.current_bytes, bytes_after_batch1 as u64); + assert_eq!(progress.active_files_count, 1); + + // Append another batch + in_progress_file.append_batch(&batch1)?; + let bytes_after_batch2 = metrics.spilled_bytes.value(); + assert!(bytes_after_batch2 > bytes_after_batch1); + + // Check global progress again + let progress = env.spilling_progress(); + assert_eq!(progress.current_bytes, bytes_after_batch2 as u64); + + // Finish the file + let spilled_file = in_progress_file.finish()?; + let final_bytes = metrics.spilled_bytes.value(); + assert!(final_bytes > bytes_after_batch2); + + // Even after finish, file is still "active" until dropped + let progress = env.spilling_progress(); + assert!(progress.current_bytes > 0); + assert_eq!(progress.active_files_count, 1); + + drop(spilled_file); + assert_eq!(env.spilling_progress().active_files_count, 0); + assert_eq!(env.spilling_progress().current_bytes, 0); + + Ok(()) + } + + #[test] + fn test_gc_string_view_before_spill() -> Result<()> { + use arrow::array::StringViewArray; + + let strings: Vec = (0..200) + .map(|i| { + if i % 2 == 0 { + "short_string".to_string() + } else { + "this_is_a_much_longer_string_that_will_not_be_inlined".to_string() + } + }) + .collect(); + + let string_array = StringViewArray::from(strings); + let schema = Arc::new(Schema::new(vec![Field::new( + "strings", + DataType::Utf8View, + false, + )])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(string_array) as ArrayRef], + )?; + let sliced_batch = batch.slice(0, 20); + let gc_batch = gc_view_arrays(&sliced_batch)?; + + assert_eq!(gc_batch.num_rows(), sliced_batch.num_rows()); + assert_eq!(gc_batch.num_columns(), sliced_batch.num_columns()); + + Ok(()) + } + + #[test] + fn test_gc_binary_view_before_spill() -> Result<()> { + use arrow::array::BinaryViewArray; + + let binaries: Vec> = (0..200) + .map(|i| { + if i % 2 == 0 { + vec![1, 2, 3, 4] + } else { + vec![1; 50] + } + }) + .collect(); + + let binary_array = + BinaryViewArray::from_iter(binaries.iter().map(|b| Some(b.as_slice()))); + let schema = Arc::new(Schema::new(vec![Field::new( + "binaries", + DataType::BinaryView, + false, + )])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(binary_array) as ArrayRef], + )?; + let sliced_batch = batch.slice(0, 20); + let gc_batch = gc_view_arrays(&sliced_batch)?; + + assert_eq!(gc_batch.num_rows(), sliced_batch.num_rows()); + assert_eq!(gc_batch.num_columns(), sliced_batch.num_columns()); + + Ok(()) + } + + #[test] + fn test_gc_skips_small_arrays() -> Result<()> { + use arrow::array::StringViewArray; + + let strings: Vec = (0..10).map(|i| format!("string_{i}")).collect(); + + let string_array = StringViewArray::from(strings); + let array_ref: ArrayRef = Arc::new(string_array); + + let schema = Arc::new(Schema::new(vec![Field::new( + "strings", + DataType::Utf8View, + false, + )])); + + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![array_ref])?; + + // GC should return the original batch for small arrays + let should_gc = should_gc_view_array( + batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(), + ); + let gc_batch = gc_view_arrays(&batch)?; + + assert!(!should_gc); + assert_eq!(gc_batch.num_rows(), batch.num_rows()); + assert!(Arc::ptr_eq(batch.column(0), gc_batch.column(0))); + + Ok(()) + } + + #[test] + fn test_gc_with_mixed_columns() -> Result<()> { + use arrow::array::{Int32Array, StringViewArray}; + + let strings: Vec = (0..200) + .map(|i| format!("long_string_for_gc_testing_{i}")) + .collect(); + + let string_array = StringViewArray::from(strings); + let int_array = Int32Array::from((0..200).collect::>()); + + let schema = Arc::new(Schema::new(vec![ + Field::new("strings", DataType::Utf8View, false), + Field::new("ints", DataType::Int32, false), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(string_array) as ArrayRef, + Arc::new(int_array) as ArrayRef, + ], + )?; + + let sliced_batch = batch.slice(0, 50); + let gc_batch = gc_view_arrays(&sliced_batch)?; + + assert_eq!(gc_batch.num_columns(), 2); + assert_eq!(gc_batch.num_rows(), 50); + + Ok(()) + } + + #[test] + fn test_verify_gc_triggers_for_sliced_arrays() -> Result<()> { + let strings: Vec = (0..200) + .map(|i| { + format!( + "http://example.com/very/long/path/that/exceeds/inline/threshold/{i}" + ) + }) + .collect(); + + let string_array = StringViewArray::from(strings); + let schema = Arc::new(Schema::new(vec![Field::new( + "url", + DataType::Utf8View, + false, + )])); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(string_array.clone()) as ArrayRef], + )?; + + let sliced = batch.slice(0, 20); + + let sliced_array = sliced + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let should_gc = should_gc_view_array(sliced_array); + let waste_ratio = calculate_string_view_waste_ratio(sliced_array); + + assert!( + waste_ratio > 0.8, + "Waste ratio should be > 0.8 for sliced array" + ); + assert!( + should_gc, + "GC should trigger for sliced array with high waste" + ); + + Ok(()) + } + + #[test] + fn test_reproduce_issue_19414_string_view_spill_without_gc() -> Result<()> { + use arrow::array::StringViewArray; + use std::fs; + + let num_rows = 1000; + let mut strings = Vec::with_capacity(num_rows); + + for i in 0..num_rows { + let url = match i % 5 { + 0 => format!( + "http://irr.ru/index.php?showalbum/login-leniya7777294,938303130/{i}" + ), + 1 => format!("http://komme%2F27.0.1453.116/very/long/path/{i}"), + 2 => format!("https://produkty%2Fproduct/category/item/{i}"), + 3 => format!( + "http://irr.ru/index.php?showalbum/login-kapusta-advert2668/{i}" + ), + 4 => format!( + "http://irr.ru/index.php?showalbum/login-kapustic/product/{i}" + ), + _ => unreachable!(), + }; + strings.push(url); + } + + let string_array = StringViewArray::from(strings); + let schema = Arc::new(Schema::new(vec![Field::new( + "URL", + DataType::Utf8View, + false, + )])); + + let original_batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(string_array.clone()) as ArrayRef], + )?; + + let total_buffer_size: usize = string_array + .data_buffers() + .iter() + .map(|buffer| buffer.capacity()) + .sum(); + + let mut sliced_batches = Vec::new(); + let slice_size = 100; + + for i in (0..num_rows).step_by(slice_size) { + let len = std::cmp::min(slice_size, num_rows - i); + let sliced = original_batch.slice(i, len); + sliced_batches.push(sliced); + } + + let env = Arc::new(RuntimeEnv::default()); + let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let spill_manager = SpillManager::new(env, metrics, schema); + + let mut in_progress_file = spill_manager.create_in_progress_file("Test GC")?; + + for batch in &sliced_batches { + in_progress_file.append_batch(batch)?; + } + + let spill_file = in_progress_file.finish()?.unwrap(); + let file_size = fs::metadata(spill_file.path())?.len() as usize; + + let theoretical_without_gc = total_buffer_size * sliced_batches.len(); + let reduction_percent = ((theoretical_without_gc - file_size) as f64 + / theoretical_without_gc as f64) + * 100.0; + + assert!( + reduction_percent > 80.0, + "GC should reduce spill file size by >80%, got {reduction_percent:.1}%" + ); + + Ok(()) + } + + #[test] + fn test_spill_with_and_without_gc_comparison() -> Result<()> { + let num_rows = 400; + let strings: Vec = (0..num_rows) + .map(|i| { + format!( + "http://example.com/this/is/a/long/url/path/that/wont/be/inlined/{i}" + ) + }) + .collect(); + + let string_array = StringViewArray::from(strings); + let schema = Arc::new(Schema::new(vec![Field::new( + "url", + DataType::Utf8View, + false, + )])); + + let batch = + RecordBatch::try_new(schema, vec![Arc::new(string_array) as ArrayRef])?; + + let sliced_batch = batch.slice(0, 40); + + let array_without_gc = sliced_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let size_without_gc: usize = array_without_gc + .data_buffers() + .iter() + .map(|buffer| buffer.capacity()) + .sum(); + + let gc_batch = gc_view_arrays(&sliced_batch)?; + let array_with_gc = gc_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let size_with_gc: usize = array_with_gc + .data_buffers() + .iter() + .map(|buffer| buffer.capacity()) + .sum(); + + let reduction_percent = + ((size_without_gc - size_with_gc) as f64 / size_without_gc as f64) * 100.0; + + assert!( + reduction_percent > 85.0, + "Expected >85% reduction for 10% slice, got {reduction_percent:.1}%" + ); + + Ok(()) + } + + #[test] + fn test_gc_recurses_into_nested_view_arrays() -> Result<()> { + use arrow::array::{DictionaryArray, Int32Array}; + use arrow::buffer::Buffer; + + let strings: Vec = (0..200) + .map(|i| format!("http://example.com/nested/path/that/is/not/inlined/{i}")) + .collect(); + let string_values = Arc::new(StringViewArray::from(strings)) as ArrayRef; + + let list_data = ArrayDataBuilder::new(DataType::List(Arc::new( + Field::new_list_field(DataType::Utf8View, true), + ))) + .len(20) + .buffers(vec![Buffer::from_iter((0..=20).map(|i| i * 5_i32))]) + .child_data(vec![string_values.slice(0, 100).to_data()]) + .build()?; + let list_array = make_array(list_data); + + let keys = Int32Array::from_iter_values(0..20); + let dictionary = DictionaryArray::new(keys, string_values.slice(0, 20)); + let dictionary_array = Arc::new(dictionary) as ArrayRef; + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "list_strings", + DataType::List(Arc::new(Field::new_list_field(DataType::Utf8View, true))), + false, + ), + Field::new( + "dictionary_strings", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8View), + ), + false, + ), + ])); + let batch = RecordBatch::try_new(schema, vec![list_array, dictionary_array])?; + let gc_batch = gc_view_arrays(&batch)?; + + let gc_list_values = gc_batch.column(0).to_data().child_data()[0].clone(); + let gc_list_values = make_array(gc_list_values); + let gc_list_values = gc_list_values + .as_any() + .downcast_ref::() + .unwrap(); + assert!( + calculate_string_view_waste_ratio(gc_list_values) < 0.2, + "GC should compact nested List child views" + ); + + let gc_dictionary_values = gc_batch.column(1).to_data().child_data()[0].clone(); + let gc_dictionary_values = make_array(gc_dictionary_values); + let gc_dictionary_values = gc_dictionary_values + .as_any() + .downcast_ref::() + .unwrap(); + assert!( + calculate_string_view_waste_ratio(gc_dictionary_values) < 0.2, + "GC should compact nested Dictionary values" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_spill_file_size_gc_verification_string_view() -> Result<()> { + use arrow::array::StringViewArray; + use std::fs; + + // 1. Setup bloated data (large buffers) + let num_rows = 1000; + let string_array: StringViewArray = (0..num_rows) + .map(|i| Some(format!("this_is_a_long_string_to_ensure_it_is_not_inlined_and_causes_waste_{i}"))) + .collect(); + let schema = Arc::new(Schema::new(vec![Field::new( + "s", + DataType::Utf8View, + false, + )])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(string_array.clone()) as ArrayRef], + )?; + + // 2. Slice it heavily (1% of the data) + let sliced_batch = batch.slice(0, 10); + + // 3. Spill to disk using SpillManager + let env = Arc::new(RuntimeEnv::default()); + let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let spill_manager = SpillManager::new(env, metrics, schema); + let spill_file = spill_manager + .spill_record_batch_and_finish(&[sliced_batch], "TestGC")? + .unwrap(); + + // 4. Check file size on disk + let file_size = fs::metadata(spill_file.path())?.len(); + + // The original buffer size is around 70KB. + // Without GC, the spill file would be > 70KB. + // With GC, it should be much smaller (only 10 rows of ~70 bytes each + metadata). + assert!( + file_size < 10 * 1024, + "Spill file is too large ({file_size} bytes)! GC might not be working." + ); + + Ok(()) + } + + #[tokio::test] + async fn test_spill_file_size_gc_verification_binary_view() -> Result<()> { + use arrow::array::BinaryViewArray; + use std::fs; + + // 1. Setup bloated data (large buffers) + let num_rows = 1000; + let binary_array: BinaryViewArray = + (0..num_rows).map(|i| Some(vec![i as u8; 100])).collect(); + let schema = Arc::new(Schema::new(vec![Field::new( + "b", + DataType::BinaryView, + false, + )])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(binary_array.clone()) as ArrayRef], + )?; + + // 2. Slice it heavily (1% of the data) + let sliced_batch = batch.slice(0, 10); + + // 3. Spill to disk using SpillManager + let env = Arc::new(RuntimeEnv::default()); + let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let spill_manager = SpillManager::new(env, metrics, schema); + let spill_file = spill_manager + .spill_record_batch_and_finish(&[sliced_batch], "TestGCBinary")? + .unwrap(); + + // 4. Check file size on disk + let file_size = fs::metadata(spill_file.path())?.len(); + + // Original buffer is 100KB. + // With GC, it should be much smaller. + assert!( + file_size < 10 * 1024, + "Spill file is too large ({file_size} bytes)! GC might not be working." + ); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/spill/replayable_spill_input.rs b/datafusion/physical-plan/src/spill/replayable_spill_input.rs new file mode 100644 index 0000000000000..fea998d268c59 --- /dev/null +++ b/datafusion/physical-plan/src/spill/replayable_spill_input.rs @@ -0,0 +1,448 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utility for replaying a one-shot input `RecordBatchStream` through spill. +//! +//! See comments in [`ReplayableStreamSource`] for details. + +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::{Result, internal_err}; +use datafusion_execution::RecordBatchStream; +use datafusion_execution::SendableRecordBatchStream; +use datafusion_execution::disk_manager::RefCountedTempFile; +use futures::Stream; +use parking_lot::Mutex; + +use crate::EmptyRecordBatchStream; +use crate::spill::in_progress_spill_file::InProgressSpillFile; +use crate::spill::spill_manager::SpillManager; + +/// Spill-backed replayable stream source. +/// +/// [`ReplayableStreamSource`] is constructed from an input stream, usually produced +/// by executing an input `ExecutionPlan`. +/// +/// - On the first pass, it evaluates the input stream, produces `RecordBatch`es, +/// caches those batches to a local spill file, and also forwards them to the +/// output. +/// - On subsequent passes, it reads directly from the spill file. +/// +/// ```text +/// first pass: +/// +/// RecordBatch stream +/// | +/// v +/// [batch] -> output +/// | +/// +----> spill file +/// +/// +/// later passes: +/// +/// spill file +/// | +/// v +/// [batch] -> output +/// ``` +/// +/// This is useful when an input stream must be replayed and: +/// - Re-evaluation is expensive because the input stream may come from a long +/// and complex pipeline. +/// - The parent operator is under memory pressure and cannot cache the input in +/// memory for replay. +/// +/// # Concurrency assumption +/// Passes must be opened and consumed sequentially. +/// Opening another pass before exhausting the current one returns an error. +pub(crate) struct ReplayableStreamSource { + schema: SchemaRef, + input: Option, + spill_manager: SpillManager, + request_description: String, + /// Inner state is owned by either the source or one active stream to ensure + /// sequential access; see struct docs for the concurrency contract. + /// + /// Ownership model: + /// - No active stream: source owns the state (`source.state = Some(state)`). + /// - Active stream: the stream owns the state (`source.state = None`). + state: Arc>>, +} + +/// Inner state exclusively owned by either [`ReplayableStreamSource`] or one [`ReplayableSpillStream`] +enum StateInner { + Unopened, + Replayable(Option), + Poisoned, +} + +impl ReplayableStreamSource { + /// Creates a replayable stream producer over a one-shot input stream. + /// + /// It caches the input into a local spill file on the first pass, then + /// reads directly from that spill file on subsequent passes. + pub(crate) fn new( + input: SendableRecordBatchStream, + spill_manager: SpillManager, + request_description: impl Into, + ) -> Self { + let schema = input.schema(); + Self { + schema, + input: Some(input), + spill_manager, + request_description: request_description.into(), + state: Arc::new(Mutex::new(Some(StateInner::Unopened))), + } + } + + fn set_state(&self, state: StateInner) { + *self.state.lock() = Some(state); + } + + /// Opens the next pass over this input. + /// + /// The first call returns a stream that forwards upstream batches while + /// caching them to spill. Later calls return streams that read directly + /// from the completed spill file. + /// + /// # Note + /// Subsequent passes MUST be opened only after the previous pass is fully + /// consumed; otherwise, an error is returned. + pub(crate) fn open_pass(&mut self) -> Result { + let state = self.state.lock().take(); + let Some(state) = state else { + return internal_err!("ReplayableStreamSource pass is still active"); + }; + + match state { + StateInner::Unopened => { + let Some(input) = self.input.take() else { + self.set_state(StateInner::Poisoned); + return internal_err!( + "ReplayableStreamSource missing first-pass input" + ); + }; + let spill_file = match self + .spill_manager + .create_in_progress_file(&self.request_description) + { + Ok(spill_file) => spill_file, + Err(e) => { + self.input = Some(input); + self.set_state(StateInner::Unopened); + return Err(e); + } + }; + + Ok(Box::pin(ReplayableSpillStream::new_first( + Arc::clone(&self.schema), + input, + Arc::clone(&self.state), + spill_file, + ))) + } + StateInner::Poisoned => { + internal_err!( + "ReplayableStreamSource first pass did not complete successfully" + ) + } + StateInner::Replayable(spill_file) => { + let replay_state = spill_file.clone(); + match ReplayableSpillStream::new_replay( + Arc::clone(&self.schema), + &self.spill_manager, + Arc::clone(&self.state), + spill_file, + ) { + Ok(stream) => Ok(Box::pin(stream)), + Err(e) => { + self.set_state(StateInner::Replayable(replay_state)); + Err(e) + } + } + } + } + } +} + +/// Makes a one-shot stream replayable using spill caching, keeping replays fast +/// and memory efficient. +/// +/// On the first pass, it evaluates and forwards output from `inner` while +/// caching it to a spill file for future replays. +/// +/// On later passes, it replays directly from the cached spill file. +/// +/// See also [`ReplayableStreamSource`] for details. +struct ReplayableSpillStream { + schema: SchemaRef, + shared_state: Arc>>, + held_state: Option, + spill_file: Option, + inner: SendableRecordBatchStream, +} + +impl ReplayableSpillStream { + fn new_first( + schema: SchemaRef, + inner: SendableRecordBatchStream, + shared_state: Arc>>, + spill_file: InProgressSpillFile, + ) -> Self { + Self { + schema, + shared_state, + held_state: Some(StateInner::Unopened), + spill_file: Some(spill_file), + inner, + } + } + + fn new_replay( + schema: SchemaRef, + spill_manager: &SpillManager, + shared_state: Arc>>, + spill_file: Option, + ) -> Result { + let inner = if let Some(file) = spill_file.as_ref() { + spill_manager.read_spill_as_stream(file.clone(), None)? + } else { + Box::pin(EmptyRecordBatchStream::new(Arc::clone(&schema))) + }; + + Ok(Self { + schema, + shared_state, + held_state: Some(StateInner::Replayable(spill_file)), + spill_file: None, + inner, + }) + } + + fn restore_held_state(&mut self) { + if let Some(state) = self.held_state.take() { + *self.shared_state.lock() = Some(state); + } + } + + fn set_state(&mut self, state: StateInner) { + if self.held_state.take().is_some() { + *self.shared_state.lock() = Some(state); + } + } + + fn poison(&mut self) { + self.set_state(StateInner::Poisoned); + } +} + +impl Stream for ReplayableSpillStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + match this.inner.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(batch))) => { + if batch.num_rows() > 0 + && let Some(spill_file) = this.spill_file.as_mut() + && let Err(e) = spill_file.append_batch(&batch) + { + this.spill_file.take(); + this.poison(); + return Poll::Ready(Some(Err(e))); + } + + Poll::Ready(Some(Ok(batch))) + } + Poll::Ready(Some(Err(e))) => { + this.spill_file.take(); + this.poison(); + Poll::Ready(Some(Err(e))) + } + // The stream is exhausted, give the inner state ownership back to `ReplayableStreamSource` + Poll::Ready(None) => { + // Release the input pipeline's resources. + let inner_schema = this.inner.schema(); + this.inner = Box::pin(EmptyRecordBatchStream::new(inner_schema)); + if let Some(spill_file) = this.spill_file.as_mut() { + match spill_file.finish() { + Ok(file) => { + this.spill_file.take(); + this.set_state(StateInner::Replayable(file)); + Poll::Ready(None) + } + Err(e) => { + this.spill_file.take(); + this.poison(); + Poll::Ready(Some(Err(e))) + } + } + } else { + this.restore_held_state(); + Poll::Ready(None) + } + } + Poll::Pending => Poll::Pending, + } + } +} + +impl RecordBatchStream for ReplayableSpillStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +impl Drop for ReplayableSpillStream { + /// If a stream is dropped before it finishes, poison the state so later + /// replay attempts fail. + /// + /// A partial first pass leaves the spill file incomplete, so replaying it + /// would be unsafe. + fn drop(&mut self) { + if self.held_state.is_some() { + self.poison(); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int64Array; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + use datafusion_physical_expr_common::metrics::{ + ExecutionPlanMetricsSet, SpillMetrics, + }; + use futures::{StreamExt, TryStreamExt}; + + use crate::stream::RecordBatchStreamAdapter; + + fn build_spill_manager(schema: SchemaRef) -> Result { + let runtime = Arc::new(RuntimeEnvBuilder::new().build()?); + let metrics_set = ExecutionPlanMetricsSet::new(); + let spill_metrics = SpillMetrics::new(&metrics_set, 0); + Ok(SpillManager::new(runtime, spill_metrics, schema)) + } + + fn build_batch(schema: SchemaRef, values: Vec) -> Result { + RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(values))]) + .map_err(Into::into) + } + + #[tokio::test] + async fn test_replayable_spill_input_replays_completed_first_pass() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let batch1 = build_batch(Arc::clone(&schema), vec![1, 2])?; + let batch2 = build_batch(Arc::clone(&schema), vec![3, 4])?; + + let input = Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter(vec![Ok(batch1.clone()), Ok(batch2.clone())]), + )); + let spill_manager = build_spill_manager(Arc::clone(&schema))?; + let mut replayable = + ReplayableStreamSource::new(input, spill_manager, "test replayable spill"); + + let pass1 = replayable.open_pass()?; + let pass1_batches = pass1.try_collect::>().await?; + assert_eq!(pass1_batches, vec![batch1.clone(), batch2.clone()]); + + let pass2 = replayable.open_pass()?; + let pass2_batches = pass2.try_collect::>().await?; + assert_eq!(pass2_batches, vec![batch1, batch2]); + + Ok(()) + } + + // Try to open a new pass, when the first pass has not finished. + // The spill file is only partially written, so an error will be returned. + #[tokio::test] + async fn test_replayable_spill_input_poisoned_when_first_pass_dropped() -> Result<()> + { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let batch1 = build_batch(Arc::clone(&schema), vec![1, 2])?; + let batch2 = build_batch(Arc::clone(&schema), vec![3, 4])?; + + let input = Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter(vec![Ok(batch1), Ok(batch2)]), + )); + let spill_manager = build_spill_manager(Arc::clone(&schema))?; + let mut replayable = + ReplayableStreamSource::new(input, spill_manager, "test replayable spill"); + + let mut pass1 = replayable.open_pass()?; + let first = pass1.next().await.transpose()?; + assert!(first.is_some()); + drop(pass1); + + let err = match replayable.open_pass() { + Ok(_) => panic!("expected first pass to poison replayable spill input"), + Err(err) => err.strip_backtrace(), + }; + assert!( + err.to_string().contains( + "ReplayableStreamSource first pass did not complete successfully" + ) + ); + + Ok(()) + } + + // Open a new pass, when the previous pass from spill is still in progress. + // An error is expected, since it requires sequential access. + #[tokio::test] + async fn test_replayable_spill_input_errors_when_replay_pass_in_progress() + -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let batch1 = build_batch(Arc::clone(&schema), vec![1, 2])?; + let batch2 = build_batch(Arc::clone(&schema), vec![3, 4])?; + + let input = Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter(vec![Ok(batch1.clone()), Ok(batch2.clone())]), + )); + let spill_manager = build_spill_manager(Arc::clone(&schema))?; + let mut replayable = + ReplayableStreamSource::new(input, spill_manager, "test replayable spill"); + + let pass1 = replayable.open_pass()?; + let _ = pass1.try_collect::>().await?; + + let pass2 = replayable.open_pass()?; + let err = match replayable.open_pass() { + Ok(_) => panic!("expected open_pass to fail while replay pass is active"), + Err(err) => err.strip_backtrace(), + }; + assert!( + err.to_string() + .contains("ReplayableStreamSource pass is still active") + ); + drop(pass2); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 6fd97a8e2e6a0..365a9f977eace 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -17,20 +17,19 @@ //! Define the `SpillManager` struct, which is responsible for reading and writing `RecordBatch`es to raw files based on the provided configurations. -use arrow::array::StringViewArray; -use arrow::datatypes::SchemaRef; +use super::{SpillReaderStream, in_progress_spill_file::InProgressSpillFile}; +use crate::coop::cooperative; +use crate::{common::spawn_buffered, metrics::SpillMetrics}; +use arrow::array::{BinaryViewArray, GenericByteViewArray, StringViewArray}; +use arrow::datatypes::{ByteViewType, SchemaRef}; use arrow::record_batch::RecordBatch; +use datafusion_common::{DataFusionError, Result, config::SpillCompression}; +use datafusion_execution::SendableRecordBatchStream; +use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::runtime_env::RuntimeEnv; +use std::borrow::Borrow; use std::sync::Arc; -use datafusion_common::{config::SpillCompression, Result}; -use datafusion_execution::disk_manager::RefCountedTempFile; -use datafusion_execution::SendableRecordBatchStream; - -use super::{in_progress_spill_file::InProgressSpillFile, SpillReaderStream}; -use crate::coop::cooperative; -use crate::{common::spawn_buffered, metrics::SpillMetrics}; - /// The `SpillManager` is responsible for the following tasks: /// - Reading and writing `RecordBatch`es to raw files based on the provided configurations. /// - Updating the associated metrics. @@ -110,39 +109,27 @@ impl SpillManager { in_progress_file.finish() } - /// Refer to the documentation for [`Self::spill_record_batch_and_finish`]. This method - /// additionally spills the `RecordBatch` into smaller batches, divided by `row_limit`. - /// - /// # Errors - /// - Returns an error if spilling would exceed the disk usage limit configured - /// by `max_temp_directory_size` in `DiskManager` - pub(crate) fn spill_record_batch_by_size_and_return_max_batch_memory( + /// Spill an iterator of `RecordBatch`es to disk and return the spill file and the size of the largest batch in memory + /// Note that this expects the caller to provide *non-sliced* batches, so the memory calculation of each batch is accurate. + pub(crate) fn spill_record_batch_iter_and_return_max_batch_memory( &self, - batch: &RecordBatch, + mut iter: impl Iterator>>, request_description: &str, - row_limit: usize, ) -> Result> { - let total_rows = batch.num_rows(); - let mut batches = Vec::new(); - let mut offset = 0; - - // It's ok to calculate all slices first, because slicing is zero-copy. - while offset < total_rows { - let length = std::cmp::min(total_rows - offset, row_limit); - let sliced_batch = batch.slice(offset, length); - batches.push(sliced_batch); - offset += length; - } - let mut in_progress_file = self.create_in_progress_file(request_description)?; let mut max_record_batch_size = 0; - for batch in batches { - in_progress_file.append_batch(&batch)?; - - max_record_batch_size = max_record_batch_size.max(batch.get_sliced_size()?); - } + iter.try_for_each(|batch| { + let batch = batch?; + let borrowed = batch.borrow(); + if borrowed.num_rows() == 0 { + return Ok(()); + } + let gc_sliced_size = in_progress_file.append_batch(borrowed)?; + max_record_batch_size = max_record_batch_size.max(gc_sliced_size); + Result::<_, DataFusionError>::Ok(()) + })?; let file = in_progress_file.finish()?; @@ -163,9 +150,9 @@ impl SpillManager { while let Some(batch) = stream.next().await { let batch = batch?; - in_progress_file.append_batch(&batch)?; + let gc_sliced_size = in_progress_file.append_batch(&batch)?; - max_record_batch_size = max_record_batch_size.max(batch.get_sliced_size()?); + max_record_batch_size = max_record_batch_size.max(gc_sliced_size); } let file = in_progress_file.finish()?; @@ -173,9 +160,22 @@ impl SpillManager { Ok(file.map(|f| (f, max_record_batch_size))) } - /// Reads a spill file as a stream. The file must be created by the current `SpillManager`. - /// This method will generate output in FIFO order: the batch appended first - /// will be read first. + /// Reads a spill file as a stream. The file must be created by the current + /// `SpillManager`; otherwise an error will be returned. + /// + /// Output is produced in FIFO order: the batch appended first is read first. + /// + /// # Arg `max_record_batch_memory` + /// + /// Most callers should pass `None`. This is mainly useful for the + /// memory-limited sort-preserving merge path. + /// + /// When provided, this value is used only as a validation hint. If a + /// decoded batch exceeds this threshold, a debug-level log message is + /// emitted. + /// + /// That path uses the maximum spilled batch size to conservatively estimate + /// the merge degree when merging multiple sorted runs. pub fn read_spill_as_stream( &self, spill_file_path: RefCountedTempFile, @@ -189,12 +189,25 @@ impl SpillManager { Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) } + + /// Same as `read_spill_as_stream`, but without buffering. + pub fn read_spill_as_stream_unbuffered( + &self, + spill_file_path: RefCountedTempFile, + max_record_batch_memory: Option, + ) -> Result { + Ok(Box::pin(cooperative(SpillReaderStream::new( + Arc::clone(&self.schema), + spill_file_path, + max_record_batch_memory, + )))) + } } pub(crate) trait GetSlicedSize { /// Returns the size of the `RecordBatch` when sliced. /// Note: if multiple arrays or even a single array share the same data buffers, we may double count each buffer. - /// Therefore, make sure we call gc() or organize_stringview_arrays() before using this method. + /// Therefore, make sure we call gc() or gc_view_arrays() before using this method. fn get_sliced_size(&self) -> Result; } @@ -214,26 +227,132 @@ impl GetSlicedSize for RecordBatch { // "bytes needed if we materialized exactly this slice into fresh buffers". // This is a workaround until https://github.com/apache/arrow-rs/issues/8230 if let Some(sv) = array.as_any().downcast_ref::() { - for buffer in sv.data_buffers() { - total += buffer.capacity(); - } + total += byte_view_data_buffer_size(sv); + } + if let Some(bv) = array.as_any().downcast_ref::() { + total += byte_view_data_buffer_size(bv); } } Ok(total) } } +fn byte_view_data_buffer_size(array: &GenericByteViewArray) -> usize { + array + .data_buffers() + .iter() + .map(|buffer| buffer.capacity()) + .sum() +} + #[cfg(test)] mod tests { + use super::SpillManager; + use crate::common::collect; + use crate::metrics::{ExecutionPlanMetricsSet, SpillMetrics}; use crate::spill::{get_record_batch_memory_size, spill_manager::GetSlicedSize}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::{ - array::{ArrayRef, StringViewArray}, + array::{ArrayRef, Int32Array, StringArray, StringViewArray}, record_batch::RecordBatch, }; use datafusion_common::Result; + use datafusion_execution::runtime_env::RuntimeEnv; use std::sync::Arc; + fn build_test_spill_manager( + env: Arc, + schema: Arc, + ) -> SpillManager { + let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + SpillManager::new(env, metrics, schema) + } + + fn build_writer_batch(schema: Arc) -> Result { + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .map_err(Into::into) + } + + #[tokio::test] + async fn test_read_spill_as_stream_from_another_spill_manager_same_schema() + -> Result<()> { + let env = Arc::new(RuntimeEnv::default()); + let writer_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + let reader_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + + let writer = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&writer_schema)); + let reader = build_test_spill_manager(env, Arc::clone(&reader_schema)); + let written_batch = build_writer_batch(Arc::clone(&writer_schema))?; + + let spill_file = writer + .spill_record_batch_and_finish( + std::slice::from_ref(&written_batch), + "writer", + )? + .unwrap(); + + // Same-schema reads through a different SpillManager currently pass + // because only schema compatibility is validated. This is not a + // supported usage pattern. + let stream = reader.read_spill_as_stream(spill_file, None)?; + assert_eq!(stream.schema(), reader_schema); + + let batches = collect(stream).await?; + assert_eq!(batches, vec![written_batch]); + + Ok(()) + } + + #[tokio::test] + async fn test_read_spill_as_stream_from_another_spill_manager_different_schema() + -> Result<()> { + let env = Arc::new(RuntimeEnv::default()); + let writer_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + let reader_schema = Arc::new(Schema::new(vec![ + Field::new("other_id", DataType::Int32, true), + Field::new("other_value", DataType::Utf8, true), + ])); + + let writer = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&writer_schema)); + let reader = build_test_spill_manager(env, Arc::clone(&reader_schema)); + let written_batch = build_writer_batch(Arc::clone(&writer_schema))?; + + let spill_file = writer + .spill_record_batch_and_finish( + std::slice::from_ref(&written_batch), + "writer", + )? + .unwrap(); + + let stream = reader.read_spill_as_stream(spill_file, None)?; + let err = collect(stream) + .await + .expect_err("schema mismatch should fail fast"); + let err = err.to_string(); + assert!(err.contains("Spill file schema mismatch")); + assert!(err.contains("expected")); + assert!(err.contains("got")); + + Ok(()) + } + #[test] fn check_sliced_size_for_string_view_array() -> Result<()> { let array_length = 50; diff --git a/datafusion/physical-plan/src/spill/spill_pool.rs b/datafusion/physical-plan/src/spill/spill_pool.rs index bbe54ca45caa3..2639188a2609d 100644 --- a/datafusion/physical-plan/src/spill/spill_pool.rs +++ b/datafusion/physical-plan/src/spill/spill_pool.rs @@ -61,6 +61,10 @@ struct SpillPoolShared { /// Writer's reference to the current file (shared by all cloned writers). /// Has its own lock to allow I/O without blocking queue access. current_write_file: Option>>, + /// Number of active writer clones. Only when this reaches zero should + /// `writer_dropped` be set to true. This prevents premature EOF signaling + /// when one writer clone is dropped while others are still active. + active_writer_count: usize, } impl SpillPoolShared { @@ -72,6 +76,7 @@ impl SpillPoolShared { waker: None, writer_dropped: false, current_write_file: None, + active_writer_count: 1, } } @@ -97,7 +102,6 @@ impl SpillPoolShared { /// The writer automatically manages file rotation based on the `max_file_size_bytes` /// configured in [`channel`]. When the last writer clone is dropped, it finalizes the /// current file so readers can access all written data. -#[derive(Clone)] pub struct SpillPoolWriter { /// Maximum size in bytes before rotating to a new file. /// Typically set from configuration `datafusion.execution.max_spill_file_size_bytes`. @@ -106,6 +110,18 @@ pub struct SpillPoolWriter { shared: Arc>, } +impl Clone for SpillPoolWriter { + fn clone(&self) -> Self { + // Increment the active writer count so that `writer_dropped` is only + // set to true when the *last* clone is dropped. + self.shared.lock().active_writer_count += 1; + Self { + max_file_size_bytes: self.max_file_size_bytes, + shared: Arc::clone(&self.shared), + } + } +} + impl SpillPoolWriter { /// Spills a batch to the pool, rotating files when necessary. /// @@ -194,6 +210,8 @@ impl SpillPoolWriter { // Append the batch if let Some(ref mut writer) = file_shared.writer { writer.append_batch(batch)?; + // make sure we flush the writer for readers + writer.flush()?; file_shared.batches_written += 1; file_shared.estimated_size += batch_size; } @@ -231,6 +249,15 @@ impl Drop for SpillPoolWriter { fn drop(&mut self) { let mut shared = self.shared.lock(); + shared.active_writer_count -= 1; + let is_last_writer = shared.active_writer_count == 0; + + if !is_last_writer { + // Other writer clones are still active; do not finalize or + // signal EOF to readers. + return; + } + // Finalize the current file when the last writer is dropped if let Some(current_file) = shared.current_write_file.take() { // Release shared lock before locking file @@ -384,28 +411,33 @@ impl Drop for SpillPoolWriter { /// // Create channel with 1MB file size limit /// let (writer, mut reader) = spill_pool::channel(1024 * 1024, spill_manager); /// -/// // Spawn writer task to produce batches -/// let write_handle = tokio::spawn(async move { +/// // Spawn writer and reader concurrently; writer wakes reader via wakers +/// let writer_task = tokio::spawn(async move { /// for i in 0..5 { /// let array: ArrayRef = Arc::new(Int32Array::from(vec![i; 100])); /// let batch = RecordBatch::try_new(schema.clone(), vec![array]).unwrap(); -/// writer.push_batch(&batch).unwrap(); +/// writer.push_batch(&batch)?; /// } -/// // Writer dropped here, finalizing current file +/// // Explicitly drop writer to finalize the spill file and wake the reader +/// drop(writer); +/// datafusion_common::Result::<()>::Ok(()) /// }); /// -/// // Reader consumes batches in FIFO order (can run concurrently with writer) -/// let mut batches_read = 0; -/// while let Some(result) = reader.next().await { -/// let batch = result?; -/// batches_read += 1; -/// // Process batch... -/// if batches_read == 5 { -/// break; // Got all expected batches +/// let reader_task = tokio::spawn(async move { +/// let mut batches_read = 0; +/// while let Some(result) = reader.next().await { +/// let _batch = result?; +/// batches_read += 1; /// } -/// } +/// datafusion_common::Result::::Ok(batches_read) +/// }); +/// +/// let (writer_res, reader_res) = tokio::join!(writer_task, reader_task); +/// writer_res +/// .map_err(|e| datafusion_common::DataFusionError::Execution(e.to_string()))??; +/// let batches_read = reader_res +/// .map_err(|e| datafusion_common::DataFusionError::Execution(e.to_string()))??; /// -/// write_handle.await.unwrap(); /// assert_eq!(batches_read, 5); /// # Ok(()) /// # } @@ -530,7 +562,11 @@ impl Stream for SpillFile { // Step 2: Lazy-create reader stream if needed if self.reader.is_none() && should_read { if let Some(file) = file { - match self.spill_manager.read_spill_as_stream(file, None) { + // we want this unbuffered because files are actively being written to + match self + .spill_manager + .read_spill_as_stream_unbuffered(file, None) + { Ok(stream) => { self.reader = Some(SpillFileReader { stream, @@ -712,7 +748,6 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::runtime_env::RuntimeEnv; - use futures::StreamExt; fn create_test_schema() -> SchemaRef { Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])) @@ -874,8 +909,8 @@ mod tests { ); assert_eq!( metrics.spilled_bytes.value(), - 0, - "Spilled bytes should be 0 before file finalization" + 320, + "Spilled bytes should reflect data written (header + 1 batch)" ); assert_eq!( metrics.spilled_rows.value(), @@ -1173,6 +1208,9 @@ mod tests { async fn test_reader_catches_up_to_writer() -> Result<()> { let (writer, mut reader) = create_spill_channel(1024 * 1024); + let (reader_waiting_tx, reader_waiting_rx) = tokio::sync::oneshot::channel(); + let (first_read_done_tx, first_read_done_rx) = tokio::sync::oneshot::channel(); + #[derive(Clone, Copy, Debug, PartialEq, Eq)] enum ReadWriteEvent { ReadStart, @@ -1185,36 +1223,41 @@ mod tests { let reader_events = Arc::clone(&events); let reader_handle = SpawnedTask::spawn(async move { reader_events.lock().push(ReadWriteEvent::ReadStart); + reader_waiting_tx + .send(()) + .expect("reader_waiting channel closed unexpectedly"); let result = reader.next().await.unwrap().unwrap(); reader_events .lock() .push(ReadWriteEvent::Read(result.num_rows())); + first_read_done_tx + .send(()) + .expect("first_read_done channel closed unexpectedly"); let result = reader.next().await.unwrap().unwrap(); reader_events .lock() .push(ReadWriteEvent::Read(result.num_rows())); }); - // Give reader time to start pending - tokio::time::sleep(std::time::Duration::from_millis(5)).await; + // Wait until the reader is pending on the first batch + reader_waiting_rx + .await + .expect("reader should signal when waiting"); // Now write a batch (should wake the reader) let batch = create_test_batch(0, 5); events.lock().push(ReadWriteEvent::Write(batch.num_rows())); writer.push_batch(&batch)?; - // Wait for the reader to process - let processed = async { - loop { - if events.lock().len() >= 3 { - break; - } - tokio::time::sleep(std::time::Duration::from_micros(500)).await; - } - }; - tokio::time::timeout(std::time::Duration::from_secs(1), processed) + // Wait for the reader to finish the first read before allowing the + // second write. This ensures deterministic ordering of events: + // 1. The reader starts and pends on the first `next()` + // 2. The first write wakes the reader + // 3. The reader processes the first batch and signals completion + // 4. The second write is issued, ensuring consistent event ordering + first_read_done_rx .await - .unwrap(); + .expect("reader should signal when first read completes"); // Write another batch let batch = create_test_batch(5, 10); @@ -1287,11 +1330,11 @@ mod tests { writer.push_batch(&batch)?; } - // Check metrics before drop - spilled_bytes should be 0 since file isn't finalized yet + // Check metrics before drop - spilled_bytes already reflects written data let spilled_bytes_before = metrics.spilled_bytes.value(); assert_eq!( - spilled_bytes_before, 0, - "Spilled bytes should be 0 before writer is dropped" + spilled_bytes_before, 1088, + "Spilled bytes should reflect data written (header + 5 batches)" ); // Explicitly drop the writer - this should finalize the current file @@ -1324,6 +1367,81 @@ mod tests { Ok(()) } + /// Verifies that the reader stays alive as long as any writer clone exists. + /// + /// `SpillPoolWriter` is `Clone`, and in non-preserve-order repartitioning + /// mode multiple input partition tasks share clones of the same writer. + /// The reader must not see EOF until **all** clones have been dropped, + /// even if the queue is temporarily empty between writes from different + /// clones. + /// + /// The test sequence is: + /// + /// 1. writer1 writes a batch, then is dropped. + /// 2. The reader consumes that batch (queue is now empty). + /// 3. writer2 (still alive) writes a batch. + /// 4. The reader must see that batch. + /// 5. EOF is only signalled after writer2 is also dropped. + #[tokio::test] + async fn test_clone_drop_does_not_signal_eof_prematurely() -> Result<()> { + let (writer1, mut reader) = create_spill_channel(1024 * 1024); + let writer2 = writer1.clone(); + + // Synchronization: tell writer2 when it may proceed. + let (proceed_tx, proceed_rx) = tokio::sync::oneshot::channel::<()>(); + + // Spawn writer2 — it waits for the signal before writing. + let writer2_handle = SpawnedTask::spawn(async move { + proceed_rx.await.unwrap(); + writer2.push_batch(&create_test_batch(10, 10)).unwrap(); + // writer2 is dropped here (last clone → true EOF) + }); + + // Writer1 writes one batch, then drops. + writer1.push_batch(&create_test_batch(0, 10))?; + drop(writer1); + + // Read writer1's batch. + let batch1 = reader.next().await.unwrap()?; + assert_eq!(batch1.num_rows(), 10); + let col = batch1 + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.value(0), 0); + + // Signal writer2 to write its batch. It will execute when the + // current task yields (i.e. when reader.next() returns Pending). + proceed_tx.send(()).unwrap(); + + // The reader should wait (Pending) for writer2's data, not EOF. + let batch2 = + tokio::time::timeout(std::time::Duration::from_secs(5), reader.next()) + .await + .expect("Reader timed out — should not hang"); + + assert!( + batch2.is_some(), + "Reader must not return EOF while a writer clone is still alive" + ); + let batch2 = batch2.unwrap()?; + assert_eq!(batch2.num_rows(), 10); + let col = batch2 + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.value(0), 10); + + writer2_handle.await.unwrap(); + + // All writers dropped — reader should see real EOF now. + assert!(reader.next().await.is_none()); + + Ok(()) + } + #[tokio::test] async fn test_disk_usage_decreases_as_files_consumed() -> Result<()> { use datafusion_execution::runtime_env::RuntimeEnvBuilder; diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index 480b723d0b151..9d0b964886afd 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -27,11 +27,13 @@ use super::metrics::ExecutionPlanMetricsSet; use super::metrics::{BaselineMetrics, SplitMetrics}; use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; use crate::displayable; +use crate::spill::get_record_batch_memory_size; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_common_runtime::JoinSet; use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::MemoryReservation; use futures::ready; use futures::stream::BoxStream; @@ -328,6 +330,11 @@ impl RecordBatchReceiverStreamBuilder { context: Arc, ) { let output = self.tx(); + let input_display = if log::log_enabled!(log::Level::Debug) { + displayable(input.as_ref()).one_line().to_string() + } else { + String::new() + }; self.inner.spawn(async move { let mut stream = match input.execute(partition, context) { @@ -336,14 +343,18 @@ impl RecordBatchReceiverStreamBuilder { // is no place to send the error and no reason to continue. output.send(Err(e)).await.ok(); debug!( - "Stopping execution: error executing input: {}", - displayable(input.as_ref()).one_line() + "Stopping execution: error executing input: {input_display}", ); return Ok(()); } Ok(stream) => stream, }; + // Drop the input early, as soon as we're done with it. + // Holding on to it can cause delays in cancelling the child plan when the query is + // cancelled. + drop(input); + // Transfer batches from inner stream to the output tx // immediately. while let Some(item) = stream.next().await { @@ -353,8 +364,7 @@ impl RecordBatchReceiverStreamBuilder { // place to send the error and no reason to continue. if output.send(item).await.is_err() { debug!( - "Stopping execution: output is gone, plan cancelling: {}", - displayable(input.as_ref()).one_line() + "Stopping execution: output is gone, plan cancelling: {input_display}", ); return Ok(()); } @@ -362,10 +372,7 @@ impl RecordBatchReceiverStreamBuilder { // Stop after the first error is encountered (Don't // drive all streams to completion) if is_err { - debug!( - "Stopping execution: plan returned error: {}", - displayable(input.as_ref()).one_line() - ); + debug!("Stopping execution: plan returned error: {input_display}"); return Ok(()); } } @@ -404,8 +411,11 @@ pin_project! { pub struct RecordBatchStreamAdapter { schema: SchemaRef, + // Wrapped in Option so we can drop the inner stream as soon as it + // returns `None`, releasing any upstream pipeline resources before the + // adapter itself is dropped. #[pin] - stream: S, + stream: Option, } } @@ -434,7 +444,10 @@ impl RecordBatchStreamAdapter { /// // ... /// ``` pub fn new(schema: SchemaRef, stream: S) -> Self { - Self { schema, stream } + Self { + schema, + stream: Some(stream), + } } } @@ -453,11 +466,29 @@ where type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().stream.poll_next(cx) + let mut this = self.project(); + let Some(inner) = this.stream.as_mut().as_pin_mut() else { + return Poll::Ready(None); + }; + let item = ready!(inner.poll_next(cx)); + if item.is_none() { + // Drop the inner stream in place to release its resources. + // SAFETY: the inner stream is dropped without moving it out of + // its pinned location; assigning `None` only runs the inner + // value's destructor in place, which is permitted for pinned + // values. + unsafe { + *this.stream.as_mut().get_unchecked_mut() = None; + } + } + Poll::Ready(item) } fn size_hint(&self) -> (usize, Option) { - self.stream.size_hint() + match self.stream.as_ref() { + Some(stream) => stream.size_hint(), + None => (0, Some(0)), + } } } @@ -531,6 +562,7 @@ impl ObservedStream { let Some(fetch) = self.fetch else { return poll }; if self.produced >= fetch { + self.release_inner(); return Poll::Ready(None); } @@ -538,12 +570,22 @@ impl ObservedStream { if self.produced + batch.num_rows() > fetch { let batch = batch.slice(0, fetch.saturating_sub(self.produced)); self.produced += batch.num_rows(); + if self.produced >= fetch { + self.release_inner(); + } return Poll::Ready(Some(Ok(batch))); }; self.produced += batch.num_rows() } poll } + + /// Replace the inner stream with an [`EmptyRecordBatchStream`], dropping + /// the original stream so its upstream pipeline can be torn down. + fn release_inner(&mut self) { + let schema = self.inner.schema(); + self.inner = Box::pin(EmptyRecordBatchStream::new(schema)); + } } impl RecordBatchStream for ObservedStream { @@ -671,7 +713,12 @@ impl BatchSplitStream { } } Some(Err(e)) => Poll::Ready(Some(Err(e))), - None => Poll::Ready(None), + None => { + // Release the input pipeline's resources. + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); + Poll::Ready(None) + } } } } @@ -699,11 +746,78 @@ impl RecordBatchStream for BatchSplitStream { } } +/// A stream that holds a memory reservation for its lifetime, +/// shrinking the reservation as batches are consumed. +/// The original reservation must have its batch sizes calculated using [`get_record_batch_memory_size`] +/// On error, the reservation is *NOT* freed, until the stream is dropped. +pub(crate) struct ReservationStream { + schema: SchemaRef, + inner: SendableRecordBatchStream, + reservation: MemoryReservation, +} + +impl ReservationStream { + pub(crate) fn new( + schema: SchemaRef, + inner: SendableRecordBatchStream, + reservation: MemoryReservation, + ) -> Self { + Self { + schema, + inner, + reservation, + } + } +} + +impl Stream for ReservationStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let res = self.inner.poll_next_unpin(cx); + + match res { + Poll::Ready(res) => { + match res { + Some(Ok(batch)) => { + self.reservation + .shrink(get_record_batch_memory_size(&batch)); + Poll::Ready(Some(Ok(batch))) + } + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => { + // Stream is done so free the reservation completely + self.reservation.free(); + // Release the input pipeline's resources. + let inner_schema = self.inner.schema(); + self.inner = Box::pin(EmptyRecordBatchStream::new(inner_schema)); + Poll::Ready(None) + } + } + } + Poll::Pending => Poll::Pending, + } + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +impl RecordBatchStream for ReservationStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + #[cfg(test)] mod test { use super::*; use crate::test::exec::{ - assert_strong_count_converges_to_zero, BlockingExec, MockExec, PanicExec, + BlockingExec, MockExec, PanicExec, assert_strong_count_converges_to_zero, }; use arrow::datatypes::{DataType, Field, Schema}; @@ -924,7 +1038,126 @@ mod test { assert_eq!( number_of_batches, 2, - "Should have received exactly one empty batch" + "Should have received exactly two empty batches" + ); + } + + #[tokio::test] + async fn test_reservation_stream_shrinks_on_poll() { + use arrow::array::Int32Array; + use datafusion_execution::memory_pool::MemoryConsumer; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(10 * 1024 * 1024, 1.0) + .build_arc() + .unwrap(); + + let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + // Create batches + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], + ) + .unwrap(); + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![6, 7, 8, 9, 10]))], + ) + .unwrap(); + + let batch1_size = get_record_batch_memory_size(&batch1); + let batch2_size = get_record_batch_memory_size(&batch2); + + // Reserve memory upfront + reservation.try_grow(batch1_size + batch2_size).unwrap(); + let initial_reserved = runtime.memory_pool.reserved(); + assert_eq!(initial_reserved, batch1_size + batch2_size); + + // Create stream with batches + let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]); + let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) + as SendableRecordBatchStream; + + let mut res_stream = + ReservationStream::new(Arc::clone(&schema), inner, reservation); + + // Poll first batch + let result1 = res_stream.next().await; + assert!(result1.is_some()); + + // Memory should be reduced by batch1_size + let after_first = runtime.memory_pool.reserved(); + assert_eq!(after_first, batch2_size); + + // Poll second batch + let result2 = res_stream.next().await; + assert!(result2.is_some()); + + // Memory should be reduced by batch2_size + let after_second = runtime.memory_pool.reserved(); + assert_eq!(after_second, 0); + + // Poll None (end of stream) + let result3 = res_stream.next().await; + assert!(result3.is_none()); + + // Memory should still be 0 + assert_eq!(runtime.memory_pool.reserved(), 0); + } + + #[tokio::test] + async fn test_reservation_stream_error_handling() { + use datafusion_execution::memory_pool::MemoryConsumer; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(10 * 1024 * 1024, 1.0) + .build_arc() + .unwrap(); + + let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + reservation.try_grow(1000).unwrap(); + let initial = runtime.memory_pool.reserved(); + assert_eq!(initial, 1000); + + // Create a stream that errors + let stream = futures::stream::iter(vec![exec_err!("Test error")]); + let inner = Box::pin(RecordBatchStreamAdapter::new(Arc::clone(&schema), stream)) + as SendableRecordBatchStream; + + let mut res_stream = + ReservationStream::new(Arc::clone(&schema), inner, reservation); + + // Get the error + let result = res_stream.next().await; + assert!(result.is_some()); + assert!(result.unwrap().is_err()); + + // Verify reservation is NOT automatically freed on error + // The reservation is only freed when poll_next returns Poll::Ready(None) + // After an error, the stream may continue to hold the reservation + // until it's explicitly dropped or polled to None + let after_error = runtime.memory_pool.reserved(); + assert_eq!( + after_error, 1000, + "Reservation should still be held after error" + ); + + // Drop the stream to free the reservation + drop(res_stream); + + // Now memory should be freed + assert_eq!( + runtime.memory_pool.reserved(), + 0, + "Memory should be freed when stream is dropped" ); } } diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index f9a7feb9e726e..cdf4b08f718c6 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -17,24 +17,23 @@ //! Generic plans for deferred execution: [`StreamingTableExec`] and [`PartitionStream`] -use std::any::Any; use std::fmt::Debug; use std::sync::Arc; use super::{DisplayAs, DisplayFormatType, PlanProperties}; use crate::coop::make_cooperative; -use crate::display::{display_orderings, ProjectSchemaDisplay}; +use crate::display::{ProjectSchemaDisplay, display_orderings}; use crate::execution_plan::{Boundedness, EmissionType, SchedulingType}; use crate::limit::LimitStream; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::projection::{ - all_alias_free_columns, new_projections_for_columns, update_ordering, ProjectionExec, + ProjectionExec, all_alias_free_columns, new_projections_for_columns, update_ordering, }; use crate::stream::RecordBatchStreamAdapter; use crate::{ExecutionPlan, Partitioning, SendableRecordBatchStream}; use arrow::datatypes::{Schema, SchemaRef}; -use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_common::{Result, internal_err, plan_err}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; @@ -67,7 +66,7 @@ pub struct StreamingTableExec { projected_output_ordering: Vec, infinite: bool, limit: Option, - cache: PlanProperties, + cache: Arc, metrics: ExecutionPlanMetricsSet, } @@ -111,7 +110,7 @@ impl StreamingTableExec { projected_output_ordering, infinite, limit, - cache, + cache: Arc::new(cache), metrics: ExecutionPlanMetricsSet::new(), }) } @@ -232,11 +231,7 @@ impl ExecutionPlan for StreamingTableExec { "StreamingTableExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -335,7 +330,7 @@ impl ExecutionPlan for StreamingTableExec { projected_output_ordering: self.projected_output_ordering.clone(), infinite: self.infinite, limit, - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), metrics: self.metrics.clone(), })) } @@ -346,7 +341,7 @@ mod test { use super::*; use crate::collect_partitioned; use crate::streaming::PartitionStream; - use crate::test::{make_partition, TestPartitionStream}; + use crate::test::{TestPartitionStream, make_partition}; use arrow::record_batch::RecordBatch; #[tokio::test] diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index e3b22611f4deb..a6e76cebcdee2 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -17,7 +17,6 @@ //! Utilities for testing datafusion-physical-plan -use std::any::Any; use std::collections::HashMap; use std::fmt; use std::fmt::{Debug, Formatter}; @@ -25,19 +24,19 @@ use std::pin::Pin; use std::sync::Arc; use std::task::Context; +use crate::ExecutionPlan; use crate::common; use crate::execution_plan::{Boundedness, EmissionType}; use crate::memory::MemoryStream; use crate::metrics::MetricsSet; use crate::stream::RecordBatchStreamAdapter; use crate::streaming::PartitionStream; -use crate::ExecutionPlan; use crate::{DisplayAs, DisplayFormatType, PlanProperties}; use arrow::array::{Array, ArrayRef, Int32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{ - assert_or_internal_err, config::ConfigOptions, project_schema, Result, Statistics, + Result, Statistics, assert_or_internal_err, config::ConfigOptions, project_schema, }; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::equivalence::{ @@ -75,7 +74,7 @@ pub struct TestMemoryExec { /// The maximum number of records to read from this plan. If `None`, /// all records after filtering are returned. fetch: Option, - cache: PlanProperties, + cache: Arc, } impl DisplayAs for TestMemoryExec { @@ -105,10 +104,10 @@ impl DisplayAs for TestMemoryExec { .map_or(String::new(), |limit| format!(", fetch={limit}")); if self.show_sizes { write!( - f, - "partitions={}, partition_sizes={partition_sizes:?}{limit}{output_ordering}{constraints}", - partition_sizes.len(), - ) + f, + "partitions={}, partition_sizes={partition_sizes:?}{limit}{output_ordering}{constraints}", + partition_sizes.len(), + ) } else { write!( f, @@ -130,11 +129,7 @@ impl ExecutionPlan for TestMemoryExec { "DataSourceExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -146,7 +141,7 @@ impl ExecutionPlan for TestMemoryExec { self: Arc, _: Vec>, ) -> Result> { - unimplemented!() + Ok(self) } fn repartitioned( @@ -169,15 +164,11 @@ impl ExecutionPlan for TestMemoryExec { unimplemented!() } - fn statistics(&self) -> Result { - self.statistics_inner() - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { if partition.is_some() { - Ok(Statistics::new_unknown(&self.schema)) + Ok(Arc::new(Statistics::new_unknown(&self.schema))) } else { - self.statistics_inner() + Ok(Arc::new(self.statistics_inner()?)) } } @@ -239,7 +230,7 @@ impl TestMemoryExec { Ok(Self { partitions: partitions.to_vec(), schema, - cache: PlanProperties::new( + cache: Arc::new(PlanProperties::new( EquivalenceProperties::new_with_orderings( Arc::clone(&projected_schema), Vec::::new(), @@ -247,7 +238,7 @@ impl TestMemoryExec { Partitioning::UnknownPartitioning(partitions.len()), EmissionType::Incremental, Boundedness::Bounded, - ), + )), projected_schema, projection, sort_information: vec![], @@ -265,7 +256,7 @@ impl TestMemoryExec { ) -> Result> { let mut source = Self::try_new(partitions, schema, projection)?; let cache = source.compute_properties(); - source.cache = cache; + source.cache = Arc::new(cache); Ok(Arc::new(source)) } @@ -273,7 +264,7 @@ impl TestMemoryExec { pub fn update_cache(source: &Arc) -> TestMemoryExec { let cache = source.compute_properties(); let mut source = (**source).clone(); - source.cache = cache; + source.cache = Arc::new(cache); source } @@ -342,6 +333,7 @@ impl TestMemoryExec { } self.sort_information = sort_information; + self.cache = Arc::new(self.compute_properties()); Ok(self) } diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index b720181b27fe0..e162571e32261 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -17,26 +17,25 @@ //! Simple iterator over batches for use in testing -use std::{ - any::Any, - pin::Pin, - sync::{Arc, Weak}, - task::{Context, Poll}, -}; - use crate::{ - common, execution_plan::Boundedness, DisplayAs, DisplayFormatType, ExecutionPlan, - Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, - Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, + RecordBatchStream, SendableRecordBatchStream, Statistics, common, + execution_plan::Boundedness, }; use crate::{ execution_plan::EmissionType, stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}, }; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::{ + pin::Pin, + sync::{Arc, Weak}, + task::{Context, Poll}, +}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result, internal_err}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; @@ -125,7 +124,7 @@ pub struct MockExec { /// if true (the default), sends data using a separate task to ensure the /// batches are not available without this stream yielding first use_task: bool, - cache: PlanProperties, + cache: Arc, } impl MockExec { @@ -142,7 +141,7 @@ impl MockExec { data, schema, use_task: true, - cache, + cache: Arc::new(cache), } } @@ -188,11 +187,7 @@ impl ExecutionPlan for MockExec { Self::static_name() } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -254,13 +249,9 @@ impl ExecutionPlan for MockExec { } // Panics if one of the batches is an error - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { if partition.is_some() { - return Ok(Statistics::new_unknown(&self.schema)); + return Ok(Arc::new(Statistics::new_unknown(&self.schema))); } let data: Result> = self .data @@ -273,11 +264,11 @@ impl ExecutionPlan for MockExec { let data = data?; - Ok(common::compute_record_batch_statistics( + Ok(Arc::new(common::compute_record_batch_statistics( &[data], &self.schema, None, - )) + ))) } } @@ -298,29 +289,91 @@ pub struct BarrierExec { schema: SchemaRef, /// all streams wait on this barrier to produce - barrier: Arc, - cache: PlanProperties, + start_data_barrier: Option>, + + /// the stream wait for this to return Poll::Ready(None) + finish_barrier: Option>, + + cache: Arc, + + log: bool, } impl BarrierExec { /// Create a new exec with some number of partitions. pub fn new(data: Vec>, schema: SchemaRef) -> Self { // wait for all streams and the input - let barrier = Arc::new(Barrier::new(data.len() + 1)); + let barrier = Some(Arc::new(Barrier::new(data.len() + 1))); let cache = Self::compute_properties(Arc::clone(&schema), &data); Self { data, schema, - barrier, - cache, + start_data_barrier: barrier, + cache: Arc::new(cache), + finish_barrier: None, + log: true, } } + pub fn with_log(mut self, log: bool) -> Self { + self.log = log; + self + } + + pub fn without_start_barrier(mut self) -> Self { + self.start_data_barrier = None; + self + } + + pub fn with_finish_barrier(mut self) -> Self { + let barrier = Arc::new(( + // wait for all streams and the input + Barrier::new(self.data.len() + 1), + AtomicUsize::new(0), + )); + + self.finish_barrier = Some(barrier); + self + } + /// wait until all the input streams and this function is ready pub async fn wait(&self) { - println!("BarrierExec::wait waiting on barrier"); - self.barrier.wait().await; - println!("BarrierExec::wait done waiting"); + let barrier = &self + .start_data_barrier + .as_ref() + .expect("Must only be called when having a start barrier"); + if self.log { + println!("BarrierExec::wait waiting on barrier"); + } + barrier.wait().await; + if self.log { + println!("BarrierExec::wait done waiting"); + } + } + + pub async fn wait_finish(&self) { + let (barrier, _) = &self + .finish_barrier + .as_deref() + .expect("Must only be called when having a finish barrier"); + + if self.log { + println!("BarrierExec::wait_finish waiting on barrier"); + } + barrier.wait().await; + if self.log { + println!("BarrierExec::wait_finish done waiting"); + } + } + + /// Return true if the finish barrier has been reached in all partitions + pub fn is_finish_barrier_reached(&self) -> bool { + let (_, reached_finish) = self + .finish_barrier + .as_deref() + .expect("Must only be called when having finish barrier"); + + reached_finish.load(Ordering::Relaxed) == self.data.len() } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -360,11 +413,7 @@ impl ExecutionPlan for BarrierExec { Self::static_name() } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -391,17 +440,32 @@ impl ExecutionPlan for BarrierExec { // task simply sends data in order after barrier is reached let data = self.data[partition].clone(); - let b = Arc::clone(&self.barrier); + let start_barrier = self.start_data_barrier.as_ref().map(Arc::clone); + let finish_barrier = self.finish_barrier.as_ref().map(Arc::clone); + let log = self.log; let tx = builder.tx(); builder.spawn(async move { - println!("Partition {partition} waiting on barrier"); - b.wait().await; + if let Some(barrier) = start_barrier { + if log { + println!("Partition {partition} waiting on barrier"); + } + barrier.wait().await; + } for batch in data { - println!("Partition {partition} sending batch"); + if log { + println!("Partition {partition} sending batch"); + } if let Err(e) = tx.send(Ok(batch)).await { println!("ERROR batch via barrier stream stream: {e}"); } } + if let Some((barrier, reached_finish)) = finish_barrier.as_deref() { + if log { + println!("Partition {partition} waiting on finish barrier"); + } + reached_finish.fetch_add(1, Ordering::Relaxed); + barrier.wait().await; + } Ok(()) }); @@ -410,26 +474,22 @@ impl ExecutionPlan for BarrierExec { Ok(builder.build()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { if partition.is_some() { - return Ok(Statistics::new_unknown(&self.schema)); + return Ok(Arc::new(Statistics::new_unknown(&self.schema))); } - Ok(common::compute_record_batch_statistics( + Ok(Arc::new(common::compute_record_batch_statistics( &self.data, &self.schema, None, - )) + ))) } } /// A mock execution plan that errors on a call to execute #[derive(Debug)] pub struct ErrorExec { - cache: PlanProperties, + cache: Arc, } impl Default for ErrorExec { @@ -446,7 +506,9 @@ impl ErrorExec { true, )])); let cache = Self::compute_properties(schema); - Self { cache } + Self { + cache: Arc::new(cache), + } } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -483,11 +545,7 @@ impl ExecutionPlan for ErrorExec { Self::static_name() } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -517,20 +575,20 @@ impl ExecutionPlan for ErrorExec { pub struct StatisticsExec { stats: Statistics, schema: Arc, - cache: PlanProperties, + cache: Arc, } impl StatisticsExec { pub fn new(stats: Statistics, schema: Schema) -> Self { assert_eq!( - stats - .column_statistics.len(), schema.fields().len(), + stats.column_statistics.len(), + schema.fields().len(), "if defined, the column statistics vector length should be the number of fields" ); let cache = Self::compute_properties(Arc::new(schema.clone())); Self { stats, schema: Arc::new(schema), - cache, + cache: Arc::new(cache), } } @@ -573,11 +631,7 @@ impl ExecutionPlan for StatisticsExec { Self::static_name() } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -600,16 +654,12 @@ impl ExecutionPlan for StatisticsExec { unimplemented!("This plan only serves for testing statistics") } - fn statistics(&self) -> Result { - Ok(self.stats.clone()) - } - - fn partition_statistics(&self, partition: Option) -> Result { - Ok(if partition.is_some() { + fn partition_statistics(&self, partition: Option) -> Result> { + Ok(Arc::new(if partition.is_some() { Statistics::new_unknown(&self.schema) } else { self.stats.clone() - }) + })) } } @@ -623,7 +673,7 @@ pub struct BlockingExec { /// Ref-counting helper to check if the plan and the produced stream are still in memory. refs: Arc<()>, - cache: PlanProperties, + cache: Arc, } impl BlockingExec { @@ -633,7 +683,7 @@ impl BlockingExec { Self { schema, refs: Default::default(), - cache, + cache: Arc::new(cache), } } @@ -680,11 +730,7 @@ impl ExecutionPlan for BlockingExec { Self::static_name() } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -766,7 +812,7 @@ pub struct PanicExec { /// Number of output partitions. Each partition will produce this /// many empty output record batches prior to panicking batches_until_panics: Vec, - cache: PlanProperties, + cache: Arc, } impl PanicExec { @@ -778,7 +824,7 @@ impl PanicExec { Self { schema, batches_until_panics, - cache, + cache: Arc::new(cache), } } @@ -826,11 +872,7 @@ impl ExecutionPlan for PanicExec { Self::static_name() } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 99af9b8f7ca12..9da606dc90db2 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -19,7 +19,7 @@ use arrow::{ array::{Array, AsArray}, - compute::{interleave_record_batch, prep_null_mask_filter, FilterBuilder}, + compute::{FilterBuilder, interleave_record_batch, prep_null_mask_filter}, row::{RowConverter, Rows, SortField}, }; use datafusion_expr::{ColumnarValue, Operator}; @@ -27,23 +27,24 @@ use std::mem::size_of; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; use super::metrics::{ - BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, RecordOutput, + BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricCategory, + RecordOutput, }; use crate::spill::get_record_batch_memory_size; -use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}; +use crate::{SendableRecordBatchStream, stream::RecordBatchStreamAdapter}; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::SchemaRef; use datafusion_common::{ - internal_datafusion_err, internal_err, HashMap, Result, ScalarValue, + HashMap, Result, ScalarValue, internal_datafusion_err, internal_err, }; use datafusion_execution::{ memory_pool::{MemoryConsumer, MemoryReservation}, runtime_env::RuntimeEnv, }; use datafusion_physical_expr::{ - expressions::{is_not_null, is_null, lit, BinaryExpr, DynamicFilterPhysicalExpr}, PhysicalExpr, + expressions::{BinaryExpr, DynamicFilterPhysicalExpr, is_not_null, is_null, lit}, }; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use parking_lot::RwLock; @@ -131,6 +132,9 @@ pub struct TopK { pub(crate) finished: bool, } +/// For more background, please also see the [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog] +/// +/// [Dynamic Filters: Passing Information Between Operators During Execution for 25x Faster Queries blog]: https://datafusion.apache.org/blog/2025/09/10/dynamic-filters #[derive(Debug, Clone)] pub struct TopKDynamicFilters { /// The current *global* threshold for the dynamic filter. @@ -159,7 +163,7 @@ impl TopKDynamicFilters { // Guesstimate for memory allocation: estimated number of bytes used per row in the RowConverter const ESTIMATED_BYTES_PER_ROW: usize = 20; -fn build_sort_fields( +pub(crate) fn build_sort_fields( ordering: &[PhysicalSortExpr], schema: &SchemaRef, ) -> Result> { @@ -217,7 +221,7 @@ impl TopK { expr, row_converter, scratch_rows, - heap: TopKHeap::new(k, batch_size), + heap: TopKHeap::new(k), common_sort_prefix_converter: prefix_row_converter, common_sort_prefix: Arc::from(common_sort_prefix), finished: false, @@ -250,13 +254,12 @@ impl TopK { let num_rows = batch.num_rows(); let array = filtered.into_array(num_rows)?; let mut filter = array.as_boolean().clone(); - let true_count = filter.true_count(); - if true_count == 0 { + if !filter.has_true() { // nothing to filter, so no need to update return Ok(()); } // only update the keys / rows if the filter does not match all rows - if true_count < num_rows { + if filter.null_count() > 0 || filter.has_false() { // Indices in `set_indices` should be correct if filter contains nulls // So we prepare the filter here. Note this is also done in the `FilterBuilder` // so there is no overhead to do this here. @@ -409,10 +412,10 @@ impl TopK { }; // Update the filter expression - if let Some(pred) = predicate { - if !pred.eq(&lit(true)) { - filter.expr.update(pred)?; - } + if let Some(pred) = predicate + && !pred.eq(&lit(true)) + { + filter.expr.update(pred)?; } Ok(()) @@ -644,6 +647,7 @@ impl TopKMetrics { Self { baseline: BaselineMetrics::new(metrics, partition), row_replacements: MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) .counter("row_replacements", partition), } } @@ -659,8 +663,6 @@ impl TopKMetrics { struct TopKHeap { /// The maximum number of elements to store in this heap. k: usize, - /// The target number of rows for output batches - batch_size: usize, /// Storage for up at most `k` items using a BinaryHeap. Reversed /// so that the smallest k so far is on the top inner: BinaryHeap, @@ -671,11 +673,10 @@ struct TopKHeap { } impl TopKHeap { - fn new(k: usize, batch_size: usize) -> Self { + fn new(k: usize) -> Self { assert!(k > 0); Self { k, - batch_size, inner: BinaryHeap::new(), store: RecordBatchStore::new(), owned_bytes: 0, @@ -721,8 +722,8 @@ impl TopKHeap { let row = row.as_ref(); // Reuse storage for evicted item if possible - let new_top_k = if self.inner.len() == self.k { - let prev_min = self.inner.pop().unwrap(); + if self.inner.len() == self.k { + let mut prev_min = self.inner.peek_mut().unwrap(); // Update batch use if prev_min.batch_id == batch_entry.id { @@ -733,15 +734,16 @@ impl TopKHeap { // update memory accounting self.owned_bytes -= prev_min.owned_size(); - prev_min.with_new_row(row, batch_id, index) - } else { - TopKRow::new(row, batch_id, index) - }; - self.owned_bytes += new_top_k.owned_size(); + prev_min.replace_with(row, batch_id, index); - // put the new row into the heap - self.inner.push(new_top_k) + self.owned_bytes += prev_min.owned_size(); + } else { + let new_row = TopKRow::new(row, batch_id, index); + self.owned_bytes += new_row.owned_size(); + // put the new row into the heap + self.inner.push(new_row); + }; } /// Returns the values stored in this heap, from values low to @@ -787,24 +789,26 @@ impl TopKHeap { /// Compact this heap, rewriting all stored batches into a single /// input batch pub fn maybe_compact(&mut self) -> Result<()> { - // we compact if the number of "unused" rows in the store is - // past some pre-defined threshold. Target holding up to - // around 20 batches, but handle cases of large k where some - // batches might be partially full - let max_unused_rows = (20 * self.batch_size) + self.k; - let unused_rows = self.store.unused_rows(); - - // don't compact if the store has one extra batch or - // unused rows is under the threshold - if self.store.len() <= 2 || unused_rows < max_unused_rows { + // Don't compact if there's only one batch (compacting into itself is pointless) + if self.store.len() <= 1 { + return Ok(()); + } + + let total_rows = self.store.total_rows; + let num_rows = self.inner.len(); + + // Compact when current store memory exceeds 2x what the compacted + // result would need. The multiplier avoids compacting when the + // savings would be marginal. + if total_rows <= num_rows * 2 { return Ok(()); } + // at first, compact the entire thing always into a new batch // (maybe we can get fancier in the future about ignoring // batches that have a high usage ratio already // Note: new batch is in the same order as inner - let num_rows = self.inner.len(); let (new_batch, mut topk_rows) = self.emit_with_state()?; let Some(new_batch) = new_batch else { return Ok(()); @@ -870,7 +874,7 @@ impl TopKHeap { ScalarValue::try_from_array(&array, 0)? } array => { - return internal_err!("Expected a scalar value, got {:?}", array) + return internal_err!("Expected a scalar value, got {:?}", array); } }; @@ -908,26 +912,13 @@ impl TopKRow { } } - /// Create a new TopKRow reusing the existing allocation - fn with_new_row( - self, - new_row: impl AsRef<[u8]>, - batch_id: u32, - index: usize, - ) -> Self { - let Self { - mut row, - batch_id: _, - index: _, - } = self; - row.clear(); - row.extend_from_slice(new_row.as_ref()); + // Replace the existing row capacity with new values + fn replace_with(&mut self, new_row: impl AsRef<[u8]>, batch_id: u32, index: usize) { + self.row.clear(); + self.row.extend_from_slice(new_row.as_ref()); - Self { - row, - batch_id, - index, - } + self.batch_id = batch_id; + self.index = index; } /// Returns the number of bytes owned by this row in the heap (not @@ -977,6 +968,8 @@ struct RecordBatchStore { batches: HashMap, /// total size of all record batches tracked by this store batches_size: usize, + /// row count of all the batches + total_rows: usize, } impl RecordBatchStore { @@ -985,6 +978,7 @@ impl RecordBatchStore { next_id: 0, batches: HashMap::new(), batches_size: 0, + total_rows: 0, } } @@ -1002,6 +996,7 @@ impl RecordBatchStore { // uses of 0 means that none of the rows in the batch were stored in the topk if entry.uses > 0 { self.batches_size += get_record_batch_memory_size(&entry.batch); + self.total_rows += entry.batch.num_rows(); self.batches.insert(entry.id, entry); } } @@ -1010,6 +1005,7 @@ impl RecordBatchStore { fn clear(&mut self) { self.batches.clear(); self.batches_size = 0; + self.total_rows = 0; } fn get(&self, id: u32) -> Option<&RecordBatchEntry> { @@ -1021,15 +1017,6 @@ impl RecordBatchStore { self.batches.len() } - /// Returns the total number of rows in batches minus the number - /// which are in use - fn unused_rows(&self) -> usize { - self.batches - .values() - .map(|batch_entry| batch_entry.batch.num_rows() - batch_entry.uses) - .sum() - } - /// returns true if the store has nothing stored fn is_empty(&self) -> bool { self.batches.is_empty() @@ -1053,6 +1040,11 @@ impl RecordBatchStore { .batches_size .checked_sub(get_record_batch_memory_size(&old_entry.batch)) .unwrap(); + + self.total_rows = self + .total_rows + .checked_sub(old_entry.batch.num_rows()) + .unwrap(); } } @@ -1068,7 +1060,7 @@ impl RecordBatchStore { #[cfg(test)] mod tests { use super::*; - use arrow::array::{Float64Array, Int32Array, RecordBatch}; + use arrow::array::{BooleanArray, Float64Array, Int32Array}; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::SortOptions; use datafusion_common::assert_batches_eq; @@ -1251,4 +1243,184 @@ mod tests { Ok(()) } + + /// Tests that memory-based compaction triggers when a large batch + /// has very few rows referenced by the top-k heap. + #[tokio::test] + async fn test_topk_memory_compaction() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + let sort_expr = PhysicalSortExpr { + expr: col("a", schema.as_ref())?, + options: SortOptions::default(), + }; + + let full_expr = LexOrdering::from([sort_expr.clone()]); + let prefix = vec![sort_expr]; + + let runtime = Arc::new(RuntimeEnv::default()); + let metrics = ExecutionPlanMetricsSet::new(); + + let k = 5; + let mut topk = TopK::try_new( + 0, + Arc::clone(&schema), + prefix, + full_expr, + k, + 8192, + runtime, + &metrics, + Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new( + DynamicFilterPhysicalExpr::new(vec![], lit(true)), + )))), + )?; + + // Insert a large batch (100,000 rows) with values 1..=100_000. + // Only the smallest 5 values (1..=5) will end up in the heap. + let large_values: Vec = (1..=100_000).collect(); + let array1: ArrayRef = Arc::new(Int32Array::from(large_values)); + let batch1 = RecordBatch::try_new(Arc::clone(&schema), vec![array1])?; + topk.insert_batch(batch1)?; + + // After the first batch, store has 1 batch — compaction should + // not trigger (guard: store.len() <= 1). + assert_eq!( + topk.heap.store.len(), + 1, + "should have 1 batch before second insert" + ); + + // Insert a second batch whose values displace entries in the heap. + // -1 and 0 are smaller than the current top-5 (1..=5), so they + // produce 2 replacements. With replacements > 0, `insert_batch` + // calls `insert_batch_entry` (briefly making store.len() == 2) + // and then `maybe_compact`, which should collapse it back to 1. + let array2: ArrayRef = Arc::new(Int32Array::from(vec![-1, 0])); + let batch2 = RecordBatch::try_new(Arc::clone(&schema), vec![array2])?; + let replacements_before = topk.metrics.row_replacements.value(); + topk.insert_batch(batch2)?; + + // Sanity check: batch2 was actually integrated. Without + // replacements, `maybe_compact` is never called and the + // store-length assertion below would pass vacuously. + assert!( + topk.metrics.row_replacements.value() > replacements_before, + "batch2 must produce replacements so compaction is exercised" + ); + + // The compacted-estimate guard is `total_rows <= num_rows * 2`, + // i.e. 100_002 <= 10, which is false, so compaction fires and + // collapses the two stored batches back into one. + assert_eq!( + topk.heap.store.len(), + 1, + "store should be compacted to 1 batch" + ); + + // Verify the emitted results are correct (top 5 ascending). + let results: Vec<_> = topk.emit()?.try_collect().await?; + assert_batches_eq!( + &[ + "+----+", "| a |", "+----+", "| -1 |", "| 0 |", "| 1 |", "| 2 |", + "| 3 |", "+----+", + ], + &results + ); + + Ok(()) + } + + /// Negative path: when stored rows are close to the heap size, + /// compaction must NOT fire even with multiple batches present, + /// because the savings would be marginal + /// (guard: `total_rows <= num_rows * 2`). + /// + /// Uses a bit-packed `BooleanArray` so that future changes to the + /// compaction heuristic that reintroduce a per-byte estimate + /// (where integer truncation could misbehave on sub-byte types) + /// are caught here. + #[tokio::test] + async fn test_topk_memory_compaction_skipped_when_marginal() -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, false)])); + + let sort_expr = PhysicalSortExpr { + expr: col("a", schema.as_ref())?, + options: SortOptions::default(), + }; + let full_expr = LexOrdering::from([sort_expr.clone()]); + let prefix = vec![sort_expr]; + + let runtime = Arc::new(RuntimeEnv::default()); + let metrics = ExecutionPlanMetricsSet::new(); + + let k = 10; + let mut topk = TopK::try_new( + 0, + Arc::clone(&schema), + prefix, + full_expr, + k, + 8192, + runtime, + &metrics, + Arc::new(RwLock::new(TopKDynamicFilters::new(Arc::new( + DynamicFilterPhysicalExpr::new(vec![], lit(true)), + )))), + )?; + + // Two small batches; every row from both batches ends up referenced + // by the heap, so total_rows == num_rows == 10. + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(BooleanArray::from(vec![false, false, true, true, true])) + as ArrayRef, + ], + )?; + topk.insert_batch(batch1)?; + + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(BooleanArray::from(vec![false, false, false, true, true])) + as ArrayRef, + ], + )?; + topk.insert_batch(batch2)?; + + // Guard `total_rows <= num_rows * 2` should hold (10 <= 20), + // so compaction is skipped and BOTH batches remain in the store. + assert_eq!( + topk.heap.store.len(), + 2, + "store must keep 2 batches when savings would be marginal" + ); + assert_eq!(topk.heap.inner.len(), 10, "heap should hold all 10 rows"); + + // Output is still correct (5 falses then 5 trues ascending). + let results: Vec<_> = topk.emit()?.try_collect().await?; + assert_batches_eq!( + &[ + "+-------+", + "| a |", + "+-------+", + "| false |", + "| false |", + "| false |", + "| false |", + "| false |", + "| true |", + "| true |", + "| true |", + "| true |", + "| true |", + "+-------+", + ], + &results + ); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/tree_node.rs b/datafusion/physical-plan/src/tree_node.rs index 85d7b33575ca2..aa4f144f91898 100644 --- a/datafusion/physical-plan/src/tree_node.rs +++ b/datafusion/physical-plan/src/tree_node.rs @@ -20,10 +20,10 @@ use std::fmt::{self, Display, Formatter}; use std::sync::Arc; -use crate::{displayable, with_new_children_if_necessary, ExecutionPlan}; +use crate::{ExecutionPlan, displayable, with_new_children_if_necessary}; -use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode}; use datafusion_common::Result; +use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode}; impl DynTreeNode for dyn ExecutionPlan { fn arc_children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 06c28a8081ef6..3ea2eb5402fe5 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -23,33 +23,39 @@ use std::borrow::Borrow; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; -use std::{any::Any, sync::Arc}; use super::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning, + PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, metrics::{ExecutionPlanMetricsSet, MetricsSet}, - ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, - ExecutionPlanProperties, Partitioning, PlanProperties, RecordBatchStream, - SendableRecordBatchStream, Statistics, }; +use crate::check_if_same_properties; use crate::execution_plan::{ - boundedness_from_children, check_default_invariants, emission_type_from_children, - InvariantLevel, + CardinalityEffect, InvariantLevel, boundedness_from_children, + check_default_invariants, emission_type_from_children, +}; +use crate::filter::FilterExec; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, PushedDown, }; -use crate::filter_pushdown::{FilterDescription, FilterPushdownPhase}; use crate::metrics::BaselineMetrics; -use crate::projection::{make_with_child, ProjectionExec}; +use crate::projection::{ProjectionExec, make_with_child}; use crate::stream::ObservedStream; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; -use datafusion_common::stats::Precision; +use datafusion_common::stats::NdvFallback; use datafusion_common::{ - assert_or_internal_err, exec_err, internal_datafusion_err, Result, + Result, assert_or_internal_err, exec_err, internal_datafusion_err, }; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{calculate_union, EquivalenceProperties, PhysicalExpr}; +use datafusion_physical_expr::{ + EquivalenceProperties, PhysicalExpr, calculate_union, conjunction, +}; use futures::Stream; use itertools::Itertools; @@ -100,7 +106,7 @@ pub struct UnionExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl UnionExec { @@ -118,7 +124,7 @@ impl UnionExec { UnionExec { inputs, metrics: ExecutionPlanMetricsSet::new(), - cache, + cache: Arc::new(cache), } } @@ -147,7 +153,7 @@ impl UnionExec { Ok(Arc::new(UnionExec { inputs, metrics: ExecutionPlanMetricsSet::new(), - cache, + cache: Arc::new(cache), })) } } @@ -183,6 +189,17 @@ impl UnionExec { boundedness_from_children(inputs), )) } + + fn with_new_children_and_same_properties( + &self, + children: Vec>, + ) -> Self { + Self { + inputs: children, + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for UnionExec { @@ -206,11 +223,7 @@ impl ExecutionPlan for UnionExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -259,6 +272,7 @@ impl ExecutionPlan for UnionExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); UnionExec::try_new(children) } @@ -267,7 +281,12 @@ impl ExecutionPlan for UnionExec { mut partition: usize, context: Arc, ) -> Result { - trace!("Start UnionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + trace!( + "Start UnionExec::execute for partition {} of context session_id {} and task_id {:?}", + partition, + context.session_id(), + context.task_id() + ); let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); // record the tiny amount of work done in this function so // elapsed_compute is reported as non zero @@ -299,11 +318,7 @@ impl ExecutionPlan for UnionExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { + fn partition_statistics(&self, partition: Option) -> Result> { if let Some(partition_idx) = partition { // For a specific partition, find which input it belongs to let mut remaining_idx = partition_idx; @@ -316,22 +331,23 @@ impl ExecutionPlan for UnionExec { remaining_idx -= input_partition_count; } // If we get here, the partition index is out of bounds - Ok(Statistics::new_unknown(&self.schema())) + Ok(Arc::new(Statistics::new_unknown(&self.schema()))) } else { - // Collect statistics from all inputs - let stats = self - .inputs - .iter() - .map(|input_exec| input_exec.partition_statistics(None)) - .collect::>>()?; - - Ok(stats - .into_iter() - .reduce(stats_union) - .unwrap_or_else(|| Statistics::new_unknown(&self.schema()))) + let schema = self.schema(); + Ok(Arc::new(merge_input_statistics( + &self.inputs, + None, + schema.as_ref(), + )?)) } } + fn cardinality_effect(&self) -> CardinalityEffect { + // Union combines rows from multiple inputs, so output rows are not tied + // to any single input and can only be constrained as greater-or-equal. + CardinalityEffect::GreaterEqual + } + fn supports_limit_pushdown(&self) -> bool { true } @@ -365,6 +381,83 @@ impl ExecutionPlan for UnionExec { ) -> Result { FilterDescription::from_children(parent_filters, &self.children()) } + + fn handle_child_pushdown_result( + &self, + phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + // Pre phase: handle heterogeneous pushdown by wrapping individual + // children with FilterExec and reporting all filters as handled. + // Post phase: use default behavior to let the filter creator decide how to handle + // filters that weren't fully pushed down. + if phase != FilterPushdownPhase::Pre { + return Ok(FilterPushdownPropagation::if_all(child_pushdown_result)); + } + + // UnionExec needs specialized filter pushdown handling when children have + // heterogeneous pushdown support. Without this, when some children support + // pushdown and others don't, the default behavior would leave FilterExec + // above UnionExec, re-applying filters to outputs of all children—including + // those that already applied the filters via pushdown. This specialized + // implementation adds FilterExec only to children that don't support + // pushdown, avoiding redundant filtering and improving performance. + // + // Example: Given Child1 (no pushdown support) and Child2 (has pushdown support) + // Default behavior: This implementation: + // FilterExec UnionExec + // UnionExec FilterExec + // Child1 Child1 + // Child2(filter) Child2(filter) + + // Collect unsupported filters for each child + let mut unsupported_filters_per_child = vec![Vec::new(); self.inputs.len()]; + for parent_filter_result in child_pushdown_result.parent_filters.iter() { + for (child_idx, &child_result) in + parent_filter_result.child_results.iter().enumerate() + { + if matches!(child_result, PushedDown::No) { + unsupported_filters_per_child[child_idx] + .push(Arc::clone(&parent_filter_result.filter)); + } + } + } + + // Wrap children that have unsupported filters with FilterExec + let mut new_children = self.inputs.clone(); + for (child_idx, unsupported_filters) in + unsupported_filters_per_child.iter().enumerate() + { + if !unsupported_filters.is_empty() { + let combined_filter = conjunction(unsupported_filters.clone()); + new_children[child_idx] = Arc::new(FilterExec::try_new( + combined_filter, + Arc::clone(&self.inputs[child_idx]), + )?); + } + } + + // Check if any children were modified + let children_modified = new_children + .iter() + .zip(self.inputs.iter()) + .any(|(new, old)| !Arc::ptr_eq(new, old)); + + let all_filters_pushed = + vec![PushedDown::Yes; child_pushdown_result.parent_filters.len()]; + let propagation = if children_modified { + let updated_node = UnionExec::try_new(new_children)?; + FilterPushdownPropagation::with_parent_pushdown_result(all_filters_pushed) + .with_updated_node(updated_node) + } else { + FilterPushdownPropagation::with_parent_pushdown_result(all_filters_pushed) + }; + + // Report all parent filters as supported since we've ensured they're applied + // on all children (either pushed down or via FilterExec) + Ok(propagation) + } } /// Combines multiple input streams by interleaving them. @@ -406,7 +499,7 @@ pub struct InterleaveExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl InterleaveExec { @@ -420,7 +513,7 @@ impl InterleaveExec { Ok(InterleaveExec { inputs, metrics: ExecutionPlanMetricsSet::new(), - cache, + cache: Arc::new(cache), }) } @@ -442,6 +535,17 @@ impl InterleaveExec { boundedness_from_children(inputs), )) } + + fn with_new_children_and_same_properties( + &self, + children: Vec>, + ) -> Self { + Self { + inputs: children, + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for InterleaveExec { @@ -465,11 +569,7 @@ impl ExecutionPlan for InterleaveExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -490,6 +590,7 @@ impl ExecutionPlan for InterleaveExec { can_interleave(children.iter()), "Can not create InterleaveExec: new children can not be interleaved" ); + check_if_same_properties!(self, children); Ok(Arc::new(InterleaveExec::try_new(children)?)) } @@ -498,7 +599,12 @@ impl ExecutionPlan for InterleaveExec { partition: usize, context: Arc, ) -> Result { - trace!("Start InterleaveExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); + trace!( + "Start InterleaveExec::execute for partition {} of context session_id {} and task_id {:?}", + partition, + context.session_id(), + context.task_id() + ); let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); // record the tiny amount of work done in this function so // elapsed_compute is reported as non zero @@ -535,21 +641,13 @@ impl ExecutionPlan for InterleaveExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) - } - - fn partition_statistics(&self, partition: Option) -> Result { - let stats = self - .inputs - .iter() - .map(|stat| stat.partition_statistics(partition)) - .collect::>>()?; - - Ok(stats - .into_iter() - .reduce(stats_union) - .unwrap_or_else(|| Statistics::new_unknown(&self.schema()))) + fn partition_statistics(&self, partition: Option) -> Result> { + let schema = self.schema(); + Ok(Arc::new(merge_input_statistics( + &self.inputs, + partition, + schema.as_ref(), + )?)) } fn benefits_from_input_partitioning(&self) -> Vec { @@ -583,15 +681,28 @@ fn union_schema(inputs: &[Arc]) -> Result { } let first_schema = inputs[0].schema(); + let first_field_count = first_schema.fields().len(); + + // validate that all inputs have the same number of fields + for (idx, input) in inputs.iter().enumerate().skip(1) { + let field_count = input.schema().fields().len(); + if field_count != first_field_count { + return exec_err!( + "UnionExec/InterleaveExec requires all inputs to have the same number of fields. \ + Input 0 has {first_field_count} fields, but input {idx} has {field_count} fields" + ); + } + } - let fields = (0..first_schema.fields().len()) + let fields = (0..first_field_count) .map(|i| { // We take the name from the left side of the union to match how names are coerced during logical planning, // which also uses the left side names. let base_field = first_schema.field(i).clone(); // Coerce metadata and nullability across all inputs - let merged_field = inputs + + inputs .iter() .enumerate() .map(|(input_idx, input)| { @@ -613,9 +724,7 @@ fn union_schema(inputs: &[Arc]) -> Result { // We can unwrap this because if inputs was empty, this would've already panic'ed when we // indexed into inputs[0]. .unwrap() - .with_name(base_field.name()); - - merged_field + .with_name(base_field.name()) }) .collect::>(); @@ -697,46 +806,35 @@ impl Stream for CombinedRecordBatchStream { } } -fn col_stats_union( - mut left: ColumnStatistics, - right: &ColumnStatistics, -) -> ColumnStatistics { - left.distinct_count = Precision::Absent; - left.min_value = left.min_value.min(&right.min_value); - left.max_value = left.max_value.max(&right.max_value); - left.sum_value = left.sum_value.add(&right.sum_value); - left.null_count = left.null_count.add(&right.null_count); - - left -} +fn merge_input_statistics( + inputs: &[Arc], + partition: Option, + schema: &Schema, +) -> Result { + let stats = inputs + .iter() + .map(|input| { + input + .partition_statistics(partition) + .map(Arc::unwrap_or_clone) + }) + .collect::>>()?; -fn stats_union(mut left: Statistics, right: Statistics) -> Statistics { - let Statistics { - num_rows: right_num_rows, - total_byte_size: right_total_bytes, - column_statistics: right_column_statistics, - .. - } = right; - left.num_rows = left.num_rows.add(&right_num_rows); - left.total_byte_size = left.total_byte_size.add(&right_total_bytes); - left.column_statistics = left - .column_statistics - .into_iter() - .zip(right_column_statistics.iter()) - .map(|(a, b)| col_stats_union(a, b)) - .collect::>(); - left + Statistics::try_merge_iter_with_ndv_fallback(stats.iter(), schema, NdvFallback::Sum) } #[cfg(test)] mod tests { use super::*; use crate::collect; + use crate::repartition::RepartitionExec; + use crate::test::exec::StatisticsExec; use crate::test::{self, TestMemoryExec}; use arrow::compute::SortOptions; use arrow::datatypes::DataType; - use datafusion_common::ScalarValue; + use datafusion_common::stats::Precision; + use datafusion_common::{ColumnStatistics, ScalarValue}; use datafusion_physical_expr::equivalence::convert_to_orderings; use datafusion_physical_expr::expressions::col; @@ -754,6 +852,18 @@ mod tests { Ok(schema) } + fn create_test_schema2() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); + + Ok(schema) + } + #[tokio::test] async fn test_union_partitions() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); @@ -779,94 +889,204 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_stats_union() { - let left = Statistics { - num_rows: Precision::Exact(5), - total_byte_size: Precision::Exact(23), - column_statistics: vec![ - ColumnStatistics { - distinct_count: Precision::Exact(5), - max_value: Precision::Exact(ScalarValue::Int64(Some(21))), - min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), - sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), - null_count: Precision::Exact(0), - }, - ColumnStatistics { - distinct_count: Precision::Exact(1), - max_value: Precision::Exact(ScalarValue::from("x")), - min_value: Precision::Exact(ScalarValue::from("a")), - sum_value: Precision::Absent, - null_count: Precision::Exact(3), - }, - ColumnStatistics { - distinct_count: Precision::Absent, - max_value: Precision::Exact(ScalarValue::Float32(Some(1.1))), - min_value: Precision::Exact(ScalarValue::Float32(Some(0.1))), - sum_value: Precision::Exact(ScalarValue::Float32(Some(42.0))), - null_count: Precision::Absent, - }, - ], - }; + fn stats_merge_inputs() -> (SchemaRef, Statistics, Statistics, Statistics) { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32, true)])); + + let left = Statistics::default() + .with_num_rows(Precision::Exact(5)) + .with_total_byte_size(Precision::Exact(23)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_distinct_count(Precision::Exact(5)) + .with_min_value(Precision::Exact(ScalarValue::UInt32(Some(1)))) + .with_max_value(Precision::Exact(ScalarValue::UInt32(Some(21)))) + .with_sum_value(Precision::Exact(ScalarValue::UInt32(Some(42)))) + .with_null_count(Precision::Exact(0)) + .with_byte_size(Precision::Exact(40)), + ); - let right = Statistics { - num_rows: Precision::Exact(7), - total_byte_size: Precision::Exact(29), - column_statistics: vec![ - ColumnStatistics { - distinct_count: Precision::Exact(3), - max_value: Precision::Exact(ScalarValue::Int64(Some(34))), - min_value: Precision::Exact(ScalarValue::Int64(Some(1))), - sum_value: Precision::Exact(ScalarValue::Int64(Some(42))), - null_count: Precision::Exact(1), - }, - ColumnStatistics { - distinct_count: Precision::Absent, - max_value: Precision::Exact(ScalarValue::from("c")), - min_value: Precision::Exact(ScalarValue::from("b")), - sum_value: Precision::Absent, - null_count: Precision::Absent, - }, - ColumnStatistics { - distinct_count: Precision::Absent, - max_value: Precision::Absent, - min_value: Precision::Absent, - sum_value: Precision::Absent, - null_count: Precision::Absent, - }, - ], - }; + let right = Statistics::default() + .with_num_rows(Precision::Exact(7)) + .with_total_byte_size(Precision::Exact(29)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_distinct_count(Precision::Exact(3)) + .with_min_value(Precision::Exact(ScalarValue::UInt32(Some(22)))) + .with_max_value(Precision::Exact(ScalarValue::UInt32(Some(34)))) + .with_sum_value(Precision::Exact(ScalarValue::UInt32(Some(8)))) + .with_null_count(Precision::Exact(1)) + .with_byte_size(Precision::Exact(60)), + ); - let result = stats_union(left, right); - let expected = Statistics { - num_rows: Precision::Exact(12), - total_byte_size: Precision::Exact(52), - column_statistics: vec![ - ColumnStatistics { - distinct_count: Precision::Absent, - max_value: Precision::Exact(ScalarValue::Int64(Some(34))), - min_value: Precision::Exact(ScalarValue::Int64(Some(-4))), - sum_value: Precision::Exact(ScalarValue::Int64(Some(84))), - null_count: Precision::Exact(1), - }, - ColumnStatistics { - distinct_count: Precision::Absent, - max_value: Precision::Exact(ScalarValue::from("x")), - min_value: Precision::Exact(ScalarValue::from("a")), - sum_value: Precision::Absent, - null_count: Precision::Absent, - }, - ColumnStatistics { - distinct_count: Precision::Absent, - max_value: Precision::Absent, - min_value: Precision::Absent, - sum_value: Precision::Absent, - null_count: Precision::Absent, - }, - ], - }; + let expected = Statistics::default() + .with_num_rows(Precision::Exact(12)) + .with_total_byte_size(Precision::Exact(52)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_distinct_count(Precision::Inexact(8)) + .with_min_value(Precision::Exact(ScalarValue::UInt32(Some(1)))) + .with_max_value(Precision::Exact(ScalarValue::UInt32(Some(34)))) + .with_sum_value(Precision::Exact(ScalarValue::UInt64(Some(50)))) + .with_null_count(Precision::Exact(1)) + .with_byte_size(Precision::Exact(100)), + ); + + (schema, left, right, expected) + } + + fn stats_merge_multicolumn_inputs() -> (SchemaRef, Statistics, Statistics, Statistics) + { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float32, true), + ])); + + let left = Statistics::default() + .with_num_rows(Precision::Exact(5)) + .with_total_byte_size(Precision::Exact(23)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_distinct_count(Precision::Exact(5)) + .with_min_value(Precision::Exact(ScalarValue::Int64(Some(-4)))) + .with_max_value(Precision::Exact(ScalarValue::Int64(Some(21)))) + .with_sum_value(Precision::Exact(ScalarValue::Int64(Some(42)))) + .with_null_count(Precision::Exact(0)), + ) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_distinct_count(Precision::Exact(2)) + .with_min_value(Precision::Exact(ScalarValue::from("a"))) + .with_max_value(Precision::Exact(ScalarValue::from("x"))) + .with_null_count(Precision::Exact(3)), + ) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_max_value(Precision::Exact(ScalarValue::Float32(Some(1.1)))) + .with_min_value(Precision::Exact(ScalarValue::Float32(Some(0.1)))) + .with_sum_value(Precision::Exact(ScalarValue::Float32(Some(42.0)))), + ); + + let right = Statistics::default() + .with_num_rows(Precision::Exact(7)) + .with_total_byte_size(Precision::Exact(29)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_distinct_count(Precision::Exact(3)) + .with_min_value(Precision::Exact(ScalarValue::Int64(Some(1)))) + .with_max_value(Precision::Exact(ScalarValue::Int64(Some(34)))) + .with_sum_value(Precision::Exact(ScalarValue::Int64(Some(42)))) + .with_null_count(Precision::Exact(1)), + ) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_distinct_count(Precision::Exact(3)) + .with_min_value(Precision::Exact(ScalarValue::from("b"))) + .with_max_value(Precision::Exact(ScalarValue::from("z"))), + ) + .add_column_statistics(ColumnStatistics::new_unknown()); + + let expected = Statistics::default() + .with_num_rows(Precision::Exact(12)) + .with_total_byte_size(Precision::Exact(52)) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_distinct_count(Precision::Inexact(6)) + .with_min_value(Precision::Exact(ScalarValue::Int64(Some(-4)))) + .with_max_value(Precision::Exact(ScalarValue::Int64(Some(34)))) + .with_sum_value(Precision::Exact(ScalarValue::Int64(Some(84)))) + .with_null_count(Precision::Exact(1)), + ) + .add_column_statistics( + ColumnStatistics::new_unknown() + .with_distinct_count(Precision::Inexact(5)) + .with_min_value(Precision::Exact(ScalarValue::from("a"))) + .with_max_value(Precision::Exact(ScalarValue::from("z"))), + ) + .add_column_statistics(ColumnStatistics::new_unknown()); + + (schema, left, right, expected) + } + + #[test] + fn test_union_partition_statistics_uses_shared_statistics_merge() -> Result<()> { + let (schema, left, right, expected) = stats_merge_inputs(); + + let left: Arc = + Arc::new(StatisticsExec::new(left, schema.as_ref().clone())); + let right: Arc = + Arc::new(StatisticsExec::new(right, schema.as_ref().clone())); + + let union = UnionExec::try_new(vec![left, right])?; + let stats = union.partition_statistics(None)?; + + assert_eq!(stats.as_ref(), &expected); + Ok(()) + } + + #[test] + fn test_union_partition_statistics_uses_shared_statistics_merge_multicolumn() + -> Result<()> { + let (schema, left, right, expected) = stats_merge_multicolumn_inputs(); + + let left: Arc = + Arc::new(StatisticsExec::new(left, schema.as_ref().clone())); + let right: Arc = + Arc::new(StatisticsExec::new(right, schema.as_ref().clone())); + + let union = UnionExec::try_new(vec![left, right])?; + let stats = union.partition_statistics(None)?; + + assert_eq!(stats.as_ref(), &expected); + Ok(()) + } - assert_eq!(result, expected); + #[test] + fn test_interleave_partition_statistics_uses_shared_statistics_merge() -> Result<()> { + let (schema, left, right, expected) = stats_merge_inputs(); + let hash_expr = vec![col("a", schema.as_ref())?]; + + let left: Arc = Arc::new(RepartitionExec::try_new( + Arc::new(StatisticsExec::new(left, schema.as_ref().clone())), + Partitioning::Hash(hash_expr.clone(), 2), + )?); + let right: Arc = Arc::new(RepartitionExec::try_new( + Arc::new(StatisticsExec::new(right, schema.as_ref().clone())), + Partitioning::Hash(hash_expr, 2), + )?); + + let interleave = InterleaveExec::try_new(vec![left, right])?; + let stats = interleave.partition_statistics(None)?; + + assert_eq!(stats.as_ref(), &expected); + Ok(()) + } + + #[test] + fn test_interleave_partition_statistics_for_partition_uses_shared_statistics_merge() + -> Result<()> { + let (schema, left, right, _) = stats_merge_inputs(); + let hash_expr = vec![col("a", schema.as_ref())?]; + + let left: Arc = Arc::new(RepartitionExec::try_new( + Arc::new(StatisticsExec::new(left, schema.as_ref().clone())), + Partitioning::Hash(hash_expr.clone(), 2), + )?); + let right: Arc = Arc::new(RepartitionExec::try_new( + Arc::new(StatisticsExec::new(right, schema.as_ref().clone())), + Partitioning::Hash(hash_expr, 2), + )?); + + let interleave = InterleaveExec::try_new(vec![left, right])?; + let stats = interleave.partition_statistics(Some(0))?; + + let expected = Statistics::default() + .with_num_rows(Precision::Inexact(5)) + .with_total_byte_size(Precision::Inexact(25)) + .add_column_statistics(ColumnStatistics::new_unknown()); + + assert_eq!(stats.as_ref(), &expected); + Ok(()) } #[tokio::test] @@ -973,20 +1193,24 @@ mod tests { fn test_union_empty_inputs() { // Test that UnionExec::try_new fails with empty inputs let result = UnionExec::try_new(vec![]); - assert!(result - .unwrap_err() - .to_string() - .contains("UnionExec requires at least one input")); + assert!( + result + .unwrap_err() + .to_string() + .contains("UnionExec requires at least one input") + ); } #[test] fn test_union_schema_empty_inputs() { // Test that union_schema fails with empty inputs let result = union_schema(&[]); - assert!(result - .unwrap_err() - .to_string() - .contains("Cannot create union schema from empty inputs")); + assert!( + result + .unwrap_err() + .to_string() + .contains("Cannot create union schema from empty inputs") + ); } #[test] @@ -1019,7 +1243,6 @@ mod tests { // Downcast to verify it's a UnionExec let union = union_plan - .as_any() .downcast_ref::() .expect("Expected UnionExec"); @@ -1030,4 +1253,43 @@ mod tests { Ok(()) } + + #[test] + fn test_union_schema_mismatch() { + // Test that UnionExec properly rejects inputs with different field counts + let schema = create_test_schema().unwrap(); + let schema2 = create_test_schema2().unwrap(); + let memory_exec1 = + Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None).unwrap()); + let memory_exec2 = + Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema2), None).unwrap()); + + let result = UnionExec::try_new(vec![memory_exec1, memory_exec2]); + assert!(result.is_err()); + assert!( + result.unwrap_err().to_string().contains( + "UnionExec/InterleaveExec requires all inputs to have the same number of fields" + ) + ); + } + + #[test] + fn test_union_cardinality_effect() -> Result<()> { + let schema = create_test_schema()?; + let input1: Arc = + Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?); + let input2: Arc = + Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?); + + let union = UnionExec::try_new(vec![input1, input2])?; + let union = union + .downcast_ref::() + .expect("expected UnionExec for multiple inputs"); + + assert!(matches!( + union.cardinality_effect(), + CardinalityEffect::GreaterEqual + )); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index 3c999b1a40c1a..c31d0dd23fa68 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -18,22 +18,24 @@ //! Define a plan for unnesting values in columns that contain a list type. use std::cmp::{self, Ordering}; -use std::task::{ready, Poll}; -use std::{any::Any, sync::Arc}; +use std::sync::Arc; +use std::task::{Poll, ready}; use super::metrics::{ - self, BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, - RecordOutput, + self, BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricCategory, + MetricsSet, RecordOutput, }; use super::{DisplayAs, ExecutionPlanProperties, PlanProperties}; +use crate::stream::EmptyRecordBatchStream; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, RecordBatchStream, - SendableRecordBatchStream, + SendableRecordBatchStream, check_if_same_properties, }; use arrow::array::{ - new_null_array, Array, ArrayRef, AsArray, BooleanBufferBuilder, FixedSizeListArray, - Int64Array, LargeListArray, ListArray, PrimitiveArray, Scalar, StructArray, + Array, ArrayRef, AsArray, BooleanBufferBuilder, FixedSizeListArray, Int64Array, + LargeListArray, LargeListViewArray, ListArray, ListViewArray, PrimitiveArray, Scalar, + StructArray, new_null_array, }; use arrow::compute::kernels::length::length; use arrow::compute::kernels::zip::zip; @@ -43,13 +45,13 @@ use arrow::record_batch::RecordBatch; use arrow_ord::cmp::lt; use async_trait::async_trait; use datafusion_common::{ - exec_datafusion_err, exec_err, internal_err, Constraints, HashMap, HashSet, Result, - UnnestOptions, + Constraints, HashMap, HashSet, Result, UnnestOptions, exec_datafusion_err, exec_err, + internal_err, }; use datafusion_execution::TaskContext; +use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::PhysicalExpr; use futures::{Stream, StreamExt}; use log::trace; @@ -74,7 +76,7 @@ pub struct UnnestExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl UnnestExec { @@ -100,7 +102,7 @@ impl UnnestExec { struct_column_indices, options, metrics: Default::default(), - cache, + cache: Arc::new(cache), }) } @@ -193,6 +195,17 @@ impl UnnestExec { pub fn options(&self) -> &UnnestOptions { &self.options } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for UnnestExec { @@ -217,11 +230,7 @@ impl ExecutionPlan for UnnestExec { "UnnestExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -231,10 +240,11 @@ impl ExecutionPlan for UnnestExec { fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); Ok(Arc::new(UnnestExec::new( - Arc::clone(&children[0]), + children.swap_remove(0), self.list_column_indices.clone(), self.struct_column_indices.clone(), Arc::clone(&self.schema), @@ -281,10 +291,13 @@ struct UnnestMetrics { impl UnnestMetrics { fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { - let input_batches = - MetricBuilder::new(metrics).counter("input_batches", partition); + let input_batches = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("input_batches", partition); - let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let input_rows = MetricBuilder::new(metrics) + .with_category(MetricCategory::Rows) + .counter("input_rows", partition); Self { baseline_metrics: BaselineMetrics::new(metrics, partition), @@ -362,6 +375,7 @@ impl UnnestStream { debug_assert!(result_batch.num_rows() > 0); Some(Ok(result_batch)) } + // If the stream is depleted or returned an error, log the finish message: other => { trace!( "Processed {} probe-side input batches containing {} rows and \ @@ -372,6 +386,14 @@ impl UnnestStream { self.metrics.baseline_metrics.output_rows(), self.metrics.baseline_metrics.elapsed_compute(), ); + + // In the non-error case, i.e., input is simply depleted: + if other.is_none() { + // Release the input pipeline's resources. + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); + } + other } }); @@ -406,9 +428,7 @@ fn flatten_struct_cols( Ok(struct_arr.columns().to_vec()) } data_type => internal_err!( - "expecting column {} from input plan to be a struct, got {:?}", - idx, - data_type + "expecting column {idx} from input plan to be a struct, got {data_type}" ), }, None => Ok(vec![Arc::clone(column_data)]), @@ -827,6 +847,30 @@ impl ListArrayType for FixedSizeListArray { } } +impl ListArrayType for ListViewArray { + fn values(&self) -> &ArrayRef { + self.values() + } + + fn value_offsets(&self, row: usize) -> (i64, i64) { + let offset = self.value_offsets()[row] as i64; + let size = self.value_sizes()[row] as i64; + (offset, offset + size) + } +} + +impl ListArrayType for LargeListViewArray { + fn values(&self) -> &ArrayRef { + self.values() + } + + fn value_offsets(&self, row: usize) -> (i64, i64) { + let offset = self.value_offsets()[row]; + let size = self.value_sizes()[row]; + (offset, offset + size) + } +} + /// Unnest multiple list arrays according to the length array. fn unnest_list_arrays( list_arrays: &[ArrayRef], @@ -843,6 +887,12 @@ fn unnest_list_arrays( DataType::FixedSizeList(_, _) => { Ok(list_array.as_fixed_size_list() as &dyn ListArrayType) } + DataType::ListView(_) => { + Ok(list_array.as_list_view::() as &dyn ListArrayType) + } + DataType::LargeListView(_) => { + Ok(list_array.as_list_view::() as &dyn ListArrayType) + } other => exec_err!("Invalid unnest datatype {other }"), }) .collect::>>()?; @@ -1199,32 +1249,32 @@ mod tests { .unwrap(); assert_snapshot!(batches_to_string(&[ret]), - @r###" -+---------------------------------+---------------------------------+---------------------------------+ -| col1_unnest_placeholder_depth_1 | col1_unnest_placeholder_depth_2 | col2_unnest_placeholder_depth_1 | -+---------------------------------+---------------------------------+---------------------------------+ -| [1, 2, 3] | 1 | a | -| | 2 | b | -| [4, 5] | 3 | | -| [1, 2, 3] | | a | -| | | b | -| [4, 5] | | | -| [1, 2, 3] | 4 | a | -| | 5 | b | -| [4, 5] | | | -| [7, 8, 9, 10] | 7 | c | -| | 8 | d | -| [11, 12, 13] | 9 | | -| | 10 | | -| [7, 8, 9, 10] | | c | -| | | d | -| [11, 12, 13] | | | -| [7, 8, 9, 10] | 11 | c | -| | 12 | d | -| [11, 12, 13] | 13 | | -| | | e | -+---------------------------------+---------------------------------+---------------------------------+ - "###); + @r" + +---------------------------------+---------------------------------+---------------------------------+ + | col1_unnest_placeholder_depth_1 | col1_unnest_placeholder_depth_2 | col2_unnest_placeholder_depth_1 | + +---------------------------------+---------------------------------+---------------------------------+ + | [1, 2, 3] | 1 | a | + | | 2 | b | + | [4, 5] | 3 | | + | [1, 2, 3] | | a | + | | | b | + | [4, 5] | | | + | [1, 2, 3] | 4 | a | + | | 5 | b | + | [4, 5] | | | + | [7, 8, 9, 10] | 7 | c | + | | 8 | d | + | [11, 12, 13] | 9 | | + | | 10 | | + | [7, 8, 9, 10] | | c | + | | | d | + | [11, 12, 13] | | | + | [7, 8, 9, 10] | 11 | c | + | | 12 | d | + | [11, 12, 13] | 13 | | + | | | e | + +---------------------------------+---------------------------------+---------------------------------+ + "); Ok(()) } diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index a76316369ec77..6c6b26c9cf49f 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -20,8 +20,7 @@ //! the input data seen so far), which makes it appropriate when processing //! infinite inputs. -use std::any::Any; -use std::cmp::{min, Ordering}; +use std::cmp::{Ordering, min}; use std::collections::VecDeque; use std::pin::Pin; use std::sync::Arc; @@ -29,6 +28,7 @@ use std::task::{Context, Poll}; use super::utils::create_schema; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use crate::stream::EmptyRecordBatchStream; use crate::windows::{ calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs, window_equivalence_properties, @@ -36,7 +36,7 @@ use crate::windows::{ use crate::{ ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PlanProperties, RecordBatchStream, - SendableRecordBatchStream, Statistics, WindowExpr, + SendableRecordBatchStream, Statistics, WindowExpr, check_if_same_properties, }; use arrow::compute::take_record_batch; @@ -52,11 +52,11 @@ use datafusion_common::utils::{ evaluate_partition_ranges, get_at_indices, get_row_at_idx, }; use datafusion_common::{ - arrow_datafusion_err, exec_datafusion_err, exec_err, DataFusionError, HashMap, Result, + HashMap, Result, arrow_datafusion_err, exec_datafusion_err, exec_err, }; use datafusion_execution::TaskContext; -use datafusion_expr::window_state::{PartitionBatchState, WindowAggState}; use datafusion_expr::ColumnarValue; +use datafusion_expr::window_state::{PartitionBatchState, WindowAggState}; use datafusion_physical_expr::window::{ PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowState, }; @@ -65,9 +65,10 @@ use datafusion_physical_expr_common::sort_expr::{ OrderingRequirements, PhysicalSortExpr, }; -use ahash::RandomState; +use crate::execution_plan::CardinalityEffect; +use datafusion_common::hash_utils::RandomState; use futures::stream::Stream; -use futures::{ready, StreamExt}; +use futures::{StreamExt, ready}; use hashbrown::hash_table::HashTable; use indexmap::IndexMap; use log::debug; @@ -93,7 +94,7 @@ pub struct BoundedWindowAggExec { // See `get_ordered_partition_by_indices` for more details. ordered_partition_by_indices: Vec, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, /// If `can_rerepartition` is false, partition_keys is always empty. can_repartition: bool, } @@ -134,7 +135,7 @@ impl BoundedWindowAggExec { metrics: ExecutionPlanMetricsSet::new(), input_order_mode, ordered_partition_by_indices, - cache, + cache: Arc::new(cache), can_repartition, }) } @@ -175,7 +176,9 @@ impl BoundedWindowAggExec { if self.window_expr()[0].partition_by().len() != ordered_partition_by_indices.len() { - return exec_err!("All partition by columns should have an ordering in Sorted mode."); + return exec_err!( + "All partition by columns should have an ordering in Sorted mode." + ); } Box::new(SortedSearch { partition_by_sort_keys, @@ -246,6 +249,17 @@ impl BoundedWindowAggExec { total_byte_size: Precision::Absent, }) } + + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) + } + } } impl DisplayAs for BoundedWindowAggExec { @@ -298,11 +312,7 @@ impl ExecutionPlan for BoundedWindowAggExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -337,6 +347,7 @@ impl ExecutionPlan for BoundedWindowAggExec { self: Arc, children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); Ok(Arc::new(BoundedWindowAggExec::try_new( self.window_expr.clone(), Arc::clone(&children[0]), @@ -366,13 +377,14 @@ impl ExecutionPlan for BoundedWindowAggExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.partition_statistics(None) + fn partition_statistics(&self, partition: Option) -> Result> { + let input_stat = + Arc::unwrap_or_clone(self.input.partition_statistics(partition)?); + Ok(Arc::new(self.statistics_helper(input_stat)?)) } - fn partition_statistics(&self, partition: Option) -> Result { - let input_stat = self.input.partition_statistics(partition)?; - self.statistics_helper(input_stat) + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal } } @@ -627,23 +639,23 @@ impl PartitionSearcher for LinearSearch { fn mark_partition_end(&self, partition_buffers: &mut PartitionBatches) { // We should be in the `PartiallySorted` case, otherwise we can not // tell when we are at the end of a given partition. - if !self.ordered_partition_by_indices.is_empty() { - if let Some((last_row, _)) = partition_buffers.last() { - let last_sorted_cols = self + if !self.ordered_partition_by_indices.is_empty() + && let Some((last_row, _)) = partition_buffers.last() + { + let last_sorted_cols = self + .ordered_partition_by_indices + .iter() + .map(|idx| last_row[*idx].clone()) + .collect::>(); + for (row, partition_batch_state) in partition_buffers.iter_mut() { + let sorted_cols = self .ordered_partition_by_indices .iter() - .map(|idx| last_row[*idx].clone()) - .collect::>(); - for (row, partition_batch_state) in partition_buffers.iter_mut() { - let sorted_cols = self - .ordered_partition_by_indices - .iter() - .map(|idx| &row[*idx]); - // All the partitions other than `last_sorted_cols` are done. - // We are sure that we will no longer receive values for these - // partitions (arrival of a new value would violate ordering). - partition_batch_state.is_end = !sorted_cols.eq(&last_sorted_cols); - } + .map(|idx| &row[*idx]); + // All the partitions other than `last_sorted_cols` are done. + // We are sure that we will no longer receive values for these + // partitions (arrival of a new value would violate ordering). + partition_batch_state.is_end = !sorted_cols.eq(&last_sorted_cols); } } } @@ -1058,6 +1070,10 @@ impl BoundedWindowAggStream { let _timer = elapsed_compute.timer(); self.finished = true; + // Release the input pipeline's resources before computing the + // final aggregates. + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); for (_, partition_batch_state) in self.partition_buffers.iter_mut() { partition_batch_state.is_end = true; } @@ -1242,23 +1258,24 @@ mod tests { use std::time::Duration; use crate::common::collect; + use crate::execution_plan::CardinalityEffect; use crate::expressions::PhysicalSortExpr; use crate::projection::{ProjectionExec, ProjectionExpr}; use crate::streaming::{PartitionStream, StreamingTableExec}; use crate::test::TestMemoryExec; use crate::windows::{ - create_udwf_window_expr, create_window_expr, BoundedWindowAggExec, InputOrderMode, + BoundedWindowAggExec, InputOrderMode, create_udwf_window_expr, create_window_expr, }; - use crate::{displayable, execute_stream, ExecutionPlan}; + use crate::{ExecutionPlan, displayable, execute_stream}; use arrow::array::{ - builder::{Int64Builder, UInt64Builder}, RecordBatch, + builder::{Int64Builder, UInt64Builder}, }; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::test_util::batches_to_string; - use datafusion_common::{exec_datafusion_err, Result, ScalarValue}; + use datafusion_common::{Result, ScalarValue, exec_datafusion_err}; use datafusion_execution::config::SessionConfig; use datafusion_execution::{ RecordBatchStream, SendableRecordBatchStream, TaskContext, @@ -1269,12 +1286,12 @@ mod tests { use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_window::nth_value::last_value_udwf; use datafusion_functions_window::nth_value::nth_value_udwf; - use datafusion_physical_expr::expressions::{col, Column, Literal}; + use datafusion_physical_expr::expressions::{Column, Literal, col}; use datafusion_physical_expr::window::StandardWindowExpr; use datafusion_physical_expr::{LexOrdering, PhysicalExpr}; use futures::future::Shared; - use futures::{pin_mut, ready, FutureExt, Stream, StreamExt}; + use futures::{FutureExt, Stream, StreamExt, pin_mut, ready}; use insta::assert_snapshot; use itertools::Itertools; use tokio::time::timeout; @@ -1474,20 +1491,6 @@ mod tests { Ok(results) } - /// Execute the [ExecutionPlan] and collect the results in memory - #[allow(dead_code)] - pub async fn collect_bonafide( - plan: Arc, - context: Arc, - ) -> Result> { - let stream = execute_stream(plan, context)?; - let mut results = vec![]; - - collect_stream(stream, &mut results).await?; - - Ok(results) - } - fn test_schema() -> SchemaRef { Arc::new(Schema::new(vec![ Field::new("sn", DataType::UInt64, true), @@ -1496,14 +1499,16 @@ mod tests { } fn schema_orders(schema: &SchemaRef) -> Result> { - let orderings = vec![[PhysicalSortExpr { - expr: col("sn", schema)?, - options: SortOptions { - descending: false, - nulls_first: false, - }, - }] - .into()]; + let orderings = vec![ + [PhysicalSortExpr { + expr: col("sn", schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }] + .into(), + ]; Ok(orderings) } @@ -1700,21 +1705,21 @@ mod tests { DataSourceExec: partitions=1, partition_sizes=[3] "#); - assert_snapshot!(batches_to_string(&batches), @r#" - +---+------+---------------+---------------+ - | a | last | nth_value(-1) | nth_value(-2) | - +---+------+---------------+---------------+ - | 1 | 1 | 1 | | - | 2 | 2 | 2 | 1 | - | 3 | 3 | 3 | 2 | - | 1 | 1 | 1 | 3 | - | 2 | 2 | 2 | 1 | - | 3 | 3 | 3 | 2 | - | 1 | 1 | 1 | 3 | - | 2 | 2 | 2 | 1 | - | 3 | 3 | 3 | 2 | - +---+------+---------------+---------------+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +---+------+---------------+---------------+ + | a | last | nth_value(-1) | nth_value(-2) | + +---+------+---------------+---------------+ + | 1 | 1 | 1 | | + | 2 | 2 | 2 | 1 | + | 3 | 3 | 3 | 2 | + | 1 | 1 | 1 | 3 | + | 2 | 2 | 2 | 1 | + | 3 | 3 | 3 | 2 | + | 1 | 1 | 1 | 3 | + | 2 | 2 | 2 | 1 | + | 3 | 3 | 3 | 2 | + +---+------+---------------+---------------+ + "); Ok(()) } @@ -1821,21 +1826,38 @@ mod tests { let task_ctx = task_context(); let batches = collect_with_timeout(plan, task_ctx, timeout_duration).await?; - assert_snapshot!(batches_to_string(&batches), @r#" - +----+------+-------+ - | sn | hash | col_2 | - +----+------+-------+ - | 0 | 2 | 2 | - | 1 | 2 | 2 | - | 2 | 2 | 2 | - | 3 | 2 | 1 | - | 4 | 1 | 2 | - | 5 | 1 | 2 | - | 6 | 1 | 2 | - | 7 | 1 | 1 | - +----+------+-------+ - "#); + assert_snapshot!(batches_to_string(&batches), @r" + +----+------+-------+ + | sn | hash | col_2 | + +----+------+-------+ + | 0 | 2 | 2 | + | 1 | 2 | 2 | + | 2 | 2 | 2 | + | 3 | 2 | 1 | + | 4 | 1 | 2 | + | 5 | 1 | 2 | + | 6 | 1 | 2 | + | 7 | 1 | 1 | + +----+------+-------+ + "); Ok(()) } + + #[test] + fn test_bounded_window_agg_cardinality_effect() -> Result<()> { + let schema = test_schema(); + let input: Arc = + Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?); + let plan = bounded_window_exec_pb_latent_range(input, 1, "hash", "sn")?; + let plan = plan + .downcast_ref::() + .expect("expected BoundedWindowAggExec"); + + assert!(matches!( + plan.cardinality_effect(), + CardinalityEffect::Equal + )); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index cd35325eb3d7a..b72a65cf996be 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -25,13 +25,13 @@ use std::borrow::Borrow; use std::sync::Arc; use crate::{ - expressions::PhysicalSortExpr, ExecutionPlan, ExecutionPlanProperties, - InputOrderMode, PhysicalExpr, + ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr, + expressions::PhysicalSortExpr, }; use arrow::datatypes::{Schema, SchemaRef}; use arrow_schema::{FieldRef, SortOptions}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::{ LimitEffect, PartitionEvaluator, ReversedUDWF, SetMonotonicity, WindowFrame, WindowFunctionDefinition, WindowUDF, @@ -88,7 +88,7 @@ pub fn schema_add_window_field( } /// Create a physical expression for window function -#[allow(clippy::too_many_arguments)] +#[expect(clippy::too_many_arguments)] pub fn create_window_expr( fun: &WindowFunctionDefinition, name: String, @@ -226,6 +226,18 @@ impl WindowUDFExpr { pub fn fun(&self) -> &Arc { &self.fun } + + /// Returns all arguments passed to this window function. + /// + /// Unlike [`StandardWindowFunctionExpr::expressions`], which returns + /// only the expressions that need batch evaluation (and may filter out + /// literal offset/default args like those for `lead`/`lag`), this + /// method returns the complete, unfiltered argument list. This is + /// needed for serialization so that all arguments survive a + /// protobuf round-trip. + pub fn args(&self) -> &[Arc] { + &self.args + } } impl StandardWindowFunctionExpr for WindowUDFExpr { @@ -389,11 +401,11 @@ pub(crate) fn window_equivalence_properties( let mut found = false; for sort_expr in sort_options.into_iter() { candidate_ordering.push(sort_expr); - if let Some(lex) = LexOrdering::new(candidate_ordering.clone()) { - if window_eq_properties.ordering_satisfy(lex)? { - found = true; - break; - } + if let Some(lex) = LexOrdering::new(candidate_ordering.clone()) + && window_eq_properties.ordering_satisfy(lex)? + { + found = true; + break; } // This option didn't work, remove it and try the next one candidate_ordering.pop(); @@ -407,10 +419,10 @@ pub(crate) fn window_equivalence_properties( // If we successfully built an ordering for all columns, use it // When there are no partition expressions, candidate_ordering will be empty and won't be added - if candidate_ordering.len() == partitioning_exprs.len() { - if let Some(lex) = LexOrdering::new(candidate_ordering) { - all_satisfied_lexs.push(lex); - } + if candidate_ordering.len() == partitioning_exprs.len() + && let Some(lex) = LexOrdering::new(candidate_ordering) + { + all_satisfied_lexs.push(lex); } // If there is a partitioning, and no possible ordering cannot satisfy // the input plan's orderings, then we cannot further introduce any @@ -512,21 +524,21 @@ pub(crate) fn window_equivalence_properties( let is_asc = !sort_expr.options.descending; candidate_order.push(sort_expr); - if let Some(lex) = LexOrdering::new(candidate_order.clone()) { - if window_eq_properties.ordering_satisfy(lex)? { - if idx == 0 { - // The first column's ordering direction determines the overall - // monotonicity behavior of the window result. - // - If the aggregate has increasing set monotonicity (e.g., MAX, COUNT) - // and the first arg is ascending, the window result is increasing - // - If the aggregate has decreasing set monotonicity (e.g., MIN) - // and the first arg is ascending, the window result is also increasing - // This flag is used to determine the final window column ordering. - asc = is_asc; - } - found = true; - break; + if let Some(lex) = LexOrdering::new(candidate_order.clone()) + && window_eq_properties.ordering_satisfy(lex)? + { + if idx == 0 { + // The first column's ordering direction determines the overall + // monotonicity behavior of the window result. + // - If the aggregate has increasing set monotonicity (e.g., MAX, COUNT) + // and the first arg is ascending, the window result is increasing + // - If the aggregate has decreasing set monotonicity (e.g., MIN) + // and the first arg is ascending, the window result is also increasing + // This flag is used to determine the final window column ordering. + asc = is_asc; } + found = true; + break; } // This option didn't work, remove it and try the next one candidate_order.pop(); @@ -740,13 +752,13 @@ mod tests { use crate::expressions::col; use crate::streaming::StreamingTableExec; use crate::test::assert_is_pending; - use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; + use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero}; + use InputOrderMode::{Linear, PartiallySorted, Sorted}; use arrow::compute::SortOptions; use arrow_schema::{DataType, Field}; use datafusion_execution::TaskContext; use datafusion_functions_aggregate::count::count_udaf; - use InputOrderMode::{Linear, PartiallySorted, Sorted}; use futures::FutureExt; diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index b588608397f40..ee3b071fc9167 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -17,14 +17,14 @@ //! Stream and channel implementations for window function expressions. -use std::any::Any; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use super::utils::create_schema; -use crate::execution_plan::EmissionType; +use crate::execution_plan::{CardinalityEffect, EmissionType}; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; +use crate::stream::EmptyRecordBatchStream; use crate::windows::{ calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs, window_equivalence_properties, @@ -32,7 +32,7 @@ use crate::windows::{ use crate::{ ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, RecordBatchStream, - SendableRecordBatchStream, Statistics, WindowExpr, + SendableRecordBatchStream, Statistics, WindowExpr, check_if_same_properties, }; use arrow::array::ArrayRef; @@ -42,13 +42,13 @@ use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; use datafusion_common::utils::{evaluate_partition_ranges, transpose}; -use datafusion_common::{assert_eq_or_internal_err, Result}; +use datafusion_common::{Result, assert_eq_or_internal_err}; use datafusion_execution::TaskContext; use datafusion_physical_expr_common::sort_expr::{ OrderingRequirements, PhysicalSortExpr, }; -use futures::{ready, Stream, StreamExt}; +use futures::{Stream, StreamExt, ready}; /// Window execution plan #[derive(Debug, Clone)] @@ -65,7 +65,7 @@ pub struct WindowAggExec { // see `get_ordered_partition_by_indices` for more details. ordered_partition_by_indices: Vec, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, /// If `can_partition` is false, partition_keys is always empty. can_repartition: bool, } @@ -89,7 +89,7 @@ impl WindowAggExec { schema, metrics: ExecutionPlanMetricsSet::new(), ordered_partition_by_indices, - cache, + cache: Arc::new(cache), can_repartition, }) } @@ -159,22 +159,15 @@ impl WindowAggExec { } } - fn statistics_inner(&self) -> Result { - let input_stat = self.input.partition_statistics(None)?; - let win_cols = self.window_expr.len(); - let input_cols = self.input.schema().fields().len(); - // TODO stats: some windowing function will maintain invariants such as min, max... - let mut column_statistics = Vec::with_capacity(win_cols + input_cols); - // copy stats of the input to the beginning of the schema. - column_statistics.extend(input_stat.column_statistics); - for _ in 0..win_cols { - column_statistics.push(ColumnStatistics::new_unknown()) + fn with_new_children_and_same_properties( + &self, + mut children: Vec>, + ) -> Self { + Self { + input: children.swap_remove(0), + metrics: ExecutionPlanMetricsSet::new(), + ..Self::clone(self) } - Ok(Statistics { - num_rows: input_stat.num_rows, - column_statistics, - total_byte_size: Precision::Absent, - }) } } @@ -220,11 +213,7 @@ impl ExecutionPlan for WindowAggExec { } /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -260,11 +249,12 @@ impl ExecutionPlan for WindowAggExec { fn with_new_children( self: Arc, - children: Vec>, + mut children: Vec>, ) -> Result> { + check_if_same_properties!(self, children); Ok(Arc::new(WindowAggExec::try_new( self.window_expr.clone(), - Arc::clone(&children[0]), + children.swap_remove(0), true, )?)) } @@ -290,16 +280,27 @@ impl ExecutionPlan for WindowAggExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - self.statistics_inner() + fn partition_statistics(&self, partition: Option) -> Result> { + let input_stat = + Arc::unwrap_or_clone(self.input.partition_statistics(partition)?); + let win_cols = self.window_expr.len(); + let input_cols = self.input.schema().fields().len(); + // TODO stats: some windowing function will maintain invariants such as min, max... + let mut column_statistics = Vec::with_capacity(win_cols + input_cols); + // copy stats of the input to the beginning of the schema. + column_statistics.extend(input_stat.column_statistics); + for _ in 0..win_cols { + column_statistics.push(ColumnStatistics::new_unknown()) + } + Ok(Arc::new(Statistics { + num_rows: input_stat.num_rows, + column_statistics, + total_byte_size: Precision::Absent, + })) } - fn partition_statistics(&self, partition: Option) -> Result { - if partition.is_none() { - self.statistics_inner() - } else { - Ok(Statistics::new_unknown(&self.schema())) - } + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal } } @@ -430,6 +431,10 @@ impl WindowAggStream { } Some(Err(e)) => Err(e), None => { + // Release the input pipeline's resources before computing + // the final aggregates. + let input_schema = self.input.schema(); + self.input = Box::pin(EmptyRecordBatchStream::new(input_schema)); let Some(result) = self.compute_aggregates()? else { return Poll::Ready(None); }; @@ -450,3 +455,47 @@ impl RecordBatchStream for WindowAggStream { Arc::clone(&self.schema) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::TestMemoryExec; + use crate::windows::create_window_expr; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::ScalarValue; + use datafusion_expr::{ + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + }; + use datafusion_functions_aggregate::count::count_udaf; + + #[test] + fn test_window_agg_cardinality_effect() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); + let input: Arc = + Arc::new(TestMemoryExec::try_new(&[], Arc::clone(&schema), None)?); + let args = vec![crate::expressions::col("a", &schema)?]; + let window_expr = create_window_expr( + &WindowFunctionDefinition::AggregateUDF(count_udaf()), + "count(a)".to_string(), + &args, + &[], + &[], + Arc::new(WindowFrame::new_bounds( + WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameBound::CurrentRow, + )), + Arc::clone(&schema), + false, + false, + None, + )?; + + let window = WindowAggExec::try_new(vec![window_expr], input, true)?; + assert!(matches!( + window.cardinality_effect(), + CardinalityEffect::Equal + )); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index e2c6efd508ba9..28b9c8ddc704c 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -31,16 +31,15 @@ use crate::{ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::{assert_eq_or_internal_err, internal_datafusion_err, Result}; -use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_common::{Result, assert_eq_or_internal_err, internal_datafusion_err}; use datafusion_execution::TaskContext; +use datafusion_execution::memory_pool::MemoryReservation; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; /// A vector of record batches with a memory reservation. #[derive(Debug)] pub(super) struct ReservedBatches { batches: Vec, - #[allow(dead_code)] reservation: MemoryReservation, } @@ -59,13 +58,15 @@ impl ReservedBatches { #[derive(Debug)] pub struct WorkTable { batches: Mutex>, + name: String, } impl WorkTable { /// Create a new work table. - pub(super) fn new() -> Self { + pub(super) fn new(name: String) -> Self { Self { batches: Mutex::new(None), + name, } } @@ -101,25 +102,35 @@ pub struct WorkTableExec { name: String, /// The schema of the stream schema: SchemaRef, + /// Projection to apply to build the output stream from the recursion state + projection: Option>, /// The work table work_table: Arc, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, + cache: Arc, } impl WorkTableExec { /// Create a new execution plan for a worktable exec. - pub fn new(name: String, schema: SchemaRef) -> Self { + pub fn new( + name: String, + mut schema: SchemaRef, + projection: Option>, + ) -> Result { + if let Some(projection) = &projection { + schema = Arc::new(schema.project(projection)?); + } let cache = Self::compute_properties(Arc::clone(&schema)); - Self { - name, + Ok(Self { + name: name.clone(), schema, + projection, + work_table: Arc::new(WorkTable::new(name)), metrics: ExecutionPlanMetricsSet::new(), - work_table: Arc::new(WorkTable::new()), - cache, - } + cache: Arc::new(cache), + }) } /// Ref to name @@ -166,11 +177,7 @@ impl ExecutionPlan for WorkTableExec { "WorkTableExec" } - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { + fn properties(&self) -> &Arc { &self.cache } @@ -197,11 +204,22 @@ impl ExecutionPlan for WorkTableExec { 0, "WorkTableExec got an invalid partition {partition} (expected 0)" ); - let batch = self.work_table.take()?; + let ReservedBatches { + mut batches, + reservation, + } = self.work_table.take()?; + if let Some(projection) = &self.projection { + // We apply the projection + // TODO: it would be better to apply it as soon as possible and not only here + // TODO: an aggressive projection makes the memory reservation smaller, even if we do not edit it + batches = batches + .into_iter() + .map(|b| b.project(projection)) + .collect::, _>>()?; + } - let stream = - MemoryStream::try_new(batch.batches, Arc::clone(&self.schema), None)? - .with_reservation(batch.reservation); + let stream = MemoryStream::try_new(batches, Arc::clone(&self.schema), None)? + .with_reservation(reservation); Ok(Box::pin(cooperative(stream))) } @@ -209,12 +227,8 @@ impl ExecutionPlan for WorkTableExec { Some(self.metrics.clone_inner()) } - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema())) - } - - fn partition_statistics(&self, _partition: Option) -> Result { - Ok(Statistics::new_unknown(&self.schema())) + fn partition_statistics(&self, _partition: Option) -> Result> { + Ok(Arc::new(Statistics::new_unknown(&self.schema()))) } /// Injects run-time state into this `WorkTableExec`. @@ -231,12 +245,17 @@ impl ExecutionPlan for WorkTableExec { // Down-cast to the expected state type; propagate `None` on failure let work_table = state.downcast::().ok()?; + if work_table.name != self.name { + return None; // Different table + } + Some(Arc::new(Self { name: self.name.clone(), schema: Arc::clone(&self.schema), + projection: self.projection.clone(), metrics: ExecutionPlanMetricsSet::new(), work_table, - cache: self.cache.clone(), + cache: Arc::clone(&self.cache), })) } } @@ -244,17 +263,19 @@ impl ExecutionPlan for WorkTableExec { #[cfg(test)] mod tests { use super::*; - use arrow::array::{ArrayRef, Int32Array}; + use arrow::array::{ArrayRef, Int16Array, Int32Array, Int64Array}; + use arrow_schema::{DataType, Field, Schema}; use datafusion_execution::memory_pool::{MemoryConsumer, UnboundedMemoryPool}; + use futures::StreamExt; #[test] fn test_work_table() { - let work_table = WorkTable::new(); + let work_table = WorkTable::new("test".into()); // Can't take from empty work_table assert!(work_table.take().is_err()); let pool = Arc::new(UnboundedMemoryPool::default()) as _; - let mut reservation = MemoryConsumer::new("test_work_table").register(&pool); + let reservation = MemoryConsumer::new("test_work_table").register(&pool); // Update batch to work_table let array: ArrayRef = Arc::new((0..5).collect::()); @@ -278,4 +299,53 @@ mod tests { drop(memory_stream); assert_eq!(pool.reserved(), 0); } + + #[tokio::test] + async fn test_work_table_exec() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int16, false), + ])); + let work_table_exec = + WorkTableExec::new("wt".into(), Arc::clone(&schema), Some(vec![2, 1])) + .unwrap(); + + // We inject the work table + let work_table = Arc::new(WorkTable::new("wt".into())); + let work_table_exec = work_table_exec + .with_new_state(Arc::clone(&work_table) as _) + .unwrap(); + + // We update the work table + let pool = Arc::new(UnboundedMemoryPool::default()) as _; + let reservation = MemoryConsumer::new("test_work_table").register(&pool); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])), + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])), + Arc::new(Int16Array::from(vec![1, 2, 3, 4, 5])), + ], + ) + .unwrap(); + work_table.update(ReservedBatches::new(vec![batch], reservation)); + + // We get back the batch from the work table + let returned_batch = work_table_exec + .execute(0, Arc::new(TaskContext::default())) + .unwrap() + .next() + .await + .unwrap() + .unwrap(); + assert_eq!( + returned_batch, + RecordBatch::try_from_iter(vec![ + ("c", Arc::new(Int16Array::from(vec![1, 2, 3, 4, 5])) as _), + ("b", Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _), + ]) + .unwrap() + ); + } } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 920e277b8ccc0..cfff8a949418a 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -28,9 +28,6 @@ license = { workspace = true } authors = { workspace = true } rust-version = { workspace = true } -# Exclude proto files so crates.io consumers don't need protoc -exclude = ["*.proto"] - [package.metadata.docs.rs] all-features = true @@ -39,9 +36,13 @@ name = "datafusion_proto" [features] default = ["parquet"] -json = ["pbjson", "serde", "serde_json", "datafusion-proto-common/json"] +json = [ + "serde_json", + "datafusion-proto-common/json", + "datafusion-proto-models/json", +] parquet = ["datafusion-datasource-parquet", "datafusion-common/parquet", "datafusion/parquet"] -avro = ["datafusion-datasource-avro", "datafusion-common/avro"] +avro = ["datafusion-datasource-avro"] # Note to developers: do *not* add `datafusion` as a dependency in # this crate. See https://github.com/apache/datafusion/issues/17713 @@ -62,17 +63,17 @@ datafusion-datasource-parquet = { workspace = true, optional = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions-table = { workspace = true } -datafusion-physical-expr = { workspace = true } -datafusion-physical-expr-common = { workspace = true } -datafusion-physical-plan = { workspace = true } +datafusion-physical-expr = { workspace = true, features = ["proto"] } +datafusion-physical-expr-common = { workspace = true, features = ["proto"] } +datafusion-physical-plan = { workspace = true, features = ["proto"] } datafusion-proto-common = { workspace = true } +datafusion-proto-models = { workspace = true } object_store = { workspace = true } -pbjson = { workspace = true, optional = true } prost = { workspace = true } -serde = { version = "1.0", optional = true } serde_json = { workspace = true, optional = true } [dev-dependencies] +async-trait = { workspace = true } datafusion = { workspace = true, default-features = false, features = [ "sql", "datetime_expressions", diff --git a/datafusion/proto/regen.sh b/datafusion/proto/regen.sh index 02970a90add47..c4bcea9ff5408 100755 --- a/datafusion/proto/regen.sh +++ b/datafusion/proto/regen.sh @@ -17,5 +17,7 @@ # specific language governing permissions and limitations # under the License. +# The proto schema and code generation now live in `datafusion-proto-models`. +# This script is kept as a convenience wrapper. repo_root=$(git rev-parse --show-toplevel) -cd "$repo_root" && cargo run --manifest-path datafusion/proto/gen/Cargo.toml +exec "$repo_root/datafusion/proto-models/regen.sh" diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 6eab2239015a7..2b7d7ed8e849b 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -21,28 +21,21 @@ use crate::logical_plan::{ self, AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; use crate::physical_plan::{ - AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, + DefaultPhysicalExtensionCodec, DefaultPhysicalProtoConverter, PhysicalExtensionCodec, + PhysicalPlanDecodeContext, PhysicalProtoConverterExtension, }; use crate::protobuf; -use datafusion_common::{plan_datafusion_err, Result}; +use datafusion_common::{Result, plan_datafusion_err}; use datafusion_execution::TaskContext; -use datafusion_expr::{ - create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LogicalPlan, Volatility, - WindowUDF, -}; +use datafusion_expr::{Expr, LogicalPlan}; use prost::{ - bytes::{Bytes, BytesMut}, Message, + bytes::{Bytes, BytesMut}, }; use std::sync::Arc; -// Reexport Bytes which appears in the API -use datafusion_execution::registry::FunctionRegistry; -use datafusion_expr::planner::ExprPlanner; use datafusion_physical_plan::ExecutionPlan; -mod registry; - /// Encodes something (such as [`Expr`]) to/from a stream of /// bytes. /// @@ -64,26 +57,21 @@ pub trait Serializeable: Sized { /// Convert `self` to an opaque byte stream fn to_bytes(&self) -> Result; - /// Convert `bytes` (the output of [`to_bytes`]) back into an - /// object. This will error if the serialized bytes contain any - /// user defined functions, in which case use - /// [`from_bytes_with_registry`] + /// Convert `bytes` (the output of [`to_bytes`]) back into an object. This + /// will error if the serialized bytes contain any user defined functions, + /// in which case use [`from_bytes_with_ctx`] /// /// [`to_bytes`]: Self::to_bytes - /// [`from_bytes_with_registry`]: Self::from_bytes_with_registry + /// [`from_bytes_with_ctx`]: Self::from_bytes_with_ctx fn from_bytes(bytes: &[u8]) -> Result { - Self::from_bytes_with_registry(bytes, ®istry::NoRegistry {}) + Self::from_bytes_with_ctx(bytes, &TaskContext::default()) } - /// Convert `bytes` (the output of [`to_bytes`]) back into an - /// object resolving user defined functions with the specified - /// `registry` + /// Convert `bytes` (the output of [`to_bytes`]) back into an object, + /// resolving user defined functions with the specified `ctx` /// /// [`to_bytes`]: Self::to_bytes - fn from_bytes_with_registry( - bytes: &[u8], - registry: &dyn FunctionRegistry, - ) -> Result; + fn from_bytes_with_ctx(bytes: &[u8], ctx: &TaskContext) -> Result; } impl Serializeable for Expr { @@ -99,100 +87,22 @@ impl Serializeable for Expr { let bytes: Bytes = buffer.into(); - // the produced byte stream may lead to "recursion limit" errors, see + // The produced byte stream may lead to "recursion limit" errors, see // https://github.com/apache/datafusion/issues/3968 - // Until the underlying prost issue ( https://github.com/tokio-rs/prost/issues/736 ) is fixed, we try to - // deserialize the data here and check for errors. - // - // Need to provide some placeholder registry because the stream may contain UDFs - struct PlaceHolderRegistry; - - impl FunctionRegistry for PlaceHolderRegistry { - fn udfs(&self) -> std::collections::HashSet { - std::collections::HashSet::default() - } - - fn udf(&self, name: &str) -> Result> { - Ok(Arc::new(create_udf( - name, - vec![], - arrow::datatypes::DataType::Null, - Volatility::Immutable, - Arc::new(|_| unimplemented!()), - ))) - } - - fn udaf(&self, name: &str) -> Result> { - Ok(Arc::new(create_udaf( - name, - vec![arrow::datatypes::DataType::Null], - Arc::new(arrow::datatypes::DataType::Null), - Volatility::Immutable, - Arc::new(|_| unimplemented!()), - Arc::new(vec![]), - ))) - } - - fn udwf(&self, name: &str) -> Result> { - Ok(Arc::new(create_udwf( - name, - arrow::datatypes::DataType::Null, - Arc::new(arrow::datatypes::DataType::Null), - Volatility::Immutable, - Arc::new(|| unimplemented!()), - ))) - } - fn register_udaf( - &mut self, - _udaf: Arc, - ) -> Result>> { - datafusion_common::internal_err!( - "register_udaf called in Placeholder Registry!" - ) - } - fn register_udf( - &mut self, - _udf: Arc, - ) -> Result>> { - datafusion_common::internal_err!( - "register_udf called in Placeholder Registry!" - ) - } - fn register_udwf( - &mut self, - _udaf: Arc, - ) -> Result>> { - datafusion_common::internal_err!( - "register_udwf called in Placeholder Registry!" - ) - } - - fn expr_planners(&self) -> Vec> { - vec![] - } - - fn udafs(&self) -> std::collections::HashSet { - std::collections::HashSet::default() - } - - fn udwfs(&self) -> std::collections::HashSet { - std::collections::HashSet::default() - } - } - Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?; + // Until the underlying prost issue ( https://github.com/tokio-rs/prost/issues/736 ) + // is fixed, verify the bytes can be decoded without hitting that limit. + protobuf::LogicalExprNode::decode(bytes.as_ref()) + .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; Ok(bytes) } - fn from_bytes_with_registry( - bytes: &[u8], - registry: &dyn FunctionRegistry, - ) -> Result { + fn from_bytes_with_ctx(bytes: &[u8], ctx: &TaskContext) -> Result { let protobuf = protobuf::LogicalExprNode::decode(bytes) .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; let extension_codec = DefaultLogicalExtensionCodec {}; - logical_plan::from_proto::parse_expr(&protobuf, registry, &extension_codec) + logical_plan::from_proto::parse_expr(&protobuf, ctx, &extension_codec) .map_err(|e| plan_datafusion_err!("Error parsing protobuf into Expr: {e}")) } } @@ -276,16 +186,18 @@ pub fn logical_plan_from_json_with_extension_codec( /// Serialize a PhysicalPlan as bytes pub fn physical_plan_to_bytes(plan: Arc) -> Result { let extension_codec = DefaultPhysicalExtensionCodec {}; - physical_plan_to_bytes_with_extension_codec(plan, &extension_codec) + let proto_converter = DefaultPhysicalProtoConverter {}; + physical_plan_to_bytes_with_proto_converter(plan, &extension_codec, &proto_converter) } /// Serialize a PhysicalPlan as JSON #[cfg(feature = "json")] pub fn physical_plan_to_json(plan: Arc) -> Result { let extension_codec = DefaultPhysicalExtensionCodec {}; - let protobuf = - protobuf::PhysicalPlanNode::try_from_physical_plan(plan, &extension_codec) - .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; + let proto_converter = DefaultPhysicalProtoConverter {}; + let protobuf = proto_converter + .execution_plan_to_proto(&plan, &extension_codec) + .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; serde_json::to_string(&protobuf) .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}")) } @@ -295,8 +207,18 @@ pub fn physical_plan_to_bytes_with_extension_codec( plan: Arc, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result { - let protobuf = - protobuf::PhysicalPlanNode::try_from_physical_plan(plan, extension_codec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + physical_plan_to_bytes_with_proto_converter(plan, extension_codec, &proto_converter) +} + +/// Serialize a PhysicalPlan as bytes, using the provided extension codec +/// and protobuf converter. +pub fn physical_plan_to_bytes_with_proto_converter( + plan: Arc, + extension_codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, +) -> Result { + let protobuf = proto_converter.execution_plan_to_proto(&plan, extension_codec)?; let mut buffer = BytesMut::new(); protobuf .encode(&mut buffer) @@ -313,7 +235,9 @@ pub fn physical_plan_from_json( let back: protobuf::PhysicalPlanNode = serde_json::from_str(json) .map_err(|e| plan_datafusion_err!("Error serializing plan: {e}"))?; let extension_codec = DefaultPhysicalExtensionCodec {}; - back.try_into_physical_plan(ctx, &extension_codec) + let proto_converter = DefaultPhysicalProtoConverter {}; + let decode_ctx = PhysicalPlanDecodeContext::new(ctx, &extension_codec); + proto_converter.proto_to_execution_plan(&back, &decode_ctx) } /// Deserialize a PhysicalPlan from bytes @@ -322,7 +246,13 @@ pub fn physical_plan_from_bytes( ctx: &TaskContext, ) -> Result> { let extension_codec = DefaultPhysicalExtensionCodec {}; - physical_plan_from_bytes_with_extension_codec(bytes, ctx, &extension_codec) + let proto_converter = DefaultPhysicalProtoConverter {}; + physical_plan_from_bytes_with_proto_converter( + bytes, + ctx, + &extension_codec, + &proto_converter, + ) } /// Deserialize a PhysicalPlan from bytes @@ -330,8 +260,25 @@ pub fn physical_plan_from_bytes_with_extension_codec( bytes: &[u8], ctx: &TaskContext, extension_codec: &dyn PhysicalExtensionCodec, +) -> Result> { + let proto_converter = DefaultPhysicalProtoConverter {}; + physical_plan_from_bytes_with_proto_converter( + bytes, + ctx, + extension_codec, + &proto_converter, + ) +} + +/// Deserialize a PhysicalPlan from bytes +pub fn physical_plan_from_bytes_with_proto_converter( + bytes: &[u8], + ctx: &TaskContext, + extension_codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let protobuf = protobuf::PhysicalPlanNode::decode(bytes) .map_err(|e| plan_datafusion_err!("Error decoding expr as protobuf: {e}"))?; - protobuf.try_into_physical_plan(ctx, extension_codec) + let decode_ctx = PhysicalPlanDecodeContext::new(ctx, extension_codec); + proto_converter.proto_to_execution_plan(&protobuf, &decode_ctx) } diff --git a/datafusion/proto/src/common.rs b/datafusion/proto/src/common.rs index 508e9af419c58..bff017edbc998 100644 --- a/datafusion/proto/src/common.rs +++ b/datafusion/proto/src/common.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{assert_eq_or_internal_err, internal_datafusion_err, Result}; +use datafusion_common::{Result, assert_eq_or_internal_err, internal_datafusion_err}; pub(crate) fn str_to_byte(s: &String, description: &str) -> Result { assert_eq_or_internal_err!( @@ -47,6 +47,24 @@ macro_rules! convert_required { }}; } +/// Like [`convert_required`] but for types whose proto conversion goes through +/// the [`TryFromProto`](crate::convert::TryFromProto) trait instead of +/// [`TryFrom`]. Required because some prost-generated types now live in a +/// separate crate, so `TryFrom`/`From` cannot be implemented on foreign-foreign +/// pairs from `datafusion-proto` directly. +#[macro_export] +macro_rules! convert_required_proto { + ($T:ty, $PB:expr) => {{ + if let Some(field) = $PB.as_ref() { + Ok::<$T, _>(<$T as $crate::convert::TryFromProto<_>>::try_from_proto( + field, + )?) + } else { + Err(proto_error("Missing required field in protobuf")) + } + }}; +} + #[macro_export] macro_rules! into_required { ($PB:expr) => {{ diff --git a/datafusion/proto/src/convert.rs b/datafusion/proto/src/convert.rs new file mode 100644 index 0000000000000..cb5c5bd7f8c12 --- /dev/null +++ b/datafusion/proto/src/convert.rs @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Conversion traits between proto-generated types and DataFusion types. +//! +//! The `prost`-generated structs now live in `datafusion-proto-models`, while +//! their counterparts (`StringifiedPlan`, `JoinType`, `WindowFrame`, ...) live +//! in `datafusion-common` / `datafusion-expr` / `datafusion-datasource` etc. +//! Both sides are foreign to `datafusion-proto`, which means the orphan rule +//! forbids a direct `impl From<&protobuf::X> for Y` written here. +//! +//! To keep the conversion logic colocated with serialization while satisfying +//! the orphan rule, we route those conversions through the `FromProto` / +//! `TryFromProto` traits defined in this module. Their signatures mirror the +//! standard library's `From` / `TryFrom`, so callers spell the conversion +//! `Y::from_proto(&p)` / `Y::try_from_proto(&p)?` instead of +//! `(&p).into()` / `(&p).try_into()?`. + +/// Infallible conversion from a proto value into a DataFusion value (or vice +/// versa). Mirrors [`From`]. +pub trait FromProto: Sized { + fn from_proto(value: T) -> Self; +} + +/// Fallible conversion from a proto value into a DataFusion value (or vice +/// versa). Mirrors [`TryFrom`]. +pub trait TryFromProto: Sized { + type Error; + fn try_from_proto(value: T) -> std::result::Result; +} diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 130538d5af9fa..0e63bcf5f5acb 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -23,8 +23,6 @@ // Make sure fast / cheap clones on Arc are explicit: // https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! Serialize / Deserialize DataFusion Plans to bytes @@ -125,12 +123,13 @@ //! ``` pub mod bytes; pub mod common; -pub mod generated; +pub mod convert; pub mod logical_plan; pub mod physical_plan; +pub use convert::{FromProto, TryFromProto}; + pub mod protobuf { - pub use crate::generated::datafusion::*; pub use datafusion_proto_common::common::proto_error; pub use datafusion_proto_common::protobuf_common::{ ArrowFormat, ArrowOptions, ArrowType, AvroFormat, AvroOptions, CsvFormat, @@ -138,6 +137,40 @@ pub mod protobuf { ScalarValue, Schema, }; pub use datafusion_proto_common::{FromProtoError, ToProtoError}; + // Re-export every type from `datafusion-proto-models`'s generated module + // so the existing `datafusion_proto::protobuf::Foo` paths keep resolving. + // Going through the deeper `generated::datafusion` path (rather than + // `datafusion_proto_models::protobuf`, which is itself a `pub use ::*`) + // avoids a double wildcard re-export that some tools (cargo-semver-checks) + // don't follow. + pub use datafusion_proto_models::generated::datafusion::*; +} + +/// Backwards-compatible re-export of the moved generated types. +/// +/// The prost-generated structs now live in `datafusion-proto-models`; +/// this module preserves the legacy `datafusion_proto::generated::*` paths +/// for downstream callers. Prefer the [`protobuf`] module (or +/// [`datafusion_proto_models`] directly) in new code. +#[deprecated( + since = "53.1.0", + note = "use `datafusion_proto::protobuf` (or `datafusion_proto_models::protobuf`) instead" +)] +pub mod generated { + /// Re-export of the prost-generated types defined in `datafusion.proto`. + #[deprecated( + since = "53.1.0", + note = "use `datafusion_proto::protobuf` (or `datafusion_proto_models::protobuf`) instead" + )] + pub use datafusion_proto_models::generated::datafusion; + + /// Re-export of the prost-generated common types defined in + /// `datafusion_common.proto`. + #[deprecated( + since = "53.1.0", + note = "use `datafusion_proto_common::protobuf_common` instead" + )] + pub use datafusion_proto_models::generated::datafusion_common; } #[cfg(doctest)] diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index d32bfb22ffddd..8e71cc926856c 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -17,11 +17,16 @@ use std::sync::Arc; -use crate::protobuf::{CsvOptions as CsvOptionsProto, JsonOptions as JsonOptionsProto}; +use super::LogicalExtensionCodec; +use crate::convert::{FromProto, TryFromProto}; +use crate::protobuf::{ + CsvOptions as CsvOptionsProto, CsvQuoteStyle as CsvQuoteStyleProto, + JsonOptions as JsonOptionsProto, +}; use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ - exec_datafusion_err, exec_err, not_impl_err, parsers::CompressionTypeVariant, - TableReference, + TableReference, exec_datafusion_err, exec_err, not_impl_err, + parsers::{CompressionTypeVariant, CsvQuoteStyle}, }; use datafusion_datasource::file_format::FileFormatFactory; use datafusion_datasource_arrow::file_format::ArrowFormatFactory; @@ -30,13 +35,11 @@ use datafusion_datasource_json::file_format::JsonFormatFactory; use datafusion_execution::TaskContext; use prost::Message; -use super::LogicalExtensionCodec; - #[derive(Debug)] pub struct CsvLogicalExtensionCodec; -impl CsvOptionsProto { - fn from_factory(factory: &CsvFormatFactory) -> Self { +impl FromProto<&CsvFormatFactory> for CsvOptionsProto { + fn from_proto(factory: &CsvFormatFactory) -> Self { if let Some(options) = &factory.options { CsvOptionsProto { has_header: options.has_header.map_or(vec![], |v| vec![v as u8]), @@ -62,6 +65,14 @@ impl CsvOptionsProto { .newlines_in_values .map_or(vec![], |v| vec![v as u8]), truncated_rows: options.truncated_rows.map_or(vec![], |v| vec![v as u8]), + compression_level: options.compression_level, + quote_style: options.quote_style as i32, + ignore_leading_whitespace: options + .ignore_leading_whitespace + .map_or(vec![], |v| vec![v as u8]), + ignore_trailing_whitespace: options + .ignore_trailing_whitespace + .map_or(vec![], |v| vec![v as u8]), } } else { CsvOptionsProto::default() @@ -69,8 +80,8 @@ impl CsvOptionsProto { } } -impl From<&CsvOptionsProto> for CsvOptions { - fn from(proto: &CsvOptionsProto) -> Self { +impl FromProto<&CsvOptionsProto> for CsvOptions { + fn from_proto(proto: &CsvOptionsProto) -> Self { CsvOptions { has_header: if !proto.has_header.is_empty() { Some(proto.has_header[0] != 0) @@ -152,6 +163,24 @@ impl From<&CsvOptionsProto> for CsvOptions { } else { Some(proto.truncated_rows[0] != 0) }, + compression_level: proto.compression_level, + quote_style: match CsvQuoteStyleProto::try_from(proto.quote_style) { + Ok(CsvQuoteStyleProto::Always) => CsvQuoteStyle::Always, + Ok(CsvQuoteStyleProto::NonNumeric) => CsvQuoteStyle::NonNumeric, + Ok(CsvQuoteStyleProto::Never) => CsvQuoteStyle::Never, + Ok(CsvQuoteStyleProto::Necessary) => CsvQuoteStyle::Necessary, + _ => CsvQuoteStyle::Necessary, + }, + ignore_leading_whitespace: if proto.ignore_leading_whitespace.is_empty() { + None + } else { + Some(proto.ignore_leading_whitespace[0] != 0) + }, + ignore_trailing_whitespace: if proto.ignore_trailing_whitespace.is_empty() { + None + } else { + Some(proto.ignore_trailing_whitespace[0] != 0) + }, } } } @@ -202,7 +231,7 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { let proto = CsvOptionsProto::decode(buf).map_err(|e| { exec_datafusion_err!("Failed to decode CsvOptionsProto: {e:?}") })?; - let options: CsvOptions = (&proto).into(); + let options = CsvOptions::from_proto(&proto); Ok(Arc::new(CsvFormatFactory { options: Some(options), })) @@ -213,14 +242,13 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { buf: &mut Vec, node: Arc, ) -> datafusion_common::Result<()> { - let options = - if let Some(csv_factory) = node.as_any().downcast_ref::() { - csv_factory.options.clone().unwrap_or_default() - } else { - return exec_err!("{}", "Unsupported FileFormatFactory type".to_string()); - }; + let options = if let Some(csv_factory) = node.downcast_ref::() { + csv_factory.options.clone().unwrap_or_default() + } else { + return exec_err!("{}", "Unsupported FileFormatFactory type".to_string()); + }; - let proto = CsvOptionsProto::from_factory(&CsvFormatFactory { + let proto = CsvOptionsProto::from_proto(&CsvFormatFactory { options: Some(options), }); @@ -232,12 +260,14 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { } } -impl JsonOptionsProto { - fn from_factory(factory: &JsonFormatFactory) -> Self { +impl FromProto<&JsonFormatFactory> for JsonOptionsProto { + fn from_proto(factory: &JsonFormatFactory) -> Self { if let Some(options) = &factory.options { JsonOptionsProto { compression: options.compression as i32, schema_infer_max_rec: options.schema_infer_max_rec.map(|v| v as u64), + compression_level: options.compression_level, + newline_delimited: Some(options.newline_delimited), } } else { JsonOptionsProto::default() @@ -245,8 +275,8 @@ impl JsonOptionsProto { } } -impl From<&JsonOptionsProto> for JsonOptions { - fn from(proto: &JsonOptionsProto) -> Self { +impl FromProto<&JsonOptionsProto> for JsonOptions { + fn from_proto(proto: &JsonOptionsProto) -> Self { JsonOptions { compression: match proto.compression { 0 => CompressionTypeVariant::GZIP, @@ -256,6 +286,8 @@ impl From<&JsonOptionsProto> for JsonOptions { _ => CompressionTypeVariant::UNCOMPRESSED, }, schema_infer_max_rec: proto.schema_infer_max_rec.map(|v| v as usize), + compression_level: proto.compression_level, + newline_delimited: proto.newline_delimited.unwrap_or(true), } } } @@ -309,7 +341,7 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { let proto = JsonOptionsProto::decode(buf).map_err(|e| { exec_datafusion_err!("Failed to decode JsonOptionsProto: {e:?}") })?; - let options: JsonOptions = (&proto).into(); + let options = JsonOptions::from_proto(&proto); Ok(Arc::new(JsonFormatFactory { options: Some(options), })) @@ -320,15 +352,14 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { buf: &mut Vec, node: Arc, ) -> datafusion_common::Result<()> { - let options = if let Some(json_factory) = - node.as_any().downcast_ref::() + let options = if let Some(json_factory) = node.downcast_ref::() { json_factory.options.clone().unwrap_or_default() } else { return exec_err!("Unsupported FileFormatFactory type"); }; - let proto = JsonOptionsProto::from_factory(&JsonFormatFactory { + let proto = JsonOptionsProto::from_proto(&JsonFormatFactory { options: Some(options), }); @@ -345,18 +376,20 @@ mod parquet { use super::*; use crate::protobuf::{ - parquet_column_options, parquet_options, + ParquetCdcOptions as ParquetCdcOptionsProto, ParquetColumnOptions as ParquetColumnOptionsProto, ParquetColumnSpecificOptions, ParquetOptions as ParquetOptionsProto, - TableParquetOptions as TableParquetOptionsProto, + TableParquetOptions as TableParquetOptionsProto, parquet_column_options, + parquet_options, }; use datafusion_common::config::{ - ParquetColumnOptions, ParquetOptions, TableParquetOptions, + MaxRowGroupBytes, ParquetCdcOptions, ParquetColumnOptions, ParquetOptions, + TableParquetOptions, }; use datafusion_datasource_parquet::file_format::ParquetFormatFactory; - impl TableParquetOptionsProto { - fn from_factory(factory: &ParquetFormatFactory) -> Self { + impl FromProto<&ParquetFormatFactory> for TableParquetOptionsProto { + fn from_proto(factory: &ParquetFormatFactory) -> Self { let global_options = if let Some(ref options) = factory.options { options.clone() } else { @@ -364,8 +397,7 @@ mod parquet { }; let column_specific_options = global_options.column_specific_options; - #[allow(deprecated)] // max_statistics_size - TableParquetOptionsProto { + TableParquetOptionsProto { global: Some(ParquetOptionsProto { enable_page_index: global_options.global.enable_page_index, pruning: global_options.global.pruning, @@ -375,9 +407,10 @@ mod parquet { }), pushdown_filters: global_options.global.pushdown_filters, reorder_filters: global_options.global.reorder_filters, + force_filter_selections: global_options.global.force_filter_selections, data_pagesize_limit: global_options.global.data_pagesize_limit as u64, write_batch_size: global_options.global.write_batch_size as u64, - writer_version: global_options.global.writer_version.clone(), + writer_version: global_options.global.writer_version.to_string(), compression_opt: global_options.global.compression.map(|compression| { parquet_options::CompressionOpt::Compression(compression) }), @@ -417,9 +450,21 @@ mod parquet { coerce_int96_opt: global_options.global.coerce_int96.map(|compression| { parquet_options::CoerceInt96Opt::CoerceInt96(compression) }), + coerce_int96_tz_opt: global_options.global.coerce_int96_tz.map(|tz| { + parquet_options::CoerceInt96TzOpt::CoerceInt96Tz(tz) + }), max_predicate_cache_size_opt: global_options.global.max_predicate_cache_size.map(|size| { parquet_options::MaxPredicateCacheSizeOpt::MaxPredicateCacheSize(size as u64) }), + max_row_group_bytes_opt: global_options.global.max_row_group_bytes.map(|size| { + parquet_options::MaxRowGroupBytesOpt::MaxRowGroupBytes(size.get() as u64) + }), + content_defined_chunking: Some(ParquetCdcOptionsProto { + enabled: global_options.global.content_defined_chunking.enabled, + min_chunk_size: global_options.global.content_defined_chunking.min_chunk_size as u64, + max_chunk_size: global_options.global.content_defined_chunking.max_chunk_size as u64, + norm_level: global_options.global.content_defined_chunking.norm_level, + }), }), column_specific_options: column_specific_options.into_iter().map(|(column_name, options)| { ParquetColumnSpecificOptions { @@ -459,71 +504,153 @@ mod parquet { } } - impl From<&ParquetOptionsProto> for ParquetOptions { - fn from(proto: &ParquetOptionsProto) -> Self { - #[allow(deprecated)] // max_statistics_size - ParquetOptions { - enable_page_index: proto.enable_page_index, - pruning: proto.pruning, - skip_metadata: proto.skip_metadata, - metadata_size_hint: proto.metadata_size_hint_opt.as_ref().map(|opt| match opt { - parquet_options::MetadataSizeHintOpt::MetadataSizeHint(size) => *size as usize, - }), - pushdown_filters: proto.pushdown_filters, - reorder_filters: proto.reorder_filters, - data_pagesize_limit: proto.data_pagesize_limit as usize, - write_batch_size: proto.write_batch_size as usize, - writer_version: proto.writer_version.clone(), - compression: proto.compression_opt.as_ref().map(|opt| match opt { - parquet_options::CompressionOpt::Compression(compression) => compression.clone(), - }), - dictionary_enabled: proto.dictionary_enabled_opt.as_ref().map(|opt| match opt { - parquet_options::DictionaryEnabledOpt::DictionaryEnabled(enabled) => *enabled, - }), - dictionary_page_size_limit: proto.dictionary_page_size_limit as usize, - statistics_enabled: proto.statistics_enabled_opt.as_ref().map(|opt| match opt { - parquet_options::StatisticsEnabledOpt::StatisticsEnabled(statistics) => statistics.clone(), - }), - max_row_group_size: proto.max_row_group_size as usize, - created_by: proto.created_by.clone(), - column_index_truncate_length: proto.column_index_truncate_length_opt.as_ref().map(|opt| match opt { - parquet_options::ColumnIndexTruncateLengthOpt::ColumnIndexTruncateLength(length) => *length as usize, - }), - statistics_truncate_length: proto.statistics_truncate_length_opt.as_ref().map(|opt| match opt { - parquet_options::StatisticsTruncateLengthOpt::StatisticsTruncateLength(length) => *length as usize, - }), - data_page_row_count_limit: proto.data_page_row_count_limit as usize, - encoding: proto.encoding_opt.as_ref().map(|opt| match opt { - parquet_options::EncodingOpt::Encoding(encoding) => encoding.clone(), - }), - bloom_filter_on_read: proto.bloom_filter_on_read, - bloom_filter_on_write: proto.bloom_filter_on_write, - bloom_filter_fpp: proto.bloom_filter_fpp_opt.as_ref().map(|opt| match opt { - parquet_options::BloomFilterFppOpt::BloomFilterFpp(fpp) => *fpp, - }), - bloom_filter_ndv: proto.bloom_filter_ndv_opt.as_ref().map(|opt| match opt { - parquet_options::BloomFilterNdvOpt::BloomFilterNdv(ndv) => *ndv, - }), - allow_single_file_parallelism: proto.allow_single_file_parallelism, - maximum_parallel_row_group_writers: proto.maximum_parallel_row_group_writers as usize, - maximum_buffered_record_batches_per_stream: proto.maximum_buffered_record_batches_per_stream as usize, - schema_force_view_types: proto.schema_force_view_types, - binary_as_string: proto.binary_as_string, - skip_arrow_metadata: proto.skip_arrow_metadata, - coerce_int96: proto.coerce_int96_opt.as_ref().map(|opt| match opt { - parquet_options::CoerceInt96Opt::CoerceInt96(coerce_int96) => coerce_int96.clone(), - }), - max_predicate_cache_size: proto.max_predicate_cache_size_opt.as_ref().map(|opt| match opt { - parquet_options::MaxPredicateCacheSizeOpt::MaxPredicateCacheSize(size) => *size as usize, - }), + impl FromProto for ParquetCdcOptions { + fn from_proto(value: ParquetCdcOptionsProto) -> Self { + ParquetCdcOptions { + enabled: value.enabled, + min_chunk_size: value.min_chunk_size as usize, + max_chunk_size: value.max_chunk_size as usize, + norm_level: value.norm_level, + } } + } + + impl TryFromProto<&ParquetOptionsProto> for ParquetOptions { + type Error = datafusion_common::DataFusionError; + + fn try_from_proto( + proto: &ParquetOptionsProto, + ) -> datafusion_common::Result { + let writer_version = match proto.writer_version.as_str() { + // Proto3 decodes an omitted string field as the empty string. The + // schema documents writer_version's logical default as "1.0", so + // preserve that default when the field is absent on the wire. + "" => ParquetOptions::default().writer_version, + version => version.parse()?, + }; + + Ok(ParquetOptions { + enable_page_index: proto.enable_page_index, + pruning: proto.pruning, + skip_metadata: proto.skip_metadata, + metadata_size_hint: proto + .metadata_size_hint_opt + .as_ref() + .map(|opt| match opt { + parquet_options::MetadataSizeHintOpt::MetadataSizeHint(size) => { + *size as usize + } + }), + pushdown_filters: proto.pushdown_filters, + reorder_filters: proto.reorder_filters, + force_filter_selections: proto.force_filter_selections, + data_pagesize_limit: proto.data_pagesize_limit as usize, + write_batch_size: proto.write_batch_size as usize, + writer_version, + compression: proto.compression_opt.as_ref().map(|opt| match opt { + parquet_options::CompressionOpt::Compression(compression) => { + compression.clone() + } + }), + dictionary_enabled: proto.dictionary_enabled_opt.as_ref().map(|opt| { + match opt { + parquet_options::DictionaryEnabledOpt::DictionaryEnabled( + enabled, + ) => *enabled, + } + }), + dictionary_page_size_limit: proto.dictionary_page_size_limit as usize, + statistics_enabled: proto.statistics_enabled_opt.as_ref().map( + |opt| match opt { + parquet_options::StatisticsEnabledOpt::StatisticsEnabled( + statistics, + ) => statistics.clone(), + }, + ), + max_row_group_size: proto.max_row_group_size as usize, + created_by: proto.created_by.clone(), + column_index_truncate_length: proto + .column_index_truncate_length_opt + .as_ref() + .map(|opt| match opt { + parquet_options::ColumnIndexTruncateLengthOpt::ColumnIndexTruncateLength(length) => *length as usize, + }), + statistics_truncate_length: proto + .statistics_truncate_length_opt + .as_ref() + .map(|opt| match opt { + parquet_options::StatisticsTruncateLengthOpt::StatisticsTruncateLength(length) => *length as usize, + }), + data_page_row_count_limit: proto.data_page_row_count_limit as usize, + encoding: proto.encoding_opt.as_ref().map(|opt| match opt { + parquet_options::EncodingOpt::Encoding(encoding) => { + encoding.clone() + } + }), + bloom_filter_on_read: proto.bloom_filter_on_read, + bloom_filter_on_write: proto.bloom_filter_on_write, + bloom_filter_fpp: proto + .bloom_filter_fpp_opt + .as_ref() + .map(|opt| match opt { + parquet_options::BloomFilterFppOpt::BloomFilterFpp(fpp) => *fpp, + }), + bloom_filter_ndv: proto + .bloom_filter_ndv_opt + .as_ref() + .map(|opt| match opt { + parquet_options::BloomFilterNdvOpt::BloomFilterNdv(ndv) => *ndv, + }), + allow_single_file_parallelism: proto.allow_single_file_parallelism, + maximum_parallel_row_group_writers: proto + .maximum_parallel_row_group_writers + as usize, + maximum_buffered_record_batches_per_stream: proto + .maximum_buffered_record_batches_per_stream + as usize, + schema_force_view_types: proto.schema_force_view_types, + binary_as_string: proto.binary_as_string, + skip_arrow_metadata: proto.skip_arrow_metadata, + coerce_int96: proto.coerce_int96_opt.as_ref().map(|opt| match opt { + parquet_options::CoerceInt96Opt::CoerceInt96(coerce_int96) => { + coerce_int96.clone() + } + }), + coerce_int96_tz: proto + .coerce_int96_tz_opt + .as_ref() + .map(|opt| match opt { + parquet_options::CoerceInt96TzOpt::CoerceInt96Tz(tz) => { + tz.clone() + } + }), + max_predicate_cache_size: proto + .max_predicate_cache_size_opt + .as_ref() + .map(|opt| match opt { + parquet_options::MaxPredicateCacheSizeOpt::MaxPredicateCacheSize( + size, + ) => *size as usize, + }), + max_row_group_bytes: proto + .max_row_group_bytes_opt + .as_ref() + .and_then(|opt| match opt { + parquet_options::MaxRowGroupBytesOpt::MaxRowGroupBytes(size) => { + MaxRowGroupBytes::try_new(*size as usize).ok() + } + }), + content_defined_chunking: proto + .content_defined_chunking + .map(ParquetCdcOptions::from_proto) + .unwrap_or_default(), + }) } } - impl From for ParquetColumnOptions { - fn from(proto: ParquetColumnOptionsProto) -> Self { - #[allow(deprecated)] // max_statistics_size - ParquetColumnOptions { + impl FromProto for ParquetColumnOptions { + fn from_proto(proto: ParquetColumnOptionsProto) -> Self { + ParquetColumnOptions { bloom_filter_enabled: proto.bloom_filter_enabled_opt.map( |parquet_column_options::BloomFilterEnabledOpt::BloomFilterEnabled(v)| v, ), @@ -549,13 +676,18 @@ mod parquet { } } - impl From<&TableParquetOptionsProto> for TableParquetOptions { - fn from(proto: &TableParquetOptionsProto) -> Self { - TableParquetOptions { + impl TryFromProto<&TableParquetOptionsProto> for TableParquetOptions { + type Error = datafusion_common::DataFusionError; + + fn try_from_proto( + proto: &TableParquetOptionsProto, + ) -> datafusion_common::Result { + Ok(TableParquetOptions { global: proto .global .as_ref() - .map(ParquetOptions::from) + .map(ParquetOptions::try_from_proto) + .transpose()? .unwrap_or_default(), column_specific_options: proto .column_specific_options @@ -563,7 +695,7 @@ mod parquet { .map(|parquet_column_options| { ( parquet_column_options.column_name.clone(), - ParquetColumnOptions::from( + ParquetColumnOptions::from_proto( parquet_column_options .options .clone() @@ -577,8 +709,8 @@ mod parquet { .iter() .map(|(k, v)| (k.clone(), Some(v.clone()))) .collect(), - crypto: Default::default(), - } + ..Default::default() + }) } } @@ -632,7 +764,7 @@ mod parquet { let proto = TableParquetOptionsProto::decode(buf).map_err(|e| { exec_datafusion_err!("Failed to decode TableParquetOptionsProto: {e:?}") })?; - let options: TableParquetOptions = (&proto).into(); + let options = TableParquetOptions::try_from_proto(&proto)?; Ok(Arc::new( datafusion_datasource_parquet::file_format::ParquetFormatFactory { options: Some(options), @@ -648,14 +780,14 @@ mod parquet { use datafusion_datasource_parquet::file_format::ParquetFormatFactory; let options = if let Some(parquet_factory) = - node.as_any().downcast_ref::() + node.downcast_ref::() { parquet_factory.options.clone().unwrap_or_default() } else { return exec_err!("Unsupported FileFormatFactory type"); }; - let proto = TableParquetOptionsProto::from_factory(&ParquetFormatFactory { + let proto = TableParquetOptionsProto::from_proto(&ParquetFormatFactory { options: Some(options), }); @@ -666,6 +798,64 @@ mod parquet { Ok(()) } } + + #[cfg(test)] + mod tests { + use super::*; + + fn encode_table_options(proto: TableParquetOptionsProto) -> Vec { + let mut buf = Vec::new(); + proto.encode(&mut buf).expect("encode parquet options"); + buf + } + + #[test] + fn try_decode_file_format_errors_on_invalid_writer_version() { + let proto = TableParquetOptionsProto { + global: Some(ParquetOptionsProto { + writer_version: "3.0".to_string(), + ..Default::default() + }), + ..Default::default() + }; + + let result = ParquetLogicalExtensionCodec.try_decode_file_format( + &encode_table_options(proto), + &TaskContext::default(), + ); + + let err = result.expect_err("invalid writer version should error"); + assert!( + err.to_string() + .contains("Invalid parquet writer version: 3.0"), + "{err}" + ); + } + + #[test] + fn try_decode_file_format_defaults_empty_writer_version() { + let proto = TableParquetOptionsProto { + global: Some(ParquetOptionsProto::default()), + ..Default::default() + }; + + let factory = ParquetLogicalExtensionCodec + .try_decode_file_format( + &encode_table_options(proto), + &TaskContext::default(), + ) + .expect("decode parquet options"); + let parquet_factory = factory + .downcast_ref::() + .expect("parquet format factory"); + let options = parquet_factory.options.as_ref().expect("parquet options"); + + assert_eq!( + options.global.writer_version, + ParquetOptions::default().writer_version + ); + } + } } #[cfg(feature = "parquet")] pub use parquet::ParquetLogicalExtensionCodec; diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 598a77f5420e2..c68b83964f4cf 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -17,45 +17,49 @@ use std::sync::Arc; -use arrow::datatypes::Field; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::datatype::DataTypeExt; use datafusion_common::{ - exec_datafusion_err, internal_err, plan_datafusion_err, NullEquality, - RecursionUnnestOption, Result, ScalarValue, TableReference, UnnestOptions, + NullEquality, RecursionUnnestOption, Result, ScalarValue, TableReference, + UnnestOptions, exec_datafusion_err, internal_err, plan_datafusion_err, }; +use datafusion_execution::TaskContext; use datafusion_execution::registry::FunctionRegistry; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{Alias, NullTreatment, Placeholder, Sort}; use datafusion_expr::expr::{Unnest, WildcardOptions}; +use datafusion_expr::logical_plan::Subquery; use datafusion_expr::{ - expr::{self, InList, WindowFunction}, - logical_plan::{PlanType, StringifiedPlan}, Between, BinaryExpr, Case, Cast, Expr, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, + expr::{self, InList, WindowFunction}, + logical_plan::{PlanType, StringifiedPlan}, }; use datafusion_expr::{ExprFunctionExt, WriteOp}; -use datafusion_proto_common::{from_proto::FromOptionalField, FromProtoError as Error}; +use datafusion_proto_common::{FromProtoError as Error, from_proto::FromOptionalField}; use crate::protobuf::plan_type::PlanTypeEnum::{ FinalPhysicalPlanWithSchema, InitialPhysicalPlanWithSchema, }; use crate::protobuf::{ - self, + self, AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, + OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, plan_type::PlanTypeEnum::{ AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, OptimizedPhysicalPlan, PhysicalPlanError, }, - AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, - OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, }; -use super::LogicalExtensionCodec; +use crate::convert::{FromProto, TryFromProto}; -impl From<&protobuf::UnnestOptions> for UnnestOptions { - fn from(opts: &protobuf::UnnestOptions) -> Self { +use super::{AsLogicalPlan, LogicalExtensionCodec}; + +impl FromProto<&protobuf::UnnestOptions> for UnnestOptions { + fn from_proto(opts: &protobuf::UnnestOptions) -> Self { Self { preserve_nulls: opts.preserve_nulls, recursions: opts @@ -71,8 +75,8 @@ impl From<&protobuf::UnnestOptions> for UnnestOptions { } } -impl From for WindowFrameUnits { - fn from(units: protobuf::WindowFrameUnits) -> Self { +impl FromProto for WindowFrameUnits { + fn from_proto(units: protobuf::WindowFrameUnits) -> Self { match units { protobuf::WindowFrameUnits::Rows => Self::Rows, protobuf::WindowFrameUnits::Range => Self::Range, @@ -81,10 +85,10 @@ impl From for WindowFrameUnits { } } -impl TryFrom for TableReference { +impl TryFromProto for TableReference { type Error = Error; - fn try_from(value: protobuf::TableReference) -> Result { + fn try_from_proto(value: protobuf::TableReference) -> Result { use protobuf::table_reference::TableReferenceEnum; let table_reference_enum = value .table_reference_enum @@ -107,8 +111,8 @@ impl TryFrom for TableReference { } } -impl From<&protobuf::StringifiedPlan> for StringifiedPlan { - fn from(stringified_plan: &protobuf::StringifiedPlan) -> Self { +impl FromProto<&protobuf::StringifiedPlan> for StringifiedPlan { + fn from_proto(stringified_plan: &protobuf::StringifiedPlan) -> Self { Self { plan_type: match stringified_plan .plan_type @@ -150,19 +154,25 @@ impl From<&protobuf::StringifiedPlan> for StringifiedPlan { } } -impl TryFrom for WindowFrame { +impl TryFromProto for WindowFrame { type Error = Error; - fn try_from(window: protobuf::WindowFrame) -> Result { - let units = protobuf::WindowFrameUnits::try_from(window.window_frame_units) - .map_err(|_| Error::unknown("WindowFrameUnits", window.window_frame_units))? - .into(); - let start_bound = window.start_bound.required("start_bound")?; + fn try_from_proto(window: protobuf::WindowFrame) -> Result { + let units = WindowFrameUnits::from_proto( + protobuf::WindowFrameUnits::try_from(window.window_frame_units).map_err( + |_| Error::unknown("WindowFrameUnits", window.window_frame_units), + )?, + ); + let start_bound = WindowFrameBound::try_from_proto( + window + .start_bound + .ok_or_else(|| Error::required("start_bound"))?, + )?; let end_bound = window .end_bound .map(|end_bound| match end_bound { protobuf::window_frame::EndBound::Bound(end_bound) => { - end_bound.try_into() + WindowFrameBound::try_from_proto(end_bound) } }) .transpose()? @@ -171,10 +181,10 @@ impl TryFrom for WindowFrame { } } -impl TryFrom for WindowFrameBound { +impl TryFromProto for WindowFrameBound { type Error = Error; - fn try_from(bound: protobuf::WindowFrameBound) -> Result { + fn try_from_proto(bound: protobuf::WindowFrameBound) -> Result { let bound_type = protobuf::WindowFrameBoundType::try_from(bound.window_frame_bound_type) .map_err(|_| { @@ -194,8 +204,8 @@ impl TryFrom for WindowFrameBound { } } -impl From for JoinType { - fn from(t: protobuf::JoinType) -> Self { +impl FromProto for JoinType { + fn from_proto(t: protobuf::JoinType) -> Self { match t { protobuf::JoinType::Inner => JoinType::Inner, protobuf::JoinType::Left => JoinType::Left, @@ -211,8 +221,8 @@ impl From for JoinType { } } -impl From for JoinConstraint { - fn from(t: protobuf::JoinConstraint) -> Self { +impl FromProto for JoinConstraint { + fn from_proto(t: protobuf::JoinConstraint) -> Self { match t { protobuf::JoinConstraint::On => JoinConstraint::On, protobuf::JoinConstraint::Using => JoinConstraint::Using, @@ -220,8 +230,8 @@ impl From for JoinConstraint { } } -impl From for NullEquality { - fn from(t: protobuf::NullEquality) -> Self { +impl FromProto for NullEquality { + fn from_proto(t: protobuf::NullEquality) -> Self { match t { protobuf::NullEquality::NullEqualsNothing => NullEquality::NullEqualsNothing, protobuf::NullEquality::NullEqualsNull => NullEquality::NullEqualsNull, @@ -229,8 +239,8 @@ impl From for NullEquality { } } -impl From for WriteOp { - fn from(t: protobuf::dml_node::Type) -> Self { +impl FromProto for WriteOp { + fn from_proto(t: protobuf::dml_node::Type) -> Self { match t { protobuf::dml_node::Type::Update => WriteOp::Update, protobuf::dml_node::Type::Delete => WriteOp::Delete, @@ -240,12 +250,13 @@ impl From for WriteOp { } protobuf::dml_node::Type::InsertReplace => WriteOp::Insert(InsertOp::Replace), protobuf::dml_node::Type::Ctas => WriteOp::Ctas, + protobuf::dml_node::Type::Truncate => WriteOp::Truncate, } } } -impl From for NullTreatment { - fn from(t: protobuf::NullTreatment) -> Self { +impl FromProto for NullTreatment { + fn from_proto(t: protobuf::NullTreatment) -> Self { match t { protobuf::NullTreatment::RespectNulls => NullTreatment::RespectNulls, protobuf::NullTreatment::IgnoreNulls => NullTreatment::IgnoreNulls, @@ -255,7 +266,7 @@ impl From for NullTreatment { pub fn parse_expr( proto: &protobuf::LogicalExprNode, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result { use protobuf::{logical_expr_node::ExprType, window_expr_node}; @@ -268,7 +279,7 @@ pub fn parse_expr( match expr_type { ExprType::BinaryExpr(binary_expr) => { let op = from_proto_binary_op(&binary_expr.op)?; - let operands = parse_exprs(&binary_expr.operands, registry, codec)?; + let operands = parse_exprs(&binary_expr.operands, ctx, codec)?; if operands.len() < 2 { return Err(proto_error( @@ -295,13 +306,13 @@ pub fn parse_expr( .window_function .as_ref() .ok_or_else(|| Error::required("window_function"))?; - let partition_by = parse_exprs(&expr.partition_by, registry, codec)?; - let mut order_by = parse_sorts(&expr.order_by, registry, codec)?; + let partition_by = parse_exprs(&expr.partition_by, ctx, codec)?; + let mut order_by = parse_sorts(&expr.order_by, ctx, codec)?; let window_frame = expr .window_frame .as_ref() .map::, _>(|window_frame| { - let window_frame: WindowFrame = window_frame.clone().try_into()?; + let window_frame = WindowFrame::try_from_proto(window_frame.clone())?; window_frame .regularize_order_bys(&mut order_by) .map(|_| window_frame) @@ -319,7 +330,7 @@ pub fn parse_expr( "Received a WindowExprNode message with unknown NullTreatment {null_treatment}", )) })?; - Some(NullTreatment::from(null_treatment)) + Some(NullTreatment::from_proto(null_treatment)) } None => None, }; @@ -328,7 +339,7 @@ pub fn parse_expr( window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, - None => registry + None => ctx .udaf(udaf_name) .or_else(|_| codec.try_decode_udaf(udaf_name, &[]))?, }; @@ -337,7 +348,7 @@ pub fn parse_expr( window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, - None => registry + None => ctx .udwf(udwf_name) .or_else(|_| codec.try_decode_udwf(udwf_name, &[]))?, }; @@ -345,7 +356,7 @@ pub fn parse_expr( } }; - let args = parse_exprs(&expr.exprs, registry, codec)?; + let args = parse_exprs(&expr.exprs, ctx, codec)?; let mut builder = Expr::from(WindowFunction::new(agg_fn, args)) .partition_by(partition_by) .order_by(order_by) @@ -356,8 +367,7 @@ pub fn parse_expr( builder = builder.distinct(); }; - if let Some(filter) = - parse_optional_expr(expr.filter.as_deref(), registry, codec)? + if let Some(filter) = parse_optional_expr(expr.filter.as_deref(), ctx, codec)? { builder = builder.filter(filter); } @@ -365,79 +375,79 @@ pub fn parse_expr( builder.build().map_err(Error::DataFusionError) } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( - parse_required_expr(alias.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(alias.expr.as_deref(), ctx, "expr", codec)?, alias .relation .first() - .map(|r| TableReference::try_from(r.clone())) + .map(|r| TableReference::try_from_proto(r.clone())) .transpose()?, alias.alias.clone(), ))), ExprType::IsNullExpr(is_null) => Ok(Expr::IsNull(Box::new(parse_required_expr( is_null.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotNullExpr(is_not_null) => Ok(Expr::IsNotNull(Box::new( - parse_required_expr(is_not_null.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(is_not_null.expr.as_deref(), ctx, "expr", codec)?, ))), ExprType::NotExpr(not) => Ok(Expr::Not(Box::new(parse_required_expr( not.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsTrue(msg) => Ok(Expr::IsTrue(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsFalse(msg) => Ok(Expr::IsFalse(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsUnknown(msg) => Ok(Expr::IsUnknown(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotTrue(msg) => Ok(Expr::IsNotTrue(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotFalse(msg) => Ok(Expr::IsNotFalse(Box::new(parse_required_expr( msg.expr.as_deref(), - registry, + ctx, "expr", codec, )?))), ExprType::IsNotUnknown(msg) => Ok(Expr::IsNotUnknown(Box::new( - parse_required_expr(msg.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(msg.expr.as_deref(), ctx, "expr", codec)?, ))), ExprType::Between(between) => Ok(Expr::Between(Between::new( Box::new(parse_required_expr( between.expr.as_deref(), - registry, + ctx, "expr", codec, )?), between.negated, Box::new(parse_required_expr( between.low.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( between.high.as_deref(), - registry, + ctx, "expr", codec, )?), @@ -446,13 +456,13 @@ pub fn parse_expr( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( like.pattern.as_deref(), - registry, + ctx, "pattern", codec, )?), @@ -463,13 +473,13 @@ pub fn parse_expr( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( like.pattern.as_deref(), - registry, + ctx, "pattern", codec, )?), @@ -480,13 +490,13 @@ pub fn parse_expr( like.negated, Box::new(parse_required_expr( like.expr.as_deref(), - registry, + ctx, "expr", codec, )?), Box::new(parse_required_expr( like.pattern.as_deref(), - registry, + ctx, "pattern", codec, )?), @@ -500,13 +510,13 @@ pub fn parse_expr( .map(|e| { let when_expr = parse_required_expr( e.when_expr.as_ref(), - registry, + ctx, "when_expr", codec, )?; let then_expr = parse_required_expr( e.then_expr.as_ref(), - registry, + ctx, "then_expr", codec, )?; @@ -514,37 +524,45 @@ pub fn parse_expr( }) .collect::, Box)>, Error>>()?; Ok(Expr::Case(Case::new( - parse_optional_expr(case.expr.as_deref(), registry, codec)?.map(Box::new), + parse_optional_expr(case.expr.as_deref(), ctx, codec)?.map(Box::new), when_then_expr, - parse_optional_expr(case.else_expr.as_deref(), registry, codec)? - .map(Box::new), + parse_optional_expr(case.else_expr.as_deref(), ctx, codec)?.map(Box::new), ))) } ExprType::Cast(cast) => { let expr = Box::new(parse_required_expr( cast.expr.as_deref(), - registry, + ctx, "expr", codec, )?); - let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - Ok(Expr::Cast(Cast::new(expr, data_type))) + let data_type: DataType = cast.arrow_type.as_ref().required("arrow_type")?; + let field = data_type + .into_nullable_field() + .with_nullable(cast.nullable.unwrap_or(true)); + Ok(Expr::Cast(Cast::new_from_field(expr, Arc::new(field)))) } ExprType::TryCast(cast) => { let expr = Box::new(parse_required_expr( cast.expr.as_deref(), - registry, + ctx, "expr", codec, )?); - let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - Ok(Expr::TryCast(TryCast::new(expr, data_type))) + let data_type: DataType = cast.arrow_type.as_ref().required("arrow_type")?; + let field = data_type + .into_nullable_field() + .with_nullable(cast.nullable.unwrap_or(true)); + Ok(Expr::TryCast(TryCast::new_from_field( + expr, + Arc::new(field), + ))) } ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( - parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, + parse_required_expr(negative.expr.as_deref(), ctx, "expr", codec)?, ))), ExprType::Unnest(unnest) => { - let mut exprs = parse_exprs(&unnest.exprs, registry, codec)?; + let mut exprs = parse_exprs(&unnest.exprs, ctx, codec)?; if exprs.len() != 1 { return Err(proto_error("Unnest must have exactly one expression")); } @@ -553,15 +571,18 @@ pub fn parse_expr( ExprType::InList(in_list) => Ok(Expr::InList(InList::new( Box::new(parse_required_expr( in_list.expr.as_deref(), - registry, + ctx, "expr", codec, )?), - parse_exprs(&in_list.list, registry, codec)?, + parse_exprs(&in_list.list, ctx, codec)?, in_list.negated, ))), ExprType::Wildcard(protobuf::Wildcard { qualifier }) => { - let qualifier = qualifier.to_owned().map(|x| x.try_into()).transpose()?; + let qualifier = qualifier + .to_owned() + .map(TableReference::try_from_proto) + .transpose()?; #[expect(deprecated)] Ok(Expr::Wildcard { qualifier, @@ -575,19 +596,19 @@ pub fn parse_expr( }) => { let scalar_fn = match fun_definition { Some(buf) => codec.try_decode_udf(fun_name, buf)?, - None => registry + None => ctx .udf(fun_name.as_str()) .or_else(|_| codec.try_decode_udf(fun_name, &[]))?, }; Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, - parse_exprs(args, registry, codec)?, + parse_exprs(args, ctx, codec)?, ))) } ExprType::AggregateUdfExpr(pb) => { let agg_fn = match &pb.fun_definition { Some(buf) => codec.try_decode_udaf(&pb.fun_name, buf)?, - None => registry + None => ctx .udaf(&pb.fun_name) .or_else(|_| codec.try_decode_udaf(&pb.fun_name, &[]))?, }; @@ -599,17 +620,17 @@ pub fn parse_expr( "Received an AggregateUdfExprNode message with unknown NullTreatment {null_treatment}", )) })?; - Some(NullTreatment::from(null_treatment)) + Some(NullTreatment::from_proto(null_treatment)) } None => None, }; Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, - parse_exprs(&pb.args, registry, codec)?, + parse_exprs(&pb.args, ctx, codec)?, pb.distinct, - parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), - parse_sorts(&pb.order_by, registry, codec)?, + parse_optional_expr(pb.filter.as_deref(), ctx, codec)?.map(Box::new), + parse_sorts(&pb.order_by, ctx, codec)?, null_treatment, ))) } @@ -617,15 +638,15 @@ pub fn parse_expr( ExprType::GroupingSet(GroupingSetNode { expr }) => { Ok(Expr::GroupingSet(GroupingSets( expr.iter() - .map(|expr_list| parse_exprs(&expr_list.expr, registry, codec)) + .map(|expr_list| parse_exprs(&expr_list.expr, ctx, codec)) .collect::, Error>>()?, ))) } ExprType::Cube(CubeNode { expr }) => Ok(Expr::GroupingSet(GroupingSet::Cube( - parse_exprs(expr, registry, codec)?, + parse_exprs(expr, ctx, codec)?, ))), ExprType::Rollup(RollupNode { expr }) => Ok(Expr::GroupingSet( - GroupingSet::Rollup(parse_exprs(expr, registry, codec)?), + GroupingSet::Rollup(parse_exprs(expr, ctx, codec)?), )), ExprType::Placeholder(PlaceholderNode { id, @@ -647,13 +668,41 @@ pub fn parse_expr( ))) } }, + ExprType::ScalarSubqueryExpr(sq) => { + let subquery = parse_subquery( + sq.subquery + .as_deref() + .ok_or_else(|| Error::required("ScalarSubqueryExprNode.subquery"))?, + ctx, + codec, + )?; + Ok(Expr::ScalarSubquery(subquery)) + } } } +fn parse_subquery( + proto: &protobuf::SubqueryNode, + ctx: &TaskContext, + codec: &dyn LogicalExtensionCodec, +) -> Result { + let plan_node = proto + .subquery + .as_ref() + .ok_or_else(|| Error::required("SubqueryNode.subquery"))?; + let plan = plan_node.try_into_logical_plan(ctx, codec)?; + let outer_ref_columns = parse_exprs(&proto.outer_ref_columns, ctx, codec)?; + Ok(Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: Default::default(), + }) +} + /// Parse a vector of `protobuf::LogicalExprNode`s. pub fn parse_exprs<'a, I>( protos: I, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result, Error> where @@ -662,7 +711,7 @@ where let res = protos .into_iter() .map(|elem| { - parse_expr(elem, registry, codec).map_err(|e| plan_datafusion_err!("{}", e)) + parse_expr(elem, ctx, codec).map_err(|e| plan_datafusion_err!("{}", e)) }) .collect::>>()?; Ok(res) @@ -670,7 +719,7 @@ where pub fn parse_sorts<'a, I>( protos: I, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result, Error> where @@ -678,17 +727,17 @@ where { protos .into_iter() - .map(|sort| parse_sort(sort, registry, codec)) + .map(|sort| parse_sort(sort, ctx, codec)) .collect::, Error>>() } pub fn parse_sort( sort: &protobuf::SortExprNode, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result { Ok(Sort::new( - parse_required_expr(sort.expr.as_ref(), registry, "expr", codec)?, + parse_required_expr(sort.expr.as_ref(), ctx, "expr", codec)?, sort.asc, sort.nulls_first, )) @@ -704,59 +753,32 @@ fn parse_escape_char(s: &str) -> Result> { } pub fn from_proto_binary_op(op: &str) -> Result { - match op { - "And" => Ok(Operator::And), - "Or" => Ok(Operator::Or), - "Eq" => Ok(Operator::Eq), - "NotEq" => Ok(Operator::NotEq), - "LtEq" => Ok(Operator::LtEq), - "Lt" => Ok(Operator::Lt), - "Gt" => Ok(Operator::Gt), - "GtEq" => Ok(Operator::GtEq), - "Plus" => Ok(Operator::Plus), - "Minus" => Ok(Operator::Minus), - "Multiply" => Ok(Operator::Multiply), - "Divide" => Ok(Operator::Divide), - "Modulo" => Ok(Operator::Modulo), - "IsDistinctFrom" => Ok(Operator::IsDistinctFrom), - "IsNotDistinctFrom" => Ok(Operator::IsNotDistinctFrom), - "BitwiseAnd" => Ok(Operator::BitwiseAnd), - "BitwiseOr" => Ok(Operator::BitwiseOr), - "BitwiseXor" => Ok(Operator::BitwiseXor), - "BitwiseShiftLeft" => Ok(Operator::BitwiseShiftLeft), - "BitwiseShiftRight" => Ok(Operator::BitwiseShiftRight), - "RegexIMatch" => Ok(Operator::RegexIMatch), - "RegexMatch" => Ok(Operator::RegexMatch), - "RegexNotIMatch" => Ok(Operator::RegexNotIMatch), - "RegexNotMatch" => Ok(Operator::RegexNotMatch), - "StringConcat" => Ok(Operator::StringConcat), - "AtArrow" => Ok(Operator::AtArrow), - "ArrowAt" => Ok(Operator::ArrowAt), - other => Err(proto_error(format!( - "Unsupported binary operator '{other:?}'" - ))), - } + // The proto-string <-> `Operator` mapping is canonically owned by + // `datafusion-expr-common` so `datafusion-proto` (logical plans) and + // `PhysicalExpr` decoders (e.g. `BinaryExpr`) share one source of truth. + Operator::from_proto_name(op) + .ok_or_else(|| proto_error(format!("Unsupported binary operator '{op:?}'"))) } fn parse_optional_expr( p: Option<&protobuf::LogicalExprNode>, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, codec: &dyn LogicalExtensionCodec, ) -> Result, Error> { match p { - Some(expr) => parse_expr(expr, registry, codec).map(Some), + Some(expr) => parse_expr(expr, ctx, codec).map(Some), None => Ok(None), } } fn parse_required_expr( p: Option<&protobuf::LogicalExprNode>, - registry: &dyn FunctionRegistry, + ctx: &TaskContext, field: impl Into, codec: &dyn LogicalExtensionCodec, ) -> Result { match p { - Some(expr) => parse_expr(expr, registry, codec), + Some(expr) => parse_expr(expr, ctx, codec), None => Err(Error::required(field)), } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 7a8cbafc22bf8..35c2e76d880b9 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -19,63 +19,71 @@ use std::collections::HashMap; use std::fmt::Debug; use std::sync::Arc; +use crate::convert::{FromProto, TryFromProto}; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; use crate::protobuf::{ - dml_node, ColumnUnnestListItem, ColumnUnnestListRecursion, CteWorkTableScanNode, - CustomTableScanNode, DmlNode, SortExprNodeCollection, + ColumnUnnestListItem, ColumnUnnestListRecursion, CteWorkTableScanNode, + CustomTableScanNode, DmlNode, SortExprNodeCollection, dml_node, }; use crate::{ - convert_required, into_required, + convert_required, protobuf::{ - self, listing_table_scan_node::FileFormatType, - logical_plan_node::LogicalPlanType, LogicalExtensionNode, LogicalPlanNode, + self, LogicalExtensionNode, LogicalPlanNode, + listing_table_scan_node::FileFormatType, logical_plan_node::LogicalPlanType, }, }; -use crate::protobuf::{proto_error, ToProtoError}; +use crate::protobuf::{ToProtoError, proto_error}; use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder, SchemaRef}; use datafusion_catalog::cte_worktable::CteWorkTable; +use datafusion_catalog::empty::EmptyTable; use datafusion_common::file_options::file_type::FileType; +use datafusion_common::format::{ + ExplainAnalyzeCategories, ExplainFormat, MetricCategory, MetricType, +}; use datafusion_common::{ - assert_or_internal_err, context, internal_datafusion_err, internal_err, not_impl_err, - plan_err, Result, TableReference, ToDFSchema, + NullEquality, Result, TableReference, assert_or_internal_err, context, + internal_datafusion_err, internal_err, not_impl_err, plan_err, }; use datafusion_datasource::file_format::FileFormat; use datafusion_datasource::file_format::{ - file_type_to_format, format_as_file_type, FileFormatFactory, + FileFormatFactory, file_type_to_format, format_as_file_type, }; -use datafusion_datasource_arrow::file_format::ArrowFormat; +use datafusion_datasource_arrow::file_format::{ArrowFormat, ArrowFormatFactory}; #[cfg(feature = "avro")] use datafusion_datasource_avro::file_format::AvroFormat; -use datafusion_datasource_csv::file_format::CsvFormat; -use datafusion_datasource_json::file_format::JsonFormat as OtherNdJsonFormat; +use datafusion_datasource_csv::file_format::{CsvFormat, CsvFormatFactory}; +use datafusion_datasource_json::file_format::{ + JsonFormat as OtherNdJsonFormat, JsonFormatFactory, +}; #[cfg(feature = "parquet")] -use datafusion_datasource_parquet::file_format::ParquetFormat; +use datafusion_datasource_parquet::file_format::{ParquetFormat, ParquetFormatFactory}; use datafusion_expr::{ - dml, - logical_plan::{ - builder::project, Aggregate, CreateCatalog, CreateCatalogSchema, - CreateExternalTable, CreateView, DdlStatement, Distinct, EmptyRelation, - Extension, Join, JoinConstraint, Prepare, Projection, Repartition, Sort, - SubqueryAlias, TableScan, Values, Window, - }, - DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, - Statement, WindowUDF, + AggregateUDF, DmlStatement, FetchType, HigherOrderUDF, RecursiveQuery, SkipType, + TableSource, Unnest, WriteOp, }; use datafusion_expr::{ - AggregateUDF, DmlStatement, FetchType, RecursiveQuery, SkipType, TableSource, Unnest, + DistinctOn, DropView, Expr, JoinConstraint, LogicalPlan, LogicalPlanBuilder, + ScalarUDF, SortExpr, Statement, WindowUDF, dml, + logical_plan::{ + Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateView, + DdlStatement, Distinct, EmptyRelation, Extension, Join, Prepare, Projection, + Repartition, Sort, SubqueryAlias, TableScan, TableScanBuilder, Values, Window, + builder::project, + }, }; +use datafusion_proto_common::protobuf_common; use self::to_proto::{serialize_expr, serialize_exprs}; use crate::logical_plan::to_proto::serialize_sorts; +use datafusion_catalog::TableProvider; use datafusion_catalog::default_table_source::{provider_as_source, source_as_provider}; use datafusion_catalog::view::ViewTable; -use datafusion_catalog::TableProvider; use datafusion_catalog_listing::{ListingOptions, ListingTable, ListingTableConfig}; use datafusion_datasource::ListingTableUrl; use datafusion_execution::TaskContext; -use prost::bytes::BufMut; use prost::Message; +use prost::bytes::BufMut; pub mod file_formats; pub mod from_proto; @@ -105,7 +113,7 @@ pub trait AsLogicalPlan: Debug + Send + Sync + Clone { Self: Sized; } -pub trait LogicalExtensionCodec: Debug + Send + Sync { +pub trait LogicalExtensionCodec: Debug + Send + Sync + std::any::Any { fn try_decode( &self, buf: &[u8], @@ -154,6 +162,24 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { Ok(()) } + fn try_decode_higher_order_function( + &self, + name: &str, + _buf: &[u8], + ) -> Result> { + not_impl_err!( + "LogicalExtensionCodec is not provided for higher order function {name}" + ) + } + + fn try_encode_higher_order_function( + &self, + _node: &HigherOrderUDF, + _buf: &mut Vec, + ) -> Result<()> { + Ok(()) + } + fn try_decode_udaf(&self, name: &str, _buf: &[u8]) -> Result> { not_impl_err!( "LogicalExtensionCodec is not provided for aggregate function {name}" @@ -208,6 +234,95 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { ) -> Result<()> { not_impl_err!("LogicalExtensionCodec is not provided") } + + fn try_decode_file_format( + &self, + buf: &[u8], + ctx: &TaskContext, + ) -> Result> { + let proto = protobuf::FileFormatProto::decode(buf).map_err(|e| { + internal_datafusion_err!("Failed to decode FileFormatProto: {e}") + })?; + + let kind = protobuf::FileFormatKind::try_from(proto.kind).map_err(|_| { + internal_datafusion_err!("Unknown FileFormatKind: {}", proto.kind) + })?; + + match kind { + protobuf::FileFormatKind::Csv => file_formats::CsvLogicalExtensionCodec + .try_decode_file_format(&proto.encoded_file_format, ctx), + protobuf::FileFormatKind::Json => file_formats::JsonLogicalExtensionCodec + .try_decode_file_format(&proto.encoded_file_format, ctx), + #[cfg(feature = "parquet")] + protobuf::FileFormatKind::Parquet => { + file_formats::ParquetLogicalExtensionCodec + .try_decode_file_format(&proto.encoded_file_format, ctx) + } + protobuf::FileFormatKind::Arrow => file_formats::ArrowLogicalExtensionCodec + .try_decode_file_format(&proto.encoded_file_format, ctx), + protobuf::FileFormatKind::Avro => file_formats::AvroLogicalExtensionCodec + .try_decode_file_format(&proto.encoded_file_format, ctx), + #[cfg(not(feature = "parquet"))] + protobuf::FileFormatKind::Parquet => { + not_impl_err!("Parquet support requires the 'parquet' feature") + } + protobuf::FileFormatKind::Unspecified => { + not_impl_err!("Unspecified file format kind") + } + } + } + + fn try_encode_file_format( + &self, + buf: &mut Vec, + node: Arc, + ) -> Result<()> { + let mut encoded_file_format = Vec::new(); + + let kind = if node.downcast_ref::().is_some() { + file_formats::CsvLogicalExtensionCodec + .try_encode_file_format(&mut encoded_file_format, Arc::clone(&node))?; + protobuf::FileFormatKind::Csv + } else if node.downcast_ref::().is_some() { + file_formats::JsonLogicalExtensionCodec + .try_encode_file_format(&mut encoded_file_format, Arc::clone(&node))?; + protobuf::FileFormatKind::Json + } else if node.downcast_ref::().is_some() { + file_formats::ArrowLogicalExtensionCodec + .try_encode_file_format(&mut encoded_file_format, Arc::clone(&node))?; + protobuf::FileFormatKind::Arrow + } else { + #[cfg(feature = "parquet")] + { + if node.downcast_ref::().is_some() { + file_formats::ParquetLogicalExtensionCodec.try_encode_file_format( + &mut encoded_file_format, + Arc::clone(&node), + )?; + protobuf::FileFormatKind::Parquet + } else { + return not_impl_err!( + "Unsupported FileFormatFactory type for DefaultLogicalExtensionCodec" + ); + } + } + #[cfg(not(feature = "parquet"))] + { + return not_impl_err!( + "Unsupported FileFormatFactory type for DefaultLogicalExtensionCodec" + ); + } + }; + + let proto = protobuf::FileFormatProto { + kind: kind as i32, + encoded_file_format, + }; + proto.encode(buf).map_err(|e| { + internal_datafusion_err!("Failed to encode FileFormatProto: {e}") + })?; + Ok(()) + } } #[macro_export] @@ -231,7 +346,7 @@ fn from_table_reference( ) })?; - Ok(table_ref.clone().try_into()?) + Ok(TableReference::try_from_proto(table_ref.clone())?) } /// Converts [LogicalPlan::TableScan] to [TableSource] @@ -260,19 +375,85 @@ fn from_table_source( target: Arc, extension_codec: &dyn LogicalExtensionCodec, ) -> Result { - let projected_schema = target.schema().to_dfschema_ref()?; - let r = LogicalPlan::TableScan(TableScan { - table_name, - source: target, - projection: None, - projected_schema, - filters: vec![], - fetch: None, - }); + let r = LogicalPlan::TableScan(TableScanBuilder::new(table_name, target).build()?); LogicalPlanNode::try_from_logical_plan(&r, extension_codec) } +fn metric_type_from_proto(value: i32) -> Result { + let pb = protobuf_common::MetricType::try_from(value) + .map_err(|_| proto_error(format!("Unknown MetricType discriminant: {value}")))?; + Ok(match pb { + protobuf_common::MetricType::Summary => MetricType::Summary, + protobuf_common::MetricType::Dev => MetricType::Dev, + }) +} + +fn metric_type_to_proto(value: MetricType) -> protobuf_common::MetricType { + match value { + MetricType::Summary => protobuf_common::MetricType::Summary, + MetricType::Dev => protobuf_common::MetricType::Dev, + } +} + +fn metric_category_from_proto(value: i32) -> Result { + let pb = protobuf_common::MetricCategory::try_from(value).map_err(|_| { + proto_error(format!("Unknown MetricCategory discriminant: {value}")) + })?; + Ok(match pb { + protobuf_common::MetricCategory::Rows => MetricCategory::Rows, + protobuf_common::MetricCategory::Bytes => MetricCategory::Bytes, + protobuf_common::MetricCategory::Timing => MetricCategory::Timing, + protobuf_common::MetricCategory::Uncategorized => MetricCategory::Uncategorized, + }) +} + +fn metric_category_to_proto(value: MetricCategory) -> protobuf_common::MetricCategory { + match value { + MetricCategory::Rows => protobuf_common::MetricCategory::Rows, + MetricCategory::Bytes => protobuf_common::MetricCategory::Bytes, + MetricCategory::Timing => protobuf_common::MetricCategory::Timing, + MetricCategory::Uncategorized => protobuf_common::MetricCategory::Uncategorized, + } +} + +fn explain_analyze_categories_from_proto( + node: &protobuf_common::ExplainAnalyzeCategoriesNode, +) -> Result { + if node.all { + Ok(ExplainAnalyzeCategories::All) + } else { + let cats = node + .only + .iter() + .copied() + .map(metric_category_from_proto) + .collect::>>()?; + Ok(ExplainAnalyzeCategories::Only(cats)) + } +} + +fn explain_analyze_categories_to_proto( + value: &ExplainAnalyzeCategories, +) -> protobuf_common::ExplainAnalyzeCategoriesNode { + match value { + ExplainAnalyzeCategories::All => protobuf_common::ExplainAnalyzeCategoriesNode { + all: true, + only: vec![], + }, + ExplainAnalyzeCategories::Only(cats) => { + protobuf_common::ExplainAnalyzeCategoriesNode { + all: false, + only: cats + .iter() + .copied() + .map(|c| metric_category_to_proto(c) as i32) + .collect(), + } + } + } +} + impl AsLogicalPlan for LogicalPlanNode { fn try_decode(buf: &[u8]) -> Result where @@ -423,14 +604,17 @@ impl AsLogicalPlan for LogicalPlanNode { } Arc::new(json) } - #[cfg_attr(not(feature = "avro"), allow(unused_variables))] FileFormatType::Avro(..) => { #[cfg(feature = "avro")] { Arc::new(AvroFormat) } #[cfg(not(feature = "avro"))] - panic!("Unable to process avro file since `avro` feature is not enabled"); + { + panic!( + "Unable to process avro file since `avro` feature is not enabled" + ); + } } FileFormatType::Arrow(..) => { Arc::new(ArrowFormat) @@ -606,27 +790,30 @@ impl AsLogicalPlan for LogicalPlanNode { } Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable( - CreateExternalTable { - schema: pb_schema.try_into()?, - name: from_table_reference( - create_extern_table.name.as_ref(), - "CreateExternalTable", - )?, - location: create_extern_table.location.clone(), - file_type: create_extern_table.file_type.clone(), - table_partition_cols: create_extern_table - .table_partition_cols - .clone(), - order_exprs, - if_not_exists: create_extern_table.if_not_exists, - or_replace: create_extern_table.or_replace, - temporary: create_extern_table.temporary, - definition, - unbounded: create_extern_table.unbounded, - options: create_extern_table.options.clone(), - constraints: constraints.into(), - column_defaults, - }, + Box::new( + CreateExternalTable::builder( + from_table_reference( + create_extern_table.name.as_ref(), + "CreateExternalTable", + )?, + create_extern_table.location.clone(), + create_extern_table.file_type.clone(), + pb_schema.try_into()?, + ) + .with_partition_cols( + create_extern_table.table_partition_cols.clone(), + ) + .with_order_exprs(order_exprs) + .with_if_not_exists(create_extern_table.if_not_exists) + .with_or_replace(create_extern_table.or_replace) + .with_temporary(create_extern_table.temporary) + .with_definition(definition) + .with_unbounded(create_extern_table.unbounded) + .with_options(create_extern_table.options.clone()) + .with_constraints(constraints.into()) + .with_column_defaults(column_defaults) + .build(), + ), ))) } LogicalPlanType::CreateView(create_view) => { @@ -682,15 +869,62 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::Analyze(analyze) => { let input: LogicalPlan = into_logical_plan!(analyze.input, ctx, extension_codec)?; + let analyze_level = analyze + .analyze_level + .map(metric_type_from_proto) + .transpose()?; + let analyze_categories = analyze + .analyze_categories + .as_ref() + .map(explain_analyze_categories_from_proto) + .transpose()?; + let pb_format = protobuf::ExplainFormat::try_from(analyze.format) + .map_err(|_| { + proto_error(format!( + "Received an AnalyzeNode message with unknown ExplainFormat {}", + analyze.format + )) + })?; + let analyze_format = match pb_format { + protobuf::ExplainFormat::Indent => ExplainFormat::Indent, + protobuf::ExplainFormat::Tree => ExplainFormat::Tree, + protobuf::ExplainFormat::Pgjson => ExplainFormat::PostgresJSON, + protobuf::ExplainFormat::Graphviz => ExplainFormat::Graphviz, + }; + let explain_option = + datafusion_expr::logical_plan::ExplainOption::default() + .with_verbose(analyze.verbose) + .with_analyze(true) + .with_analyze_level(analyze_level) + .with_analyze_categories(analyze_categories) + .with_format(analyze_format); LogicalPlanBuilder::from(input) - .explain(analyze.verbose, true)? + .explain_option_format(explain_option)? .build() } LogicalPlanType::Explain(explain) => { let input: LogicalPlan = into_logical_plan!(explain.input, ctx, extension_codec)?; + let pb_format = protobuf::ExplainFormat::try_from(explain.format) + .map_err(|_| { + proto_error(format!( + "Received an ExplainNode message with unknown ExplainFormat {}", + explain.format + )) + })?; + let explain_format = match pb_format { + protobuf::ExplainFormat::Indent => ExplainFormat::Indent, + protobuf::ExplainFormat::Tree => ExplainFormat::Tree, + protobuf::ExplainFormat::Pgjson => ExplainFormat::PostgresJSON, + protobuf::ExplainFormat::Graphviz => ExplainFormat::Graphviz, + }; + let explain_option = + datafusion_expr::logical_plan::ExplainOption::default() + .with_verbose(explain.verbose) + .with_format(explain_format) + .with_show_statistics(explain.show_statistics); LogicalPlanBuilder::from(input) - .explain(explain.verbose, false)? + .explain_option_format(explain_option)? .build() } LogicalPlanType::SubqueryAlias(aliased_relation) => { @@ -720,6 +954,13 @@ impl AsLogicalPlan for LogicalPlanNode { from_proto::parse_exprs(&join.left_join_key, ctx, extension_codec)?; let right_keys: Vec = from_proto::parse_exprs(&join.right_join_key, ctx, extension_codec)?; + if left_keys.len() != right_keys.len() { + return Err(proto_error(format!( + "Received a JoinNode message with left_join_key and right_join_key of different lengths: {} and {}", + left_keys.len(), + right_keys.len() + ))); + } let join_type = protobuf::JoinType::try_from(join.join_type).map_err(|_| { proto_error(format!( @@ -736,44 +977,39 @@ impl AsLogicalPlan for LogicalPlanNode { join.join_constraint )) })?; + let null_equality = protobuf::NullEquality::try_from(join.null_equality) + .map_err(|_| { + proto_error(format!( + "Received a JoinNode message with unknown NullEquality {}", + join.null_equality + )) + })?; let filter: Option = join .filter .as_ref() .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) .map_or(Ok(None), |v| v.map(Some))?; - - let builder = LogicalPlanBuilder::from(into_logical_plan!( - join.left, - ctx, - extension_codec - )?); - let builder = match join_constraint.into() { - JoinConstraint::On => builder.join_with_expr_keys( - into_logical_plan!(join.right, ctx, extension_codec)?, - join_type.into(), - (left_keys, right_keys), - filter, - )?, - JoinConstraint::Using => { - // The equijoin keys in using-join must be column. - let using_keys = left_keys - .into_iter() - .map(|key| { - key.try_as_col().cloned() - .ok_or_else(|| internal_datafusion_err!( - "Using join keys must be column references, got: {key:?}" - )) - }) - .collect::, _>>()?; - builder.join_using( - into_logical_plan!(join.right, ctx, extension_codec)?, - join_type.into(), - using_keys, - )? - } - }; - - builder.build() + let left = into_logical_plan!(join.left, ctx, extension_codec)?; + let right = into_logical_plan!(join.right, ctx, extension_codec)?; + let on: Vec<(Expr, Expr)> = + left_keys.into_iter().zip(right_keys).collect(); + + // Construct the Join directly instead of going through + // LogicalPlanBuilder. The builder methods hardcode + // `null_equality` and `null_aware`, so a round trip through + // them silently loses both fields. Both sides of the round + // trip should already have validated keys, so we don't need + // the builder's normalization / equijoin-pair checks. + Ok(LogicalPlan::Join(Join::try_new( + Arc::new(left), + Arc::new(right), + on, + filter, + datafusion_expr::JoinType::from_proto(join_type), + JoinConstraint::from_proto(join_constraint), + NullEquality::from_proto(null_equality), + join.null_aware, + )?)) } LogicalPlanType::Union(union) => { assert_or_internal_err!( @@ -934,7 +1170,13 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanBuilder::from(input) .unnest_columns_with_options( unnest.exec_columns.iter().map(|c| c.into()).collect(), - into_required!(unnest.options)?, + unnest + .options + .as_ref() + .map(datafusion_common::UnnestOptions::from_proto) + .ok_or_else(|| { + proto_error("Missing required field in protobuf") + })?, )? .build() } @@ -955,12 +1197,15 @@ impl AsLogicalPlan for LogicalPlanNode { ))? .try_into_logical_plan(ctx, extension_codec)?; - Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { - name: recursive_query_node.name.clone(), - static_term: Arc::new(static_term), - recursive_term: Arc::new(recursive_term), - is_distinct: recursive_query_node.is_distinct, - })) + // The output schema is derived state, so decoding goes through + // the constructor after restoring the child terms. + RecursiveQuery::try_new( + recursive_query_node.name.clone(), + Arc::new(static_term), + Arc::new(recursive_term), + recursive_query_node.is_distinct, + ) + .map(LogicalPlan::RecursiveQuery) } LogicalPlanType::CteWorkTableScan(cte_work_table_scan_node) => { let CteWorkTableScanNode { name, schema } = cte_work_table_scan_node; @@ -973,11 +1218,40 @@ impl AsLogicalPlan for LogicalPlanNode { )? .build() } + LogicalPlanType::EmptyTableScan(scan) => { + let schema: Schema = convert_required!(scan.schema)?; + let schema = Arc::new(schema); + let mut projection = None; + if let Some(columns) = &scan.projection { + let column_indices = columns + .columns + .iter() + .map(|name| schema.index_of(name)) + .collect::, _>>()?; + projection = Some(column_indices); + } + + let filters = + from_proto::parse_exprs(&scan.filters, ctx, extension_codec)?; + + let table_name = + from_table_reference(scan.table_name.as_ref(), "EmptyTableScan")?; + + let provider = Arc::new(EmptyTable::new(Arc::clone(&schema))); + + LogicalPlanBuilder::scan_with_filters( + table_name, + provider_as_source(provider), + projection, + filters, + )? + .build() + } LogicalPlanType::Dml(dml_node) => { Ok(LogicalPlan::Dml(datafusion_expr::DmlStatement::new( from_table_reference(dml_node.table_name.as_ref(), "DML ")?, to_table_source(&dml_node.target, ctx, extension_codec)?, - dml_node.dml_type().into(), + WriteOp::from_proto(dml_node.dml_type()), Arc::new(into_logical_plan!(dml_node.input, ctx, extension_codec)?), ))) } @@ -1018,7 +1292,6 @@ impl AsLogicalPlan for LogicalPlanNode { }) => { let provider = source_as_provider(source)?; let schema = provider.schema(); - let source = provider.as_any(); let projection = match projection { None => None, @@ -1036,13 +1309,13 @@ impl AsLogicalPlan for LogicalPlanNode { let filters: Vec = serialize_exprs(filters, extension_codec)?; - if let Some(listing_table) = source.downcast_ref::() { - let any = listing_table.options().format.as_any(); + if let Some(listing_table) = provider.downcast_ref::() { + let format = listing_table.options().format.as_ref(); let file_format_type = { let mut maybe_some_type = None; #[cfg(feature = "parquet")] - if let Some(parquet) = any.downcast_ref::() { + if let Some(parquet) = format.downcast_ref::() { let options = parquet.options(); maybe_some_type = Some(FileFormatType::Parquet(protobuf::ParquetFormat { @@ -1050,7 +1323,7 @@ impl AsLogicalPlan for LogicalPlanNode { })); }; - if let Some(csv) = any.downcast_ref::() { + if let Some(csv) = format.downcast_ref::() { let options = csv.options(); maybe_some_type = Some(FileFormatType::Csv(protobuf::CsvFormat { @@ -1058,7 +1331,7 @@ impl AsLogicalPlan for LogicalPlanNode { })); } - if let Some(json) = any.downcast_ref::() { + if let Some(json) = format.downcast_ref::() { let options = json.options(); maybe_some_type = Some(FileFormatType::Json(protobuf::NdJsonFormat { @@ -1067,12 +1340,12 @@ impl AsLogicalPlan for LogicalPlanNode { } #[cfg(feature = "avro")] - if any.is::() { + if format.is::() { maybe_some_type = Some(FileFormatType::Avro(protobuf::AvroFormat {})) } - if any.is::() { + if format.is::() { maybe_some_type = Some(FileFormatType::Arrow(protobuf::ArrowFormat {})) } @@ -1133,7 +1406,9 @@ impl AsLogicalPlan for LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::ListingScan( protobuf::ListingTableScanNode { file_format_type: Some(file_format_type), - table_name: Some(table_name.clone().into()), + table_name: Some(protobuf::TableReference::from_proto( + table_name.clone(), + )), collect_stat: options.collect_stat, file_extension: options.file_extension.clone(), table_partition_cols: partition_columns, @@ -1150,12 +1425,14 @@ impl AsLogicalPlan for LogicalPlanNode { }, )), }) - } else if let Some(view_table) = source.downcast_ref::() { + } else if let Some(view_table) = provider.downcast_ref::() { let schema: protobuf::Schema = schema.as_ref().try_into()?; Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::ViewScan(Box::new( protobuf::ViewTableScanNode { - table_name: Some(table_name.clone().into()), + table_name: Some(protobuf::TableReference::from_proto( + table_name.clone(), + )), input: Some(Box::new( LogicalPlanNode::try_from_logical_plan( view_table.logical_plan(), @@ -1171,7 +1448,8 @@ impl AsLogicalPlan for LogicalPlanNode { }, ))), }) - } else if let Some(cte_work_table) = source.downcast_ref::() + } else if let Some(cte_work_table) = + provider.downcast_ref::() { let name = cte_work_table.name().to_string(); let schema = cte_work_table.schema(); @@ -1185,6 +1463,21 @@ impl AsLogicalPlan for LogicalPlanNode { }, )), }) + } else if provider.downcast_ref::().is_some() { + let schema: protobuf::Schema = schema.as_ref().try_into()?; + + Ok(LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::EmptyTableScan( + protobuf::EmptyTableScanNode { + table_name: Some(protobuf::TableReference::from_proto( + table_name.clone(), + )), + schema: Some(schema), + projection, + filters, + }, + )), + }) } else { let schema: protobuf::Schema = schema.as_ref().try_into()?; let mut bytes = vec![]; @@ -1192,7 +1485,9 @@ impl AsLogicalPlan for LogicalPlanNode { .try_encode_table_provider(table_name, provider, &mut bytes) .map_err(|e| context!("Error serializing custom table", e))?; let scan = CustomScan(CustomTableScanNode { - table_name: Some(table_name.clone().into()), + table_name: Some(protobuf::TableReference::from_proto( + table_name.clone(), + )), projection, schema: Some(schema), filters, @@ -1229,10 +1524,10 @@ impl AsLogicalPlan for LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Selection(Box::new( protobuf::SelectionNode { input: Some(Box::new(input)), - expr: Some(serialize_expr( + expr: Some(Box::new(serialize_expr( &filter.predicate, extension_codec, - )?), + )?)), }, ))), }) @@ -1320,7 +1615,9 @@ impl AsLogicalPlan for LogicalPlanNode { join_type, join_constraint, null_equality, - .. + null_aware, + // Not encoded; recomputed by `Join::try_new` on decode. + schema: _, }) => { let left: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( left.as_ref(), @@ -1341,14 +1638,14 @@ impl AsLogicalPlan for LogicalPlanNode { .collect::, ToProtoError>>()? .into_iter() .unzip(); - let join_type: protobuf::JoinType = join_type.to_owned().into(); - let join_constraint: protobuf::JoinConstraint = - join_constraint.to_owned().into(); - let null_equality: protobuf::NullEquality = - null_equality.to_owned().into(); + let join_type = protobuf::JoinType::from_proto(join_type.to_owned()); + let join_constraint = + protobuf::JoinConstraint::from_proto(join_constraint.to_owned()); + let null_equality = + protobuf::NullEquality::from_proto(null_equality.to_owned()); let filter = filter .as_ref() - .map(|e| serialize_expr(e, extension_codec)) + .map(|e| serialize_expr(e, extension_codec).map(Box::new)) .map_or(Ok(None), |v| v.map(Some))?; Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( @@ -1361,12 +1658,19 @@ impl AsLogicalPlan for LogicalPlanNode { right_join_key, null_equality: null_equality.into(), filter, + null_aware: *null_aware, }, ))), }) } - LogicalPlan::Subquery(_) => { - not_impl_err!("LogicalPlan serde is not yet implemented for subqueries") + LogicalPlan::Subquery(subquery) => { + // Serialize the inner subquery plan directly — the + // LogicalPlan::Subquery wrapper is reconstructed during + // expression deserialization. + LogicalPlanNode::try_from_logical_plan( + &subquery.subquery, + extension_codec, + ) } LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( @@ -1377,7 +1681,9 @@ impl AsLogicalPlan for LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::SubqueryAlias(Box::new( protobuf::SubqueryAliasNode { input: Some(Box::new(input)), - alias: Some((*alias).clone().into()), + alias: Some(protobuf::TableReference::from_proto( + (*alias).clone(), + )), }, ))), }) @@ -1449,8 +1755,13 @@ impl AsLogicalPlan for LogicalPlanNode { Partitioning::RoundRobinBatch(partition_count) => { PartitionMethod::RoundRobin(*partition_count as u64) } + Partitioning::Range(_) => { + // TODO: Support range repartition protobuf serialization. + // Tracked by https://github.com/apache/datafusion/issues/22787 + return not_impl_err!("Range repartition"); + } Partitioning::DistributeBy(_) => { - return not_impl_err!("DistributeBy") + return not_impl_err!("DistributeBy"); } }; @@ -1472,8 +1783,8 @@ impl AsLogicalPlan for LogicalPlanNode { }, )), }), - LogicalPlan::Ddl(DdlStatement::CreateExternalTable( - CreateExternalTable { + LogicalPlan::Ddl(DdlStatement::CreateExternalTable(ce)) => { + let CreateExternalTable { name, location, file_type, @@ -1488,8 +1799,7 @@ impl AsLogicalPlan for LogicalPlanNode { constraints, column_defaults, temporary, - }, - )) => { + } = ce.as_ref(); let mut converted_order_exprs: Vec = vec![]; for order in order_exprs { let temp = SortExprNodeCollection { @@ -1508,7 +1818,9 @@ impl AsLogicalPlan for LogicalPlanNode { Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateExternalTable( protobuf::CreateExternalTableNode { - name: Some(name.clone().into()), + name: Some(protobuf::TableReference::from_proto( + name.clone(), + )), location: location.clone(), file_type: file_type.clone(), schema: Some(df_schema.try_into()?), @@ -1535,7 +1847,7 @@ impl AsLogicalPlan for LogicalPlanNode { })) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateView(Box::new( protobuf::CreateViewNode { - name: Some(name.clone().into()), + name: Some(protobuf::TableReference::from_proto(name.clone())), input: Some(Box::new(LogicalPlanNode::try_from_logical_plan( input, extension_codec, @@ -1584,6 +1896,23 @@ impl AsLogicalPlan for LogicalPlanNode { protobuf::AnalyzeNode { input: Some(Box::new(input)), verbose: a.verbose, + analyze_level: a + .analyze_level + .map(|m| metric_type_to_proto(m) as i32), + analyze_categories: a + .analyze_categories + .as_ref() + .map(explain_analyze_categories_to_proto), + format: match &a.format { + ExplainFormat::Indent => protobuf::ExplainFormat::Indent, + ExplainFormat::Tree => protobuf::ExplainFormat::Tree, + ExplainFormat::PostgresJSON => { + protobuf::ExplainFormat::Pgjson + } + ExplainFormat::Graphviz => { + protobuf::ExplainFormat::Graphviz + } + } as i32, }, ))), }) @@ -1598,6 +1927,18 @@ impl AsLogicalPlan for LogicalPlanNode { protobuf::ExplainNode { input: Some(Box::new(input)), verbose: a.verbose, + format: match &a.explain_format { + ExplainFormat::Indent => protobuf::ExplainFormat::Indent, + ExplainFormat::Tree => protobuf::ExplainFormat::Tree, + ExplainFormat::PostgresJSON => { + protobuf::ExplainFormat::Pgjson + } + ExplainFormat::Graphviz => { + protobuf::ExplainFormat::Graphviz + } + } + .into(), + show_statistics: a.show_statistics, }, ))), }) @@ -1696,7 +2037,7 @@ impl AsLogicalPlan for LogicalPlanNode { .map(|c| *c as u64) .collect(), schema: Some(schema.try_into()?), - options: Some(options.into()), + options: Some(protobuf::UnnestOptions::from_proto(options)), }, ))), }) @@ -1717,7 +2058,7 @@ impl AsLogicalPlan for LogicalPlanNode { })) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::DropView( protobuf::DropViewNode { - name: Some(name.clone().into()), + name: Some(protobuf::TableReference::from_proto(name.clone())), if_exists: *if_exists, schema: Some(schema.try_into()?), }, @@ -1744,7 +2085,7 @@ impl AsLogicalPlan for LogicalPlanNode { }) => { let input = LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; - let dml_type: dml_node::Type = op.into(); + let dml_type = dml_node::Type::from_proto(op); Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Dml(Box::new(DmlNode { input: Some(Box::new(input)), @@ -1753,7 +2094,9 @@ impl AsLogicalPlan for LogicalPlanNode { Arc::clone(target), extension_codec, )?)), - table_name: Some(table_name.clone().into()), + table_name: Some(protobuf::TableReference::from_proto( + table_name.clone(), + )), dml_type: dml_type.into(), }))), }) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2774b5b6ba7c3..71a6bd824a369 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -22,21 +22,24 @@ use std::collections::HashMap; use datafusion_common::{NullEquality, TableReference, UnnestOptions}; +use datafusion_expr::WriteOp; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like, NullTreatment, Placeholder, ScalarFunction, Unnest, }; -use datafusion_expr::WriteOp; +use datafusion_expr::logical_plan::Subquery; use datafusion_expr::{ - logical_plan::PlanType, logical_plan::StringifiedPlan, Expr, JoinConstraint, - JoinType, SortExpr, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, + Expr, JoinConstraint, JoinType, SortExpr, TryCast, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, logical_plan::PlanType, + logical_plan::StringifiedPlan, }; use crate::protobuf::RecursionUnnestOption; use crate::protobuf::{ - self, + self, AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, + LogicalExprList, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, + PlaceholderNode, RollupNode, ToProtoError as Error, plan_type::PlanTypeEnum::{ AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, FinalPhysicalPlan, FinalPhysicalPlanWithSchema, FinalPhysicalPlanWithStats, @@ -44,15 +47,14 @@ use crate::protobuf::{ InitialPhysicalPlanWithStats, OptimizedLogicalPlan, OptimizedPhysicalPlan, PhysicalPlanError, }, - AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, - OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, - ToProtoError as Error, }; -use super::LogicalExtensionCodec; +use super::{AsLogicalPlan, LogicalExtensionCodec}; +use crate::convert::{FromProto, TryFromProto}; +use crate::protobuf::LogicalPlanNode; -impl From<&UnnestOptions> for protobuf::UnnestOptions { - fn from(opts: &UnnestOptions) -> Self { +impl FromProto<&UnnestOptions> for protobuf::UnnestOptions { + fn from_proto(opts: &UnnestOptions) -> Self { Self { preserve_nulls: opts.preserve_nulls, recursions: opts @@ -68,8 +70,8 @@ impl From<&UnnestOptions> for protobuf::UnnestOptions { } } -impl From<&StringifiedPlan> for protobuf::StringifiedPlan { - fn from(stringified_plan: &StringifiedPlan) -> Self { +impl FromProto<&StringifiedPlan> for protobuf::StringifiedPlan { + fn from_proto(stringified_plan: &StringifiedPlan) -> Self { Self { plan_type: match stringified_plan.clone().plan_type { PlanType::InitialLogicalPlan => Some(protobuf::PlanType { @@ -129,8 +131,8 @@ impl From<&StringifiedPlan> for protobuf::StringifiedPlan { } } -impl From for protobuf::WindowFrameUnits { - fn from(units: WindowFrameUnits) -> Self { +impl FromProto for protobuf::WindowFrameUnits { + fn from_proto(units: WindowFrameUnits) -> Self { match units { WindowFrameUnits::Rows => Self::Rows, WindowFrameUnits::Range => Self::Range, @@ -139,10 +141,10 @@ impl From for protobuf::WindowFrameUnits { } } -impl TryFrom<&WindowFrameBound> for protobuf::WindowFrameBound { +impl TryFromProto<&WindowFrameBound> for protobuf::WindowFrameBound { type Error = Error; - fn try_from(bound: &WindowFrameBound) -> Result { + fn try_from_proto(bound: &WindowFrameBound) -> Result { Ok(match bound { WindowFrameBound::CurrentRow => Self { window_frame_bound_type: protobuf::WindowFrameBoundType::CurrentRow @@ -161,15 +163,18 @@ impl TryFrom<&WindowFrameBound> for protobuf::WindowFrameBound { } } -impl TryFrom<&WindowFrame> for protobuf::WindowFrame { +impl TryFromProto<&WindowFrame> for protobuf::WindowFrame { type Error = Error; - fn try_from(window: &WindowFrame) -> Result { + fn try_from_proto(window: &WindowFrame) -> Result { Ok(Self { - window_frame_units: protobuf::WindowFrameUnits::from(window.units).into(), - start_bound: Some((&window.start_bound).try_into()?), + window_frame_units: protobuf::WindowFrameUnits::from_proto(window.units) + .into(), + start_bound: Some(protobuf::WindowFrameBound::try_from_proto( + &window.start_bound, + )?), end_bound: Some(protobuf::window_frame::EndBound::Bound( - (&window.end_bound).try_into()?, + protobuf::WindowFrameBound::try_from_proto(&window.end_bound)?, )), }) } @@ -208,7 +213,7 @@ pub fn serialize_expr( expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), relation: relation .to_owned() - .map(|r| vec![r.into()]) + .map(|r| vec![protobuf::TableReference::from_proto(r)]) .unwrap_or(vec![]), alias: name.to_owned(), metadata: metadata @@ -307,16 +312,16 @@ pub fn serialize_expr( } Expr::WindowFunction(window_fun) => { let expr::WindowFunction { - ref fun, + fun, params: expr::WindowFunctionParams { - ref args, - ref partition_by, - ref order_by, - ref window_frame, - ref null_treatment, - ref distinct, - ref filter, + args, + partition_by, + order_by, + window_frame, + null_treatment, + distinct, + filter, }, } = window_fun.as_ref(); let mut buf = Vec::new(); @@ -338,8 +343,7 @@ pub fn serialize_expr( let partition_by = serialize_exprs(partition_by, codec)?; let order_by = serialize_sorts(order_by, codec)?; - let window_frame: Option = - Some(window_frame.try_into()?); + let window_frame = Some(protobuf::WindowFrame::try_from_proto(window_frame)?); let window_expr = protobuf::WindowExprNode { exprs: serialize_exprs(args, codec)?, @@ -353,7 +357,7 @@ pub fn serialize_expr( None => None, }, null_treatment: null_treatment - .map(|nt| protobuf::NullTreatment::from(nt).into()), + .map(|nt| protobuf::NullTreatment::from_proto(nt).into()), fun_definition, }; protobuf::LogicalExprNode { @@ -361,14 +365,14 @@ pub fn serialize_expr( } } Expr::AggregateFunction(expr::AggregateFunction { - ref func, + func, params: AggregateFunctionParams { - ref args, - ref distinct, - ref filter, - ref order_by, - ref null_treatment, + args, + distinct, + filter, + order_by, + null_treatment, }, }) => { let mut buf = Vec::new(); @@ -386,7 +390,7 @@ pub fn serialize_expr( order_by: serialize_sorts(order_by, codec)?, fun_definition: (!buf.is_empty()).then_some(buf), null_treatment: null_treatment - .map(|nt| protobuf::NullTreatment::from(nt).into()), + .map(|nt| protobuf::NullTreatment::from_proto(nt).into()), }, ))), } @@ -395,7 +399,7 @@ pub fn serialize_expr( Expr::ScalarVariable(_, _) => { return Err(Error::General( "Proto serialization error: Scalar Variable not supported".to_string(), - )) + )); } Expr::ScalarFunction(ScalarFunction { func, args }) => { let mut buf = Vec::new(); @@ -522,19 +526,23 @@ pub fn serialize_expr( expr_type: Some(ExprType::Case(expr)), } } - Expr::Cast(Cast { expr, data_type }) => { + Expr::Cast(Cast { expr, field }) => { let expr = Box::new(protobuf::CastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - arrow_type: Some(data_type.try_into()?), + arrow_type: Some(field.data_type().try_into()?), + metadata: field.metadata().clone(), + nullable: Some(field.is_nullable()), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::Cast(expr)), } } - Expr::TryCast(TryCast { expr, data_type }) => { + Expr::TryCast(TryCast { expr, field }) => { let expr = Box::new(protobuf::TryCastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - arrow_type: Some(data_type.try_into()?), + arrow_type: Some(field.data_type().try_into()?), + metadata: field.metadata().clone(), + nullable: Some(field.is_nullable()), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::TryCast(expr)), @@ -573,16 +581,25 @@ pub fn serialize_expr( #[expect(deprecated)] Expr::Wildcard { qualifier, .. } => protobuf::LogicalExprNode { expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { - qualifier: qualifier.to_owned().map(|x| x.into()), + qualifier: qualifier + .to_owned() + .map(protobuf::TableReference::from_proto), })), }, - Expr::ScalarSubquery(_) - | Expr::InSubquery(_) - | Expr::Exists { .. } - | Expr::OuterReferenceColumn { .. } => { - // we would need to add logical plan operators to datafusion.proto to support this - // see discussion in https://github.com/apache/datafusion/issues/2565 - return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); + Expr::ScalarSubquery(subquery) => protobuf::LogicalExprNode { + expr_type: Some(ExprType::ScalarSubqueryExpr(Box::new( + protobuf::ScalarSubqueryExprNode { + subquery: Some(Box::new(serialize_subquery(subquery, codec)?)), + }, + ))), + }, + Expr::InSubquery(_) + | Expr::Exists(_) + | Expr::OuterReferenceColumn(_, _) + | Expr::SetComparison(_) => { + return Err(Error::General(format!( + "Proto serialization error: {expr} is not yet supported" + ))); } Expr::GroupingSet(GroupingSet::Cube(exprs)) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Cube(CubeNode { @@ -622,11 +639,29 @@ pub fn serialize_expr( .unwrap_or(HashMap::new()), })), }, + Expr::HigherOrderFunction(_) | Expr::Lambda(_) | Expr::LambdaVariable(_) => { + return Err(Error::General( + "Proto serialization error: Lambda not implemented".to_string(), + )); + } }; Ok(expr_node) } +fn serialize_subquery( + subquery: &Subquery, + codec: &dyn LogicalExtensionCodec, +) -> Result { + let plan = LogicalPlanNode::try_from_logical_plan(&subquery.subquery, codec) + .map_err(|e| Error::General(e.to_string()))?; + let outer_ref_columns = serialize_exprs(&subquery.outer_ref_columns, codec)?; + Ok(protobuf::SubqueryNode { + subquery: Some(Box::new(plan)), + outer_ref_columns, + }) +} + pub fn serialize_sorts<'a, I>( sorts: I, codec: &dyn LogicalExtensionCodec, @@ -651,8 +686,8 @@ where .collect::, Error>>() } -impl From for protobuf::TableReference { - fn from(t: TableReference) -> Self { +impl FromProto for protobuf::TableReference { + fn from_proto(t: TableReference) -> Self { use protobuf::table_reference::TableReferenceEnum; let table_reference_enum = match t { TableReference::Bare { table } => { @@ -683,8 +718,8 @@ impl From for protobuf::TableReference { } } -impl From for protobuf::JoinType { - fn from(t: JoinType) -> Self { +impl FromProto for protobuf::JoinType { + fn from_proto(t: JoinType) -> Self { match t { JoinType::Inner => protobuf::JoinType::Inner, JoinType::Left => protobuf::JoinType::Left, @@ -700,8 +735,8 @@ impl From for protobuf::JoinType { } } -impl From for protobuf::JoinConstraint { - fn from(t: JoinConstraint) -> Self { +impl FromProto for protobuf::JoinConstraint { + fn from_proto(t: JoinConstraint) -> Self { match t { JoinConstraint::On => protobuf::JoinConstraint::On, JoinConstraint::Using => protobuf::JoinConstraint::Using, @@ -709,8 +744,8 @@ impl From for protobuf::JoinConstraint { } } -impl From for protobuf::NullEquality { - fn from(t: NullEquality) -> Self { +impl FromProto for protobuf::NullEquality { + fn from_proto(t: NullEquality) -> Self { match t { NullEquality::NullEqualsNothing => protobuf::NullEquality::NullEqualsNothing, NullEquality::NullEqualsNull => protobuf::NullEquality::NullEqualsNull, @@ -718,8 +753,8 @@ impl From for protobuf::NullEquality { } } -impl From<&WriteOp> for protobuf::dml_node::Type { - fn from(t: &WriteOp) -> Self { +impl FromProto<&WriteOp> for protobuf::dml_node::Type { + fn from_proto(t: &WriteOp) -> Self { match t { WriteOp::Insert(InsertOp::Append) => protobuf::dml_node::Type::InsertAppend, WriteOp::Insert(InsertOp::Overwrite) => { @@ -729,12 +764,13 @@ impl From<&WriteOp> for protobuf::dml_node::Type { WriteOp::Delete => protobuf::dml_node::Type::Delete, WriteOp::Update => protobuf::dml_node::Type::Update, WriteOp::Ctas => protobuf::dml_node::Type::Ctas, + WriteOp::Truncate => protobuf::dml_node::Type::Truncate, } } } -impl From for protobuf::NullTreatment { - fn from(t: NullTreatment) -> Self { +impl FromProto for protobuf::NullTreatment { + fn from_proto(t: NullTreatment) -> Self { match t { NullTreatment::RespectNulls => protobuf::NullTreatment::RespectNulls, NullTreatment::IgnoreNulls => protobuf::NullTreatment::IgnoreNulls, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index f1a9abe6ea7b1..36751d8a61a3e 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -21,15 +21,12 @@ use std::sync::Arc; use arrow::array::RecordBatch; use arrow::compute::SortOptions; -use arrow::datatypes::Field; +use arrow::datatypes::{Field, Schema}; use arrow::ipc::reader::StreamReader; use chrono::{TimeZone, Utc}; -use datafusion_expr::dml::InsertOp; -use object_store::path::Path; -use object_store::ObjectMeta; - -use arrow::datatypes::Schema; -use datafusion_common::{internal_datafusion_err, not_impl_err, DataFusionError, Result}; +use datafusion_common::{ + DataFusionError, Result, ScalarValue, internal_datafusion_err, not_impl_err, +}; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; @@ -42,45 +39,55 @@ use datafusion_datasource_parquet::file_format::ParquetSink; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{FunctionRegistry, TaskContext}; use datafusion_expr::WindowFunctionDefinition; +use datafusion_expr::dml::InsertOp; +use datafusion_expr::execution_props::SubqueryIndex; +use datafusion_physical_expr::projection::{ProjectionExpr, ProjectionExprs}; +use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion_physical_plan::expressions::{ - in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, - Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, + BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, + LikeExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, }; +use datafusion_physical_plan::joins::HashExpr; use datafusion_physical_plan::windows::{create_window_expr, schema_add_window_field}; -use datafusion_physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; +use datafusion_physical_plan::{ + Partitioning, PhysicalExpr, RangePartitioning, SplitPoint, WindowExpr, +}; use datafusion_proto_common::common::proto_error; +use object_store::ObjectMeta; +use object_store::path::Path; -use crate::convert_required; -use crate::logical_plan::{self}; -use crate::protobuf; +use super::{ + DefaultPhysicalProtoConverter, PhysicalExtensionCodec, PhysicalPlanDecodeContext, + PhysicalProtoConverterExtension, +}; +use crate::convert::TryFromProto; use crate::protobuf::physical_expr_node::ExprType; - -use super::PhysicalExtensionCodec; - -impl From<&protobuf::PhysicalColumn> for Column { - fn from(c: &protobuf::PhysicalColumn) -> Column { - Column::new(&c.name, c.index as usize) - } -} +use crate::{convert_required, convert_required_proto, protobuf}; +use datafusion_physical_expr::expressions::{ + DynamicFilterInner, DynamicFilterPhysicalExpr, +}; /// Parses a physical sort expression from a protobuf. /// /// # Arguments /// /// * `proto` - Input proto with physical sort expression node -/// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types /// when performing type coercion. -/// * `codec` - An extension codec used to decode custom UDFs. +/// * `ctx` - Decode context carrying the task context, extension codec, and +/// any scoped state needed during recursive deserialization. +/// * `proto_converter` - Converter hooks used for recursive physical plan and +/// expression deserialization. pub fn parse_physical_sort_expr( proto: &protobuf::PhysicalSortExprNode, - ctx: &TaskContext, + ctx: &PhysicalPlanDecodeContext<'_>, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { if let Some(expr) = &proto.expr { - let expr = parse_physical_expr(expr.as_ref(), ctx, input_schema, codec)?; + let expr = + proto_converter.proto_to_physical_expr(expr.as_ref(), input_schema, ctx)?; let options = SortOptions { descending: !proto.asc, nulls_first: proto.nulls_first, @@ -96,19 +103,23 @@ pub fn parse_physical_sort_expr( /// # Arguments /// /// * `proto` - Input proto with vector of physical sort expression node -/// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types /// when performing type coercion. -/// * `codec` - An extension codec used to decode custom UDFs. +/// * `ctx` - Decode context carrying the task context, extension codec, and +/// any scoped state needed during recursive deserialization. +/// * `proto_converter` - Converter hooks used for recursive physical plan and +/// expression deserialization. pub fn parse_physical_sort_exprs( proto: &[protobuf::PhysicalSortExprNode], - ctx: &TaskContext, + ctx: &PhysicalPlanDecodeContext<'_>, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { proto .iter() - .map(|sort_expr| parse_physical_sort_expr(sort_expr, ctx, input_schema, codec)) + .map(|sort_expr| { + parse_physical_sort_expr(sort_expr, ctx, input_schema, proto_converter) + }) .collect() } @@ -118,26 +129,30 @@ pub fn parse_physical_sort_exprs( /// /// * `proto` - Input proto with physical window expression node. /// * `name` - Name of the window expression. -/// * `registry` - A registry knows how to build logical expressions out of user-defined function names -/// * `input_schema` - The Arrow schema for the input, used for determining expression data types -/// when performing type coercion. -/// * `codec` - An extension codec used to decode custom UDFs. +/// * `input_schema` - The Arrow schema for the input, used for determining +/// expression data types when performing type coercion. +/// * `ctx` - Decode context carrying the task context, extension codec, and +/// any scoped state needed during recursive deserialization. +/// * `proto_converter` - Converter hooks used for recursive physical plan and +/// expression deserialization. pub fn parse_physical_window_expr( proto: &protobuf::PhysicalWindowExprNode, - ctx: &TaskContext, + ctx: &PhysicalPlanDecodeContext<'_>, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let window_node_expr = parse_physical_exprs(&proto.args, ctx, input_schema, codec)?; + let window_node_expr = + parse_physical_exprs(&proto.args, ctx, input_schema, proto_converter)?; let partition_by = - parse_physical_exprs(&proto.partition_by, ctx, input_schema, codec)?; + parse_physical_exprs(&proto.partition_by, ctx, input_schema, proto_converter)?; - let order_by = parse_physical_sort_exprs(&proto.order_by, ctx, input_schema, codec)?; + let order_by = + parse_physical_sort_exprs(&proto.order_by, ctx, input_schema, proto_converter)?; let window_frame = proto .window_frame .as_ref() - .map(|wf| wf.clone().try_into()) + .map(|wf| datafusion_expr::WindowFrame::try_from_proto(wf.clone())) .transpose() .map_err(|e| internal_datafusion_err!("{e}"))? .ok_or_else(|| { @@ -148,14 +163,20 @@ pub fn parse_physical_window_expr( match window_func { protobuf::physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(udaf_name) => { WindowFunctionDefinition::AggregateUDF(match &proto.fun_definition { - Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, - None => ctx.udaf(udaf_name).or_else(|_| codec.try_decode_udaf(udaf_name, &[]))?, + Some(buf) => ctx.codec().try_decode_udaf(udaf_name, buf)?, + None => ctx + .task_ctx() + .udaf(udaf_name) + .or_else(|_| ctx.codec().try_decode_udaf(udaf_name, &[]))?, }) } protobuf::physical_window_expr_node::WindowFunction::UserDefinedWindowFunction(udwf_name) => { WindowFunctionDefinition::WindowUDF(match &proto.fun_definition { - Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, - None => ctx.udwf(udwf_name).or_else(|_| codec.try_decode_udwf(udwf_name, &[]))? + Some(buf) => ctx.codec().try_decode_udwf(udwf_name, buf)?, + None => ctx + .task_ctx() + .udwf(udwf_name) + .or_else(|_| ctx.codec().try_decode_udwf(udwf_name, &[]))? }) } } @@ -183,16 +204,16 @@ pub fn parse_physical_window_expr( pub fn parse_physical_exprs<'a, I>( protos: I, - ctx: &TaskContext, + ctx: &PhysicalPlanDecodeContext<'_>, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result>> where I: IntoIterator, { protos .into_iter() - .map(|p| parse_physical_expr(p, ctx, input_schema, codec)) + .map(|p| proto_converter.proto_to_physical_expr(p, input_schema, ctx)) .collect::>>() } @@ -201,45 +222,69 @@ where /// # Arguments /// /// * `proto` - Input proto with physical expression node -/// * `registry` - A registry knows how to build logical expressions out of user-defined function names -/// * `input_schema` - The Arrow schema for the input, used for determining expression data types -/// when performing type coercion. -/// * `codec` - An extension codec used to decode custom UDFs. +/// * `ctx` - Task context used to resolve registered functions. +/// * `input_schema` - The Arrow schema for the input, used for determining +/// expression data types when performing type coercion. +/// * `codec` - Physical extension codec used to construct the root decode +/// context for deserialization. pub fn parse_physical_expr( proto: &protobuf::PhysicalExprNode, ctx: &TaskContext, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, +) -> Result> { + let decode_ctx = PhysicalPlanDecodeContext::new(ctx, codec); + parse_physical_expr_with_converter( + proto, + input_schema, + &decode_ctx, + &DefaultPhysicalProtoConverter {}, + ) +} + +/// Parses a physical expression from a protobuf. +/// +/// # Arguments +/// +/// * `proto` - Input proto with physical expression node +/// * `input_schema` - The Arrow schema for the input, used for determining +/// expression data types when performing type coercion. +/// * `ctx` - Decode context carrying the task context, extension codec, and +/// any scoped state needed during recursive deserialization. +/// * `proto_converter` - Converter hooks used for recursive physical plan and +/// expression deserialization. +pub fn parse_physical_expr_with_converter( + proto: &protobuf::PhysicalExprNode, + input_schema: &Schema, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let expr_type = proto .expr_type .as_ref() .ok_or_else(|| proto_error("Unexpected empty physical expression"))?; + // Decoder context handed to per-expression `try_from_proto` constructors. + // This is the new shape the codebase is migrating toward (see #21835); + // the remaining `ExprType` variants stay matched inline until they migrate. + let decoder = ConverterDecoder { + ctx, + proto_converter, + }; + let decode_ctx = + datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx::new( + input_schema, + &decoder, + ); + let pexpr: Arc = match expr_type { - ExprType::Column(c) => { - let pcol: Column = c.into(); - Arc::new(pcol) - } - ExprType::UnknownColumn(c) => Arc::new(UnKnownColumn::new(&c.name)), - ExprType::Literal(scalar) => Arc::new(Literal::new(scalar.try_into()?)), - ExprType::BinaryExpr(binary_expr) => Arc::new(BinaryExpr::new( - parse_required_physical_expr( - binary_expr.l.as_deref(), - ctx, - "left", - input_schema, - codec, - )?, - logical_plan::from_proto::from_proto_binary_op(&binary_expr.op)?, - parse_required_physical_expr( - binary_expr.r.as_deref(), - ctx, - "right", - input_schema, - codec, - )?, - )), + // Migrated expressions take the whole `PhysicalExprNode` and unwrap + // their own `ExprType` variant — see #21835. This match only routes + // to the right constructor. + ExprType::Column(_) => Column::try_from_proto(proto, &decode_ctx)?, + ExprType::UnknownColumn(_) => UnKnownColumn::try_from_proto(proto, &decode_ctx)?, + ExprType::Literal(_) => Literal::try_from_proto(proto, &decode_ctx)?, + ExprType::BinaryExpr(_) => BinaryExpr::try_from_proto(proto, &decode_ctx)?, ExprType::AggregateExpr(_) => { return not_impl_err!( "Cannot convert aggregate expr node to physical expression" @@ -253,116 +298,27 @@ pub fn parse_physical_expr( ExprType::Sort(_) => { return not_impl_err!("Cannot convert sort expr node to physical expression"); } - ExprType::IsNullExpr(e) => { - Arc::new(IsNullExpr::new(parse_required_physical_expr( - e.expr.as_deref(), - ctx, - "expr", - input_schema, - codec, - )?)) - } - ExprType::IsNotNullExpr(e) => { - Arc::new(IsNotNullExpr::new(parse_required_physical_expr( - e.expr.as_deref(), - ctx, - "expr", - input_schema, - codec, - )?)) - } - ExprType::NotExpr(e) => Arc::new(NotExpr::new(parse_required_physical_expr( - e.expr.as_deref(), - ctx, - "expr", - input_schema, - codec, - )?)), - ExprType::Negative(e) => { - Arc::new(NegativeExpr::new(parse_required_physical_expr( - e.expr.as_deref(), - ctx, - "expr", - input_schema, - codec, - )?)) - } - ExprType::InList(e) => in_list( - parse_required_physical_expr( - e.expr.as_deref(), - ctx, - "expr", - input_schema, - codec, - )?, - parse_physical_exprs(&e.list, ctx, input_schema, codec)?, - &e.negated, - input_schema, - )?, - ExprType::Case(e) => Arc::new(CaseExpr::try_new( - e.expr - .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), ctx, input_schema, codec)) - .transpose()?, - e.when_then_expr - .iter() - .map(|e| { - Ok(( - parse_required_physical_expr( - e.when_expr.as_ref(), - ctx, - "when_expr", - input_schema, - codec, - )?, - parse_required_physical_expr( - e.then_expr.as_ref(), - ctx, - "then_expr", - input_schema, - codec, - )?, - )) - }) - .collect::>>()?, - e.else_expr - .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), ctx, input_schema, codec)) - .transpose()?, - )?), - ExprType::Cast(e) => Arc::new(CastExpr::new( - parse_required_physical_expr( - e.expr.as_deref(), - ctx, - "expr", - input_schema, - codec, - )?, - convert_required!(e.arrow_type)?, - None, - )), - ExprType::TryCast(e) => Arc::new(TryCastExpr::new( - parse_required_physical_expr( - e.expr.as_deref(), - ctx, - "expr", - input_schema, - codec, - )?, - convert_required!(e.arrow_type)?, - )), + ExprType::IsNullExpr(_) => IsNullExpr::try_from_proto(proto, &decode_ctx)?, + ExprType::IsNotNullExpr(_) => IsNotNullExpr::try_from_proto(proto, &decode_ctx)?, + ExprType::NotExpr(_) => NotExpr::try_from_proto(proto, &decode_ctx)?, + ExprType::Negative(_) => NegativeExpr::try_from_proto(proto, &decode_ctx)?, + ExprType::InList(_) => InListExpr::try_from_proto(proto, &decode_ctx)?, + ExprType::Case(_) => CaseExpr::try_from_proto(proto, &decode_ctx)?, + ExprType::Cast(_) => CastExpr::try_from_proto(proto, &decode_ctx)?, + ExprType::TryCast(_) => TryCastExpr::try_from_proto(proto, &decode_ctx)?, ExprType::ScalarUdf(e) => { let udf = match &e.fun_definition { - Some(buf) => codec.try_decode_udf(&e.name, buf)?, + Some(buf) => ctx.codec().try_decode_udf(&e.name, buf)?, None => ctx + .task_ctx() .udf(e.name.as_str()) - .or_else(|_| codec.try_decode_udf(&e.name, &[]))?, + .or_else(|_| ctx.codec().try_decode_udf(&e.name, &[]))?, }; let scalar_fun_def = Arc::clone(&udf); - let args = parse_physical_exprs(&e.args, ctx, input_schema, codec)?; + let args = parse_physical_exprs(&e.args, ctx, input_schema, proto_converter)?; - let config_options = Arc::clone(ctx.session_config().options()); + let config_options = Arc::clone(ctx.task_ctx().session_config().options()); Arc::new( ScalarFunctionExpr::new( @@ -380,31 +336,84 @@ pub fn parse_physical_expr( .with_nullable(e.nullable), ) } - ExprType::LikeExpr(like_expr) => Arc::new(LikeExpr::new( - like_expr.negated, - like_expr.case_insensitive, - parse_required_physical_expr( - like_expr.expr.as_deref(), + ExprType::LikeExpr(_) => LikeExpr::try_from_proto(proto, &decode_ctx)?, + ExprType::HashExpr(_) => HashExpr::try_from_proto(proto, &decode_ctx)?, + ExprType::ScalarSubquery(sq) => { + let data_type: arrow::datatypes::DataType = sq + .data_type + .as_ref() + .ok_or_else(|| { + proto_error("Missing data_type in PhysicalScalarSubqueryExprNode") + })? + .try_into()?; + let results = ctx.scalar_subquery_results().ok_or_else(|| { + proto_error( + "ScalarSubqueryExpr can only be deserialized as part \ + of a surrounding ScalarSubqueryExec", + ) + })?; + Arc::new(ScalarSubqueryExpr::new( + data_type, + sq.nullable, + SubqueryIndex::new(sq.index as usize), + results.clone(), + )) + } + ExprType::DynamicFilter(dynamic_filter) => { + let children = parse_physical_exprs( + &dynamic_filter.children, ctx, - "expr", input_schema, - codec, - )?, - parse_required_physical_expr( - like_expr.pattern.as_deref(), + proto_converter, + )?; + + let remapped_children = if !dynamic_filter.remapped_children.is_empty() { + Some(parse_physical_exprs( + &dynamic_filter.remapped_children, + ctx, + input_schema, + proto_converter, + )?) + } else { + None + }; + + let inner_expr = parse_required_physical_expr( + dynamic_filter.inner_expr.as_deref(), ctx, - "pattern", + "inner_expr", input_schema, - codec, - )?, - )), + proto_converter, + )?; + + let expression_id = proto.expr_id.ok_or_else(|| { + proto_error( + "DynamicFilterPhysicalExpr requires PhysicalExprNode.expr_id \ + to be set by the serializer", + ) + })?; + + let base_filter: Arc = + Arc::new(DynamicFilterPhysicalExpr::from_parts( + children, + remapped_children, + DynamicFilterInner { + expression_id, + generation: dynamic_filter.generation, + expr: inner_expr, + is_complete: dynamic_filter.is_complete, + }, + )); + base_filter + } ExprType::Extension(extension) => { let inputs: Vec> = extension .inputs .iter() - .map(|e| parse_physical_expr(e, ctx, input_schema, codec)) + .map(|e| proto_converter.proto_to_physical_expr(e, input_schema, ctx)) .collect::>()?; - (codec.try_decode_expr(extension.expr.as_slice(), &inputs)?) as _ + ctx.codec() + .try_decode_expr(extension.expr.as_slice(), &inputs)? as _ } }; @@ -413,26 +422,30 @@ pub fn parse_physical_expr( fn parse_required_physical_expr( expr: Option<&protobuf::PhysicalExprNode>, - ctx: &TaskContext, + ctx: &PhysicalPlanDecodeContext<'_>, field: &str, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - expr.map(|e| parse_physical_expr(e, ctx, input_schema, codec)) + expr.map(|e| proto_converter.proto_to_physical_expr(e, input_schema, ctx)) .transpose()? .ok_or_else(|| internal_datafusion_err!("Missing required field {field:?}")) } pub fn parse_protobuf_hash_partitioning( partitioning: Option<&protobuf::PhysicalHashRepartition>, - ctx: &TaskContext, + ctx: &PhysicalPlanDecodeContext<'_>, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { match partitioning { Some(hash_part) => { - let expr = - parse_physical_exprs(&hash_part.hash_expr, ctx, input_schema, codec)?; + let expr = parse_physical_exprs( + &hash_part.hash_expr, + ctx, + input_schema, + proto_converter, + )?; Ok(Some(Partitioning::Hash( expr, @@ -445,9 +458,9 @@ pub fn parse_protobuf_hash_partitioning( pub fn parse_protobuf_partitioning( partitioning: Option<&protobuf::Partitioning>, - ctx: &TaskContext, + ctx: &PhysicalPlanDecodeContext<'_>, input_schema: &Schema, - codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { match partitioning { Some(protobuf::Partitioning { partition_method }) => match partition_method { @@ -461,9 +474,17 @@ pub fn parse_protobuf_partitioning( Some(hash_repartition), ctx, input_schema, - codec, + proto_converter, ) } + Some(protobuf::partitioning::PartitionMethod::Range(range_partitioning)) => { + Ok(Some(parse_protobuf_range_partitioning( + range_partitioning, + ctx, + input_schema, + proto_converter, + )?)) + } Some(protobuf::partitioning::PartitionMethod::Unknown(partition_count)) => { Ok(Some(Partitioning::UnknownPartitioning( *partition_count as usize, @@ -475,6 +496,49 @@ pub fn parse_protobuf_partitioning( } } +fn parse_protobuf_range_partitioning( + range_partitioning: &protobuf::PhysicalRangePartitioning, + ctx: &PhysicalPlanDecodeContext<'_>, + input_schema: &Schema, + proto_converter: &dyn PhysicalProtoConverterExtension, +) -> Result { + let sort_exprs = parse_physical_sort_exprs( + &range_partitioning.sort_expr, + ctx, + input_schema, + proto_converter, + )?; + let sort_expr_count = sort_exprs.len(); + let ordering = LexOrdering::new(sort_exprs).ok_or_else(|| { + internal_datafusion_err!("Range partitioning requires non-empty ordering") + })?; + if ordering.len() != sort_expr_count { + return Err(internal_datafusion_err!( + "Range partitioning ordering must not contain duplicate expressions" + )); + } + let split_points = range_partitioning + .split_point + .iter() + .map(parse_protobuf_range_split_point) + .collect::>()?; + Ok(Partitioning::Range(RangePartitioning::try_new( + ordering, + split_points, + )?)) +} + +fn parse_protobuf_range_split_point( + split_point: &protobuf::PhysicalRangeSplitPoint, +) -> Result { + let values = split_point + .value + .iter() + .map(|value| ScalarValue::try_from(value).map_err(Into::into)) + .collect::>()?; + Ok(SplitPoint::new(values)) +} + pub fn parse_protobuf_file_scan_schema( proto: &protobuf::FileScanExecConf, ) -> Result> { @@ -509,21 +573,18 @@ pub fn parse_table_schema_from_proto( .with_metadata(schema.metadata.clone()), ); - Ok(TableSchema::new(file_schema, table_partition_cols)) + Ok(TableSchema::builder(file_schema) + .with_table_partition_cols(table_partition_cols) + .build()) } pub fn parse_protobuf_file_scan_config( proto: &protobuf::FileScanExecConf, - ctx: &TaskContext, - codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, file_source: Arc, ) -> Result { let schema: Arc = parse_protobuf_file_scan_schema(proto)?; - let projection = proto - .projection - .iter() - .map(|i| *i as usize) - .collect::>(); let constraints = convert_required!(proto.constraints)?; let statistics = convert_required!(proto.statistics)?; @@ -531,7 +592,7 @@ pub fn parse_protobuf_file_scan_config( let file_groups = proto .file_groups .iter() - .map(|f| f.try_into()) + .map(FileGroup::try_from_proto) .collect::, _>>()?; let object_store_url = match proto.object_store_url.is_empty() { @@ -545,19 +606,46 @@ pub fn parse_protobuf_file_scan_config( &node_collection.physical_sort_expr_nodes, ctx, &schema, - codec, + proto_converter, )?; output_ordering.extend(LexOrdering::new(sort_exprs)); } + // Parse projection expressions if present and apply to file source + let file_source = if let Some(proto_projection_exprs) = &proto.projection_exprs { + let projection_exprs: Vec = proto_projection_exprs + .projections + .iter() + .map(|proto_expr| { + let expr = proto_converter.proto_to_physical_expr( + proto_expr.expr.as_ref().ok_or_else(|| { + internal_datafusion_err!("ProjectionExpr missing expr field") + })?, + &schema, + ctx, + )?; + Ok(ProjectionExpr::new(expr, proto_expr.alias.clone())) + }) + .collect::>>()?; + + let projection_exprs = ProjectionExprs::new(projection_exprs); + + // Apply projection to file source + file_source + .try_pushdown_projection(&projection_exprs)? + .unwrap_or(file_source) + } else { + file_source + }; + let config = FileScanConfigBuilder::new(object_store_url, file_source) .with_file_groups(file_groups) .with_constraints(constraints) .with_statistics(statistics) - .with_projection_indices(Some(projection)) .with_limit(proto.limit.as_ref().map(|sl| sl.limit as usize)) .with_output_ordering(output_ordering) .with_batch_size(proto.batch_size.map(|s| s as usize)) + .with_partitioned_by_file_group(proto.partitioned_by_file_group.unwrap_or(false)) .build(); Ok(config) } @@ -574,41 +662,44 @@ pub fn parse_record_batches(buf: &[u8]) -> Result> { Ok(batches) } -impl TryFrom<&protobuf::PartitionedFile> for PartitionedFile { +impl TryFromProto<&protobuf::PartitionedFile> for PartitionedFile { type Error = DataFusionError; - fn try_from(val: &protobuf::PartitionedFile) -> Result { - Ok(PartitionedFile { - object_meta: ObjectMeta { - location: Path::parse(val.path.as_str()).map_err(|e| { - proto_error(format!("Invalid object_store path: {e}")) - })?, - last_modified: Utc.timestamp_nanos(val.last_modified_ns as i64), - size: val.size, - e_tag: None, - version: None, - }, - partition_values: val - .partition_values + fn try_from_proto(val: &protobuf::PartitionedFile) -> Result { + let mut pf = PartitionedFile::new_from_meta(ObjectMeta { + location: Path::parse(val.path.as_str()) + .map_err(|e| proto_error(format!("Invalid object_store path: {e}")))?, + last_modified: Utc.timestamp_nanos(val.last_modified_ns as i64), + size: val.size, + e_tag: None, + version: None, + }) + .with_partition_values( + val.partition_values .iter() .map(|v| v.try_into()) .collect::, _>>()?, - range: val.range.as_ref().map(|v| v.try_into()).transpose()?, - statistics: val - .statistics - .as_ref() - .map(|v| v.try_into().map(Arc::new)) - .transpose()?, - extensions: None, - metadata_size_hint: None, - }) + ); + if let Some(proto_schema) = val.arrow_schema.as_ref() { + pf = pf.with_arrow_schema(Arc::new( + proto_schema.try_into().map_err(DataFusionError::from)?, + )); + } + if let Some(range) = val.range.as_ref() { + let file_range = FileRange::try_from_proto(range)?; + pf = pf.with_range(file_range.start, file_range.end); + } + if let Some(proto_stats) = val.statistics.as_ref() { + pf = pf.with_statistics(Arc::new(proto_stats.try_into()?)); + } + Ok(pf) } } -impl TryFrom<&protobuf::FileRange> for FileRange { +impl TryFromProto<&protobuf::FileRange> for FileRange { type Error = DataFusionError; - fn try_from(value: &protobuf::FileRange) -> Result { + fn try_from_proto(value: &protobuf::FileRange) -> Result { Ok(FileRange { start: value.start, end: value.end, @@ -616,61 +707,61 @@ impl TryFrom<&protobuf::FileRange> for FileRange { } } -impl TryFrom<&protobuf::FileGroup> for FileGroup { +impl TryFromProto<&protobuf::FileGroup> for FileGroup { type Error = DataFusionError; - fn try_from(val: &protobuf::FileGroup) -> Result { + fn try_from_proto(val: &protobuf::FileGroup) -> Result { let files = val .files .iter() - .map(|f| f.try_into()) + .map(PartitionedFile::try_from_proto) .collect::, _>>()?; Ok(FileGroup::new(files)) } } -impl TryFrom<&protobuf::JsonSink> for JsonSink { +impl TryFromProto<&protobuf::JsonSink> for JsonSink { type Error = DataFusionError; - fn try_from(value: &protobuf::JsonSink) -> Result { + fn try_from_proto(value: &protobuf::JsonSink) -> Result { Ok(Self::new( - convert_required!(value.config)?, + convert_required_proto!(FileSinkConfig, value.config)?, convert_required!(value.writer_options)?, )) } } #[cfg(feature = "parquet")] -impl TryFrom<&protobuf::ParquetSink> for ParquetSink { +impl TryFromProto<&protobuf::ParquetSink> for ParquetSink { type Error = DataFusionError; - fn try_from(value: &protobuf::ParquetSink) -> Result { + fn try_from_proto(value: &protobuf::ParquetSink) -> Result { Ok(Self::new( - convert_required!(value.config)?, + convert_required_proto!(FileSinkConfig, value.config)?, convert_required!(value.parquet_options)?, )) } } -impl TryFrom<&protobuf::CsvSink> for CsvSink { +impl TryFromProto<&protobuf::CsvSink> for CsvSink { type Error = DataFusionError; - fn try_from(value: &protobuf::CsvSink) -> Result { + fn try_from_proto(value: &protobuf::CsvSink) -> Result { Ok(Self::new( - convert_required!(value.config)?, + convert_required_proto!(FileSinkConfig, value.config)?, convert_required!(value.writer_options)?, )) } } -impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { +impl TryFromProto<&protobuf::FileSinkConfig> for FileSinkConfig { type Error = DataFusionError; - fn try_from(conf: &protobuf::FileSinkConfig) -> Result { + fn try_from_proto(conf: &protobuf::FileSinkConfig) -> Result { let file_group = FileGroup::new( conf.file_groups .iter() - .map(|f| f.try_into()) + .map(PartitionedFile::try_from_proto) .collect::>>()?, ); let table_paths = conf @@ -691,6 +782,17 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { protobuf::InsertOp::Overwrite => InsertOp::Overwrite, protobuf::InsertOp::Replace => InsertOp::Replace, }; + let file_output_mode = match conf.file_output_mode() { + protobuf::FileOutputMode::Automatic => { + datafusion_datasource::file_sink_config::FileOutputMode::Automatic + } + protobuf::FileOutputMode::SingleFile => { + datafusion_datasource::file_sink_config::FileOutputMode::SingleFile + } + protobuf::FileOutputMode::Directory => { + datafusion_datasource::file_sink_config::FileOutputMode::Directory + } + }; Ok(Self { original_url: String::default(), object_store_url: ObjectStoreUrl::parse(&conf.object_store_url)?, @@ -701,49 +803,95 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { insert_op, keep_partition_by_columns: conf.keep_partition_by_columns, file_extension: conf.file_extension.clone(), + file_output_mode, }) } } +/// Concrete [`PhysicalExprDecode`] driver that backs +/// [`PhysicalExprDecodeCtx`] inside `parse_physical_expr_with_converter`. +/// +/// Today this is a thin wrapper that re-enters the central match through +/// `proto_to_physical_expr`; once more expressions migrate, the central match +/// shrinks and a future builder-style decoder can take over. +/// +/// [`PhysicalExprDecode`]: datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecode +/// [`PhysicalExprDecodeCtx`]: datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecodeCtx +struct ConverterDecoder<'a, 'b> { + ctx: &'a PhysicalPlanDecodeContext<'b>, + proto_converter: &'a dyn PhysicalProtoConverterExtension, +} + +impl datafusion_physical_expr_common::physical_expr::proto_decode::PhysicalExprDecode + for ConverterDecoder<'_, '_> +{ + fn decode( + &self, + node: &protobuf::PhysicalExprNode, + schema: &Schema, + ) -> Result> { + self.proto_converter + .proto_to_physical_expr(node, schema, self.ctx) + } +} + #[cfg(test)] mod tests { + use super::*; - use chrono::{TimeZone, Utc}; - use datafusion_datasource::PartitionedFile; - use object_store::path::Path; - use object_store::ObjectMeta; #[test] fn partitioned_file_path_roundtrip_percent_encoded() { let path_str = "foo/foo%2Fbar/baz%252Fqux"; - let pf = PartitionedFile { - object_meta: ObjectMeta { - location: Path::parse(path_str).unwrap(), - last_modified: Utc.timestamp_nanos(1_000), - size: 42, - e_tag: None, - version: None, - }, - partition_values: vec![], - range: None, - statistics: None, - extensions: None, - metadata_size_hint: None, - }; - - let proto = protobuf::PartitionedFile::try_from(&pf).unwrap(); + let pf = PartitionedFile::new_from_meta(ObjectMeta { + location: Path::parse(path_str).unwrap(), + last_modified: Utc.timestamp_nanos(1_000), + size: 42, + e_tag: None, + version: None, + }); + + let proto = protobuf::PartitionedFile::try_from_proto(&pf).unwrap(); assert_eq!(proto.path, path_str); - let pf2 = PartitionedFile::try_from(&proto).unwrap(); + let pf2 = PartitionedFile::try_from_proto(&proto).unwrap(); assert_eq!(pf2.object_meta.location.as_ref(), path_str); assert_eq!(pf2.object_meta.location, pf.object_meta.location); assert_eq!(pf2.object_meta.size, pf.object_meta.size); assert_eq!(pf2.object_meta.last_modified, pf.object_meta.last_modified); } + #[test] + fn partitioned_file_arrow_schema_roundtrip() { + use arrow::datatypes::{DataType, Field, Schema}; + use std::collections::HashMap; + + let arrow_schema = Arc::new(Schema::new_with_metadata( + vec![ + Field::new("id", DataType::Int64, false), + Field::new("value", DataType::Utf8, true).with_metadata(HashMap::from([ + ("field_meta".to_string(), "field_value".to_string()), + ])), + ], + HashMap::from([("schema_meta".to_string(), "schema_value".to_string())]), + )); + let pf = PartitionedFile::new("foo/bar.parquet", 10) + .with_arrow_schema(Arc::clone(&arrow_schema)); + + let proto = protobuf::PartitionedFile::try_from_proto(&pf).unwrap(); + assert!(proto.arrow_schema.is_some()); + + let decoded = PartitionedFile::try_from_proto(&proto).unwrap(); + assert_eq!( + decoded.arrow_schema.as_ref().map(|s| s.as_ref()), + Some(arrow_schema.as_ref()) + ); + } + #[test] fn partitioned_file_from_proto_invalid_path() { let proto = protobuf::PartitionedFile { + arrow_schema: None, path: "foo//bar".to_string(), size: 1, last_modified_ns: 0, @@ -752,7 +900,7 @@ mod tests { statistics: None, }; - let err = PartitionedFile::try_from(&proto).unwrap_err(); + let err = PartitionedFile::try_from_proto(&proto).unwrap_err(); assert!(err.to_string().contains("Invalid object_store path")); } } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index fc7818fe461a6..9efcd25fcb412 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -15,37 +15,21 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::cell::RefCell; +use std::collections::HashMap; use std::fmt::Debug; use std::sync::Arc; -use self::from_proto::parse_protobuf_partitioning; -use self::to_proto::{serialize_partitioning, serialize_physical_expr}; -use crate::common::{byte_to_string, str_to_byte}; -use crate::physical_plan::from_proto::{ - parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs, - parse_physical_window_expr, parse_protobuf_file_scan_config, parse_record_batches, - parse_table_schema_from_proto, -}; -use crate::physical_plan::to_proto::{ - serialize_file_scan_config, serialize_maybe_filter, serialize_physical_aggr_expr, - serialize_physical_sort_exprs, serialize_physical_window_expr, - serialize_record_batches, -}; -use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; -use crate::protobuf::physical_expr_node::ExprType; -use crate::protobuf::physical_plan_node::PhysicalPlanType; -use crate::protobuf::{ - self, proto_error, window_agg_exec_node, ListUnnest as ProtoListUnnest, SortExprNode, - SortMergeJoinExecNode, -}; -use crate::{convert_required, into_required}; - use arrow::compute::SortOptions; -use arrow::datatypes::{IntervalMonthDayNanoType, SchemaRef}; +use arrow::datatypes::{IntervalMonthDayNanoType, Schema, SchemaRef}; use datafusion_catalog::memory::MemorySourceConfig; use datafusion_common::config::CsvOptions; +use datafusion_common::display::StringifiedPlan; +use datafusion_common::format::ExplainFormat; use datafusion_common::{ - internal_datafusion_err, internal_err, not_impl_err, DataFusionError, Result, + DataFusionError, JoinType, NullEquality, Result, internal_datafusion_err, + internal_err, not_impl_err, }; #[cfg(feature = "parquet")] use datafusion_datasource::file::FileSource; @@ -53,6 +37,7 @@ use datafusion_datasource::file_compression_type::FileCompressionType; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; use datafusion_datasource::sink::DataSinkExec; use datafusion_datasource::source::{DataSource, DataSourceExec}; +use datafusion_datasource_arrow::source::ArrowSource; #[cfg(feature = "avro")] use datafusion_datasource_avro::source::AvroSource; use datafusion_datasource_csv::file_format::CsvSink; @@ -60,52 +45,188 @@ use datafusion_datasource_csv::source::CsvSource; use datafusion_datasource_json::file_format::JsonSink; use datafusion_datasource_json::source::JsonSource; #[cfg(feature = "parquet")] +use datafusion_datasource_parquet::CachedParquetFileReaderFactory; +#[cfg(feature = "parquet")] use datafusion_datasource_parquet::file_format::ParquetSink; #[cfg(feature = "parquet")] use datafusion_datasource_parquet::source::ParquetSource; +#[cfg(feature = "parquet")] +use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{FunctionRegistry, TaskContext}; +use datafusion_expr::execution_props::{ScalarSubqueryResults, SubqueryIndex}; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use datafusion_functions_table::generate_series::{ Empty, GenSeriesArgs, GenerateSeriesTable, GenericSeriesState, TimestampValue, }; -use datafusion_physical_expr::aggregate::AggregateExprBuilder; -use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; +use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; +use datafusion_physical_expr::expressions::DynamicFilterPhysicalExpr; use datafusion_physical_expr::{LexOrdering, LexRequirement, PhysicalExprRef}; -use datafusion_physical_plan::aggregates::AggregateMode; -use datafusion_physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, +}; use datafusion_physical_plan::analyze::AnalyzeExec; +use datafusion_physical_plan::async_func::AsyncFuncExec; +use datafusion_physical_plan::buffer::BufferExec; +#[expect(deprecated)] use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::coop::CooperativeExec; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::explain::ExplainExec; use datafusion_physical_plan::expressions::PhysicalSortExpr; -use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::filter::{FilterExec, FilterExecBuilder}; use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use datafusion_physical_plan::joins::{ - CrossJoinExec, NestedLoopJoinExec, SortMergeJoinExec, StreamJoinPartitionMode, - SymmetricHashJoinExec, + CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, + StreamJoinPartitionMode, SymmetricHashJoinExec, }; -use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::memory::LazyMemoryExec; -use datafusion_physical_plan::metrics::MetricType; +use datafusion_physical_plan::metrics::MetricCategory; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::scalar_subquery::{ScalarSubqueryExec, ScalarSubqueryLink}; use datafusion_physical_plan::sorts::sort::SortExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::{InterleaveExec, UnionExec}; use datafusion_physical_plan::unnest::{ListUnnest, UnnestExec}; use datafusion_physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion_physical_plan::{ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr}; - -use prost::bytes::BufMut; use prost::Message; +use prost::bytes::BufMut; + +use self::from_proto::parse_protobuf_partitioning; +use self::to_proto::serialize_partitioning; +use crate::common::{byte_to_string, str_to_byte}; +use crate::convert::{FromProto, TryFromProto}; +use crate::convert_required; +use crate::physical_plan::from_proto::{ + parse_physical_expr_with_converter, parse_physical_sort_expr, + parse_physical_sort_exprs, parse_physical_window_expr, + parse_protobuf_file_scan_config, parse_record_batches, parse_table_schema_from_proto, +}; +use crate::physical_plan::to_proto::{ + serialize_file_scan_config, serialize_maybe_filter, serialize_physical_aggr_expr, + serialize_physical_expr_with_converter, serialize_physical_sort_exprs, + serialize_physical_window_expr, serialize_record_batches, +}; +use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; +use crate::protobuf::physical_expr_node::ExprType; +use crate::protobuf::physical_plan_node::PhysicalPlanType; +use crate::protobuf::{ + self, ListUnnest as ProtoListUnnest, SortExprNode, SortMergeJoinExecNode, + proto_error, window_agg_exec_node, +}; pub mod from_proto; pub mod to_proto; +const HUMAN_DISPLAY_ALIAS_PREFIX: &str = "\u{1f}datafusion_human_display_alias_v1:"; + +fn encode_human_display_alias(human_display: &str, alias: &str) -> String { + format!( + "{HUMAN_DISPLAY_ALIAS_PREFIX}{}:{alias}{human_display}", + alias.len() + ) +} + +fn split_human_display_alias<'a>( + human_display: &'a str, + name: &'a str, +) -> (&'a str, Option<&'a str>) { + if let Some(encoded) = human_display.strip_prefix(HUMAN_DISPLAY_ALIAS_PREFIX) + && let Some((alias_len, encoded)) = encoded.split_once(':') + && let Ok(alias_len) = alias_len.parse::() + && let Some(alias) = encoded.get(..alias_len) + && let Some(human_display) = encoded.get(alias_len..) + && alias == name + && !human_display.is_empty() + { + return (human_display, Some(alias)); + } + + (human_display, None) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn split_human_display_alias_ignores_mismatched_alias() { + let encoded = encode_human_display_alias("sum(value)", "revenue"); + + assert_eq!( + split_human_display_alias(&encoded, "other"), + (encoded.as_str(), None) + ); + } + + #[test] + fn split_human_display_alias_keeps_malformed_prefix_literal() { + let display = format!("{HUMAN_DISPLAY_ALIAS_PREFIX}not-an-encoding"); + + assert_eq!( + split_human_display_alias(&display, "agg"), + (display.as_str(), None) + ); + } +} + +/// Context threaded through physical-plan deserialization. +/// +/// This bundles the stable per-call inputs for deserialization and the +/// per-scope `ScalarSubqueryResults` handle needed while reconstructing +/// `ScalarSubqueryExpr` nodes inside a `ScalarSubqueryExec` input plan. +#[derive(Clone)] +pub struct PhysicalPlanDecodeContext<'a> { + task_ctx: &'a TaskContext, + codec: &'a dyn PhysicalExtensionCodec, + scalar_subquery_results: Option, +} + +impl<'a> PhysicalPlanDecodeContext<'a> { + /// Creates a new root decode context. + pub fn new(task_ctx: &'a TaskContext, codec: &'a dyn PhysicalExtensionCodec) -> Self { + Self { + task_ctx, + codec, + scalar_subquery_results: None, + } + } + + /// Returns the task context used for deserialization. + pub fn task_ctx(&self) -> &'a TaskContext { + self.task_ctx + } + + /// Returns the physical extension codec used for deserialization. + pub fn codec(&self) -> &'a dyn PhysicalExtensionCodec { + self.codec + } + + /// Returns the scalar subquery results container for the current scope, if + /// one is active. + pub fn scalar_subquery_results(&self) -> Option<&ScalarSubqueryResults> { + self.scalar_subquery_results.as_ref() + } + + /// Returns a child context with a different scalar subquery results + /// container. + pub fn with_scalar_subquery_results( + &self, + scalar_subquery_results: ScalarSubqueryResults, + ) -> Self { + Self { + task_ctx: self.task_ctx, + codec: self.codec, + scalar_subquery_results: Some(scalar_subquery_results), + } + } +} + impl AsExecutionPlan for protobuf::PhysicalPlanNode { fn try_decode(buf: &[u8]) -> Result where @@ -129,350 +250,454 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { fn try_into_physical_plan( &self, ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + self.try_into_physical_plan_with_converter( + ctx, + codec, + &DefaultPhysicalProtoConverter {}, + ) + } + + fn try_from_physical_plan( + plan: Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + Self::try_from_physical_plan_with_converter( + plan, + codec, + &DefaultPhysicalProtoConverter {}, + ) + } +} + +/// Extension methods on [`protobuf::PhysicalPlanNode`]. +/// +/// The prost-generated `PhysicalPlanNode` struct lives in +/// `datafusion-proto-models`, which is foreign to this crate, so the orphan +/// rule forbids inherent `impl` blocks here. Instead, all (de)serialization +/// helpers are exposed through this trait. Callers can bring it in scope with +/// `use datafusion_proto::physical_plan::PhysicalPlanNodeExt;`. +/// +/// Method bodies live in the default trait implementation. To make the trait +/// usable as if it were inherent (i.e. let bodies access fields on `self`), +/// implementors provide [`PhysicalPlanNodeExt::node`] returning a reference +/// back to the concrete `protobuf::PhysicalPlanNode`. Default method bodies +/// then go through `self.node()` to read fields. +pub trait PhysicalPlanNodeExt: Sized { + /// Returns a reference to the underlying [`protobuf::PhysicalPlanNode`]. + fn node(&self) -> &protobuf::PhysicalPlanNode; + + fn try_into_physical_plan_with_converter( + &self, + ctx: &TaskContext, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + let decode_ctx = PhysicalPlanDecodeContext::new(ctx, codec); + self.try_into_physical_plan_with_context(&decode_ctx, proto_converter) + } - extension_codec: &dyn PhysicalExtensionCodec, + fn try_into_physical_plan_with_context( + &self, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let plan = self.physical_plan_type.as_ref().ok_or_else(|| { + let plan = self.node().physical_plan_type.as_ref().ok_or_else(|| { proto_error(format!( - "physical_plan::from_proto() Unsupported physical plan '{self:?}'" + "physical_plan::from_proto() Unsupported physical plan '{:?}'", + self.node(), )) })?; match plan { PhysicalPlanType::Explain(explain) => { - self.try_into_explain_physical_plan(explain, ctx, extension_codec) + self.try_into_explain_physical_plan(explain, ctx, proto_converter) } PhysicalPlanType::Projection(projection) => { - self.try_into_projection_physical_plan(projection, ctx, extension_codec) + self.try_into_projection_physical_plan(projection, ctx, proto_converter) } PhysicalPlanType::Filter(filter) => { - self.try_into_filter_physical_plan(filter, ctx, extension_codec) + self.try_into_filter_physical_plan(filter, ctx, proto_converter) } PhysicalPlanType::CsvScan(scan) => { - self.try_into_csv_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_csv_scan_physical_plan(scan, ctx, proto_converter) } PhysicalPlanType::JsonScan(scan) => { - self.try_into_json_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_json_scan_physical_plan(scan, ctx, proto_converter) } - #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] PhysicalPlanType::ParquetScan(scan) => { - self.try_into_parquet_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_parquet_scan_physical_plan(scan, ctx, proto_converter) } - #[cfg_attr(not(feature = "avro"), allow(unused_variables))] PhysicalPlanType::AvroScan(scan) => { - self.try_into_avro_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_avro_scan_physical_plan(scan, ctx, proto_converter) } PhysicalPlanType::MemoryScan(scan) => { - self.try_into_memory_scan_physical_plan(scan, ctx, extension_codec) + self.try_into_memory_scan_physical_plan(scan, ctx, proto_converter) + } + PhysicalPlanType::ArrowScan(scan) => { + self.try_into_arrow_scan_physical_plan(scan, ctx, proto_converter) } PhysicalPlanType::CoalesceBatches(coalesce_batches) => self .try_into_coalesce_batches_physical_plan( coalesce_batches, ctx, - extension_codec, + proto_converter, ), PhysicalPlanType::Merge(merge) => { - self.try_into_merge_physical_plan(merge, ctx, extension_codec) + self.try_into_merge_physical_plan(merge, ctx, proto_converter) } PhysicalPlanType::Repartition(repart) => { - self.try_into_repartition_physical_plan(repart, ctx, extension_codec) + self.try_into_repartition_physical_plan(repart, ctx, proto_converter) } PhysicalPlanType::GlobalLimit(limit) => { - self.try_into_global_limit_physical_plan(limit, ctx, extension_codec) + self.try_into_global_limit_physical_plan(limit, ctx, proto_converter) } PhysicalPlanType::LocalLimit(limit) => { - self.try_into_local_limit_physical_plan(limit, ctx, extension_codec) + self.try_into_local_limit_physical_plan(limit, ctx, proto_converter) } PhysicalPlanType::Window(window_agg) => { - self.try_into_window_physical_plan(window_agg, ctx, extension_codec) + self.try_into_window_physical_plan(window_agg, ctx, proto_converter) } PhysicalPlanType::Aggregate(hash_agg) => { - self.try_into_aggregate_physical_plan(hash_agg, ctx, extension_codec) + self.try_into_aggregate_physical_plan(hash_agg, ctx, proto_converter) } PhysicalPlanType::HashJoin(hashjoin) => { - self.try_into_hash_join_physical_plan(hashjoin, ctx, extension_codec) + self.try_into_hash_join_physical_plan(hashjoin, ctx, proto_converter) } PhysicalPlanType::SymmetricHashJoin(sym_join) => self .try_into_symmetric_hash_join_physical_plan( sym_join, ctx, - extension_codec, + proto_converter, ), PhysicalPlanType::Union(union) => { - self.try_into_union_physical_plan(union, ctx, extension_codec) + self.try_into_union_physical_plan(union, ctx, proto_converter) } PhysicalPlanType::Interleave(interleave) => { - self.try_into_interleave_physical_plan(interleave, ctx, extension_codec) + self.try_into_interleave_physical_plan(interleave, ctx, proto_converter) } PhysicalPlanType::CrossJoin(crossjoin) => { - self.try_into_cross_join_physical_plan(crossjoin, ctx, extension_codec) + self.try_into_cross_join_physical_plan(crossjoin, ctx, proto_converter) } PhysicalPlanType::Empty(empty) => { - self.try_into_empty_physical_plan(empty, ctx, extension_codec) + self.try_into_empty_physical_plan(empty, ctx, proto_converter) + } + PhysicalPlanType::PlaceholderRow(placeholder) => { + self.try_into_placeholder_row_physical_plan(placeholder, ctx) } - PhysicalPlanType::PlaceholderRow(placeholder) => self - .try_into_placeholder_row_physical_plan( - placeholder, - ctx, - extension_codec, - ), PhysicalPlanType::Sort(sort) => { - self.try_into_sort_physical_plan(sort, ctx, extension_codec) + self.try_into_sort_physical_plan(sort, ctx, proto_converter) } PhysicalPlanType::SortPreservingMerge(sort) => self - .try_into_sort_preserving_merge_physical_plan(sort, ctx, extension_codec), + .try_into_sort_preserving_merge_physical_plan(sort, ctx, proto_converter), PhysicalPlanType::Extension(extension) => { - self.try_into_extension_physical_plan(extension, ctx, extension_codec) + self.try_into_extension_physical_plan(extension, ctx, proto_converter) } PhysicalPlanType::NestedLoopJoin(join) => { - self.try_into_nested_loop_join_physical_plan(join, ctx, extension_codec) + self.try_into_nested_loop_join_physical_plan(join, ctx, proto_converter) } PhysicalPlanType::Analyze(analyze) => { - self.try_into_analyze_physical_plan(analyze, ctx, extension_codec) + self.try_into_analyze_physical_plan(analyze, ctx, proto_converter) } PhysicalPlanType::JsonSink(sink) => { - self.try_into_json_sink_physical_plan(sink, ctx, extension_codec) + self.try_into_json_sink_physical_plan(sink, ctx, proto_converter) } PhysicalPlanType::CsvSink(sink) => { - self.try_into_csv_sink_physical_plan(sink, ctx, extension_codec) + self.try_into_csv_sink_physical_plan(sink, ctx, proto_converter) } #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] PhysicalPlanType::ParquetSink(sink) => { - self.try_into_parquet_sink_physical_plan(sink, ctx, extension_codec) + self.try_into_parquet_sink_physical_plan(sink, ctx, proto_converter) } PhysicalPlanType::Unnest(unnest) => { - self.try_into_unnest_physical_plan(unnest, ctx, extension_codec) + self.try_into_unnest_physical_plan(unnest, ctx, proto_converter) } PhysicalPlanType::Cooperative(cooperative) => { - self.try_into_cooperative_physical_plan(cooperative, ctx, extension_codec) + self.try_into_cooperative_physical_plan(cooperative, ctx, proto_converter) } PhysicalPlanType::GenerateSeries(generate_series) => { self.try_into_generate_series_physical_plan(generate_series) } PhysicalPlanType::SortMergeJoin(sort_join) => { - self.try_into_sort_join(sort_join, ctx, extension_codec) + self.try_into_sort_join(sort_join, ctx, proto_converter) + } + PhysicalPlanType::AsyncFunc(async_func) => { + self.try_into_async_func_physical_plan(async_func, ctx, proto_converter) + } + PhysicalPlanType::Buffer(buffer) => { + self.try_into_buffer_physical_plan(buffer, ctx, proto_converter) + } + PhysicalPlanType::ScalarSubquery(sq) => { + self.try_into_scalar_subquery_physical_plan(sq, ctx, proto_converter) } } } - fn try_from_physical_plan( + fn try_from_physical_plan_with_converter( plan: Arc, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result - where - Self: Sized, - { + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { let plan_clone = Arc::clone(&plan); - let plan = plan.as_any(); + let plan = plan.as_ref() as &dyn Any; if let Some(exec) = plan.downcast_ref::() { - return protobuf::PhysicalPlanNode::try_from_explain_exec( - exec, - extension_codec, - ); + return protobuf::PhysicalPlanNode::try_from_explain_exec(exec, codec); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_projection_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_analyze_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_filter_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(limit) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_global_limit_exec( limit, - extension_codec, + codec, + proto_converter, ); } if let Some(limit) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_local_limit_exec( limit, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_hash_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_symmetric_hash_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_sort_merge_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_cross_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_aggregate_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(empty) = plan.downcast_ref::() { - return protobuf::PhysicalPlanNode::try_from_empty_exec( - empty, - extension_codec, - ); + return protobuf::PhysicalPlanNode::try_from_empty_exec(empty, codec); } if let Some(empty) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_placeholder_row_exec( - empty, - extension_codec, + empty, codec, ); } + #[expect(deprecated)] if let Some(coalesce_batches) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_coalesce_batches_exec( coalesce_batches, - extension_codec, + codec, + proto_converter, ); } - if let Some(data_source_exec) = plan.downcast_ref::() { - if let Some(node) = protobuf::PhysicalPlanNode::try_from_data_source_exec( + if let Some(data_source_exec) = plan.downcast_ref::() + && let Some(node) = protobuf::PhysicalPlanNode::try_from_data_source_exec( data_source_exec, - extension_codec, - )? { - return Ok(node); - } + codec, + proto_converter, + )? + { + return Ok(node); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_coalesce_partitions_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_repartition_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { - return protobuf::PhysicalPlanNode::try_from_sort_exec(exec, extension_codec); + return protobuf::PhysicalPlanNode::try_from_sort_exec( + exec, + codec, + proto_converter, + ); } if let Some(union) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_union_exec( union, - extension_codec, + codec, + proto_converter, ); } if let Some(interleave) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_interleave_exec( interleave, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_sort_preserving_merge_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_nested_loop_join_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_window_agg_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_bounded_window_agg_exec( exec, - extension_codec, + codec, + proto_converter, ); } - if let Some(exec) = plan.downcast_ref::() { - if let Some(node) = protobuf::PhysicalPlanNode::try_from_data_sink_exec( + if let Some(exec) = plan.downcast_ref::() + && let Some(node) = protobuf::PhysicalPlanNode::try_from_data_sink_exec( exec, - extension_codec, - )? { - return Ok(node); - } + codec, + proto_converter, + )? + { + return Ok(node); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_unnest_exec( exec, - extension_codec, + codec, + proto_converter, ); } if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_cooperative_exec( exec, - extension_codec, + codec, + proto_converter, ); } - if let Some(exec) = plan.downcast_ref::() { - if let Some(node) = + if let Some(exec) = plan.downcast_ref::() + && let Some(node) = protobuf::PhysicalPlanNode::try_from_lazy_memory_exec(exec)? - { - return Ok(node); - } + { + return Ok(node); + } + + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_async_func_exec( + exec, + codec, + proto_converter, + ); + } + + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_buffer_exec( + exec, + codec, + proto_converter, + ); + } + + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_scalar_subquery_exec( + exec, + codec, + proto_converter, + ); } let mut buf: Vec = vec![]; - match extension_codec.try_encode(Arc::clone(&plan_clone), &mut buf) { + match codec.try_encode(Arc::clone(&plan_clone), &mut buf, proto_converter) { Ok(_) => { let inputs: Vec = plan_clone .children() .into_iter() .cloned() .map(|i| { - protobuf::PhysicalPlanNode::try_from_physical_plan( + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( i, - extension_codec, + codec, + proto_converter, ) }) .collect::>()?; @@ -488,22 +713,19 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { ), } } -} -impl protobuf::PhysicalPlanNode { fn try_into_explain_physical_plan( &self, explain: &protobuf::ExplainExecNode, - _ctx: &TaskContext, - - _extension_codec: &dyn PhysicalExtensionCodec, + _ctx: &PhysicalPlanDecodeContext<'_>, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { Ok(Arc::new(ExplainExec::new( Arc::new(explain.schema.as_ref().unwrap().try_into()?), explain .stringified_plans .iter() - .map(|plan| plan.into()) + .map(StringifiedPlan::from_proto) .collect(), explain.verbose, ))) @@ -512,23 +734,21 @@ impl protobuf::PhysicalPlanNode { fn try_into_projection_physical_plan( &self, projection: &protobuf::ProjectionExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&projection.input, ctx, extension_codec)?; + into_physical_plan(&projection.input, ctx, proto_converter)?; let exprs = projection .expr .iter() .zip(projection.expr_name.iter()) .map(|(expr, name)| { Ok(( - parse_physical_expr( + proto_converter.proto_to_physical_expr( expr, - ctx, input.schema().as_ref(), - extension_codec, + ctx, )?, name.to_string(), )) @@ -544,18 +764,17 @@ impl protobuf::PhysicalPlanNode { fn try_into_filter_physical_plan( &self, filter: &protobuf::FilterExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&filter.input, ctx, extension_codec)?; + into_physical_plan(&filter.input, ctx, proto_converter)?; let predicate = filter .expr .as_ref() .map(|expr| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + proto_converter.proto_to_physical_expr(expr, input.schema().as_ref(), ctx) }) .transpose()? .ok_or_else(|| { @@ -565,20 +784,28 @@ impl protobuf::PhysicalPlanNode { })?; let filter_selectivity = filter.default_filter_selectivity.try_into(); - let projection = if !filter.projection.is_empty() { - Some( - filter - .projection - .iter() - .map(|i| *i as usize) - .collect::>(), - ) - } else { + // Preserve the `None` state across proto boundaries. Proto cannot distinguish + // between `None` (full projection) and `Some(vec![])` (empty projection) since + // both serialize as an empty list. If all columns are included, we reconstruct + // `None` to avoid losing this semantic distinction on deserialization. + let num_fields = input.schema().fields().len(); + let mut is_full_projection = filter.projection.len() == num_fields; + let mut projection_vec: Vec = Vec::with_capacity(filter.projection.len()); + for (i, idx) in filter.projection.iter().enumerate() { + let idx = *idx as usize; + is_full_projection &= idx == i; + projection_vec.push(idx); + } + let projection = if is_full_projection { None + } else { + Some(projection_vec) }; - - let filter = - FilterExec::try_new(predicate, input)?.with_projection(projection)?; + let filter = FilterExecBuilder::new(predicate, input) + .apply_projection(projection)? + .with_batch_size(filter.batch_size as usize) + .with_fetch(filter.fetch.map(|f| f as usize)) + .build()?; match filter_selectivity { Ok(filter_selectivity) => Ok(Arc::new( filter.with_default_selectivity(filter_selectivity)?, @@ -592,9 +819,8 @@ impl protobuf::PhysicalPlanNode { fn try_into_csv_scan_physical_plan( &self, scan: &protobuf::CsvScanExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let escape = if let Some(protobuf::csv_scan_exec_node::OptionalEscape::Escape(escape)) = @@ -622,6 +848,7 @@ impl protobuf::PhysicalPlanNode { has_header: Some(scan.has_header), delimiter: str_to_byte(&scan.delimiter, "delimiter")?, quote: str_to_byte(&scan.quote, "quote")?, + newlines_in_values: Some(scan.newlines_in_values), ..Default::default() }; let source = Arc::new( @@ -634,10 +861,9 @@ impl protobuf::PhysicalPlanNode { let conf = FileScanConfigBuilder::from(parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), ctx, - extension_codec, + proto_converter, source, )?) - .with_newlines_in_values(scan.newlines_in_values) .with_file_compression_type(FileCompressionType::UNCOMPRESSED) .build(); Ok(DataSourceExec::from_data_source(conf)) @@ -646,28 +872,45 @@ impl protobuf::PhysicalPlanNode { fn try_into_json_scan_physical_plan( &self, scan: &protobuf::JsonScanExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let base_conf = scan.base_conf.as_ref().unwrap(); let table_schema = parse_table_schema_from_proto(base_conf)?; let scan_conf = parse_protobuf_file_scan_config( base_conf, ctx, - extension_codec, + proto_converter, Arc::new(JsonSource::new(table_schema)), )?; Ok(DataSourceExec::from_data_source(scan_conf)) } - #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] + fn try_into_arrow_scan_physical_plan( + &self, + scan: &protobuf::ArrowScanExecNode, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + let base_conf = scan.base_conf.as_ref().ok_or_else(|| { + internal_datafusion_err!("base_conf in ArrowScanExecNode is missing.") + })?; + let table_schema = parse_table_schema_from_proto(base_conf)?; + let scan_conf = parse_protobuf_file_scan_config( + base_conf, + ctx, + proto_converter, + Arc::new(ArrowSource::new_file_source(table_schema)), + )?; + Ok(DataSourceExec::from_data_source(scan_conf)) + } + + #[cfg_attr(not(feature = "parquet"), expect(unused_variables))] fn try_into_parquet_scan_physical_plan( &self, scan: &protobuf::ParquetScanExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { #[cfg(feature = "parquet")] { @@ -684,7 +927,7 @@ impl protobuf::PhysicalPlanNode { .iter() .map(|&i| schema.field(i as usize).clone()) .collect(); - Arc::new(arrow::datatypes::Schema::new(projected_fields)) + Arc::new(Schema::new(projected_fields)) } else { schema }; @@ -693,11 +936,10 @@ impl protobuf::PhysicalPlanNode { .predicate .as_ref() .map(|expr| { - parse_physical_expr( + proto_converter.proto_to_physical_expr( expr, - ctx, predicate_schema.as_ref(), - extension_codec, + ctx, ) }) .transpose()?; @@ -709,9 +951,25 @@ impl protobuf::PhysicalPlanNode { // Parse table schema with partition columns let table_schema = parse_table_schema_from_proto(base_conf)?; - - let mut source = - ParquetSource::new(table_schema).with_table_parquet_options(options); + let object_store_url = match base_conf.object_store_url.is_empty() { + false => ObjectStoreUrl::parse(&base_conf.object_store_url)?, + true => ObjectStoreUrl::local_filesystem(), + }; + let store = ctx + .task_ctx() + .runtime_env() + .object_store(object_store_url)?; + let metadata_cache = ctx + .task_ctx() + .runtime_env() + .cache_manager + .get_file_metadata_cache(); + let reader_factory = + Arc::new(CachedParquetFileReaderFactory::new(store, metadata_cache)); + + let mut source = ParquetSource::new(table_schema) + .with_parquet_file_reader_factory(reader_factory) + .with_table_parquet_options(options); if let Some(predicate) = predicate { source = source.with_predicate(predicate); @@ -719,21 +977,23 @@ impl protobuf::PhysicalPlanNode { let base_config = parse_protobuf_file_scan_config( base_conf, ctx, - extension_codec, + proto_converter, Arc::new(source), )?; Ok(DataSourceExec::from_data_source(base_config)) } #[cfg(not(feature = "parquet"))] - panic!("Unable to process a Parquet PhysicalPlan when `parquet` feature is not enabled") + panic!( + "Unable to process a Parquet PhysicalPlan when `parquet` feature is not enabled" + ) } - #[cfg_attr(not(feature = "avro"), allow(unused_variables))] + #[cfg_attr(not(feature = "avro"), expect(unused_variables))] fn try_into_avro_scan_physical_plan( &self, scan: &protobuf::AvroScanExecNode, - ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { #[cfg(feature = "avro")] { @@ -742,11 +1002,12 @@ impl protobuf::PhysicalPlanNode { let conf = parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), ctx, - extension_codec, + proto_converter, Arc::new(AvroSource::new(table_schema)), )?; Ok(DataSourceExec::from_data_source(conf)) } + #[cfg(not(feature = "avro"))] panic!("Unable to process a Avro PhysicalPlan when `avro` feature is not enabled") } @@ -754,9 +1015,8 @@ impl protobuf::PhysicalPlanNode { fn try_into_memory_scan_physical_plan( &self, scan: &protobuf::MemoryScanExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let partitions = scan .partitions @@ -786,7 +1046,7 @@ impl protobuf::PhysicalPlanNode { &ordering.physical_sort_expr_nodes, ctx, &schema, - extension_codec, + proto_converter, )?; sort_information.extend(LexOrdering::new(sort_exprs)); } @@ -803,13 +1063,13 @@ impl protobuf::PhysicalPlanNode { fn try_into_coalesce_batches_physical_plan( &self, coalesce_batches: &protobuf::CoalesceBatchesExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&coalesce_batches.input, ctx, extension_codec)?; + into_physical_plan(&coalesce_batches.input, ctx, proto_converter)?; Ok(Arc::new( + #[expect(deprecated)] CoalesceBatchesExec::new(input, coalesce_batches.target_batch_size as usize) .with_fetch(coalesce_batches.fetch.map(|f| f as usize)), )) @@ -818,12 +1078,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_merge_physical_plan( &self, merge: &protobuf::CoalescePartitionsExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&merge.input, ctx, extension_codec)?; + into_physical_plan(&merge.input, ctx, proto_converter)?; Ok(Arc::new( CoalescePartitionsExec::new(input) .with_fetch(merge.fetch.map(|f| f as usize)), @@ -833,33 +1092,32 @@ impl protobuf::PhysicalPlanNode { fn try_into_repartition_physical_plan( &self, repart: &protobuf::RepartitionExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&repart.input, ctx, extension_codec)?; + into_physical_plan(&repart.input, ctx, proto_converter)?; let partitioning = parse_protobuf_partitioning( repart.partitioning.as_ref(), ctx, input.schema().as_ref(), - extension_codec, + proto_converter, )?; - Ok(Arc::new(RepartitionExec::try_new( - input, - partitioning.unwrap(), - )?)) + let mut repart_exec = RepartitionExec::try_new(input, partitioning.unwrap())?; + if repart.preserve_order { + repart_exec = repart_exec.with_preserve_order(); + } + Ok(Arc::new(repart_exec)) } fn try_into_global_limit_physical_plan( &self, limit: &protobuf::GlobalLimitExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&limit.input, ctx, extension_codec)?; + into_physical_plan(&limit.input, ctx, proto_converter)?; let fetch = if limit.fetch >= 0 { Some(limit.fetch as usize) } else { @@ -875,24 +1133,22 @@ impl protobuf::PhysicalPlanNode { fn try_into_local_limit_physical_plan( &self, limit: &protobuf::LocalLimitExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&limit.input, ctx, extension_codec)?; + into_physical_plan(&limit.input, ctx, proto_converter)?; Ok(Arc::new(LocalLimitExec::new(input, limit.fetch as usize))) } fn try_into_window_physical_plan( &self, window_agg: &protobuf::WindowAggExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&window_agg.input, ctx, extension_codec)?; + into_physical_plan(&window_agg.input, ctx, proto_converter)?; let input_schema = input.schema(); let physical_window_expr: Vec> = window_agg @@ -903,7 +1159,7 @@ impl protobuf::PhysicalPlanNode { window_expr, ctx, input_schema.as_ref(), - extension_codec, + proto_converter, ) }) .collect::, _>>()?; @@ -912,7 +1168,7 @@ impl protobuf::PhysicalPlanNode { .partition_keys .iter() .map(|expr| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + proto_converter.proto_to_physical_expr(expr, input.schema().as_ref(), ctx) }) .collect::>>>()?; @@ -945,12 +1201,11 @@ impl protobuf::PhysicalPlanNode { fn try_into_aggregate_physical_plan( &self, hash_agg: &protobuf::AggregateExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&hash_agg.input, ctx, extension_codec)?; + into_physical_plan(&hash_agg.input, ctx, proto_converter)?; let mode = protobuf::AggregateMode::try_from(hash_agg.mode).map_err(|_| { proto_error(format!( "Received a AggregateNode message with unknown AggregateMode {}", @@ -965,6 +1220,7 @@ impl protobuf::PhysicalPlanNode { protobuf::AggregateMode::SinglePartitioned => { AggregateMode::SinglePartitioned } + protobuf::AggregateMode::PartialReduce => AggregateMode::PartialReduce, }; let num_expr = hash_agg.group_expr.len(); @@ -974,7 +1230,8 @@ impl protobuf::PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + proto_converter + .proto_to_physical_expr(expr, input.schema().as_ref(), ctx) .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -984,7 +1241,8 @@ impl protobuf::PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec) + proto_converter + .proto_to_physical_expr(expr, input.schema().as_ref(), ctx) .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -999,6 +1257,8 @@ impl protobuf::PhysicalPlanNode { vec![] }; + let has_grouping_set = hash_agg.has_grouping_set; + let input_schema = hash_agg.input_schema.as_ref().ok_or_else(|| { internal_datafusion_err!("input_schema in AggregateNode is missing.") })?; @@ -1011,7 +1271,7 @@ impl protobuf::PhysicalPlanNode { expr.expr .as_ref() .map(|e| { - parse_physical_expr(e, ctx, &physical_schema, extension_codec) + proto_converter.proto_to_physical_expr(e, &physical_schema, ctx) }) .transpose() }) @@ -1032,11 +1292,10 @@ impl protobuf::PhysicalPlanNode { .expr .iter() .map(|e| { - parse_physical_expr( + proto_converter.proto_to_physical_expr( e, - ctx, &physical_schema, - extension_codec, + ctx, ) }) .collect::>>()?; @@ -1048,7 +1307,7 @@ impl protobuf::PhysicalPlanNode { e, ctx, &physical_schema, - extension_codec, + proto_converter, ) }) .collect::>()?; @@ -1058,23 +1317,39 @@ impl protobuf::PhysicalPlanNode { .map(|func| match func { AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = match &agg_node.fun_definition { - Some(buf) => extension_codec - .try_decode_udaf(udaf_name, buf)?, - None => ctx.udaf(udaf_name).or_else(|_| { - extension_codec - .try_decode_udaf(udaf_name, &[]) - })?, + Some(buf) => { + ctx.codec().try_decode_udaf(udaf_name, buf)? + } + None => ctx.task_ctx().udaf(udaf_name).or_else( + |_| { + ctx.codec() + .try_decode_udaf(udaf_name, &[]) + }, + )?, }; - AggregateExprBuilder::new(agg_udf, input_phy_expr) - .schema(Arc::clone(&physical_schema)) - .alias(name) - .human_display(agg_node.human_display.clone()) - .with_ignore_nulls(agg_node.ignore_nulls) - .with_distinct(agg_node.distinct) - .order_by(order_bys) - .build() - .map(Arc::new) + let (human_display, human_display_alias) = + split_human_display_alias( + &agg_node.human_display, + name, + ); + let builder = AggregateExprBuilder::new( + agg_udf, + input_phy_expr, + ) + .schema(Arc::clone(&physical_schema)) + .alias(name) + .with_ignore_nulls(agg_node.ignore_nulls) + .with_distinct(agg_node.distinct) + .order_by(order_bys) + .human_display(human_display); + let builder = if let Some(alias) = human_display_alias + { + builder.human_display_alias(alias) + } else { + builder + }; + builder.build().map(Arc::new) } }) .transpose()? @@ -1089,21 +1364,44 @@ impl protobuf::PhysicalPlanNode { }) .collect::, _>>()?; - let limit = hash_agg - .limit - .as_ref() - .map(|lit_value| lit_value.limit as usize); - + let physical_schema_ref = Arc::clone(&physical_schema); let agg = AggregateExec::try_new( agg_mode, - PhysicalGroupBy::new(group_expr, null_expr, groups), + PhysicalGroupBy::new(group_expr, null_expr, groups, has_grouping_set), physical_aggr_expr, physical_filter_expr, input, physical_schema, )?; - let agg = agg.with_limit(limit); + let agg = if let Some(limit_proto) = &hash_agg.limit { + let limit = limit_proto.limit as usize; + let limit_options = match limit_proto.descending { + Some(descending) => LimitOptions::new_with_order(limit, descending), + None => LimitOptions::new(limit), + }; + agg.with_limit_options(Some(limit_options)) + } else { + agg + }; + + let agg = if let Some(dynamic_filter_proto) = &hash_agg.dynamic_filter { + let dynamic_filter_expr = proto_converter.proto_to_physical_expr( + dynamic_filter_proto, + physical_schema_ref.as_ref(), + ctx, + )?; + let df = (dynamic_filter_expr as Arc) + .downcast::() + .map_err(|_| { + internal_datafusion_err!( + "AggregateExec dynamic_filter did not decode to a DynamicFilterPhysicalExpr" + ) + })?; + agg.with_dynamic_filter_expr(df)? + } else { + agg + }; Ok(Arc::new(agg)) } @@ -1111,31 +1409,28 @@ impl protobuf::PhysicalPlanNode { fn try_into_hash_join_physical_plan( &self, hashjoin: &protobuf::HashJoinExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let left: Arc = - into_physical_plan(&hashjoin.left, ctx, extension_codec)?; + into_physical_plan(&hashjoin.left, ctx, proto_converter)?; let right: Arc = - into_physical_plan(&hashjoin.right, ctx, extension_codec)?; + into_physical_plan(&hashjoin.right, ctx, proto_converter)?; let left_schema = left.schema(); let right_schema = right.schema(); let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = hashjoin .on .iter() .map(|col| { - let left = parse_physical_expr( + let left = proto_converter.proto_to_physical_expr( &col.left.clone().unwrap(), - ctx, left_schema.as_ref(), - extension_codec, + ctx, )?; - let right = parse_physical_expr( + let right = proto_converter.proto_to_physical_expr( &col.right.clone().unwrap(), - ctx, right_schema.as_ref(), - extension_codec, + ctx, )?; Ok((left, right)) }) @@ -1164,12 +1459,12 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = parse_physical_expr( + let expression = proto_converter.proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - ctx, &schema, - extension_codec, + &schema, + ctx, )?; let column_indices = f.column_indices .iter() @@ -1214,44 +1509,60 @@ impl protobuf::PhysicalPlanNode { } else { None }; - Ok(Arc::new(HashJoinExec::try_new( + let mut hash_join = HashJoinExec::try_new( left, right, on, filter, - &join_type.into(), + &JoinType::from_proto(join_type), projection, partition_mode, - null_equality.into(), - )?)) + NullEquality::from_proto(null_equality), + hashjoin.null_aware, + )?; + + if let Some(dynamic_filter_proto) = &hashjoin.dynamic_filter { + let dynamic_filter_expr = proto_converter.proto_to_physical_expr( + dynamic_filter_proto, + right_schema.as_ref(), + ctx, + )?; + let df = (dynamic_filter_expr as Arc) + .downcast::() + .map_err(|_| { + internal_datafusion_err!( + "HashJoinExec dynamic_filter did not decode to a DynamicFilterPhysicalExpr" + ) + })?; + hash_join = hash_join.with_dynamic_filter_expr(df)?; + } + + Ok(Arc::new(hash_join)) } fn try_into_symmetric_hash_join_physical_plan( &self, sym_join: &protobuf::SymmetricHashJoinExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let left = into_physical_plan(&sym_join.left, ctx, extension_codec)?; - let right = into_physical_plan(&sym_join.right, ctx, extension_codec)?; + let left = into_physical_plan(&sym_join.left, ctx, proto_converter)?; + let right = into_physical_plan(&sym_join.right, ctx, proto_converter)?; let left_schema = left.schema(); let right_schema = right.schema(); let on = sym_join .on .iter() .map(|col| { - let left = parse_physical_expr( + let left = proto_converter.proto_to_physical_expr( &col.left.clone().unwrap(), - ctx, left_schema.as_ref(), - extension_codec, + ctx, )?; - let right = parse_physical_expr( + let right = proto_converter.proto_to_physical_expr( &col.right.clone().unwrap(), - ctx, right_schema.as_ref(), - extension_codec, + ctx, )?; Ok((left, right)) }) @@ -1280,12 +1591,12 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = parse_physical_expr( + let expression = proto_converter.proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - ctx, &schema, - extension_codec, + &schema, + ctx, )?; let column_indices = f.column_indices .iter() @@ -1311,7 +1622,7 @@ impl protobuf::PhysicalPlanNode { &sym_join.left_sort_exprs, ctx, &left_schema, - extension_codec, + proto_converter, )?; let left_sort_exprs = LexOrdering::new(left_sort_exprs); @@ -1319,7 +1630,7 @@ impl protobuf::PhysicalPlanNode { &sym_join.right_sort_exprs, ctx, &right_schema, - extension_codec, + proto_converter, )?; let right_sort_exprs = LexOrdering::new(right_sort_exprs); @@ -1345,8 +1656,8 @@ impl protobuf::PhysicalPlanNode { right, on, filter, - &join_type.into(), - null_equality.into(), + &JoinType::from_proto(join_type), + NullEquality::from_proto(null_equality), left_sort_exprs, right_sort_exprs, partition_mode, @@ -1357,13 +1668,12 @@ impl protobuf::PhysicalPlanNode { fn try_into_union_physical_plan( &self, union: &protobuf::UnionExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let mut inputs: Vec> = vec![]; for input in &union.inputs { - inputs.push(input.try_into_physical_plan(ctx, extension_codec)?); + inputs.push(proto_converter.proto_to_execution_plan(input, ctx)?); } UnionExec::try_new(inputs) } @@ -1371,13 +1681,12 @@ impl protobuf::PhysicalPlanNode { fn try_into_interleave_physical_plan( &self, interleave: &protobuf::InterleaveExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let mut inputs: Vec> = vec![]; for input in &interleave.inputs { - inputs.push(input.try_into_physical_plan(ctx, extension_codec)?); + inputs.push(proto_converter.proto_to_execution_plan(input, ctx)?); } Ok(Arc::new(InterleaveExec::try_new(inputs)?)) } @@ -1385,23 +1694,21 @@ impl protobuf::PhysicalPlanNode { fn try_into_cross_join_physical_plan( &self, crossjoin: &protobuf::CrossJoinExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let left: Arc = - into_physical_plan(&crossjoin.left, ctx, extension_codec)?; + into_physical_plan(&crossjoin.left, ctx, proto_converter)?; let right: Arc = - into_physical_plan(&crossjoin.right, ctx, extension_codec)?; + into_physical_plan(&crossjoin.right, ctx, proto_converter)?; Ok(Arc::new(CrossJoinExec::new(left, right))) } fn try_into_empty_physical_plan( &self, empty: &protobuf::EmptyExecNode, - _ctx: &TaskContext, - - _extension_codec: &dyn PhysicalExtensionCodec, + _ctx: &PhysicalPlanDecodeContext<'_>, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let schema = Arc::new(convert_required!(empty.schema)?); Ok(Arc::new(EmptyExec::new(schema))) @@ -1410,9 +1717,7 @@ impl protobuf::PhysicalPlanNode { fn try_into_placeholder_row_physical_plan( &self, placeholder: &protobuf::PlaceholderRowExecNode, - _ctx: &TaskContext, - - _extension_codec: &dyn PhysicalExtensionCodec, + _ctx: &PhysicalPlanDecodeContext<'_>, ) -> Result> { let schema = Arc::new(convert_required!(placeholder.schema)?); Ok(Arc::new(PlaceholderRowExec::new(schema))) @@ -1421,18 +1726,18 @@ impl protobuf::PhysicalPlanNode { fn try_into_sort_physical_plan( &self, sort: &protobuf::SortExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sort.input, ctx, extension_codec)?; + let node = self.node(); + let input = into_physical_plan(&sort.input, ctx, proto_converter)?; let exprs = sort .expr .iter() .map(|expr| { let expr = expr.expr_type.as_ref().ok_or_else(|| { proto_error(format!( - "physical_plan::from_proto() Unexpected expr {self:?}" + "physical_plan::from_proto() Unexpected expr {node:?}" )) })?; if let ExprType::Sort(sort_expr) = expr { @@ -1441,12 +1746,16 @@ impl protobuf::PhysicalPlanNode { .as_ref() .ok_or_else(|| { proto_error(format!( - "physical_plan::from_proto() Unexpected sort expr {self:?}" + "physical_plan::from_proto() Unexpected sort expr {node:?}" )) })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, ctx, input.schema().as_ref(), extension_codec)?, + expr: proto_converter.proto_to_physical_expr( + expr, + input.schema().as_ref(), + ctx, + )?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -1454,7 +1763,7 @@ impl protobuf::PhysicalPlanNode { }) } else { internal_err!( - "physical_plan::from_proto() {self:?}" + "physical_plan::from_proto() {node:?}" ) } }) @@ -1467,24 +1776,42 @@ impl protobuf::PhysicalPlanNode { .with_fetch(fetch) .with_preserve_partitioning(sort.preserve_partitioning); + let new_sort = if let Some(dynamic_filter_proto) = &sort.dynamic_filter { + let dynamic_filter_expr = proto_converter.proto_to_physical_expr( + dynamic_filter_proto, + new_sort.input().schema().as_ref(), + ctx, + )?; + let df = (dynamic_filter_expr as Arc) + .downcast::() + .map_err(|_| { + internal_datafusion_err!( + "SortExec dynamic_filter did not decode to a DynamicFilterPhysicalExpr" + ) + })?; + new_sort.with_dynamic_filter_expr(df)? + } else { + new_sort + }; + Ok(Arc::new(new_sort)) } fn try_into_sort_preserving_merge_physical_plan( &self, sort: &protobuf::SortPreservingMergeExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sort.input, ctx, extension_codec)?; + let node = self.node(); + let input = into_physical_plan(&sort.input, ctx, proto_converter)?; let exprs = sort .expr .iter() .map(|expr| { let expr = expr.expr_type.as_ref().ok_or_else(|| { proto_error(format!( - "physical_plan::from_proto() Unexpected expr {self:?}" + "physical_plan::from_proto() Unexpected expr {node:?}" )) })?; if let ExprType::Sort(sort_expr) = expr { @@ -1493,16 +1820,15 @@ impl protobuf::PhysicalPlanNode { .as_ref() .ok_or_else(|| { proto_error(format!( - "physical_plan::from_proto() Unexpected sort expr {self:?}" + "physical_plan::from_proto() Unexpected sort expr {node:?}" )) })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr( + expr: proto_converter.proto_to_physical_expr( expr, - ctx, input.schema().as_ref(), - extension_codec, + ctx, )?, options: SortOptions { descending: !sort_expr.asc, @@ -1510,7 +1836,7 @@ impl protobuf::PhysicalPlanNode { }, }) } else { - internal_err!("physical_plan::from_proto() {self:?}") + internal_err!("physical_plan::from_proto() {node:?}") } }) .collect::>>()?; @@ -1526,18 +1852,21 @@ impl protobuf::PhysicalPlanNode { fn try_into_extension_physical_plan( &self, extension: &protobuf::PhysicalExtensionNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let inputs: Vec> = extension .inputs .iter() - .map(|i| i.try_into_physical_plan(ctx, extension_codec)) + .map(|i| proto_converter.proto_to_execution_plan(i, ctx)) .collect::>()?; - let extension_node = - extension_codec.try_decode(extension.node.as_slice(), &inputs, ctx)?; + let extension_node = ctx.codec().try_decode( + extension.node.as_slice(), + &inputs, + ctx.task_ctx(), + proto_converter, + )?; Ok(extension_node) } @@ -1545,14 +1874,13 @@ impl protobuf::PhysicalPlanNode { fn try_into_nested_loop_join_physical_plan( &self, join: &protobuf::NestedLoopJoinExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let left: Arc = - into_physical_plan(&join.left, ctx, extension_codec)?; + into_physical_plan(&join.left, ctx, proto_converter)?; let right: Arc = - into_physical_plan(&join.right, ctx, extension_codec)?; + into_physical_plan(&join.right, ctx, proto_converter)?; let join_type = protobuf::JoinType::try_from(join.join_type).map_err(|_| { proto_error(format!( "Received a NestedLoopJoinExecNode message with unknown JoinType {}", @@ -1569,12 +1897,13 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = parse_physical_expr( + let expression = proto_converter + .proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - ctx, &schema, - extension_codec, + &schema, + ctx, )?; let column_indices = f.column_indices .iter() @@ -1611,7 +1940,7 @@ impl protobuf::PhysicalPlanNode { left, right, filter, - &join_type.into(), + &JoinType::from_proto(join_type), projection, )?)) } @@ -1619,35 +1948,60 @@ impl protobuf::PhysicalPlanNode { fn try_into_analyze_physical_plan( &self, analyze: &protobuf::AnalyzeExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { let input: Arc = - into_physical_plan(&analyze.input, ctx, extension_codec)?; - Ok(Arc::new(AnalyzeExec::new( - analyze.verbose, - analyze.show_statistics, - vec![MetricType::SUMMARY, MetricType::DEV], - input, - Arc::new(convert_required!(analyze.schema)?), - ))) + into_physical_plan(&analyze.input, ctx, proto_converter)?; + let metric_categories = if analyze.has_metric_categories { + let cats: Result> = analyze + .metric_categories + .iter() + .map(|s| s.parse::()) + .collect(); + Some(cats?) + } else { + None + }; + let pb_format = + protobuf::ExplainFormat::try_from(analyze.format).map_err(|_| { + DataFusionError::Internal(format!( + "Received an AnalyzeExecNode message with unknown ExplainFormat {}", + analyze.format + )) + })?; + let format = match pb_format { + protobuf::ExplainFormat::Indent => ExplainFormat::Indent, + protobuf::ExplainFormat::Tree => ExplainFormat::Tree, + protobuf::ExplainFormat::Pgjson => ExplainFormat::PostgresJSON, + protobuf::ExplainFormat::Graphviz => ExplainFormat::Graphviz, + }; + Ok(Arc::new( + AnalyzeExec::builder( + analyze.verbose, + analyze.show_statistics, + input, + Arc::new(convert_required!(analyze.schema)?), + ) + .with_metric_categories(metric_categories) + .with_format(format) + .build(), + )) } fn try_into_json_sink_physical_plan( &self, sink: &protobuf::JsonSinkExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + let input = into_physical_plan(&sink.input, ctx, proto_converter)?; - let data_sink: JsonSink = sink - .sink - .as_ref() - .ok_or_else(|| proto_error("Missing required field in protobuf"))? - .try_into()?; + let data_sink = JsonSink::try_from_proto( + sink.sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))?, + )?; let sink_schema = input.schema(); let sort_order = sink .sort_order @@ -1657,7 +2011,7 @@ impl protobuf::PhysicalPlanNode { &collection.physical_sort_expr_nodes, ctx, &sink_schema, - extension_codec, + proto_converter, ) .map(|sort_exprs| { LexRequirement::new(sort_exprs.into_iter().map(Into::into)) @@ -1675,17 +2029,16 @@ impl protobuf::PhysicalPlanNode { fn try_into_csv_sink_physical_plan( &self, sink: &protobuf::CsvSinkExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + let input = into_physical_plan(&sink.input, ctx, proto_converter)?; - let data_sink: CsvSink = sink - .sink - .as_ref() - .ok_or_else(|| proto_error("Missing required field in protobuf"))? - .try_into()?; + let data_sink = CsvSink::try_from_proto( + sink.sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))?, + )?; let sink_schema = input.schema(); let sort_order = sink .sort_order @@ -1695,7 +2048,7 @@ impl protobuf::PhysicalPlanNode { &collection.physical_sort_expr_nodes, ctx, &sink_schema, - extension_codec, + proto_converter, ) .map(|sort_exprs| { LexRequirement::new(sort_exprs.into_iter().map(Into::into)) @@ -1714,19 +2067,18 @@ impl protobuf::PhysicalPlanNode { fn try_into_parquet_sink_physical_plan( &self, sink: &protobuf::ParquetSinkExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { #[cfg(feature = "parquet")] { - let input = into_physical_plan(&sink.input, ctx, extension_codec)?; + let input = into_physical_plan(&sink.input, ctx, proto_converter)?; - let data_sink: ParquetSink = sink - .sink - .as_ref() - .ok_or_else(|| proto_error("Missing required field in protobuf"))? - .try_into()?; + let data_sink = ParquetSink::try_from_proto( + sink.sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))?, + )?; let sink_schema = input.schema(); let sort_order = sink .sort_order @@ -1736,7 +2088,7 @@ impl protobuf::PhysicalPlanNode { &collection.physical_sort_expr_nodes, ctx, &sink_schema, - extension_codec, + proto_converter, ) .map(|sort_exprs| { LexRequirement::new(sort_exprs.into_iter().map(Into::into)) @@ -1757,11 +2109,10 @@ impl protobuf::PhysicalPlanNode { fn try_into_unnest_physical_plan( &self, unnest: &protobuf::UnnestExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&unnest.input, ctx, extension_codec)?; + let input = into_physical_plan(&unnest.input, ctx, proto_converter)?; Ok(Arc::new(UnnestExec::new( input, @@ -1775,7 +2126,11 @@ impl protobuf::PhysicalPlanNode { .collect(), unnest.struct_type_columns.iter().map(|c| *c as _).collect(), Arc::new(convert_required!(unnest.schema)?), - into_required!(unnest.options)?, + unnest + .options + .as_ref() + .map(datafusion_common::UnnestOptions::from_proto) + .ok_or_else(|| proto_error("Missing required field in protobuf"))?, )?)) } @@ -1788,13 +2143,12 @@ impl protobuf::PhysicalPlanNode { fn try_into_sort_join( &self, sort_join: &SortMergeJoinExecNode, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let left = into_physical_plan(&sort_join.left, ctx, extension_codec)?; + let left = into_physical_plan(&sort_join.left, ctx, proto_converter)?; let left_schema = left.schema(); - let right = into_physical_plan(&sort_join.right, ctx, extension_codec)?; + let right = into_physical_plan(&sort_join.right, ctx, proto_converter)?; let right_schema = right.schema(); let filter = sort_join @@ -1807,13 +2161,12 @@ impl protobuf::PhysicalPlanNode { .ok_or_else(|| proto_error("Missing JoinFilter schema"))? .try_into()?; - let expression = parse_physical_expr( + let expression = proto_converter.proto_to_physical_expr( f.expression.as_ref().ok_or_else(|| { proto_error("Unexpected empty filter expression") })?, - ctx, &schema, - extension_codec, + ctx, )?; let column_indices = f .column_indices @@ -1870,17 +2223,15 @@ impl protobuf::PhysicalPlanNode { .on .iter() .map(|col| { - let left = parse_physical_expr( + let left = proto_converter.proto_to_physical_expr( &col.left.clone().unwrap(), - ctx, left_schema.as_ref(), - extension_codec, + ctx, )?; - let right = parse_physical_expr( + let right = proto_converter.proto_to_physical_expr( &col.right.clone().unwrap(), - ctx, right_schema.as_ref(), - extension_codec, + ctx, )?; Ok((left, right)) }) @@ -1891,9 +2242,9 @@ impl protobuf::PhysicalPlanNode { right, on, filter, - join_type.into(), + JoinType::from_proto(join_type), sort_options, - null_equality.into(), + NullEquality::from_proto(null_equality), )?)) } @@ -1906,7 +2257,9 @@ impl protobuf::PhysicalPlanNode { let args = match &generate_series.args { Some(protobuf::generate_series_node::Args::ContainsNull(args)) => { GenSeriesArgs::ContainsNull { - name: Self::generate_series_name_to_str(args.name()), + name: protobuf::PhysicalPlanNode::generate_series_name_to_str( + args.name(), + ), } } Some(protobuf::generate_series_node::Args::Int64Args(args)) => { @@ -1915,7 +2268,9 @@ impl protobuf::PhysicalPlanNode { end: args.end, step: args.step, include_end: args.include_end, - name: Self::generate_series_name_to_str(args.name()), + name: protobuf::PhysicalPlanNode::generate_series_name_to_str( + args.name(), + ), } } Some(protobuf::generate_series_node::Args::TimestampArgs(args)) => { @@ -1933,7 +2288,9 @@ impl protobuf::PhysicalPlanNode { step, tz: args.tz.as_ref().map(|s| Arc::from(s.as_str())), include_end: args.include_end, - name: Self::generate_series_name_to_str(args.name()), + name: protobuf::PhysicalPlanNode::generate_series_name_to_str( + args.name(), + ), } } Some(protobuf::generate_series_node::Args::DateArgs(args)) => { @@ -1950,33 +2307,117 @@ impl protobuf::PhysicalPlanNode { end: args.end, step, include_end: args.include_end, - name: Self::generate_series_name_to_str(args.name()), + name: protobuf::PhysicalPlanNode::generate_series_name_to_str( + args.name(), + ), } } None => return internal_err!("Missing args in GenerateSeriesNode"), }; - let table = GenerateSeriesTable::new(Arc::clone(&schema), args); - let generator = table.as_generator(generate_series.target_batch_size as usize)?; + let table = GenerateSeriesTable::new(Arc::clone(&schema), args); + let generator = table.as_generator(generate_series.target_batch_size as usize)?; + + Ok(Arc::new(LazyMemoryExec::try_new(schema, vec![generator])?)) + } + + fn try_into_cooperative_physical_plan( + &self, + field_stream: &protobuf::CooperativeExecNode, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + let input = into_physical_plan(&field_stream.input, ctx, proto_converter)?; + Ok(Arc::new(CooperativeExec::new(input))) + } + + fn try_into_async_func_physical_plan( + &self, + async_func: &protobuf::AsyncFuncExecNode, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + let input: Arc = + into_physical_plan(&async_func.input, ctx, proto_converter)?; + + if async_func.async_exprs.len() != async_func.async_expr_names.len() { + return internal_err!( + "AsyncFuncExecNode async_exprs length does not match async_expr_names" + ); + } - Ok(Arc::new(LazyMemoryExec::try_new(schema, vec![generator])?)) + let async_exprs = async_func + .async_exprs + .iter() + .zip(async_func.async_expr_names.iter()) + .map(|(expr, name)| { + let physical_expr = proto_converter.proto_to_physical_expr( + expr, + input.schema().as_ref(), + ctx, + )?; + + Ok(Arc::new(AsyncFuncExpr::try_new( + name.clone(), + physical_expr, + input.schema().as_ref(), + )?)) + }) + .collect::>>()?; + + Ok(Arc::new(AsyncFuncExec::try_new(async_exprs, input)?)) } - fn try_into_cooperative_physical_plan( + fn try_into_buffer_physical_plan( &self, - field_stream: &protobuf::CooperativeExecNode, - ctx: &TaskContext, + buffer: &protobuf::BufferExecNode, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + let input: Arc = + into_physical_plan(&buffer.input, ctx, proto_converter)?; - extension_codec: &dyn PhysicalExtensionCodec, + Ok(Arc::new(BufferExec::new(input, buffer.capacity as usize))) + } + + fn try_into_scalar_subquery_physical_plan( + &self, + sq: &protobuf::ScalarSubqueryExecNode, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let input = into_physical_plan(&field_stream.input, ctx, extension_codec)?; - Ok(Arc::new(CooperativeExec::new(input))) + // First, deserialize the main input plan. We set up the subquery results + // container first, so that ScalarSubqueryExpr nodes can reference it. + let subquery_results = ScalarSubqueryResults::new(sq.subqueries.len()); + let input_ctx = ctx.with_scalar_subquery_results(subquery_results.clone()); + let input = into_physical_plan(&sq.input, &input_ctx, proto_converter)?; + + // Now deserialize the subquery children. + let subqueries: Vec = sq + .subqueries + .iter() + .enumerate() + .map(|(index, sq_plan)| { + let plan = + sq_plan.try_into_physical_plan_with_context(ctx, proto_converter)?; + Ok(ScalarSubqueryLink { + plan, + index: SubqueryIndex::new(index), + }) + }) + .collect::>>()?; + + Ok(Arc::new(ScalarSubqueryExec::new( + input, + subqueries, + subquery_results, + ))) } fn try_from_explain_exec( exec: &ExplainExec, - _extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { + _codec: &dyn PhysicalExtensionCodec, + ) -> Result { Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Explain( protobuf::ExplainExecNode { @@ -1984,7 +2425,7 @@ impl protobuf::PhysicalPlanNode { stringified_plans: exec .stringified_plans() .iter() - .map(|plan| plan.into()) + .map(protobuf::StringifiedPlan::from_proto) .collect(), verbose: exec.verbose(), }, @@ -1994,16 +2435,20 @@ impl protobuf::PhysicalPlanNode { fn try_from_projection_exec( exec: &ProjectionExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let expr = exec .expr() .iter() - .map(|proj_expr| serialize_physical_expr(&proj_expr.expr, extension_codec)) + .map(|proj_expr| { + proto_converter.physical_expr_to_proto(&proj_expr.expr, codec) + }) .collect::>>()?; let expr_name = exec .expr() @@ -2023,12 +2468,24 @@ impl protobuf::PhysicalPlanNode { fn try_from_analyze_exec( exec: &AnalyzeExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; + let (has_metric_categories, metric_categories) = match exec.metric_categories() { + Some(cats) => (true, cats.iter().map(|c| c.to_string()).collect()), + None => (false, vec![]), + }; + let format = match exec.format() { + ExplainFormat::Indent => protobuf::ExplainFormat::Indent, + ExplainFormat::Tree => protobuf::ExplainFormat::Tree, + ExplainFormat::PostgresJSON => protobuf::ExplainFormat::Pgjson, + ExplainFormat::Graphviz => protobuf::ExplainFormat::Graphviz, + } as i32; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Analyze(Box::new( protobuf::AnalyzeExecNode { @@ -2036,6 +2493,9 @@ impl protobuf::PhysicalPlanNode { show_statistics: exec.show_statistics(), input: Some(Box::new(input)), schema: Some(exec.schema().as_ref().try_into()?), + has_metric_categories, + metric_categories, + format, }, ))), }) @@ -2043,24 +2503,31 @@ impl protobuf::PhysicalPlanNode { fn try_from_filter_exec( exec: &FilterExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Filter(Box::new( protobuf::FilterExecNode { input: Some(Box::new(input)), - expr: Some(serialize_physical_expr( - exec.predicate(), - extension_codec, - )?), + expr: Some( + proto_converter + .physical_expr_to_proto(exec.predicate(), codec)?, + ), default_filter_selectivity: exec.default_selectivity() as u32, - projection: exec.projection().as_ref().map_or_else(Vec::new, |v| { - v.iter().map(|x| *x as u32).collect::>() - }), + projection: match exec.projection() { + None => (0..exec.input().schema().fields().len()) + .map(|i| i as u32) + .collect(), + Some(v) => v.iter().map(|x| *x as u32).collect(), + }, + batch_size: exec.batch_size() as u32, + fetch: exec.fetch().map(|f| f as u32), }, ))), }) @@ -2068,11 +2535,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_global_limit_exec( limit: &GlobalLimitExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( limit.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { @@ -2091,11 +2560,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_local_limit_exec( limit: &LocalLimitExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( limit.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::LocalLimit(Box::new( @@ -2109,36 +2580,39 @@ impl protobuf::PhysicalPlanNode { fn try_from_hash_join_exec( exec: &HashJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; let on: Vec = exec .on() .iter() .map(|tuple| { - let l = serialize_physical_expr(&tuple.0, extension_codec)?; - let r = serialize_physical_expr(&tuple.1, extension_codec)?; + let l = proto_converter.physical_expr_to_proto(&tuple.0, codec)?; + let r = proto_converter.physical_expr_to_proto(&tuple.1, codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), }) }) .collect::>()?; - let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); - let null_equality: protobuf::NullEquality = exec.null_equality().into(); + let join_type = protobuf::JoinType::from_proto(exec.join_type().to_owned()); + let null_equality = protobuf::NullEquality::from_proto(exec.null_equality()); let filter = exec .filter() .as_ref() .map(|f| { let expression = - serialize_physical_expr(f.expression(), extension_codec)?; + proto_converter.physical_expr_to_proto(f.expression(), codec)?; let column_indices = f .column_indices() .iter() @@ -2165,6 +2639,15 @@ impl protobuf::PhysicalPlanNode { PartitionMode::Auto => protobuf::PartitionMode::Auto, }; + let dynamic_filter = exec + .dynamic_filter_expr() + .map(|df| { + let df_expr: Arc = + Arc::clone(df) as Arc; + proto_converter.physical_expr_to_proto(&df_expr, codec) + }) + .transpose()?; + Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new( protobuf::HashJoinExecNode { @@ -2178,6 +2661,8 @@ impl protobuf::PhysicalPlanNode { projection: exec.projection.as_ref().map_or_else(Vec::new, |v| { v.iter().map(|x| *x as u32).collect::>() }), + null_aware: exec.null_aware, + dynamic_filter, }, ))), }) @@ -2185,36 +2670,39 @@ impl protobuf::PhysicalPlanNode { fn try_from_symmetric_hash_join_exec( exec: &SymmetricHashJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; let on = exec .on() .iter() .map(|tuple| { - let l = serialize_physical_expr(&tuple.0, extension_codec)?; - let r = serialize_physical_expr(&tuple.1, extension_codec)?; + let l = proto_converter.physical_expr_to_proto(&tuple.0, codec)?; + let r = proto_converter.physical_expr_to_proto(&tuple.1, codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), }) }) .collect::>()?; - let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); - let null_equality: protobuf::NullEquality = exec.null_equality().into(); + let join_type = protobuf::JoinType::from_proto(exec.join_type().to_owned()); + let null_equality = protobuf::NullEquality::from_proto(exec.null_equality()); let filter = exec .filter() .as_ref() .map(|f| { let expression = - serialize_physical_expr(f.expression(), extension_codec)?; + proto_converter.physical_expr_to_proto(f.expression(), codec)?; let column_indices = f .column_indices() .iter() @@ -2251,10 +2739,10 @@ impl protobuf::PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter + .physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -2271,10 +2759,10 @@ impl protobuf::PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter + .physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -2303,36 +2791,39 @@ impl protobuf::PhysicalPlanNode { fn try_from_sort_merge_join_exec( exec: &SortMergeJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; let on = exec .on() .iter() .map(|tuple| { - let l = serialize_physical_expr(&tuple.0, extension_codec)?; - let r = serialize_physical_expr(&tuple.1, extension_codec)?; + let l = proto_converter.physical_expr_to_proto(&tuple.0, codec)?; + let r = proto_converter.physical_expr_to_proto(&tuple.1, codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), }) }) .collect::>()?; - let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); - let null_equality: protobuf::NullEquality = exec.null_equality().into(); + let join_type = protobuf::JoinType::from_proto(exec.join_type().to_owned()); + let null_equality = protobuf::NullEquality::from_proto(exec.null_equality()); let filter = exec .filter() .as_ref() .map(|f| { let expression = - serialize_physical_expr(f.expression(), extension_codec)?; + proto_converter.physical_expr_to_proto(f.expression(), codec)?; let column_indices = f .column_indices() .iter() @@ -2372,7 +2863,7 @@ impl protobuf::PhysicalPlanNode { Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::SortMergeJoin(Box::new( - protobuf::SortMergeJoinExecNode { + SortMergeJoinExecNode { left: Some(Box::new(left)), right: Some(Box::new(right)), on, @@ -2387,15 +2878,18 @@ impl protobuf::PhysicalPlanNode { fn try_from_cross_join_exec( exec: &CrossJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CrossJoin(Box::new( @@ -2409,8 +2903,9 @@ impl protobuf::PhysicalPlanNode { fn try_from_aggregate_exec( exec: &AggregateExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { let groups: Vec = exec .group_expr() .groups() @@ -2429,13 +2924,15 @@ impl protobuf::PhysicalPlanNode { let filter = exec .filter_expr() .iter() - .map(|expr| serialize_maybe_filter(expr.to_owned(), extension_codec)) + .map(|expr| serialize_maybe_filter(expr.to_owned(), codec, proto_converter)) .collect::>>()?; let agg = exec .aggr_expr() .iter() - .map(|expr| serialize_physical_aggr_expr(expr.to_owned(), extension_codec)) + .map(|expr| { + serialize_physical_aggr_expr(expr.to_owned(), codec, proto_converter) + }) .collect::>>()?; let agg_names = exec @@ -2452,29 +2949,32 @@ impl protobuf::PhysicalPlanNode { AggregateMode::SinglePartitioned => { protobuf::AggregateMode::SinglePartitioned } + AggregateMode::PartialReduce => protobuf::AggregateMode::PartialReduce, }; let input_schema = exec.input_schema(); - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let null_expr = exec .group_expr() .null_expr() .iter() - .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) + .map(|expr| proto_converter.physical_expr_to_proto(&expr.0, codec)) .collect::>>()?; let group_expr = exec .group_expr() .expr() .iter() - .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) + .map(|expr| proto_converter.physical_expr_to_proto(&expr.0, codec)) .collect::>>()?; - let limit = exec.limit().map(|value| protobuf::AggLimit { - limit: value as u64, + let limit = exec.limit_options().map(|config| protobuf::AggLimit { + limit: config.limit() as u64, + descending: config.descending(), }); Ok(protobuf::PhysicalPlanNode { @@ -2491,6 +2991,15 @@ impl protobuf::PhysicalPlanNode { null_expr, groups, limit, + has_grouping_set: exec.group_expr().has_grouping_set(), + dynamic_filter: exec + .dynamic_filter_expr() + .map(|df| { + let df_expr: Arc = + Arc::clone(df) as Arc; + proto_converter.physical_expr_to_proto(&df_expr, codec) + }) + .transpose()?, }, ))), }) @@ -2498,8 +3007,8 @@ impl protobuf::PhysicalPlanNode { fn try_from_empty_exec( empty: &EmptyExec, - _extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { + _codec: &dyn PhysicalExtensionCodec, + ) -> Result { let schema = empty.schema().as_ref().try_into()?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Empty(protobuf::EmptyExecNode { @@ -2510,8 +3019,8 @@ impl protobuf::PhysicalPlanNode { fn try_from_placeholder_row_exec( empty: &PlaceholderRowExec, - _extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { + _codec: &dyn PhysicalExtensionCodec, + ) -> Result { let schema = empty.schema().as_ref().try_into()?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::PlaceholderRow( @@ -2522,13 +3031,16 @@ impl protobuf::PhysicalPlanNode { }) } + #[expect(deprecated)] fn try_from_coalesce_batches_exec( coalesce_batches: &CoalesceBatchesExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( coalesce_batches.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CoalesceBatches(Box::new( @@ -2543,18 +3055,20 @@ impl protobuf::PhysicalPlanNode { fn try_from_data_source_exec( data_source_exec: &DataSourceExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result> { + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { let data_source = data_source_exec.data_source(); - if let Some(maybe_csv) = data_source.as_any().downcast_ref::() { + if let Some(maybe_csv) = data_source.downcast_ref::() { let source = maybe_csv.file_source(); - if let Some(csv_config) = source.as_any().downcast_ref::() { + if let Some(csv_config) = source.downcast_ref::() { return Ok(Some(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CsvScan( protobuf::CsvScanExecNode { base_conf: Some(serialize_file_scan_config( maybe_csv, - extension_codec, + codec, + proto_converter, )?), has_header: csv_config.has_header(), delimiter: byte_to_string( @@ -2579,7 +3093,7 @@ impl protobuf::PhysicalPlanNode { } else { None }, - newlines_in_values: maybe_csv.newlines_in_values(), + newlines_in_values: csv_config.newlines_in_values(), truncate_rows: csv_config.truncate_rows(), }, )), @@ -2587,15 +3101,33 @@ impl protobuf::PhysicalPlanNode { } } - if let Some(scan_conf) = data_source.as_any().downcast_ref::() { + if let Some(scan_conf) = data_source.downcast_ref::() { let source = scan_conf.file_source(); - if let Some(_json_source) = source.as_any().downcast_ref::() { + if let Some(_json_source) = source.downcast_ref::() { return Ok(Some(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::JsonScan( protobuf::JsonScanExecNode { base_conf: Some(serialize_file_scan_config( scan_conf, - extension_codec, + codec, + proto_converter, + )?), + }, + )), + })); + } + } + + if let Some(scan_conf) = data_source.downcast_ref::() { + let source = scan_conf.file_source(); + if let Some(_arrow_source) = source.downcast_ref::() { + return Ok(Some(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::ArrowScan( + protobuf::ArrowScanExecNode { + base_conf: Some(serialize_file_scan_config( + scan_conf, + codec, + proto_converter, )?), }, )), @@ -2609,14 +3141,15 @@ impl protobuf::PhysicalPlanNode { { let predicate = conf .filter() - .map(|pred| serialize_physical_expr(&pred, extension_codec)) + .map(|pred| proto_converter.physical_expr_to_proto(&pred, codec)) .transpose()?; return Ok(Some(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetScan( protobuf::ParquetScanExecNode { base_conf: Some(serialize_file_scan_config( maybe_parquet, - extension_codec, + codec, + proto_converter, )?), predicate, parquet_options: Some(conf.table_parquet_options().try_into()?), @@ -2626,15 +3159,16 @@ impl protobuf::PhysicalPlanNode { } #[cfg(feature = "avro")] - if let Some(maybe_avro) = data_source.as_any().downcast_ref::() { + if let Some(maybe_avro) = data_source.downcast_ref::() { let source = maybe_avro.file_source(); - if source.as_any().downcast_ref::().is_some() { + if source.downcast_ref::().is_some() { return Ok(Some(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::AvroScan( protobuf::AvroScanExecNode { base_conf: Some(serialize_file_scan_config( maybe_avro, - extension_codec, + codec, + proto_converter, )?), }, )), @@ -2642,9 +3176,7 @@ impl protobuf::PhysicalPlanNode { } } - if let Some(source_conf) = - data_source.as_any().downcast_ref::() - { + if let Some(source_conf) = data_source.downcast_ref::() { let proto_partitions = source_conf .partitions() .iter() @@ -2667,7 +3199,8 @@ impl protobuf::PhysicalPlanNode { .map(|ordering| { let sort_exprs = serialize_physical_sort_exprs( ordering.to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok::<_, DataFusionError>(protobuf::PhysicalSortExprNodeCollection { physical_sort_expr_nodes: sort_exprs, @@ -2694,11 +3227,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_coalesce_partitions_exec( exec: &CoalescePartitionsExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Merge(Box::new( @@ -2712,21 +3247,24 @@ impl protobuf::PhysicalPlanNode { fn try_from_repartition_exec( exec: &RepartitionExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let pb_partitioning = - serialize_partitioning(exec.partitioning(), extension_codec)?; + serialize_partitioning(exec.partitioning(), codec, proto_converter)?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Repartition(Box::new( protobuf::RepartitionExecNode { input: Some(Box::new(input)), partitioning: Some(pb_partitioning), + preserve_order: exec.preserve_order(), }, ))), }) @@ -2734,29 +3272,35 @@ impl protobuf::PhysicalPlanNode { fn try_from_sort_exec( exec: &SortExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( - exec.input().to_owned(), - extension_codec, - )?; + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = proto_converter.execution_plan_to_proto(exec.input(), codec)?; let expr = exec .expr() .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(ExprType::Sort(sort_expr)), }) }) .collect::>>()?; + let dynamic_filter = exec + .dynamic_filter_expr() + .map(|df| { + let df_expr: Arc = df as Arc; + proto_converter.physical_expr_to_proto(&df_expr, codec) + }) + .transpose()?; + Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Sort(Box::new( protobuf::SortExecNode { @@ -2767,6 +3311,7 @@ impl protobuf::PhysicalPlanNode { _ => -1, }, preserve_partitioning: exec.preserve_partitioning(), + dynamic_filter, }, ))), }) @@ -2774,14 +3319,18 @@ impl protobuf::PhysicalPlanNode { fn try_from_union_exec( union: &UnionExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { let mut inputs: Vec = vec![]; for input in union.inputs() { - inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( - input.to_owned(), - extension_codec, - )?); + inputs.push( + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + input.to_owned(), + codec, + proto_converter, + )?, + ); } Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Union(protobuf::UnionExecNode { @@ -2792,14 +3341,18 @@ impl protobuf::PhysicalPlanNode { fn try_from_interleave_exec( interleave: &InterleaveExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { let mut inputs: Vec = vec![]; for input in interleave.inputs() { - inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( - input.to_owned(), - extension_codec, - )?); + inputs.push( + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + input.to_owned(), + codec, + proto_converter, + )?, + ); } Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Interleave( @@ -2810,25 +3363,27 @@ impl protobuf::PhysicalPlanNode { fn try_from_sort_preserving_merge_exec( exec: &SortPreservingMergeExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let expr = exec .expr() .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter.physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(ExprType::Sort(sort_expr)), }) }) @@ -2846,24 +3401,27 @@ impl protobuf::PhysicalPlanNode { fn try_from_nested_loop_join_exec( exec: &NestedLoopJoinExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.left().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + let right = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.right().to_owned(), - extension_codec, + codec, + proto_converter, )?; - let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let join_type = protobuf::JoinType::from_proto(exec.join_type().to_owned()); let filter = exec .filter() .as_ref() .map(|f| { let expression = - serialize_physical_expr(f.expression(), extension_codec)?; + proto_converter.physical_expr_to_proto(f.expression(), codec)?; let column_indices = f .column_indices() .iter() @@ -2891,7 +3449,7 @@ impl protobuf::PhysicalPlanNode { right: Some(Box::new(right)), join_type: join_type.into(), filter, - projection: exec.projection().map_or_else(Vec::new, |v| { + projection: exec.projection().as_ref().map_or_else(Vec::new, |v| { v.iter().map(|x| *x as u32).collect::>() }), }, @@ -2901,23 +3459,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_window_agg_exec( exec: &WindowAggExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let window_expr = exec .window_expr() .iter() - .map(|e| serialize_physical_window_expr(e, extension_codec)) + .map(|e| serialize_physical_window_expr(e, codec, proto_converter)) .collect::>>()?; let partition_keys = exec .partition_keys() .iter() - .map(|e| serialize_physical_expr(e, extension_codec)) + .map(|e| proto_converter.physical_expr_to_proto(e, codec)) .collect::>>()?; Ok(protobuf::PhysicalPlanNode { @@ -2934,23 +3494,25 @@ impl protobuf::PhysicalPlanNode { fn try_from_bounded_window_agg_exec( exec: &BoundedWindowAggExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let window_expr = exec .window_expr() .iter() - .map(|e| serialize_physical_window_expr(e, extension_codec)) + .map(|e| serialize_physical_window_expr(e, codec, proto_converter)) .collect::>>()?; let partition_keys = exec .partition_keys() .iter() - .map(|e| serialize_physical_expr(e, extension_codec)) + .map(|e| proto_converter.physical_expr_to_proto(e, codec)) .collect::>>()?; let input_order_mode = match &exec.input_order_mode { @@ -2983,12 +3545,14 @@ impl protobuf::PhysicalPlanNode { fn try_from_data_sink_exec( exec: &DataSinkExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result> { + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { let input: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan( + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; let sort_order = match exec.sort_order() { Some(requirements) => { @@ -2997,10 +3561,10 @@ impl protobuf::PhysicalPlanNode { .map(|requirement| { let expr: PhysicalSortExpr = requirement.to_owned().into(); let sort_expr = protobuf::PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr( - &expr.expr, - extension_codec, - )?)), + expr: Some(Box::new( + proto_converter + .physical_expr_to_proto(&expr.expr, codec)?, + )), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }; @@ -3014,12 +3578,12 @@ impl protobuf::PhysicalPlanNode { None => None, }; - if let Some(sink) = exec.sink().as_any().downcast_ref::() { + if let Some(sink) = exec.sink().downcast_ref::() { return Ok(Some(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::JsonSink(Box::new( protobuf::JsonSinkExecNode { input: Some(Box::new(input)), - sink: Some(sink.try_into()?), + sink: Some(protobuf::JsonSink::try_from_proto(sink)?), sink_schema: Some(exec.schema().as_ref().try_into()?), sort_order, }, @@ -3027,12 +3591,12 @@ impl protobuf::PhysicalPlanNode { })); } - if let Some(sink) = exec.sink().as_any().downcast_ref::() { + if let Some(sink) = exec.sink().downcast_ref::() { return Ok(Some(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CsvSink(Box::new( protobuf::CsvSinkExecNode { input: Some(Box::new(input)), - sink: Some(sink.try_into()?), + sink: Some(protobuf::CsvSink::try_from_proto(sink)?), sink_schema: Some(exec.schema().as_ref().try_into()?), sort_order, }, @@ -3041,12 +3605,12 @@ impl protobuf::PhysicalPlanNode { } #[cfg(feature = "parquet")] - if let Some(sink) = exec.sink().as_any().downcast_ref::() { + if let Some(sink) = exec.sink().downcast_ref::() { return Ok(Some(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetSink(Box::new( protobuf::ParquetSinkExecNode { input: Some(Box::new(input)), - sink: Some(sink.try_into()?), + sink: Some(protobuf::ParquetSink::try_from_proto(sink)?), sink_schema: Some(exec.schema().as_ref().try_into()?), sort_order, }, @@ -3060,11 +3624,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_unnest_exec( exec: &UnnestExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { @@ -3085,7 +3651,7 @@ impl protobuf::PhysicalPlanNode { .iter() .map(|c| *c as _) .collect(), - options: Some(exec.options().into()), + options: Some(protobuf::UnnestOptions::from_proto(exec.options())), }, ))), }) @@ -3093,11 +3659,13 @@ impl protobuf::PhysicalPlanNode { fn try_from_cooperative_exec( exec: &CooperativeExec, - extension_codec: &dyn PhysicalExtensionCodec, - ) -> Result { - let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( exec.input().to_owned(), - extension_codec, + codec, + proto_converter, )?; Ok(protobuf::PhysicalPlanNode { @@ -3117,7 +3685,9 @@ impl protobuf::PhysicalPlanNode { } } - fn try_from_lazy_memory_exec(exec: &LazyMemoryExec) -> Result> { + fn try_from_lazy_memory_exec( + exec: &LazyMemoryExec, + ) -> Result> { let generators = exec.generators(); // ensure we only have one generator @@ -3135,7 +3705,9 @@ impl protobuf::PhysicalPlanNode { target_batch_size: 8192, // Default batch size args: Some(protobuf::generate_series_node::Args::ContainsNull( protobuf::GenerateSeriesArgsContainsNull { - name: Self::str_to_generate_series_name(empty_gen.name())? as i32, + name: protobuf::PhysicalPlanNode::str_to_generate_series_name( + empty_gen.name(), + )? as i32, }, )), }; @@ -3159,7 +3731,9 @@ impl protobuf::PhysicalPlanNode { end: *int_64.end(), step: *int_64.step(), include_end: int_64.include_end(), - name: Self::str_to_generate_series_name(int_64.name())? as i32, + name: protobuf::PhysicalPlanNode::str_to_generate_series_name( + int_64.name(), + )? as i32, }, )), }; @@ -3186,7 +3760,9 @@ impl protobuf::PhysicalPlanNode { nanos: step_value.nanoseconds, }); let include_end = timestamp_args.include_end(); - let name = Self::str_to_generate_series_name(timestamp_args.name())? as i32; + let name = protobuf::PhysicalPlanNode::str_to_generate_series_name( + timestamp_args.name(), + )? as i32; let args = match timestamp_args.current().tz_str() { Some(tz) => protobuf::generate_series_node::Args::TimestampArgs( @@ -3223,6 +3799,96 @@ impl protobuf::PhysicalPlanNode { Ok(None) } + + fn try_from_async_func_exec( + exec: &AsyncFuncExec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(exec.input()), + codec, + proto_converter, + )?; + + let mut async_exprs = vec![]; + let mut async_expr_names = vec![]; + + for async_expr in exec.async_exprs() { + async_exprs + .push(proto_converter.physical_expr_to_proto(&async_expr.func, codec)?); + async_expr_names.push(async_expr.name.clone()) + } + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::AsyncFunc(Box::new( + protobuf::AsyncFuncExecNode { + input: Some(Box::new(input)), + async_exprs, + async_expr_names, + }, + ))), + }) + } + + fn try_from_buffer_exec( + exec: &BufferExec, + extension_codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(exec.input()), + extension_codec, + proto_converter, + )?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Buffer(Box::new( + protobuf::BufferExecNode { + input: Some(Box::new(input)), + capacity: exec.capacity() as u64, + }, + ))), + }) + } + + fn try_from_scalar_subquery_exec( + exec: &ScalarSubqueryExec, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(exec.input()), + codec, + proto_converter, + )?; + let subqueries = exec + .subqueries() + .iter() + .map(|sq| { + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(&sq.plan), + codec, + proto_converter, + ) + }) + .collect::>>()?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::ScalarSubquery(Box::new( + protobuf::ScalarSubqueryExecNode { + input: Some(Box::new(input)), + subqueries, + }, + ))), + }) + } +} + +impl PhysicalPlanNodeExt for protobuf::PhysicalPlanNode { + fn node(&self) -> &protobuf::PhysicalPlanNode { + self + } } pub trait AsExecutionPlan: Debug + Send + Sync + Clone { @@ -3239,26 +3905,32 @@ pub trait AsExecutionPlan: Debug + Send + Sync + Clone { &self, ctx: &TaskContext, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, ) -> Result>; fn try_from_physical_plan( plan: Arc, - extension_codec: &dyn PhysicalExtensionCodec, + codec: &dyn PhysicalExtensionCodec, ) -> Result where Self: Sized; } -pub trait PhysicalExtensionCodec: Debug + Send + Sync { +pub trait PhysicalExtensionCodec: Debug + Send + Sync + Any { fn try_decode( &self, buf: &[u8], inputs: &[Arc], ctx: &TaskContext, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result>; - fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()>; + fn try_encode( + &self, + node: Arc, + buf: &mut Vec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result<()>; fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { not_impl_err!("PhysicalExtensionCodec is not provided for scalar function {name}") @@ -3312,6 +3984,7 @@ impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { _buf: &[u8], _inputs: &[Arc], _ctx: &TaskContext, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { not_impl_err!("PhysicalExtensionCodec is not provided") } @@ -3320,11 +3993,65 @@ impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { &self, _node: Arc, _buf: &mut Vec, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result<()> { not_impl_err!("PhysicalExtensionCodec is not provided") } } +/// Controls the conversion of physical plans and expressions to and from their +/// Protobuf variants. Using this trait, users can perform optimizations on the +/// conversion process or collect performance metrics. +pub trait PhysicalProtoConverterExtension { + fn proto_to_execution_plan( + &self, + proto: &protobuf::PhysicalPlanNode, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result>; + + fn default_proto_to_execution_plan( + &self, + proto: &protobuf::PhysicalPlanNode, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> + where + Self: Sized, + { + proto.try_into_physical_plan_with_context(ctx, self) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result; + + fn proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + input_schema: &Schema, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result>; + + fn default_proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + input_schema: &Schema, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> + where + Self: Sized, + { + parse_physical_expr_with_converter(proto, input_schema, ctx, self) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result; +} + /// DataEncoderTuple captures the position of the encoder /// in the codec list that was used to encode the data and actual encoded data #[derive(Clone, PartialEq, prost::Message)] @@ -3338,6 +4065,184 @@ struct DataEncoderTuple { pub blob: Vec, } +pub struct DefaultPhysicalProtoConverter {} + +impl PhysicalProtoConverterExtension for DefaultPhysicalProtoConverter { + fn proto_to_execution_plan( + &self, + proto: &protobuf::PhysicalPlanNode, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> { + proto.try_into_physical_plan_with_context(ctx, self) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + codec, + self, + ) + } + + fn proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + input_schema: &Schema, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> + where + Self: Sized, + { + // Default implementation calls the free function + parse_physical_expr_with_converter(proto, input_schema, ctx, self) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + serialize_physical_expr_with_converter(expr, codec, self) + } +} + +/// Internal deserializer that caches expressions by their `expression_id()` so +/// multiple occurrences of the same expression are deduped. +#[derive(Default)] +struct DeduplicatingDeserializer { + /// Cache mapping expression_id to deserialized expressions. + cache: RefCell>>, +} + +impl PhysicalProtoConverterExtension for DeduplicatingDeserializer { + fn proto_to_execution_plan( + &self, + proto: &protobuf::PhysicalPlanNode, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> { + proto.try_into_physical_plan_with_context(ctx, self) + } + + fn execution_plan_to_proto( + &self, + _plan: &Arc, + _codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + internal_err!("DeduplicatingDeserializer cannot serialize execution plans") + } + + fn proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + input_schema: &Schema, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> + where + Self: Sized, + { + // `expr_id` is the generic identity slot on `PhysicalExprNode`. + // The default serializer populates it from `PhysicalExpr::expression_id`. + // A missing id means this expression type doesn't participate in deduping. + let Some(id) = proto.expr_id else { + return parse_physical_expr_with_converter(proto, input_schema, ctx, self); + }; + + let parsed = parse_physical_expr_with_converter(proto, input_schema, ctx, self)?; + + let mut cache = self.cache.borrow_mut(); + if let Some(cached) = cache.get(&id) { + // Since expressions may manage their own internal state when deriving + // expressions via `with_new_children`, we use `with_new_children` + // to opt into the same behavior. + // + // For example, one `DynamicFilterPhysicalExpr` may be derived from + // another resulting in shared references. Using `with_new_children` + // is meant to preserve those references. + let children: Vec<_> = parsed.children().into_iter().cloned().collect(); + return Arc::clone(cached).with_new_children(children); + } + + cache.insert(id, Arc::clone(&parsed)); + Ok(parsed) + } + + fn physical_expr_to_proto( + &self, + _expr: &Arc, + _codec: &dyn PhysicalExtensionCodec, + ) -> Result { + internal_err!("DeduplicatingDeserializer cannot serialize physical expressions") + } +} + +/// A proto converter that deduplicates [`PhysicalExpr`] by [`PhysicalExpr::expression_id`]. +/// This helps preserve referential integrity when deserializing [`ExecutionPlan`]s +/// which may contain multiple occurrences of the same [`PhysicalExpr`] (ex. when +/// [`DynamicFilterPhysicalExpr`] are pushed down, it is important to preserve +/// referential integrity). +/// +/// +/// [`DynamicFilterPhysicalExpr`]: https://docs.rs/datafusion-physical-expr/latest/datafusion_physical_expr/expressions/struct.DynamicFilterPhysicalExpr.html +#[derive(Debug, Default, Clone, Copy)] +pub struct DeduplicatingProtoConverter {} + +impl PhysicalProtoConverterExtension for DeduplicatingProtoConverter { + fn proto_to_execution_plan( + &self, + proto: &protobuf::PhysicalPlanNode, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> { + let deserializer = DeduplicatingDeserializer::default(); + proto.try_into_physical_plan_with_context(ctx, &deserializer) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + protobuf::PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + codec, + self, + ) + } + + fn proto_to_physical_expr( + &self, + proto: &protobuf::PhysicalExprNode, + input_schema: &Schema, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> + where + Self: Sized, + { + let deserializer = DeduplicatingDeserializer::default(); + deserializer.proto_to_physical_expr(proto, input_schema, ctx) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + serialize_physical_expr_with_converter(expr, codec, self) + } +} + /// A PhysicalExtensionCodec that tries one of multiple inner codecs /// until one works #[derive(Debug)] @@ -3412,12 +4317,22 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { buf: &[u8], inputs: &[Arc], ctx: &TaskContext, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - self.decode_protobuf(buf, |codec, data| codec.try_decode(data, inputs, ctx)) + self.decode_protobuf(buf, |codec, data| { + codec.try_decode(data, inputs, ctx, proto_converter) + }) } - fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { - self.encode_protobuf(buf, |codec, data| codec.try_encode(Arc::clone(&node), data)) + fn try_encode( + &self, + node: Arc, + buf: &mut Vec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result<()> { + self.encode_protobuf(buf, |codec, data| { + codec.try_encode(Arc::clone(&node), data, proto_converter) + }) } fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { @@ -3439,12 +4354,11 @@ impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { fn into_physical_plan( node: &Option>, - ctx: &TaskContext, - - extension_codec: &dyn PhysicalExtensionCodec, + ctx: &PhysicalPlanDecodeContext<'_>, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { if let Some(field) = node { - field.try_into_physical_plan(ctx, extension_codec) + proto_converter.proto_to_execution_plan(field, ctx) } else { Err(proto_error("Missing required field in protobuf")) } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 1ae85618b92ad..d9315af431e22 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -21,50 +21,62 @@ use arrow::array::RecordBatch; use arrow::datatypes::Schema; use arrow::ipc::writer::StreamWriter; use datafusion_common::{ - internal_datafusion_err, internal_err, not_impl_err, DataFusionError, Result, + DataFusionError, Result, internal_datafusion_err, internal_err, not_impl_err, }; use datafusion_datasource::file_scan_config::FileScanConfig; -use datafusion_datasource::file_sink_config::FileSink; -use datafusion_datasource::file_sink_config::FileSinkConfig; +use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; use datafusion_datasource::{FileRange, PartitionedFile}; use datafusion_datasource_csv::file_format::CsvSink; use datafusion_datasource_json::file_format::JsonSink; #[cfg(feature = "parquet")] use datafusion_datasource_parquet::file_format::ParquetSink; use datafusion_expr::WindowFrame; -use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; use datafusion_physical_expr::ScalarFunctionExpr; -use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; +use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; +use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; -use datafusion_physical_plan::expressions::LikeExpr; -use datafusion_physical_plan::expressions::{ - BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, - Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, -}; +use datafusion_physical_plan::expressions::DynamicFilterPhysicalExpr; use datafusion_physical_plan::udaf::AggregateFunctionExpr; use datafusion_physical_plan::windows::{PlainAggregateWindowExpr, WindowUDFExpr}; -use datafusion_physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; +use datafusion_physical_plan::{ + Partitioning, PhysicalExpr, RangePartitioning, SplitPoint, WindowExpr, +}; +use super::{ + DefaultPhysicalProtoConverter, PhysicalExtensionCodec, + PhysicalProtoConverterExtension, encode_human_display_alias, +}; +use crate::convert::TryFromProto; use crate::protobuf::{ - self, physical_aggregate_expr_node, physical_window_expr_node, PhysicalSortExprNode, - PhysicalSortExprNodeCollection, + self, PhysicalSortExprNode, PhysicalSortExprNodeCollection, + physical_aggregate_expr_node, physical_window_expr_node, }; -use super::PhysicalExtensionCodec; - #[expect(clippy::needless_pass_by_value)] pub fn serialize_physical_aggr_expr( aggr_expr: Arc, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { - let expressions = serialize_physical_exprs(&aggr_expr.expressions(), codec)?; - let order_bys = - serialize_physical_sort_exprs(aggr_expr.order_bys().iter().cloned(), codec)?; + let expressions = + serialize_physical_exprs(&aggr_expr.expressions(), codec, proto_converter)?; + let order_bys = serialize_physical_sort_exprs( + aggr_expr.order_bys().iter().cloned(), + codec, + proto_converter, + )?; let name = aggr_expr.fun().name().to_string(); let mut buf = Vec::new(); codec.try_encode_udaf(aggr_expr.fun(), &mut buf)?; + let human_display = match (aggr_expr.human_display(), aggr_expr.human_display_alias()) + { + (Some(display), Some(alias)) => encode_human_display_alias(display, alias), + (Some(display), None) => display.to_string(), + (None, _) => String::new(), + }; Ok(protobuf::PhysicalExprNode { + expr_id: None, expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), @@ -73,7 +85,7 @@ pub fn serialize_physical_aggr_expr( distinct: aggr_expr.is_distinct(), ignore_nulls: aggr_expr.ignore_nulls(), fun_definition: (!buf.is_empty()).then_some(buf), - human_display: aggr_expr.human_display().to_string(), + human_display, }, )), }) @@ -99,9 +111,10 @@ fn serialize_physical_window_aggr_expr( pub fn serialize_physical_window_expr( window_expr: &Arc, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let expr = window_expr.as_any(); - let args = window_expr.expressions().to_vec(); + let mut args = window_expr.expressions().to_vec(); let window_frame = window_expr.get_window_frame(); let (window_function, fun_definition, ignore_nulls, distinct) = @@ -137,6 +150,7 @@ pub fn serialize_physical_window_expr( { let mut buf = Vec::new(); codec.try_encode_udwf(expr.fun(), &mut buf)?; + args = expr.args().to_vec(); ( physical_window_expr_node::WindowFunction::UserDefinedWindowFunction( expr.fun().name().to_string(), @@ -154,12 +168,15 @@ pub fn serialize_physical_window_expr( return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; - let args = serialize_physical_exprs(&args, codec)?; - let partition_by = serialize_physical_exprs(window_expr.partition_by(), codec)?; - let order_by = serialize_physical_sort_exprs(window_expr.order_by().to_vec(), codec)?; - let window_frame: protobuf::WindowFrame = window_frame - .as_ref() - .try_into() + let args = serialize_physical_exprs(&args, codec, proto_converter)?; + let partition_by = + serialize_physical_exprs(window_expr.partition_by(), codec, proto_converter)?; + let order_by = serialize_physical_sort_exprs( + window_expr.order_by().to_vec(), + codec, + proto_converter, + )?; + let window_frame = protobuf::WindowFrame::try_from_proto(window_frame.as_ref()) .map_err(|e| internal_datafusion_err!("{e}"))?; Ok(protobuf::PhysicalWindowExprNode { @@ -178,22 +195,24 @@ pub fn serialize_physical_window_expr( pub fn serialize_physical_sort_exprs( sort_exprs: I, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> where I: IntoIterator, { sort_exprs .into_iter() - .map(|sort_expr| serialize_physical_sort_expr(sort_expr, codec)) + .map(|sort_expr| serialize_physical_sort_expr(sort_expr, codec, proto_converter)) .collect() } pub fn serialize_physical_sort_expr( sort_expr: PhysicalSortExpr, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let PhysicalSortExpr { expr, options } = sort_expr; - let expr = serialize_physical_expr(&expr, codec)?; + let expr = proto_converter.physical_expr_to_proto(&expr, codec)?; Ok(PhysicalSortExprNode { expr: Some(Box::new(expr)), asc: !options.descending, @@ -204,13 +223,14 @@ pub fn serialize_physical_sort_expr( pub fn serialize_physical_exprs<'a, I>( values: I, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> where I: IntoIterator>, { values .into_iter() - .map(|value| serialize_physical_expr(value, codec)) + .map(|value| proto_converter.physical_expr_to_proto(value, codec)) .collect() } @@ -222,145 +242,73 @@ pub fn serialize_physical_expr( value: &Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result { - // Snapshot the expr in case it has dynamic predicate state so - // it can be serialized - let value = snapshot_physical_expr(Arc::clone(value))?; - let expr = value.as_any(); + serialize_physical_expr_with_converter( + value, + codec, + &DefaultPhysicalProtoConverter {}, + ) +} - if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Column( - protobuf::PhysicalColumn { - name: expr.name().to_string(), - index: expr.index() as u32, - }, - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::UnknownColumn( - protobuf::UnknownColumn { - name: expr.name().to_string(), - }, - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { - l: Some(Box::new(serialize_physical_expr(expr.left(), codec)?)), - r: Some(Box::new(serialize_physical_expr(expr.right(), codec)?)), - op: format!("{:?}", expr.op()), - }); +/// Concrete [`PhysicalExprEncode`] driver used to back +/// [`PhysicalExprEncodeCtx`] when expressions invoke `PhysicalExpr::to_proto`. +/// +/// Wraps the existing extension codec + converter pair so individual +/// expressions can recurse into children without depending on +/// `datafusion-proto` directly. +/// +/// [`PhysicalExprEncode`]: datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncode +/// [`PhysicalExprEncodeCtx`]: datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx +struct ConverterEncoder<'a> { + codec: &'a dyn PhysicalExtensionCodec, + proto_converter: &'a dyn PhysicalProtoConverterExtension, +} - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( - binary_expr, - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::Case( - Box::new( - protobuf::PhysicalCaseNode { - expr: expr - .expr() - .map(|exp| { - serialize_physical_expr(exp, codec).map(Box::new) - }) - .transpose()?, - when_then_expr: expr - .when_then_expr() - .iter() - .map(|(when_expr, then_expr)| { - serialize_when_then_expr(when_expr, then_expr, codec) - }) - .collect::, - DataFusionError, - >>()?, - else_expr: expr - .else_expr() - .map(|a| serialize_physical_expr(a, codec).map(Box::new)) - .transpose()?, - }, - ), - ), - ), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr(Box::new( - protobuf::PhysicalNot { - expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), - }, - ))), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( - Box::new(protobuf::PhysicalIsNull { - expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( - Box::new(protobuf::PhysicalIsNotNull { - expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::InList(Box::new( - protobuf::PhysicalInListNode { - expr: Some(Box::new(serialize_physical_expr(expr.expr(), codec)?)), - list: serialize_physical_exprs(expr.list(), codec)?, - negated: expr.negated(), - }, - ))), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Negative(Box::new( - protobuf::PhysicalNegativeNode { - expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), - }, - ))), - }) - } else if let Some(lit) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( - lit.value().try_into()?, - )), - }) - } else if let Some(cast) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( - protobuf::PhysicalCastNode { - expr: Some(Box::new(serialize_physical_expr(cast.expr(), codec)?)), - arrow_type: Some(cast.cast_type().try_into()?), - }, - ))), - }) - } else if let Some(cast) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( - protobuf::PhysicalTryCastNode { - expr: Some(Box::new(serialize_physical_expr(cast.expr(), codec)?)), - arrow_type: Some(cast.cast_type().try_into()?), - }, - ))), - }) - } else if let Some(expr) = expr.downcast_ref::() { +impl datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncode + for ConverterEncoder<'_> +{ + fn encode(&self, expr: &Arc) -> Result { + self.proto_converter + .physical_expr_to_proto(expr, self.codec) + } +} + +/// Serialize a `PhysicalExpr` to default protobuf representation. +/// +/// If required, a [`PhysicalExtensionCodec`] can be provided which can handle +/// serialization of udfs requiring specialized serialization (see [`PhysicalExtensionCodec::try_encode_udf`]). +/// A [`PhysicalProtoConverterExtension`] can be provided to handle the +/// conversion process (see [`PhysicalProtoConverterExtension::physical_expr_to_proto`]). +pub fn serialize_physical_expr_with_converter( + value: &Arc, + codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, +) -> Result { + let expr = value.as_ref(); + let expr_id = value.expression_id(); + + // Give the expression a chance to serialize itself first. Returning + // `Ok(Some(node))` lets expressions with private state (e.g. + // `DynamicFilterPhysicalExpr`) avoid exposing pub-for-proto accessors. + // `Ok(None)` falls through to the downcast chain below — that's the + // default for built-in expressions which haven't been migrated yet. + let encoder = ConverterEncoder { + codec, + proto_converter, + }; + let ctx = datafusion_physical_expr_common::physical_expr::proto_encode::PhysicalExprEncodeCtx::new(&encoder); + if let Some(node) = expr.try_to_proto(&ctx)? { + return Ok(node); + } + + if let Some(expr) = expr.downcast_ref::() { let mut buf = Vec::new(); codec.try_encode_udf(expr.fun(), &mut buf)?; Ok(protobuf::PhysicalExprNode { + expr_id, expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( protobuf::PhysicalScalarUdfNode { name: expr.name().to_string(), - args: serialize_physical_exprs(expr.args(), codec)?, + args: serialize_physical_exprs(expr.args(), codec, proto_converter)?, fun_definition: (!buf.is_empty()).then_some(buf), return_type: Some(expr.return_type().try_into()?), nullable: expr.nullable(), @@ -371,30 +319,61 @@ pub fn serialize_physical_expr( }, )), }) - } else if let Some(expr) = expr.downcast_ref::() { + } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr(Box::new( - protobuf::PhysicalLikeExprNode { - negated: expr.negated(), - case_insensitive: expr.case_insensitive(), - expr: Some(Box::new(serialize_physical_expr(expr.expr(), codec)?)), - pattern: Some(Box::new(serialize_physical_expr( - expr.pattern(), - codec, - )?)), + expr_id, + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarSubquery( + protobuf::PhysicalScalarSubqueryExprNode { + data_type: Some(expr.data_type().try_into()?), + nullable: expr.nullable(), + index: expr.index().as_usize() as u32, }, - ))), + )), + }) + } else if let Some(df) = expr.downcast_ref::() { + let children = df + .original_children() + .iter() + .map(|child| proto_converter.physical_expr_to_proto(child, codec)) + .collect::>>()?; + + let remapped_children = if let Some(remapped) = df.remapped_children() { + remapped + .iter() + .map(|child| proto_converter.physical_expr_to_proto(child, codec)) + .collect::>>()? + } else { + vec![] + }; + + // Atomic snapshot of inner state. + let inner = df.inner(); + let inner_expr = + Box::new(proto_converter.physical_expr_to_proto(&inner.expr, codec)?); + + Ok(protobuf::PhysicalExprNode { + expr_id, + expr_type: Some(protobuf::physical_expr_node::ExprType::DynamicFilter( + Box::new(protobuf::PhysicalDynamicFilterNode { + children, + remapped_children, + generation: inner.generation, + inner_expr: Some(inner_expr), + is_complete: inner.is_complete, + }), + )), }) } else { let mut buf: Vec = vec![]; - match codec.try_encode_expr(&value, &mut buf) { + match codec.try_encode_expr(value, &mut buf) { Ok(_) => { let inputs: Vec = value .children() .into_iter() - .map(|e| serialize_physical_expr(e, codec)) + .map(|e| proto_converter.physical_expr_to_proto(e, codec)) .collect::>()?; Ok(protobuf::PhysicalExprNode { + expr_id, expr_type: Some(protobuf::physical_expr_node::ExprType::Extension( protobuf::PhysicalExtensionExprNode { expr: buf, inputs }, )), @@ -410,6 +389,7 @@ pub fn serialize_physical_expr( pub fn serialize_partitioning( partitioning: &Partitioning, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let serialized_partitioning = match partitioning { Partitioning::RoundRobinBatch(partition_count) => protobuf::Partitioning { @@ -418,7 +398,8 @@ pub fn serialize_partitioning( )), }, Partitioning::Hash(exprs, partition_count) => { - let serialized_exprs = serialize_physical_exprs(exprs, codec)?; + let serialized_exprs = + serialize_physical_exprs(exprs, codec, proto_converter)?; protobuf::Partitioning { partition_method: Some(protobuf::partitioning::PartitionMethod::Hash( protobuf::PhysicalHashRepartition { @@ -428,6 +409,11 @@ pub fn serialize_partitioning( )), } } + Partitioning::Range(range) => protobuf::Partitioning { + partition_method: Some(protobuf::partitioning::PartitionMethod::Range( + serialize_range_partitioning(range, codec, proto_converter)?, + )), + }, Partitioning::UnknownPartitioning(partition_count) => protobuf::Partitioning { partition_method: Some(protobuf::partitioning::PartitionMethod::Unknown( *partition_count as u64, @@ -437,21 +423,44 @@ pub fn serialize_partitioning( Ok(serialized_partitioning) } -fn serialize_when_then_expr( - when_expr: &Arc, - then_expr: &Arc, +fn serialize_range_partitioning( + range: &RangePartitioning, codec: &dyn PhysicalExtensionCodec, -) -> Result { - Ok(protobuf::PhysicalWhenThen { - when_expr: Some(serialize_physical_expr(when_expr, codec)?), - then_expr: Some(serialize_physical_expr(then_expr, codec)?), + proto_converter: &dyn PhysicalProtoConverterExtension, +) -> Result { + Ok(protobuf::PhysicalRangePartitioning { + sort_expr: serialize_physical_sort_exprs( + range.ordering().iter().cloned(), + codec, + proto_converter, + )?, + split_point: range + .split_points() + .iter() + .map(serialize_range_split_point) + .collect::>()?, }) } -impl TryFrom<&PartitionedFile> for protobuf::PartitionedFile { +fn serialize_range_split_point( + split_point: &SplitPoint, +) -> Result { + Ok(protobuf::PhysicalRangeSplitPoint { + value: split_point + .values() + .iter() + .map(|value| { + TryInto::::try_into(value) + .map_err(Into::into) + }) + .collect::>()?, + }) +} + +impl TryFromProto<&PartitionedFile> for protobuf::PartitionedFile { type Error = DataFusionError; - fn try_from(pf: &PartitionedFile) -> Result { + fn try_from_proto(pf: &PartitionedFile) -> Result { let last_modified = pf.object_meta.last_modified; let last_modified_ns = last_modified.timestamp_nanos_opt().ok_or_else(|| { DataFusionError::Plan(format!( @@ -459,6 +468,11 @@ impl TryFrom<&PartitionedFile> for protobuf::PartitionedFile { )) })? as u64; Ok(protobuf::PartitionedFile { + arrow_schema: pf + .arrow_schema + .as_ref() + .map(|s| s.as_ref().try_into()) + .transpose()?, path: pf.object_meta.location.as_ref().to_owned(), size: pf.object_meta.size, last_modified_ns, @@ -467,16 +481,20 @@ impl TryFrom<&PartitionedFile> for protobuf::PartitionedFile { .iter() .map(|v| v.try_into()) .collect::, _>>()?, - range: pf.range.as_ref().map(|r| r.try_into()).transpose()?, + range: pf + .range + .as_ref() + .map(protobuf::FileRange::try_from_proto) + .transpose()?, statistics: pf.statistics.as_ref().map(|s| s.as_ref().into()), }) } } -impl TryFrom<&FileRange> for protobuf::FileRange { +impl TryFromProto<&FileRange> for protobuf::FileRange { type Error = DataFusionError; - fn try_from(value: &FileRange) -> Result { + fn try_from_proto(value: &FileRange) -> Result { Ok(protobuf::FileRange { start: value.start, end: value.end, @@ -484,14 +502,14 @@ impl TryFrom<&FileRange> for protobuf::FileRange { } } -impl TryFrom<&[PartitionedFile]> for protobuf::FileGroup { +impl TryFromProto<&[PartitionedFile]> for protobuf::FileGroup { type Error = DataFusionError; - fn try_from(gr: &[PartitionedFile]) -> Result { + fn try_from_proto(gr: &[PartitionedFile]) -> Result { Ok(protobuf::FileGroup { files: gr .iter() - .map(|f| f.try_into()) + .map(protobuf::PartitionedFile::try_from_proto) .collect::, _>>()?, }) } @@ -500,16 +518,18 @@ impl TryFrom<&[PartitionedFile]> for protobuf::FileGroup { pub fn serialize_file_scan_config( conf: &FileScanConfig, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { let file_groups = conf .file_groups .iter() - .map(|p| p.files().try_into()) + .map(|p| protobuf::FileGroup::try_from_proto(p.files())) .collect::, _>>()?; let mut output_orderings = vec![]; for order in &conf.output_ordering { - let ordering = serialize_physical_sort_exprs(order.to_vec(), codec)?; + let ordering = + serialize_physical_sort_exprs(order.to_vec(), codec, proto_converter)?; output_orderings.push(ordering) } @@ -524,22 +544,37 @@ pub fn serialize_file_scan_config( fields.extend(conf.table_partition_cols().iter().cloned()); let schema = Arc::new( - arrow::datatypes::Schema::new(fields.clone()) - .with_metadata(conf.file_schema().metadata.clone()), + Schema::new(fields.clone()).with_metadata(conf.file_schema().metadata.clone()), ); + let projection_exprs = conf + .file_source + .projection() + .as_ref() + .map(|projection_exprs| { + let projections = projection_exprs.iter().cloned().collect::>(); + Ok::<_, DataFusionError>(protobuf::ProjectionExprs { + projections: projections + .into_iter() + .map(|expr| { + Ok(protobuf::ProjectionExpr { + alias: expr.alias.to_string(), + expr: Some( + proto_converter + .physical_expr_to_proto(&expr.expr, codec)?, + ), + }) + }) + .collect::>>()?, + }) + }) + .transpose()?; + Ok(protobuf::FileScanExecConf { file_groups, statistics: Some((&conf.statistics()).into()), limit: conf.limit.map(|l| protobuf::ScanLimit { limit: l as u32 }), - projection: conf - .projection_exprs - .as_ref() - .map(|p| p.column_indices()) - .unwrap_or((0..schema.fields().len()).collect::>()) - .iter() - .map(|n| *n as u32) - .collect(), + projection: vec![], schema: Some(schema.as_ref().try_into()?), table_partition_cols: conf .table_partition_cols() @@ -555,17 +590,20 @@ pub fn serialize_file_scan_config( .collect::>(), constraints: Some(conf.constraints.clone().into()), batch_size: conf.batch_size.map(|s| s as u64), + projection_exprs, + partitioned_by_file_group: Some(conf.partitioned_by_file_group), }) } pub fn serialize_maybe_filter( expr: Option>, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result { match expr { None => Ok(protobuf::MaybeFilter { expr: None }), Some(expr) => Ok(protobuf::MaybeFilter { - expr: Some(serialize_physical_expr(&expr, codec)?), + expr: Some(proto_converter.physical_expr_to_proto(&expr, codec)?), }), } } @@ -584,48 +622,48 @@ pub fn serialize_record_batches(batches: &[RecordBatch]) -> Result> { Ok(buf) } -impl TryFrom<&JsonSink> for protobuf::JsonSink { +impl TryFromProto<&JsonSink> for protobuf::JsonSink { type Error = DataFusionError; - fn try_from(value: &JsonSink) -> Result { + fn try_from_proto(value: &JsonSink) -> Result { Ok(Self { - config: Some(value.config().try_into()?), + config: Some(protobuf::FileSinkConfig::try_from_proto(value.config())?), writer_options: Some(value.writer_options().try_into()?), }) } } -impl TryFrom<&CsvSink> for protobuf::CsvSink { +impl TryFromProto<&CsvSink> for protobuf::CsvSink { type Error = DataFusionError; - fn try_from(value: &CsvSink) -> Result { + fn try_from_proto(value: &CsvSink) -> Result { Ok(Self { - config: Some(value.config().try_into()?), + config: Some(protobuf::FileSinkConfig::try_from_proto(value.config())?), writer_options: Some(value.writer_options().try_into()?), }) } } #[cfg(feature = "parquet")] -impl TryFrom<&ParquetSink> for protobuf::ParquetSink { +impl TryFromProto<&ParquetSink> for protobuf::ParquetSink { type Error = DataFusionError; - fn try_from(value: &ParquetSink) -> Result { + fn try_from_proto(value: &ParquetSink) -> Result { Ok(Self { - config: Some(value.config().try_into()?), + config: Some(protobuf::FileSinkConfig::try_from_proto(value.config())?), parquet_options: Some(value.parquet_options().try_into()?), }) } } -impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { +impl TryFromProto<&FileSinkConfig> for protobuf::FileSinkConfig { type Error = DataFusionError; - fn try_from(conf: &FileSinkConfig) -> Result { + fn try_from_proto(conf: &FileSinkConfig) -> Result { let file_groups = conf .file_group .iter() - .map(TryInto::try_into) + .map(protobuf::PartitionedFile::try_from_proto) .collect::>>()?; let table_paths = conf .table_paths @@ -642,6 +680,17 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { }) }) .collect::>>()?; + let file_output_mode = match conf.file_output_mode { + datafusion_datasource::file_sink_config::FileOutputMode::Automatic => { + protobuf::FileOutputMode::Automatic + } + datafusion_datasource::file_sink_config::FileOutputMode::SingleFile => { + protobuf::FileOutputMode::SingleFile + } + datafusion_datasource::file_sink_config::FileOutputMode::Directory => { + protobuf::FileOutputMode::Directory + } + }; Ok(Self { object_store_url: conf.object_store_url.to_string(), file_groups, @@ -651,6 +700,7 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { keep_partition_by_columns: conf.keep_partition_by_columns, insert_op: conf.insert_op as i32, file_extension: conf.file_extension.to_string(), + file_output_mode: file_output_mode.into(), }) } } diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index aec6c1de30309..3abbaccf79673 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -26,7 +26,6 @@ use datafusion_expr::{ }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; -use std::any::Any; use std::fmt::Debug; use std::hash::Hash; use std::sync::Arc; @@ -56,9 +55,6 @@ impl MyRegexUdf { /// Implement the ScalarUDFImpl trait for MyRegexUdf impl ScalarUDFImpl for MyRegexUdf { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "regex_udf" } @@ -105,9 +101,6 @@ impl MyAggregateUDF { } impl AggregateUDFImpl for MyAggregateUDF { - fn as_any(&self) -> &dyn Any { - self - } fn name(&self) -> &str { "aggregate_udf" } @@ -150,10 +143,6 @@ impl CustomUDWF { } impl WindowUDFImpl for CustomUDWF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "custom_udwf" } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 989589dfb8b2d..7f1d0a666fdce 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -16,21 +16,23 @@ // under the License. use arrow::array::{ - ArrayRef, FixedSizeListArray, Int32Builder, MapArray, MapBuilder, StringBuilder, + ArrayRef, FixedSizeListArray, Int32Builder, LargeListViewArray, ListViewArray, + MapArray, MapBuilder, StringBuilder, }; use arrow::datatypes::{ - DataType, Field, FieldRef, Fields, Int32Type, IntervalDayTimeType, - IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, - UnionMode, DECIMAL256_MAX_PRECISION, + DECIMAL256_MAX_PRECISION, DataType, Field, FieldRef, Fields, Int32Type, + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, + TimeUnit, UnionFields, UnionMode, }; use arrow::util::pretty::pretty_format_batches; use datafusion::datasource::file_format::json::{JsonFormat, JsonFormatFactory}; use datafusion::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; -use datafusion::execution::options::ArrowReadOptions; -use datafusion::optimizer::optimize_unions::OptimizeUnions; +use datafusion::execution::options::{ArrowReadOptions, JsonReadOptions}; use datafusion::optimizer::Optimizer; +use datafusion::optimizer::optimize_unions::OptimizeUnions; +use datafusion_common::parquet_config::DFParquetWriterVersion; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_functions_aggregate::sum::sum_distinct; use prost::Message; @@ -42,13 +44,13 @@ use std::sync::Arc; use std::vec; use datafusion::catalog::{TableProvider, TableProviderFactory}; +use datafusion::datasource::DefaultTableSource; use datafusion::datasource::file_format::arrow::ArrowFormatFactory; use datafusion::datasource::file_format::csv::CsvFormatFactory; use datafusion::datasource::file_format::parquet::ParquetFormatFactory; -use datafusion::datasource::file_format::{format_as_file_type, DefaultFileType}; -use datafusion::datasource::DefaultTableSource; -use datafusion::execution::session_state::SessionStateBuilder; +use datafusion::datasource::file_format::{DefaultFileType, format_as_file_type}; use datafusion::execution::FunctionRegistry; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::expr_fn::{ approx_median, approx_percentile_cont, approx_percentile_cont_with_weight, count, @@ -66,10 +68,13 @@ use datafusion::physical_expr::PhysicalExpr; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::TableOptions; +use datafusion_common::format::{ + ExplainAnalyzeCategories, ExplainFormat, MetricCategory, MetricType, +}; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ - internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, DFSchemaRef, - DataFusionError, Result, ScalarValue, TableReference, + DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, + internal_datafusion_err, internal_err, not_impl_err, plan_err, }; use datafusion_execution::TaskContext; use datafusion_expr::dml::CopyTo; @@ -77,7 +82,9 @@ use datafusion_expr::expr::{ self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like, NullTreatment, ScalarFunction, Unnest, WildcardOptions, }; -use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; +use datafusion_expr::logical_plan::{ + ExplainOption, Extension, UserDefinedLogicalNodeCore, +}; use datafusion_expr::{ Accumulator, AggregateUDF, ColumnarValue, ExprFunctionExt, ExprSchemable, LimitEffect, Literal, LogicalPlan, LogicalPlanBuilder, Operator, PartitionEvaluator, @@ -102,7 +109,7 @@ use datafusion_proto::logical_plan::file_formats::{ }; use datafusion_proto::logical_plan::to_proto::serialize_expr; use datafusion_proto::logical_plan::{ - from_proto, DefaultLogicalExtensionCodec, LogicalExtensionCodec, + DefaultLogicalExtensionCodec, LogicalExtensionCodec, from_proto, }; use datafusion_proto::protobuf; @@ -132,7 +139,8 @@ fn roundtrip_expr_test_with_codec( ) { let proto: protobuf::LogicalExprNode = serialize_expr(&initial_struct, codec) .unwrap_or_else(|e| panic!("Error serializing expression: {e:?}")); - let round_trip: Expr = from_proto::parse_expr(&proto, &ctx, codec).unwrap(); + let round_trip: Expr = + from_proto::parse_expr(&proto, ctx.task_ctx().as_ref(), codec).unwrap(); assert_eq!(format!("{:?}", &initial_struct), format!("{round_trip:?}")); @@ -214,8 +222,6 @@ impl LogicalExtensionCodec for TestTableProviderCodec { buf: &mut Vec, ) -> Result<()> { let table = node - .as_ref() - .as_any() .downcast_ref::() .expect("Can't encode non-test tables"); let msg = TestTableProto { @@ -276,6 +282,95 @@ async fn roundtrip_custom_memory_tables() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_explain_format_tree() -> Result<()> { + let ctx = SessionContext::new(); + let plan = ctx + .state() + .create_logical_plan("EXPLAIN FORMAT TREE SELECT 1") + .await?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; + + match logical_round_trip { + LogicalPlan::Explain(explain) => { + assert_eq!(explain.explain_format, ExplainFormat::Tree); + } + plan => panic!("expected Explain plan, got {plan:?}"), + } + + Ok(()) +} + +/// Build an `EXPLAIN`/`EXPLAIN ANALYZE` plan with statement-level overrides +/// set directly via the builder, then assert the proto round-trip preserves +/// every field. Going through the builder avoids depending on parser support +/// for the parenthesized option syntax in this test crate. +async fn assert_explain_roundtrip(option: ExplainOption) -> Result<()> { + let ctx = SessionContext::new(); + let input = ctx.sql("SELECT 1 AS x").await?.into_optimized_plan()?; + let plan = LogicalPlanBuilder::from(input) + .explain_option_format(option)? + .build()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let round_trip = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; + assert_eq!(plan, round_trip); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_explain_show_statistics_override() -> Result<()> { + for show_statistics in [None, Some(true), Some(false)] { + assert_explain_roundtrip( + ExplainOption::default() + .with_format(ExplainFormat::Indent) + .with_show_statistics(show_statistics), + ) + .await?; + } + Ok(()) +} + +#[tokio::test] +async fn roundtrip_analyze_level_override() -> Result<()> { + for analyze_level in [None, Some(MetricType::Summary), Some(MetricType::Dev)] { + assert_explain_roundtrip( + ExplainOption::default() + .with_analyze(true) + .with_analyze_level(analyze_level), + ) + .await?; + } + Ok(()) +} + +#[tokio::test] +async fn roundtrip_analyze_categories_override() -> Result<()> { + let cases = [ + None, + Some(ExplainAnalyzeCategories::All), + Some(ExplainAnalyzeCategories::Only(vec![])), + Some(ExplainAnalyzeCategories::Only(vec![MetricCategory::Rows])), + Some(ExplainAnalyzeCategories::Only(vec![ + MetricCategory::Rows, + MetricCategory::Bytes, + MetricCategory::Timing, + MetricCategory::Uncategorized, + ])), + ]; + for analyze_categories in cases { + assert_explain_roundtrip( + ExplainOption::default() + .with_analyze(true) + .with_analyze_categories(analyze_categories), + ) + .await?; + } + Ok(()) +} + #[tokio::test] async fn roundtrip_custom_listing_tables() -> Result<()> { let ctx = SessionContext::new(); @@ -412,6 +507,7 @@ async fn roundtrip_logical_plan_dml() -> Result<()> { "DELETE FROM T1", "UPDATE T1 SET a = 1", "CREATE TABLE T2 AS SELECT * FROM T1", + "TRUNCATE TABLE T1", ]; for query in queries { let plan = ctx.sql(query).await?.into_optimized_plan()?; @@ -464,7 +560,7 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { parquet_format.global.bloom_filter_on_read = true; parquet_format.global.created_by = "DataFusion Test".to_string(); - parquet_format.global.writer_version = "PARQUET_2_0".to_string(); + parquet_format.global.writer_version = DFParquetWriterVersion::V2_0; parquet_format.global.write_batch_size = 111; parquet_format.global.data_pagesize_limit = 222; parquet_format.global.data_page_row_count_limit = 333; @@ -549,6 +645,8 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { csv_format.timestamp_format = Some("HH:mm:ss.SSSSSS".to_string()); csv_format.time_format = Some("HH:mm:ss".to_string()); csv_format.null_value = Some("NIL".to_string()); + csv_format.compression = CompressionTypeVariant::GZIP; + csv_format.compression_level = Some(6); let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new_with_options( csv_format.clone(), @@ -584,7 +682,6 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { let format_factory = file_type.as_format_factory(); let csv_factory = format_factory .as_ref() - .as_any() .downcast_ref::() .unwrap(); let csv_config = csv_factory.options.as_ref().unwrap(); @@ -593,7 +690,9 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { assert_eq!(csv_format.datetime_format, csv_config.datetime_format); assert_eq!(csv_format.timestamp_format, csv_config.timestamp_format); assert_eq!(csv_format.time_format, csv_config.time_format); - assert_eq!(csv_format.null_value, csv_config.null_value) + assert_eq!(csv_format.null_value, csv_config.null_value); + assert_eq!(csv_format.compression, csv_config.compression); + assert_eq!(csv_format.compression_level, csv_config.compression_level); } _ => panic!(), } @@ -651,7 +750,6 @@ async fn roundtrip_logical_plan_copy_to_json() -> Result<()> { let format_factory = file_type.as_format_factory(); let json_factory = format_factory .as_ref() - .as_any() .downcast_ref::() .unwrap(); let json_config = json_factory.options.as_ref().unwrap(); @@ -723,7 +821,6 @@ async fn roundtrip_logical_plan_copy_to_parquet() -> Result<()> { let format_factory = file_type.as_format_factory(); let parquet_factory = format_factory .as_ref() - .as_any() .downcast_ref::() .unwrap(); let parquet_config = parquet_factory.options.as_ref().unwrap(); @@ -737,6 +834,189 @@ async fn roundtrip_logical_plan_copy_to_parquet() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_default_codec_csv() -> Result<()> { + let ctx = SessionContext::new(); + let input = create_csv_scan(&ctx).await?; + + let table_options = + TableOptions::default_from_session_config(ctx.state().config_options()); + let mut csv_format = table_options.csv; + csv_format.delimiter = b'|'; + csv_format.has_header = Some(true); + csv_format.compression = CompressionTypeVariant::GZIP; + + let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new_with_options( + csv_format.clone(), + ))); + + let plan = LogicalPlan::Copy(CopyTo::new( + Arc::new(input), + "test.csv".to_string(), + vec![], + file_type, + Default::default(), + )); + + let bytes = logical_plan_to_bytes(&plan)?; + let roundtrip = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; + + match roundtrip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.csv", copy_to.output_url); + assert_eq!("csv", copy_to.file_type.get_ext()); + let dt = copy_to + .file_type + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let csv = dt + .as_format_factory() + .as_ref() + .downcast_ref::() + .unwrap(); + let decoded = csv.options.as_ref().unwrap(); + assert_eq!(csv_format.delimiter, decoded.delimiter); + assert_eq!(csv_format.has_header, decoded.has_header); + assert_eq!(csv_format.compression, decoded.compression); + } + _ => panic!("Expected CopyTo plan"), + } + Ok(()) +} + +#[tokio::test] +async fn roundtrip_default_codec_json() -> Result<()> { + let ctx = SessionContext::new(); + let input = create_json_scan(&ctx).await?; + + let table_options = + TableOptions::default_from_session_config(ctx.state().config_options()); + let mut json_format = table_options.json; + json_format.compression = CompressionTypeVariant::GZIP; + json_format.schema_infer_max_rec = Some(500); + + let file_type = format_as_file_type(Arc::new(JsonFormatFactory::new_with_options( + json_format.clone(), + ))); + + let plan = LogicalPlan::Copy(CopyTo::new( + Arc::new(input), + "test.json".to_string(), + vec![], + file_type, + Default::default(), + )); + + let bytes = logical_plan_to_bytes(&plan)?; + let roundtrip = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; + + match roundtrip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.json", copy_to.output_url); + assert_eq!("json", copy_to.file_type.get_ext()); + let dt = copy_to + .file_type + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let json = dt + .as_format_factory() + .as_ref() + .downcast_ref::() + .unwrap(); + let decoded = json.options.as_ref().unwrap(); + assert_eq!(json_format.compression, decoded.compression); + assert_eq!( + json_format.schema_infer_max_rec, + decoded.schema_infer_max_rec + ); + } + _ => panic!("Expected CopyTo plan"), + } + Ok(()) +} + +#[tokio::test] +async fn roundtrip_default_codec_parquet() -> Result<()> { + let ctx = SessionContext::new(); + let input = create_parquet_scan(&ctx).await?; + + let table_options = + TableOptions::default_from_session_config(ctx.state().config_options()); + let mut parquet_format = table_options.parquet; + parquet_format.global.bloom_filter_on_read = true; + parquet_format.global.created_by = "DefaultCodecTest".to_string(); + + let file_type = format_as_file_type(Arc::new( + ParquetFormatFactory::new_with_options(parquet_format.clone()), + )); + + let plan = LogicalPlan::Copy(CopyTo::new( + Arc::new(input), + "test.parquet".to_string(), + vec![], + file_type, + Default::default(), + )); + + let bytes = logical_plan_to_bytes(&plan)?; + let roundtrip = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; + + match roundtrip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.parquet", copy_to.output_url); + assert_eq!("parquet", copy_to.file_type.get_ext()); + let dt = copy_to + .file_type + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let pq = dt + .as_format_factory() + .as_ref() + .downcast_ref::() + .unwrap(); + let decoded = pq.options.as_ref().unwrap(); + assert!(decoded.global.bloom_filter_on_read); + assert_eq!("DefaultCodecTest", decoded.global.created_by); + } + _ => panic!("Expected CopyTo plan"), + } + Ok(()) +} + +#[tokio::test] +async fn roundtrip_default_codec_arrow() -> Result<()> { + let ctx = SessionContext::new(); + let input = create_csv_scan(&ctx).await?; + + let file_type = format_as_file_type(Arc::new(ArrowFormatFactory::new())); + + let plan = LogicalPlan::Copy(CopyTo::new( + Arc::new(input), + "test.arrow".to_string(), + vec![], + file_type, + Default::default(), + )); + + let bytes = logical_plan_to_bytes(&plan)?; + let roundtrip = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; + + match roundtrip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.arrow", copy_to.output_url); + assert_eq!("arrow", copy_to.file_type.get_ext()); + } + _ => panic!("Expected CopyTo plan"), + } + Ok(()) +} + async fn create_csv_scan(ctx: &SessionContext) -> Result { ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) .await?; @@ -749,7 +1029,7 @@ async fn create_json_scan(ctx: &SessionContext) -> Result Result<()> let prepared = LogicalPlanBuilder::new(plan) .prepare( "".to_string(), - vec![Field::new("", DataType::Int32, true) - .with_metadata( - [("some_key".to_string(), "some_value".to_string())].into(), - ) - .into()], + vec![ + Field::new("", DataType::Int32, true) + .with_metadata( + [("some_key".to_string(), "some_value".to_string())].into(), + ) + .into(), + ], ) .unwrap() .plan() @@ -1305,7 +1587,9 @@ impl LogicalExtensionCodec for UDFExtensionCodec { fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { let binding = node.inner(); - let udf = binding.as_any().downcast_ref::().unwrap(); + let udf = (binding.as_ref() as &dyn Any) + .downcast_ref::() + .unwrap(); let proto = MyRegexUdfNode { pattern: udf.pattern.clone(), }; @@ -1331,7 +1615,9 @@ impl LogicalExtensionCodec for UDFExtensionCodec { fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { let binding = node.inner(); - let udf = binding.as_any().downcast_ref::().unwrap(); + let udf = (binding.as_ref() as &dyn Any) + .downcast_ref::() + .unwrap(); let proto = MyAggregateUdfNode { result: udf.result.clone(), }; @@ -1344,253 +1630,277 @@ impl LogicalExtensionCodec for UDFExtensionCodec { #[test] fn round_trip_scalar_values_and_data_types() { - let should_pass: Vec = vec![ - ScalarValue::Boolean(None), - ScalarValue::Float32(None), - ScalarValue::Float64(None), - ScalarValue::Int8(None), - ScalarValue::Int16(None), - ScalarValue::Int32(None), - ScalarValue::Int64(None), - ScalarValue::UInt8(None), - ScalarValue::UInt16(None), - ScalarValue::UInt32(None), - ScalarValue::UInt64(None), - ScalarValue::Utf8(None), - ScalarValue::LargeUtf8(None), - ScalarValue::List(ScalarValue::new_list_nullable(&[], &DataType::Boolean)), - ScalarValue::LargeList(ScalarValue::new_large_list(&[], &DataType::Boolean)), - ScalarValue::Date32(None), - ScalarValue::Boolean(Some(true)), - ScalarValue::Boolean(Some(false)), - ScalarValue::Float32(Some(1.0)), - ScalarValue::Float32(Some(f32::MAX)), - ScalarValue::Float32(Some(f32::MIN)), - ScalarValue::Float32(Some(-2000.0)), - ScalarValue::Float64(Some(1.0)), - ScalarValue::Float64(Some(f64::MAX)), - ScalarValue::Float64(Some(f64::MIN)), - ScalarValue::Float64(Some(-2000.0)), - ScalarValue::Int8(Some(i8::MIN)), - ScalarValue::Int8(Some(i8::MAX)), - ScalarValue::Int8(Some(0)), - ScalarValue::Int8(Some(-15)), - ScalarValue::Int16(Some(i16::MIN)), - ScalarValue::Int16(Some(i16::MAX)), - ScalarValue::Int16(Some(0)), - ScalarValue::Int16(Some(-15)), - ScalarValue::Int32(Some(i32::MIN)), - ScalarValue::Int32(Some(i32::MAX)), - ScalarValue::Int32(Some(0)), - ScalarValue::Int32(Some(-15)), - ScalarValue::Int64(Some(i64::MIN)), - ScalarValue::Int64(Some(i64::MAX)), - ScalarValue::Int64(Some(0)), - ScalarValue::Int64(Some(-15)), - ScalarValue::UInt8(Some(u8::MAX)), - ScalarValue::UInt8(Some(0)), - ScalarValue::UInt16(Some(u16::MAX)), - ScalarValue::UInt16(Some(0)), - ScalarValue::UInt32(Some(u32::MAX)), - ScalarValue::UInt32(Some(0)), - ScalarValue::UInt64(Some(u64::MAX)), - ScalarValue::UInt64(Some(0)), - ScalarValue::Utf8(Some(String::from("Test string "))), - ScalarValue::LargeUtf8(Some(String::from("Test Large utf8"))), - ScalarValue::Utf8View(Some(String::from("Test stringview"))), - ScalarValue::BinaryView(Some(b"binaryview".to_vec())), - ScalarValue::Date32(Some(0)), - ScalarValue::Date32(Some(i32::MAX)), - ScalarValue::Date32(None), - ScalarValue::Date64(Some(0)), - ScalarValue::Date64(Some(i64::MAX)), - ScalarValue::Date64(None), - ScalarValue::Time32Second(Some(0)), - ScalarValue::Time32Second(Some(i32::MAX)), - ScalarValue::Time32Second(None), - ScalarValue::Time32Millisecond(Some(0)), - ScalarValue::Time32Millisecond(Some(i32::MAX)), - ScalarValue::Time32Millisecond(None), - ScalarValue::Time64Microsecond(Some(0)), - ScalarValue::Time64Microsecond(Some(i64::MAX)), - ScalarValue::Time64Microsecond(None), - ScalarValue::Time64Nanosecond(Some(0)), - ScalarValue::Time64Nanosecond(Some(i64::MAX)), - ScalarValue::Time64Nanosecond(None), - ScalarValue::TimestampNanosecond(Some(0), None), - ScalarValue::TimestampNanosecond(Some(i64::MAX), None), - ScalarValue::TimestampNanosecond(Some(0), Some("UTC".into())), - ScalarValue::TimestampNanosecond(None, None), - ScalarValue::TimestampMicrosecond(Some(0), None), - ScalarValue::TimestampMicrosecond(Some(i64::MAX), None), - ScalarValue::TimestampMicrosecond(Some(0), Some("UTC".into())), - ScalarValue::TimestampMicrosecond(None, None), - ScalarValue::TimestampMillisecond(Some(0), None), - ScalarValue::TimestampMillisecond(Some(i64::MAX), None), - ScalarValue::TimestampMillisecond(Some(0), Some("UTC".into())), - ScalarValue::TimestampMillisecond(None, None), - ScalarValue::TimestampSecond(Some(0), None), - ScalarValue::TimestampSecond(Some(i64::MAX), None), - ScalarValue::TimestampSecond(Some(0), Some("UTC".into())), - ScalarValue::TimestampSecond(None, None), - ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(0, 0))), - ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(1, 2))), - ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value( - i32::MAX, - i32::MAX, - ))), - ScalarValue::IntervalDayTime(None), - ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( - 0, 0, 0, - ))), - ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( - 1, 2, 3, - ))), - ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNanoType::make_value( - i32::MAX, - i32::MAX, - i64::MAX, - ))), - ScalarValue::IntervalMonthDayNano(None), - ScalarValue::List(ScalarValue::new_list_nullable( - &[ - ScalarValue::Float32(Some(-213.1)), - ScalarValue::Float32(None), - ScalarValue::Float32(Some(5.5)), - ScalarValue::Float32(Some(2.0)), - ScalarValue::Float32(Some(1.0)), - ], - &DataType::Float32, - )), - ScalarValue::LargeList(ScalarValue::new_large_list( - &[ - ScalarValue::Float32(Some(-213.1)), - ScalarValue::Float32(None), - ScalarValue::Float32(Some(5.5)), - ScalarValue::Float32(Some(2.0)), - ScalarValue::Float32(Some(1.0)), - ], - &DataType::Float32, - )), - ScalarValue::List(ScalarValue::new_list_nullable( - &[ - ScalarValue::List(ScalarValue::new_list_nullable( - &[], - &DataType::Float32, - )), - ScalarValue::List(ScalarValue::new_list_nullable( - &[ - ScalarValue::Float32(Some(-213.1)), - ScalarValue::Float32(None), - ScalarValue::Float32(Some(5.5)), - ScalarValue::Float32(Some(2.0)), - ScalarValue::Float32(Some(1.0)), - ], - &DataType::Float32, - )), - ], - &DataType::List(new_arc_field("item", DataType::Float32, true)), - )), - ScalarValue::LargeList(ScalarValue::new_large_list( - &[ - ScalarValue::LargeList(ScalarValue::new_large_list( - &[], - &DataType::Float32, - )), - ScalarValue::LargeList(ScalarValue::new_large_list( - &[ - ScalarValue::Float32(Some(-213.1)), - ScalarValue::Float32(None), - ScalarValue::Float32(Some(5.5)), - ScalarValue::Float32(Some(2.0)), - ScalarValue::Float32(Some(1.0)), - ], - &DataType::Float32, - )), - ], - &DataType::LargeList(new_arc_field("item", DataType::Float32, true)), - )), - ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::from_iter_primitive::< - Int32Type, - _, - _, - >( - vec![Some(vec![Some(1), Some(2), Some(3)])], - 3, - ))), - ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::from("foo")), - ), - ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::Utf8(None)), - ), - ScalarValue::Binary(Some(b"bar".to_vec())), - ScalarValue::Binary(None), - ScalarValue::LargeBinary(Some(b"bar".to_vec())), - ScalarValue::LargeBinary(None), - ScalarStructBuilder::new() - .with_scalar( + let should_pass: Vec = + vec![ + ScalarValue::Boolean(None), + ScalarValue::Float32(None), + ScalarValue::Float64(None), + ScalarValue::Int8(None), + ScalarValue::Int16(None), + ScalarValue::Int32(None), + ScalarValue::Int64(None), + ScalarValue::UInt8(None), + ScalarValue::UInt16(None), + ScalarValue::UInt32(None), + ScalarValue::UInt64(None), + ScalarValue::Utf8(None), + ScalarValue::LargeUtf8(None), + ScalarValue::List(ScalarValue::new_list_nullable(&[], &DataType::Boolean)), + ScalarValue::LargeList(ScalarValue::new_large_list(&[], &DataType::Boolean)), + ScalarValue::Date32(None), + ScalarValue::Boolean(Some(true)), + ScalarValue::Boolean(Some(false)), + ScalarValue::Float32(Some(1.0)), + ScalarValue::Float32(Some(f32::MAX)), + ScalarValue::Float32(Some(f32::MIN)), + ScalarValue::Float32(Some(-2000.0)), + ScalarValue::Float64(Some(1.0)), + ScalarValue::Float64(Some(f64::MAX)), + ScalarValue::Float64(Some(f64::MIN)), + ScalarValue::Float64(Some(-2000.0)), + ScalarValue::Int8(Some(i8::MIN)), + ScalarValue::Int8(Some(i8::MAX)), + ScalarValue::Int8(Some(0)), + ScalarValue::Int8(Some(-15)), + ScalarValue::Int16(Some(i16::MIN)), + ScalarValue::Int16(Some(i16::MAX)), + ScalarValue::Int16(Some(0)), + ScalarValue::Int16(Some(-15)), + ScalarValue::Int32(Some(i32::MIN)), + ScalarValue::Int32(Some(i32::MAX)), + ScalarValue::Int32(Some(0)), + ScalarValue::Int32(Some(-15)), + ScalarValue::Int64(Some(i64::MIN)), + ScalarValue::Int64(Some(i64::MAX)), + ScalarValue::Int64(Some(0)), + ScalarValue::Int64(Some(-15)), + ScalarValue::UInt8(Some(u8::MAX)), + ScalarValue::UInt8(Some(0)), + ScalarValue::UInt16(Some(u16::MAX)), + ScalarValue::UInt16(Some(0)), + ScalarValue::UInt32(Some(u32::MAX)), + ScalarValue::UInt32(Some(0)), + ScalarValue::UInt64(Some(u64::MAX)), + ScalarValue::UInt64(Some(0)), + ScalarValue::Utf8(Some(String::from("Test string "))), + ScalarValue::LargeUtf8(Some(String::from("Test Large utf8"))), + ScalarValue::Utf8View(Some(String::from("Test stringview"))), + ScalarValue::BinaryView(Some(b"binaryview".to_vec())), + ScalarValue::Date32(Some(0)), + ScalarValue::Date32(Some(i32::MAX)), + ScalarValue::Date32(None), + ScalarValue::Date64(Some(0)), + ScalarValue::Date64(Some(i64::MAX)), + ScalarValue::Date64(None), + ScalarValue::Time32Second(Some(0)), + ScalarValue::Time32Second(Some(i32::MAX)), + ScalarValue::Time32Second(None), + ScalarValue::Time32Millisecond(Some(0)), + ScalarValue::Time32Millisecond(Some(i32::MAX)), + ScalarValue::Time32Millisecond(None), + ScalarValue::Time64Microsecond(Some(0)), + ScalarValue::Time64Microsecond(Some(i64::MAX)), + ScalarValue::Time64Microsecond(None), + ScalarValue::Time64Nanosecond(Some(0)), + ScalarValue::Time64Nanosecond(Some(i64::MAX)), + ScalarValue::Time64Nanosecond(None), + ScalarValue::TimestampNanosecond(Some(0), None), + ScalarValue::TimestampNanosecond(Some(i64::MAX), None), + ScalarValue::TimestampNanosecond(Some(0), Some("UTC".into())), + ScalarValue::TimestampNanosecond(None, None), + ScalarValue::TimestampMicrosecond(Some(0), None), + ScalarValue::TimestampMicrosecond(Some(i64::MAX), None), + ScalarValue::TimestampMicrosecond(Some(0), Some("UTC".into())), + ScalarValue::TimestampMicrosecond(None, None), + ScalarValue::TimestampMillisecond(Some(0), None), + ScalarValue::TimestampMillisecond(Some(i64::MAX), None), + ScalarValue::TimestampMillisecond(Some(0), Some("UTC".into())), + ScalarValue::TimestampMillisecond(None, None), + ScalarValue::TimestampSecond(Some(0), None), + ScalarValue::TimestampSecond(Some(i64::MAX), None), + ScalarValue::TimestampSecond(Some(0), Some("UTC".into())), + ScalarValue::TimestampSecond(None, None), + ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(0, 0))), + ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(1, 2))), + ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value( + i32::MAX, + i32::MAX, + ))), + ScalarValue::IntervalDayTime(None), + ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNanoType::make_value(0, 0, 0), + )), + ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNanoType::make_value(1, 2, 3), + )), + ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNanoType::make_value(i32::MAX, i32::MAX, i64::MAX), + )), + ScalarValue::IntervalMonthDayNano(None), + ScalarValue::List(ScalarValue::new_list_nullable( + &[ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ], + &DataType::Float32, + )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ], + &DataType::Float32, + )), + ScalarValue::List(ScalarValue::new_list_nullable( + &[ + ScalarValue::List(ScalarValue::new_list_nullable( + &[], + &DataType::Float32, + )), + ScalarValue::List(ScalarValue::new_list_nullable( + &[ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ], + &DataType::Float32, + )), + ], + &DataType::List(new_arc_field("item", DataType::Float32, true)), + )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::LargeList(ScalarValue::new_large_list( + &[], + &DataType::Float32, + )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ], + &DataType::Float32, + )), + ], + &DataType::LargeList(new_arc_field("item", DataType::Float32, true)), + )), + ScalarValue::FixedSizeList(Arc::new( + FixedSizeListArray::from_iter_primitive::( + vec![Some(vec![Some(1), Some(2), Some(3)])], + 3, + ), + )), + ScalarValue::ListView(Arc::new(ListViewArray::from_iter_primitive::< + Int32Type, + _, + _, + >(vec![Some(vec![ + Some(1), + None, + Some(3), + ])]))), + ScalarValue::LargeListView(Arc::new( + LargeListViewArray::from_iter_primitive::(vec![Some( + vec![Some(1), None, Some(3)], + )]), + )), + ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::from("foo")), + ), + ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Utf8(None)), + ), + ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Field::new("values", DataType::Utf8, true).into(), + Box::new(ScalarValue::from("foo")), + ), + ScalarValue::RunEndEncoded( + Field::new("run_ends", DataType::Int32, false).into(), + Field::new("values", DataType::Utf8, true).into(), + Box::new(ScalarValue::Utf8(None)), + ), + ScalarValue::Binary(Some(b"bar".to_vec())), + ScalarValue::Binary(None), + ScalarValue::LargeBinary(Some(b"bar".to_vec())), + ScalarValue::LargeBinary(None), + ScalarStructBuilder::new() + .with_scalar( + Field::new("a", DataType::Int32, true), + ScalarValue::from(23i32), + ) + .with_scalar( + Field::new("b", DataType::Boolean, false), + ScalarValue::from(false), + ) + .build() + .unwrap(), + ScalarStructBuilder::new() + .with_scalar( + Field::new("a", DataType::Int32, true), + ScalarValue::from(23i32), + ) + .with_scalar( + Field::new("b", DataType::Boolean, false), + ScalarValue::from(false), + ) + .build() + .unwrap(), + ScalarValue::try_from(&DataType::Struct(Fields::from(vec![ Field::new("a", DataType::Int32, true), - ScalarValue::from(23i32), - ) - .with_scalar( Field::new("b", DataType::Boolean, false), - ScalarValue::from(false), - ) - .build() + ]))) .unwrap(), - ScalarStructBuilder::new() - .with_scalar( + ScalarValue::try_from(&DataType::Struct(Fields::from(vec![ Field::new("a", DataType::Int32, true), - ScalarValue::from(23i32), - ) - .with_scalar( Field::new("b", DataType::Boolean, false), - ScalarValue::from(false), - ) - .build() + ]))) .unwrap(), - ScalarValue::try_from(&DataType::Struct(Fields::from(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Boolean, false), - ]))) - .unwrap(), - ScalarValue::try_from(&DataType::Struct(Fields::from(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Boolean, false), - ]))) - .unwrap(), - ScalarValue::try_from(&DataType::Map( - Arc::new(Field::new( - "entries", - DataType::Struct(Fields::from(vec![ - Field::new("key", DataType::Int32, true), - Field::new("value", DataType::Utf8, false), - ])), - false, - )), - false, - )) - .unwrap(), - ScalarValue::try_from(&DataType::Map( - Arc::new(Field::new( - "entries", - DataType::Struct(Fields::from(vec![ - Field::new("key", DataType::Int32, true), - Field::new("value", DataType::Utf8, true), - ])), + ScalarValue::try_from(&DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Int32, true), + Field::new("value", DataType::Utf8, false), + ])), + false, + )), false, - )), - true, - )) - .unwrap(), - ScalarValue::Map(Arc::new(create_map_array_test_case())), - ScalarValue::FixedSizeBinary(b"bar".to_vec().len() as i32, Some(b"bar".to_vec())), - ScalarValue::FixedSizeBinary(0, None), - ScalarValue::FixedSizeBinary(5, None), - ]; + )) + .unwrap(), + ScalarValue::try_from(&DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Int32, true), + Field::new("value", DataType::Utf8, true), + ])), + false, + )), + true, + )) + .unwrap(), + ScalarValue::Map(Arc::new(create_map_array_test_case())), + ScalarValue::FixedSizeBinary( + b"bar".to_vec().len() as i32, + Some(b"bar".to_vec()), + ), + ScalarValue::FixedSizeBinary(0, None), + ScalarValue::FixedSizeBinary(5, None), + ]; // ScalarValue directly for test_case in should_pass.iter() { @@ -1773,19 +2083,20 @@ fn round_trip_datatype() { ), ])), DataType::Union( - UnionFields::new( + UnionFields::try_new( vec![7, 5, 3], vec![ Field::new("nullable", DataType::Boolean, false), Field::new("name", DataType::Utf8, false), Field::new("datatype", DataType::Binary, false), ], - ), + ) + .unwrap(), UnionMode::Sparse, ), DataType::Union( - UnionFields::new( - vec![5, 8, 1], + UnionFields::try_new( + vec![5, 8, 1, 100], vec![ Field::new("nullable", DataType::Boolean, false), Field::new("name", DataType::Utf8, false), @@ -1800,7 +2111,8 @@ fn round_trip_datatype() { true, ), ], - ), + ) + .unwrap(), UnionMode::Dense, ), DataType::Dictionary( @@ -1981,6 +2293,10 @@ fn roundtrip_binary_op() { test(Operator::RegexNotMatch); test(Operator::RegexIMatch); test(Operator::RegexMatch); + test(Operator::LikeMatch); + test(Operator::ILikeMatch); + test(Operator::NotLikeMatch); + test(Operator::NotILikeMatch); test(Operator::BitwiseShiftRight); test(Operator::BitwiseShiftLeft); test(Operator::BitwiseAnd); @@ -2350,7 +2666,8 @@ fn roundtrip_scalar_udf_extension_codec() { let ctx = SessionContext::new(); let proto = serialize_expr(&test_expr, &UDFExtensionCodec).expect("serialize expr"); let round_trip = - from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse expr"); + from_proto::parse_expr(&proto, ctx.task_ctx().as_ref(), &UDFExtensionCodec) + .expect("parse expr"); assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); roundtrip_json_test(&proto); @@ -2363,7 +2680,8 @@ fn roundtrip_aggregate_udf_extension_codec() { let ctx = SessionContext::new(); let proto = serialize_expr(&test_expr, &UDFExtensionCodec).expect("serialize expr"); let round_trip = - from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse expr"); + from_proto::parse_expr(&proto, ctx.task_ctx().as_ref(), &UDFExtensionCodec) + .expect("parse expr"); assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); roundtrip_json_test(&proto); @@ -2567,10 +2885,6 @@ fn roundtrip_window() { } impl WindowUDFImpl for SimpleWindowUDF { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "dummy_udwf" } @@ -2879,3 +3193,140 @@ async fn roundtrip_mixed_case_table_reference() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn roundtrip_empty_table_scan() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + let table = Arc::new(datafusion::datasource::empty::EmptyTable::new(Arc::clone( + &schema, + ))); + + let ctx = SessionContext::new(); + ctx.register_table("empty", table)?; + + let plan = ctx.table("empty").await?.into_optimized_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let restored = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; + + assert_eq!( + format!("{}", plan.display_indent_schema()), + format!("{}", restored.display_indent_schema()), + ); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_empty_table_scan_with_projection() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + let table = Arc::new(datafusion::datasource::empty::EmptyTable::new(Arc::clone( + &schema, + ))); + + let ctx = SessionContext::new(); + ctx.register_table("empty", table)?; + + let plan = ctx + .table("empty") + .await? + .select_columns(&["name"])? + .into_optimized_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let restored = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; + + assert_eq!( + format!("{}", plan.display_indent_schema()), + format!("{}", restored.display_indent_schema()), + ); + Ok(()) +} + +// Regression test for https://github.com/apache/datafusion/issues/22065: +// the decoder must preserve `null_aware = true` (NOT IN semantics) +// across a to_proto -> from_proto round trip. `null_equality` is at +// its default (`NullEqualsNothing`). +#[tokio::test] +async fn roundtrip_join_null_aware() -> Result<()> { + use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; + use datafusion_expr::JoinType; + + let ctx = SessionContext::new(); + let sql = " + SELECT id + FROM (VALUES (1), (2), (3)) AS t1(id) + WHERE id NOT IN ( + SELECT bad_id + FROM (VALUES (CAST(1 AS INT)), (CAST(NULL AS INT))) AS excludes(bad_id) + ) + "; + + let df = ctx.sql(sql).await?; + let plan = ctx.state().optimize(df.logical_plan())?; + + let mut found_null_aware = false; + plan.apply(|n| { + if let LogicalPlan::Join(j) = n + && j.join_type == JoinType::LeftAnti + && j.null_aware + { + found_null_aware = true; + } + Ok(TreeNodeRecursion::Continue) + })?; + assert!(found_null_aware); + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + + Ok(()) +} + +// Regression test for `null_equality` round-trip (related to #22065): +// the decoder must preserve a non-default `null_equality` +// (`NullEqualsNull`) across a to_proto -> from_proto round trip. +// `null_aware` is at its default (`false`). +#[tokio::test] +async fn roundtrip_join_null_equality() -> Result<()> { + use datafusion_common::NullEquality; + use datafusion_expr::JoinType; + use datafusion_expr::logical_plan::{Join, JoinConstraint}; + + let ctx = SessionContext::new(); + + let left_schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let right_schema = + Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, true)])); + ctx.register_table( + "t1", + Arc::new(datafusion::datasource::empty::EmptyTable::new(left_schema)), + )?; + ctx.register_table( + "t2", + Arc::new(datafusion::datasource::empty::EmptyTable::new(right_schema)), + )?; + let left = ctx.table("t1").await?.into_optimized_plan()?; + let right = ctx.table("t2").await?.into_optimized_plan()?; + + let join = LogicalPlan::Join(Join::try_new( + Arc::new(left), + Arc::new(right), + vec![(col("t1.a"), col("t2.b"))], + None, + JoinType::Inner, + JoinConstraint::On, + NullEquality::NullEqualsNull, + false, + )?); + + let bytes = logical_plan_to_bytes(&join)?; + let rt = logical_plan_from_bytes(&bytes, &ctx.task_ctx())?; + assert_eq!(format!("{join:?}"), format!("{rt:?}")); + + Ok(()) +} diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index c50f41625c70d..8e80467788598 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -15,34 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::collections::HashMap; use std::fmt::{Display, Formatter}; - -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use std::vec; -use crate::cases::{ - CustomUDWF, CustomUDWFNode, MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, - MyRegexUdfNode, -}; - use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use arrow::datatypes::{Fields, TimeUnit}; -use datafusion::physical_expr::aggregate::AggregateExprBuilder; -use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion::physical_plan::metrics::MetricType; -use datafusion_datasource::TableSchema; -use datafusion_expr::dml::InsertOp; -use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf; -use datafusion_functions_aggregate::array_agg::array_agg_udaf; -use datafusion_functions_aggregate::min_max::max_udaf; -use prost::Message; - use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; -use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; +use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, SchemaRef}; use datafusion::datasource::empty::EmptyTable; use datafusion::datasource::file_format::csv::CsvSink; use datafusion::datasource::file_format::json::{JsonFormat, JsonSink}; @@ -52,32 +35,40 @@ use datafusion::datasource::listing::{ }; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{ - wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileGroup, - FileScanConfigBuilder, FileSinkConfig, ParquetSource, + ArrowSource, FileGroup, FileOutputMode, FileScanConfig, FileScanConfigBuilder, + FileSinkConfig, ParquetSource, wrap_partition_type_in_dict, + wrap_partition_value_in_dict, }; use datafusion::datasource::sink::DataSinkExec; use datafusion::datasource::source::DataSourceExec; use datafusion::execution::TaskContext; use datafusion::functions_aggregate::count::count_udaf; +use datafusion::functions_aggregate::first_last::first_value_udaf; use datafusion::functions_aggregate::sum::sum_udaf; use datafusion::functions_window::nth_value::nth_value_udwf; use datafusion::functions_window::row_number::row_number_udwf; -use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; +use datafusion::logical_expr::{JoinType, Operator, Volatility, create_udf}; +use datafusion::physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion::physical_expr::expressions::Literal; use datafusion::physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; use datafusion::physical_expr::{ LexOrdering, PhysicalSortRequirement, ScalarFunctionExpr, }; +use datafusion::physical_optimizer::PhysicalOptimizerRule; +use datafusion::physical_optimizer::filter_pushdown::FilterPushdown; use datafusion::physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, + AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy, }; use datafusion::physical_plan::analyze::AnalyzeExec; +#[expect(deprecated)] +use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ - binary, cast, col, in_list, like, lit, BinaryExpr, Column, NotExpr, PhysicalSortExpr, + BinaryExpr, Column, DynamicFilterPhysicalExpr, NotExpr, PhysicalSortExpr, binary, + cast, col, in_list, like, lit, }; -use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::filter::{FilterExec, FilterExecBuilder}; use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, @@ -86,15 +77,20 @@ use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion::physical_plan::repartition::RepartitionExec; +use datafusion::physical_plan::scalar_subquery::{ + ScalarSubqueryExec, ScalarSubqueryLink, +}; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; use datafusion::physical_plan::windows::{ - create_udwf_window_expr, BoundedWindowAggExec, PlainAggregateWindowExpr, - WindowAggExec, + BoundedWindowAggExec, PlainAggregateWindowExpr, WindowAggExec, + create_udwf_window_expr, }; use datafusion::physical_plan::{ - displayable, ExecutionPlan, InputOrderMode, Partitioning, PhysicalExpr, Statistics, + DisplayAs, DisplayFormatType, ExecutionPlan, InputOrderMode, Partitioning, + PhysicalExpr, RangePartitioning, SendableRecordBatchStream, SplitPoint, Statistics, + displayable, }; use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion::scalar::ScalarValue; @@ -104,20 +100,45 @@ use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{ - internal_datafusion_err, internal_err, not_impl_err, DataFusionError, NullEquality, - Result, UnnestOptions, + DataFusionError, NullEquality, Result, UnnestOptions, exec_datafusion_err, + internal_datafusion_err, internal_err, not_impl_err, }; +use datafusion_datasource::file::FileSource; +use datafusion_datasource::{TableSchema, TableSchemaBuilder}; +use datafusion_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; +use datafusion_expr::dml::InsertOp; use datafusion_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, - Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, WindowUDF, + Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, + ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, + WindowFrame, WindowFrameBound, WindowUDF, + execution_props::{ScalarSubqueryResults, SubqueryIndex}, }; +use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf; +use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_functions_aggregate::average::avg_udaf; +use datafusion_functions_aggregate::min_max::max_udaf; use datafusion_functions_aggregate::nth_value::nth_value_udaf; use datafusion_functions_aggregate::string_agg::string_agg_udaf; +use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; +use datafusion_proto::bytes::{ + physical_plan_from_bytes_with_proto_converter, + physical_plan_to_bytes_with_proto_converter, +}; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr_with_converter; use datafusion_proto::physical_plan::{ - AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, + AsExecutionPlan, DeduplicatingProtoConverter, DefaultPhysicalExtensionCodec, + DefaultPhysicalProtoConverter, PhysicalExtensionCodec, PhysicalPlanDecodeContext, + PhysicalPlanNodeExt, PhysicalProtoConverterExtension, +}; +use datafusion_proto::protobuf; +use datafusion_proto::protobuf::{PhysicalExprNode, PhysicalPlanNode}; +use prost::Message; + +use crate::cases::{ + CustomUDWF, CustomUDWFNode, MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, + MyRegexUdfNode, }; -use datafusion_proto::protobuf::{self, PhysicalPlanNode}; +use datafusion_physical_expr::utils::reassign_expr_columns; /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is @@ -125,7 +146,8 @@ use datafusion_proto::protobuf::{self, PhysicalPlanNode}; fn roundtrip_test(exec_plan: Arc) -> Result<()> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; - roundtrip_test_and_return(exec_plan, &ctx, &codec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(exec_plan, &ctx, &codec, &proto_converter)?; Ok(()) } @@ -139,13 +161,19 @@ fn roundtrip_test_and_return( exec_plan: Arc, ctx: &SessionContext, codec: &dyn PhysicalExtensionCodec, + proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { - let proto: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), codec) - .expect("to proto"); - let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx.task_ctx(), codec) - .expect("from proto"); + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan), + codec, + proto_converter, + )?; + let result_exec_plan = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + codec, + proto_converter, + )?; pretty_assertions::assert_eq!( format!("{exec_plan:?}"), @@ -165,7 +193,8 @@ fn roundtrip_test_with_context( ctx: &SessionContext, ) -> Result<()> { let codec = DefaultPhysicalExtensionCodec {}; - roundtrip_test_and_return(exec_plan, ctx, &codec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(exec_plan, ctx, &codec, &proto_converter)?; Ok(()) } @@ -173,9 +202,10 @@ fn roundtrip_test_with_context( /// query results are identical. async fn roundtrip_test_sql_with_context(sql: &str, ctx: &SessionContext) -> Result<()> { let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; let initial_plan = ctx.sql(sql).await?.create_physical_plan().await?; - roundtrip_test_and_return(initial_plan, ctx, &codec)?; + roundtrip_test_and_return(initial_plan, ctx, &codec, &proto_converter)?; Ok(()) } @@ -282,6 +312,7 @@ fn roundtrip_hash_join() -> Result<()> { None, *partition_mode, NullEquality::NullEqualsNothing, + false, )?))?; } } @@ -596,14 +627,13 @@ fn roundtrip_aggregate_with_limit() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates = - vec![ - AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) - .alias("AVG(b)") - .build() - .map(Arc::new)?, - ]; + let aggregates = vec![ + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(b)") + .build() + .map(Arc::new)?, + ]; let agg = AggregateExec::try_new( AggregateMode::Final, @@ -613,7 +643,7 @@ fn roundtrip_aggregate_with_limit() -> Result<()> { Arc::new(EmptyExec::new(schema.clone())), schema, )?; - let agg = agg.with_limit(Some(12)); + let agg = agg.with_limit_options(Some(LimitOptions::new_with_order(12, false))); roundtrip_test(Arc::new(agg)) } @@ -626,14 +656,16 @@ fn roundtrip_aggregate_with_approx_pencentile_cont() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates = vec![AggregateExprBuilder::new( - approx_percentile_cont_udaf(), - vec![col("b", &schema)?, lit(0.5)], - ) - .schema(Arc::clone(&schema)) - .alias("APPROX_PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY b)") - .build() - .map(Arc::new)?]; + let aggregates = vec![ + AggregateExprBuilder::new( + approx_percentile_cont_udaf(), + vec![col("b", &schema)?, lit(0.5)], + ) + .schema(Arc::clone(&schema)) + .alias("APPROX_PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY b)") + .build() + .map(Arc::new)?, + ]; let agg = AggregateExec::try_new( AggregateMode::Final, @@ -662,15 +694,14 @@ fn roundtrip_aggregate_with_sort() -> Result<()> { }, }]; - let aggregates = - vec![ - AggregateExprBuilder::new(array_agg_udaf(), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) - .alias("ARRAY_AGG(b)") - .order_by(sort_exprs) - .build() - .map(Arc::new)?, - ]; + let aggregates = vec![ + AggregateExprBuilder::new(array_agg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("ARRAY_AGG(b)") + .order_by(sort_exprs) + .build() + .map(Arc::new)?, + ]; let agg = AggregateExec::try_new( AggregateMode::Final, @@ -730,14 +761,13 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates = - vec![ - AggregateExprBuilder::new(Arc::new(udaf), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) - .alias("example_agg") - .build() - .map(Arc::new)?, - ]; + let aggregates = vec![ + AggregateExprBuilder::new(Arc::new(udaf), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("example_agg") + .build() + .map(Arc::new)?, + ]; roundtrip_test_with_context( Arc::new(AggregateExec::try_new( @@ -775,6 +805,19 @@ fn roundtrip_filter_with_not_and_in_list() -> Result<()> { )?)) } +#[test] +fn roundtrip_filter_with_fetch() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let predicate = col("a", &schema)?; + let filter = FilterExecBuilder::new(predicate, Arc::new(EmptyExec::new(schema))) + .with_fetch(Some(10)) + .build()?; + assert_eq!(filter.fetch(), Some(10)); + roundtrip_test(Arc::new(filter)) +} + #[test] fn roundtrip_sort() -> Result<()> { let field_a = Field::new("a", DataType::Boolean, false); @@ -843,11 +886,13 @@ fn roundtrip_coalesce_batches_with_fetch() -> Result<()> { let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); + #[expect(deprecated)] roundtrip_test(Arc::new(CoalesceBatchesExec::new( Arc::new(EmptyExec::new(schema.clone())), 8096, )))?; + #[expect(deprecated)] roundtrip_test(Arc::new( CoalesceBatchesExec::new(Arc::new(EmptyExec::new(schema)), 8096) .with_fetch(Some(10)), @@ -908,6 +953,80 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { roundtrip_test(DataSourceExec::from_data_source(scan_config)) } +#[test] +fn roundtrip_parquet_exec_attaches_cached_reader_factory_after_roundtrip() -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, false)])); + let file_source = Arc::new(ParquetSource::new(Arc::clone(&file_schema))); + let scan_config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(vec![FileGroup::new(vec![PartitionedFile::new( + "/path/to/file.parquet".to_string(), + 1024, + )])]) + .with_statistics(Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(1024), + column_statistics: Statistics::unknown_column(&file_schema), + }) + .build(); + let exec_plan = DataSourceExec::from_data_source(scan_config); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; + let roundtripped = + roundtrip_test_and_return(exec_plan, &ctx, &codec, &proto_converter)?; + + let data_source = roundtripped + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("Expected DataSourceExec after roundtrip") + })?; + let file_scan = data_source + .data_source() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("Expected FileScanConfig after roundtrip") + })?; + let parquet_source = file_scan + .file_source() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("Expected ParquetSource after roundtrip") + })?; + + assert!( + parquet_source.parquet_file_reader_factory().is_some(), + "Parquet reader factory should be attached after decoding from protobuf" + ); + Ok(()) +} + +#[test] +fn roundtrip_arrow_scan() -> Result<()> { + let file_schema = + Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, false)])); + + let table_schema = TableSchema::from(&file_schema); + let file_source = Arc::new(ArrowSource::new_file_source(table_schema)); + + let scan_config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(vec![FileGroup::new(vec![PartitionedFile::new( + "/path/to/file.arrow".to_string(), + 1024, + )])]) + .with_statistics(Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(1024), + column_statistics: Statistics::unknown_column(&file_schema), + }) + .build(); + + roundtrip_test(DataSourceExec::from_data_source(scan_config)) +} + #[tokio::test] async fn roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> { let mut file_group = @@ -916,21 +1035,19 @@ async fn roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> { vec![wrap_partition_value_in_dict(ScalarValue::Int64(Some(0)))]; let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, false)])); - let table_schema = TableSchema::new( - schema.clone(), - vec![Arc::new(Field::new( + let table_schema = TableSchemaBuilder::from(&schema) + .with_table_partition_cols(vec![Arc::new(Field::new( "part".to_string(), wrap_partition_type_in_dict(DataType::Int16), false, - ))], - ); + ))]) + .build(); let file_source = Arc::new(ParquetSource::new(table_schema.clone())); let scan_config = FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) - .with_projection_indices(Some(vec![0, 1])) + .with_projection_indices(Some(vec![0, 1]))? .with_file_group(FileGroup::new(vec![file_group])) - .with_newlines_in_values(false) .build(); roundtrip_test(DataSourceExec::from_data_source(scan_config)) @@ -984,16 +1101,12 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { } impl Display for CustomPredicateExpr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CustomPredicateExpr") } } impl PhysicalExpr for CustomPredicateExpr { - fn as_any(&self) -> &dyn Any { - self - } - fn data_type(&self, _input_schema: &Schema) -> Result { unreachable!() } @@ -1030,6 +1143,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { _buf: &[u8], _inputs: &[Arc], _ctx: &TaskContext, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { unreachable!() } @@ -1038,6 +1152,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { &self, _node: Arc, _buf: &mut Vec, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result<()> { unreachable!() } @@ -1061,11 +1176,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { node: &Arc, buf: &mut Vec, ) -> Result<()> { - if node - .as_any() - .downcast_ref::() - .is_some() - { + if node.downcast_ref::().is_some() { buf.extend_from_slice("CustomPredicateExpr".as_bytes()); Ok(()) } else { @@ -1077,7 +1188,12 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { let exec_plan = DataSourceExec::from_data_source(scan_config); let ctx = SessionContext::new(); - roundtrip_test_and_return(exec_plan, &ctx, &CustomPhysicalExtensionCodec {})?; + roundtrip_test_and_return( + exec_plan, + &ctx, + &CustomPhysicalExtensionCodec {}, + &DefaultPhysicalProtoConverter {}, + )?; Ok(()) } @@ -1138,6 +1254,7 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { _buf: &[u8], _inputs: &[Arc], _ctx: &TaskContext, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result> { not_impl_err!("No extension codec provided") } @@ -1146,6 +1263,7 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { &self, _node: Arc, _buf: &mut Vec, + _proto_converter: &dyn PhysicalProtoConverterExtension, ) -> Result<()> { not_impl_err!("No extension codec provided") } @@ -1164,7 +1282,7 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { let binding = node.inner(); - if let Some(udf) = binding.as_any().downcast_ref::() { + if let Some(udf) = binding.downcast_ref::() { let proto = MyRegexUdfNode { pattern: udf.pattern.clone(), }; @@ -1191,7 +1309,7 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { let binding = node.inner(); - if let Some(udf) = binding.as_any().downcast_ref::() { + if let Some(udf) = binding.downcast_ref::() { let proto = MyAggregateUdfNode { result: udf.result.clone(), }; @@ -1218,7 +1336,7 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec) -> Result<()> { let binding = node.inner(); - if let Some(udwf) = binding.as_any().downcast_ref::() { + if let Some(udwf) = binding.downcast_ref::() { let proto = CustomUDWFNode { payload: udwf.payload.clone(), }; @@ -1275,7 +1393,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, - PhysicalGroupBy::new(vec![], vec![], vec![]), + PhysicalGroupBy::new(vec![], vec![], vec![], false), vec![aggr_expr], vec![None], window, @@ -1283,7 +1401,8 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1330,7 +1449,8 @@ fn roundtrip_udwf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - roundtrip_test_and_return(window, &ctx, &UDFExtensionCodec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(window, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1393,7 +1513,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, - PhysicalGroupBy::new(vec![], vec![], vec![]), + PhysicalGroupBy::new(vec![], vec![], vec![], false), vec![aggr_expr], vec![None], window, @@ -1401,7 +1521,8 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { )?); let ctx = SessionContext::new(); - roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?; + let proto_converter = DefaultPhysicalProtoConverter {}; + roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec, &proto_converter)?; Ok(()) } @@ -1436,13 +1557,9 @@ fn roundtrip_analyze() -> Result<()> { let schema = Schema::new(vec![field_a, field_b]); let input = Arc::new(PlaceholderRowExec::new(Arc::new(schema.clone()))); - roundtrip_test(Arc::new(AnalyzeExec::new( - false, - false, - vec![MetricType::SUMMARY, MetricType::DEV], - input, - Arc::new(schema), - ))) + roundtrip_test(Arc::new( + AnalyzeExec::builder(false, false, input, Arc::new(schema)).build(), + )) } #[tokio::test] @@ -1471,6 +1588,7 @@ fn roundtrip_json_sink() -> Result<()> { insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, file_extension: "json".into(), + file_output_mode: FileOutputMode::SingleFile, }; let data_sink = Arc::new(JsonSink::new( file_sink_config, @@ -1509,6 +1627,7 @@ fn roundtrip_csv_sink() -> Result<()> { insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, file_extension: "csv".into(), + file_output_mode: FileOutputMode::Directory, }; let data_sink = Arc::new(CsvSink::new( file_sink_config, @@ -1525,22 +1644,17 @@ fn roundtrip_csv_sink() -> Result<()> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; + let roundtrip_plan = roundtrip_test_and_return( Arc::new(DataSinkExec::new(input, data_sink, Some(sort_order))), &ctx, &codec, - ) - .unwrap(); + &proto_converter, + )?; - let roundtrip_plan = roundtrip_plan - .as_any() - .downcast_ref::() - .unwrap(); - let csv_sink = roundtrip_plan - .sink() - .as_any() - .downcast_ref::() - .unwrap(); + let roundtrip_plan = roundtrip_plan.downcast_ref::().unwrap(); + let csv_sink = roundtrip_plan.sink().downcast_ref::().unwrap(); assert_eq!( CompressionTypeVariant::ZSTD, csv_sink.writer_options().compression @@ -1566,6 +1680,7 @@ fn roundtrip_parquet_sink() -> Result<()> { insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, file_extension: "parquet".into(), + file_output_mode: FileOutputMode::Automatic, }; let data_sink = Arc::new(ParquetSink::new( file_sink_config, @@ -1657,6 +1772,50 @@ fn roundtrip_union() -> Result<()> { roundtrip_test(union) } +#[test] +fn roundtrip_repartition_preserve_order() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a])); + let sort_exprs: LexOrdering = [PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions::default(), + }] + .into(); + + // Create two sorted single-partition inputs, then union them to get + // a sorted input with 2 partitions. + let source1 = SortExec::new( + sort_exprs.clone(), + Arc::new(EmptyExec::new(Arc::clone(&schema))), + ); + let source2 = SortExec::new(sort_exprs, Arc::new(EmptyExec::new(schema))); + let union = UnionExec::try_new(vec![ + Arc::new(source1) as Arc, + Arc::new(source2) as Arc, + ])?; + + let repartition = RepartitionExec::try_new(union, Partitioning::RoundRobinBatch(10))? + .with_preserve_order(); + assert!(repartition.preserve_order()); + + roundtrip_test(Arc::new(repartition)) +} + +#[test] +fn roundtrip_range_partitioning() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let input = Arc::new(EmptyExec::new(Arc::clone(&schema))); + let range_partitioning = Partitioning::Range(RangePartitioning::new( + [PhysicalSortExpr::new_default(col("a", &schema)?)].into(), + vec![SplitPoint::new(vec![ScalarValue::Int64(Some(10))])], + )); + // RepartitionExec is used only to carry the partitioning through proto. + // Executing range repartitioning is intentionally unsupported. + let repartition = RepartitionExec::try_new(input, range_partitioning)?; + + roundtrip_test(Arc::new(repartition)) +} + #[test] fn roundtrip_interleave() -> Result<()> { let field_a = Field::new("col", DataType::Int64, false); @@ -1813,15 +1972,16 @@ async fn roundtrip_projection_source() -> Result<()> { 1024, )])]) .with_statistics(statistics) - .with_projection_indices(Some(vec![0, 1, 2])) + .with_projection_indices(Some(vec![0, 1, 2]))? .build(); let filter = Arc::new( - FilterExec::try_new( + FilterExecBuilder::new( Arc::new(BinaryExpr::new(col("c", &schema)?, Operator::Eq, lit(1))), DataSourceExec::from_data_source(scan_config), - )? - .with_projection(Some(vec![0, 1]))?, + ) + .apply_projection(Some(vec![0, 1]))? + .build()?, ); roundtrip_test(filter) @@ -1971,6 +2131,7 @@ async fn test_serialize_deserialize_tpch_queries() -> Result<()> { // serialize the physical plan let codec = DefaultPhysicalExtensionCodec {}; + let proto = PhysicalPlanNode::try_from_physical_plan(physical_plan.clone(), &codec)?; @@ -2022,9 +2183,107 @@ async fn test_round_trip_human_display() -> Result<()> { let sql = "select r_name, count(r_name) from region group by r_name"; roundtrip_test_sql_with_context(sql, &ctx).await?; + let sql = "select count(*) as count_star from region"; + roundtrip_test_sql_with_context(sql, &ctx).await?; + + Ok(()) +} + +#[test] +fn test_round_trip_aliased_reverse_human_display() -> Result<()> { + let aggregate_expr = roundtrip_first_value_aggregate( + "agg", + "first_value(b) ORDER BY [b ASC NULLS LAST]", + Some("agg"), + )?; + let reversed = aggregate_expr + .reverse_expr() + .expect("expected reverse expr"); + + assert_eq!(reversed.name(), "agg"); + assert_eq!(reversed.human_display_alias(), Some("agg")); + assert_eq!( + reversed.human_display(), + Some("last_value(b) ORDER BY [b DESC NULLS FIRST]") + ); + + Ok(()) +} + +#[test] +fn test_round_trip_human_display_alias_with_colon() -> Result<()> { + let aggregate_expr = roundtrip_first_value_aggregate( + "agg:one", + "first_value(b) ORDER BY [b ASC NULLS LAST]", + Some("agg:one"), + )?; + + assert_eq!(aggregate_expr.name(), "agg:one"); + assert_eq!(aggregate_expr.human_display_alias(), Some("agg:one")); + assert_eq!( + aggregate_expr.human_display(), + Some("first_value(b) ORDER BY [b ASC NULLS LAST]") + ); + + Ok(()) +} + +#[test] +fn test_round_trip_non_aliased_human_display_ending_like_alias() -> Result<()> { + let aggregate_expr = + roundtrip_first_value_aggregate("agg", "first_value(b) as agg", None)?; + + assert_eq!(aggregate_expr.name(), "agg"); + assert_eq!( + aggregate_expr.human_display(), + Some("first_value(b) as agg") + ); + assert_eq!(aggregate_expr.human_display_alias(), None); + Ok(()) } +fn roundtrip_first_value_aggregate( + alias: &str, + human_display: &str, + human_display_alias: Option<&str>, +) -> Result> { + let schema = Arc::new(Schema::new(vec![Field::new("b", DataType::Int64, true)])); + let mut builder = + AggregateExprBuilder::new(first_value_udaf(), vec![col("b", &schema)?]) + .order_by(vec![PhysicalSortExpr { + expr: col("b", &schema)?, + options: SortOptions::new(false, false), + }]) + .schema(Arc::clone(&schema)) + .alias(alias) + .human_display(human_display); + if let Some(human_display_alias) = human_display_alias { + builder = builder.human_display_alias(human_display_alias); + } + let agg_expr = builder.build().map(Arc::new)?; + + let plan = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![], vec![], vec![], false), + vec![agg_expr], + vec![None], + Arc::new(EmptyExec::new(Arc::clone(&schema))), + schema, + )?); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; + let roundtrip_plan = roundtrip_test_and_return(plan, &ctx, &codec, &proto_converter)?; + let aggregate = roundtrip_plan + .as_ref() + .downcast_ref::() + .expect("expected AggregateExec after roundtrip"); + + Ok(Arc::clone(&aggregate.aggr_expr()[0])) +} + // Bug 2 of https://github.com/apache/datafusion/issues/16772 /// Test that PhysicalGroupBy groups field is correctly serialized/deserialized /// for simple aggregates (no GROUP BY clause). @@ -2092,6 +2351,7 @@ async fn test_tpch_part_in_list_query_with_real_parquet_data() -> Result<()> { // Serialize the physical plan - bug may happen here already but not necessarily manifests let codec = DefaultPhysicalExtensionCodec {}; + let proto = PhysicalPlanNode::try_from_physical_plan(physical_plan.clone(), &codec)?; // This will fail with the bug, but should succeed when fixed @@ -2263,3 +2523,1607 @@ async fn roundtrip_listing_table_with_schema_metadata() -> Result<()> { roundtrip_test(plan) } + +#[tokio::test] +async fn roundtrip_async_func_exec() -> Result<()> { + #[derive(Debug, PartialEq, Eq, Hash)] + struct TestAsyncUDF { + signature: Signature, + } + + impl TestAsyncUDF { + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Int64], Volatility::Volatile), + } + } + } + + impl ScalarUDFImpl for TestAsyncUDF { + fn name(&self) -> &str { + "test_async_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + not_impl_err!("Must call from `invoke_async_with_args`") + } + } + + #[async_trait::async_trait] + impl AsyncScalarUDFImpl for TestAsyncUDF { + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + Ok(args.args[0].clone()) + } + } + + let ctx = SessionContext::new(); + let async_udf = AsyncScalarUDF::new(Arc::new(TestAsyncUDF::new())); + ctx.register_udf(async_udf.into_scalar_udf()); + + let physical_plan = ctx + .sql("select test_async_udf(1)") + .await? + .create_physical_plan() + .await?; + + roundtrip_test_with_context(physical_plan, &ctx)?; + + Ok(()) +} + +/// Test that HashTableLookupExpr serializes to lit(true) +/// +/// HashTableLookupExpr contains a runtime hash table that cannot be serialized. +/// The serialization code replaces it with lit(true) which is safe because +/// it's a performance optimization filter, not a correctness requirement. +#[test] +fn roundtrip_hash_table_lookup_expr_to_lit() -> Result<()> { + use datafusion::physical_plan::joins::join_hash_map::JoinHashMapU32; + use datafusion::physical_plan::joins::{HashTableLookupExpr, Map}; + + // Create a simple schema and input plan + let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Int64, false)])); + let input = Arc::new(EmptyExec::new(schema.clone())); + + // Create a HashTableLookupExpr - it will be replaced with lit(true) during serialization + let hash_map = Arc::new(Map::HashMap(Box::new(JoinHashMapU32::with_capacity(0)))); + let on_columns = vec![datafusion::physical_plan::expressions::col("col", &schema)?]; + let lookup_expr: Arc = Arc::new(HashTableLookupExpr::new( + on_columns, + datafusion::physical_plan::joins::SeededRandomState::with_seed(0), + hash_map, + "test_lookup".to_string(), + )); + + // Create a filter with the lookup expression + let filter = Arc::new(FilterExec::try_new(lookup_expr, input)?); + + // Serialize + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + + let proto: PhysicalPlanNode = + PhysicalPlanNode::try_from_physical_plan(filter.clone(), &codec) + .expect("serialization should succeed"); + + // Deserialize + let result: Arc = proto + .try_into_physical_plan(&ctx.task_ctx(), &codec) + .expect("deserialization should succeed"); + + // The deserialized plan should have lit(true) instead of HashTableLookupExpr + // Verify the filter predicate is a Literal(true) + let result_filter = result.downcast_ref::().unwrap(); + let predicate = result_filter.predicate(); + let literal = predicate.downcast_ref::().unwrap(); + assert_eq!(*literal.value(), ScalarValue::Boolean(Some(true))); + + Ok(()) +} + +#[test] +fn roundtrip_hash_expr() -> Result<()> { + use datafusion::physical_plan::joins::{HashExpr, SeededRandomState}; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + ])); + + // Create a HashExpr with test columns and seeds + let on_columns = vec![col("a", &schema)?, col("b", &schema)?]; + let hash_expr: Arc = Arc::new(HashExpr::new( + on_columns, + SeededRandomState::with_seed(0), // arbitrary random seed for testing + "test_hash".to_string(), + )); + + // Wrap in a filter by comparing hash value to a literal + // hash_expr > 0 is always boolean + let filter_expr = binary(hash_expr, Operator::Gt, lit(0u64), &schema)?; + let filter = Arc::new(FilterExec::try_new( + filter_expr, + Arc::new(EmptyExec::new(schema)), + )?); + + // Confirm that the debug string contains the random state seeds + assert!( + format!("{filter:?}").contains("test_hash(a@0, b@1, [0])"), + "Debug string missing seeds: {filter:?}" + ); + roundtrip_test(filter) +} + +#[test] +fn custom_proto_converter_intercepts() -> Result<()> { + #[derive(Default)] + struct CustomConverterInterceptor { + num_proto_plans: RwLock, + num_physical_plans: RwLock, + num_proto_exprs: RwLock, + num_physical_exprs: RwLock, + } + + impl PhysicalProtoConverterExtension for CustomConverterInterceptor { + fn proto_to_execution_plan( + &self, + proto: &protobuf::PhysicalPlanNode, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> { + { + let mut counter = self + .num_proto_plans + .write() + .map_err(|err| exec_datafusion_err!("{err}"))?; + *counter += 1; + } + self.default_proto_to_execution_plan(proto, ctx) + } + + fn execution_plan_to_proto( + &self, + plan: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result + where + Self: Sized, + { + { + let mut counter = self + .num_physical_plans + .write() + .map_err(|err| exec_datafusion_err!("{err}"))?; + *counter += 1; + } + PhysicalPlanNode::try_from_physical_plan_with_converter( + Arc::clone(plan), + codec, + self, + ) + } + + fn proto_to_physical_expr( + &self, + proto: &PhysicalExprNode, + input_schema: &Schema, + ctx: &PhysicalPlanDecodeContext<'_>, + ) -> Result> + where + Self: Sized, + { + { + let mut counter = self + .num_proto_exprs + .write() + .map_err(|err| exec_datafusion_err!("{err}"))?; + *counter += 1; + } + self.default_proto_to_physical_expr(proto, input_schema, ctx) + } + + fn physical_expr_to_proto( + &self, + expr: &Arc, + codec: &dyn PhysicalExtensionCodec, + ) -> Result { + { + let mut counter = self + .num_physical_exprs + .write() + .map_err(|err| exec_datafusion_err!("{err}"))?; + *counter += 1; + } + serialize_physical_expr_with_converter(expr, codec, self) + } + } + + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let sort_exprs = [ + PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: true, + nulls_first: false, + }, + }, + PhysicalSortExpr { + expr: col("b", &schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }, + ] + .into(); + + let exec_plan = Arc::new(SortExec::new(sort_exprs, Arc::new(EmptyExec::new(schema)))); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = CustomConverterInterceptor::default(); + roundtrip_test_and_return(exec_plan, &ctx, &codec, &proto_converter)?; + + assert_eq!(*proto_converter.num_proto_exprs.read().unwrap(), 2); + assert_eq!(*proto_converter.num_physical_exprs.read().unwrap(), 2); + assert_eq!(*proto_converter.num_proto_plans.read().unwrap(), 2); + assert_eq!(*proto_converter.num_physical_plans.read().unwrap(), 2); + + Ok(()) +} + +#[test] +fn roundtrip_call_null_scalar_struct_dict() -> Result<()> { + let data_type = DataType::Struct(Fields::from(vec![Field::new( + "item", + DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + true, + )])); + + let schema = Arc::new(Schema::new(vec![Field::new("a", data_type.clone(), true)])); + let scan = Arc::new(EmptyExec::new(Arc::clone(&schema))); + let scalar = lit(ScalarValue::try_from(data_type)?); + let filter = Arc::new(FilterExec::try_new( + Arc::new(BinaryExpr::new(scalar, Operator::Eq, col("a", &schema)?)), + scan, + )?); + + roundtrip_test(filter) +} + +/// Create a [`DynamicFilterPhysicalExpr`] with child column expression "a" @ index 0. +fn make_dynamic_filter() -> Arc { + Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("a", 0)) as Arc], + lit(true), + )) as Arc +} + +/// Update a [`DynamicFilterPhysicalExpr`]'s children to support child schema "b" @ 0, "a" @ 1. +fn make_reassigned_dynamic_filter( + filter: Arc, +) -> Result<(Arc, Arc)> { + let schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Int64, false), + Field::new("a", DataType::Int64, false), + ])); + let reassigned = reassign_expr_columns(filter, &schema)?; + Ok((schema, reassigned)) +} + +/// Extract the expression id from a [`PhysicalExpr`] proto. Populated by the +/// default serializer from `PhysicalExpr::expression_id`. +fn proto_expression_id(expr: &PhysicalExprNode) -> u64 { + expr.expr_id + .expect("expected PhysicalExprNode.expr_id to be populated") +} + +/// Roundtrip a single physical expression shaped like so: +/// +/// ```text +/// BinaryExpr(AND) +/// / \ +/// filter_expr_1 filter_expr_2 +/// ``` +/// +/// Returns filter_expr_1 and filter_expr_2 after deserialization. +fn roundtrip_dynamic_filter_expr_pair( + filter_expr_1: Arc, + filter_expr_2: Arc, + schema: Arc, +) -> Result<(Arc, Arc)> { + let pair_expr = Arc::new(BinaryExpr::new( + Arc::clone(&filter_expr_1), + Operator::And, + Arc::clone(&filter_expr_2), + )) as Arc; + + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let proto = converter.physical_expr_to_proto(&pair_expr, &codec)?; + let ctx = SessionContext::new(); + let task_ctx = ctx.task_ctx(); + let decode_ctx = PhysicalPlanDecodeContext::new(task_ctx.as_ref(), &codec); + let deserialized_expr = + converter.proto_to_physical_expr(&proto, &schema, &decode_ctx)?; + + let binary = deserialized_expr + .downcast_ref::() + .expect("Expected BinaryExpr"); + + Ok((Arc::clone(binary.left()), Arc::clone(binary.right()))) +} + +/// Roundtrip an execution plan shaped like so: +/// +/// ```text +/// FilterExec(dynamic_filter_1 on a@0) +/// ProjectionExec(a := Column("a", source_index)) +/// DataSourceExec +/// ParquetSource(predicate = dynamic_filter_2) +/// ``` +/// +/// `dynamic_filter_1` and `dynamic_filter_2` are the same dynamic filter, except with +/// different children. +/// +/// Returns +/// - `dynamic_filter_1` before serialization +/// - `dynamic_filter_2` before serialization +/// - `dynamic_filter_1` after serialization +/// - `dynamic_filter_2` after serialization +#[allow(clippy::type_complexity)] +fn roundtrip_dynamic_filter_plan_pair() -> Result<( + Arc, + Arc, + Arc, + Arc, +)> { + let filter_expr_1 = make_dynamic_filter(); + let (data_source_schema, filter_expr_2) = + make_reassigned_dynamic_filter(Arc::clone(&filter_expr_1))?; + let left_before = Arc::clone(&filter_expr_1); + let right_before = Arc::clone(&filter_expr_2); + let file_source = Arc::new( + ParquetSource::new(Arc::clone(&data_source_schema)) + .with_predicate(Arc::clone(&filter_expr_2)), + ); + let scan_config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(vec![FileGroup::new(vec![PartitionedFile::new( + "/path/to/file.parquet".to_string(), + 1024, + )])]) + .build(); + let data_source_exec = + DataSourceExec::from_data_source(scan_config) as Arc; + + let projection_exec = Arc::new(ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(Column::new("a", 1)) as Arc, + alias: "a".to_string(), + }], + data_source_exec, + )?) as Arc; + let filter_exec = Arc::new(FilterExec::try_new( + Arc::clone(&filter_expr_1), + projection_exec, + )?) as Arc; + + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let proto = converter.execution_plan_to_proto(&filter_exec, &codec)?; + + let ctx = SessionContext::new(); + let task_ctx = ctx.task_ctx(); + let decode_ctx = PhysicalPlanDecodeContext::new(task_ctx.as_ref(), &codec); + let deserialized_plan = converter.proto_to_execution_plan(&proto, &decode_ctx)?; + + let outer_filter = deserialized_plan + .downcast_ref::() + .expect("Expected outer FilterExec"); + let left_filter = Arc::clone(outer_filter.predicate()); + let projection = outer_filter.children()[0] + .downcast_ref::() + .expect("Expected ProjectionExec"); + let data_source = projection + .input() + .downcast_ref::() + .expect("Expected DataSourceExec"); + let scan_config = data_source + .data_source() + .downcast_ref::() + .expect("Expected FileScanConfig"); + let right_filter = scan_config + .file_source() + .filter() + .expect("Expected pushed-down predicate"); + + Ok((left_before, right_before, left_filter, right_filter)) +} + +/// Takes two [`DynamicFilterPhysicalExpr`] and asserts that updates to one are visible +/// via the other. This helps assert that referential integrity is maintained after +/// deserializing. +fn assert_dynamic_filter_update_is_visible( + left_filter: &Arc, + right_filter: &Arc, +) -> Result<()> { + let left_filter = left_filter + .downcast_ref::() + .expect("Expected dynamic filter"); + let right_filter = right_filter + .downcast_ref::() + .expect("Expected dynamic filter"); + + // Sanity check that the filters have the same generation. + let original_generation = left_filter.snapshot_generation(); + assert_eq!(original_generation, right_filter.snapshot_generation(),); + + left_filter.update(lit(123_i64))?; + + // Assert that both generations updated. + assert_eq!(original_generation + 1, right_filter.snapshot_generation(),); + assert_eq!( + left_filter.snapshot_generation(), + right_filter.snapshot_generation(), + ); + + // Ensure both filters have the updated expr. + let expected_current = r#"Literal { value: Int64(123), field: Field { name: "lit", data_type: Int64 } }"#; + assert_eq!(expected_current, format!("{:?}", left_filter.current()?),); + assert_eq!(expected_current, format!("{:?}", right_filter.current()?),); + + Ok(()) +} + +/// Extract the dynamic-filter predicate that was pushed down to the parquet +/// scan at the bottom of the plan tree. +fn parquet_source_predicate(child: &Arc) -> Arc { + let data_source = child + .downcast_ref::() + .expect("Child should be DataSourceExec"); + let (_, parquet_source) = data_source + .downcast_to_file_source::() + .expect("Should be ParquetSource"); + parquet_source + .filter() + .expect("ParquetSource should have a predicate after roundtrip") +} + +/// Assert that two dynamic filters are equal both structurally (Debug output) +/// and by identity (`expression_id`). +fn assert_dynamic_filters_equal( + expected: &Arc, + actual: &Arc, +) { + // Structural. + let expected_dbg = format!("{expected:?}"); + let actual_dbg = format!("{actual:?}"); + if expected_dbg == actual_dbg { + return; + } + + // Note that the `DeduplicatingDeserializer` routes every cache hit through + // `with_new_children`. This produces an equivalent expression, but with + // remapped children that are equal to the original. Handle that case here. + let rewritten = Arc::clone(expected) + .with_new_children(expected.children().iter().map(|c| Arc::clone(c)).collect()) + .expect("with_new_children on a dynamic filter should not fail"); + assert_eq!(format!("{rewritten:?}"), actual_dbg); +} + +// Two clones of a dynamic filter expression should be deduped to the exact same expression. +#[test] +fn test_dynamic_filter_roundtrip_dedupe() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let filter_expr_1 = make_dynamic_filter(); + let filter_expr_2 = Arc::clone(&filter_expr_1); + + let (filter_expr_1_after_roundtrip, filter_expr_2_after_roundtrip) = + roundtrip_dynamic_filter_expr_pair( + Arc::clone(&filter_expr_1), + Arc::clone(&filter_expr_2), + schema, + )?; + + // Assert the filters are not modified during roundtrip. + assert_dynamic_filters_equal(&filter_expr_1, &filter_expr_1_after_roundtrip); + assert_dynamic_filters_equal(&filter_expr_2, &filter_expr_2_after_roundtrip); + assert_dynamic_filters_equal( + &filter_expr_1_after_roundtrip, + &filter_expr_2_after_roundtrip, + ); + + // Assert referential integrity. + assert_dynamic_filter_update_is_visible( + &filter_expr_1_after_roundtrip, + &filter_expr_2_after_roundtrip, + )?; + + Ok(()) +} + +/// Roundtrip test for an execution plan where there are multiple instances of a dynamic filter +/// with different children. +#[test] +fn test_dynamic_filter_plan_roundtrip_dedupe() -> Result<()> { + let ( + filter_expr_1, + filter_expr_2, + filter_expr_1_after_roundtrip, + filter_expr_2_after_roundtrip, + ) = roundtrip_dynamic_filter_plan_pair()?; + + // Assert the filters are not modified during roundtrip. + assert_dynamic_filters_equal(&filter_expr_1, &filter_expr_1_after_roundtrip); + assert_dynamic_filters_equal(&filter_expr_2, &filter_expr_2_after_roundtrip); + + // Assert referential integrity. + assert_dynamic_filter_update_is_visible( + &filter_expr_1_after_roundtrip, + &filter_expr_2_after_roundtrip, + )?; + + Ok(()) +} + +#[test] +fn test_dynamic_filter_expression_id_is_stable_between_serializations() -> Result<()> { + let filter_expr = make_dynamic_filter(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DeduplicatingProtoConverter {}; + + let proto1 = proto_converter.physical_expr_to_proto(&filter_expr, &codec)?; + let expr_id1 = proto_expression_id(&proto1); + + let proto2 = proto_converter.physical_expr_to_proto(&filter_expr, &codec)?; + let expr_id2 = proto_expression_id(&proto2); + + assert_eq!( + expr_id1, expr_id2, + "Expected the same dynamic filter expression id across serializations" + ); + + Ok(()) +} + +/// Tests that `lead` window function with offset and default value args +/// survives a protobuf round-trip. This is a regression test for a bug +/// where `expressions()` (used during serialization) returns only the +/// column expression for lead/lag, silently dropping the offset and +/// default value literal args. +#[test] +fn roundtrip_lead_with_default_value() -> Result<()> { + use datafusion::functions_window::lead_lag::lead_udwf; + + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + // lead(a, 2, 42) — column a, offset 2, default value 42 + let lead_window = create_udwf_window_expr( + &lead_udwf(), + &[col("a", &schema)?, lit(2i64), lit(42i64)], + schema.as_ref(), + "test lead with default".to_string(), + false, + )?; + + let udwf_expr = Arc::new(StandardWindowExpr::new( + lead_window, + &[col("b", &schema)?], + &[PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + Arc::new(WindowFrame::new(None)), + )); + + let input = Arc::new(EmptyExec::new(schema.clone())); + + roundtrip_test(Arc::new(BoundedWindowAggExec::try_new( + vec![udwf_expr], + input, + InputOrderMode::Sorted, + true, + )?)) +} + +/// Verify that ScalarSubqueryExpr nodes in the input plan are connected to the +/// same shared results container as ScalarSubqueryExec after a proto round-trip. +#[test] +fn roundtrip_scalar_subquery_exec() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let results = ScalarSubqueryResults::new(1); + + // Build the input plan: a filter whose predicate references the + // scalar subquery result via ScalarSubqueryExpr. + let sq_expr = Arc::new(ScalarSubqueryExpr::new( + DataType::Int64, + true, + SubqueryIndex::new(0), + results.clone(), + )); + let predicate = binary(col("a", &schema)?, Operator::Eq, sq_expr, &schema)?; + let filter = + FilterExec::try_new(predicate, Arc::new(EmptyExec::new(schema.clone())))?; + + // Build a trivial subquery plan. + let subquery_plan = + Arc::new(EmptyExec::new(Arc::new(Schema::new(vec![Field::new( + "x", + DataType::Int64, + true, + )])))); + + let exec: Arc = Arc::new(ScalarSubqueryExec::new( + Arc::new(filter), + vec![ScalarSubqueryLink { + plan: subquery_plan, + index: SubqueryIndex::new(0), + }], + results, + )); + + // Perform the round-trip using DeduplicatingProtoConverter, which + // creates a DeduplicatingDeserializer that threads scalar subquery + // results through expression deserialization. + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec), + &codec, + &converter, + )?; + let ctx = SessionContext::new(); + let deserialized = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &converter, + )?; + + // Verify the deserialized ScalarSubqueryExec's results container is + // shared with the ScalarSubqueryExpr in the input plan. + let sq_exec = deserialized + .downcast_ref::() + .expect("expected ScalarSubqueryExec"); + let exec_results = sq_exec.results(); + + // Walk the input plan to find the ScalarSubqueryExpr and verify it + // points to the same results container. + let filter_exec = sq_exec + .input() + .downcast_ref::() + .expect("expected FilterExec"); + let binary_expr = filter_exec + .predicate() + .downcast_ref::() + .expect("expected BinaryExpr"); + let deserialized_sq_expr = binary_expr + .right() + .downcast_ref::() + .expect("expected ScalarSubqueryExpr"); + + assert!( + ScalarSubqueryResults::ptr_eq(exec_results, deserialized_sq_expr.results()), + "ScalarSubqueryExpr should share the same results container as ScalarSubqueryExec" + ); + Ok(()) +} + +/// Verify that nested ScalarSubqueryExec nodes deserialize with distinct +/// scoped results containers, and that each ScalarSubqueryExpr is wired to the +/// container for its own surrounding ScalarSubqueryExec. +#[test] +fn roundtrip_nested_scalar_subquery_exec_scopes_results() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let subquery_schema = + Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, true)])); + + let inner_results = ScalarSubqueryResults::new(1); + let inner_sq_expr = Arc::new(ScalarSubqueryExpr::new( + DataType::Int64, + true, + SubqueryIndex::new(0), + inner_results.clone(), + )); + let inner_predicate = + binary(col("a", &schema)?, Operator::Eq, inner_sq_expr, &schema)?; + let inner_filter = Arc::new(FilterExec::try_new( + inner_predicate, + Arc::new(EmptyExec::new(schema.clone())), + )?); + let inner_exec: Arc = Arc::new(ScalarSubqueryExec::new( + inner_filter, + vec![ScalarSubqueryLink { + plan: Arc::new(EmptyExec::new(subquery_schema.clone())), + index: SubqueryIndex::new(0), + }], + inner_results, + )); + + let outer_results = ScalarSubqueryResults::new(1); + let outer_sq_expr = Arc::new(ScalarSubqueryExpr::new( + DataType::Int64, + true, + SubqueryIndex::new(0), + outer_results.clone(), + )); + let outer_predicate = + binary(col("a", &schema)?, Operator::Eq, outer_sq_expr, &schema)?; + let outer_filter = Arc::new(FilterExec::try_new(outer_predicate, inner_exec)?); + let outer_exec: Arc = Arc::new(ScalarSubqueryExec::new( + outer_filter, + vec![ScalarSubqueryLink { + plan: Arc::new(EmptyExec::new(subquery_schema)), + index: SubqueryIndex::new(0), + }], + outer_results, + )); + + let bytes = datafusion_proto::bytes::physical_plan_to_bytes(Arc::clone(&outer_exec))?; + let ctx = SessionContext::new(); + let deserialized = datafusion_proto::bytes::physical_plan_from_bytes( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + )?; + + let outer_exec = deserialized + .downcast_ref::() + .expect("expected outer ScalarSubqueryExec"); + let outer_results = outer_exec.results(); + let outer_filter = outer_exec + .input() + .downcast_ref::() + .expect("expected outer FilterExec"); + let outer_binary = outer_filter + .predicate() + .downcast_ref::() + .expect("expected outer BinaryExpr"); + let outer_sq_expr = outer_binary + .right() + .downcast_ref::() + .expect("expected outer ScalarSubqueryExpr"); + + let inner_exec = outer_filter + .input() + .downcast_ref::() + .expect("expected inner ScalarSubqueryExec"); + let inner_results = inner_exec.results(); + let inner_filter = inner_exec + .input() + .downcast_ref::() + .expect("expected inner FilterExec"); + let inner_binary = inner_filter + .predicate() + .downcast_ref::() + .expect("expected inner BinaryExpr"); + let inner_sq_expr = inner_binary + .right() + .downcast_ref::() + .expect("expected inner ScalarSubqueryExpr"); + + assert!( + ScalarSubqueryResults::ptr_eq(outer_results, outer_sq_expr.results()), + "outer ScalarSubqueryExpr should use outer ScalarSubqueryExec results" + ); + assert!( + ScalarSubqueryResults::ptr_eq(inner_results, inner_sq_expr.results()), + "inner ScalarSubqueryExpr should use inner ScalarSubqueryExec results" + ); + assert!( + !ScalarSubqueryResults::ptr_eq(outer_results, inner_results), + "nested ScalarSubqueryExec nodes should not share results containers" + ); + assert!( + !ScalarSubqueryResults::ptr_eq(outer_results, inner_sq_expr.results()), + "inner ScalarSubqueryExpr must not read from outer results" + ); + assert!( + !ScalarSubqueryResults::ptr_eq(inner_results, outer_sq_expr.results()), + "outer ScalarSubqueryExpr must not read from inner results" + ); + + Ok(()) +} + +/// Verify that the default physical plan bytes round-trip preserves executable +/// scalar subquery plans. +#[tokio::test] +async fn roundtrip_scalar_subquery_exec_with_default_converter_executes() -> Result<()> { + let ctx = SessionContext::new(); + let sql = "SELECT x + (SELECT max(y) FROM (VALUES (10), (20)) AS u(y)) AS s \ + FROM (VALUES (2), (1)) AS t(x) \ + ORDER BY s"; + + let initial_plan = ctx.sql(sql).await?.create_physical_plan().await?; + assert!( + format!("{initial_plan:?}").contains("ScalarSubqueryExec"), + "expected ScalarSubqueryExec in plan:\n{initial_plan:?}" + ); + + let bytes = + datafusion_proto::bytes::physical_plan_to_bytes(Arc::clone(&initial_plan))?; + let roundtripped = datafusion_proto::bytes::physical_plan_from_bytes( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + )?; + assert!( + format!("{roundtripped:?}").contains("ScalarSubqueryExec"), + "expected ScalarSubqueryExec after roundtrip:\n{roundtripped:?}" + ); + + let batches = datafusion::physical_plan::common::collect( + roundtripped.execute(0, ctx.task_ctx())?, + ) + .await?; + datafusion::assert_batches_eq!( + &["+----+", "| s |", "+----+", "| 21 |", "| 22 |", "+----+",], + &batches + ); + + Ok(()) +} + +/// Test that a chain of the same operator (a AND b AND c) is linearized +/// and roundtrips correctly. +#[test] +fn roundtrip_binary_expr_chain_same_op() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Boolean, false); + let field_c = Field::new("c", DataType::Boolean, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b, field_c])); + let ab = binary( + col("a", &schema)?, + Operator::And, + col("b", &schema)?, + &schema, + )?; + let abc = binary(ab, Operator::And, col("c", &schema)?, &schema)?; + roundtrip_test(Arc::new(FilterExec::try_new( + abc, + Arc::new(EmptyExec::new(schema)), + )?)) +} + +/// Test that mixed operators (a AND b OR c) are NOT linearized together — +/// only chains of the same operator are flattened. +#[test] +fn roundtrip_binary_expr_mixed_ops() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Boolean, false); + let field_c = Field::new("c", DataType::Boolean, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b, field_c])); + // (a AND b) OR c — AND and OR are different operators, so linearization stops + let a_and_b = binary( + col("a", &schema)?, + Operator::And, + col("b", &schema)?, + &schema, + )?; + let expr = binary(a_and_b, Operator::Or, col("c", &schema)?, &schema)?; + roundtrip_test(Arc::new(FilterExec::try_new( + expr, + Arc::new(EmptyExec::new(schema)), + )?)) +} + +/// Test that a deeply nested chain of AND expressions (like many WHERE conditions) +/// roundtrips correctly. This is the scenario from issue #18602. +#[test] +fn roundtrip_binary_expr_deeply_nested_and_chain() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + // Build a chain: a AND a AND a AND ... (100 times) + let col_a = col("a", &schema)?; + let mut expr = Arc::clone(&col_a); + for _ in 0..99 { + expr = binary(expr, Operator::And, Arc::clone(&col_a), &schema)?; + } + + roundtrip_test(Arc::new(FilterExec::try_new( + expr, + Arc::new(EmptyExec::new(schema)), + )?)) +} + +/// Test that a deeply nested chain of OR expressions roundtrips correctly. +#[test] +fn roundtrip_binary_expr_deeply_nested_or_chain() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, false); + let schema = Arc::new(Schema::new(vec![field_a])); + + let col_a = col("a", &schema)?; + let mut expr = Arc::clone(&col_a); + for _ in 0..99 { + expr = binary(expr, Operator::Or, Arc::clone(&col_a), &schema)?; + } + + roundtrip_test(Arc::new(FilterExec::try_new( + expr, + Arc::new(EmptyExec::new(schema)), + )?)) +} + +/// Test that alternating AND/OR operators produce correct results — +/// each sub-chain gets linearized independently. +#[test] +fn roundtrip_binary_expr_alternating_and_or() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Boolean, false); + let field_c = Field::new("c", DataType::Boolean, false); + let field_d = Field::new("d", DataType::Boolean, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b, field_c, field_d])); + + // (a AND b) OR (c AND d) + let a_and_b = binary( + col("a", &schema)?, + Operator::And, + col("b", &schema)?, + &schema, + )?; + let c_and_d = binary( + col("c", &schema)?, + Operator::And, + col("d", &schema)?, + &schema, + )?; + let expr = binary(a_and_b, Operator::Or, c_and_d, &schema)?; + + roundtrip_test(Arc::new(FilterExec::try_new( + expr, + Arc::new(EmptyExec::new(schema)), + )?)) +} + +/// Verify that the linearized proto format has a flat operands list +/// rather than deeply nested l/r fields. +#[test] +fn test_linearization_produces_flat_operands() -> Result<()> { + // Build: a AND a AND a AND a (4 operands, 3 levels of nesting) + let col_a: Arc = Arc::new(Column::new("a", 0)); + let expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::clone(&col_a), + Operator::And, + Arc::clone(&col_a), + )), + Operator::And, + Arc::clone(&col_a), + )), + Operator::And, + Arc::clone(&col_a), + )); + + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; + let proto = proto_converter.physical_expr_to_proto(&expr, &codec)?; + + // The top-level should use the operands field with 4 entries + match &proto.expr_type { + Some(protobuf::physical_expr_node::ExprType::BinaryExpr(b)) => { + assert!( + b.l.is_none(), + "l should be None when using linearized operands" + ); + assert!( + b.r.is_none(), + "r should be None when using linearized operands" + ); + assert_eq!( + b.operands.len(), + 4, + "Expected 4 linearized operands for a AND a AND a AND a" + ); + assert_eq!(b.op, "And"); + } + other => panic!("Expected BinaryExpr, got {other:?}"), + } + + Ok(()) +} + +/// Test that linearization stops when encountering a different operator. +/// For (a AND b) OR c, only the top-level OR should be represented, and +/// the left-hand AND subtree should be a separate nested BinaryExpr. +#[test] +fn test_linearization_stops_at_different_op() -> Result<()> { + // (a AND b) OR c + let a_and_b: Arc = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::And, + Arc::new(Column::new("b", 1)), + )); + let expr: Arc = Arc::new(BinaryExpr::new( + a_and_b, + Operator::Or, + Arc::new(Column::new("c", 2)), + )); + + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; + let proto = proto_converter.physical_expr_to_proto(&expr, &codec)?; + + // The top-level OR should have only 2 operands (can't linearize through AND) + match &proto.expr_type { + Some(protobuf::physical_expr_node::ExprType::BinaryExpr(b)) => { + assert_eq!( + b.operands.len(), + 2, + "Expected 2 operands for (a AND b) OR c" + ); + assert_eq!(b.op, "Or"); + // The first operand should be a nested AND BinaryExpr + match &b.operands[0].expr_type { + Some(protobuf::physical_expr_node::ExprType::BinaryExpr(inner)) => { + assert_eq!(inner.op, "And"); + assert_eq!(inner.operands.len(), 2); + } + other => panic!("Expected inner BinaryExpr(AND), got {other:?}"), + } + } + other => panic!("Expected BinaryExpr, got {other:?}"), + } + + Ok(()) +} + +/// Create a DataSourceExec backed by a ParquetSource that accepts filter pushdown, +/// along with a ConfigOptions that enables all dynamic filter pushdown options. +fn datasource_for_dynamic_filter_pushdown( + schema: &Arc, +) -> (Arc, ConfigOptions) { + let mut parquet_options = TableParquetOptions::new(); + parquet_options.global.pushdown_filters = true; + let source = Arc::new( + ParquetSource::new(Arc::clone(schema)) + .with_table_parquet_options(parquet_options), + ); + let scan_config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), source) + .with_file(PartitionedFile::new("/path/to/file.parquet", 1024)) + .build(); + + let mut config = ConfigOptions::default(); + config.execution.parquet.pushdown_filters = true; + config.optimizer.enable_join_dynamic_filter_pushdown = true; + config.optimizer.enable_aggregate_dynamic_filter_pushdown = true; + config.optimizer.enable_topk_dynamic_filter_pushdown = true; + + (DataSourceExec::from_data_source(scan_config), config) +} + +/// Test that plan containing a HashJoinExec with dynamic filter pushdown +/// can be serialized and deserialized while preserving references to the dynamic filter. +#[test] +fn test_hash_join_with_dynamic_filter_roundtrip() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Int64, false)])); + + let left_child = Arc::new(EmptyExec::new(Arc::clone(&schema))); + let (right_child, config) = datasource_for_dynamic_filter_pushdown(&schema); + + let on: Vec<(Arc, Arc)> = vec![( + Arc::new(Column::new("col", 0)), + Arc::new(Column::new("col", 0)), + )]; + + let hash_join = Arc::new(HashJoinExec::try_new( + left_child, + right_child, + on, + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + false, + )?) as Arc; + + // Run the optimizer rule for filter pushdown. + let optimizer = FilterPushdown::new_post_optimization(); + let plan = optimizer.optimize(hash_join, &config)?; + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let deserialized = roundtrip_test_and_return(plan, &ctx, &codec, &converter)?; + + // Extract the deserialized HashJoinExec and its dynamic filter. + let deserialized_join = deserialized + .downcast_ref::() + .expect("Should be HashJoinExec"); + let deserialized_hash_join_df = deserialized_join + .dynamic_filter_expr() + .expect("HashJoinExec should have a dynamic filter after roundtrip"); + + // Extract the dynamic filter pushed down to the probe side's ParquetSource. + let deserialized_predicate = parquet_source_predicate(deserialized_join.right()); + + // The HashJoinExec's dynamic filter and the probe side's predicate should + // refer to the same underlying expression. + let plan_df: Arc = deserialized_hash_join_df.clone(); + assert_dynamic_filters_equal(&plan_df, &deserialized_predicate); + assert_dynamic_filter_update_is_visible(&plan_df, &deserialized_predicate)?; + + Ok(()) +} + +/// returns a SessionContext with an empty `netflow` table registered +fn netflow_context() -> Result { + let ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![ + Field::new("dst_geo_country_name", DataType::Utf8, true), + Field::new("dst_geo_city_name", DataType::Utf8, true), + Field::new("packets", DataType::UInt64, true), + Field::new("src_addr", DataType::Utf8, true), + Field::new("dst_addr", DataType::Utf8, true), + ])); + + ctx.register_table("netflow", Arc::new(EmptyTable::new(schema)))?; + + Ok(ctx) +} + +/// Regression test for issue #18602: +/// https://github.com/apache/datafusion/issues/18602 +/// +/// The physical filter expression here contains a long chain of `AND` predicates. +/// Before linearizing `PhysicalBinaryExprNode`, encoding then decoding the protobuf +/// could fail with `DecodeError: recursion limit reached`. +#[tokio::test] +async fn roundtrip_issue_18602_complex_filter_decode_recursion() -> Result<()> { + let ctx = netflow_context()?; + let sql = "SELECT \ + dst_geo_country_name AS x_axis_1, \ + dst_geo_city_name AS x_axis_2, \ + sum(packets) AS y_axis_1 \ + FROM netflow \ + WHERE dst_geo_country_name IS NOT NULL \ + AND src_addr NOT LIKE '10.201.%' \ + AND dst_addr NOT LIKE '10.201.%' \ + AND src_addr NOT LIKE '10.202.%' \ + AND dst_addr NOT LIKE '10.202.%' \ + AND src_addr NOT LIKE '10.203.%' \ + AND dst_addr NOT LIKE '10.203.%' \ + AND src_addr NOT LIKE '10.204.%' \ + AND dst_addr NOT LIKE '10.204.%' \ + AND src_addr NOT LIKE '172.16.186.%' \ + AND dst_addr NOT LIKE '172.16.186.%' \ + AND src_addr NOT LIKE '172.16.187.%' \ + AND dst_addr NOT LIKE '172.16.187.%' \ + AND src_addr NOT LIKE '172.16.188.%' \ + AND dst_addr NOT LIKE '172.16.188.%' \ + AND src_addr NOT LIKE '10.102.45.%' \ + AND dst_addr NOT LIKE '10.102.45.%' \ + AND src_addr NOT LIKE '172.25.210.%' \ + AND dst_addr NOT LIKE '172.25.210.%' \ + AND src_addr NOT LIKE '172.25.211.%' \ + AND dst_addr NOT LIKE '172.25.211.%' \ + AND src_addr NOT LIKE '141.226.101.%' \ + AND dst_addr NOT LIKE '141.226.101.%' \ + AND src_addr NOT LIKE '167.86.40.%' \ + AND dst_addr NOT LIKE '167.86.40.%' \ + AND src_addr NOT LIKE '66.22.38.%' \ + AND dst_addr NOT LIKE '66.22.38.%' \ + AND src_addr != '168.143.191.55' \ + AND dst_addr != '168.143.191.55' \ + AND src_addr != '82.112.107.142' \ + AND dst_addr != '82.112.107.142' \ + AND src_addr != '20.76.39.176' \ + AND dst_addr != '20.76.39.176' \ + AND src_addr != '162.159.129.83' \ + AND dst_addr != '162.159.129.83' \ + AND src_addr != '34.201.223.155' \ + AND dst_addr != '34.201.223.155' \ + AND src_addr != '34.201.223.156' \ + AND dst_addr != '34.201.223.156' \ + AND src_addr != '34.201.223.157' \ + AND dst_addr != '34.201.223.157' \ + AND src_addr != '134.201.223.157' \ + AND dst_addr != '134.201.223.157' \ + AND src_addr != '341.201.223.157' \ + AND dst_addr != '341.201.223.157' \ + GROUP BY x_axis_1, x_axis_2 \ + ORDER BY y_axis_1 DESC \ + LIMIT 20"; + + roundtrip_test_sql_with_context(sql, &ctx).await +} + +/// Test that plan containing a AggregateExec with dynamic filter pushdown +/// can be serialized and deserialized while preserving references to the dynamic filter. +#[test] +fn test_aggregate_with_dynamic_filter_roundtrip() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let col_a: Arc = Arc::new(Column::new("a", 0)); + + let (child, config) = datasource_for_dynamic_filter_pushdown(&schema); + + let agg = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![]), + vec![ + AggregateExprBuilder::new( + datafusion::functions_aggregate::min_max::min_udaf(), + vec![Arc::clone(&col_a)], + ) + .schema(Arc::clone(&schema)) + .alias("min_a") + .build() + .map(Arc::new)?, + ], + vec![None], + child, + Arc::clone(&schema), + )?) as Arc; + + // Run the optimizer rule for filter pushdown. + let optimizer = FilterPushdown::new_post_optimization(); + let plan = optimizer.optimize(agg, &config)?; + + // Roundtrip with deduplication. + // + // Note: We don't use `roundtrip_test_and_return` here because there's a + // pre-existing issue with PhysicalGroupBy serialization where empty groups + // `[[]]` become `[]` after roundtrip. This behavior is unrelated to this test. + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&plan), + &codec, + &converter, + )?; + let deserialized = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &converter, + )?; + + // Extract the deserialized AggregateExec and its dynamic filter. + let deserialized_agg = deserialized + .downcast_ref::() + .expect("Should be AggregateExec"); + let deserialized_agg_df = deserialized_agg + .dynamic_filter_expr() + .expect("AggregateExec should have a dynamic filter after roundtrip"); + + // Extract the dynamic filter pushed down to the child ParquetSource. + let deserialized_predicate = parquet_source_predicate(deserialized_agg.input()); + + // The AggregateExec's dynamic filter and the child's predicate should + // refer to the same underlying expression. + let plan_df: Arc = deserialized_agg_df.clone(); + assert_dynamic_filters_equal(&plan_df, &deserialized_predicate); + assert_dynamic_filter_update_is_visible(&plan_df, &deserialized_predicate)?; + + Ok(()) +} + +/// Test that plan containing a SortExec with dynamic filter pushdown +/// can be serialized and deserialized while preserving references to the dynamic filter. +#[test] +fn test_sort_topk_with_dynamic_filter_roundtrip() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let col_a: Arc = Arc::new(Column::new("a", 0)); + + let (child, config) = datasource_for_dynamic_filter_pushdown(&schema); + + let sort = Arc::new( + SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(&col_a), + options: SortOptions::default(), + }]) + .unwrap(), + child, + ) + .with_fetch(Some(10)), + ) as Arc; + + // Verify the optimizer kept the dynamic filter on the SortExec. + let optimizer = FilterPushdown::new_post_optimization(); + let plan = optimizer.optimize(sort, &config)?; + + // Roundtrip with deduplication. + // + // Note: We don't use `roundtrip_test_and_return` here because + // `DeduplicatingDeserializer` rewrites cache hits via `with_new_children`, + // which sets `remapped_children: Some(...)` on the second encounter of a + // shared `DynamicFilterPhysicalExpr`. SortExec's `Debug` includes its + // dynamic filter, so the original-vs-deserialized structural equality check + // would fail purely on this artifact. + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let converter = DeduplicatingProtoConverter {}; + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&plan), + &codec, + &converter, + )?; + let deserialized = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &converter, + )?; + + // Extract the deserialized SortExec and its dynamic filter. + let deserialized_sort = deserialized + .downcast_ref::() + .expect("Should be SortExec"); + let deserialized_sort_df = deserialized_sort + .dynamic_filter_expr() + .expect("SortExec should have a dynamic filter after roundtrip"); + + // Extract the dynamic filter pushed down to the child ParquetSource. + let deserialized_predicate = parquet_source_predicate(deserialized_sort.input()); + + // The SortExec's dynamic filter and the child's predicate should + // refer to the same underlying expression. + let plan_df: Arc = deserialized_sort_df; + assert_dynamic_filters_equal(&plan_df, &deserialized_predicate); + assert_dynamic_filter_update_is_visible(&plan_df, &deserialized_predicate)?; + + Ok(()) +} + +/// A custom [`ExecutionPlan`] which stores [`PhysicalExpr`]s. +struct CustomExecWithExprs { + exprs: Vec>, + child: Arc, +} + +#[derive(Clone, PartialEq, Message)] +struct CustomExecWithExprsProto { + #[prost(message, repeated, tag = "1")] + exprs: Vec, +} + +impl std::fmt::Debug for CustomExecWithExprs { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CustomExecWithExprs") + .field("exprs", &self.exprs) + .field("child", &self.child) + .finish() + } +} + +impl CustomExecWithExprs { + fn new(exprs: Vec>, child: Arc) -> Self { + Self { exprs, child } + } +} + +impl DisplayAs for CustomExecWithExprs { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "CustomExecWithExprs") + } +} + +impl ExecutionPlan for CustomExecWithExprs { + fn name(&self) -> &str { + "CustomExecWithExprs" + } + + fn schema(&self) -> SchemaRef { + self.child.schema() + } + + fn properties(&self) -> &Arc { + self.child.properties() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + unreachable!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unreachable!() + } +} + +/// A [`PhysicalExtensionCodec`] for [`CustomExecWithExprs`]. +#[derive(Debug)] +struct CustomExecWithExprsCodec {} + +impl PhysicalExtensionCodec for CustomExecWithExprsCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + ctx: &TaskContext, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result> { + let decode_ctx = PhysicalPlanDecodeContext::new(ctx, self); + let input_schema = inputs[0].schema(); + let proto = CustomExecWithExprsProto::decode(buf) + .map_err(|e| internal_datafusion_err!("Failed to decode custom exec: {e}"))?; + let exprs = proto + .exprs + .iter() + .map(|expr_proto| { + proto_converter.proto_to_physical_expr( + expr_proto, + input_schema.as_ref(), + &decode_ctx, + ) + }) + .collect::>>()?; + + Ok(Arc::new(CustomExecWithExprs::new(exprs, inputs[0].clone()))) + } + + fn try_encode( + &self, + node: Arc, + buf: &mut Vec, + proto_converter: &dyn PhysicalProtoConverterExtension, + ) -> Result<()> { + let custom = node + .downcast_ref::() + .ok_or_else(|| internal_datafusion_err!("Expected CustomExecWithExprs"))?; + let proto = CustomExecWithExprsProto { + exprs: custom + .exprs + .iter() + .map(|expr| proto_converter.physical_expr_to_proto(expr, self)) + .collect::>>()?, + }; + proto + .encode(buf) + .map_err(|e| internal_datafusion_err!("Failed to encode custom exec: {e}"))?; + + Ok(()) + } +} + +/// Tests that a custom [`ExecutionPlan`] with [`PhysicalExpr`] can +/// dedupe dynamic filters by using the proto converter in its +/// [`PhysicalExtensionCodec`] implementation. +#[test] +fn test_custom_node_with_dynamic_filter_dedup_roundtrip() -> Result<()> { + // Create the plan: + // + // FilterExec(dynamic_filter) + // -> CustomExecWithExprs(exprs: [dynamic_filter]) + // -> EmptyExec + // + // The same dynamic filter expression is saved in both the FilterExec and CustomExecWithExprs. + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::new(Column::new("a", 0)) as Arc], + lit(true), + )); + let dynamic_filter_expr: Arc = dynamic_filter; + + let empty = Arc::new(EmptyExec::new(Arc::clone(&schema))); + let custom_exec = Arc::new(CustomExecWithExprs::new( + vec![Arc::clone(&dynamic_filter_expr)], + empty, + )); + let filter_exec = Arc::new(FilterExec::try_new( + Arc::clone(&dynamic_filter_expr), + custom_exec, + )?) as Arc; + + // Roundtrip with DeduplicatingProtoConverter + let codec = CustomExecWithExprsCodec {}; + let converter = DeduplicatingProtoConverter {}; + + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&filter_exec), + &codec, + &converter, + )?; + + let ctx = SessionContext::new(); + let deser_converter = DeduplicatingProtoConverter {}; + let deserialized = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &deser_converter, + )?; + + // Extract the deserialized FilterExec's dynamic filter + let deser_filter = deserialized + .downcast_ref::() + .expect("Top-level should be FilterExec"); + let deser_filter_df = deser_filter.predicate(); + + // Extract the deserialized custom node's dynamic filter + let deser_custom = deser_filter + .input() + .downcast_ref::() + .expect("FilterExec child should be CustomExecWithExprs"); + assert_eq!(deser_custom.exprs.len(), 1, "Should have one expression"); + let [deser_custom_df] = deser_custom.exprs.as_slice() else { + return internal_err!("Custom node should have one expression"); + }; + + // Pass the un-remapped filter first so the helper's `with_new_children` + // rewrite can reconstruct the remapped form on the other side. + assert_dynamic_filters_equal(deser_custom_df, deser_filter_df); + assert_dynamic_filter_update_is_visible(deser_custom_df, deser_filter_df)?; + + Ok(()) +} + +#[test] +fn roundtrip_parquet_exec_partitioned_by_file_group() -> Result<()> { + use datafusion::datasource::physical_plan::FileScanConfig; + + let file_schema = + Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, false)])); + let file_source = Arc::new(ParquetSource::new(Arc::clone(&file_schema))); + let scan_config = + FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_source) + .with_file_groups(vec![FileGroup::new(vec![PartitionedFile::new( + "/path/to/file.parquet".to_string(), + 1024, + )])]) + .with_partitioned_by_file_group(true) + .build(); + + assert!(scan_config.partitioned_by_file_group); + + let exec_plan: Arc = DataSourceExec::from_data_source(scan_config); + + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let proto_converter = DefaultPhysicalProtoConverter {}; + let bytes = physical_plan_to_bytes_with_proto_converter( + Arc::clone(&exec_plan), + &codec, + &proto_converter, + )?; + let result_plan = physical_plan_from_bytes_with_proto_converter( + bytes.as_ref(), + ctx.task_ctx().as_ref(), + &codec, + &proto_converter, + )?; + + let data_source_exec = result_plan + .downcast_ref::() + .expect("Expected DataSourceExec"); + let file_scan_config = data_source_exec + .data_source() + .downcast_ref::() + .expect("Expected FileScanConfig"); + + assert!(file_scan_config.partitioned_by_file_group); + + Ok(()) +} diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index f45a62e948740..850fd42ce131b 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -23,12 +23,12 @@ use arrow::datatypes::{DataType, Field}; use datafusion::execution::FunctionRegistry; use datafusion::prelude::SessionContext; use datafusion_expr::expr::Placeholder; -use datafusion_expr::{col, create_udf, lit, ColumnarValue}; +use datafusion_expr::{ColumnarValue, col, create_udf, lit}; use datafusion_expr::{Expr, Volatility}; use datafusion_functions::string; use datafusion_proto::bytes::Serializeable; -use datafusion_proto::logical_plan::to_proto::serialize_expr; use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec; +use datafusion_proto::logical_plan::to_proto::serialize_expr; #[test] #[should_panic( @@ -42,7 +42,7 @@ fn bad_decode() { #[cfg(feature = "json")] fn plan_to_json() { use datafusion_common::DFSchema; - use datafusion_expr::{logical_plan::EmptyRelation, LogicalPlan}; + use datafusion_expr::{LogicalPlan, logical_plan::EmptyRelation}; use datafusion_proto::bytes::logical_plan_to_json; let plan = LogicalPlan::EmptyRelation(EmptyRelation { @@ -77,7 +77,8 @@ fn udf_roundtrip_with_registry() { .call(vec![lit("")]); let bytes = expr.to_bytes().unwrap(); - let deserialized_expr = Expr::from_bytes_with_registry(&bytes, &ctx).unwrap(); + let deserialized_expr = + Expr::from_bytes_with_ctx(&bytes, ctx.task_ctx().as_ref()).unwrap(); assert_eq!(expr, deserialized_expr); } @@ -281,7 +282,8 @@ fn test_expression_serialization_roundtrip() { let extension_codec = DefaultLogicalExtensionCodec {}; let proto = serialize_expr(&expr, &extension_codec).unwrap(); - let deserialize = parse_expr(&proto, &ctx, &extension_codec).unwrap(); + let deserialize = + parse_expr(&proto, ctx.task_ctx().as_ref(), &extension_codec).unwrap(); let serialize_name = extract_function_name(&expr); let deserialize_name = extract_function_name(&deserialize); diff --git a/datafusion/pruning/Cargo.toml b/datafusion/pruning/Cargo.toml index bd898cba202ba..e6f4bb6f273c9 100644 --- a/datafusion/pruning/Cargo.toml +++ b/datafusion/pruning/Cargo.toml @@ -23,10 +23,10 @@ datafusion-expr-common = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } -itertools = { workspace = true } log = { workspace = true } [dev-dependencies] datafusion-expr = { workspace = true } datafusion-functions-nested = { workspace = true } insta = { workspace = true } +itertools = { workspace = true } diff --git a/datafusion/pruning/LICENSE.txt b/datafusion/pruning/LICENSE.txt new file mode 120000 index 0000000000000..1ef648f64b34f --- /dev/null +++ b/datafusion/pruning/LICENSE.txt @@ -0,0 +1 @@ +../../LICENSE.txt \ No newline at end of file diff --git a/datafusion/pruning/NOTICE.txt b/datafusion/pruning/NOTICE.txt new file mode 120000 index 0000000000000..fb051c92b10b2 --- /dev/null +++ b/datafusion/pruning/NOTICE.txt @@ -0,0 +1 @@ +../../NOTICE.txt \ No newline at end of file diff --git a/datafusion/pruning/src/file_pruner.rs b/datafusion/pruning/src/file_pruner.rs index ee86a8cc8cd58..661832915c40f 100644 --- a/datafusion/pruning/src/file_pruner.rs +++ b/datafusion/pruning/src/file_pruner.rs @@ -19,113 +19,151 @@ use std::sync::Arc; -use arrow::datatypes::{FieldRef, Schema, SchemaRef}; -use datafusion_common::{ - pruning::{ - CompositePruningStatistics, PartitionPruningStatistics, PrunableStatistics, - PruningStatistics, - }, - Result, -}; +use arrow::datatypes::{FieldRef, SchemaRef}; +use datafusion_common::{Result, internal_datafusion_err, pruning::PrunableStatistics}; use datafusion_datasource::PartitionedFile; -use datafusion_physical_expr_common::physical_expr::{snapshot_generation, PhysicalExpr}; +use datafusion_physical_expr::DynamicFilterTracking; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::metrics::Count; -use itertools::Itertools; use log::debug; use crate::build_pruning_predicate; -/// Prune based on partition values and file-level statistics. +/// Prune based on file-level statistics. +/// +/// Note: Partition column pruning is handled earlier via `replace_columns_with_literals` +/// which substitutes partition column references with their literal values before +/// the predicate reaches this pruner. pub struct FilePruner { - predicate_generation: Option, predicate: Arc, - /// Schema used for pruning, which combines the file schema and partition fields. - /// Partition fields are always at the end, as they are during scans. - pruning_schema: Arc, - partitioned_file: PartitionedFile, - partition_fields: Vec, + /// Tracks the dynamic filters inside `predicate` so we only rebuild the + /// pruning predicate when one of them has actually moved. + tracking: DynamicFilterTracking, + /// Whether [`Self::should_prune`] has built+evaluated the pruning predicate + /// at least once. The first check always runs; subsequent checks only run + /// when a watched dynamic filter changed. + checked_once: bool, + /// Schema used for pruning (the logical file schema). + file_schema: SchemaRef, + file_stats_pruning: PrunableStatistics, predicate_creation_errors: Count, } impl FilePruner { + #[deprecated( + since = "52.0.0", + note = "Use `try_new` instead which returns None if no statistics are available" + )] + #[expect(clippy::needless_pass_by_value)] pub fn new( predicate: Arc, logical_file_schema: &SchemaRef, - partition_fields: Vec, + _partition_fields: Vec, partitioned_file: PartitionedFile, predicate_creation_errors: Count, ) -> Result { - // Build a pruning schema that combines the file fields and partition fields. - // Partition fields are always at the end. - let pruning_schema = Arc::new( - Schema::new( - logical_file_schema - .fields() - .iter() - .cloned() - .chain(partition_fields.iter().cloned()) - .collect_vec(), + Self::try_new( + predicate, + logical_file_schema, + &partitioned_file, + predicate_creation_errors, + ) + .ok_or_else(|| { + internal_datafusion_err!( + "FilePruner::new called on a file without statistics: {:?}", + partitioned_file ) - .with_metadata(logical_file_schema.metadata().clone()), - ); - Ok(Self { - // Initialize the predicate generation to None so that the first time we call `should_prune` we actually check the predicate - // Subsequent calls will only do work if the predicate itself has changed. - // See `snapshot_generation` for more info. - predicate_generation: None, + }) + } + + /// Create a file pruner for this file, or `None` when pruning it cannot + /// help. + /// + /// Returns `None` when the file has no statistics struct to evaluate a + /// pruning predicate against, or when the predicate is purely static and the + /// file has no usable column statistics — in that case planning already did + /// everything such a pruner could. A predicate carrying a dynamic filter is + /// always accepted (given a statistics struct), since it may prune via + /// partition-value folding even without column statistics. + pub fn try_new( + predicate: Arc, + file_schema: &SchemaRef, + partitioned_file: &PartitionedFile, + predicate_creation_errors: Count, + ) -> Option { + // A pruning predicate is evaluated against a statistics struct, so one + // must exist (its columns may all be `Absent`). + let file_stats = partitioned_file.statistics.as_ref()?; + let tracking = DynamicFilterTracking::classify(&predicate); + // Only build a pruner when it could prune something planning didn't + // already: the file has real column statistics, or the predicate carries + // a dynamic filter (whose value, or folded partition columns, can prune + // even without column statistics). For a purely static predicate with no + // usable stats there is nothing to gain. + if !partitioned_file.has_statistics() && !tracking.contains_dynamic_filter() { + return None; + } + let file_stats_pruning = + PrunableStatistics::new(vec![file_stats.clone()], Arc::clone(file_schema)); + Some(Self { predicate, - pruning_schema, - partitioned_file, - partition_fields, + tracking, + checked_once: false, + file_schema: Arc::clone(file_schema), + file_stats_pruning, predicate_creation_errors, }) } + /// Returns `true` if this pruner watches a dynamic filter that can still + /// change, meaning [`Self::should_prune`] is worth re-checking as the scan + /// progresses. When `false`, the predicate is effectively static for the + /// remainder of the scan and the caller can avoid wrapping the stream in a + /// per-batch re-pruning adapter. + pub fn is_watching(&self) -> bool { + matches!(self.tracking, DynamicFilterTracking::Watching(_)) + } + pub fn should_prune(&mut self) -> Result { - let new_generation = snapshot_generation(&self.predicate); - if let Some(current_generation) = self.predicate_generation.as_mut() { - if *current_generation == new_generation { - return Ok(false); - } - *current_generation = new_generation; + // Building the pruning predicate is expensive (it involves expression + // analysis), so we only do it on the first check and whenever a dynamic + // filter inside the predicate has actually moved. + // + // Dynamic filter expressions can change their values during query + // execution; `DynamicFilterTracking` watches the still-incomplete + // filters and reports a change at most once per update. A purely static + // predicate (or one whose dynamic filters have all completed) is checked + // exactly once. + let should_build = if self.checked_once { + self.tracking.watcher().is_some_and(|w| w.changed()) } else { - self.predicate_generation = Some(new_generation); + self.checked_once = true; + true + }; + if !should_build { + return Ok(false); } let pruning_predicate = build_pruning_predicate( Arc::clone(&self.predicate), - &self.pruning_schema, + &self.file_schema, &self.predicate_creation_errors, ); - if let Some(pruning_predicate) = pruning_predicate { - // The partition column schema is the schema of the table - the schema of the file - let mut pruning = Box::new(PartitionPruningStatistics::try_new( - vec![self.partitioned_file.partition_values.clone()], - self.partition_fields.clone(), - )?) as Box; - if let Some(stats) = &self.partitioned_file.statistics { - let stats_pruning = Box::new(PrunableStatistics::new( - vec![Arc::clone(stats)], - Arc::clone(&self.pruning_schema), - )); - pruning = Box::new(CompositePruningStatistics::new(vec![ - pruning, - stats_pruning, - ])); - } - match pruning_predicate.prune(pruning.as_ref()) { - Ok(values) => { - assert!(values.len() == 1); - // We expect a single container -> if all containers are false skip this file - if values.into_iter().all(|v| !v) { - return Ok(true); - } - } - // Stats filter array could not be built, so we can't prune - Err(e) => { - debug!("Ignoring error building pruning predicate for file: {e}"); - self.predicate_creation_errors.add(1); + let Some(pruning_predicate) = pruning_predicate else { + return Ok(false); + }; + match pruning_predicate.prune(&self.file_stats_pruning) { + Ok(values) => { + assert!(values.len() == 1); + // We expect a single container -> if all containers are false skip this file + if values.into_iter().all(|v| !v) { + return Ok(true); } } + // Stats filter array could not be built, so we can't prune + Err(e) => { + debug!("Ignoring error building pruning predicate for file: {e}"); + self.predicate_creation_errors.add(1); + } } Ok(false) diff --git a/datafusion/pruning/src/lib.rs b/datafusion/pruning/src/lib.rs index 35e1baef239b3..be17f29eaafa0 100644 --- a/datafusion/pruning/src/lib.rs +++ b/datafusion/pruning/src/lib.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] mod file_pruner; @@ -24,6 +22,6 @@ mod pruning_predicate; pub use file_pruner::FilePruner; pub use pruning_predicate::{ - build_pruning_predicate, PredicateRewriter, PruningPredicate, PruningStatistics, - RequiredColumns, UnhandledPredicateHook, + PredicateRewriter, PruningPredicate, PruningStatistics, RequiredColumns, + UnhandledPredicateHook, build_pruning_predicate, }; diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index 4084da820c0d4..bacdd7032ead2 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use arrow::array::AsArray; use arrow::{ - array::{new_null_array, ArrayRef, BooleanArray}, + array::{ArrayRef, BooleanArray, new_null_array}, datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::{RecordBatch, RecordBatchOptions}, }; @@ -35,17 +35,17 @@ use datafusion_physical_plan::metrics::Count; use log::{debug, trace}; use datafusion_common::error::Result; -use datafusion_common::tree_node::TransformedResult; -use datafusion_common::{assert_eq_or_internal_err, Column, DFSchema}; +use datafusion_common::tree_node::{TransformedResult, TreeNodeRecursion}; +use datafusion_common::{Column, DFSchema, assert_eq_or_internal_err}; use datafusion_common::{ - internal_datafusion_err, plan_datafusion_err, plan_err, + ScalarValue, internal_datafusion_err, plan_datafusion_err, plan_err, tree_node::{Transformed, TreeNode}, - ScalarValue, }; +use datafusion_expr_common::casts::try_cast_literal_to_type; use datafusion_expr_common::operator::Operator; -use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; -use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; -use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; +use datafusion_physical_expr::utils::{Guarantee, LiteralGuarantee}; +use datafusion_physical_expr::{PhysicalExprRef, expressions as phys_expr}; +use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr_opt; use datafusion_physical_plan::{ColumnarValue, PhysicalExpr}; /// Used to prove that arbitrary predicates (boolean expression) can not @@ -86,7 +86,7 @@ use datafusion_physical_plan::{ColumnarValue, PhysicalExpr}; /// example of how to use `PruningPredicate` to prune files based on min/max /// values. /// -/// [`pruning.rs` example in the `datafusion-examples`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/pruning.rs +/// [`pruning.rs` example in the `datafusion-examples`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/query_planning/pruning.rs /// /// Given an expression like `x = 5` and statistics for 3 containers (Row /// Groups, files, etc) `A`, `B`, and `C`: @@ -455,10 +455,29 @@ impl PruningPredicate { /// /// See the struct level documentation on [`PruningPredicate`] for more /// details. - pub fn try_new(expr: Arc, schema: SchemaRef) -> Result { - // Get a (simpler) snapshot of the physical expr here to use with `PruningPredicate` - // which does not handle dynamic exprs in general - let expr = snapshot_physical_expr(expr)?; + /// + /// Note that `PruningPredicate` does not attempt to normalize or simplify + /// the input expression unless calling [`snapshot_physical_expr_opt`] + /// returns a new expression. + /// It is recommended that you pass the expressions through [`PhysicalExprSimplifier`] + /// before calling this method to make sure the expressions can be used for pruning. + pub fn try_new(mut expr: Arc, schema: SchemaRef) -> Result { + // Get a (simpler) snapshot of the physical expr here to use with `PruningPredicate`. + // In particular this unravels any `DynamicFilterPhysicalExpr`s by snapshotting them + // so that PruningPredicate can work with a static expression. + let tf = snapshot_physical_expr_opt(expr)?; + if tf.transformed { + // If we had an expression such as Dynamic(part_col < 5 and col < 10) + // (this could come from something like `select * from t order by part_col, col, limit 10`) + // after snapshotting and because `DynamicFilterPhysicalExpr` applies child replacements to its + // children after snapshotting and previously `replace_columns_with_literals` may have been called with partition values + // the expression we have now is `8 < 5 and col < 10`. + // Thus we need as simplifier pass to get `false and col < 10` => `false` here. + let simplifier = PhysicalExprSimplifier::new(&schema); + expr = simplifier.simplify(tf.data)?; + } else { + expr = tf.data; + } let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; // build predicate expression once @@ -473,7 +492,6 @@ impl PruningPredicate { // Simplify the newly created predicate to get rid of redundant casts, comparisons, etc. let predicate_expr = PhysicalExprSimplifier::new(&predicate_schema).simplify(predicate_expr)?; - let literal_guarantees = LiteralGuarantee::analyze(&expr); Ok(Self { @@ -585,8 +603,6 @@ impl PruningPredicate { is_always_true(&self.predicate_expr) && self.literal_guarantees.is_empty() } - // this is only used by `parquet` feature right now - #[allow(dead_code)] pub fn required_columns(&self) -> &RequiredColumns { &self.required_columns } @@ -680,15 +696,13 @@ impl BoolVecBuilder { } fn is_always_true(expr: &Arc) -> bool { - expr.as_any() - .downcast_ref::() + expr.downcast_ref::() .map(|l| matches!(l.value(), ScalarValue::Boolean(Some(true)))) .unwrap_or_default() } fn is_always_false(expr: &Arc) -> bool { - expr.as_any() - .downcast_ref::() + expr.downcast_ref::() .map(|l| matches!(l.value(), ScalarValue::Boolean(Some(false)))) .unwrap_or_default() } @@ -725,8 +739,6 @@ impl RequiredColumns { /// * `a > 5 OR a < 10` returns `Some(a)` /// * `a > 5 OR b < 10` returns `None` /// * `true` returns None - #[allow(dead_code)] - // this fn is only used by `parquet` feature right now, thus the `allow(dead_code)` pub fn single_column(&self) -> Option<&phys_expr::Column> { if self.columns.windows(2).all(|w| { // check if all columns are the same (ignoring statistics and field) @@ -915,7 +927,7 @@ fn build_statistics_record_batch( StatisticsType::Min => statistics.min_values(&column), StatisticsType::Max => statistics.max_values(&column), StatisticsType::NullCount => statistics.null_counts(&column), - StatisticsType::RowCount => statistics.row_counts(&column), + StatisticsType::RowCount => statistics.row_counts(), }; let array = array.unwrap_or_else(|| new_null_array(data_type, num_containers)); @@ -959,24 +971,41 @@ impl<'a> PruningExpressionBuilder<'a> { fn try_new( left: &'a Arc, right: &'a Arc, + left_columns: ColumnReferenceCount, + right_columns: ColumnReferenceCount, op: Operator, schema: &'a SchemaRef, required_columns: &'a mut RequiredColumns, ) -> Result { // find column name; input could be a more complicated expression - let left_columns = collect_columns(left); - let right_columns = collect_columns(right); - let (column_expr, scalar_expr, columns, correct_operator) = - match (left_columns.len(), right_columns.len()) { - (1, 0) => (left, right, left_columns, op), - (0, 1) => (right, left, right_columns, reverse_operator(op)?), - _ => { - // if more than one column used in expression - not supported - return plan_err!( - "Multi-column expressions are not currently supported" - ); - } - }; + let (column_expr, scalar_expr, column, correct_operator) = match ( + left_columns, + right_columns, + ) { + (ColumnReferenceCount::One(column), ColumnReferenceCount::Zero) => { + (left, right, column, op) + } + (ColumnReferenceCount::Zero, ColumnReferenceCount::One(column)) => { + (right, left, column, reverse_operator(op)?) + } + (ColumnReferenceCount::One(_), ColumnReferenceCount::One(_)) => { + // both sides have one column - not supported + return plan_err!( + "Expression not supported for pruning: left has 1 column, right has 1 column" + ); + } + (ColumnReferenceCount::Zero, ColumnReferenceCount::Zero) => { + // both sides are literals - should be handled before calling try_new + return plan_err!( + "Pruning literal expressions is not supported, please call PhysicalExprSimplifier first" + ); + } + (ColumnReferenceCount::Many, _) | (_, ColumnReferenceCount::Many) => { + return plan_err!( + "Expression not supported for pruning: left or right has multiple columns" + ); + } + }; let df_schema = DFSchema::try_from(Arc::clone(schema))?; let (column_expr, correct_operator, scalar_expr) = rewrite_expr_to_prunable( @@ -985,7 +1014,6 @@ impl<'a> PruningExpressionBuilder<'a> { scalar_expr, df_schema, )?; - let column = columns.iter().next().unwrap().clone(); let field = match schema.column_with_name(column.name()) { Some((_, f)) => f, _ => { @@ -1084,58 +1112,62 @@ fn rewrite_expr_to_prunable( return plan_err!("rewrite_expr_to_prunable only support compare expression"); } - let column_expr_any = column_expr.as_any(); - - if column_expr_any - .downcast_ref::() - .is_some() - { + if column_expr.downcast_ref::().is_some() { // `col op lit()` Ok((Arc::clone(column_expr), op, Arc::clone(scalar_expr))) - } else if let Some(cast) = column_expr_any.downcast_ref::() { + } else if let Some(cast) = column_expr.downcast_ref::() { // `cast(col) op lit()` - let arrow_schema = schema.as_arrow(); - let from_type = cast.expr().data_type(arrow_schema)?; - verify_support_type_for_prune(&from_type, cast.cast_type())?; - let (left, op, right) = - rewrite_expr_to_prunable(cast.expr(), op, scalar_expr, schema)?; - let left = Arc::new(phys_expr::CastExpr::new( + let (left, op, right) = rewrite_cast_child_to_prunable( + cast.expr(), + cast.cast_type(), + op, + scalar_expr, + schema, + )?; + let left = Arc::new(phys_expr::CastExpr::new_with_target_field( left, - cast.cast_type().clone(), + Arc::clone(cast.target_field()), None, )); + // PruningPredicate does not support pruning on nested fields yet. + // End-to-end nested-field pruning also requires Parquet statistics + // extraction to agree with PruningPredicate on a stats representation + // for nested field expressions. Ok((left, op, right)) - } else if let Some(try_cast) = - column_expr_any.downcast_ref::() - { + } else if let Some(try_cast) = column_expr.downcast_ref::() { // `try_cast(col) op lit()` - let arrow_schema = schema.as_arrow(); - let from_type = try_cast.expr().data_type(arrow_schema)?; - verify_support_type_for_prune(&from_type, try_cast.cast_type())?; - let (left, op, right) = - rewrite_expr_to_prunable(try_cast.expr(), op, scalar_expr, schema)?; + let (left, op, right) = rewrite_cast_child_to_prunable( + try_cast.expr(), + try_cast.cast_type(), + op, + scalar_expr, + schema, + )?; let left = Arc::new(phys_expr::TryCastExpr::new( left, try_cast.cast_type().clone(), )); Ok((left, op, right)) - } else if let Some(neg) = column_expr_any.downcast_ref::() { + } else if let Some(neg) = column_expr.downcast_ref::() { // `-col > lit()` --> `col < -lit()` let (left, op, right) = rewrite_expr_to_prunable(neg.arg(), op, scalar_expr, schema)?; let right = Arc::new(phys_expr::NegativeExpr::new(right)); Ok((left, reverse_operator(op)?, right)) - } else if let Some(not) = column_expr_any.downcast_ref::() { + } else if let Some(not) = column_expr.downcast_ref::() { // `!col = true` --> `col = !true` - if op != Operator::Eq && op != Operator::NotEq { - return plan_err!("Not with operator other than Eq / NotEq is not supported"); + if !matches!( + op, + Operator::Eq + | Operator::NotEq + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom + ) { + return plan_err!( + "Not with operator other than Eq / NotEq / IsDistinctFrom / IsNotDistinctFrom is not supported" + ); } - if not - .arg() - .as_any() - .downcast_ref::() - .is_some() - { + if not.arg().downcast_ref::().is_some() { let left = Arc::clone(not.arg()); let right = Arc::new(phys_expr::NotExpr::new(Arc::clone(scalar_expr))); Ok((left, reverse_operator(op)?, right)) @@ -1147,6 +1179,20 @@ fn rewrite_expr_to_prunable( } } +fn rewrite_cast_child_to_prunable( + cast_child_expr: &PhysicalExprRef, + cast_type: &DataType, + op: Operator, + scalar_expr: &PhysicalExprRef, + schema: DFSchema, +) -> Result<(PhysicalExprRef, Operator, PhysicalExprRef)> { + verify_support_type_for_prune( + &cast_child_expr.data_type(schema.as_arrow())?, + cast_type, + )?; + rewrite_expr_to_prunable(cast_child_expr, op, scalar_expr, schema) +} + fn is_compare_op(op: Operator) -> bool { matches!( op, @@ -1156,18 +1202,13 @@ fn is_compare_op(op: Operator) -> bool { | Operator::LtEq | Operator::Gt | Operator::GtEq + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom | Operator::LikeMatch | Operator::NotLikeMatch ) } -fn is_string_type(data_type: &DataType) -> bool { - matches!( - data_type, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View - ) -} - // The pruning logic is based on the comparing the min/max bounds. // Must make sure the two type has order. // For example, casts from string to numbers is not correct. @@ -1176,20 +1217,20 @@ fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) -> Re // Dictionary casts are always supported as long as the value types are supported let from_type = match from_type { DataType::Dictionary(_, t) => { - return verify_support_type_for_prune(t.as_ref(), to_type) + return verify_support_type_for_prune(t.as_ref(), to_type); } _ => from_type, }; let to_type = match to_type { DataType::Dictionary(_, t) => { - return verify_support_type_for_prune(from_type, t.as_ref()) + return verify_support_type_for_prune(from_type, t.as_ref()); } _ => to_type, }; // If both types are strings or both are not strings (number, timestamp, etc) // then we can compare them. // PruningPredicate does not support casting of strings to numbers and such. - if is_string_type(from_type) == is_string_type(to_type) { + if from_type.is_string() == to_type.is_string() { Ok(()) } else { plan_err!( @@ -1205,10 +1246,10 @@ fn rewrite_column_expr( column_new: &phys_expr::Column, ) -> Result> { e.transform(|expr| { - if let Some(column) = expr.as_any().downcast_ref::() { - if column == column_old { - return Ok(Transformed::yes(Arc::new(column_new.clone()))); - } + if let Some(column) = expr.downcast_ref::() + && column == column_old + { + return Ok(Transformed::yes(Arc::new(column_new.clone()))); } Ok(Transformed::no(expr)) @@ -1236,7 +1277,7 @@ fn build_single_column_expr( ) -> Option> { let field = schema.field_with_name(column.name()).ok()?; - if matches!(field.data_type(), &DataType::Boolean) { + if *field.data_type() == DataType::Boolean { let col_ref = Arc::new(column.clone()) as _; let min = required_columns @@ -1279,7 +1320,7 @@ fn build_is_null_column_expr( required_columns: &mut RequiredColumns, with_not: bool, ) -> Option> { - if let Some(col) = expr.as_any().downcast_ref::() { + if let Some(col) = expr.downcast_ref::() { let field = schema.field_with_name(col.name()).ok()?; let null_count_field = &Field::new(field.name(), DataType::UInt64, true); @@ -1396,12 +1437,11 @@ fn build_predicate_expression( return Arc::clone(expr); } // predicate expression can only be a binary expression - let expr_any = expr.as_any(); - if let Some(is_null) = expr_any.downcast_ref::() { + if let Some(is_null) = expr.downcast_ref::() { return build_is_null_column_expr(is_null.arg(), schema, required_columns, false) .unwrap_or_else(|| unhandled_hook.handle(expr)); } - if let Some(is_not_null) = expr_any.downcast_ref::() { + if let Some(is_not_null) = expr.downcast_ref::() { return build_is_null_column_expr( is_not_null.arg(), schema, @@ -1410,20 +1450,20 @@ fn build_predicate_expression( ) .unwrap_or_else(|| unhandled_hook.handle(expr)); } - if let Some(col) = expr_any.downcast_ref::() { + if let Some(col) = expr.downcast_ref::() { return build_single_column_expr(col, schema, required_columns, false) .unwrap_or_else(|| unhandled_hook.handle(expr)); } - if let Some(not) = expr_any.downcast_ref::() { + if let Some(not) = expr.downcast_ref::() { // match !col (don't do so recursively) - if let Some(col) = not.arg().as_any().downcast_ref::() { + if let Some(col) = not.arg().downcast_ref::() { return build_single_column_expr(col, schema, required_columns, true) .unwrap_or_else(|| unhandled_hook.handle(expr)); } else { return unhandled_hook.handle(expr); } } - if let Some(in_list) = expr_any.downcast_ref::() { + if let Some(in_list) = expr.downcast_ref::() { if !in_list.list().is_empty() && in_list.list().len() <= MAX_LIST_VALUE_SIZE_REWRITE { @@ -1461,13 +1501,13 @@ fn build_predicate_expression( } let (left, op, right) = { - if let Some(bin_expr) = expr_any.downcast_ref::() { + if let Some(bin_expr) = expr.downcast_ref::() { ( Arc::clone(bin_expr.left()), *bin_expr.op(), Arc::clone(bin_expr.right()), ) - } else if let Some(like_expr) = expr_any.downcast_ref::() { + } else if let Some(like_expr) = expr.downcast_ref::() { if like_expr.case_insensitive() { return unhandled_hook.handle(expr); } @@ -1514,8 +1554,17 @@ fn build_predicate_expression( return expr; } - let expr_builder = - PruningExpressionBuilder::try_new(&left, &right, op, schema, required_columns); + let left_columns = ColumnReferenceCount::from_expression(&left); + let right_columns = ColumnReferenceCount::from_expression(&right); + let expr_builder = PruningExpressionBuilder::try_new( + &left, + &right, + left_columns, + right_columns, + op, + schema, + required_columns, + ); let mut expr_builder = match expr_builder { Ok(builder) => builder, // allow partial failure in predicate expression generation @@ -1530,49 +1579,62 @@ fn build_predicate_expression( .unwrap_or_else(|_| unhandled_hook.handle(expr)) } +/// Count of distinct column references in an expression. +/// This is the same as [`collect_columns`] but optimized to stop counting +/// once more than one distinct column is found. +/// +/// For example, in expression `col1 + col2`, the count is `Many`. +/// In expression `col1 + 5`, the count is `One`. +/// In expression `5 + 10`, the count is `Zero`. +/// +/// [`collect_columns`]: datafusion_physical_expr::utils::collect_columns +#[derive(Debug, PartialEq, Eq)] +enum ColumnReferenceCount { + /// no column references + Zero, + /// Only one column reference + One(phys_expr::Column), + /// More than one column reference + Many, +} + +impl ColumnReferenceCount { + /// Count the number of distinct column references in an expression + fn from_expression(expr: &Arc) -> Self { + let mut seen = HashSet::::new(); + expr.apply(|expr| { + if let Some(column) = expr.downcast_ref::() { + seen.insert(column.clone()); + if seen.len() > 1 { + return Ok(TreeNodeRecursion::Stop); + } + } + Ok(TreeNodeRecursion::Continue) + }) + // pre_visit always returns OK, so this will always too + .expect("no way to return error during recursion"); + match seen.len() { + 0 => ColumnReferenceCount::Zero, + 1 => ColumnReferenceCount::One( + seen.into_iter().next().expect("just checked len==1"), + ), + _ => ColumnReferenceCount::Many, + } + } +} + fn build_statistics_expr( expr_builder: &mut PruningExpressionBuilder, ) -> Result> { let statistics_expr: Arc = match expr_builder.op() { - Operator::NotEq => { - // column != literal => (min, max) = literal => - // !(min != literal && max != literal) ==> - // min != literal || literal != max - let min_column_expr = expr_builder.min_column_expr()?; - let max_column_expr = expr_builder.max_column_expr()?; - Arc::new(phys_expr::BinaryExpr::new( - Arc::new(phys_expr::BinaryExpr::new( - min_column_expr, - Operator::NotEq, - Arc::clone(expr_builder.scalar_expr()), - )), - Operator::Or, - Arc::new(phys_expr::BinaryExpr::new( - Arc::clone(expr_builder.scalar_expr()), - Operator::NotEq, - max_column_expr, - )), - )) - } + Operator::NotEq => build_ne_statistics_expr(expr_builder)?, Operator::Eq => { // column = literal => (min, max) = literal => min <= literal && literal <= max // (column / 2) = 4 => (column_min / 2) <= 4 && 4 <= (column_max / 2) - let min_column_expr = expr_builder.min_column_expr()?; - let max_column_expr = expr_builder.max_column_expr()?; - Arc::new(phys_expr::BinaryExpr::new( - Arc::new(phys_expr::BinaryExpr::new( - min_column_expr, - Operator::LtEq, - Arc::clone(expr_builder.scalar_expr()), - )), - Operator::And, - Arc::new(phys_expr::BinaryExpr::new( - Arc::clone(expr_builder.scalar_expr()), - Operator::LtEq, - max_column_expr, - )), - )) + build_eq_statistics_expr(expr_builder)? } + Operator::IsDistinctFrom => return build_is_distinct_from(expr_builder), + Operator::IsNotDistinctFrom => return build_is_not_distinct_from(expr_builder), Operator::NotLikeMatch => build_not_like_match(expr_builder)?, Operator::LikeMatch => build_like_match(expr_builder).ok_or_else(|| { plan_datafusion_err!( @@ -1622,27 +1684,156 @@ fn build_statistics_expr( Ok(statistics_expr) } +fn binary_expr( + left: Arc, + op: Operator, + right: Arc, +) -> Arc { + Arc::new(phys_expr::BinaryExpr::new(left, op, right)) +} + +fn and_expr( + left: Arc, + right: Arc, +) -> Arc { + binary_expr(left, Operator::And, right) +} + +fn or_expr( + left: Arc, + right: Arc, +) -> Arc { + binary_expr(left, Operator::Or, right) +} + +fn build_eq_statistics_expr( + expr_builder: &mut PruningExpressionBuilder, +) -> Result> { + let min_column_expr = expr_builder.min_column_expr()?; + let max_column_expr = expr_builder.max_column_expr()?; + Ok(and_expr( + binary_expr( + min_column_expr, + Operator::LtEq, + Arc::clone(expr_builder.scalar_expr()), + ), + binary_expr( + Arc::clone(expr_builder.scalar_expr()), + Operator::LtEq, + max_column_expr, + ), + )) +} + +fn build_ne_statistics_expr( + expr_builder: &mut PruningExpressionBuilder, +) -> Result> { + let min_column_expr = expr_builder.min_column_expr()?; + let max_column_expr = expr_builder.max_column_expr()?; + Ok(or_expr( + binary_expr( + min_column_expr, + Operator::NotEq, + Arc::clone(expr_builder.scalar_expr()), + ), + binary_expr( + Arc::clone(expr_builder.scalar_expr()), + Operator::NotEq, + max_column_expr, + ), + )) +} + +fn column_has_nulls_expr( + expr_builder: &mut PruningExpressionBuilder, +) -> Result> { + Ok(binary_expr( + expr_builder.null_count_column_expr()?, + Operator::Gt, + Arc::new(phys_expr::Literal::new(ScalarValue::UInt64(Some(0)))), + )) +} + +fn column_has_non_nulls_expr( + expr_builder: &mut PruningExpressionBuilder, +) -> Result> { + Ok(binary_expr( + expr_builder.null_count_column_expr()?, + Operator::NotEq, + expr_builder.row_count_column_expr()?, + )) +} + +fn build_is_distinct_from( + expr_builder: &mut PruningExpressionBuilder, +) -> Result> { + let scalar_expr = Arc::clone(expr_builder.scalar_expr()); + + Ok(or_expr( + and_expr( + Arc::new(phys_expr::IsNullExpr::new(Arc::clone(&scalar_expr))), + column_has_non_nulls_expr(expr_builder)?, + ), + and_expr( + Arc::new(phys_expr::IsNotNullExpr::new(scalar_expr)), + or_expr( + column_has_nulls_expr(expr_builder)?, + build_ne_statistics_expr(expr_builder)?, + ), + ), + )) +} + +fn build_is_not_distinct_from( + expr_builder: &mut PruningExpressionBuilder, +) -> Result> { + let scalar_expr = Arc::clone(expr_builder.scalar_expr()); + + Ok(or_expr( + and_expr( + Arc::new(phys_expr::IsNullExpr::new(Arc::clone(&scalar_expr))), + column_has_nulls_expr(expr_builder)?, + ), + and_expr( + Arc::new(phys_expr::IsNotNullExpr::new(scalar_expr)), + and_expr( + column_has_non_nulls_expr(expr_builder)?, + build_eq_statistics_expr(expr_builder)?, + ), + ), + )) +} + /// returns the string literal of the scalar value if it is a string fn unpack_string(s: &ScalarValue) -> Option<&str> { s.try_as_str().flatten() } fn extract_string_literal(expr: &Arc) -> Option<&str> { - if let Some(lit) = expr.as_any().downcast_ref::() { + if let Some(lit) = expr.downcast_ref::() { let s = unpack_string(lit.value())?; return Some(s); } None } +/// Wrap a string in a `Literal` whose `ScalarValue` matches `target_type` +fn string_literal_as(value: String, target_type: &DataType) -> Arc { + let utf8 = ScalarValue::Utf8(Some(value)); + let scalar = try_cast_literal_to_type(&utf8, target_type).unwrap_or(utf8); + Arc::new(phys_expr::Literal::new(scalar)) +} + /// Convert `column LIKE literal` where P is a constant prefix of the literal /// to a range check on the column: `P <= column && column < P'`, where P' is the /// lowest string after all P* strings. fn build_like_match( expr_builder: &mut PruningExpressionBuilder, ) -> Option> { - // column LIKE literal => (min, max) LIKE literal split at % => min <= split literal && split literal <= max + // column LIKE literal => (min, max) LIKE literal split at unescaped % => min <= split literal && split literal <= max // column LIKE 'foo%' => min <= 'foo' && 'foo' <= max + // column LIKE 'foo\_%' => min <= 'foo_' && 'foo_' <= max (the _ is escaped) + // column LIKE 'foo\%%' => min <= 'foo%' && 'foo%' <= max (the % is escaped) // column LIKE '%foo' => min <= '' && '' <= max => true // column LIKE '%foo%' => min <= '' && '' <= max => true // column LIKE 'foo' => min <= 'foo' && 'foo' <= max @@ -1652,28 +1843,25 @@ fn build_like_match( let min_column_expr = expr_builder.min_column_expr().ok()?; let max_column_expr = expr_builder.max_column_expr().ok()?; let scalar_expr = expr_builder.scalar_expr(); + // Synthesized bounds must match the column type (e.g. `Utf8View`). + let target_type = expr_builder.field.data_type(); // check that the scalar is a string literal let s = extract_string_literal(scalar_expr)?; // ANSI SQL specifies two wildcards: % and _. % matches zero or more characters, _ matches exactly one character. - let first_wildcard_index = s.find(['%', '_']); - if first_wildcard_index == Some(0) { - // there's no filtering we could possibly do, return an error and have this be handled by the unhandled hook + let (decoded_prefix, rest) = split_constant_prefix(s); + let has_wildcard = !rest.is_empty(); + if has_wildcard && decoded_prefix.is_empty() { + // there's no filtering we could possibly do, return None and have this be handled by the unhandled hook return None; } - let (lower_bound, upper_bound) = if let Some(wildcard_index) = first_wildcard_index { - let prefix = &s[..wildcard_index]; - let lower_bound_lit = Arc::new(phys_expr::Literal::new(ScalarValue::Utf8(Some( - prefix.to_string(), - )))); - let upper_bound_lit = Arc::new(phys_expr::Literal::new(ScalarValue::Utf8(Some( - increment_utf8(prefix)?, - )))); + let (lower_bound, upper_bound) = if has_wildcard { + let incremented_prefix = increment_utf8(&decoded_prefix)?; + let lower_bound_lit = string_literal_as(decoded_prefix, target_type); + let upper_bound_lit = string_literal_as(incremented_prefix, target_type); (lower_bound_lit, upper_bound_lit) } else { // the like expression is a literal and can be converted into a comparison - let bound = Arc::new(phys_expr::Literal::new(ScalarValue::Utf8(Some( - s.to_string(), - )))); + let bound = string_literal_as(decoded_prefix, target_type); (Arc::clone(&bound), bound) }; let lower_bound_expr = Arc::new(phys_expr::BinaryExpr::new( @@ -1753,19 +1941,20 @@ fn build_not_like_match( } /// Returns unescaped constant prefix of a LIKE pattern (possibly empty) and the remaining pattern (possibly empty) -fn split_constant_prefix(pattern: &str) -> (&str, &str) { - let char_indices = pattern.char_indices().collect::>(); - for i in 0..char_indices.len() { - let (idx, char) = char_indices[i]; - if char == '%' || char == '_' { - if i != 0 && char_indices[i - 1].1 == '\\' { - // ecsaped by `\` - continue; - } - return (&pattern[..idx], &pattern[idx..]); +fn split_constant_prefix(pattern: &str) -> (String, &str) { + let mut prefix = String::with_capacity(pattern.len()); + let mut iter = pattern.char_indices(); + while let Some((idx, c)) = iter.next() { + match c { + '%' | '_' => return (prefix, &pattern[idx..]), + '\\' => match iter.next() { + Some((_, escaped)) => prefix.push(escaped), + None => prefix.push('\\'), + }, + _ => prefix.push(c), } } - (pattern, "") + (prefix, "") } /// Increment a UTF8 string by one, returning `None` if it can't be incremented. @@ -1801,13 +1990,13 @@ fn increment_utf8(data: &str) -> Option { let original = code_points[idx] as u32; // Try incrementing the code point - if let Some(next_char) = char::from_u32(original + 1) { - if is_valid_unicode(next_char) { - code_points[idx] = next_char; - // truncate the string to the current index - code_points.truncate(idx + 1); - return Some(code_points.into_iter().collect()); - } + if let Some(next_char) = char::from_u32(original + 1) + && is_valid_unicode(next_char) + { + code_points[idx] = next_char; + // truncate the string to the current index + code_points.truncate(idx + 1); + return Some(code_points.into_iter().collect()); } } @@ -1838,19 +2027,11 @@ fn wrap_null_count_check_expr( statistics_expr: Arc, expr_builder: &mut PruningExpressionBuilder, ) -> Result> { - // x_null_count != x_row_count - let not_when_null_count_eq_row_count = Arc::new(phys_expr::BinaryExpr::new( - expr_builder.null_count_column_expr()?, - Operator::NotEq, - expr_builder.row_count_column_expr()?, - )); - // (x_null_count != x_row_count) AND () - Ok(Arc::new(phys_expr::BinaryExpr::new( - not_when_null_count_eq_row_count, - Operator::And, + Ok(and_expr( + column_has_non_nulls_expr(expr_builder)?, statistics_expr, - ))) + )) } #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -1869,6 +2050,7 @@ mod tests { use super::*; use datafusion_common::test_util::batches_to_string; use datafusion_expr::{and, col, lit, or}; + use datafusion_physical_expr::utils::collect_columns; use insta::assert_snapshot; use arrow::array::Decimal128Array; @@ -1877,10 +2059,13 @@ mod tests { datatypes::TimeUnit, }; use datafusion_expr::expr::InList; - use datafusion_expr::{cast, is_null, try_cast, Expr}; + use datafusion_expr::{BinaryExpr, Expr, cast, is_null, try_cast}; use datafusion_functions_nested::expr_fn::{array_has, make_array}; - use datafusion_physical_expr::expressions as phys_expr; + use datafusion_physical_expr::expressions::{ + self as phys_expr, DynamicFilterPhysicalExpr, + }; use datafusion_physical_expr::planner::logical2physical; + use itertools::Itertools; #[derive(Debug, Default)] /// Mock statistic provider for tests @@ -2064,6 +2249,7 @@ mod tests { } /// Add contained information. + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key pub fn with_contained( mut self, values: impl IntoIterator, @@ -2078,6 +2264,7 @@ mod tests { } /// get any contained information for the specified values + #[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key fn contained(&self, find_values: &HashSet) -> Option { // find the one with the matching values self.contained @@ -2204,11 +2391,10 @@ mod tests { .unwrap_or(None) } - fn row_counts(&self, column: &Column) -> Option { + fn row_counts(&self) -> Option { self.stats - .get(column) - .map(|container_stats| container_stats.row_counts()) - .unwrap_or(None) + .values() + .find_map(|container_stats| container_stats.row_counts()) } fn contained( @@ -2246,7 +2432,7 @@ mod tests { None } - fn row_counts(&self, _column: &Column) -> Option { + fn row_counts(&self) -> Option { None } @@ -2759,6 +2945,163 @@ mod tests { Ok(()) } + /// Test that non-boolean literal expressions don't prune any containers and error gracefully by not pruning anything instead of e.g. panicking + #[test] + fn row_group_predicate_non_boolean() { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + let statistics = TestStatistics::new() + .with("c1", ContainerStats::new_i32(vec![Some(0)], vec![Some(10)])); + let expected_ret = &[true]; + prune_with_expr(lit(1), &schema, &statistics, expected_ret); + } + + // Test that literal-to-literal comparisons are correctly evaluated. + // When both sides are constants, the expression should be evaluated directly + // and if it's false, all containers should be pruned. + #[test] + fn row_group_predicate_literal_false() { + // lit(1) = lit(2) is always false, so all containers should be pruned + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + let statistics = TestStatistics::new() + .with("c1", ContainerStats::new_i32(vec![Some(0)], vec![Some(10)])); + let expected_ret = &[false]; + prune_with_simplified_expr(lit(1).eq(lit(2)), &schema, &statistics, expected_ret); + } + + /// Test nested/complex literal expression trees. + /// This is an integration test that PhysicalExprSimplifier + PruningPredicate work together as expected. + #[test] + fn row_group_predicate_literal_true() { + // lit(1) = lit(1) is always true, so no containers should be pruned + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + let statistics = TestStatistics::new() + .with("c1", ContainerStats::new_i32(vec![Some(0)], vec![Some(10)])); + let expected_ret = &[true]; + prune_with_simplified_expr(lit(1).eq(lit(1)), &schema, &statistics, expected_ret); + } + + /// Test nested/complex literal expression trees. + /// This is an integration test that PhysicalExprSimplifier + PruningPredicate work together as expected. + #[test] + fn row_group_predicate_literal_null() { + // lit(1) = null is always null, so no containers should be pruned + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + let statistics = TestStatistics::new() + .with("c1", ContainerStats::new_i32(vec![Some(0)], vec![Some(10)])); + let expected_ret = &[true]; + prune_with_simplified_expr( + lit(1).eq(lit(ScalarValue::Null)), + &schema, + &statistics, + expected_ret, + ); + } + + /// Test nested/complex literal expression trees. + /// This is an integration test that PhysicalExprSimplifier + PruningPredicate work together as expected. + #[test] + fn row_group_predicate_complex_literals() { + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)])); + let statistics = TestStatistics::new() + .with("c1", ContainerStats::new_i32(vec![Some(0)], vec![Some(10)])); + + // (1 + 2) > 0 is always true + prune_with_simplified_expr( + (lit(1) + lit(2)).gt(lit(0)), + &schema, + &statistics, + &[true], + ); + + // (1 + 2) < 0 is always false + prune_with_simplified_expr( + (lit(1) + lit(2)).lt(lit(0)), + &schema, + &statistics, + &[false], + ); + + // Nested AND of literals: true AND false = false + prune_with_simplified_expr( + lit(true).and(lit(false)), + &schema, + &statistics, + &[false], + ); + + // Nested OR of literals: true OR false = true + prune_with_simplified_expr( + lit(true).or(lit(false)), + &schema, + &statistics, + &[true], + ); + + // Complex nested: (1 < 2) AND (3 > 1) = true AND true = true + prune_with_simplified_expr( + lit(1).lt(lit(2)).and(lit(3).gt(lit(1))), + &schema, + &statistics, + &[true], + ); + + // Complex nested: (1 > 2) OR (3 < 1) = false OR false = false + prune_with_simplified_expr( + lit(1).gt(lit(2)).or(lit(3).lt(lit(1))), + &schema, + &statistics, + &[false], + ); + } + + /// Integration test demonstrating that a dynamic filter with replaced children as literals will be snapshotted, simplified and then pruned correctly. + #[test] + fn row_group_predicate_dynamic_filter_with_literals() { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("part", DataType::Utf8, true), + ])); + let statistics = TestStatistics::new() + // Note that we have no stats, pruning can only happen via partition value pruning from the dynamic filter + .with_row_counts("c1", vec![Some(10)]); + let dynamic_filter_expr = col("c1").gt(lit(5)).and(col("part").eq(lit("B"))); + let phys_expr = logical2physical(&dynamic_filter_expr, &schema); + let children = collect_columns(&phys_expr) + .iter() + .map(|c| Arc::new(c.clone()) as Arc) + .collect_vec(); + let dynamic_phys_expr = + Arc::new(DynamicFilterPhysicalExpr::new(children, phys_expr)) + as Arc; + // Simulate the partition value substitution that would happen in ParquetOpener + let remapped_expr = dynamic_phys_expr + .children() + .into_iter() + .map(|child_expr| { + let Some(col_expr) = child_expr.downcast_ref::() + else { + return Arc::clone(child_expr); + }; + if col_expr.name() == "part" { + // simulate dynamic filter replacement with literal "A" + Arc::new(phys_expr::Literal::new(ScalarValue::Utf8(Some( + "A".to_string(), + )))) as Arc + } else { + Arc::clone(child_expr) + } + }) + .collect_vec(); + let dynamic_filter_expr = + dynamic_phys_expr.with_new_children(remapped_expr).unwrap(); + // After substitution the expression is c1 > 5 AND part = "B" which should prune the file since the partition value is "A" + let expected = &[false]; + let p = + PruningPredicate::try_new(dynamic_filter_expr, Arc::clone(&schema)).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, expected); + } + #[test] fn row_group_predicate_lt_bool() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]); @@ -2790,7 +3133,7 @@ mod tests { test_build_predicate_expression(&expr, &schema, &mut required_columns); assert_eq!(predicate_expr.to_string(), expected_expr); println!("required_columns: {required_columns:#?}"); // for debugging assertions below - // c1 < 1 should add c1_min + // c1 < 1 should add c1_min let c1_min_field = Field::new("c1_min", DataType::Int32, false); assert_eq!( required_columns.columns[0], @@ -3748,6 +4091,90 @@ mod tests { ); } + #[test] + fn prune_int32_col_is_not_distinct_from() { + let (schema, statistics) = int32_setup(); + + // Without null counts, IS NOT DISTINCT FROM a non-null literal can + // still use min/max ranges, but unknown all-null containers must be kept. + let expected_ret = &[true, false, false, true, false]; + + prune_with_expr( + is_not_distinct_from(col("i"), lit(0)), + &schema, + &statistics, + expected_ret, + ); + + // The operator is symmetric, so the scalar-left form should prune the + // same row groups. + prune_with_expr( + is_not_distinct_from(lit(0), col("i")), + &schema, + &statistics, + expected_ret, + ); + + let statistics = statistics + .with_row_counts("i", vec![Some(10), Some(9), None, Some(4), Some(10)]) + .with_null_counts("i", vec![Some(0), Some(1), None, Some(4), Some(0)]); + + let expected_ret = &[true, false, false, false, false]; + prune_with_expr( + is_not_distinct_from(col("i"), lit(0)), + &schema, + &statistics, + expected_ret, + ); + + let expected_ret = &[false, true, true, true, false]; + prune_with_expr( + is_not_distinct_from(col("i"), lit(ScalarValue::Int32(None))), + &schema, + &statistics, + expected_ret, + ); + } + + #[test] + fn prune_int32_col_is_distinct_from() { + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let statistics = TestStatistics::new().with( + "i", + ContainerStats::new_i32( + vec![Some(0), Some(0), Some(5), None], + vec![Some(0), Some(2), Some(5), None], + ) + .with_row_counts(vec![Some(2), Some(2), Some(2), Some(2)]) + .with_null_counts(vec![Some(0), Some(0), Some(0), Some(2)]), + ); + + let expected_ret = &[false, true, true, true]; + prune_with_expr( + is_distinct_from(col("i"), lit(0)), + &schema, + &statistics, + expected_ret, + ); + + // The operator is symmetric, so the scalar-left form should prune the + // same row groups. + prune_with_expr( + is_distinct_from(lit(0), col("i")), + &schema, + &statistics, + expected_ret, + ); + + let expected_ret = &[true, true, true, false]; + prune_with_expr( + is_distinct_from(col("i"), lit(ScalarValue::Int32(None))), + &schema, + &statistics, + expected_ret, + ); + } + #[test] fn prune_int32_col_eq_zero_cast() { let (schema, statistics) = int32_setup(); @@ -3939,7 +4366,7 @@ mod tests { } #[test] - fn prune_cast_column_scalar() { + fn prune_cast_scalar() { // The data type of column i is INT32 let (schema, statistics) = int32_setup(); let expected_ret = &[true, true, false, true, true]; @@ -4397,6 +4824,174 @@ mod tests { prune_with_expr(expr, &schema, &statistics, expected_ret); } + // `build_like_match()` must honor `\` escapes when scanning the pattern for + // wildcards. + #[test] + fn prune_utf8_like_escaped_chars() { + let schema = Arc::new(Schema::new(vec![Field::new("s1", DataType::Utf8, true)])); + let statistics = TestStatistics::new().with( + "s1", + ContainerStats::new_utf8( + vec![ + Some("foo_aaa"), + Some(r#"foo\aaa"#), + Some("foo"), + Some("bar"), + Some("foo%aaa"), + Some("%foo_aaa"), + ], // min + vec![ + Some("foo_zzz"), + Some(r#"foo\zzz"#), + Some("foozzz"), + Some("baz"), + Some("foo%zzz"), + Some("%foo_zzz"), + ], // max + ), + ); + + let expr = col("s1").like(lit(r#"foo\_%"#)); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["foo_aaa", "foo_zzz"] => every value starts with literal + // "foo_" and matches the pattern; must keep. + true, + // s1 ["foo\aaa", "foo\zzz"] => no rows can pass (not keep) + false, + // s1 ["foo", "foozzz"] => stats don't prove "foo_" is or isn't in + // range; must conservatively keep. + true, + // s1 ["bar", "baz"] => no rows can pass (not keep) + false, + // s1 ["foo%aaa", "foo%zzz"] => no rows can pass (not keep) + false, + // s1 ["%foo_aaa", "%foo_zzz"] => no rows can pass (not keep) + false, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + let expr = col("s1").like(lit(r#"foo\\%"#)); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["foo_aaa", "foo_zzz"] => no rows can pass (not keep) + false, + // s1 ["foo\aaa", "foo\zzz"] => every value starts with literal + // "foo\" and matches the pattern; must keep. + true, + // s1 ["foo", "foozzz"] => stats don't prove "foo\" is or isn't in + // range; must conservatively keep. + true, + // s1 ["bar", "baz"] => no rows can pass (not keep) + false, + // s1 ["foo%aaa", "foo%zzz"] => no rows can pass (not keep) + false, + // s1 ["%foo_aaa", "%foo_zzz"] => no rows can pass (not keep) + false, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + let expr = col("s1").like(lit(r#"foo\%%"#)); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["foo_aaa", "foo_zzz"] => no rows can pass (not keep) + false, + // s1 ["foo\aaa", "foo\zzz"] => no rows can pass (not keep) + false, + // s1 ["foo", "foozzz"] => range straddles "foo%"; must keep. + true, + // s1 ["bar", "baz"] => no rows can pass (not keep) + false, + // s1 ["foo%aaa", "foo%zzz"] => every value starts with literal + // "foo%" and matches the pattern; must keep. + true, + // s1 ["%foo_aaa", "%foo_zzz"] => no rows can pass (not keep) + false, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + // No wildcard after escapes: pattern reduces to an equality check on + // the literal "foo_". + let expr = col("s1").like(lit(r#"foo\_"#)); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["foo_aaa", "foo_zzz"] => no rows can pass (not keep) + false, + // s1 ["foo\aaa", "foo\zzz"] => no rows can pass (not keep) + false, + // s1 ["foo", "foozzz"] => "foo_" is within the range; must keep. + true, + // s1 ["bar", "baz"] => no rows can pass (not keep) + false, + // s1 ["foo%aaa", "foo%zzz"] => no rows can pass (not keep) + false, + // s1 ["%foo_aaa", "%foo_zzz"] => no rows can pass (not keep) + false, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + // Leading escaped `%`: prefix is "%foo" (non-empty), so the guard + // for "all wildcards" must NOT bail out here. + let expr = col("s1").like(lit(r#"\%foo%"#)); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["foo_aaa", "foo_zzz"] => no rows can pass (not keep) + false, + // s1 ["foo\aaa", "foo\zzz"] => no rows can pass (not keep) + false, + // s1 ["foo", "foozzz"] => no rows can pass (not keep) + false, + // s1 ["bar", "baz"] => no rows can pass (not keep) + false, + // s1 ["foo%aaa", "foo%zzz"] => no rows can pass (not keep) + false, + // s1 ["%foo_aaa", "%foo_zzz"] => every value starts with literal + // "%foo" and matches the pattern; must keep. + true, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + // Two escaped wildcards, no real wildcard: equality on "foo%_". + let expr = col("s1").like(lit(r#"foo\%\_"#)); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["foo_aaa", "foo_zzz"] => no rows can pass (not keep) + false, + // s1 ["foo\aaa", "foo\zzz"] => no rows can pass (not keep) + false, + // s1 ["foo", "foozzz"] => "foo%_" is within the range; must keep. + true, + // s1 ["bar", "baz"] => no rows can pass (not keep) + false, + // s1 ["foo%aaa", "foo%zzz"] => no rows can pass (not keep) + false, + // s1 ["%foo_aaa", "%foo_zzz"] => no rows can pass (not keep) + false, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + + // Escaped backslash followed by more literal chars before the + // wildcard: prefix is "foo\bar". + let expr = col("s1").like(lit(r#"foo\\bar%"#)); + #[rustfmt::skip] + let expected_ret = &[ + // s1 ["foo_aaa", "foo_zzz"] => no rows can pass (not keep) + false, + // s1 ["foo\aaa", "foo\zzz"] => range straddles "foo\bar"; must + // keep. + true, + // s1 ["foo", "foozzz"] => range straddles "foo\bar"; must keep. + true, + // s1 ["bar", "baz"] => no rows can pass (not keep) + false, + // s1 ["foo%aaa", "foo%zzz"] => no rows can pass (not keep) + false, + // s1 ["%foo_aaa", "%foo_zzz"] => no rows can pass (not keep) + false, + ]; + prune_with_expr(expr, &schema, &statistics, expected_ret); + } + #[test] fn prune_utf8_not_like_one() { let (schema, statistics) = utf8_setup(); @@ -4422,7 +5017,7 @@ mod tests { true, // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) true, - // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate // original (min, max) maybe ("A\u{10ffff}\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}\u{10ffff}\u{10ffff}") true, ]; @@ -5122,6 +5717,37 @@ mod tests { assert_eq!(result, expected); } + fn prune_with_simplified_expr( + expr: Expr, + schema: &SchemaRef, + statistics: &TestStatistics, + expected: &[bool], + ) { + println!("Pruning with expr: {expr}"); + let expr = logical2physical(&expr, schema); + let simplifier = PhysicalExprSimplifier::new(schema); + let expr = simplifier.simplify(expr).unwrap(); + let p = PruningPredicate::try_new(expr, Arc::::clone(schema)).unwrap(); + let result = p.prune(statistics).unwrap(); + assert_eq!(result, expected); + } + + fn is_not_distinct_from(left: Expr, right: Expr) -> Expr { + Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + Operator::IsNotDistinctFrom, + Box::new(right), + )) + } + + fn is_distinct_from(left: Expr, right: Expr) -> Expr { + Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + Operator::IsDistinctFrom, + Box::new(right), + )) + } + fn test_build_predicate_expression( expr: &Expr, schema: &Schema, diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index 279c88b525d3c..93987b553f2f5 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -29,6 +29,10 @@ edition = { workspace = true } [package.metadata.docs.rs] all-features = true +[features] +default = [] +core = ["datafusion"] + # Note: add additional linter rules in lib.rs. # Rust does not support workspace + new linter rules in subcrates yet # https://github.com/rust-lang/cargo/issues/13157 @@ -43,19 +47,61 @@ arrow = { workspace = true } bigdecimal = { workspace = true } chrono = { workspace = true } crc32fast = "1.4" +# Optional dependency for SessionStateBuilderSpark extension trait +datafusion = { workspace = true, optional = true, default-features = false } datafusion-catalog = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true, features = ["crypto_expressions"] } +datafusion-functions-aggregate = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } +datafusion-functions-nested = { workspace = true } log = { workspace = true } +num-traits = { workspace = true } +percent-encoding = "2.3.2" rand = { workspace = true } -sha1 = "0.10" +serde_json = { workspace = true } +sha1 = "0.11" +sha2 = { workspace = true } +twox-hash = "2.1" url = { workspace = true } [dev-dependencies] +arrow = { workspace = true, features = ["test_utils"] } criterion = { workspace = true } +# for SessionStateBuilderSpark tests +datafusion = { workspace = true, default-features = false, features = ["sql"] } +tokio = { workspace = true, features = ["rt"] } [[bench]] harness = false name = "char" + +[[bench]] +harness = false +name = "space" + +[[bench]] +harness = false +name = "hex" + +[[bench]] +harness = false +name = "slice" + +[[bench]] +harness = false +name = "substring" + +[[bench]] +harness = false +name = "unhex" + +[[bench]] +harness = false +name = "sha2" + +[[bench]] +harness = false +name = "floor" diff --git a/datafusion/spark/benches/char.rs b/datafusion/spark/benches/char.rs index 02eab7630d070..38d9ebdeb4f5f 100644 --- a/datafusion/spark/benches/char.rs +++ b/datafusion/spark/benches/char.rs @@ -15,11 +15,9 @@ // specific language governing permissions and limitations // under the License. -extern crate criterion; - use arrow::datatypes::{DataType, Field}; use arrow::{array::PrimitiveArray, datatypes::Int64Type}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use datafusion_common::config::ConfigOptions; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_spark::function::string::char; diff --git a/datafusion/spark/benches/floor.rs b/datafusion/spark/benches/floor.rs new file mode 100644 index 0000000000000..ecb1590acc542 --- /dev/null +++ b/datafusion/spark/benches/floor.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::*; +use arrow::datatypes::*; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_spark::function::math::floor::SparkFloor; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +fn generate_float64_data(size: usize, null_density: f32) -> Float64Array { + let mut rng = seedable_rng(); + (0..size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range::(-1_000_000.0..1_000_000.0)) + } + }) + .collect() +} + +fn generate_decimal128_data(size: usize, null_density: f32) -> Decimal128Array { + let mut rng = seedable_rng(); + let array: Decimal128Array = (0..size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range::(-999_999_999..999_999_999)) + } + }) + .collect(); + array.with_precision_and_scale(18, 2).unwrap() +} + +fn run_benchmark( + c: &mut Criterion, + name: &str, + size: usize, + array: Arc, + return_type: &DataType, +) { + let floor_func = SparkFloor::new(); + let args = vec![ColumnarValue::Array(array)]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("{name}/size={size}"), |b| { + b.iter(|| { + black_box( + floor_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Arc::new(Field::new( + "f", + return_type.clone(), + true, + )), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let sizes = vec![1024, 4096, 8192]; + let null_density = 0.1; + + for &size in &sizes { + let data = generate_float64_data(size, null_density); + run_benchmark(c, "floor_float64", size, Arc::new(data), &DataType::Int64); + } + + for &size in &sizes { + let data = generate_decimal128_data(size, null_density); + run_benchmark( + c, + "floor_decimal128", + size, + Arc::new(data), + &DataType::Decimal128(17, 0), + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/benches/hex.rs b/datafusion/spark/benches/hex.rs new file mode 100644 index 0000000000000..9785371cc5827 --- /dev/null +++ b/datafusion/spark/benches/hex.rs @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::*; +use arrow::datatypes::*; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_spark::function::math::hex::SparkHex; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +fn generate_int64_data(size: usize, null_density: f32) -> PrimitiveArray { + let mut rng = seedable_rng(); + (0..size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range::(-999_999_999_999..999_999_999_999)) + } + }) + .collect() +} + +fn generate_utf8_data(size: usize, null_density: f32) -> StringArray { + let mut rng = seedable_rng(); + let mut builder = StringBuilder::new(); + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(1..=100); + let s: String = + std::iter::repeat_with(|| rng.random_range(b'a'..=b'z') as char) + .take(len) + .collect(); + builder.append_value(&s); + } + } + builder.finish() +} + +fn generate_binary_data(size: usize, null_density: f32) -> BinaryArray { + let mut rng = seedable_rng(); + let mut builder = BinaryBuilder::new(); + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(1..=100); + let bytes: Vec = (0..len).map(|_| rng.random()).collect(); + builder.append_value(&bytes); + } + } + builder.finish() +} + +fn generate_int64_dict_data( + size: usize, + null_density: f32, +) -> DictionaryArray { + let mut rng = seedable_rng(); + let mut builder = PrimitiveDictionaryBuilder::::new(); + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + builder.append_value( + rng.random_range::(-999_999_999_999..999_999_999_999), + ); + } + } + builder.finish() +} + +fn run_benchmark(c: &mut Criterion, name: &str, size: usize, array: Arc) { + let hex_func = SparkHex::new(); + let args = vec![ColumnarValue::Array(array)]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("{name}/size={size}"), |b| { + b.iter(|| { + black_box( + hex_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Arc::new(Field::new("f", DataType::Utf8, true)), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let sizes = vec![1024, 4096, 8192]; + let null_density = 0.1; + + for &size in &sizes { + let data = generate_int64_data(size, null_density); + run_benchmark(c, "hex_int64", size, Arc::new(data)); + } + + for &size in &sizes { + let data = generate_utf8_data(size, null_density); + run_benchmark(c, "hex_utf8", size, Arc::new(data)); + } + + for &size in &sizes { + let data = generate_binary_data(size, null_density); + run_benchmark(c, "hex_binary", size, Arc::new(data)); + } + + for &size in &sizes { + let data = generate_int64_dict_data(size, null_density); + run_benchmark(c, "hex_int64_dict", size, Arc::new(data)); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/benches/sha2.rs b/datafusion/spark/benches/sha2.rs new file mode 100644 index 0000000000000..6e835984703f0 --- /dev/null +++ b/datafusion/spark/benches/sha2.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::*; +use arrow::datatypes::*; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_spark::function::hash::sha2::SparkSha2; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +fn generate_binary_data(size: usize, null_density: f32) -> BinaryArray { + let mut rng = seedable_rng(); + let mut builder = BinaryBuilder::new(); + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(1..=100); + let bytes: Vec = (0..len).map(|_| rng.random()).collect(); + builder.append_value(&bytes); + } + } + builder.finish() +} + +fn run_benchmark(c: &mut Criterion, name: &str, size: usize, args: &[ColumnarValue]) { + let sha2_func = SparkSha2::new(); + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("{name}/size={size}"), |b| { + b.iter(|| { + black_box( + sha2_func + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Arc::new(Field::new("f", DataType::Utf8, true)), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + // Scalar benchmark (avoid array expansion) + let scalar_args = vec![ + ColumnarValue::Scalar(ScalarValue::Binary(Some(b"Spark".to_vec()))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(256))), + ]; + run_benchmark(c, "sha2/scalar", 1, &scalar_args); + + let sizes = vec![1024, 4096, 8192]; + let null_density = 0.1; + + for &size in &sizes { + let values: ArrayRef = Arc::new(generate_binary_data(size, null_density)); + let bit_lengths: ArrayRef = Arc::new(Int32Array::from(vec![256; size])); + + let array_args = vec![ + ColumnarValue::Array(Arc::clone(&values)), + ColumnarValue::Array(Arc::clone(&bit_lengths)), + ]; + run_benchmark(c, "sha2/array_binary_256", size, &array_args); + + let array_scalar_args = vec![ + ColumnarValue::Array(Arc::clone(&values)), + ColumnarValue::Scalar(ScalarValue::Int32(Some(256))), + ]; + run_benchmark(c, "sha2/array_scalar_binary_256", size, &array_scalar_args); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/benches/slice.rs b/datafusion/spark/benches/slice.rs new file mode 100644 index 0000000000000..da392dc042f92 --- /dev/null +++ b/datafusion/spark/benches/slice.rs @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Int64Array, ListArray, ListViewArray, NullBufferBuilder, PrimitiveArray, +}; +use arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::{DataType, Field, Int64Type}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_spark::function::array::slice; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn create_inputs( + rng: &mut StdRng, + size: usize, + child_array_size: usize, + null_density: f32, +) -> (ListArray, ListViewArray) { + let mut nulls_builder = NullBufferBuilder::new(size); + let mut sizes = Vec::with_capacity(size); + + for _ in 0..size { + if rng.random::() < null_density { + nulls_builder.append_null(); + } else { + nulls_builder.append_non_null(); + } + sizes.push(rng.random_range(1..child_array_size)); + } + let nulls = nulls_builder.finish(); + + let length = sizes.iter().sum(); + let values: PrimitiveArray = + (0..length).map(|_| Some(rng.random())).collect(); + let values = Arc::new(values); + + let offsets = OffsetBuffer::from_lengths(sizes.clone()); + let list_array = ListArray::new( + Arc::new(Field::new_list_field(DataType::Int64, true)), + offsets.clone(), + values.clone(), + nulls.clone(), + ); + + let offsets = ScalarBuffer::from(offsets.slice(0, size - 1)); + let sizes = ScalarBuffer::from_iter(sizes.into_iter().map(|v| v as i32)); + let list_view_array = ListViewArray::new( + Arc::new(Field::new_list_field(DataType::Int64, true)), + offsets, + sizes, + values, + nulls, + ); + + (list_array, list_view_array) +} + +fn random_from_to( + rng: &mut StdRng, + size: i64, + null_density: f32, +) -> (Option, Option) { + let from = if rng.random::() < null_density { + None + } else { + Some(rng.random_range(1..=size)) + }; + + let to = if rng.random::() < null_density { + None + } else { + match from { + Some(from) => Some(rng.random_range(from..=size)), + None => Some(rng.random_range(1..=size)), + } + }; + + (from, to) +} + +fn array_slice_benchmark( + name: &str, + input: ColumnarValue, + mut args: Vec, + c: &mut Criterion, + size: usize, +) { + args.insert(0, input); + + let array_slice = slice(); + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| { + >::from(Field::new(format!("arg_{idx}"), arg.data_type(), true)) + }) + .collect::>(); + c.bench_function(name, |b| { + b.iter(|| { + black_box( + array_slice + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Field::new_list_field(args[0].data_type(), true) + .into(), + config_options: Arc::new(ConfigOptions::default()), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let rng = &mut StdRng::seed_from_u64(42); + let size = 1_000_000; + let child_array_size = 100; + let null_density = 0.1; + + let (list_array, list_view_array) = + create_inputs(rng, size, child_array_size, null_density); + + let mut array_from = Vec::with_capacity(size); + let mut array_to = Vec::with_capacity(size); + for child_array_size in list_array.offsets().lengths() { + let (from, to) = random_from_to(rng, child_array_size as i64, null_density); + array_from.push(from); + array_to.push(to); + } + + // input + let list_array = ColumnarValue::Array(Arc::new(list_array)); + let list_view_array = ColumnarValue::Array(Arc::new(list_view_array)); + + // args + let array_from = ColumnarValue::Array(Arc::new(Int64Array::from(array_from))); + let array_to = ColumnarValue::Array(Arc::new(Int64Array::from(array_to))); + let scalar_from = ColumnarValue::Scalar(ScalarValue::from(1i64)); + let scalar_to = ColumnarValue::Scalar(ScalarValue::from(child_array_size as i64 / 2)); + + for input in [list_array, list_view_array] { + let input_type = input.data_type().to_string(); + + array_slice_benchmark( + &format!("slice: input {input_type}, array args, no stride"), + input.clone(), + vec![array_from.clone(), array_to.clone()], + c, + size, + ); + + array_slice_benchmark( + &format!("slice: input {input_type}, scalar args, no stride"), + input.clone(), + vec![scalar_from.clone(), scalar_to.clone()], + c, + size, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/benches/space.rs b/datafusion/spark/benches/space.rs new file mode 100644 index 0000000000000..bd9d370ca37fe --- /dev/null +++ b/datafusion/spark/benches/space.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::PrimitiveArray; +use arrow::datatypes::{DataType, Field, Int32Type}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_spark::function::string::space; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let space_func = space(); + let size = 1024; + let input: PrimitiveArray = { + let null_density = 0.2; + let mut rng = StdRng::seed_from_u64(42); + (0..size) + .map(|_| { + if rng.random::() < null_density { + None + } else { + Some(rng.random_range::(1i32..10)) + } + }) + .collect() + }; + let input = Arc::new(input); + let args = vec![ColumnarValue::Array(input)]; + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + c.bench_function("space", |b| { + b.iter(|| { + black_box( + space_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Arc::new(Field::new("f", DataType::Utf8, true)), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/benches/substring.rs b/datafusion/spark/benches/substring.rs new file mode 100644 index 0000000000000..d6eac817c322f --- /dev/null +++ b/datafusion/spark/benches/substring.rs @@ -0,0 +1,205 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::{DataType, Field}; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{Criterion, SamplingMode, criterion_group, criterion_main}; +use datafusion_common::DataFusionError; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; +use datafusion_spark::function::string::substring; +use std::hint::black_box; +use std::sync::Arc; + +fn create_args_without_count( + size: usize, + str_len: usize, + start_half_way: bool, + force_view_types: bool, +) -> Vec { + let start_array = Arc::new(Int64Array::from( + (0..size) + .map(|_| { + if start_half_way { + (str_len / 2) as i64 + } else { + 1i64 + } + }) + .collect::>(), + )); + + if force_view_types { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(start_array), + ] + } else { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&start_array) as ArrayRef), + ] + } +} + +fn create_args_with_count( + size: usize, + str_len: usize, + count_max: usize, + force_view_types: bool, +) -> Vec { + let start_array = + Arc::new(Int64Array::from((0..size).map(|_| 1).collect::>())); + let count = count_max.min(str_len) as i64; + let count_array = Arc::new(Int64Array::from( + (0..size).map(|_| count).collect::>(), + )); + + if force_view_types { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(start_array), + ColumnarValue::Array(count_array), + ] + } else { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&start_array) as ArrayRef), + ColumnarValue::Array(Arc::clone(&count_array) as ArrayRef), + ] + } +} + +#[expect(clippy::needless_pass_by_value)] +fn invoke_substr_with_args( + args: Vec, + number_rows: usize, +) -> Result { + let arg_fields = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect::>(); + let config_options = Arc::new(ConfigOptions::default()); + + substring().invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields, + number_rows, + return_field: Field::new("f", DataType::Utf8View, true).into(), + config_options: Arc::clone(&config_options), + }) +} + +fn criterion_benchmark(c: &mut Criterion) { + for size in [1024, 4096] { + // string_len = 12, substring_len=6 (see `create_args_without_count`) + let len = 12; + let mut group = c.benchmark_group("SHORTER THAN 12"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_without_count::(size, len, true, true); + group.bench_function( + format!("substr_string_view [size={size}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_without_count::(size, len, false, false); + group.bench_function(format!("substr_string [size={size}, strlen={len}]"), |b| { + b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))) + }); + + let args = create_args_without_count::(size, len, true, false); + group.bench_function( + format!("substr_large_string [size={size}, strlen={len}]"), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + group.finish(); + + // string_len = 128, start=1, count=64, substring_len=64 + let len = 128; + let count = 64; + let mut group = c.benchmark_group("LONGER THAN 12"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_with_count::(size, len, count, true); + group.bench_function( + format!("substr_string_view [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + format!("substr_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + format!("substr_large_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + group.finish(); + + // string_len = 128, start=1, count=6, substring_len=6 + let len = 128; + let count = 6; + let mut group = c.benchmark_group("SRC_LEN > 12, SUB_LEN < 12"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_with_count::(size, len, count, true); + group.bench_function( + format!("substr_string_view [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + format!("substr_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + format!("substr_large_string [size={size}, count={count}, strlen={len}]",), + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), + ); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/benches/unhex.rs b/datafusion/spark/benches/unhex.rs new file mode 100644 index 0000000000000..7dce683485bc7 --- /dev/null +++ b/datafusion/spark/benches/unhex.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, LargeStringArray, LargeStringBuilder, StringArray, StringBuilder, + StringViewArray, StringViewBuilder, +}; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_spark::function::math::unhex::SparkUnhex; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn generate_hex_string_data(size: usize, null_density: f32) -> StringArray { + let mut rng = StdRng::seed_from_u64(42); + let mut builder = StringBuilder::with_capacity(size, 0); + let hex_chars = b"0123456789abcdefABCDEF"; + + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(2..=100); + let s: String = std::iter::repeat_with(|| { + hex_chars[rng.random_range(0..hex_chars.len())] as char + }) + .take(len) + .collect(); + builder.append_value(&s); + } + } + builder.finish() +} + +fn generate_hex_large_string_data(size: usize, null_density: f32) -> LargeStringArray { + let mut rng = StdRng::seed_from_u64(42); + let mut builder = LargeStringBuilder::with_capacity(size, 0); + let hex_chars = b"0123456789abcdefABCDEF"; + + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(2..=100); + let s: String = std::iter::repeat_with(|| { + hex_chars[rng.random_range(0..hex_chars.len())] as char + }) + .take(len) + .collect(); + builder.append_value(&s); + } + } + builder.finish() +} + +fn generate_hex_utf8view_data(size: usize, null_density: f32) -> StringViewArray { + let mut rng = StdRng::seed_from_u64(42); + let mut builder = StringViewBuilder::with_capacity(size); + let hex_chars = b"0123456789abcdefABCDEF"; + + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(2..=100); + let s: String = std::iter::repeat_with(|| { + hex_chars[rng.random_range(0..hex_chars.len())] as char + }) + .take(len) + .collect(); + builder.append_value(&s); + } + } + builder.finish() +} + +fn run_benchmark(c: &mut Criterion, name: &str, size: usize, array: Arc) { + let unhex_func = SparkUnhex::new(); + let args = vec![ColumnarValue::Array(array)]; + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("{name}/size={size}"), |b| { + b.iter(|| { + black_box( + unhex_func + .invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Arc::new(Field::new("f", DataType::Binary, true)), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let sizes = vec![1024, 4096, 8192]; + let null_density = 0.1; + + // Benchmark with hex string + for &size in &sizes { + let data = generate_hex_string_data(size, null_density); + run_benchmark(c, "unhex_utf8", size, Arc::new(data)); + } + + // Benchmark with hex large string + for &size in &sizes { + let data = generate_hex_large_string_data(size, null_density); + run_benchmark(c, "unhex_large_utf8", size, Arc::new(data)); + } + + // Benchmark with hex Utf8View + for &size in &sizes { + let data = generate_hex_utf8view_data(size, null_density); + run_benchmark(c, "unhex_utf8view", size, Arc::new(data)); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/src/function/aggregate/avg.rs b/datafusion/spark/src/function/aggregate/avg.rs index 4a7adc515bbc4..5f4d2c253a2dc 100644 --- a/datafusion/spark/src/function/aggregate/avg.rs +++ b/datafusion/spark/src/function/aggregate/avg.rs @@ -16,22 +16,26 @@ // under the License. use arrow::array::{ + Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, BooleanArray, Int64Array, + PrimitiveArray, builder::PrimitiveBuilder, cast::AsArray, types::{Float64Type, Int64Type}, - Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, Int64Array, PrimitiveArray, }; use arrow::compute::sum; use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion_common::types::{logical_float64, NativeType}; -use datafusion_common::{not_impl_err, Result, ScalarValue}; +use datafusion_common::types::{NativeType, logical_float64}; +use datafusion_common::{Result, ScalarValue, not_impl_err}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Coercion, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, TypeSignatureClass, Volatility, }; -use std::{any::Any, sync::Arc}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ + filtered_null_mask, set_nulls, +}; +use std::sync::Arc; /// AVG aggregate expression /// Spark average aggregate expression. Differs from standard DataFusion average aggregate @@ -68,10 +72,6 @@ impl SparkAvg { } impl AggregateUDFImpl for SparkAvg { - fn as_any(&self) -> &dyn Any { - self - } - fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(DataType::Float64) } @@ -213,7 +213,7 @@ impl Accumulator for AvgAccumulator { struct AvgGroupsAccumulator where T: ArrowNumericType + Send, - F: Fn(T::Native, i64) -> Result + Send, + F: Fn(T::Native, i64) -> Result + Send + 'static, { /// The type of the returned average return_data_type: DataType, @@ -231,7 +231,7 @@ where impl AvgGroupsAccumulator where T: ArrowNumericType + Send, - F: Fn(T::Native, i64) -> Result + Send, + F: Fn(T::Native, i64) -> Result + Send + 'static, { pub fn new(return_data_type: &DataType, avg_fn: F) -> Self { Self { @@ -246,13 +246,13 @@ where impl GroupsAccumulator for AvgGroupsAccumulator where T: ArrowNumericType + Send, - F: Fn(T::Native, i64) -> Result + Send, + F: Fn(T::Native, i64) -> Result + Send + 'static, { fn update_batch( &mut self, values: &[ArrayRef], group_indices: &[usize], - _opt_filter: Option<&arrow::array::BooleanArray>, + _opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); @@ -289,26 +289,26 @@ where &mut self, values: &[ArrayRef], group_indices: &[usize], - _opt_filter: Option<&arrow::array::BooleanArray>, + _opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 2, "two arguments to merge_batch"); // first batch is partial sums, second is counts let partial_sums = values[0].as_primitive::(); let partial_counts = values[1].as_primitive::(); - // update counts with partial counts - self.counts.resize(total_num_groups, 0); - let iter1 = group_indices.iter().zip(partial_counts.values().iter()); - for (&group_index, &partial_count) in iter1 { - self.counts[group_index] += partial_count; - } - // update sums + self.counts.resize(total_num_groups, 0); self.sums.resize(total_num_groups, T::default_value()); - let iter2 = group_indices.iter().zip(partial_sums.values().iter()); - for (&group_index, &new_value) in iter2 { + + for (idx, &group_index) in group_indices.iter().enumerate() { + // Skip null state entries emitted by convert_to_state for + // filtered / null input rows. + if partial_counts.is_null(idx) || partial_sums.is_null(idx) { + continue; + } + self.counts[group_index] += partial_counts.value(idx); let sum = &mut self.sums[group_index]; - *sum = sum.add_wrapping(new_value); + *sum = sum.add_wrapping(partial_sums.value(idx)); } Ok(()) @@ -347,7 +347,149 @@ where ]) } + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let sums = values[0] + .as_primitive::() + .clone() + .with_data_type(self.return_data_type.clone()); + let counts = Int64Array::from_value(1, sums.len()); + + let nulls = filtered_null_mask(opt_filter, &sums); + let counts = set_nulls(counts, nulls.clone()); + let sums = set_nulls(sums, nulls); + + // [sum, count] - must match state() and merge_batch() + Ok(vec![ + Arc::new(sums) as ArrayRef, + Arc::new(counts) as ArrayRef, + ]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + fn size(&self) -> usize { self.counts.capacity() * size_of::() + self.sums.capacity() * size_of::() } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Float64Array; + + fn make_acc() -> AvgGroupsAccumulator Result> { + AvgGroupsAccumulator::::new(&DataType::Float64, |sum, count| { + Ok(sum / count as f64) + }) + } + + #[test] + fn supports_convert_to_state() { + assert!(make_acc().supports_convert_to_state()); + } + + #[test] + fn convert_to_state_basic() { + let acc = make_acc(); + let values: Vec = + vec![Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]))]; + let state = acc.convert_to_state(&values, None).unwrap(); + + assert_eq!(state.len(), 2); + let sums = state[0].as_primitive::(); + let counts = state[1].as_primitive::(); + + assert_eq!(sums.values().as_ref(), &[1.0, 2.0, 3.0]); + assert_eq!(counts.values().as_ref(), &[1, 1, 1]); + assert_eq!(sums.null_count(), 0); + assert_eq!(counts.null_count(), 0); + } + + #[test] + fn convert_to_state_with_nulls() { + let acc = make_acc(); + let values: Vec = vec![Arc::new(Float64Array::from(vec![ + Some(1.0), + None, + Some(3.0), + ]))]; + let state = acc.convert_to_state(&values, None).unwrap(); + + let sums = state[0].as_primitive::(); + let counts = state[1].as_primitive::(); + + assert!(!sums.is_null(0)); + assert!(sums.is_null(1)); + assert!(!sums.is_null(2)); + + assert_eq!(counts.value(0), 1); + assert!(counts.is_null(1)); + assert_eq!(counts.value(2), 1); + } + + #[test] + fn convert_to_state_with_filter() { + let acc = make_acc(); + let values: Vec = + vec![Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]))]; + let filter = BooleanArray::from(vec![true, false, true]); + let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); + + let sums = state[0].as_primitive::(); + let counts = state[1].as_primitive::(); + + assert!(!sums.is_null(0)); + assert!(sums.is_null(1)); + assert!(!sums.is_null(2)); + + assert_eq!(counts.value(0), 1); + assert!(counts.is_null(1)); + assert_eq!(counts.value(2), 1); + } + + #[test] + fn convert_to_state_roundtrips_through_merge() { + let mut acc = make_acc(); + let input: Vec = + vec![Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0]))]; + let state = acc.convert_to_state(&input, None).unwrap(); + + // feed the converted state back through merge_batch + acc.merge_batch( + &state, + &[0, 0, 0], + None, + 1, // single group + ) + .unwrap(); + + let result = acc.evaluate(EmitTo::All).unwrap(); + let result = result.as_primitive::(); + assert_eq!(result.value(0), 20.0); // (10+20+30)/3 + } + + #[test] + fn convert_to_state_null_merge_matches_direct() { + // avg([1.0, NULL, 3.0]) must be 2.0 after a convert_to_state → merge_batch + // round-trip. Before the merge-path null fix this leaked the backing + // buffer value at the null slot and produced the wrong average. + let mut acc = make_acc(); + let input: Vec = vec![Arc::new(Float64Array::from(vec![ + Some(1.0), + None, + Some(3.0), + ]))]; + let state = acc.convert_to_state(&input, None).unwrap(); + acc.merge_batch(&state, &[0, 0, 0], None, 1).unwrap(); + + let result = acc.evaluate(EmitTo::All).unwrap(); + let result = result.as_primitive::(); + assert_eq!(result.value(0), 2.0); + } +} diff --git a/datafusion/spark/src/function/aggregate/collect.rs b/datafusion/spark/src/function/aggregate/collect.rs new file mode 100644 index 0000000000000..5af0fd39cca07 --- /dev/null +++ b/datafusion/spark/src/function/aggregate/collect.rs @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::utils::SingleRowListArrayBuilder; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_functions_aggregate::array_agg::{ + ArrayAggAccumulator, DistinctArrayAggAccumulator, +}; +use std::sync::Arc; + +// Spark implementation of collect_list/collect_set aggregate function. +// Differs from DataFusion ArrayAgg in the following ways: +// - ignores NULL inputs +// - returns an empty list when all inputs are NULL +// - does not support ordering + +// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkCollectList { + signature: Signature, +} + +impl Default for SparkCollectList { + fn default() -> Self { + Self::new() + } +} + +impl SparkCollectList { + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for SparkCollectList { + fn name(&self) -> &str { + "collect_list" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new_list_field( + arg_types[0].clone(), + true, + )))) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new_list( + format_state_name(args.name, "collect_list"), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), + true, + ) + .into(), + ]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let field = &acc_args.expr_fields[0]; + let data_type = field.data_type().clone(); + let ignore_nulls = true; + Ok(Box::new(NullToEmptyListAccumulator::new( + ArrayAggAccumulator::try_new(&data_type, ignore_nulls)?, + data_type, + ))) + } +} + +// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkCollectSet { + signature: Signature, +} + +impl Default for SparkCollectSet { + fn default() -> Self { + Self::new() + } +} + +impl SparkCollectSet { + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for SparkCollectSet { + fn name(&self) -> &str { + "collect_set" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new_list_field( + arg_types[0].clone(), + true, + )))) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new_list( + format_state_name(args.name, "collect_set"), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), + true, + ) + .into(), + ]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let field = &acc_args.expr_fields[0]; + let data_type = field.data_type().clone(); + let ignore_nulls = true; + Ok(Box::new(NullToEmptyListAccumulator::new( + DistinctArrayAggAccumulator::try_new(&data_type, None, ignore_nulls)?, + data_type, + ))) + } +} + +/// Wrapper accumulator that returns an empty list instead of NULL when all inputs are NULL. +/// This implements Spark's behavior for collect_list and collect_set. +#[derive(Debug)] +struct NullToEmptyListAccumulator { + inner: T, + data_type: DataType, +} + +impl NullToEmptyListAccumulator { + pub fn new(inner: T, data_type: DataType) -> Self { + Self { inner, data_type } + } +} + +impl Accumulator for NullToEmptyListAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.inner.update_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.inner.merge_batch(states) + } + + fn state(&mut self) -> Result> { + self.inner.state() + } + + fn evaluate(&mut self) -> Result { + let result = self.inner.evaluate()?; + if result.is_null() { + let empty_array = arrow::array::new_empty_array(&self.data_type); + Ok(SingleRowListArrayBuilder::new(empty_array).build_list_scalar()) + } else { + Ok(result) + } + } + + fn size(&self) -> usize { + self.inner.size() + self.data_type.size() + } +} diff --git a/datafusion/spark/src/function/aggregate/mod.rs b/datafusion/spark/src/function/aggregate/mod.rs index d765d9c82f068..d6a2fe7a8503e 100644 --- a/datafusion/spark/src/function/aggregate/mod.rs +++ b/datafusion/spark/src/function/aggregate/mod.rs @@ -19,17 +19,44 @@ use datafusion_expr::AggregateUDF; use std::sync::Arc; pub mod avg; +pub mod collect; +pub mod try_sum; + pub mod expr_fn { use datafusion_functions::export_functions; export_functions!((avg, "Returns the average value of a given column", arg1)); + export_functions!(( + try_sum, + "Returns the sum of values for a column, or NULL if overflow occurs", + arg1 + )); + export_functions!(( + collect_list, + "Returns a list created from the values in a column", + arg1 + )); + export_functions!(( + collect_set, + "Returns a set created from the values in a column", + arg1 + )); } // TODO: try use something like datafusion_functions_aggregate::create_func!() pub fn avg() -> Arc { Arc::new(AggregateUDF::new_from_impl(avg::SparkAvg::new())) } +pub fn try_sum() -> Arc { + Arc::new(AggregateUDF::new_from_impl(try_sum::SparkTrySum::new())) +} +pub fn collect_list() -> Arc { + Arc::new(AggregateUDF::new_from_impl(collect::SparkCollectList::new())) +} +pub fn collect_set() -> Arc { + Arc::new(AggregateUDF::new_from_impl(collect::SparkCollectSet::new())) +} pub fn functions() -> Vec> { - vec![avg()] + vec![avg(), try_sum(), collect_list(), collect_set()] } diff --git a/datafusion/spark/src/function/aggregate/try_sum.rs b/datafusion/spark/src/function/aggregate/try_sum.rs new file mode 100644 index 0000000000000..d1f99f4ebc0c3 --- /dev/null +++ b/datafusion/spark/src/function/aggregate/try_sum.rs @@ -0,0 +1,655 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, ArrowNumericType, AsArray, BooleanArray, PrimitiveArray}; +use arrow::datatypes::{ + DECIMAL128_MAX_PRECISION, DataType, Decimal128Type, Field, FieldRef, Float64Type, + Int64Type, +}; +use datafusion_common::{Result, ScalarValue, downcast_value, exec_err, not_impl_err}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use std::fmt::{Debug, Formatter}; +use std::mem::size_of_val; + +#[derive(PartialEq, Eq, Hash)] +pub struct SparkTrySum { + signature: Signature, +} + +impl Default for SparkTrySum { + fn default() -> Self { + Self::new() + } +} + +impl SparkTrySum { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl Debug for SparkTrySum { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SparkTrySum") + .field("signature", &self.signature) + .finish() + } +} + +/// Accumulator for try_sum that detects overflow +struct TrySumAccumulator { + sum: Option, + data_type: DataType, + failed: bool, + // Only used if data_type is Decimal128(p, s) + dec_precision: Option, +} + +impl Debug for TrySumAccumulator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "TrySumAccumulator({})", self.data_type) + } +} + +impl TrySumAccumulator { + fn new(data_type: DataType) -> Self { + let dec_precision = match &data_type { + DataType::Decimal128(p, _) => Some(*p), + _ => None, + }; + Self { + sum: None, + data_type, + failed: false, + dec_precision, + } + } +} + +impl Accumulator for TrySumAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + self.evaluate()?, + ScalarValue::Boolean(Some(self.failed)), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + update_batch_internal(self, values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // Check if any partition has failed + if downcast_value!(states[1], BooleanArray) + .iter() + .flatten() + .any(|f| f) + { + self.failed = true; + return Ok(()); + } + + // Merge the sum values using the same logic as update_batch + update_batch_internal(self, states) + } + + fn evaluate(&mut self) -> Result { + evaluate_internal(self) + } + + fn size(&self) -> usize { + size_of_val(self) + } +} + +// Specialized implementations for update_batch for each type + +fn update_batch_internal( + acc: &mut TrySumAccumulator, + values: &[ArrayRef], +) -> Result<()> { + if values.is_empty() || acc.failed { + return Ok(()); + } + + let array: &PrimitiveArray = values[0].as_primitive::(); + + match acc.data_type { + DataType::Int64 => update_int64(acc, array), + DataType::Float64 => update_float64(acc, array), + DataType::Decimal128(_, _) => update_decimal128(acc, array), + _ => exec_err!( + "try_sum: unsupported type in update_batch: {:?}", + acc.data_type + ), + } +} + +fn update_int64( + acc: &mut TrySumAccumulator, + array: &PrimitiveArray, +) -> Result<()> { + for v in array.iter().flatten() { + // Cast to i64 for checked_add + let v_i64 = unsafe { std::mem::transmute_copy::(&v) }; + let sum_i64 = acc + .sum + .map(|s| unsafe { std::mem::transmute_copy::(&s) }); + + let new_sum = match sum_i64 { + None => v_i64, + Some(s) => match s.checked_add(v_i64) { + Some(result) => result, + None => { + acc.failed = true; + return Ok(()); + } + }, + }; + + acc.sum = Some(unsafe { std::mem::transmute_copy::(&new_sum) }); + } + Ok(()) +} + +fn update_float64( + acc: &mut TrySumAccumulator, + array: &PrimitiveArray, +) -> Result<()> { + for v in array.iter().flatten() { + let v_f64 = unsafe { std::mem::transmute_copy::(&v) }; + let sum_f64 = acc + .sum + .map(|s| unsafe { std::mem::transmute_copy::(&s) }) + .unwrap_or(0.0); + let new_sum = sum_f64 + v_f64; + acc.sum = Some(unsafe { std::mem::transmute_copy::(&new_sum) }); + } + Ok(()) +} + +fn update_decimal128( + acc: &mut TrySumAccumulator, + array: &PrimitiveArray, +) -> Result<()> { + let precision = acc.dec_precision.unwrap_or(DECIMAL128_MAX_PRECISION); + + for v in array.iter().flatten() { + let v_i128 = unsafe { std::mem::transmute_copy::(&v) }; + let sum_i128 = acc + .sum + .map(|s| unsafe { std::mem::transmute_copy::(&s) }); + + let new_sum = match sum_i128 { + None => v_i128, + Some(s) => match s.checked_add(v_i128) { + Some(result) => result, + None => { + acc.failed = true; + return Ok(()); + } + }, + }; + + if exceeds_decimal128_precision(new_sum, precision) { + acc.failed = true; + return Ok(()); + } + + acc.sum = Some(unsafe { std::mem::transmute_copy::(&new_sum) }); + } + Ok(()) +} + +fn evaluate_internal( + acc: &mut TrySumAccumulator, +) -> Result { + if acc.failed { + return ScalarValue::new_primitive::(None, &acc.data_type); + } + ScalarValue::new_primitive::(acc.sum, &acc.data_type) +} + +// Helpers to determine if it exceeds decimal precision +fn pow10_i128(p: u8) -> Option { + let mut v: i128 = 1; + for _ in 0..p { + v = v.checked_mul(10)?; + } + Some(v) +} + +fn exceeds_decimal128_precision(sum: i128, p: u8) -> bool { + if let Some(max_plus_one) = pow10_i128(p) { + let max = max_plus_one - 1; + sum > max || sum < -max + } else { + true + } +} + +impl AggregateUDFImpl for SparkTrySum { + fn name(&self) -> &str { + "try_sum" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + + let dt = &arg_types[0]; + let result_type = match dt { + Null => Float64, + Decimal128(p, s) => { + let new_precision = DECIMAL128_MAX_PRECISION.min(p + 10); + Decimal128(new_precision, *s) + } + Int8 | Int16 | Int32 | Int64 => Int64, + Float16 | Float32 | Float64 => Float64, + + other => return exec_err!("try_sum: unsupported type: {other:?}"), + }; + + Ok(result_type) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(TrySumAccumulator::<$t>::new($dt.clone()))) + }; + } + + match acc_args.return_field.data_type() { + DataType::Int64 => helper!(Int64Type, acc_args.return_field.data_type()), + DataType::Float64 => helper!(Float64Type, acc_args.return_field.data_type()), + DataType::Decimal128(_, _) => { + helper!(Decimal128Type, acc_args.return_field.data_type()) + } + _ => not_impl_err!( + "try_sum: unsupported type for accumulator: {}", + acc_args.return_field.data_type() + ), + } + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let sum_dt = args.return_field.data_type().clone(); + Ok(vec![ + Field::new(format_state_name(args.name, "sum"), sum_dt, true).into(), + Field::new( + format_state_name(args.name, "failed"), + DataType::Boolean, + false, + ) + .into(), + ]) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + use DataType::*; + if arg_types.len() != 1 { + return exec_err!( + "try_sum: exactly 1 argument expected, got {}", + arg_types.len() + ); + } + + let dt = &arg_types[0]; + let coerced = match dt { + Null => Float64, + Decimal128(p, s) => Decimal128(*p, *s), + Int8 | Int16 | Int32 | Int64 => Int64, + Float16 | Float32 | Float64 => Float64, + other => return exec_err!("try_sum: unsupported type: {other:?}"), + }; + Ok(vec![coerced]) + } + + fn default_value(&self, _data_type: &DataType) -> Result { + Ok(ScalarValue::Null) + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Decimal128Array, Float64Array, Int64Array}; + use datafusion_common::DataFusionError; + use std::sync::Arc; + + use super::*; + // -------- Helpers -------- + + fn int64(values: Vec>) -> ArrayRef { + Arc::new(Int64Array::from(values)) as ArrayRef + } + + fn f64(values: Vec>) -> ArrayRef { + Arc::new(Float64Array::from(values)) as ArrayRef + } + + fn dec128(p: u8, s: i8, vals: Vec>) -> Result { + let base = Decimal128Array::from(vals); + let arr = base.with_precision_and_scale(p, s).map_err(|e| { + DataFusionError::Execution(format!("invalid precision/scale ({p},{s}): {e}")) + })?; + Ok(Arc::new(arr) as ArrayRef) + } + + // -------- update_batch + evaluate -------- + + #[test] + fn try_sum_int_basic() -> Result<()> { + let mut acc = TrySumAccumulator::::new(DataType::Int64); + acc.update_batch(&[int64((0..10).map(Some).collect())])?; + let out = acc.evaluate()?; + assert_eq!(out, ScalarValue::Int64(Some(45))); + Ok(()) + } + + #[test] + fn try_sum_int_with_nulls() -> Result<()> { + let mut acc = TrySumAccumulator::::new(DataType::Int64); + acc.update_batch(&[int64(vec![None, Some(2), Some(3), None, Some(5)])])?; + let out = acc.evaluate()?; + assert_eq!(out, ScalarValue::Int64(Some(10))); + Ok(()) + } + + #[test] + fn try_sum_float_basic() -> Result<()> { + let mut acc = TrySumAccumulator::::new(DataType::Float64); + acc.update_batch(&[f64(vec![Some(1.5), Some(2.5), None, Some(3.0)])])?; + let out = acc.evaluate()?; + assert_eq!(out, ScalarValue::Float64(Some(7.0))); + Ok(()) + } + + #[test] + fn float_overflow_behaves_like_spark_sum_infinite() -> Result<()> { + let mut acc = TrySumAccumulator::::new(DataType::Float64); + acc.update_batch(&[f64(vec![Some(1e308), Some(1e308)])])?; + + let out = acc.evaluate()?; + assert!( + matches!(out, ScalarValue::Float64(Some(v)) if v.is_infinite() && v.is_sign_positive()), + "waiting +Infinity, got: {out:?}" + ); + Ok(()) + } + + #[test] + fn try_sum_float_negative_zero_normalizes_to_positive_zero() -> Result<()> { + let mut acc = TrySumAccumulator::::new(DataType::Float64); + // -0.0 + 0.0 should normalize to 0.0 (positive zero), not -0.0 + acc.update_batch(&[f64(vec![Some(-0.0), Some(0.0)])])?; + let out = acc.evaluate()?; + assert_eq!(out, ScalarValue::Float64(Some(0.0))); + // Verify it's positive zero using is_sign_positive + if let ScalarValue::Float64(Some(v)) = out { + assert!(v.is_sign_positive() || v == 0.0); + } + Ok(()) + } + + #[test] + fn try_sum_decimal_basic() -> Result<()> { + let p = 10u8; + let s = 2i8; + let mut acc = + TrySumAccumulator::::new(DataType::Decimal128(p, s)); + acc.update_batch(&[dec128(p, s, vec![Some(123), Some(477)])?])?; + let out = acc.evaluate()?; + assert_eq!(out, ScalarValue::Decimal128(Some(600), p, s)); + Ok(()) + } + + #[test] + fn try_sum_decimal_with_nulls() -> Result<()> { + let p = 10u8; + let s = 2i8; + let mut acc = + TrySumAccumulator::::new(DataType::Decimal128(p, s)); + acc.update_batch(&[dec128(p, s, vec![Some(150), None, Some(200)])?])?; + let out = acc.evaluate()?; + assert_eq!(out, ScalarValue::Decimal128(Some(350), p, s)); + Ok(()) + } + + #[test] + fn try_sum_decimal_overflow_sets_failed() -> Result<()> { + let p = 5u8; + let s = 0i8; + let mut acc = + TrySumAccumulator::::new(DataType::Decimal128(p, s)); + acc.update_batch(&[dec128(p, s, vec![Some(90_000), Some(20_000)])?])?; + let out = acc.evaluate()?; + assert_eq!(out, ScalarValue::Decimal128(None, p, s)); + assert!(acc.failed); + Ok(()) + } + + #[test] + fn try_sum_decimal_merge_ok_and_failure_propagation() -> Result<()> { + let p = 10u8; + let s = 2i8; + + let mut p_ok = + TrySumAccumulator::::new(DataType::Decimal128(p, s)); + p_ok.update_batch(&[dec128(p, s, vec![Some(100), Some(200)])?])?; + let s_ok = p_ok + .state()? + .into_iter() + .map(|sv| sv.to_array()) + .collect::>>()?; + + let mut p_fail = + TrySumAccumulator::::new(DataType::Decimal128(p, s)); + p_fail.update_batch(&[dec128(p, s, vec![Some(i128::MAX), Some(1)])?])?; + let s_fail = p_fail + .state()? + .into_iter() + .map(|sv| sv.to_array()) + .collect::>>()?; + + let mut final_acc = + TrySumAccumulator::::new(DataType::Decimal128(p, s)); + final_acc.merge_batch(&s_ok)?; + final_acc.merge_batch(&s_fail)?; + + assert!(final_acc.failed); + assert_eq!(final_acc.evaluate()?, ScalarValue::Decimal128(None, p, s)); + Ok(()) + } + + #[test] + fn try_sum_int_overflow_sets_failed() -> Result<()> { + let mut acc = TrySumAccumulator::::new(DataType::Int64); + // i64::MAX + 1 => overflow => failed => result NULL + acc.update_batch(&[int64(vec![Some(i64::MAX), Some(1)])])?; + let out = acc.evaluate()?; + assert_eq!(out, ScalarValue::Int64(None)); + assert!(acc.failed); + Ok(()) + } + + #[test] + fn try_sum_int_negative_overflow_sets_failed() -> Result<()> { + let mut acc = TrySumAccumulator::::new(DataType::Int64); + // i64::MIN - 1 → overflow negative + acc.update_batch(&[int64(vec![Some(i64::MIN), Some(-1)])])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(None)); + assert!(acc.failed); + Ok(()) + } + + // -------- state + merge_batch -------- + + #[test] + fn try_sum_state_two_fields_and_merge_ok() -> Result<()> { + // acumulador 1 [10, 5] -> sum=15 + let mut acc1 = TrySumAccumulator::::new(DataType::Int64); + acc1.update_batch(&[int64(vec![Some(10), Some(5)])])?; + let state1 = acc1.state()?; // [sum, failed] + assert_eq!(state1.len(), 2); + + // acumulador 2 [20, NULL] -> sum=20 + let mut acc2 = TrySumAccumulator::::new(DataType::Int64); + acc2.update_batch(&[int64(vec![Some(20), None])])?; + let state2 = acc2.state()?; // [sum, failed] + + let state1_arrays: Vec = state1 + .into_iter() + .map(|sv| sv.to_array()) + .collect::>()?; + + let state2_arrays: Vec = state2 + .into_iter() + .map(|sv| sv.to_array()) + .collect::>()?; + + // final accumulator + let mut final_acc = TrySumAccumulator::::new(DataType::Int64); + + final_acc.merge_batch(&state1_arrays)?; + final_acc.merge_batch(&state2_arrays)?; + + // sum total = 15 + 20 = 35 + assert!(!final_acc.failed); + assert_eq!(final_acc.evaluate()?, ScalarValue::Int64(Some(35))); + Ok(()) + } + + #[test] + fn try_sum_merge_propagates_failure() -> Result<()> { + // sum=NULL, failed=true + let failed_sum = Arc::new(Int64Array::from(vec![None])) as ArrayRef; + let failed_flag = Arc::new(BooleanArray::from(vec![Some(true)])) as ArrayRef; + + let mut acc = TrySumAccumulator::::new(DataType::Int64); + acc.merge_batch(&[failed_sum, failed_flag])?; + + assert!(acc.failed); + assert_eq!(acc.evaluate()?, ScalarValue::Int64(None)); + Ok(()) + } + + #[test] + fn try_sum_merge_empty_partition_is_not_failure() -> Result<()> { + // sum=NULL, failed=false + let empty_sum = Arc::new(Int64Array::from(vec![None])) as ArrayRef; + let ok_flag = Arc::new(BooleanArray::from(vec![Some(false)])) as ArrayRef; + + let mut acc = TrySumAccumulator::::new(DataType::Int64); + acc.update_batch(&[int64(vec![Some(7), Some(8)])])?; // 15 + + acc.merge_batch(&[empty_sum, ok_flag])?; + + assert!(!acc.failed); + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(15))); + Ok(()) + } + + // -------- signature -------- + + #[test] + fn try_sum_return_type_matches_input() -> Result<()> { + let f = SparkTrySum::new(); + assert_eq!(f.return_type(&[DataType::Int64])?, DataType::Int64); + assert_eq!(f.return_type(&[DataType::Float64])?, DataType::Float64); + Ok(()) + } + + #[test] + fn try_sum_state_and_evaluate_consistency() -> Result<()> { + let mut acc = TrySumAccumulator::::new(DataType::Float64); + acc.update_batch(&[f64(vec![Some(1.0), Some(2.0)])])?; + let eval = acc.evaluate()?; + let state = acc.state()?; + assert_eq!(state[0], eval); + assert_eq!(state[1], ScalarValue::Boolean(Some(false))); + Ok(()) + } + + // ------------------------- + // DECIMAL + // ------------------------- + + #[test] + fn decimal_10_2_sum_and_schema_widened() -> Result<()> { + // input: DECIMAL(10,2) -> result: DECIMAL(20,2) + let f = SparkTrySum::new(); + assert_eq!( + f.return_type(&[DataType::Decimal128(10, 2)])?, + DataType::Decimal128(20, 2), + "Spark needs +10 more digits of precision" + ); + + let mut acc = + TrySumAccumulator::::new(DataType::Decimal128(20, 2)); + acc.update_batch(&[dec128(10, 2, vec![Some(123), Some(477)])?])?; + assert_eq!(acc.evaluate()?, ScalarValue::Decimal128(Some(600), 20, 2)); + Ok(()) + } + + #[test] + fn decimal_5_0_fits_after_widening() -> Result<()> { + // input: DECIMAL(5,0) -> result: DECIMAL(15,0) + let f = SparkTrySum::new(); + assert_eq!( + f.return_type(&[DataType::Decimal128(5, 0)])?, + DataType::Decimal128(15, 0) + ); + + let mut acc = + TrySumAccumulator::::new(DataType::Decimal128(15, 0)); + acc.update_batch(&[dec128(5, 0, vec![Some(90_000), Some(20_000)])?])?; + assert_eq!( + acc.evaluate()?, + ScalarValue::Decimal128(Some(110_000), 15, 0) + ); + Ok(()) + } + + #[test] + fn decimal_38_0_max_precision_overflows_to_null() -> Result<()> { + let f = SparkTrySum::new(); + assert_eq!( + f.return_type(&[DataType::Decimal128(38, 0)])?, + DataType::Decimal128(38, 0) + ); + let ten_pow_38_minus_1 = { + let p10 = pow10_i128(38) + .ok_or_else(|| DataFusionError::Internal("10^38 overflow".into()))?; + p10 - 1 + }; + let mut acc = + TrySumAccumulator::::new(DataType::Decimal128(38, 0)); + acc.update_batch(&[dec128(38, 0, vec![Some(ten_pow_38_minus_1), Some(1)])?])?; + + assert!(acc.failed, "need fail in overflow p=38"); + assert_eq!(acc.evaluate()?, ScalarValue::Decimal128(None, 38, 0)); + Ok(()) + } +} diff --git a/datafusion/spark/src/function/array/array_contains.rs b/datafusion/spark/src/function/array/array_contains.rs new file mode 100644 index 0000000000000..5c7cb4be6ff9d --- /dev/null +++ b/datafusion/spark/src/function/array/array_contains.rs @@ -0,0 +1,163 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, AsArray, BooleanArray, BooleanBufferBuilder, GenericListArray, OffsetSizeTrait, +}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::datatypes::DataType; +use datafusion_common::{Result, exec_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions_nested::array_has::array_has_udf; +use std::sync::Arc; + +/// Spark-compatible `array_contains` function. +/// +/// Calls DataFusion's `array_has` and then applies Spark's null semantics: +/// - If the result from `array_has` is `true`, return `true`. +/// - If the result is `false` and the input array row contains any null elements, +/// return `null` (because the element might have been the null). +/// - If the result is `false` and the input array row has no null elements, +/// return `false`. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkArrayContains { + signature: Signature, +} + +impl Default for SparkArrayContains { + fn default() -> Self { + Self::new() + } +} + +impl SparkArrayContains { + pub fn new() -> Self { + Self { + signature: Signature::array_and_element(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkArrayContains { + fn name(&self) -> &str { + "array_contains" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let haystack = args.args[0].clone(); + let array_has_result = array_has_udf().invoke_with_args(args)?; + + let result_array = array_has_result.to_array(1)?; + let patched = apply_spark_null_semantics(result_array.as_boolean(), &haystack)?; + Ok(ColumnarValue::Array(Arc::new(patched))) + } +} + +/// For each row where `array_has` returned `false`, set the output to null +/// if that row's input array contains any null elements. +fn apply_spark_null_semantics( + result: &BooleanArray, + haystack_arg: &ColumnarValue, +) -> Result { + // happy path + if haystack_arg.data_type() == DataType::Null || !result.has_false() { + return Ok(result.clone()); + } + + let haystack = haystack_arg.to_array_of_size(result.len())?; + + let row_has_nulls = compute_row_has_nulls(&haystack)?; + + // A row keeps its validity when result is true OR the row has no nulls. + let keep_mask = result.values() | &!&row_has_nulls; + let new_validity = match result.nulls() { + Some(n) => n.inner() & &keep_mask, + None => keep_mask, + }; + + Ok(BooleanArray::new( + result.values().clone(), + Some(NullBuffer::new(new_validity)), + )) +} + +/// Returns a per-row bitmap where bit i is set if row i's list contains any null element. +fn compute_row_has_nulls(haystack: &dyn Array) -> Result { + match haystack.data_type() { + DataType::List(_) => generic_list_row_has_nulls(haystack.as_list::()), + DataType::LargeList(_) => generic_list_row_has_nulls(haystack.as_list::()), + DataType::FixedSizeList(_, _) => { + let list = haystack.as_fixed_size_list(); + let buf = match list.values().nulls() { + Some(nulls) => { + let validity = nulls.inner(); + let vl = list.value_length() as usize; + let mut builder = BooleanBufferBuilder::new(list.len()); + for i in 0..list.len() { + builder.append(validity.slice(i * vl, vl).count_set_bits() < vl); + } + builder.finish() + } + None => BooleanBuffer::new_unset(list.len()), + }; + Ok(mask_with_list_nulls(buf, list.nulls())) + } + dt => exec_err!("compute_row_has_nulls: unsupported data type {dt}"), + } +} + +/// Computes per-row null presence for `List` and `LargeList` arrays. +fn generic_list_row_has_nulls( + list: &GenericListArray, +) -> Result { + let buf = match list.values().nulls() { + Some(nulls) => { + let validity = nulls.inner(); + let offsets = list.offsets(); + let mut builder = BooleanBufferBuilder::new(list.len()); + for i in 0..list.len() { + let s = offsets[i].as_usize(); + let len = offsets[i + 1].as_usize() - s; + builder.append(validity.slice(s, len).count_set_bits() < len); + } + builder.finish() + } + None => BooleanBuffer::new_unset(list.len()), + }; + Ok(mask_with_list_nulls(buf, list.nulls())) +} + +/// Rows where the list itself is null should not be marked as "has nulls". +fn mask_with_list_nulls( + buf: BooleanBuffer, + list_nulls: Option<&NullBuffer>, +) -> BooleanBuffer { + match list_nulls { + Some(n) => &buf & n.inner(), + None => buf, + } +} diff --git a/datafusion/spark/src/function/array/mod.rs b/datafusion/spark/src/function/array/mod.rs index 01056ba952984..6c16e05361641 100644 --- a/datafusion/spark/src/function/array/mod.rs +++ b/datafusion/spark/src/function/array/mod.rs @@ -15,27 +15,54 @@ // specific language governing permissions and limitations // under the License. +pub mod array_contains; +pub mod repeat; pub mod shuffle; +pub mod slice; pub mod spark_array; use datafusion_expr::ScalarUDF; use datafusion_functions::make_udf_function; use std::sync::Arc; +make_udf_function!(array_contains::SparkArrayContains, spark_array_contains); make_udf_function!(spark_array::SparkArray, array); make_udf_function!(shuffle::SparkShuffle, shuffle); +make_udf_function!(repeat::SparkArrayRepeat, array_repeat); +make_udf_function!(slice::SparkSlice, slice); pub mod expr_fn { use datafusion_functions::export_functions; + export_functions!(( + spark_array_contains, + "Returns true if the array contains the element (Spark semantics).", + array element + )); export_functions!((array, "Returns an array with the given elements.", args)); export_functions!(( shuffle, "Returns a random permutation of the given array.", args )); + export_functions!(( + array_repeat, + "returns an array containing element count times.", + element count + )); + export_functions!(( + slice, + "Returns a slice of the array from the start index with the given length.", + array start length + )); } pub fn functions() -> Vec> { - vec![array(), shuffle()] + vec![ + spark_array_contains(), + array(), + shuffle(), + array_repeat(), + slice(), + ] } diff --git a/datafusion/spark/src/function/array/repeat.rs b/datafusion/spark/src/function/array/repeat.rs new file mode 100644 index 0000000000000..da9b19a768680 --- /dev/null +++ b/datafusion/spark/src/function/array/repeat.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions_nested::repeat::ArrayRepeat; +use std::sync::Arc; + +use crate::function::null_utils::{ + NullMaskResolution, apply_null_mask, compute_null_mask, +}; + +/// Spark-compatible `array_repeat` expression. The difference with DataFusion's `array_repeat` is the handling of NULL count: in Spark if the count is NULL, the result is NULL. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkArrayRepeat { + signature: Signature, +} + +impl Default for SparkArrayRepeat { + fn default() -> Self { + Self::new() + } +} + +impl SparkArrayRepeat { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkArrayRepeat { + fn name(&self) -> &str { + "array_repeat" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new_list_field( + arg_types[0].clone(), + true, + )))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_array_repeat(args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [first_type, second_type] = take_function_args(self.name(), arg_types)?; + + // Coerce the second argument to Int64/UInt64 if it's a numeric type + let second = match second_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + DataType::Int64 + } + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + DataType::UInt64 + } + _ => return exec_err!("count must be an integer type"), + }; + + Ok(vec![first_type.clone(), second]) + } +} + +/// This is a Spark-specific wrapper around DataFusion's array_repeat that returns NULL +/// if the count argument is NULL (Spark behavior), whereas DataFusion's array_repeat ignores NULLs. +fn spark_array_repeat(args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { + args: arg_values, + arg_fields, + number_rows, + return_field, + config_options, + } = args; + let return_type = return_field.data_type().clone(); + + // A NULL element should be repeated into the array, not cause a NULL result. + let null_mask = compute_null_mask(&arg_values[1..]); + + // If count is null then return NULL immediately + if matches!(null_mask, NullMaskResolution::ReturnNull) { + return Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?)); + } + + let array_repeat_func = ArrayRepeat::new(); + let func_args = ScalarFunctionArgs { + args: arg_values, + arg_fields, + number_rows, + return_field, + config_options, + }; + let result = array_repeat_func.invoke_with_args(func_args)?; + + apply_null_mask(result, null_mask, &return_type) +} diff --git a/datafusion/spark/src/function/array/shuffle.rs b/datafusion/spark/src/function/array/shuffle.rs index 9f345b53b89a7..031dd17177577 100644 --- a/datafusion/spark/src/function/array/shuffle.rs +++ b/datafusion/spark/src/function/array/shuffle.rs @@ -26,15 +26,16 @@ use arrow::datatypes::FieldRef; use datafusion_common::cast::{ as_fixed_size_list_array, as_large_list_array, as_list_array, }; -use datafusion_common::{exec_err, utils::take_function_args, Result, ScalarValue}; +use datafusion_common::{ + Result, ScalarValue, exec_err, internal_err, utils::take_function_args, +}; use datafusion_expr::{ - ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl, - Signature, TypeSignature, Volatility, + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use rand::rng; use rand::rngs::StdRng; -use rand::{seq::SliceRandom, Rng, SeedableRng}; -use std::any::Any; +use rand::{Rng, SeedableRng, seq::SliceRandom}; use std::sync::Arc; #[derive(Debug, PartialEq, Eq, Hash)] @@ -75,10 +76,6 @@ impl SparkShuffle { } impl ScalarUDFImpl for SparkShuffle { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "shuffle" } @@ -87,19 +84,21 @@ impl ScalarUDFImpl for SparkShuffle { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") } - fn invoke_with_args( + fn return_field_from_args( &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - if args.args.is_empty() { - return exec_err!("shuffle expects at least 1 argument"); - } - if args.args.len() > 2 { - return exec_err!("shuffle expects at most 2 arguments"); + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + // Shuffle returns an array with the same type and nullability as the input + Ok(Arc::clone(&args.arg_fields[0])) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.is_empty() || args.args.len() > 2 { + return exec_err!("shuffle expects 1 or 2 argument(s)"); } // Extract seed from second argument if present @@ -121,10 +120,10 @@ fn extract_seed(seed_arg: &ColumnarValue) -> Result> { ColumnarValue::Scalar(scalar) => { let seed = match scalar { ScalarValue::Int64(Some(v)) => Some(*v as u64), - ScalarValue::Null => None, + ScalarValue::Null | ScalarValue::Int64(None) => None, _ => { return exec_err!( - "shuffle seed must be Int64 type, got '{}'", + "shuffle seed must be Int64 type but got '{}'", scalar.data_type() ); } @@ -154,7 +153,10 @@ fn array_shuffle_with_seed(arg: &[ArrayRef], seed: Option) -> Result Ok(Arc::clone(input_array)), - array_type => exec_err!("shuffle does not support type '{array_type}'."), + array_type => exec_err!( + "shuffle does not support type '{array_type}'; \ + expected types: List, LargeList, FixedSizeList or Null." + ), } } @@ -263,3 +265,51 @@ fn fixed_size_array_shuffle( Some(nulls.into()), )?)) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::Field; + use datafusion_expr::ReturnFieldArgs; + + #[test] + fn test_shuffle_nullability() { + let shuffle = SparkShuffle::new(); + + // Test with non-nullable array + let non_nullable_field = Arc::new(Field::new( + "arr", + List(Arc::new(Field::new("item", DataType::Int32, true))), + false, // not nullable + )); + + let result = shuffle + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&non_nullable_field)], + scalar_arguments: &[None], + }) + .unwrap(); + + // The result should not be nullable (same as input) + assert!(!result.is_nullable()); + assert_eq!(result.data_type(), non_nullable_field.data_type()); + + // Test with nullable array + let nullable_field = Arc::new(Field::new( + "arr", + List(Arc::new(Field::new("item", DataType::Int32, true))), + true, // nullable + )); + + let result = shuffle + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&nullable_field)], + scalar_arguments: &[None], + }) + .unwrap(); + + // The result should be nullable (same as input) + assert!(result.is_nullable()); + assert_eq!(result.data_type(), nullable_field.data_type()); + } +} diff --git a/datafusion/spark/src/function/array/slice.rs b/datafusion/spark/src/function/array/slice.rs new file mode 100644 index 0000000000000..5c65f899a01b0 --- /dev/null +++ b/datafusion/spark/src/function/array/slice.rs @@ -0,0 +1,249 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, Int64Builder}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::{as_int64_array, as_list_array}; +use datafusion_common::utils::ListCoercion; +use datafusion_common::{ + Result, ScalarValue, exec_err, internal_err, utils::take_function_args, +}; +use datafusion_expr::{ + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use datafusion_functions_nested::extract::array_slice_udf; +use std::sync::Arc; + +/// Spark slice function implementation +/// Main difference from DataFusion's array_slice is that the third argument is the length of the slice and not the end index. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkSlice { + signature: Signature, +} + +impl Default for SparkSlice { + fn default() -> Self { + Self::new() + } +} + +impl SparkSlice { + pub fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::Array { + arguments: vec![ + ArrayFunctionArgument::Array, + ArrayFunctionArgument::Index, + ArrayFunctionArgument::Index, + ], + array_coercion: Some(ListCoercion::FixedSizedListToList), + }, + ), + volatility: Volatility::Immutable, + parameter_names: None, + }, + } + } +} + +impl ScalarUDFImpl for SparkSlice { + fn name(&self) -> &str { + "slice" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + let data_type = match args.arg_fields[0].data_type() { + DataType::Null => { + DataType::List(Arc::new(Field::new_list_field(DataType::Null, true))) + } + dt => dt.clone(), + }; + + Ok(Arc::new(Field::new("slice", data_type, nullable))) + } + + fn invoke_with_args( + &self, + mut func_args: ScalarFunctionArgs, + ) -> Result { + if func_args.args[0].data_type() == DataType::Null { + return Ok(ColumnarValue::Scalar(ScalarValue::new_null_list( + DataType::Null, + true, + 1, + ))); + } + + let array_len = func_args + .args + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }) + .unwrap_or(func_args.number_rows); + + let arrays = func_args + .args + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => Ok(Arc::clone(array)), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len), + }) + .collect::>>()?; + + let (start, end) = calculate_start_end(&arrays)?; + + array_slice_udf().invoke_with_args(ScalarFunctionArgs { + args: vec![ + func_args.args.swap_remove(0), + ColumnarValue::Array(start), + ColumnarValue::Array(end), + ], + arg_fields: func_args.arg_fields, + number_rows: func_args.number_rows, + return_field: func_args.return_field, + config_options: func_args.config_options, + }) + } +} + +fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> { + let [values, start, length] = take_function_args("slice", args)?; + + let values_len = values.len(); + + let start = as_int64_array(&start)?; + let length = as_int64_array(&length)?; + + let values = as_list_array(values)?; + + let mut adjusted_start = Int64Builder::with_capacity(values_len); + let mut end = Int64Builder::with_capacity(values_len); + + for row in 0..values_len { + if values.is_null(row) || start.is_null(row) || length.is_null(row) { + adjusted_start.append_null(); + end.append_null(); + continue; + } + let start = start.value(row); + let length = length.value(row); + let value_length = values.value(row).len() as i64; + + if start == 0 { + return exec_err!("Start index must not be zero"); + } + if length < 0 { + return exec_err!("Length must be non-negative, but got {}", length); + } + + let adjusted_start_value = if start < 0 { + start + value_length + 1 + } else { + start + }; + + // Spark returns an empty array when the adjusted start lands before + // position 1 (e.g. slice([1], -2, 2)). array_slice would otherwise + // treat 0 the same as 1 and return the first element. + if adjusted_start_value < 1 { + adjusted_start.append_value(1); + end.append_value(0); + continue; + } + + adjusted_start.append_value(adjusted_start_value); + end.append_value(adjusted_start_value + (length - 1)); + } + + Ok((Arc::new(adjusted_start.finish()), Arc::new(end.finish()))) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::NullArray; + use arrow::datatypes::Field; + use datafusion_common::ScalarValue; + use datafusion_common::cast::as_list_array; + use datafusion_expr::ReturnFieldArgs; + + #[test] + fn test_spark_slice_function_when_input_is_null() { + let slice = SparkSlice::new(); + let arg_fields: Vec> = vec![ + Arc::new(Field::new("a", DataType::Null, true)), + Arc::new(Field::new("s", DataType::Int64, true)), + Arc::new(Field::new("l", DataType::Int64, true)), + ]; + let out = slice + .return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &[], + }) + .unwrap(); + assert_eq!( + out.data_type(), + &DataType::List(Arc::new(Field::new_list_field(DataType::Null, true))) + ); + } + + #[test] + fn test_spark_slice_function_when_input_array_is_null() { + let input_args = vec![ + ColumnarValue::Array(Arc::new(NullArray::new(1))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ]; + + let args = ScalarFunctionArgs { + args: input_args, + arg_fields: vec![Arc::new(Field::new("item", DataType::Null, true))], + number_rows: 1, + return_field: Arc::new(Field::new( + "slice", + DataType::List(Arc::new(Field::new_list_field(DataType::Null, true))), + true, + )), + config_options: Arc::new(Default::default()), + }; + let slice = SparkSlice::new(); + let result = slice.invoke_with_args(args).unwrap(); + let arr = result.to_array(1).unwrap(); + let list = as_list_array(&arr).unwrap(); + assert_eq!( + arr.data_type(), + &DataType::List(Arc::new(Field::new_list_field(DataType::Null, true))) + ); + assert!(list.is_null(0)); + } +} diff --git a/datafusion/spark/src/function/array/spark_array.rs b/datafusion/spark/src/function/array/spark_array.rs index bb9665613de9b..d6d4e7f0ab9f0 100644 --- a/datafusion/spark/src/function/array/spark_array.rs +++ b/datafusion/spark/src/function/array/spark_array.rs @@ -15,21 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::{any::Any, sync::Arc}; +use std::sync::Arc; -use arrow::array::{ - make_array, new_null_array, Array, ArrayData, ArrayRef, Capacities, GenericListArray, - MutableArrayData, NullArray, OffsetSizeTrait, -}; -use arrow::buffer::OffsetBuffer; +use arrow::array::{Array, ArrayRef, new_null_array}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::utils::SingleRowListArrayBuilder; -use datafusion_common::{internal_err, plan_datafusion_err, plan_err, Result}; -use datafusion_expr::type_coercion::binary::comparison_coercion; +use datafusion_common::{Result, internal_err}; use datafusion_expr::{ ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, - TypeSignature, Volatility, + Volatility, }; +use datafusion_functions_nested::make_array::{array_array, coerce_types_inner}; use crate::function::functions_nested_utils::make_scalar_function; @@ -38,7 +34,6 @@ const ARRAY_FIELD_DEFAULT_NAME: &str = "element"; #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkArray { signature: Signature, - aliases: Vec, } impl Default for SparkArray { @@ -50,20 +45,12 @@ impl Default for SparkArray { impl SparkArray { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![TypeSignature::UserDefined, TypeSignature::Nullary], - Volatility::Immutable, - ), - aliases: vec![String::from("spark_make_array")], + signature: Signature::user_defined(Volatility::Immutable), } } } impl ScalarUDFImpl for SparkArray { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "array" } @@ -92,10 +79,6 @@ impl ScalarUDFImpl for SparkArray { } } - if expr_type.is_null() { - expr_type = DataType::Int32; - } - let return_type = DataType::List(Arc::new(Field::new( ARRAY_FIELD_DEFAULT_NAME, expr_type, @@ -114,31 +97,12 @@ impl ScalarUDFImpl for SparkArray { make_scalar_function(make_array_inner)(args.as_slice()) } - fn aliases(&self) -> &[String] { - &self.aliases - } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let first_type = arg_types.first().ok_or_else(|| { - plan_datafusion_err!("Spark array function requires at least one argument") - })?; - let new_type = - arg_types - .iter() - .skip(1) - .try_fold(first_type.clone(), |acc, x| { - // The coerced types found by `comparison_coercion` are not guaranteed to be - // coercible for the arguments. `comparison_coercion` returns more loose - // types that can be coerced to both `acc` and `x` for comparison purpose. - // See `maybe_data_types` for the actual coercion. - let coerced_type = comparison_coercion(&acc, x); - if let Some(coerced_type) = coerced_type { - Ok(coerced_type) - } else { - plan_err!("Coercion from {acc} to {x} failed.") - } - })?; - Ok(vec![new_type; arg_types.len()]) + if arg_types.is_empty() { + Ok(vec![]) + } else { + coerce_types_inner(arg_types, self.name()) + } } } @@ -160,7 +124,7 @@ pub fn make_array_inner(arrays: &[ArrayRef]) -> Result { DataType::Null => { let length = arrays.iter().map(|a| a.len()).sum(); // By default Int32 - let array = new_null_array(&DataType::Int32, length); + let array = new_null_array(&DataType::Null, length); Ok(Arc::new( SingleRowListArrayBuilder::new(array) .with_nullable(true) @@ -168,98 +132,6 @@ pub fn make_array_inner(arrays: &[ArrayRef]) -> Result { .build_list_array(), )) } - _ => array_array::(arrays, data_type), - } -} - -/// Convert one or more [`ArrayRef`] of the same type into a -/// `ListArray` or 'LargeListArray' depending on the offset size. -/// -/// # Example (non nested) -/// -/// Calling `array(col1, col2)` where col1 and col2 are non nested -/// would return a single new `ListArray`, where each row was a list -/// of 2 elements: -/// -/// ```text -/// ┌─────────┐ ┌─────────┐ ┌──────────────┐ -/// │ ┌─────┐ │ │ ┌─────┐ │ │ ┌──────────┐ │ -/// │ │ A │ │ │ │ X │ │ │ │ [A, X] │ │ -/// │ ├─────┤ │ │ ├─────┤ │ │ ├──────────┤ │ -/// │ │NULL │ │ │ │ Y │ │──────────▶│ │[NULL, Y] │ │ -/// │ ├─────┤ │ │ ├─────┤ │ │ ├──────────┤ │ -/// │ │ C │ │ │ │ Z │ │ │ │ [C, Z] │ │ -/// │ └─────┘ │ │ └─────┘ │ │ └──────────┘ │ -/// └─────────┘ └─────────┘ └──────────────┘ -/// col1 col2 output -/// ``` -/// -/// # Example (nested) -/// -/// Calling `array(col1, col2)` where col1 and col2 are lists -/// would return a single new `ListArray`, where each row was a list -/// of the corresponding elements of col1 and col2. -/// -/// ``` text -/// ┌──────────────┐ ┌──────────────┐ ┌─────────────────────────────┐ -/// │ ┌──────────┐ │ │ ┌──────────┐ │ │ ┌────────────────────────┐ │ -/// │ │ [A, X] │ │ │ │ [] │ │ │ │ [[A, X], []] │ │ -/// │ ├──────────┤ │ │ ├──────────┤ │ │ ├────────────────────────┤ │ -/// │ │[NULL, Y] │ │ │ │[Q, R, S] │ │───────▶│ │ [[NULL, Y], [Q, R, S]] │ │ -/// │ ├──────────┤ │ │ ├──────────┤ │ │ ├────────────────────────│ │ -/// │ │ [C, Z] │ │ │ │ NULL │ │ │ │ [[C, Z], NULL] │ │ -/// │ └──────────┘ │ │ └──────────┘ │ │ └────────────────────────┘ │ -/// └──────────────┘ └──────────────┘ └─────────────────────────────┘ -/// col1 col2 output -/// ``` -fn array_array( - args: &[ArrayRef], - data_type: DataType, -) -> Result { - // do not accept 0 arguments. - if args.is_empty() { - return plan_err!("Array requires at least one argument"); - } - - let mut data = vec![]; - let mut total_len = 0; - for arg in args { - let arg_data = if arg.as_any().is::() { - ArrayData::new_empty(&data_type) - } else { - arg.to_data() - }; - total_len += arg_data.len(); - data.push(arg_data); - } - - let mut offsets: Vec = Vec::with_capacity(total_len); - offsets.push(O::usize_as(0)); - - let capacity = Capacities::Array(total_len); - let data_ref = data.iter().collect::>(); - let mut mutable = MutableArrayData::with_capacities(data_ref, true, capacity); - - let num_rows = args[0].len(); - for row_idx in 0..num_rows { - for (arr_idx, arg) in args.iter().enumerate() { - if !arg.as_any().is::() - && !arg.is_null(row_idx) - && arg.is_valid(row_idx) - { - mutable.extend(arr_idx, row_idx, row_idx + 1); - } else { - mutable.extend_nulls(1); - } - } - offsets.push(O::usize_as(mutable.len())); + _ => array_array::(arrays, data_type, ARRAY_FIELD_DEFAULT_NAME), } - let data = mutable.freeze(); - - Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new(ARRAY_FIELD_DEFAULT_NAME, data_type, true)), - OffsetBuffer::new(offsets.into()), - make_array(data), - None, - )?)) } diff --git a/datafusion/spark/src/function/bitmap/bitmap_bit_position.rs b/datafusion/spark/src/function/bitmap/bitmap_bit_position.rs new file mode 100644 index 0000000000000..49343f49a86d6 --- /dev/null +++ b/datafusion/spark/src/function/bitmap/bitmap_bit_position.rs @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, AsArray, Int64Array}; +use arrow::datatypes::Field; +use arrow::datatypes::{DataType, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use std::sync::Arc; + +/// Spark-compatible `bitmap_bit_position` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct BitmapBitPosition { + signature: Signature, +} + +impl Default for BitmapBitPosition { + fn default() -> Self { + Self::new() + } +} + +impl BitmapBitPosition { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Int8]), + TypeSignature::Exact(vec![DataType::Int16]), + TypeSignature::Exact(vec![DataType::Int32]), + TypeSignature::Exact(vec![DataType::Int64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for BitmapBitPosition { + fn name(&self) -> &str { + "bitmap_bit_position" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + Ok(Arc::new(Field::new( + self.name(), + DataType::Int64, + args.arg_fields[0].is_nullable(), + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(bitmap_bit_position_inner, vec![])(&args.args) + } +} + +pub fn bitmap_bit_position_inner(arg: &[ArrayRef]) -> Result { + let [array] = take_function_args("bitmap_bit_position", arg)?; + match &array.data_type() { + DataType::Int8 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| bitmap_bit_position(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int16 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| bitmap_bit_position(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int32 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| bitmap_bit_position(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int64 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(bitmap_bit_position)) + .collect(); + Ok(Arc::new(result)) + } + data_type => { + internal_err!("bitmap_bit_position does not support {data_type}") + } + } +} + +const NUM_BYTES: i64 = 4 * 1024; +const NUM_BITS: i64 = NUM_BYTES * 8; + +fn bitmap_bit_position(value: i64) -> i64 { + if value > 0 { + (value - 1) % NUM_BITS + } else { + (value.wrapping_neg()) % NUM_BITS + } +} diff --git a/datafusion/spark/src/function/bitmap/bitmap_bucket_number.rs b/datafusion/spark/src/function/bitmap/bitmap_bucket_number.rs new file mode 100644 index 0000000000000..e49a9ca3d4f0a --- /dev/null +++ b/datafusion/spark/src/function/bitmap/bitmap_bucket_number.rs @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, AsArray, Int64Array}; +use arrow::datatypes::Field; +use arrow::datatypes::{DataType, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use std::sync::Arc; + +/// Spark-compatible `bitmap_bucket_number` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct BitmapBucketNumber { + signature: Signature, +} + +impl Default for BitmapBucketNumber { + fn default() -> Self { + Self::new() + } +} + +impl BitmapBucketNumber { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Int8]), + TypeSignature::Exact(vec![DataType::Int16]), + TypeSignature::Exact(vec![DataType::Int32]), + TypeSignature::Exact(vec![DataType::Int64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for BitmapBucketNumber { + fn name(&self) -> &str { + "bitmap_bucket_number" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + Ok(Arc::new(Field::new( + self.name(), + DataType::Int64, + args.arg_fields[0].is_nullable(), + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(bitmap_bucket_number_inner, vec![])(&args.args) + } +} + +pub fn bitmap_bucket_number_inner(arg: &[ArrayRef]) -> Result { + let [array] = take_function_args("bitmap_bucket_number", arg)?; + match &array.data_type() { + DataType::Int8 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| bitmap_bucket_number(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int16 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| bitmap_bucket_number(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int32 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(|value| bitmap_bucket_number(value.into()))) + .collect(); + Ok(Arc::new(result)) + } + DataType::Int64 => { + let result: Int64Array = array + .as_primitive::() + .iter() + .map(|opt| opt.map(bitmap_bucket_number)) + .collect(); + Ok(Arc::new(result)) + } + data_type => { + internal_err!("bitmap_bucket_number does not support {data_type}") + } + } +} + +const NUM_BYTES: i64 = 4 * 1024; +const NUM_BITS: i64 = NUM_BYTES * 8; + +fn bitmap_bucket_number(value: i64) -> i64 { + if value > 0 { + 1 + (value - 1) / NUM_BITS + } else { + value / NUM_BITS + } +} diff --git a/datafusion/spark/src/function/bitmap/bitmap_count.rs b/datafusion/spark/src/function/bitmap/bitmap_count.rs index 56a9c5edb812c..89bea101afbe7 100644 --- a/datafusion/spark/src/function/bitmap/bitmap_count.rs +++ b/datafusion/spark/src/function/bitmap/bitmap_count.rs @@ -15,19 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::{ - as_dictionary_array, Array, ArrayRef, BinaryArray, BinaryViewArray, - FixedSizeBinaryArray, Int64Array, LargeBinaryArray, + Array, ArrayRef, BinaryArray, BinaryViewArray, FixedSizeBinaryArray, Int64Array, + LargeBinaryArray, as_dictionary_array, }; use arrow::datatypes::DataType::{ Binary, BinaryView, Dictionary, FixedSizeBinary, LargeBinary, }; -use arrow::datatypes::{DataType, Int16Type, Int32Type, Int64Type, Int8Type}; +use arrow::datatypes::{DataType, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type}; use datafusion_common::utils::take_function_args; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{Result, internal_err}; use datafusion_expr::{ Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, @@ -58,10 +57,6 @@ impl BitmapCount { } impl ScalarUDFImpl for BitmapCount { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "bitmap_count" } @@ -71,7 +66,20 @@ impl ScalarUDFImpl for BitmapCount { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Int64) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + use arrow::datatypes::Field; + // bitmap_count returns Int64 with the same nullability as the input + Ok(Arc::new(Field::new( + args.arg_fields[0].name(), + DataType::Int64, + args.arg_fields[0].is_nullable(), + ))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -205,12 +213,17 @@ mod tests { Box::new(ScalarValue::Binary(Some(vec![0xFFu8, 0xFFu8]))), )); - let arg_fields = vec![Field::new( - "a", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Binary)), - true, - ) - .into()]; + let arg_fields = vec![ + Field::new( + "a", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Binary), + ), + true, + ) + .into(), + ]; let args = ScalarFunctionArgs { args: vec![dict.clone()], arg_fields, @@ -224,4 +237,37 @@ mod tests { assert_eq!(*actual.into_array(1)?, *expect.into_array(1)?); Ok(()) } + + #[test] + fn test_bitmap_count_nullability() -> Result<()> { + use datafusion_expr::ReturnFieldArgs; + + let bitmap_count = BitmapCount::new(); + + // Test with non-nullable binary field + let non_nullable_field = Arc::new(Field::new("bin", DataType::Binary, false)); + + let result = bitmap_count.return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&non_nullable_field)], + scalar_arguments: &[None], + })?; + + // The result should not be nullable (same as input) + assert!(!result.is_nullable()); + assert_eq!(result.data_type(), &Int64); + + // Test with nullable binary field + let nullable_field = Arc::new(Field::new("bin", DataType::Binary, true)); + + let result = bitmap_count.return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&nullable_field)], + scalar_arguments: &[None], + })?; + + // The result should be nullable (same as input) + assert!(result.is_nullable()); + assert_eq!(result.data_type(), &Int64); + + Ok(()) + } } diff --git a/datafusion/spark/src/function/bitmap/mod.rs b/datafusion/spark/src/function/bitmap/mod.rs index 8532c32ac9c5f..4992992aeae8b 100644 --- a/datafusion/spark/src/function/bitmap/mod.rs +++ b/datafusion/spark/src/function/bitmap/mod.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +pub mod bitmap_bit_position; +pub mod bitmap_bucket_number; pub mod bitmap_count; use datafusion_expr::ScalarUDF; @@ -22,6 +24,11 @@ use datafusion_functions::make_udf_function; use std::sync::Arc; make_udf_function!(bitmap_count::BitmapCount, bitmap_count); +make_udf_function!(bitmap_bit_position::BitmapBitPosition, bitmap_bit_position); +make_udf_function!( + bitmap_bucket_number::BitmapBucketNumber, + bitmap_bucket_number +); pub mod expr_fn { use datafusion_functions::export_functions; @@ -31,8 +38,22 @@ pub mod expr_fn { "Returns the number of set bits in the input bitmap.", arg )); + export_functions!(( + bitmap_bit_position, + "Returns the bit position for the given input child expression.", + arg + )); + export_functions!(( + bitmap_bucket_number, + "Returns the bucket number for the given input child expression.", + arg + )); } pub fn functions() -> Vec> { - vec![bitmap_count()] + vec![ + bitmap_count(), + bitmap_bit_position(), + bitmap_bucket_number(), + ] } diff --git a/datafusion/spark/src/function/bitwise/bit_count.rs b/datafusion/spark/src/function/bitwise/bit_count.rs index 1af5598a1d6a7..3a91fea7a90c0 100644 --- a/datafusion/spark/src/function/bitwise/bit_count.rs +++ b/datafusion/spark/src/function/bitwise/bit_count.rs @@ -15,16 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::{ArrayRef, AsArray, Int32Array}; use arrow::datatypes::{ - DataType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, - UInt64Type, UInt8Type, + DataType, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type, UInt16Type, + UInt32Type, UInt64Type, }; use datafusion_common::cast::as_boolean_array; -use datafusion_common::{plan_err, Result}; +use datafusion_common::{Result, internal_err, plan_err}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -64,10 +63,6 @@ impl SparkBitCount { } impl ScalarUDFImpl for SparkBitCount { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "bit_count" } @@ -77,7 +72,20 @@ impl ScalarUDFImpl for SparkBitCount { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Int32) // Spark returns int (Int32) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + use arrow::datatypes::Field; + // bit_count returns Int32 with the same nullability as the input + Ok(Arc::new(Field::new( + args.arg_fields[0].name(), + DataType::Int32, + args.arg_fields[0].is_nullable(), + ))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -160,10 +168,10 @@ fn spark_bit_count(value_array: &[ArrayRef]) -> Result { mod tests { use super::*; use arrow::array::{ - Array, BooleanArray, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, - UInt32Array, UInt64Array, UInt8Array, + Array, BooleanArray, Int8Array, Int16Array, Int32Array, Int64Array, UInt8Array, + UInt16Array, UInt32Array, UInt64Array, }; - use arrow::datatypes::Int32Type; + use arrow::datatypes::Field; #[test] fn test_bit_count_basic() { @@ -336,4 +344,37 @@ mod tests { assert!(arr.is_null(1)); assert_eq!(arr.value(2), 3); // 0b111 } + + #[test] + fn test_bit_count_nullability() -> Result<()> { + use datafusion_expr::ReturnFieldArgs; + + let bit_count = SparkBitCount::new(); + + // Test with non-nullable Int32 field + let non_nullable_field = Arc::new(Field::new("num", DataType::Int32, false)); + + let result = bit_count.return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&non_nullable_field)], + scalar_arguments: &[None], + })?; + + // The result should not be nullable (same as input) + assert!(!result.is_nullable()); + assert_eq!(result.data_type(), &DataType::Int32); + + // Test with nullable Int32 field + let nullable_field = Arc::new(Field::new("num", DataType::Int32, true)); + + let result = bit_count.return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&nullable_field)], + scalar_arguments: &[None], + })?; + + // The result should be nullable (same as input) + assert!(result.is_nullable()); + assert_eq!(result.data_type(), &DataType::Int32); + + Ok(()) + } } diff --git a/datafusion/spark/src/function/bitwise/bit_get.rs b/datafusion/spark/src/function/bitwise/bit_get.rs index bc8d8cdbd1f9f..0de6498a0c37e 100644 --- a/datafusion/spark/src/function/bitwise/bit_get.rs +++ b/datafusion/spark/src/function/bitwise/bit_get.rs @@ -15,22 +15,21 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::mem::size_of; use std::sync::Arc; use arrow::array::{ - downcast_integer_array, Array, ArrayRef, ArrowPrimitiveType, AsArray, Int32Array, - Int8Array, PrimitiveArray, + Array, ArrayRef, ArrowPrimitiveType, AsArray, Int8Array, Int32Array, PrimitiveArray, + downcast_integer_array, }; use arrow::compute::try_binary; -use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int8Type}; -use datafusion_common::types::{logical_int32, NativeType}; +use arrow::datatypes::{ArrowNativeType, DataType, Field, FieldRef, Int8Type, Int32Type}; +use datafusion_common::types::{NativeType, logical_int32}; use datafusion_common::utils::take_function_args; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{Result, internal_err}; use datafusion_expr::{ - Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, - TypeSignatureClass, Volatility, + Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignatureClass, Volatility, }; use datafusion_functions::utils::make_scalar_function; @@ -66,10 +65,6 @@ impl SparkBitGet { } impl ScalarUDFImpl for SparkBitGet { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "bit_get" } @@ -83,7 +78,13 @@ impl ScalarUDFImpl for SparkBitGet { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Int8) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + // Spark derives nullability for BinaryExpression from its children + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new(self.name(), DataType::Int8, nullable))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -118,3 +119,42 @@ fn spark_bit_get(args: &[ArrayRef]) -> Result { )?; Ok(Arc::new(ret)) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bit_get_nullability_non_nullable_inputs() { + let func = SparkBitGet::new(); + let value_field = Arc::new(Field::new("value", DataType::Int32, false)); + let pos_field = Arc::new(Field::new("pos", DataType::Int32, false)); + + let out_field = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[value_field, pos_field], + scalar_arguments: &[None, None], + }) + .unwrap(); + + assert_eq!(out_field.data_type(), &DataType::Int8); + assert!(!out_field.is_nullable()); + } + + #[test] + fn test_bit_get_nullability_nullable_inputs() { + let func = SparkBitGet::new(); + let value_field = Arc::new(Field::new("value", DataType::Int32, true)); + let pos_field = Arc::new(Field::new("pos", DataType::Int32, false)); + + let out_field = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[value_field, pos_field], + scalar_arguments: &[None, None], + }) + .unwrap(); + + assert_eq!(out_field.data_type(), &DataType::Int8); + assert!(out_field.is_nullable()); + } +} diff --git a/datafusion/spark/src/function/bitwise/bit_shift.rs b/datafusion/spark/src/function/bitwise/bit_shift.rs index 65df04858077f..b78f1890832cb 100644 --- a/datafusion/spark/src/function/bitwise/bit_shift.rs +++ b/datafusion/spark/src/function/bitwise/bit_shift.rs @@ -15,23 +15,23 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, Int32Array, PrimitiveArray}; use arrow::compute; use arrow::datatypes::{ - ArrowNativeType, DataType, Int32Type, Int64Type, UInt32Type, UInt64Type, + ArrowNativeType, DataType, Field, FieldRef, Int32Type, Int64Type, UInt32Type, + UInt64Type, }; use datafusion_common::types::{ - logical_int16, logical_int32, logical_int64, logical_int8, logical_uint16, - logical_uint32, logical_uint64, logical_uint8, NativeType, + NativeType, logical_int8, logical_int16, logical_int32, logical_int64, logical_uint8, + logical_uint16, logical_uint32, logical_uint64, }; use datafusion_common::utils::take_function_args; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{Result, internal_err}; use datafusion_expr::{ - Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, - TypeSignatureClass, Volatility, + Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_functions::utils::make_scalar_function; @@ -263,10 +263,6 @@ impl SparkBitShift { } impl ScalarUDFImpl for SparkBitShift { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { self.name } @@ -275,8 +271,14 @@ impl ScalarUDFImpl for SparkBitShift { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + let data_type = args.arg_fields[0].data_type().clone(); + Ok(Arc::new(Field::new(self.name(), data_type, nullable))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -286,3 +288,56 @@ impl ScalarUDFImpl for SparkBitShift { make_scalar_function(inner, vec![])(&args.args) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bit_shift_nullability() -> Result<()> { + let func = SparkBitShift::left(); + + let non_nullable_value: FieldRef = + Arc::new(Field::new("value", DataType::Int64, false)); + let non_nullable_shift: FieldRef = + Arc::new(Field::new("shift", DataType::Int32, false)); + + let out = func.return_field_from_args(ReturnFieldArgs { + arg_fields: &[ + Arc::clone(&non_nullable_value), + Arc::clone(&non_nullable_shift), + ], + scalar_arguments: &[None, None], + })?; + + assert_eq!(out.data_type(), non_nullable_value.data_type()); + assert!( + !out.is_nullable(), + "shift result should be non-nullable when both inputs are non-nullable" + ); + + let nullable_value: FieldRef = + Arc::new(Field::new("value", DataType::Int64, true)); + let out_nullable_value = func.return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&nullable_value), Arc::clone(&non_nullable_shift)], + scalar_arguments: &[None, None], + })?; + assert!( + out_nullable_value.is_nullable(), + "shift result should be nullable when value is nullable" + ); + + let nullable_shift: FieldRef = + Arc::new(Field::new("shift", DataType::Int32, true)); + let out_nullable_shift = func.return_field_from_args(ReturnFieldArgs { + arg_fields: &[non_nullable_value, nullable_shift], + scalar_arguments: &[None, None], + })?; + assert!( + out_nullable_shift.is_nullable(), + "shift result should be nullable when shift is nullable" + ); + + Ok(()) + } +} diff --git a/datafusion/spark/src/function/bitwise/bitwise_not.rs b/datafusion/spark/src/function/bitwise/bitwise_not.rs index 2f3fe227833b0..9252e1fb606da 100644 --- a/datafusion/spark/src/function/bitwise/bitwise_not.rs +++ b/datafusion/spark/src/function/bitwise/bitwise_not.rs @@ -15,14 +15,16 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::*; use arrow::compute::kernels::bitwise; -use arrow::datatypes::{Int16Type, Int32Type, Int64Type, Int8Type}; -use arrow::{array::*, datatypes::DataType}; -use datafusion_common::{plan_err, Result}; +use arrow::datatypes::{ + DataType, Field, FieldRef, Int8Type, Int16Type, Int32Type, Int64Type, +}; +use datafusion_common::{Result, internal_err, plan_err}; use datafusion_expr::{ColumnarValue, TypeSignature, Volatility}; -use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; +use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_functions::utils::make_scalar_function; -use std::{any::Any, sync::Arc}; +use std::sync::Arc; #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkBitwiseNot { @@ -52,10 +54,6 @@ impl SparkBitwiseNot { } impl ScalarUDFImpl for SparkBitwiseNot { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "bitwise_not" } @@ -64,8 +62,18 @@ impl ScalarUDFImpl for SparkBitwiseNot { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!( + "SparkBitwiseNot: return_type() is not used; return_field_from_args() is implemented" + ) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[0].data_type().clone(), + args.arg_fields[0].is_nullable(), + ))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -107,3 +115,64 @@ pub fn spark_bitwise_not(args: &[ArrayRef]) -> Result { } } } + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn test_bitwise_not_nullability() { + let bitwise_not = SparkBitwiseNot::new(); + + // --- non-nullable Int32 input --- + let non_nullable_i32 = Arc::new(Field::new("c", DataType::Int32, false)); + let out_non_null = bitwise_not + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&non_nullable_i32)], + // single-argument function -> one scalar_argument slot (None) + scalar_arguments: &[None], + }) + .unwrap(); + + // result should be non-nullable and the same DataType as input + assert!(!out_non_null.is_nullable()); + assert_eq!(out_non_null.data_type(), &DataType::Int32); + + // --- nullable Int32 input --- + let nullable_i32 = Arc::new(Field::new("c", DataType::Int32, true)); + let out_nullable = bitwise_not + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&nullable_i32)], + scalar_arguments: &[None], + }) + .unwrap(); + + // result should be nullable and the same DataType as input + assert!(out_nullable.is_nullable()); + assert_eq!(out_nullable.data_type(), &DataType::Int32); + + // --- also test another integer type (Int64) for completeness --- + let non_nullable_i64 = Arc::new(Field::new("c", DataType::Int64, false)); + let out_i64 = bitwise_not + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&non_nullable_i64)], + scalar_arguments: &[None], + }) + .unwrap(); + + assert!(!out_i64.is_nullable()); + assert_eq!(out_i64.data_type(), &DataType::Int64); + + let nullable_i64 = Arc::new(Field::new("c", DataType::Int64, true)); + let out_i64_null = bitwise_not + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&nullable_i64)], + scalar_arguments: &[None], + }) + .unwrap(); + + assert!(out_i64_null.is_nullable()); + assert_eq!(out_i64_null.data_type(), &DataType::Int64); + } +} diff --git a/datafusion/spark/src/function/collection/mod.rs b/datafusion/spark/src/function/collection/mod.rs index a87df9a2c87a0..6871e3aba6469 100644 --- a/datafusion/spark/src/function/collection/mod.rs +++ b/datafusion/spark/src/function/collection/mod.rs @@ -15,11 +15,20 @@ // specific language governing permissions and limitations // under the License. +pub mod size; + use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; use std::sync::Arc; -pub mod expr_fn {} +make_udf_function!(size::SparkSize, size); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!((size, "Return the size of an array or map.", arg)); +} pub fn functions() -> Vec> { - vec![] + vec![size()] } diff --git a/datafusion/spark/src/function/collection/size.rs b/datafusion/spark/src/function/collection/size.rs new file mode 100644 index 0000000000000..e53bbf86d78cb --- /dev/null +++ b/datafusion/spark/src/function/collection/size.rs @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, AsArray, Int32Array}; +use arrow::compute::kernels::length::length as arrow_length; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{Result, plan_err}; +use datafusion_expr::{ + ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use std::sync::Arc; + +/// Spark-compatible `size` function. +/// +/// Returns the number of elements in an array or the number of key-value pairs in a map. +/// Returns -1 for null input (Spark behavior). +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkSize { + signature: Signature, +} + +impl Default for SparkSize { + fn default() -> Self { + Self::new() + } +} + +impl SparkSize { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + // Array Type + TypeSignature::ArraySignature(ArrayFunctionSignature::Array { + arguments: vec![ArrayFunctionArgument::Array], + array_coercion: None, + }), + // Map Type + TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkSize { + fn name(&self) -> &str { + "size" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + // nullable=false for legacy behavior (NULL -> -1); set to input nullability for null-on-null + Ok(Arc::new(Field::new(self.name(), DataType::Int32, false))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_size_inner, vec![])(&args.args) + } +} + +fn spark_size_inner(args: &[ArrayRef]) -> Result { + let array = &args[0]; + + match array.data_type() { + DataType::List(_) => { + if array.null_count() == 0 { + Ok(arrow_length(array)?) + } else { + let list_array = array.as_list::(); + let lengths: Vec = list_array + .offsets() + .lengths() + .enumerate() + .map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 }) + .collect(); + Ok(Arc::new(Int32Array::from(lengths))) + } + } + DataType::FixedSizeList(_, size) => { + if array.null_count() == 0 { + Ok(arrow_length(array)?) + } else { + let length: Vec = (0..array.len()) + .map(|i| if array.is_null(i) { -1 } else { *size }) + .collect(); + Ok(Arc::new(Int32Array::from(length))) + } + } + DataType::LargeList(_) => { + // Arrow length kernel returns Int64 for LargeList + let list_array = array.as_list::(); + if array.null_count() == 0 { + let lengths: Vec = list_array + .offsets() + .lengths() + .map(|len| len as i32) + .collect(); + Ok(Arc::new(Int32Array::from(lengths))) + } else { + let lengths: Vec = list_array + .offsets() + .lengths() + .enumerate() + .map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 }) + .collect(); + Ok(Arc::new(Int32Array::from(lengths))) + } + } + DataType::Map(_, _) => { + let map_array = array.as_map(); + let length: Vec = if array.null_count() == 0 { + map_array + .offsets() + .lengths() + .map(|len| len as i32) + .collect() + } else { + map_array + .offsets() + .lengths() + .enumerate() + .map(|(i, len)| if array.is_null(i) { -1 } else { len as i32 }) + .collect() + }; + Ok(Arc::new(Int32Array::from(length))) + } + DataType::Null => Ok(Arc::new(Int32Array::from(vec![-1; array.len()]))), + dt => { + plan_err!("size function does not support type: {}", dt) + } + } +} diff --git a/datafusion/spark/src/function/conditional/if.rs b/datafusion/spark/src/function/conditional/if.rs index aee43dd8d0a58..b185a5187055d 100644 --- a/datafusion/spark/src/function/conditional/if.rs +++ b/datafusion/spark/src/function/conditional/if.rs @@ -16,10 +16,10 @@ // under the License. use arrow::datatypes::DataType; -use datafusion_common::{internal_err, plan_err, Result}; +use datafusion_common::{Result, internal_err, plan_err}; use datafusion_expr::{ - binary::try_type_union_resolution, simplify::ExprSimplifyResult, when, ColumnarValue, - Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + binary::try_type_union_resolution, simplify::ExprSimplifyResult, when, }; #[derive(Debug, PartialEq, Eq, Hash)] @@ -42,10 +42,6 @@ impl SparkIf { } impl ScalarUDFImpl for SparkIf { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "if" } @@ -86,7 +82,7 @@ impl ScalarUDFImpl for SparkIf { fn simplify( &self, args: Vec, - _info: &dyn datafusion_expr::simplify::SimplifyInfo, + _info: &datafusion_expr::simplify::SimplifyContext, ) -> Result { let condition = args[0].clone(); let then_expr = args[1].clone(); diff --git a/datafusion/spark/src/function/conversion/cast.rs b/datafusion/spark/src/function/conversion/cast.rs new file mode 100644 index 0000000000000..45d1b336261d7 --- /dev/null +++ b/datafusion/spark/src/function/conversion/cast.rs @@ -0,0 +1,1007 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, AsArray, TimestampMicrosecondBuilder}; +use arrow::datatypes::{ + ArrowPrimitiveType, DataType, Field, FieldRef, Float32Type, Float64Type, Int8Type, + Int16Type, Int32Type, Int64Type, TimeUnit, +}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::types::{ + logical_float32, logical_float64, logical_int8, logical_int16, logical_int32, + logical_int64, logical_string, +}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; +use datafusion_expr::{Coercion, TypeSignatureClass}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + Signature, TypeSignature, Volatility, +}; +use std::sync::Arc; + +const MICROS_PER_SECOND: i64 = 1_000_000; + +/// Convert integer seconds to microseconds with saturating overflow behavior +#[inline] +fn secs_to_micros(secs: i64) -> i64 { + secs.saturating_mul(MICROS_PER_SECOND) +} + +/// Convert float seconds to microseconds +/// Returns None for NaN/Infinity in non-ANSI mode, error in ANSI mode +/// Saturates to i64::MAX/MIN for overflow +#[inline] +fn float_secs_to_micros(val: f64, enable_ansi_mode: bool) -> Result> { + if val.is_nan() || val.is_infinite() { + if enable_ansi_mode { + let display_val = if val.is_nan() { + "NaN" + } else if val.is_sign_positive() { + "Infinity" + } else { + "-Infinity" + }; + return exec_err!("Cannot cast {} to TIMESTAMP", display_val); + } + return Ok(None); + } + let micros = val * MICROS_PER_SECOND as f64; + + // Bounds check for i64 range. + // Note on precision: i64::MIN (-2^63) is exactly representable in f64, + // but i64::MAX (2^63 - 1) is not - it rounds up to 2^63 (i64::MAX + 1). + // We use strict `<` for the upper bound to reject values >= 2^63, + // which correctly handles the precision loss edge case. + if micros >= i64::MIN as f64 && micros < i64::MAX as f64 { + Ok(Some(micros as i64)) + } else { + if enable_ansi_mode { + return exec_err!("Overflow casting {} to TIMESTAMP", val); + } + // Saturate to i64::MAX or i64::MIN like Spark does for overflow + if micros.is_sign_negative() { + Ok(Some(i64::MIN)) + } else { + Ok(Some(i64::MAX)) + } + } +} + +/// Spark-compatible `cast` function for type conversions +/// +/// This implements Spark's CAST expression with a target type parameter +/// +/// # Usage +/// ```sql +/// SELECT spark_cast(value, 'timestamp') +/// ``` +/// +/// # Currently supported conversions +/// - Int8/Int16/Int32/Int64/Float32/Float64 -> Timestamp (target_type = 'timestamp') +/// +/// The integer value is interpreted as seconds since the Unix epoch (1970-01-01 00:00:00 UTC) +/// and converted to a timestamp with microsecond precision (matches spark's spec). Same is the case +/// with Float but with higher precision to support micro / nanoseconds. +/// +/// # Overflow behavior +/// Uses saturating multiplication to handle overflow - values that would overflow +/// i64 when multiplied by 1,000,000 are clamped to i64::MAX or i64::MIN +/// +/// # References +/// - +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkCast { + signature: Signature, + timezone: Option>, +} + +impl Default for SparkCast { + fn default() -> Self { + Self::new() + } +} + +impl SparkCast { + pub fn new() -> Self { + Self::new_with_config(&ConfigOptions::default()) + } + + pub fn new_with_config(config: &ConfigOptions) -> Self { + // First arg: value to cast + // Second arg: target datatype as Utf8 string literal (ex : 'timestamp') + let string_arg = + Coercion::new_exact(TypeSignatureClass::Native(logical_string())); + + // Supported input types: signed integers and floats + let input_type_signatures = [ + logical_int8(), + logical_int16(), + logical_int32(), + logical_int64(), + logical_float32(), + logical_float64(), + ] + .map(|input_type| { + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(input_type)), + string_arg.clone(), + ]) + }); + + Self { + signature: Signature::new( + TypeSignature::OneOf(Vec::from(input_type_signatures)), + Volatility::Stable, + ), + timezone: config + .execution + .time_zone + .as_ref() + .map(|tz| Arc::from(tz.as_str())) + .or_else(|| Some(Arc::from("UTC"))), + } + } +} + +/// Parse target type string into a DataType +fn parse_target_type(type_str: &str, timezone: Option>) -> Result { + match type_str.to_lowercase().as_str() { + // further data type support in future + "timestamp" => Ok(DataType::Timestamp(TimeUnit::Microsecond, timezone)), + other => exec_err!( + "Unsupported spark_cast target type '{}'. Supported types: timestamp", + other + ), + } +} + +/// Extract target type string from scalar arguments +fn get_target_type_from_scalar_args( + scalar_args: &[Option<&ScalarValue>], + timezone: Option>, +) -> Result { + let type_arg = scalar_args.get(1).and_then(|opt| *opt); + + match type_arg { + Some(ScalarValue::Utf8(Some(s))) + | Some(ScalarValue::LargeUtf8(Some(s))) + | Some(ScalarValue::Utf8View(Some(s))) => parse_target_type(s, timezone), + _ => exec_err!( + "spark_cast requires second argument to be a string of target data type ex: timestamp" + ), + } +} + +fn cast_int_to_timestamp( + array: &ArrayRef, + timezone: Option>, +) -> Result +where + T::Native: Into, +{ + let arr = array.as_primitive::(); + let mut builder = TimestampMicrosecondBuilder::with_capacity(arr.len()); + + for i in 0..arr.len() { + if arr.is_null(i) { + builder.append_null(); + } else { + // spark saturates to i64 min/max + let micros = secs_to_micros(arr.value(i).into()); + builder.append_value(micros); + } + } + + Ok(Arc::new(builder.finish().with_timezone_opt(timezone))) +} + +/// Cast float to timestamp +/// Float value represents seconds (with fractional part) since Unix epoch +/// NaN and Infinity: error in ANSI mode, NULL in non-ANSI mode +fn cast_float_to_timestamp( + array: &ArrayRef, + timezone: Option>, + enable_ansi_mode: bool, +) -> Result +where + T::Native: Into, +{ + let arr = array.as_primitive::(); + let mut builder = TimestampMicrosecondBuilder::with_capacity(arr.len()); + + for i in 0..arr.len() { + if arr.is_null(i) { + builder.append_null(); + } else { + let val: f64 = arr.value(i).into(); + match float_secs_to_micros(val, enable_ansi_mode)? { + Some(micros) => builder.append_value(micros), + None => builder.append_null(), + } + } + } + + Ok(Arc::new(builder.finish().with_timezone_opt(timezone))) +} + +impl ScalarUDFImpl for SparkCast { + fn name(&self) -> &str { + "spark_cast" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn with_updated_config(&self, config: &ConfigOptions) -> Option { + Some(ScalarUDF::from(Self::new_with_config(config))) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let return_type = get_target_type_from_scalar_args( + args.scalar_arguments, + self.timezone.clone(), + )?; + Ok(Arc::new(Field::new(self.name(), return_type, true))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let enable_ansi_mode = args.config_options.execution.enable_ansi_mode; + let target_type = args.return_field.data_type(); + match target_type { + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + cast_to_timestamp(&args.args[0], tz.clone(), enable_ansi_mode) + } + other => exec_err!("Unsupported spark_cast target type: {:?}", other), + } + } +} + +/// Cast value to timestamp internal function +fn cast_to_timestamp( + input: &ColumnarValue, + timezone: Option>, + enable_ansi_mode: bool, +) -> Result { + match input { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Null => Ok(ColumnarValue::Array(Arc::new( + arrow::array::TimestampMicrosecondArray::new_null(array.len()) + .with_timezone_opt(timezone), + ))), + DataType::Int8 => Ok(ColumnarValue::Array( + cast_int_to_timestamp::(array, timezone)?, + )), + DataType::Int16 => Ok(ColumnarValue::Array(cast_int_to_timestamp::< + Int16Type, + >(array, timezone)?)), + DataType::Int32 => Ok(ColumnarValue::Array(cast_int_to_timestamp::< + Int32Type, + >(array, timezone)?)), + DataType::Int64 => Ok(ColumnarValue::Array(cast_int_to_timestamp::< + Int64Type, + >(array, timezone)?)), + DataType::Float32 => Ok(ColumnarValue::Array(cast_float_to_timestamp::< + Float32Type, + >( + array, + timezone, + enable_ansi_mode, + )?)), + DataType::Float64 => Ok(ColumnarValue::Array(cast_float_to_timestamp::< + Float64Type, + >( + array, + timezone, + enable_ansi_mode, + )?)), + other => exec_err!("Unsupported cast from {:?} to timestamp", other), + }, + ColumnarValue::Scalar(scalar) => { + let micros = match scalar { + ScalarValue::Null + | ScalarValue::Int8(None) + | ScalarValue::Int16(None) + | ScalarValue::Int32(None) + | ScalarValue::Int64(None) + | ScalarValue::Float32(None) + | ScalarValue::Float64(None) => None, + ScalarValue::Int8(Some(v)) => Some(secs_to_micros((*v).into())), + ScalarValue::Int16(Some(v)) => Some(secs_to_micros((*v).into())), + ScalarValue::Int32(Some(v)) => Some(secs_to_micros((*v).into())), + ScalarValue::Int64(Some(v)) => Some(secs_to_micros(*v)), + ScalarValue::Float32(Some(v)) => { + float_secs_to_micros(*v as f64, enable_ansi_mode)? + } + ScalarValue::Float64(Some(v)) => { + float_secs_to_micros(*v, enable_ansi_mode)? + } + other => { + return exec_err!("Unsupported cast from {:?} to timestamp", other); + } + }; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( + micros, timezone, + ))) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ + Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array, + }; + use arrow::datatypes::TimestampMicrosecondType; + + // helpers to make testing easier + fn make_args(input: ColumnarValue, target_type: &str) -> ScalarFunctionArgs { + make_args_with_timezone(input, target_type, Some("UTC")) + } + + fn make_args_with_timezone( + input: ColumnarValue, + target_type: &str, + timezone: Option<&str>, + ) -> ScalarFunctionArgs { + let return_field = Arc::new(Field::new( + "result", + DataType::Timestamp( + TimeUnit::Microsecond, + Some(Arc::from(timezone.unwrap())), + ), + true, + )); + let mut config = ConfigOptions::default(); + if let Some(tz) = timezone { + config.execution.time_zone = Some(tz.to_string()); + } + ScalarFunctionArgs { + args: vec![ + input, + ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_type.to_string()))), + ], + arg_fields: vec![], + number_rows: 0, + return_field, + config_options: Arc::new(config), + } + } + + fn assert_scalar_timestamp(result: ColumnarValue, expected: i64) { + assert_scalar_timestamp_with_tz(result, expected, "UTC"); + } + + fn assert_scalar_timestamp_with_tz( + result: ColumnarValue, + expected: i64, + expected_tz: &str, + ) { + match result { + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( + Some(val), + Some(tz), + )) => { + assert_eq!(val, expected); + assert_eq!(tz.as_ref(), expected_tz); + } + _ => { + panic!( + "Expected scalar timestamp with value {expected} and {expected_tz} timezone" + ) + } + } + } + + fn assert_scalar_null(result: ColumnarValue) { + assert_scalar_null_with_tz(result, "UTC"); + } + + fn assert_scalar_null_with_tz(result: ColumnarValue, expected_tz: &str) { + match result { + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(None, Some(tz))) => { + assert_eq!(tz.as_ref(), expected_tz); + } + _ => panic!("Expected null scalar timestamp with {expected_tz} timezone"), + } + } + + #[test] + fn test_cast_int8_array_to_timestamp() { + let array: ArrayRef = Arc::new(Int8Array::from(vec![ + Some(0), + Some(1), + Some(-1), + Some(127), + Some(-128), + None, + ])); + + let cast = SparkCast::new(); + let args = make_args(ColumnarValue::Array(array), "timestamp"); + let result = cast.invoke_with_args(args).unwrap(); + + match result { + ColumnarValue::Array(result_array) => { + let ts_array = result_array.as_primitive::(); + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000); + assert_eq!(ts_array.value(2), -1_000_000); + assert_eq!(ts_array.value(3), 127_000_000); + assert_eq!(ts_array.value(4), -128_000_000); + assert!(ts_array.is_null(5)); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_cast_int16_array_to_timestamp() { + let array: ArrayRef = Arc::new(Int16Array::from(vec![ + Some(0), + Some(32767), + Some(-32768), + None, + ])); + + let cast = SparkCast::new(); + let args = make_args(ColumnarValue::Array(array), "timestamp"); + let result = cast.invoke_with_args(args).unwrap(); + + match result { + ColumnarValue::Array(result_array) => { + let ts_array = result_array.as_primitive::(); + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 32_767_000_000); + assert_eq!(ts_array.value(2), -32_768_000_000); + assert!(ts_array.is_null(3)); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_cast_int32_array_to_timestamp() { + let array: ArrayRef = + Arc::new(Int32Array::from(vec![Some(0), Some(1704067200), None])); + + let cast = SparkCast::new(); + let args = make_args(ColumnarValue::Array(array), "timestamp"); + let result = cast.invoke_with_args(args).unwrap(); + + match result { + ColumnarValue::Array(result_array) => { + let ts_array = result_array.as_primitive::(); + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_704_067_200_000_000); + assert!(ts_array.is_null(2)); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_cast_int64_array_overflow() { + let array: ArrayRef = + Arc::new(Int64Array::from(vec![Some(i64::MAX), Some(i64::MIN)])); + + let cast = SparkCast::new(); + let args = make_args(ColumnarValue::Array(array), "timestamp"); + let result = cast.invoke_with_args(args).unwrap(); + + match result { + ColumnarValue::Array(result_array) => { + let ts_array = result_array.as_primitive::(); + // saturating_mul clamps to i64::MAX/MIN + assert_eq!(ts_array.value(0), i64::MAX); + assert_eq!(ts_array.value(1), i64::MIN); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_cast_int64_array_to_timestamp() { + let array: ArrayRef = Arc::new(Int64Array::from(vec![ + Some(0), + Some(1704067200), + Some(-86400), + None, + ])); + + let cast = SparkCast::new(); + let args = make_args(ColumnarValue::Array(array), "timestamp"); + let result = cast.invoke_with_args(args).unwrap(); + + match result { + ColumnarValue::Array(result_array) => { + let ts_array = result_array.as_primitive::(); + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_704_067_200_000_000); + assert_eq!(ts_array.value(2), -86_400_000_000); // -1 day + assert!(ts_array.is_null(3)); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_cast_scalar_int8() { + let cast = SparkCast::new(); + let args = make_args( + ColumnarValue::Scalar(ScalarValue::Int8(Some(100))), + "timestamp", + ); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_timestamp(result, 100_000_000); + } + + #[test] + fn test_cast_scalar_int16() { + let cast = SparkCast::new(); + let args = make_args( + ColumnarValue::Scalar(ScalarValue::Int16(Some(100))), + "timestamp", + ); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_timestamp(result, 100_000_000); + } + + #[test] + fn test_cast_scalar_int32() { + let cast = SparkCast::new(); + let args = make_args( + ColumnarValue::Scalar(ScalarValue::Int32(Some(1704067200))), + "timestamp", + ); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_timestamp(result, 1_704_067_200_000_000); + } + + #[test] + fn test_cast_scalar_int64() { + let cast = SparkCast::new(); + let args = make_args( + ColumnarValue::Scalar(ScalarValue::Int64(Some(1704067200))), + "timestamp", + ); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_timestamp(result, 1_704_067_200_000_000); + } + + #[test] + fn test_cast_scalar_negative() { + let cast = SparkCast::new(); + let args = make_args( + ColumnarValue::Scalar(ScalarValue::Int32(Some(-86400))), + "timestamp", + ); + let result = cast.invoke_with_args(args).unwrap(); + // -86400 seconds = -1 day before epoch + assert_scalar_timestamp(result, -86_400_000_000); + } + + #[test] + fn test_cast_scalar_null() { + let cast = SparkCast::new(); + let args = + make_args(ColumnarValue::Scalar(ScalarValue::Int64(None)), "timestamp"); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_null(result); + } + + #[test] + fn test_cast_scalar_int64_overflow() { + let cast = SparkCast::new(); + let args = make_args( + ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MAX))), + "timestamp", + ); + let result = cast.invoke_with_args(args).unwrap(); + // saturating_mul clamps to i64::MAX + assert_scalar_timestamp(result, i64::MAX); + } + + #[test] + fn test_unsupported_target_type() { + let cast = SparkCast::new(); + // invoke_with_args uses return_field which would be set correctly by planning + // For this test, we need to check return_field_from_args + let arg_fields: Vec = + vec![Arc::new(Field::new("a", DataType::Int64, true))]; + let target_type = ScalarValue::Utf8(Some("string".to_string())); + let scalar_arguments: Vec> = vec![None, Some(&target_type)]; + let return_field_args = ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &scalar_arguments, + }; + let result = cast.return_field_from_args(return_field_args); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Unsupported spark_cast target type") + ); + } + + #[test] + fn test_unsupported_source_type() { + let cast = SparkCast::new(); + let args = make_args( + ColumnarValue::Scalar(ScalarValue::Utf8(Some("2024-01-01".to_string()))), + "timestamp", + ); + let result = cast.invoke_with_args(args); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Unsupported cast from") + ); + } + + #[test] + fn test_cast_null_to_timestamp() { + let cast = SparkCast::new(); + let args = make_args(ColumnarValue::Scalar(ScalarValue::Null), "timestamp"); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_null(result); + } + + #[test] + fn test_cast_null_array_to_timestamp() { + let array: ArrayRef = Arc::new(arrow::array::NullArray::new(3)); + + let cast = SparkCast::new(); + let args = make_args(ColumnarValue::Array(array), "timestamp"); + let result = cast.invoke_with_args(args).unwrap(); + + match result { + ColumnarValue::Array(result_array) => { + let ts_array = result_array.as_primitive::(); + assert_eq!(ts_array.len(), 3); + assert!(ts_array.is_null(0)); + assert!(ts_array.is_null(1)); + assert!(ts_array.is_null(2)); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_cast_int_to_timestamp_with_timezones() { + // Test with various timezones like Comet does + let timezones = [ + "UTC", + "America/New_York", + "America/Los_Angeles", + "Europe/London", + "Asia/Tokyo", + "Australia/Sydney", + ]; + + let cast = SparkCast::new(); + let test_value: i64 = 1704067200; // 2024-01-01 00:00:00 UTC + let expected_micros = test_value * MICROS_PER_SECOND; + + for tz in timezones { + // scalar + let args = make_args_with_timezone( + ColumnarValue::Scalar(ScalarValue::Int64(Some(test_value))), + "timestamp", + Some(tz), + ); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_timestamp_with_tz(result, expected_micros, tz); + + // array input + let array: ArrayRef = + Arc::new(Int64Array::from(vec![Some(test_value), None])); + let args = make_args_with_timezone( + ColumnarValue::Array(array), + "timestamp", + Some(tz), + ); + let result = cast.invoke_with_args(args).unwrap(); + + match result { + ColumnarValue::Array(result_array) => { + let ts_array = + result_array.as_primitive::(); + assert_eq!(ts_array.value(0), expected_micros); + assert!(ts_array.is_null(1)); + assert_eq!(ts_array.timezone(), Some(tz)); + } + _ => panic!("Expected array result for timezone {tz}"), + } + } + } + + #[test] + fn test_cast_int_to_timestamp_default_timezone() { + let cast = SparkCast::new(); + let args = make_args( + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + "timestamp", + ); + let result = cast.invoke_with_args(args).unwrap(); + // Defaults to UTC + assert_scalar_timestamp_with_tz(result, 0, "UTC"); + } + + fn make_args_with_ansi_mode( + input: ColumnarValue, + target_type: &str, + enable_ansi_mode: bool, + ) -> ScalarFunctionArgs { + let return_field = Arc::new(Field::new( + "result", + DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), + true, + )); + let mut config = ConfigOptions::default(); + config.execution.time_zone = Some("UTC".to_string()); + config.execution.enable_ansi_mode = enable_ansi_mode; + ScalarFunctionArgs { + args: vec![ + input, + ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_type.to_string()))), + ], + arg_fields: vec![], + number_rows: 0, + return_field, + config_options: Arc::new(config), + } + } + + #[test] + fn test_cast_float64_array_to_timestamp() { + let array: ArrayRef = Arc::new(Float64Array::from(vec![ + Some(0.0), + Some(1.5), + Some(-1.5), + Some(1704067200.123456), + None, + ])); + + let cast = SparkCast::new(); + let args = make_args(ColumnarValue::Array(array), "timestamp"); + let result = cast.invoke_with_args(args).unwrap(); + + match result { + ColumnarValue::Array(result_array) => { + let ts_array = result_array.as_primitive::(); + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_500_000); // 1.5 seconds + assert_eq!(ts_array.value(2), -1_500_000); // -1.5 seconds + assert_eq!(ts_array.value(3), 1_704_067_200_123_456); // with fractional + assert!(ts_array.is_null(4)); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_cast_float32_array_to_timestamp() { + let array: ArrayRef = Arc::new(Float32Array::from(vec![ + Some(0.0f32), + Some(1.5f32), + Some(-1.5f32), + None, + ])); + + let cast = SparkCast::new(); + let args = make_args(ColumnarValue::Array(array), "timestamp"); + let result = cast.invoke_with_args(args).unwrap(); + + match result { + ColumnarValue::Array(result_array) => { + let ts_array = result_array.as_primitive::(); + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_500_000); // 1.5 seconds + assert_eq!(ts_array.value(2), -1_500_000); // -1.5 seconds + assert!(ts_array.is_null(3)); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_cast_scalar_float64() { + let cast = SparkCast::new(); + let args = make_args( + ColumnarValue::Scalar(ScalarValue::Float64(Some(1.5))), + "timestamp", + ); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_timestamp(result, 1_500_000); + } + + #[test] + fn test_cast_scalar_float32() { + let cast = SparkCast::new(); + let args = make_args( + ColumnarValue::Scalar(ScalarValue::Float32(Some(1.5f32))), + "timestamp", + ); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_timestamp(result, 1_500_000); + } + + #[test] + fn test_cast_float_nan_non_ansi_mode() { + // In non-ANSI mode, NaN should return NULL + let cast = SparkCast::new(); + let args = make_args_with_ansi_mode( + ColumnarValue::Scalar(ScalarValue::Float64(Some(f64::NAN))), + "timestamp", + false, + ); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_null(result); + } + + #[test] + fn test_cast_float_infinity_non_ansi_mode() { + // In non-ANSI mode, Infinity should return NULL + let cast = SparkCast::new(); + + // Positive infinity + let args = make_args_with_ansi_mode( + ColumnarValue::Scalar(ScalarValue::Float64(Some(f64::INFINITY))), + "timestamp", + false, + ); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_null(result); + + // Negative infinity + let args = make_args_with_ansi_mode( + ColumnarValue::Scalar(ScalarValue::Float64(Some(f64::NEG_INFINITY))), + "timestamp", + false, + ); + let result = cast.invoke_with_args(args).unwrap(); + assert_scalar_null(result); + } + + #[test] + fn test_cast_float_nan_ansi_mode() { + // In ANSI mode, NaN should error + let cast = SparkCast::new(); + let args = make_args_with_ansi_mode( + ColumnarValue::Scalar(ScalarValue::Float64(Some(f64::NAN))), + "timestamp", + true, + ); + let result = cast.invoke_with_args(args); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Cannot cast NaN")); + } + + #[test] + fn test_cast_float_infinity_ansi_mode() { + // In ANSI mode, Infinity should error + let cast = SparkCast::new(); + let args = make_args_with_ansi_mode( + ColumnarValue::Scalar(ScalarValue::Float64(Some(f64::INFINITY))), + "timestamp", + true, + ); + let result = cast.invoke_with_args(args); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Cannot cast Infinity") + ); + } + + #[test] + fn test_cast_float_overflow_non_ansi_mode() { + // Value too large to fit in i64 microseconds - should saturate to i64::MAX like Spark + let cast = SparkCast::new(); + let large_value = 1e19; // Way too large for i64 microseconds + let args = make_args_with_ansi_mode( + ColumnarValue::Scalar(ScalarValue::Float64(Some(large_value))), + "timestamp", + false, + ); + let result = cast.invoke_with_args(args).unwrap(); + // Spark saturates overflow to i64::MAX + assert_scalar_timestamp(result, i64::MAX); + } + + #[test] + fn test_cast_float_negative_overflow_non_ansi_mode() { + // Large negative value - should saturate to i64::MIN like Spark + let cast = SparkCast::new(); + let large_value = -1e19; // Way too large negative for i64 microseconds + let args = make_args_with_ansi_mode( + ColumnarValue::Scalar(ScalarValue::Float64(Some(large_value))), + "timestamp", + false, + ); + let result = cast.invoke_with_args(args).unwrap(); + // Spark saturates negative overflow to i64::MIN + assert_scalar_timestamp(result, i64::MIN); + } + + #[test] + fn test_cast_float_overflow_ansi_mode() { + // Value too large to fit in i64 microseconds - should error in ANSI mode + let cast = SparkCast::new(); + let large_value = 1e19; // Way too large for i64 microseconds + let args = make_args_with_ansi_mode( + ColumnarValue::Scalar(ScalarValue::Float64(Some(large_value))), + "timestamp", + true, + ); + let result = cast.invoke_with_args(args); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Overflow")); + } + + #[test] + fn test_cast_float_array_with_nan_and_infinity() { + // Array with NaN and Infinity in non-ANSI mode + let array: ArrayRef = Arc::new(Float64Array::from(vec![ + Some(1.0), + Some(f64::NAN), + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + Some(2.0), + ])); + + let cast = SparkCast::new(); + let args = + make_args_with_ansi_mode(ColumnarValue::Array(array), "timestamp", false); + let result = cast.invoke_with_args(args).unwrap(); + + match result { + ColumnarValue::Array(result_array) => { + let ts_array = result_array.as_primitive::(); + assert_eq!(ts_array.value(0), 1_000_000); + assert!(ts_array.is_null(1)); // NaN -> NULL + assert!(ts_array.is_null(2)); // Infinity -> NULL + assert!(ts_array.is_null(3)); // -Infinity -> NULL + assert_eq!(ts_array.value(4), 2_000_000); + } + _ => panic!("Expected array result"), + } + } + + #[test] + fn test_cast_float_negative_values() { + let cast = SparkCast::new(); + let args = make_args( + ColumnarValue::Scalar(ScalarValue::Float64(Some(-86400.5))), + "timestamp", + ); + let result = cast.invoke_with_args(args).unwrap(); + // -86400.5 seconds = -86400500000 microseconds (1 day and 0.5 seconds before epoch) + assert_scalar_timestamp(result, -86_400_500_000); + } +} diff --git a/datafusion/spark/src/function/conversion/mod.rs b/datafusion/spark/src/function/conversion/mod.rs index a87df9a2c87a0..e8a89fa8c0616 100644 --- a/datafusion/spark/src/function/conversion/mod.rs +++ b/datafusion/spark/src/function/conversion/mod.rs @@ -15,11 +15,26 @@ // specific language governing permissions and limitations // under the License. +mod cast; + use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function_with_config; use std::sync::Arc; -pub mod expr_fn {} +make_udf_function_with_config!(cast::SparkCast, spark_cast); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!(( + spark_cast, + "Casts given value to the specified type following Spark-compatible semantics", + @config arg1 arg2 + )); +} pub fn functions() -> Vec> { - vec![] + use datafusion_common::config::ConfigOptions; + let config = ConfigOptions::default(); + vec![spark_cast(&config)] } diff --git a/datafusion/spark/src/function/datetime/add_months.rs b/datafusion/spark/src/function/datetime/add_months.rs new file mode 100644 index 0000000000000..2963cf5880b9c --- /dev/null +++ b/datafusion/spark/src/function/datetime/add_months.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ops::Add; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef, IntervalUnit}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, Volatility, +}; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkAddMonths { + signature: Signature, +} + +impl Default for SparkAddMonths { + fn default() -> Self { + Self::new() + } +} + +impl SparkAddMonths { + pub fn new() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Date32, DataType::Int32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkAddMonths { + fn name(&self) -> &str { + "add_months" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + DataType::Date32, + nullable, + ))) + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [date_arg, months_arg] = take_function_args("add_months", args)?; + let interval = months_arg + .cast_to(&DataType::Interval(IntervalUnit::YearMonth), info.schema())?; + Ok(ExprSimplifyResult::Simplified(date_arg.add(interval))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("invoke should not be called on a simplified add_months() function") + } +} diff --git a/datafusion/spark/src/function/datetime/date_add.rs b/datafusion/spark/src/function/datetime/date_add.rs index 457d4d476dce3..6db0fe3a36cf2 100644 --- a/datafusion/spark/src/function/datetime/date_add.rs +++ b/datafusion/spark/src/function/datetime/date_add.rs @@ -15,21 +15,19 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::ArrayRef; use arrow::compute; -use arrow::datatypes::{DataType, Date32Type}; -use arrow::error::ArrowError; +use arrow::datatypes::{DataType, Date32Type, Field, FieldRef}; use datafusion_common::cast::{ - as_date32_array, as_int16_array, as_int32_array, as_int8_array, + as_date32_array, as_int8_array, as_int16_array, as_int32_array, }; use datafusion_common::utils::take_function_args; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{Result, internal_err}; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, - Volatility, + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; use datafusion_functions::utils::make_scalar_function; @@ -62,10 +60,6 @@ impl SparkDateAdd { } impl ScalarUDFImpl for SparkDateAdd { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "date_add" } @@ -79,7 +73,16 @@ impl ScalarUDFImpl for SparkDateAdd { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Date32) + internal_err!("Use return_field_from_args in this case instead.") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new( + self.name(), + DataType::Date32, + nullable, + ))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -93,38 +96,26 @@ fn spark_date_add(args: &[ArrayRef]) -> Result { let result = match days_arg.data_type() { DataType::Int8 => { let days_array = as_int8_array(days_arg)?; - compute::try_binary::<_, _, _, Date32Type>( + compute::binary::<_, _, _, Date32Type>( date_array, days_array, - |date, days| { - date.checked_add(days as i32).ok_or_else(|| { - ArrowError::ArithmeticOverflow("date_add".to_string()) - }) - }, + |date, days| date.wrapping_add(days as i32), )? } DataType::Int16 => { let days_array = as_int16_array(days_arg)?; - compute::try_binary::<_, _, _, Date32Type>( + compute::binary::<_, _, _, Date32Type>( date_array, days_array, - |date, days| { - date.checked_add(days as i32).ok_or_else(|| { - ArrowError::ArithmeticOverflow("date_add".to_string()) - }) - }, + |date, days| date.wrapping_add(days as i32), )? } DataType::Int32 => { let days_array = as_int32_array(days_arg)?; - compute::try_binary::<_, _, _, Date32Type>( + compute::binary::<_, _, _, Date32Type>( date_array, days_array, - |date, days| { - date.checked_add(days).ok_or_else(|| { - ArrowError::ArithmeticOverflow("date_add".to_string()) - }) - }, + |date, days| date.wrapping_add(days), )? } _ => { @@ -136,3 +127,46 @@ fn spark_date_add(args: &[ArrayRef]) -> Result { }; Ok(Arc::new(result)) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_date_add_non_nullable_inputs() { + let func = SparkDateAdd::new(); + let args = &[ + Arc::new(Field::new("date", DataType::Date32, false)), + Arc::new(Field::new("num", DataType::Int8, false)), + ]; + + let ret_field = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: args, + scalar_arguments: &[None, None], + }) + .unwrap(); + + assert_eq!(ret_field.data_type(), &DataType::Date32); + assert!(!ret_field.is_nullable()); + } + + #[test] + fn test_date_add_nullable_inputs() { + let func = SparkDateAdd::new(); + let args = &[ + Arc::new(Field::new("date", DataType::Date32, false)), + Arc::new(Field::new("num", DataType::Int16, true)), + ]; + + let ret_field = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: args, + scalar_arguments: &[None, None], + }) + .unwrap(); + + assert_eq!(ret_field.data_type(), &DataType::Date32); + assert!(ret_field.is_nullable()); + } +} diff --git a/datafusion/spark/src/function/datetime/date_diff.rs b/datafusion/spark/src/function/datetime/date_diff.rs new file mode 100644 index 0000000000000..b9793ddc00670 --- /dev/null +++ b/datafusion/spark/src/function/datetime/date_diff.rs @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::types::{NativeType, logical_date, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ExprSchemable, Operator, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, + binary_expr, +}; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkDateDiff { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkDateDiff { + fn default() -> Self { + Self::new() + } +} + +impl SparkDateDiff { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Native(logical_date()), + vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Timestamp, + ], + NativeType::Date, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_date()), + vec![ + TypeSignatureClass::Native(logical_string()), + TypeSignatureClass::Timestamp, + ], + NativeType::Date, + ), + ], + Volatility::Immutable, + ), + aliases: vec!["datediff".to_string()], + } + } +} + +impl ScalarUDFImpl for SparkDateDiff { + fn name(&self) -> &str { + "date_diff" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new(self.name(), DataType::Int32, nullable))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!( + "Apache Spark `date_diff` should have been simplified to standard subtraction" + ) + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [end, start] = take_function_args(self.name(), args)?; + let end = end.cast_to(&DataType::Date32, info.schema())?; + let start = start.cast_to(&DataType::Date32, info.schema())?; + Ok(ExprSimplifyResult::Simplified( + binary_expr(end, Operator::Minus, start) + .cast_to(&DataType::Int32, info.schema())?, + )) + } +} diff --git a/datafusion/spark/src/function/datetime/date_part.rs b/datafusion/spark/src/function/datetime/date_part.rs new file mode 100644 index 0000000000000..91bdb9a55318b --- /dev/null +++ b/datafusion/spark/src/function/datetime/date_part.rs @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::types::logical_date; +use datafusion_common::{ + Result, ScalarValue, internal_err, types::logical_string, utils::take_function_args, +}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignature, TypeSignatureClass, Volatility, +}; +use std::sync::Arc; + +/// Wrapper around datafusion date_part function to handle +/// Spark behavior returning day of the week 1-indexed instead of 0-indexed and different part aliases. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkDatePart { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkDatePart { + fn default() -> Self { + Self::new() + } +} + +impl SparkDatePart { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Timestamp), + ]), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_date())), + ]), + ], + Volatility::Immutable, + ), + aliases: vec![String::from("datepart")], + } + } +} + +impl ScalarUDFImpl for SparkDatePart { + fn name(&self) -> &str { + "date_part" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("Use return_field_from_args in this case instead.") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new(self.name(), DataType::Int32, nullable))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("spark date_part should have been simplified to standard date_part") + } + + fn simplify( + &self, + args: Vec, + _info: &SimplifyContext, + ) -> Result { + let [part_expr, date_expr] = take_function_args(self.name(), args)?; + + let part = match part_expr.as_literal() { + Some(ScalarValue::Utf8(Some(v))) + | Some(ScalarValue::Utf8View(Some(v))) + | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(), + _ => { + return internal_err!( + "First argument of `DATE_PART` must be non-null scalar Utf8" + ); + } + }; + + // Map Spark-specific date part aliases to datafusion ones + let part = match part.as_str() { + "yearofweek" | "year_iso" => "isoyear", + "dayofweek" => "dow", + "dayofweek_iso" | "dow_iso" => "isodow", + other => other, + }; + + let part_expr = Expr::Literal(ScalarValue::new_utf8(part), None); + + let date_part_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + datafusion_functions::datetime::date_part(), + vec![part_expr, date_expr], + )); + + match part { + // Spark's `dayofweek` is 1..=7 (Sun=1) but df's `dow` is 0..=6 + // (Sun=0); shift by +1. df's `isodow` already returns the + // PG-correct 1..=7 (Mon=1), which matches Spark's + // `dayofweek_iso`/`dow_iso`, so no shift is needed there. + "dow" => Ok(ExprSimplifyResult::Simplified( + date_part_expr + Expr::Literal(ScalarValue::Int32(Some(1)), None), + )), + _ => Ok(ExprSimplifyResult::Simplified(date_part_expr)), + } + } +} diff --git a/datafusion/spark/src/function/datetime/date_sub.rs b/datafusion/spark/src/function/datetime/date_sub.rs index a3b26661d196c..bc2025c9b2eda 100644 --- a/datafusion/spark/src/function/datetime/date_sub.rs +++ b/datafusion/spark/src/function/datetime/date_sub.rs @@ -15,20 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::ArrayRef; use arrow::compute; -use arrow::datatypes::{DataType, Date32Type}; -use arrow::error::ArrowError; +use arrow::datatypes::{DataType, Date32Type, Field, FieldRef}; use datafusion_common::cast::{ - as_date32_array, as_int16_array, as_int32_array, as_int8_array, + as_date32_array, as_int8_array, as_int16_array, as_int32_array, }; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{Result, internal_err}; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, - Volatility, + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; use datafusion_functions::utils::make_scalar_function; @@ -59,10 +57,6 @@ impl SparkDateSub { } impl ScalarUDFImpl for SparkDateSub { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "date_sub" } @@ -72,7 +66,16 @@ impl ScalarUDFImpl for SparkDateSub { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Date32) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new( + self.name(), + DataType::Date32, + nullable, + ))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -91,38 +94,26 @@ fn spark_date_sub(args: &[ArrayRef]) -> Result { let result = match days_arg.data_type() { DataType::Int8 => { let days_array = as_int8_array(days_arg)?; - compute::try_binary::<_, _, _, Date32Type>( + compute::binary::<_, _, _, Date32Type>( date_array, days_array, - |date, days| { - date.checked_sub(days as i32).ok_or_else(|| { - ArrowError::ArithmeticOverflow("date_sub".to_string()) - }) - }, + |date, days| date.wrapping_sub(days as i32), )? } DataType::Int16 => { let days_array = as_int16_array(days_arg)?; - compute::try_binary::<_, _, _, Date32Type>( + compute::binary::<_, _, _, Date32Type>( date_array, days_array, - |date, days| { - date.checked_sub(days as i32).ok_or_else(|| { - ArrowError::ArithmeticOverflow("date_sub".to_string()) - }) - }, + |date, days| date.wrapping_sub(days as i32), )? } DataType::Int32 => { let days_array = as_int32_array(days_arg)?; - compute::try_binary::<_, _, _, Date32Type>( + compute::binary::<_, _, _, Date32Type>( date_array, days_array, - |date, days| { - date.checked_sub(days).ok_or_else(|| { - ArrowError::ArithmeticOverflow("date_sub".to_string()) - }) - }, + |date, days| date.wrapping_sub(days), )? } _ => { @@ -134,3 +125,42 @@ fn spark_date_sub(args: &[ArrayRef]) -> Result { }; Ok(Arc::new(result)) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_date_sub_nullability_non_nullable_args() { + let udf = SparkDateSub::new(); + let date_field = Arc::new(Field::new("d", DataType::Date32, false)); + let days_field = Arc::new(Field::new("n", DataType::Int32, false)); + + let result = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[date_field, days_field], + scalar_arguments: &[None, None], + }) + .unwrap(); + + assert!(!result.is_nullable()); + assert_eq!(result.data_type(), &DataType::Date32); + } + + #[test] + fn test_date_sub_nullability_nullable_arg() { + let udf = SparkDateSub::new(); + let date_field = Arc::new(Field::new("d", DataType::Date32, false)); + let nullable_days_field = Arc::new(Field::new("n", DataType::Int32, true)); + + let result = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[date_field, nullable_days_field], + scalar_arguments: &[None, None], + }) + .unwrap(); + + assert!(result.is_nullable()); + assert_eq!(result.data_type(), &DataType::Date32); + } +} diff --git a/datafusion/spark/src/function/datetime/date_trunc.rs b/datafusion/spark/src/function/datetime/date_trunc.rs new file mode 100644 index 0000000000000..c8b0fbca36165 --- /dev/null +++ b/datafusion/spark/src/function/datetime/date_trunc.rs @@ -0,0 +1,167 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; +use datafusion_common::types::{NativeType, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err, plan_err}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; + +/// Spark date_trunc supports extra format aliases. +/// It also handles timestamps with timezones by converting to session timezone first. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkDateTrunc { + signature: Signature, +} + +impl Default for SparkDateTrunc { + fn default() -> Self { + Self::new() + } +} + +impl SparkDateTrunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_implicit( + TypeSignatureClass::Timestamp, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Timestamp(TimeUnit::Microsecond, None), + ), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkDateTrunc { + fn name(&self) -> &str { + "date_trunc" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[1].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!( + "spark date_trunc should have been simplified to standard date_trunc" + ) + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [fmt_expr, ts_expr] = take_function_args(self.name(), args)?; + + let fmt = match fmt_expr.as_literal() { + Some(ScalarValue::Utf8(Some(v))) + | Some(ScalarValue::Utf8View(Some(v))) + | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(), + _ => { + return plan_err!( + "First argument of `DATE_TRUNC` must be non-null scalar Utf8" + ); + } + }; + + // Map Spark-specific fmt aliases to datafusion ones + let fmt = match fmt.as_str() { + "yy" | "yyyy" => "year", + "mm" | "mon" => "month", + "dd" => "day", + other => other, + }; + + let session_tz = info.config_options().execution.time_zone.clone(); + let ts_type = ts_expr.get_type(info.schema())?; + + // Spark interprets timestamps in the session timezone before truncating, + // then returns a timestamp at microsecond precision. + // See: https://github.com/apache/spark/blob/f310f4fcc95580a6824bc7d22b76006f79b8804a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala#L492 + // + // For sub-second truncations (second, millisecond, microsecond), timezone + // adjustment is unnecessary since timezone offsets are whole seconds. + let ts_expr = match (&ts_type, fmt) { + // Sub-second truncations don't need timezone adjustment + (_, "second" | "millisecond" | "microsecond") => ts_expr, + + // convert to session timezone, strip timezone and convert back to original timezone + (DataType::Timestamp(unit, tz), _) => { + let ts_expr = match &session_tz { + Some(session_tz) => ts_expr.cast_to( + &DataType::Timestamp( + TimeUnit::Microsecond, + Some(Arc::from(session_tz.as_str())), + ), + info.schema(), + )?, + None => ts_expr, + }; + Expr::ScalarFunction(ScalarFunction::new_udf( + datafusion_functions::datetime::to_local_time(), + vec![ts_expr], + )) + .cast_to(&DataType::Timestamp(*unit, tz.clone()), info.schema())? + } + + _ => { + return plan_err!( + "Second argument of `DATE_TRUNC` must be Timestamp, got {}", + ts_type + ); + } + }; + + let fmt_expr = Expr::Literal(ScalarValue::new_utf8(fmt), None); + + Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( + ScalarFunction::new_udf( + datafusion_functions::datetime::date_trunc(), + vec![fmt_expr, ts_expr], + ), + ))) + } +} diff --git a/datafusion/spark/src/function/datetime/extract.rs b/datafusion/spark/src/function/datetime/extract.rs new file mode 100644 index 0000000000000..70026b18ed5e7 --- /dev/null +++ b/datafusion/spark/src/function/datetime/extract.rs @@ -0,0 +1,254 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::compute::{DatePart, date_part}; +use arrow::datatypes::DataType; +use datafusion_common::Result; +use datafusion_common::utils::take_function_args; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +/// Creates a signature for datetime extraction functions that accept timestamp types. +fn extract_signature() -> Signature { + Signature::coercible( + vec![Coercion::new_exact(TypeSignatureClass::Timestamp)], + Volatility::Immutable, + ) +} + +// ----------------------------------------------------------------------------- +// SparkHour +// ----------------------------------------------------------------------------- + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkHour { + signature: Signature, +} + +impl Default for SparkHour { + fn default() -> Self { + Self::new() + } +} + +impl SparkHour { + pub fn new() -> Self { + Self { + signature: extract_signature(), + } + } +} + +impl ScalarUDFImpl for SparkHour { + fn name(&self) -> &str { + "hour" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_hour, vec![])(&args.args) + } +} + +fn spark_hour(args: &[ArrayRef]) -> Result { + let [ts_arg] = take_function_args("hour", args)?; + let result = date_part(ts_arg.as_ref(), DatePart::Hour)?; + Ok(result) +} + +// ----------------------------------------------------------------------------- +// SparkMinute +// ----------------------------------------------------------------------------- + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkMinute { + signature: Signature, +} + +impl Default for SparkMinute { + fn default() -> Self { + Self::new() + } +} + +impl SparkMinute { + pub fn new() -> Self { + Self { + signature: extract_signature(), + } + } +} + +impl ScalarUDFImpl for SparkMinute { + fn name(&self) -> &str { + "minute" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_minute, vec![])(&args.args) + } +} + +fn spark_minute(args: &[ArrayRef]) -> Result { + let [ts_arg] = take_function_args("minute", args)?; + let result = date_part(ts_arg.as_ref(), DatePart::Minute)?; + Ok(result) +} + +// ----------------------------------------------------------------------------- +// SparkSecond +// ----------------------------------------------------------------------------- + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkSecond { + signature: Signature, +} + +impl Default for SparkSecond { + fn default() -> Self { + Self::new() + } +} + +impl SparkSecond { + pub fn new() -> Self { + Self { + signature: extract_signature(), + } + } +} + +impl ScalarUDFImpl for SparkSecond { + fn name(&self) -> &str { + "second" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_second, vec![])(&args.args) + } +} + +fn spark_second(args: &[ArrayRef]) -> Result { + let [ts_arg] = take_function_args("second", args)?; + let result = date_part(ts_arg.as_ref(), DatePart::Second)?; + Ok(result) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, Int32Array, TimestampMicrosecondArray}; + use arrow::datatypes::TimeUnit; + use std::sync::Arc; + + #[test] + fn test_spark_hour() { + // Create a timestamp array: 2024-01-15 14:30:45 UTC (in microseconds) + // 14:30:45 -> hour = 14 + let ts_micros = 1_705_329_045_000_000_i64; // 2024-01-15 14:30:45 UTC + let ts_array = TimestampMicrosecondArray::from(vec![Some(ts_micros), None]); + let ts_array = Arc::new(ts_array) as ArrayRef; + + let result = spark_hour(&[ts_array]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.value(0), 14); + assert!(result.is_null(1)); + } + + #[test] + fn test_spark_minute() { + // 14:30:45 -> minute = 30 + let ts_micros = 1_705_329_045_000_000_i64; + let ts_array = TimestampMicrosecondArray::from(vec![Some(ts_micros), None]); + let ts_array = Arc::new(ts_array) as ArrayRef; + + let result = spark_minute(&[ts_array]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.value(0), 30); + assert!(result.is_null(1)); + } + + #[test] + fn test_spark_second() { + // 14:30:45 -> second = 45 + let ts_micros = 1_705_329_045_000_000_i64; + let ts_array = TimestampMicrosecondArray::from(vec![Some(ts_micros), None]); + let ts_array = Arc::new(ts_array) as ArrayRef; + + let result = spark_second(&[ts_array]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.value(0), 45); + assert!(result.is_null(1)); + } + + #[test] + fn test_hour_return_type() { + let func = SparkHour::new(); + let result = func + .return_type(&[DataType::Timestamp(TimeUnit::Microsecond, None)]) + .unwrap(); + assert_eq!(result, DataType::Int32); + } + + #[test] + fn test_minute_return_type() { + let func = SparkMinute::new(); + let result = func + .return_type(&[DataType::Timestamp(TimeUnit::Microsecond, None)]) + .unwrap(); + assert_eq!(result, DataType::Int32); + } + + #[test] + fn test_second_return_type() { + let func = SparkSecond::new(); + let result = func + .return_type(&[DataType::Timestamp(TimeUnit::Microsecond, None)]) + .unwrap(); + assert_eq!(result, DataType::Int32); + } +} diff --git a/datafusion/spark/src/function/datetime/from_utc_timestamp.rs b/datafusion/spark/src/function/datetime/from_utc_timestamp.rs new file mode 100644 index 0000000000000..bfca677c1dcce --- /dev/null +++ b/datafusion/spark/src/function/datetime/from_utc_timestamp.rs @@ -0,0 +1,190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::timezone::Tz; +use arrow::array::{Array, ArrayRef, AsArray, PrimitiveBuilder, StringArrayType}; +use arrow::datatypes::TimeUnit; +use arrow::datatypes::{ + ArrowTimestampType, DataType, Field, FieldRef, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; +use datafusion_common::types::{NativeType, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, exec_datafusion_err, exec_err, internal_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignatureClass, Volatility, +}; +use datafusion_functions::datetime::to_local_time::adjust_to_local_time; +use datafusion_functions::utils::make_scalar_function; + +/// Apache Spark `from_utc_timestamp` function. +/// +/// Interprets the given timestamp as UTC and converts it to the given timezone. +/// +/// Timestamp in Apache Spark represents number of microseconds from the Unix epoch, which is not +/// timezone-agnostic. So in Apache Spark this function just shift the timestamp value from UTC timezone to +/// the given timezone. +/// +/// See +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkFromUtcTimestamp { + signature: Signature, +} + +impl Default for SparkFromUtcTimestamp { + fn default() -> Self { + Self::new() + } +} + +impl SparkFromUtcTimestamp { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Timestamp, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Timestamp(TimeUnit::Microsecond, None), + ), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkFromUtcTimestamp { + fn name(&self) -> &str { + "from_utc_timestamp" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[0].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_from_utc_timestamp, vec![])(&args.args) + } +} + +fn spark_from_utc_timestamp(args: &[ArrayRef]) -> Result { + let [timestamp, timezone] = take_function_args("from_utc_timestamp", args)?; + + match timestamp.data_type() { + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + ts_type => { + exec_err!("`from_utc_timestamp`: unsupported argument types: {ts_type}") + } + } +} + +fn process_timestamp_with_tz_array( + ts_array: &ArrayRef, + tz_array: &ArrayRef, + tz_opt: Option>, +) -> Result { + match tz_array.data_type() { + DataType::Utf8 => { + process_arrays::(tz_opt, ts_array, tz_array.as_string::()) + } + DataType::LargeUtf8 => { + process_arrays::(tz_opt, ts_array, tz_array.as_string::()) + } + DataType::Utf8View => { + process_arrays::(tz_opt, ts_array, tz_array.as_string_view()) + } + other => { + exec_err!("`from_utc_timestamp`: timezone must be a string type, got {other}") + } + } +} + +fn process_arrays<'a, T: ArrowTimestampType, S>( + return_tz_opt: Option>, + ts_array: &ArrayRef, + tz_array: &'a S, +) -> Result +where + &'a S: StringArrayType<'a>, +{ + let ts_primitive = ts_array.as_primitive::(); + let mut builder = PrimitiveBuilder::::with_capacity(ts_array.len()); + + for (ts_opt, tz_opt) in ts_primitive.iter().zip(tz_array.iter()) { + match (ts_opt, tz_opt) { + (Some(ts), Some(tz_str)) => { + let tz: Tz = tz_str.parse().map_err(|e| { + exec_datafusion_err!( + "`from_utc_timestamp`: invalid timezone '{tz_str}': {e}" + ) + })?; + let val = adjust_to_local_time::(ts, tz)?; + builder.append_value(val); + } + _ => builder.append_null(), + } + } + + builder = builder.with_timezone_opt(return_tz_opt); + Ok(Arc::new(builder.finish())) +} diff --git a/datafusion/spark/src/function/datetime/last_day.rs b/datafusion/spark/src/function/datetime/last_day.rs index b75f10ad5e42e..74c55d911f410 100644 --- a/datafusion/spark/src/function/datetime/last_day.rs +++ b/datafusion/spark/src/function/datetime/last_day.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::{ArrayRef, AsArray, Date32Array}; -use arrow::datatypes::{DataType, Date32Type}; +use arrow::datatypes::{DataType, Date32Type, Field, FieldRef}; use chrono::{Datelike, Duration, NaiveDate}; use datafusion_common::utils::take_function_args; -use datafusion_common::{exec_datafusion_err, internal_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, exec_datafusion_err, internal_err}; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; #[derive(Debug, PartialEq, Eq, Hash)] @@ -47,10 +47,6 @@ impl SparkLastDay { } impl ScalarUDFImpl for SparkLastDay { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "last_day" } @@ -60,7 +56,19 @@ impl ScalarUDFImpl for SparkLastDay { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Date32) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let Some(field) = args.arg_fields.first() else { + return internal_err!("Spark `last_day` expects exactly one argument"); + }; + + Ok(Arc::new(Field::new( + self.name(), + DataType::Date32, + field.is_nullable(), + ))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -86,7 +94,9 @@ impl ScalarUDFImpl for SparkLastDay { Ok(Arc::new(result) as ArrayRef) } other => { - internal_err!("Unsupported data type {other:?} for Spark function `last_day`") + internal_err!( + "Unsupported data type {other:?} for Spark function `last_day`" + ) } }?; Ok(ColumnarValue::Array(result)) @@ -99,7 +109,11 @@ impl ScalarUDFImpl for SparkLastDay { } fn spark_last_day(days: i32) -> Result { - let date = Date32Type::to_naive_date(days); + let date = Date32Type::to_naive_date_opt(days).ok_or_else(|| { + exec_datafusion_err!( + "Spark `last_day`: Unable to convert days value {days} to date" + ) + })?; let (year, month) = (date.year(), date.month()); let (next_year, next_month) = if month == 12 { @@ -119,3 +133,57 @@ fn spark_last_day(days: i32) -> Result { first_day_next_month - Duration::days(1), )) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::function::utils::test::test_scalar_function; + use arrow::array::Array; + + #[test] + fn test_last_day_nullability_matches_input() { + let func = SparkLastDay::new(); + + let non_nullable_arg = Arc::new(Field::new("arg", DataType::Date32, false)); + let nullable_arg = Arc::new(Field::new("arg", DataType::Date32, true)); + + let non_nullable_out = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&non_nullable_arg)], + scalar_arguments: &[None], + }) + .expect("non-nullable arg should succeed"); + assert_eq!(non_nullable_out.data_type(), &DataType::Date32); + assert!(!non_nullable_out.is_nullable()); + + let nullable_out = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&nullable_arg)], + scalar_arguments: &[None], + }) + .expect("nullable arg should succeed"); + assert_eq!(nullable_out.data_type(), &DataType::Date32); + assert!(nullable_out.is_nullable()); + } + + #[test] + fn test_last_day_scalar_evaluation() { + test_scalar_function!( + SparkLastDay::new(), + vec![ColumnarValue::Scalar(ScalarValue::Date32(Some(0)))], + Ok(Some(30)), + i32, + DataType::Date32, + Date32Array + ); + + test_scalar_function!( + SparkLastDay::new(), + vec![ColumnarValue::Scalar(ScalarValue::Date32(None))], + Ok(None), + i32, + DataType::Date32, + Date32Array + ); + } +} diff --git a/datafusion/spark/src/function/datetime/make_dt_interval.rs b/datafusion/spark/src/function/datetime/make_dt_interval.rs index bbfba44861344..88ccae1b914a4 100644 --- a/datafusion/spark/src/function/datetime/make_dt_interval.rs +++ b/datafusion/spark/src/function/datetime/make_dt_interval.rs @@ -15,19 +15,20 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::{ Array, ArrayRef, AsArray, DurationMicrosecondBuilder, PrimitiveArray, }; use arrow::datatypes::TimeUnit::Microsecond; -use arrow::datatypes::{DataType, Float64Type, Int32Type}; +use arrow::datatypes::{DataType, Field, FieldRef, Float64Type, Int32Type}; +use datafusion_common::types::{NativeType, logical_float64, logical_int32}; use datafusion_common::{ - exec_err, plan_datafusion_err, DataFusionError, Result, ScalarValue, + DataFusionError, Result, ScalarValue, internal_err, plan_datafusion_err, }; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_functions::utils::make_scalar_function; @@ -44,17 +45,42 @@ impl Default for SparkMakeDtInterval { impl SparkMakeDtInterval { pub fn new() -> Self { + let int32 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ); + + let float64 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ); + + let variants = vec![ + TypeSignature::Nullary, + // (days) + TypeSignature::Coercible(vec![int32.clone()]), + // (days, hours) + TypeSignature::Coercible(vec![int32.clone(), int32.clone()]), + // (days, hours, minutes) + TypeSignature::Coercible(vec![int32.clone(), int32.clone(), int32.clone()]), + // (days, hours, minutes, seconds) + TypeSignature::Coercible(vec![ + int32.clone(), + int32.clone(), + int32.clone(), + float64, + ]), + ]; + Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::one_of(variants, Volatility::Immutable), } } } impl ScalarUDFImpl for SparkMakeDtInterval { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "make_dt_interval" } @@ -70,7 +96,28 @@ impl ScalarUDFImpl for SparkMakeDtInterval { /// /// [Sail compatibility doc]: https://github.com/lakehq/sail/blob/dc5368daa24d40a7758a299e1ba8fc985cb29108/docs/guide/dataframe/data-types/compatibility.md?plain=1#L260 fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Duration(Microsecond)) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let has_non_finite_secs = args + .scalar_arguments + .get(3) + .and_then(|arg| { + arg.map(|scalar| match scalar { + ScalarValue::Float64(Some(v)) => !v.is_finite(), + ScalarValue::Float32(Some(v)) => !v.is_finite(), + _ => false, + }) + }) + .unwrap_or(false); + let nullable = + has_non_finite_secs || args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new( + self.name(), + DataType::Duration(Microsecond), + nullable, + ))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -79,26 +126,13 @@ impl ScalarUDFImpl for SparkMakeDtInterval { Some(0), ))); } - make_scalar_function(make_dt_interval_kernel, vec![])(&args.args) - } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() > 4 { - return exec_err!( + if args.args.len() > 4 { + return Err(DataFusionError::Execution(format!( "make_dt_interval expects between 0 and 4 arguments, got {}", - arg_types.len() - ); + args.args.len() + ))); } - - Ok((0..arg_types.len()) - .map(|i| { - if i == 3 { - DataType::Float64 - } else { - DataType::Int32 - } - }) - .collect()) + make_scalar_function(make_dt_interval_kernel, vec![])(&args.args) } } @@ -205,14 +239,11 @@ fn make_interval_dt_nano(day: i32, hour: i32, min: i32, sec: f64) -> Option #[cfg(test)] mod tests { - use std::sync::Arc; use arrow::array::{DurationMicrosecondArray, Float64Array, Int32Array}; use arrow::datatypes::DataType::Duration; - use arrow::datatypes::Field; use arrow::datatypes::TimeUnit::Microsecond; - use datafusion_common::{internal_datafusion_err, DataFusionError, Result}; - use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; + use datafusion_common::internal_datafusion_err; use super::*; @@ -276,6 +307,59 @@ mod tests { Ok(()) } + #[test] + fn return_field_respects_nullability() -> Result<()> { + let udf = SparkMakeDtInterval::new(); + + // All nullable inputs -> nullable output + let arg_fields = vec![ + Arc::new(Field::new("days", DataType::Int32, true)), + Arc::new(Field::new("hours", DataType::Int32, true)), + Arc::new(Field::new("mins", DataType::Int32, true)), + Arc::new(Field::new("secs", DataType::Float64, true)), + ]; + + let out = udf.return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &[None, None, None, None], + })?; + assert!(out.is_nullable()); + assert_eq!(out.data_type(), &Duration(Microsecond)); + + // Non-nullable inputs -> non-nullable output + let non_nullable_arg_fields = vec![ + Arc::new(Field::new("days", DataType::Int32, false)), + Arc::new(Field::new("hours", DataType::Int32, false)), + Arc::new(Field::new("mins", DataType::Int32, false)), + Arc::new(Field::new("secs", DataType::Float64, false)), + ]; + + let out = udf.return_field_from_args(ReturnFieldArgs { + arg_fields: &non_nullable_arg_fields, + scalar_arguments: &[None, None, None, None], + })?; + assert!(!out.is_nullable()); + + // Non-finite secs scalar should force nullable even if fields are non-nullable + let scalar_values = + [None, None, None, Some(ScalarValue::Float64(Some(f64::NAN)))]; + let scalar_refs = scalar_values.iter().map(|v| v.as_ref()).collect::>(); + let out = udf.return_field_from_args(ReturnFieldArgs { + arg_fields: &non_nullable_arg_fields, + scalar_arguments: &scalar_refs, + })?; + assert!(out.is_nullable()); + + // Zero-arg call (defaults) should also be non-nullable + let out = udf.return_field_from_args(ReturnFieldArgs { + arg_fields: &[], + scalar_arguments: &[], + })?; + assert!(!out.is_nullable()); + + Ok(()) + } + #[test] fn error_months_overflow_should_be_null() -> Result<()> { // months = year*12 + month → NULL @@ -465,19 +549,33 @@ mod tests { fn no_more_than_4_params() -> Result<()> { let udf = SparkMakeDtInterval::new(); - let arg_types = vec![ - DataType::Int32, - DataType::Int32, - DataType::Int32, - DataType::Float64, - DataType::Int32, + // Create args with 5 parameters (exceeds the limit of 4) + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(3))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(4.0))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(5))), ]; - let res = udf.coerce_types(&arg_types); + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true).into()) + .collect::>(); + + let func_args = ScalarFunctionArgs { + args, + arg_fields, + number_rows: 1, + return_field: Field::new("f", Duration(Microsecond), true).into(), + config_options: Arc::new(Default::default()), + }; + + let res = udf.invoke_with_args(func_args); assert!( matches!(res, Err(DataFusionError::Execution(_))), - "make_dt_interval should return execution error for too many arguments" + "make_dt_interval should return execution error for more than 4 arguments" ); Ok(()) diff --git a/datafusion/spark/src/function/datetime/make_interval.rs b/datafusion/spark/src/function/datetime/make_interval.rs index d510eacb9aa41..abbf398d53d89 100644 --- a/datafusion/spark/src/function/datetime/make_interval.rs +++ b/datafusion/spark/src/function/datetime/make_interval.rs @@ -15,18 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::{Array, ArrayRef, IntervalMonthDayNanoBuilder, PrimitiveArray}; use arrow::datatypes::DataType::Interval; use arrow::datatypes::IntervalUnit::MonthDayNano; use arrow::datatypes::{DataType, IntervalMonthDayNano}; -use datafusion_common::{ - exec_err, plan_datafusion_err, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::types::{NativeType, logical_float64, logical_int32}; +use datafusion_common::{DataFusionError, Result, ScalarValue, plan_datafusion_err}; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, }; use datafusion_functions::utils::make_scalar_function; @@ -43,17 +42,69 @@ impl Default for SparkMakeInterval { impl SparkMakeInterval { pub fn new() -> Self { + let int32 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ); + + let float64 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ); + + let variants = vec![ + TypeSignature::Nullary, + // year + TypeSignature::Coercible(vec![int32.clone()]), + // year, month + TypeSignature::Coercible(vec![int32.clone(), int32.clone()]), + // year, month, week + TypeSignature::Coercible(vec![int32.clone(), int32.clone(), int32.clone()]), + // year, month, week, day + TypeSignature::Coercible(vec![ + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + ]), + // year, month, week, day, hour + TypeSignature::Coercible(vec![ + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + ]), + // year, month, week, day, hour, minute + TypeSignature::Coercible(vec![ + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + ]), + // year, month, week, day, hour, minute, second + TypeSignature::Coercible(vec![ + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + int32.clone(), + float64.clone(), + ]), + ]; + Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::one_of(variants, Volatility::Immutable), } } } impl ScalarUDFImpl for SparkMakeInterval { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "make_interval" } @@ -74,27 +125,6 @@ impl ScalarUDFImpl for SparkMakeInterval { } make_scalar_function(make_interval_kernel, vec![])(&args.args) } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let length = arg_types.len(); - match length { - x if x > 7 => { - exec_err!( - "make_interval expects between 0 and 7 arguments, got {}", - arg_types.len() - ) - } - _ => Ok((0..arg_types.len()) - .map(|i| { - if i == 6 { - DataType::Float64 - } else { - DataType::Int32 - } - }) - .collect()), - } - } } fn make_interval_kernel(args: &[ArrayRef]) -> Result { @@ -239,7 +269,7 @@ mod tests { use arrow::datatypes::Field; use datafusion_common::config::ConfigOptions; use datafusion_common::{ - assert_eq_or_internal_err, internal_datafusion_err, internal_err, Result, + Result, assert_eq_or_internal_err, internal_datafusion_err, internal_err, }; use super::*; diff --git a/datafusion/spark/src/function/datetime/mod.rs b/datafusion/spark/src/function/datetime/mod.rs index a6adc99607665..98afa91ddc834 100644 --- a/datafusion/spark/src/function/datetime/mod.rs +++ b/datafusion/spark/src/function/datetime/mod.rs @@ -15,27 +15,74 @@ // specific language governing permissions and limitations // under the License. +pub mod add_months; pub mod date_add; +pub mod date_diff; +pub mod date_part; pub mod date_sub; +pub mod date_trunc; +pub mod extract; +pub mod from_utc_timestamp; pub mod last_day; pub mod make_dt_interval; pub mod make_interval; +pub mod monthname; pub mod next_day; +pub mod time_trunc; +pub mod to_utc_timestamp; +pub mod trunc; +pub mod unix; use datafusion_expr::ScalarUDF; use datafusion_functions::make_udf_function; use std::sync::Arc; +make_udf_function!(add_months::SparkAddMonths, add_months); make_udf_function!(date_add::SparkDateAdd, date_add); +make_udf_function!(date_diff::SparkDateDiff, date_diff); +make_udf_function!(date_part::SparkDatePart, date_part); make_udf_function!(date_sub::SparkDateSub, date_sub); +make_udf_function!(date_trunc::SparkDateTrunc, date_trunc); +make_udf_function!( + from_utc_timestamp::SparkFromUtcTimestamp, + from_utc_timestamp +); +make_udf_function!(extract::SparkHour, hour); +make_udf_function!(extract::SparkMinute, minute); +make_udf_function!(extract::SparkSecond, second); make_udf_function!(last_day::SparkLastDay, last_day); make_udf_function!(make_dt_interval::SparkMakeDtInterval, make_dt_interval); make_udf_function!(make_interval::SparkMakeInterval, make_interval); +make_udf_function!(monthname::SparkMonthName, monthname); make_udf_function!(next_day::SparkNextDay, next_day); +make_udf_function!(time_trunc::SparkTimeTrunc, time_trunc); +make_udf_function!(to_utc_timestamp::SparkToUtcTimestamp, to_utc_timestamp); +make_udf_function!(trunc::SparkTrunc, trunc); +make_udf_function!(unix::SparkUnixDate, unix_date); +make_udf_function!( + unix::SparkUnixTimestamp, + unix_micros, + unix::SparkUnixTimestamp::microseconds +); +make_udf_function!( + unix::SparkUnixTimestamp, + unix_millis, + unix::SparkUnixTimestamp::milliseconds +); +make_udf_function!( + unix::SparkUnixTimestamp, + unix_seconds, + unix::SparkUnixTimestamp::seconds +); pub mod expr_fn { use datafusion_functions::export_functions; + export_functions!(( + add_months, + "Returns the date that is months months after start. The function returns NULL if at least one of the input parameters is NULL.", + arg1 arg2 + )); export_functions!(( date_add, "Returns the date that is days days after start. The function returns NULL if at least one of the input parameters is NULL.", @@ -46,6 +93,17 @@ pub mod expr_fn { "Returns the date that is days days before start. The function returns NULL if at least one of the input parameters is NULL.", arg1 arg2 )); + export_functions!((hour, "Extracts the hour component of a timestamp.", arg1)); + export_functions!(( + minute, + "Extracts the minute component of a timestamp.", + arg1 + )); + export_functions!(( + second, + "Extracts the second component of a timestamp.", + arg1 + )); export_functions!(( last_day, "Returns the last day of the month which the date belongs to.", @@ -61,6 +119,11 @@ pub mod expr_fn { "Make interval from years, months, weeks, days, hours, mins and secs.", years months weeks days hours mins secs )); + export_functions!(( + monthname, + "Returns the three-letter abbreviated month name from a date or timestamp.", + arg1 + )); // TODO: add once ANSI support is added: // "When both of the input parameters are not NULL and day_of_week is an invalid input, the function throws SparkIllegalArgumentException if spark.sql.ansi.enabled is set to true, otherwise NULL." export_functions!(( @@ -68,15 +131,86 @@ pub mod expr_fn { "Returns the first date which is later than start_date and named as indicated. The function returns NULL if at least one of the input parameters is NULL.", arg1 arg2 )); + export_functions!(( + date_diff, + "Returns the number of days from start `start` to end `end`.", + end start + )); + export_functions!(( + date_trunc, + "Truncates a timestamp `ts` to the unit specified by the format `fmt`.", + fmt ts + )); + export_functions!(( + time_trunc, + "Truncates a time `t` to the unit specified by the format `fmt`.", + fmt t + )); + export_functions!(( + trunc, + "Truncates a date `dt` to the unit specified by the format `fmt`.", + dt fmt + )); + export_functions!(( + date_part, + "Extracts a part of the date or time from a date, time, or timestamp expression.", + arg1 arg2 + )); + export_functions!(( + from_utc_timestamp, + "Interpret a given timestamp `ts` in UTC timezone and then convert it to timezone `tz`.", + ts tz + )); + export_functions!(( + to_utc_timestamp, + "Interpret a given timestamp `ts` in timezone `tz` and then convert it to UTC timezone.", + ts tz + )); + export_functions!(( + unix_date, + "Returns the number of days since epoch (1970-01-01) for the given date `dt`.", + dt + )); + export_functions!(( + unix_micros, + "Returns the number of microseconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp `ts`.", + ts + )); + export_functions!(( + unix_millis, + "Returns the number of milliseconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp `ts`.", + ts + )); + export_functions!(( + unix_seconds, + "Returns the number of seconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp `ts`.", + ts + )); } pub fn functions() -> Vec> { vec![ + add_months(), date_add(), + date_diff(), + date_part(), date_sub(), + date_trunc(), + from_utc_timestamp(), + hour(), last_day(), make_dt_interval(), make_interval(), + minute(), + monthname(), next_day(), + second(), + time_trunc(), + to_utc_timestamp(), + trunc(), + unix_date(), + unix_micros(), + unix_millis(), + unix_seconds(), ] } diff --git a/datafusion/spark/src/function/datetime/monthname.rs b/datafusion/spark/src/function/datetime/monthname.rs new file mode 100644 index 0000000000000..6cfa9c0a9212e --- /dev/null +++ b/datafusion/spark/src/function/datetime/monthname.rs @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{AsArray, StringArray}; +use arrow::compute::{DatePart, date_part}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::types::{NativeType, logical_date}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignatureClass, Volatility, +}; + +const MONTH_NAMES: [&str; 12] = [ + "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec", +]; + +fn month_number_to_name(month: i32) -> Option<&'static str> { + MONTH_NAMES.get((month - 1) as usize).copied() +} + +/// Spark-compatible `monthname` expression. +/// Returns the three-letter abbreviated month name from a date or timestamp. +/// +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkMonthName { + signature: Signature, +} + +impl Default for SparkMonthName { + fn default() -> Self { + Self::new() + } +} + +impl SparkMonthName { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_date()), + vec![TypeSignatureClass::Timestamp], + NativeType::Date, + )], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkMonthName { + fn name(&self) -> &str { + "monthname" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new(self.name(), DataType::Utf8, nullable))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [arg] = take_function_args(self.name(), args.args)?; + match arg { + ColumnarValue::Scalar(scalar) => { + if scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + let arr = scalar.to_array_of_size(1)?; + let month_arr = date_part(&arr, DatePart::Month)?; + let month_val = month_arr + .as_primitive::() + .value(0); + let name = month_number_to_name(month_val).map(|s| s.to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(name))) + } + ColumnarValue::Array(arr) => { + let month_arr = date_part(&arr, DatePart::Month)?; + let int_arr = month_arr.as_primitive::(); + + let result: StringArray = int_arr + .iter() + .map(|maybe_month| maybe_month.and_then(month_number_to_name)) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result))) + } + } + } +} diff --git a/datafusion/spark/src/function/datetime/next_day.rs b/datafusion/spark/src/function/datetime/next_day.rs index 32739f3e2c591..2241043d44cd7 100644 --- a/datafusion/spark/src/function/datetime/next_day.rs +++ b/datafusion/spark/src/function/datetime/next_day.rs @@ -15,15 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; -use arrow::array::{new_null_array, ArrayRef, AsArray, Date32Array, StringArrayType}; -use arrow::datatypes::{DataType, Date32Type}; +use arrow::array::{ArrayRef, AsArray, Date32Array, StringArrayType}; +use arrow::datatypes::{DataType, Date32Type, Field, FieldRef}; use chrono::{Datelike, Duration, Weekday}; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; /// @@ -50,10 +50,6 @@ impl SparkNextDay { } impl ScalarUDFImpl for SparkNextDay { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "next_day" } @@ -63,7 +59,13 @@ impl ScalarUDFImpl for SparkNextDay { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Date32) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + // Spark marks next_day as always nullable because invalid day_of_week values + // can yield NULL even when inputs are non-null. + Ok(Arc::new(Field::new(self.name(), DataType::Date32, true))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -78,7 +80,12 @@ impl ScalarUDFImpl for SparkNextDay { match (date, day_of_week) { (ColumnarValue::Scalar(date), ColumnarValue::Scalar(day_of_week)) => { match (date, day_of_week) { - (ScalarValue::Date32(days), ScalarValue::Utf8(day_of_week) | ScalarValue::LargeUtf8(day_of_week) | ScalarValue::Utf8View(day_of_week)) => { + ( + ScalarValue::Date32(days), + ScalarValue::Utf8(day_of_week) + | ScalarValue::LargeUtf8(day_of_week) + | ScalarValue::Utf8View(day_of_week), + ) => { if let Some(days) = days { if let Some(day_of_week) = day_of_week { Ok(ColumnarValue::Scalar(ScalarValue::Date32( @@ -93,25 +100,36 @@ impl ScalarUDFImpl for SparkNextDay { Ok(ColumnarValue::Scalar(ScalarValue::Date32(None))) } } - _ => exec_err!("Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"), + _ => exec_err!( + "Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}" + ), } } (ColumnarValue::Array(date_array), ColumnarValue::Scalar(day_of_week)) => { match (date_array.data_type(), day_of_week) { - (DataType::Date32, ScalarValue::Utf8(day_of_week) | ScalarValue::LargeUtf8(day_of_week) | ScalarValue::Utf8View(day_of_week)) => { + ( + DataType::Date32, + ScalarValue::Utf8(day_of_week) + | ScalarValue::LargeUtf8(day_of_week) + | ScalarValue::Utf8View(day_of_week), + ) => { if let Some(day_of_week) = day_of_week { let result: Date32Array = date_array .as_primitive::() - .unary_opt(|days| spark_next_day(days, day_of_week.as_str())) + .unary_opt(|days| { + spark_next_day(days, day_of_week.as_str()) + }) .with_data_type(DataType::Date32); Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) } else { // TODO: if spark.sql.ansi.enabled is false, // returns NULL instead of an error for a malformed dayOfWeek. - Ok(ColumnarValue::Array(Arc::new(new_null_array(&DataType::Date32, date_array.len())))) + Ok(ColumnarValue::Scalar(ScalarValue::Date32(None))) } } - _ => exec_err!("Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}"), + _ => exec_err!( + "Spark `next_day` function: first arg must be date, second arg must be string. Got {args:?}" + ), } } ( @@ -143,7 +161,9 @@ impl ScalarUDFImpl for SparkNextDay { process_next_day_arrays(date_array, day_of_week_array) } other => { - exec_err!("Spark `next_day` function: second arg must be string. Got {other:?}") + exec_err!( + "Spark `next_day` function: second arg must be string. Got {other:?}" + ) } } } @@ -188,7 +208,7 @@ where } fn spark_next_day(days: i32, day_of_week: &str) -> Option { - let date = Date32Type::to_naive_date(days); + let date = Date32Type::to_naive_date_opt(days)?; let day_of_week = day_of_week.trim().to_uppercase(); let day_of_week = match day_of_week.as_str() { @@ -224,3 +244,39 @@ fn spark_next_day(days: i32, day_of_week: &str) -> Option { None } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn return_type_is_not_used() { + let func = SparkNextDay::new(); + let err = func + .return_type(&[DataType::Date32, DataType::Utf8]) + .unwrap_err(); + assert!( + err.to_string() + .contains("return_field_from_args should be used instead") + ); + } + + #[test] + fn next_day_is_always_nullable() { + let func = SparkNextDay::new(); + let date_field: FieldRef = + Arc::new(Field::new("start_date", DataType::Date32, false)); + let day_field: FieldRef = + Arc::new(Field::new("day_of_week", DataType::Utf8, false)); + + let field = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&date_field), Arc::clone(&day_field)], + scalar_arguments: &[None, None], + }) + .unwrap(); + + assert_eq!(field.data_type(), &DataType::Date32); + assert!(field.is_nullable()); + } +} diff --git a/datafusion/spark/src/function/datetime/time_trunc.rs b/datafusion/spark/src/function/datetime/time_trunc.rs new file mode 100644 index 0000000000000..a66b8e94685aa --- /dev/null +++ b/datafusion/spark/src/function/datetime/time_trunc.rs @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::types::logical_string; +use datafusion_common::{Result, ScalarValue, internal_err, plan_err}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignatureClass, Volatility, +}; + +/// Spark time_trunc function only handles time inputs. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkTimeTrunc { + signature: Signature, +} + +impl Default for SparkTimeTrunc { + fn default() -> Self { + Self::new() + } +} + +impl SparkTimeTrunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Time), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkTimeTrunc { + fn name(&self) -> &str { + "time_trunc" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[1].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!( + "spark time_trunc should have been simplified to standard date_trunc" + ) + } + + fn simplify( + &self, + args: Vec, + _info: &SimplifyContext, + ) -> Result { + let fmt_expr = &args[0]; + + let fmt = match fmt_expr.as_literal() { + Some(ScalarValue::Utf8(Some(v))) + | Some(ScalarValue::Utf8View(Some(v))) + | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(), + _ => { + return plan_err!( + "First argument of `TIME_TRUNC` must be non-null scalar Utf8" + ); + } + }; + + if !matches!( + fmt.as_str(), + "hour" | "minute" | "second" | "millisecond" | "microsecond" + ) { + return plan_err!( + "The format argument of `TIME_TRUNC` must be one of: hour, minute, second, millisecond, microsecond" + ); + } + + Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( + ScalarFunction::new_udf(datafusion_functions::datetime::date_trunc(), args), + ))) + } +} diff --git a/datafusion/spark/src/function/datetime/to_utc_timestamp.rs b/datafusion/spark/src/function/datetime/to_utc_timestamp.rs new file mode 100644 index 0000000000000..67910ff33f1af --- /dev/null +++ b/datafusion/spark/src/function/datetime/to_utc_timestamp.rs @@ -0,0 +1,220 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::timezone::Tz; +use arrow::array::{Array, ArrayRef, AsArray, PrimitiveBuilder, StringArrayType}; +use arrow::datatypes::TimeUnit; +use arrow::datatypes::{ + ArrowTimestampType, DataType, Field, FieldRef, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; +use chrono::{DateTime, Offset, TimeZone}; +use datafusion_common::types::{NativeType, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{ + Result, exec_datafusion_err, exec_err, internal_datafusion_err, internal_err, +}; +use datafusion_expr::{ + Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +/// Apache Spark `to_utc_timestamp` function. +/// +/// Interprets the given timestamp in the provided timezone and then converts it to UTC. +/// +/// Timestamp in Apache Spark represents number of microseconds from the Unix epoch, which is not +/// timezone-agnostic. So in Apache Spark this function just shift the timestamp value from the given +/// timezone to UTC timezone. +/// +/// See +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkToUtcTimestamp { + signature: Signature, +} + +impl Default for SparkToUtcTimestamp { + fn default() -> Self { + Self::new() + } +} + +impl SparkToUtcTimestamp { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Timestamp, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Timestamp(TimeUnit::Microsecond, None), + ), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkToUtcTimestamp { + fn name(&self) -> &str { + "to_utc_timestamp" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[0].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(to_utc_timestamp, vec![])(&args.args) + } +} + +fn to_utc_timestamp(args: &[ArrayRef]) -> Result { + let [timestamp, timezone] = take_function_args("to_utc_timestamp", args)?; + + match timestamp.data_type() { + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + process_timestamp_with_tz_array::( + timestamp, + timezone, + tz_opt.clone(), + ) + } + ts_type => { + exec_err!("`to_utc_timestamp`: unsupported argument types: {ts_type}") + } + } +} + +fn process_timestamp_with_tz_array( + ts_array: &ArrayRef, + tz_array: &ArrayRef, + tz_opt: Option>, +) -> Result { + match tz_array.data_type() { + DataType::Utf8 => { + process_arrays::(tz_opt, ts_array, tz_array.as_string::()) + } + DataType::LargeUtf8 => { + process_arrays::(tz_opt, ts_array, tz_array.as_string::()) + } + DataType::Utf8View => { + process_arrays::(tz_opt, ts_array, tz_array.as_string_view()) + } + other => { + exec_err!("`to_utc_timestamp`: timezone must be a string type, got {other}") + } + } +} + +fn process_arrays<'a, T: ArrowTimestampType, S>( + return_tz_opt: Option>, + ts_array: &ArrayRef, + tz_array: &'a S, +) -> Result +where + &'a S: StringArrayType<'a>, +{ + let ts_primitive = ts_array.as_primitive::(); + let mut builder = PrimitiveBuilder::::with_capacity(ts_array.len()); + + for (ts_opt, tz_opt) in ts_primitive.iter().zip(tz_array.iter()) { + match (ts_opt, tz_opt) { + (Some(ts), Some(tz_str)) => { + let tz: Tz = tz_str.parse().map_err(|e| { + exec_datafusion_err!( + "`to_utc_timestamp`: invalid timezone '{tz_str}': {e}" + ) + })?; + let val = adjust_to_utc_time::(ts, tz)?; + builder.append_value(val); + } + _ => builder.append_null(), + } + } + + builder = builder.with_timezone_opt(return_tz_opt); + Ok(Arc::new(builder.finish())) +} + +fn adjust_to_utc_time(ts: i64, tz: Tz) -> Result { + let dt = match T::UNIT { + TimeUnit::Nanosecond => Some(DateTime::from_timestamp_nanos(ts)), + TimeUnit::Microsecond => DateTime::from_timestamp_micros(ts), + TimeUnit::Millisecond => DateTime::from_timestamp_millis(ts), + TimeUnit::Second => DateTime::from_timestamp(ts, 0), + } + .ok_or_else(|| internal_datafusion_err!("Invalid timestamp"))?; + let naive_dt = dt.naive_utc(); + + let offset_seconds = tz + .offset_from_utc_datetime(&naive_dt) + .fix() + .local_minus_utc() as i64; + + let offset_in_unit = match T::UNIT { + TimeUnit::Nanosecond => offset_seconds.checked_mul(1_000_000_000), + TimeUnit::Microsecond => offset_seconds.checked_mul(1_000_000), + TimeUnit::Millisecond => offset_seconds.checked_mul(1_000), + TimeUnit::Second => Some(offset_seconds), + } + .ok_or_else(|| internal_datafusion_err!("Offset overflow"))?; + + ts.checked_sub(offset_in_unit).ok_or_else(|| { + internal_datafusion_err!("Timestamp overflow during timezone adjustment") + }) +} diff --git a/datafusion/spark/src/function/datetime/trunc.rs b/datafusion/spark/src/function/datetime/trunc.rs new file mode 100644 index 0000000000000..9d7da5969a525 --- /dev/null +++ b/datafusion/spark/src/function/datetime/trunc.rs @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; +use datafusion_common::types::{NativeType, logical_date, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err, plan_err}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; + +/// Spark trunc supports date inputs only and extra format aliases. +/// Also spark trunc's argument order is (date, format). +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkTrunc { + signature: Signature, +} + +impl Default for SparkTrunc { + fn default() -> Self { + Self::new() + } +} + +impl SparkTrunc { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Native(logical_date()), + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Date, + ), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkTrunc { + fn name(&self) -> &str { + "trunc" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + self.name(), + args.arg_fields[0].data_type().clone(), + nullable, + ))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("spark trunc should have been simplified to standard date_trunc") + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [dt_expr, fmt_expr] = take_function_args(self.name(), args)?; + + let fmt = match fmt_expr.as_literal() { + Some(ScalarValue::Utf8(Some(v))) + | Some(ScalarValue::Utf8View(Some(v))) + | Some(ScalarValue::LargeUtf8(Some(v))) => v.to_lowercase(), + _ => { + return plan_err!( + "Second argument of `TRUNC` must be non-null scalar Utf8" + ); + } + }; + + // Map Spark-specific fmt aliases to datafusion ones + let fmt = match fmt.as_str() { + "yy" | "yyyy" => "year", + "mm" | "mon" => "month", + "year" | "month" | "day" | "week" | "quarter" => fmt.as_str(), + _ => { + return plan_err!( + "The format argument of `TRUNC` must be one of: year, yy, yyyy, month, mm, mon, day, week, quarter." + ); + } + }; + let return_type = dt_expr.get_type(info.schema())?; + + let fmt_expr = Expr::Literal(ScalarValue::new_utf8(fmt), None); + + // Spark uses Dates so we need to cast to timestamp and back to work with datafusion's date_trunc + Ok(ExprSimplifyResult::Simplified( + Expr::ScalarFunction(ScalarFunction::new_udf( + datafusion_functions::datetime::date_trunc(), + vec![ + fmt_expr, + dt_expr.cast_to( + &DataType::Timestamp(TimeUnit::Nanosecond, None), + info.schema(), + )?, + ], + )) + .cast_to(&return_type, info.schema())?, + )) + } +} diff --git a/datafusion/spark/src/function/datetime/unix.rs b/datafusion/spark/src/function/datetime/unix.rs new file mode 100644 index 0000000000000..6eaf3a08780bc --- /dev/null +++ b/datafusion/spark/src/function/datetime/unix.rs @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef, TimeUnit}; +use datafusion_common::types::logical_date; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{ + Coercion, ColumnarValue, Expr, ExprSchemable, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, +}; + +/// Returns the number of days since epoch (1970-01-01) for the given date. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkUnixDate { + signature: Signature, +} + +impl Default for SparkUnixDate { + fn default() -> Self { + Self::new() + } +} + +impl SparkUnixDate { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![Coercion::new_exact(TypeSignatureClass::Native( + logical_date(), + ))], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkUnixDate { + fn name(&self) -> &str { + "unix_date" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields[0].is_nullable(); + Ok(Arc::new(Field::new(self.name(), DataType::Int32, nullable))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("invoke_with_args should not be called on SparkUnixDate") + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [date] = take_function_args(self.name(), args)?; + Ok(ExprSimplifyResult::Simplified( + date.cast_to(&DataType::Date32, info.schema())? + .cast_to(&DataType::Int32, info.schema())?, + )) + } +} + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkUnixTimestamp { + time_unit: TimeUnit, + signature: Signature, + name: &'static str, +} + +impl SparkUnixTimestamp { + pub fn new(name: &'static str, time_unit: TimeUnit) -> Self { + Self { + signature: Signature::coercible( + vec![Coercion::new_exact(TypeSignatureClass::Timestamp)], + Volatility::Immutable, + ), + time_unit, + name, + } + } + + /// Returns the number of microseconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp. + /// + pub fn microseconds() -> Self { + Self::new("unix_micros", TimeUnit::Microsecond) + } + + /// Returns the number of milliseconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp. + /// + pub fn milliseconds() -> Self { + Self::new("unix_millis", TimeUnit::Millisecond) + } + + /// Returns the number of seconds since epoch (1970-01-01 00:00:00 UTC) for the given timestamp. + /// + pub fn seconds() -> Self { + Self::new("unix_seconds", TimeUnit::Second) + } +} + +impl ScalarUDFImpl for SparkUnixTimestamp { + fn name(&self) -> &str { + self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields[0].is_nullable(); + Ok(Arc::new(Field::new(self.name(), DataType::Int64, nullable))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("invoke_with_args should not be called on `{}`", self.name()) + } + + fn simplify( + &self, + args: Vec, + info: &SimplifyContext, + ) -> Result { + let [ts] = take_function_args(self.name(), args)?; + Ok(ExprSimplifyResult::Simplified( + ts.cast_to( + &DataType::Timestamp(self.time_unit, Some("UTC".into())), + info.schema(), + )? + .cast_to(&DataType::Int64, info.schema())?, + )) + } +} diff --git a/datafusion/spark/src/function/error_utils.rs b/datafusion/spark/src/function/error_utils.rs index b972d64ed3e9a..362a32bcd0cc2 100644 --- a/datafusion/spark/src/function/error_utils.rs +++ b/datafusion/spark/src/function/error_utils.rs @@ -18,7 +18,7 @@ // TODO: https://github.com/apache/spark/tree/master/common/utils/src/main/resources/error use arrow::datatypes::DataType; -use datafusion_common::{exec_datafusion_err, internal_datafusion_err, DataFusionError}; +use datafusion_common::{DataFusionError, exec_datafusion_err, internal_datafusion_err}; pub fn invalid_arg_count_exec_err( function_name: &str, @@ -44,7 +44,9 @@ pub fn unsupported_data_type_exec_err( required: &str, provided: &DataType, ) -> DataFusionError { - exec_datafusion_err!("Unsupported Data Type: Spark `{function_name}` function expects {required}, got {provided}") + exec_datafusion_err!( + "Unsupported Data Type: Spark `{function_name}` function expects {required}, got {provided}" + ) } pub fn unsupported_data_types_exec_err( diff --git a/datafusion/spark/src/function/hash/crc32.rs b/datafusion/spark/src/function/hash/crc32.rs index 8280e24b8ef59..2fc376abcb725 100644 --- a/datafusion/spark/src/function/hash/crc32.rs +++ b/datafusion/spark/src/function/hash/crc32.rs @@ -15,22 +15,21 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::{ArrayRef, Int64Array}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use crc32fast::Hasher; use datafusion_common::cast::{ as_binary_array, as_binary_view_array, as_fixed_size_binary_array, as_large_binary_array, }; -use datafusion_common::types::{logical_string, NativeType}; +use datafusion_common::types::{NativeType, logical_string}; use datafusion_common::utils::take_function_args; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{Result, internal_err}; use datafusion_expr::{ - Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, - TypeSignatureClass, Volatility, + Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignatureClass, Volatility, }; use datafusion_functions::utils::make_scalar_function; @@ -62,10 +61,6 @@ impl SparkCrc32 { } impl ScalarUDFImpl for SparkCrc32 { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "crc32" } @@ -75,7 +70,12 @@ impl ScalarUDFImpl for SparkCrc32 { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Int64) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new(self.name(), DataType::Int64, nullable))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -122,3 +122,33 @@ fn spark_crc32(args: &[ArrayRef]) -> Result { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_crc32_nullability() -> Result<()> { + let crc32_func = SparkCrc32::new(); + + // non-nullable field should produce non-nullable output + let field_not_null = Arc::new(Field::new("data", DataType::Binary, false)); + let result = crc32_func.return_field_from_args(ReturnFieldArgs { + arg_fields: std::slice::from_ref(&field_not_null), + scalar_arguments: &[None], + })?; + assert!(!result.is_nullable()); + assert_eq!(result.data_type(), &DataType::Int64); + + // nullable field should produce nullable output + let field_nullable = Arc::new(Field::new("data", DataType::Binary, true)); + let result = crc32_func.return_field_from_args(ReturnFieldArgs { + arg_fields: &[field_nullable], + scalar_arguments: &[None], + })?; + assert!(result.is_nullable()); + assert_eq!(result.data_type(), &DataType::Int64); + + Ok(()) + } +} diff --git a/datafusion/spark/src/function/hash/mod.rs b/datafusion/spark/src/function/hash/mod.rs index 5860596ac70a3..351f5d2d5063c 100644 --- a/datafusion/spark/src/function/hash/mod.rs +++ b/datafusion/spark/src/function/hash/mod.rs @@ -18,6 +18,8 @@ pub mod crc32; pub mod sha1; pub mod sha2; +pub(crate) mod utils; +pub mod xxhash64; use datafusion_expr::ScalarUDF; use datafusion_functions::make_udf_function; @@ -26,16 +28,18 @@ use std::sync::Arc; make_udf_function!(crc32::SparkCrc32, crc32); make_udf_function!(sha1::SparkSha1, sha1); make_udf_function!(sha2::SparkSha2, sha2); +make_udf_function!(xxhash64::SparkXxhash64, xxhash64); pub mod expr_fn { use datafusion_functions::export_functions; export_functions!( (crc32, "crc32(expr) - Returns a cyclic redundancy check value of the expr as a bigint.", arg1), (sha1, "sha1(expr) - Returns a SHA-1 hash value of the expr as a hex string.", arg1), - (sha2, "sha2(expr, bitLength) - Returns a checksum of SHA-2 family as a hex string of expr. SHA-224, SHA-256, SHA-384, and SHA-512 are supported. Bit length of 0 is equivalent to 256.", arg1 arg2) + (sha2, "sha2(expr, bitLength) - Returns a checksum of SHA-2 family as a hex string of expr. SHA-224, SHA-256, SHA-384, and SHA-512 are supported. Bit length of 0 is equivalent to 256.", arg1 arg2), + (xxhash64, "xxhash64(expr1, expr2, ...) - Returns a 64-bit hash value of the arguments using xxHash.", args) ); } pub fn functions() -> Vec> { - vec![crc32(), sha1(), sha2()] + vec![crc32(), sha1(), sha2(), xxhash64()] } diff --git a/datafusion/spark/src/function/hash/sha1.rs b/datafusion/spark/src/function/hash/sha1.rs index 9e3d96b8031a1..dd9009eb8233f 100644 --- a/datafusion/spark/src/function/hash/sha1.rs +++ b/datafusion/spark/src/function/hash/sha1.rs @@ -15,22 +15,20 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::fmt::Write; use std::sync::Arc; use arrow::array::{ArrayRef, StringArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cast::{ as_binary_array, as_binary_view_array, as_fixed_size_binary_array, as_large_binary_array, }; -use datafusion_common::types::{logical_string, NativeType}; +use datafusion_common::types::{NativeType, logical_string}; use datafusion_common::utils::take_function_args; -use datafusion_common::{internal_err, Result}; +use datafusion_common::{Result, internal_err}; use datafusion_expr::{ - Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, - TypeSignatureClass, Volatility, + Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignatureClass, Volatility, }; use datafusion_functions::utils::make_scalar_function; use sha1::{Digest, Sha1}; @@ -65,10 +63,6 @@ impl SparkSha1 { } impl ScalarUDFImpl for SparkSha1 { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "sha1" } @@ -82,7 +76,12 @@ impl ScalarUDFImpl for SparkSha1 { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Utf8) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new(self.name(), DataType::Utf8, nullable))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -90,12 +89,16 @@ impl ScalarUDFImpl for SparkSha1 { } } +/// Hex encoding lookup table for fast byte-to-hex conversion +const HEX_CHARS_LOWER: &[u8; 16] = b"0123456789abcdef"; + +#[inline] fn spark_sha1_digest(value: &[u8]) -> String { let result = Sha1::digest(value); let mut s = String::with_capacity(result.len() * 2); - #[allow(deprecated)] - for b in result.as_slice() { - write!(&mut s, "{b:02x}").unwrap(); + for &b in result.as_slice() { + s.push(HEX_CHARS_LOWER[(b >> 4) as usize] as char); + s.push(HEX_CHARS_LOWER[(b & 0x0f) as usize] as char); } s } @@ -133,3 +136,33 @@ fn spark_sha1(args: &[ArrayRef]) -> Result { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sha1_nullability() -> Result<()> { + let func = SparkSha1::new(); + + // Non-nullable input keeps output non-nullable + let non_nullable: FieldRef = Arc::new(Field::new("col", DataType::Binary, false)); + let out = func.return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&non_nullable)], + scalar_arguments: &[None], + })?; + assert!(!out.is_nullable()); + assert_eq!(out.data_type(), &DataType::Utf8); + + // Nullable input makes output nullable + let nullable: FieldRef = Arc::new(Field::new("col", DataType::Binary, true)); + let out = func.return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&nullable)], + scalar_arguments: &[None], + })?; + assert!(out.is_nullable()); + assert_eq!(out.data_type(), &DataType::Utf8); + + Ok(()) + } +} diff --git a/datafusion/spark/src/function/hash/sha2.rs b/datafusion/spark/src/function/hash/sha2.rs index b006607d3eeda..38fa0cc643751 100644 --- a/datafusion/spark/src/function/hash/sha2.rs +++ b/datafusion/spark/src/function/hash/sha2.rs @@ -15,26 +15,28 @@ // specific language governing permissions and limitations // under the License. -extern crate datafusion_functions; - -use crate::function::error_utils::{ - invalid_arg_count_exec_err, unsupported_data_type_exec_err, -}; -use crate::function::math::hex::spark_sha2_hex; -use arrow::array::{ArrayRef, AsArray, StringArray}; +use arrow::array::{ArrayRef, AsArray, BinaryArrayType, Int32Array, StringArray}; use arrow::datatypes::{DataType, Int32Type}; -use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue}; -use datafusion_expr::Signature; -use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; -pub use datafusion_functions::crypto::basic::{sha224, sha256, sha384, sha512}; -use std::any::Any; +use datafusion_common::types::{ + NativeType, logical_binary, logical_int32, logical_string, +}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, internal_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use sha2::{self, Digest}; use std::sync::Arc; +/// Differs from DataFusion version in allowing array input for bit lengths, and +/// also hex encoding the output. +/// /// #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkSha2 { signature: Signature, - aliases: Vec, } impl Default for SparkSha2 { @@ -46,17 +48,26 @@ impl Default for SparkSha2 { impl SparkSha2 { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), - aliases: vec![], + signature: Signature::coercible( + vec![ + Coercion::new_implicit( + TypeSignatureClass::Native(logical_binary()), + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Binary, + ), + Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ), + ], + Volatility::Immutable, + ), } } } impl ScalarUDFImpl for SparkSha2 { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "sha2" } @@ -65,156 +76,188 @@ impl ScalarUDFImpl for SparkSha2 { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types[1].is_null() { - return Ok(DataType::Null); - } - Ok(match arg_types[0] { - DataType::Utf8View - | DataType::LargeUtf8 - | DataType::Utf8 - | DataType::Binary - | DataType::BinaryView - | DataType::LargeBinary => DataType::Utf8, - DataType::Null => DataType::Null, - _ => { - return exec_err!( - "{} function can only accept strings or binary arrays.", - self.name() - ) - } - }) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| { - internal_datafusion_err!("Expected 2 arguments for function sha2") - })?; + let [values, bit_lengths] = take_function_args(self.name(), args.args.iter())?; - sha2(args) - } + match (values, bit_lengths) { + ( + ColumnarValue::Scalar(value_scalar), + ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length))), + ) => { + if value_scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } - fn aliases(&self) -> &[String] { - &self.aliases - } - - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 { - return Err(invalid_arg_count_exec_err( - self.name(), - (2, 2), - arg_types.len(), - )); - } - let expr_type = match &arg_types[0] { - DataType::Utf8View - | DataType::LargeUtf8 - | DataType::Utf8 - | DataType::Binary - | DataType::BinaryView - | DataType::LargeBinary - | DataType::Null => Ok(arg_types[0].clone()), - _ => Err(unsupported_data_type_exec_err( - self.name(), - "String, Binary", - &arg_types[0], - )), - }?; - let bit_length_type = if arg_types[1].is_numeric() { - Ok(DataType::Int32) - } else if arg_types[1].is_null() { - Ok(DataType::Null) - } else { - Err(unsupported_data_type_exec_err( - self.name(), - "Numeric Type", - &arg_types[1], - )) - }?; - - Ok(vec![expr_type, bit_length_type]) - } -} + // Accept both Binary and Utf8 scalars (depending on coercion) + let bytes = match value_scalar { + ScalarValue::Binary(Some(b)) => b.as_slice(), + ScalarValue::LargeBinary(Some(b)) => b.as_slice(), + ScalarValue::BinaryView(Some(b)) => b.as_slice(), + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => s.as_bytes(), + other => { + return internal_err!( + "Unsupported scalar datatype for sha2: {}", + other.data_type() + ); + } + }; -pub fn sha2(args: [ColumnarValue; 2]) -> Result { - match args { - [ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg)))] => { - compute_sha2( - bit_length_arg, - &[ColumnarValue::from(ScalarValue::Utf8(expr_arg))], - ) - } - [ColumnarValue::Array(expr_arg), ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length_arg)))] => { - compute_sha2(bit_length_arg, &[ColumnarValue::from(expr_arg)]) - } - [ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg)), ColumnarValue::Array(bit_length_arg)] => - { - let arr: StringArray = bit_length_arg - .as_primitive::() - .iter() - .map(|bit_length| { - match sha2([ - ColumnarValue::Scalar(ScalarValue::Utf8(expr_arg.clone())), - ColumnarValue::Scalar(ScalarValue::Int32(bit_length)), - ]) - .unwrap() - { - ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str, - ColumnarValue::Array(arr) => arr - .as_string::() - .iter() - .map(|str| str.unwrap().to_string()) - .next(), // first element - _ => unreachable!(), + let out = match bit_length { + 224 => { + let mut digest = sha2::Sha224::default(); + digest.update(bytes); + Some(hex_encode(digest.finalize())) } - }) - .collect(); - Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef)) - } - [ColumnarValue::Array(expr_arg), ColumnarValue::Array(bit_length_arg)] => { - let expr_iter = expr_arg.as_string::().iter(); - let bit_length_iter = bit_length_arg.as_primitive::().iter(); - let arr: StringArray = expr_iter - .zip(bit_length_iter) - .map(|(expr, bit_length)| { - match sha2([ - ColumnarValue::Scalar(ScalarValue::Utf8(Some( - expr.unwrap().to_string(), - ))), - ColumnarValue::Scalar(ScalarValue::Int32(bit_length)), - ]) - .unwrap() - { - ColumnarValue::Scalar(ScalarValue::Utf8(str)) => str, - ColumnarValue::Array(arr) => arr - .as_string::() - .iter() - .map(|str| str.unwrap().to_string()) - .next(), // first element - _ => unreachable!(), + 0 | 256 => { + let mut digest = sha2::Sha256::default(); + digest.update(bytes); + Some(hex_encode(digest.finalize())) + } + 384 => { + let mut digest = sha2::Sha384::default(); + digest.update(bytes); + Some(hex_encode(digest.finalize())) } - }) - .collect(); - Ok(ColumnarValue::Array(Arc::new(arr) as ArrayRef)) + 512 => { + let mut digest = sha2::Sha512::default(); + digest.update(bytes); + Some(hex_encode(digest.finalize())) + } + _ => None, + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(out))) + } + // Array values + scalar bit length (common case: sha2(col, 256)) + ( + ColumnarValue::Array(values_array), + ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length))), + ) => { + let output: ArrayRef = match values_array.data_type() { + DataType::Binary => sha2_binary_scalar_bitlen( + &values_array.as_binary::(), + *bit_length, + ), + DataType::LargeBinary => sha2_binary_scalar_bitlen( + &values_array.as_binary::(), + *bit_length, + ), + DataType::BinaryView => sha2_binary_scalar_bitlen( + &values_array.as_binary_view(), + *bit_length, + ), + dt => return internal_err!("Unsupported datatype for sha2: {dt}"), + }; + Ok(ColumnarValue::Array(output)) + } + ( + ColumnarValue::Scalar(_), + ColumnarValue::Scalar(ScalarValue::Int32(None)), + ) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + ( + ColumnarValue::Array(_), + ColumnarValue::Scalar(ScalarValue::Int32(None)), + ) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + _ => { + // Fallback to existing behavior for any array/mixed cases + make_scalar_function(sha2_impl, vec![])(&args.args) + } } - _ => exec_err!("Unsupported argument types for sha2 function"), } } -fn compute_sha2( - bit_length_arg: i32, - expr_arg: &[ColumnarValue], -) -> Result { - match bit_length_arg { - 0 | 256 => sha256(expr_arg), - 224 => sha224(expr_arg), - 384 => sha384(expr_arg), - 512 => sha512(expr_arg), - _ => { - // Return null for unsupported bit lengths instead of error, because spark sha2 does not - // error out for this. - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); +fn sha2_impl(args: &[ArrayRef]) -> Result { + let [values, bit_lengths] = take_function_args("sha2", args)?; + + let bit_lengths = bit_lengths.as_primitive::(); + let output = match values.data_type() { + DataType::Binary => sha2_binary_impl(&values.as_binary::(), bit_lengths), + DataType::LargeBinary => { + sha2_binary_impl(&values.as_binary::(), bit_lengths) } + DataType::BinaryView => sha2_binary_impl(&values.as_binary_view(), bit_lengths), + dt => return internal_err!("Unsupported datatype for sha2: {dt}"), + }; + Ok(output) +} + +fn sha2_binary_impl<'a, BinaryArrType>( + values: &BinaryArrType, + bit_lengths: &Int32Array, +) -> ArrayRef +where + BinaryArrType: BinaryArrayType<'a>, +{ + sha2_binary_bitlen_iter(values, bit_lengths.iter()) +} + +fn sha2_binary_scalar_bitlen<'a, BinaryArrType>( + values: &BinaryArrType, + bit_length: i32, +) -> ArrayRef +where + BinaryArrType: BinaryArrayType<'a>, +{ + sha2_binary_bitlen_iter(values, std::iter::repeat(Some(bit_length))) +} + +fn sha2_binary_bitlen_iter<'a, BinaryArrType, I>( + values: &BinaryArrType, + bit_lengths: I, +) -> ArrayRef +where + BinaryArrType: BinaryArrayType<'a>, + I: Iterator>, +{ + let array = values + .iter() + .zip(bit_lengths) + .map(|(value, bit_length)| match (value, bit_length) { + (Some(value), Some(224)) => { + let mut digest = sha2::Sha224::default(); + digest.update(value); + Some(hex_encode(digest.finalize())) + } + (Some(value), Some(0 | 256)) => { + let mut digest = sha2::Sha256::default(); + digest.update(value); + Some(hex_encode(digest.finalize())) + } + (Some(value), Some(384)) => { + let mut digest = sha2::Sha384::default(); + digest.update(value); + Some(hex_encode(digest.finalize())) + } + (Some(value), Some(512)) => { + let mut digest = sha2::Sha512::default(); + digest.update(value); + Some(hex_encode(digest.finalize())) + } + // Unknown bit-lengths go to null, same as in Spark + _ => None, + }) + .collect::(); + Arc::new(array) +} + +const HEX_CHARS: [u8; 16] = *b"0123456789abcdef"; + +#[inline] +fn hex_encode>(data: T) -> String { + let bytes = data.as_ref(); + let mut out = Vec::with_capacity(bytes.len() * 2); + for &b in bytes { + let hi = b >> 4; + let lo = b & 0x0F; + out.push(HEX_CHARS[hi as usize]); + out.push(HEX_CHARS[lo as usize]); } - .map(|hashed| spark_sha2_hex(&[hashed]).unwrap()) + // SAFETY: out contains only ASCII + unsafe { String::from_utf8_unchecked(out) } } diff --git a/datafusion/spark/src/function/hash/utils.rs b/datafusion/spark/src/function/hash/utils.rs new file mode 100644 index 0000000000000..e7918d33ec3e7 --- /dev/null +++ b/datafusion/spark/src/function/hash/utils.rs @@ -0,0 +1,1005 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shared helpers for Spark-compatible hash functions. +//! +//! Generic helpers ([`hash_primitive_values`], [`hash_bytes_values`]) cover +//! the cases where a function-based abstraction is clean; less uniform paths +//! (booleans, decimals, list/map dispatch) remain as macros. +//! +//! Ported from Apache DataFusion Comet: +//! + +use arrow::array::{Array, PrimitiveArray}; +use arrow::datatypes::ArrowPrimitiveType; +use datafusion_common::Result; + +/// Hash the values of a primitive array, calling `transform` to convert each +/// value into a byte slice before feeding it to `hash_method`. +#[inline] +pub(crate) fn hash_primitive_values( + array: &PrimitiveArray

, + hashes: &mut [u64], + transform: F, + hash_method: H, +) where + P: ArrowPrimitiveType, + F: Fn(P::Native) -> B, + B: AsRef<[u8]>, + H: Fn(B, u64) -> u64, +{ + let values = array.values(); + if array.null_count() == 0 { + // Fast path: no nulls, skip null checks + for (i, hash) in hashes.iter_mut().enumerate() { + *hash = hash_method(transform(values[i]), *hash); + } + } else { + // Slow path: check each row for null + for (i, hash) in hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = hash_method(transform(values[i]), *hash); + } + } + } +} + +/// Hash an array by calling `value_at(array, i)` to produce a byte slice for +/// each non-null row. Used for byte-slice arrays (Utf8/Binary/FixedSizeBinary +/// and their Large/View variants) and for arrays whose values need a small +/// transform (e.g. booleans widened to `i32`, decimals to `i128` bytes). +#[inline] +pub(crate) fn hash_bytes_values( + array: &A, + hashes: &mut [u64], + value_at: impl Fn(&A, usize) -> B, + hash_method: H, +) where + A: Array, + B: AsRef<[u8]>, + H: Fn(B, u64) -> u64, +{ + if array.null_count() == 0 { + for (i, hash) in hashes.iter_mut().enumerate() { + *hash = hash_method(value_at(array, i), *hash); + } + } else { + for (i, hash) in hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = hash_method(value_at(array, i), *hash); + } + } + } +} + +/// Fallible variant of [`hash_bytes_values`]: `value_at` may return an error +/// (used by the small-decimal path, where values may not fit in `i64`). +#[inline] +pub(crate) fn try_hash_bytes_values( + array: &A, + hashes: &mut [u64], + value_at: impl Fn(&A, usize) -> Result, + hash_method: H, +) -> Result<()> +where + A: Array, + B: AsRef<[u8]>, + H: Fn(B, u64) -> u64, +{ + if array.null_count() == 0 { + for (i, hash) in hashes.iter_mut().enumerate() { + *hash = hash_method(value_at(array, i)?, *hash); + } + } else { + for (i, hash) in hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = hash_method(value_at(array, i)?, *hash); + } + } + } + Ok(()) +} + +/// Downcast `$column` to `$array_type`, panicking with a useful error if the +/// type does not match. +#[macro_export] +#[doc(hidden)] +macro_rules! downcast_array { + ($column:ident, $array_type:ident) => { + $column + .as_any() + .downcast_ref::<$array_type>() + .unwrap_or_else(|| { + panic!( + "Failed to downcast column to {}. Actual data type: {:?}.", + stringify!($array_type), + $column.data_type() + ) + }) + }; +} + +/// Hash a byte-slice-accessor array (strings, binary, fixed-size binary, ...) +/// where `value(i)` returns the item directly. +/// +/// Kept as a macro because `a.value(i)` returns a borrow whose lifetime +/// depends on the array, which Rust's closure inference does not propagate +/// through a `Fn(&A, usize) -> B` bound with a free type parameter `B`. +#[macro_export] +#[doc(hidden)] +macro_rules! hash_array { + ($array_type:ident, $column:ident, $hashes:ident, $hash_method:ident) => { + let array = $crate::downcast_array!($column, $array_type); + if array.null_count() == 0 { + for i in 0..$hashes.len() { + $hashes[i] = $hash_method(&array.value(i), $hashes[i]); + } + } else { + for i in 0..$hashes.len() { + if !array.is_null(i) { + $hashes[i] = $hash_method(&array.value(i), $hashes[i]); + } + } + } + }; +} + +/// Hash a `BooleanArray`, widening each value to `$hash_input_type` first. +#[macro_export] +#[doc(hidden)] +macro_rules! hash_array_boolean { + ($array_type:ident, $column:ident, $hash_input_type:ident, $hashes:ident, $hash_method:ident) => { + let array = $crate::downcast_array!($column, $array_type); + $crate::function::hash::utils::hash_bytes_values( + array, + $hashes, + |a, i| $hash_input_type::from(a.value(i)).to_le_bytes(), + $hash_method, + ); + }; +} + +/// Hash a primitive array, widening each native value to `$ty` before hashing. +#[macro_export] +#[doc(hidden)] +macro_rules! hash_array_primitive { + ($array_type:ident, $column:ident, $ty:ident, $hashes:ident, $hash_method:ident) => { + let array = $crate::downcast_array!($column, $array_type); + $crate::function::hash::utils::hash_primitive_values( + array, + $hashes, + |v| (v as $ty).to_le_bytes(), + $hash_method, + ); + }; +} + +/// Hash a floating-point primitive array, normalizing `-0.0` to `0.0` per Spark. +#[macro_export] +#[doc(hidden)] +macro_rules! hash_array_primitive_float { + ($array_type:ident, $column:ident, $ty:ident, $ty2:ident, $hashes:ident, $hash_method:ident) => { + let array = $crate::downcast_array!($column, $array_type); + $crate::function::hash::utils::hash_primitive_values( + array, + $hashes, + |v| { + // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression. + if v == 0.0 && v.is_sign_negative() { + (0 as $ty2).to_le_bytes() + } else { + (v as $ty).to_le_bytes() + } + }, + $hash_method, + ); + }; +} + +/// Hash a small (precision <= 18) decimal array by reducing each value to `i64`. +/// Errors via `?` if a value does not fit in `i64`. +#[macro_export] +#[doc(hidden)] +macro_rules! hash_array_small_decimal { + ($array_type:ident, $column:ident, $hashes:ident, $hash_method:ident) => { + let array = $crate::downcast_array!($column, $array_type); + $crate::function::hash::utils::try_hash_bytes_values( + array, + $hashes, + |a, i| { + i64::try_from(a.value(i)) + .map(|v| v.to_le_bytes()) + .map_err(|e| DataFusionError::Execution(e.to_string())) + }, + $hash_method, + )?; + }; +} + +/// Hash a wide decimal array using the raw 128-bit byte representation. +#[macro_export] +#[doc(hidden)] +macro_rules! hash_array_decimal { + ($array_type:ident, $column:ident, $hashes:ident, $hash_method:ident) => { + let array = $crate::downcast_array!($column, $array_type); + $crate::function::hash::utils::hash_bytes_values( + array, + $hashes, + |a, i| a.value(i).to_le_bytes(), + $hash_method, + ); + }; +} + +/// Hash a list array with primitive elements by directly accessing the underlying buffer. +/// This avoids the overhead of slicing and recursive calls for common cases. +/// Supports both variable-length lists (with offsets) and fixed-size lists. +#[macro_export] +#[doc(hidden)] +macro_rules! hash_list_primitive { + // Variable-length list variant (List/LargeList) + (offsets: $offsets:expr, $list_array:ident, $elem_array:ident, $hashes:ident, $hash_method:ident, $value_transform:expr) => { + if $list_array.null_count() == 0 && $elem_array.null_count() == 0 { + for (row_idx, hash) in $hashes.iter_mut().enumerate() { + let start = $offsets[row_idx] as usize; + let end = $offsets[row_idx + 1] as usize; + for elem_idx in start..end { + let value = $elem_array.value(elem_idx); + *hash = $hash_method($value_transform(value), *hash); + } + } + } else { + for (row_idx, hash) in $hashes.iter_mut().enumerate() { + if !$list_array.is_null(row_idx) { + let start = $offsets[row_idx] as usize; + let end = $offsets[row_idx + 1] as usize; + for elem_idx in start..end { + if !$elem_array.is_null(elem_idx) { + let value = $elem_array.value(elem_idx); + *hash = $hash_method($value_transform(value), *hash); + } + } + } + } + } + }; + // Fixed-size list variant + (fixed_size: $list_size:expr, $list_array:ident, $elem_array:ident, $hashes:ident, $hash_method:ident, $value_transform:expr) => { + if $list_array.null_count() == 0 && $elem_array.null_count() == 0 { + for (row_idx, hash) in $hashes.iter_mut().enumerate() { + let start = row_idx * $list_size; + for elem_idx in 0..$list_size { + let value = $elem_array.value(start + elem_idx); + *hash = $hash_method($value_transform(value), *hash); + } + } + } else { + for (row_idx, hash) in $hashes.iter_mut().enumerate() { + if !$list_array.is_null(row_idx) { + let start = row_idx * $list_size; + for elem_idx in 0..$list_size { + if !$elem_array.is_null(start + elem_idx) { + let value = $elem_array.value(start + elem_idx); + *hash = $hash_method($value_transform(value), *hash); + } + } + } + } + } + }; +} + +/// Hash a list array by recursively hashing each element. +/// For each row, we hash all elements in the list. +/// Spark hashes arrays by recursively hashing each element, where each +/// element's hash is computed using the previous element's hash as the seed. +/// This creates a chain: hash(elem_n, hash(elem_n-1, ... hash(elem_0, seed)...)) +/// Dispatches hash operations for List/LargeList/FixedSizeList arrays with primitive element types. +/// This macro eliminates duplication by handling the type-to-array mapping for all supported primitives. +#[macro_export] +#[doc(hidden)] +macro_rules! hash_list_with_primitive_elements { + // Variant for List/LargeList with offsets + (offsets: $list_array_type:ident, $list_array:ident, $values:ident, $offsets:ident, $field:expr, $hashes_buffer:ident, $hash_method:ident, $recursive_hash_method:ident, $fallback_offset_type:ty, $col:ident) => { + match $field.data_type() { + DataType::Int8 => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(offsets: $offsets, $list_array, elem_array, $hashes_buffer, $hash_method, |v: i8| (v as i32).to_le_bytes()); + } + DataType::Int16 => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(offsets: $offsets, $list_array, elem_array, $hashes_buffer, $hash_method, |v: i16| (v as i32).to_le_bytes()); + } + DataType::Int32 => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(offsets: $offsets, $list_array, elem_array, $hashes_buffer, $hash_method, |v: i32| v.to_le_bytes()); + } + DataType::Int64 => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(offsets: $offsets, $list_array, elem_array, $hashes_buffer, $hash_method, |v: i64| v.to_le_bytes()); + } + DataType::Float32 => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(offsets: $offsets, $list_array, elem_array, $hashes_buffer, $hash_method, + |v: f32| if v == 0.0 && v.is_sign_negative() { (0_i32).to_le_bytes() } else { v.to_le_bytes() }); + } + DataType::Float64 => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(offsets: $offsets, $list_array, elem_array, $hashes_buffer, $hash_method, + |v: f64| if v == 0.0 && v.is_sign_negative() { (0_i64).to_le_bytes() } else { v.to_le_bytes() }); + } + DataType::Boolean => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(offsets: $offsets, $list_array, elem_array, $hashes_buffer, $hash_method, |v: bool| (i32::from(v)).to_le_bytes()); + } + DataType::Utf8 => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + if $list_array.null_count() == 0 && elem_array.null_count() == 0 { + for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() { + let start = $offsets[row_idx] as usize; + let end = $offsets[row_idx + 1] as usize; + for elem_idx in start..end { + *hash = $hash_method(elem_array.value(elem_idx), *hash); + } + } + } else { + for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() { + if !$list_array.is_null(row_idx) { + let start = $offsets[row_idx] as usize; + let end = $offsets[row_idx + 1] as usize; + for elem_idx in start..end { + if !elem_array.is_null(elem_idx) { + *hash = $hash_method(elem_array.value(elem_idx), *hash); + } + } + } + } + } + } + DataType::Binary => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + if $list_array.null_count() == 0 && elem_array.null_count() == 0 { + for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() { + let start = $offsets[row_idx] as usize; + let end = $offsets[row_idx + 1] as usize; + for elem_idx in start..end { + *hash = $hash_method(elem_array.value(elem_idx), *hash); + } + } + } else { + for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() { + if !$list_array.is_null(row_idx) { + let start = $offsets[row_idx] as usize; + let end = $offsets[row_idx + 1] as usize; + for elem_idx in start..end { + if !elem_array.is_null(elem_idx) { + *hash = $hash_method(elem_array.value(elem_idx), *hash); + } + } + } + } + } + } + DataType::Date32 => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(offsets: $offsets, $list_array, elem_array, $hashes_buffer, $hash_method, |v: i32| v.to_le_bytes()); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(offsets: $offsets, $list_array, elem_array, $hashes_buffer, $hash_method, |v: i64| v.to_le_bytes()); + } + _ => { + // Fall back to recursive approach for complex element types + $crate::hash_list_array!($list_array_type, $fallback_offset_type, $col, $hashes_buffer, $recursive_hash_method); + } + } + }; + // Variant for FixedSizeList with fixed size + (fixed_size: $list_array:ident, $values:ident, $list_size:ident, $field:expr, $hashes_buffer:ident, $hash_method:ident, $recursive_hash_method:ident) => { + match $field.data_type() { + DataType::Int8 => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(fixed_size: $list_size, $list_array, elem_array, $hashes_buffer, $hash_method, |v: i8| (v as i32).to_le_bytes()); + } + DataType::Int16 => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(fixed_size: $list_size, $list_array, elem_array, $hashes_buffer, $hash_method, |v: i16| (v as i32).to_le_bytes()); + } + DataType::Int32 => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(fixed_size: $list_size, $list_array, elem_array, $hashes_buffer, $hash_method, |v: i32| v.to_le_bytes()); + } + DataType::Int64 => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(fixed_size: $list_size, $list_array, elem_array, $hashes_buffer, $hash_method, |v: i64| v.to_le_bytes()); + } + DataType::Float32 => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(fixed_size: $list_size, $list_array, elem_array, $hashes_buffer, $hash_method, + |v: f32| if v == 0.0 && v.is_sign_negative() { (0_i32).to_le_bytes() } else { v.to_le_bytes() }); + } + DataType::Float64 => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(fixed_size: $list_size, $list_array, elem_array, $hashes_buffer, $hash_method, + |v: f64| if v == 0.0 && v.is_sign_negative() { (0_i64).to_le_bytes() } else { v.to_le_bytes() }); + } + DataType::Boolean => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(fixed_size: $list_size, $list_array, elem_array, $hashes_buffer, $hash_method, |v: bool| (i32::from(v)).to_le_bytes()); + } + DataType::Utf8 => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + if $list_array.null_count() == 0 && elem_array.null_count() == 0 { + for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() { + let start = row_idx * $list_size; + for elem_idx in 0..$list_size { + *hash = $hash_method(elem_array.value(start + elem_idx), *hash); + } + } + } else { + for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() { + if !$list_array.is_null(row_idx) { + let start = row_idx * $list_size; + for elem_idx in 0..$list_size { + if !elem_array.is_null(start + elem_idx) { + *hash = $hash_method(elem_array.value(start + elem_idx), *hash); + } + } + } + } + } + } + DataType::Binary => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + if $list_array.null_count() == 0 && elem_array.null_count() == 0 { + for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() { + let start = row_idx * $list_size; + for elem_idx in 0..$list_size { + *hash = $hash_method(elem_array.value(start + elem_idx), *hash); + } + } + } else { + for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() { + if !$list_array.is_null(row_idx) { + let start = row_idx * $list_size; + for elem_idx in 0..$list_size { + if !elem_array.is_null(start + elem_idx) { + *hash = $hash_method(elem_array.value(start + elem_idx), *hash); + } + } + } + } + } + } + DataType::Date32 => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(fixed_size: $list_size, $list_array, elem_array, $hashes_buffer, $hash_method, |v: i32| v.to_le_bytes()); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + let elem_array = $values.as_any().downcast_ref::().unwrap(); + $crate::hash_list_primitive!(fixed_size: $list_size, $list_array, elem_array, $hashes_buffer, $hash_method, |v: i64| v.to_le_bytes()); + } + _ => { + // Fall back to recursive approach for complex element types + if $list_array.null_count() == 0 { + for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() { + let start = row_idx * $list_size; + for elem_idx in 0..$list_size { + let elem_array = $values.slice(start + elem_idx, 1); + let mut single_hash = [*hash]; + $recursive_hash_method(&[elem_array], &mut single_hash)?; + *hash = single_hash[0]; + } + } + } else { + for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() { + if !$list_array.is_null(row_idx) { + let start = row_idx * $list_size; + for elem_idx in 0..$list_size { + let elem_array = $values.slice(start + elem_idx, 1); + let mut single_hash = [*hash]; + $recursive_hash_method(&[elem_array], &mut single_hash)?; + *hash = single_hash[0]; + } + } + } + } + } + } + }; +} + +/// Hash a map array by hashing its key/value entries in order. +/// +/// Specializes for common key/value type combinations so the hot path stays +/// monomorphic; falls back to recursive hashing for arbitrary nested types. +#[macro_export] +#[doc(hidden)] +macro_rules! hash_map_array { + // Specialized variant: typed key and value arrays, byte transforms. + ($map_array:ident, $key_array:ident, $value_array:ident, $offsets:ident, $hashes_buffer:ident, $hash_method:ident, $key_transform:expr, $value_transform:expr) => { + if $map_array.null_count() == 0 + && $key_array.null_count() == 0 + && $value_array.null_count() == 0 + { + for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() { + let start = $offsets[row_idx] as usize; + let end = $offsets[row_idx + 1] as usize; + for entry_idx in start..end { + *hash = + $hash_method($key_transform($key_array.value(entry_idx)), *hash); + *hash = $hash_method( + $value_transform($value_array.value(entry_idx)), + *hash, + ); + } + } + } else { + for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() { + if !$map_array.is_null(row_idx) { + let start = $offsets[row_idx] as usize; + let end = $offsets[row_idx + 1] as usize; + for entry_idx in start..end { + if !$key_array.is_null(entry_idx) { + *hash = $hash_method( + $key_transform($key_array.value(entry_idx)), + *hash, + ); + } + if !$value_array.is_null(entry_idx) { + *hash = $hash_method( + $value_transform($value_array.value(entry_idx)), + *hash, + ); + } + } + } + } + } + }; + // Fallback variant: recursively hash each key/value entry. + (recursive: $map_array:ident, $keys:ident, $values:ident, $offsets:ident, $hashes_buffer:ident, $recursive_hash_method:ident) => { + if $map_array.null_count() == 0 { + for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() { + let start = $offsets[row_idx] as usize; + let end = $offsets[row_idx + 1] as usize; + for entry_idx in start..end { + let key_array = $keys.slice(entry_idx, 1); + let mut single_hash = [*hash]; + $recursive_hash_method(&[key_array], &mut single_hash)?; + *hash = single_hash[0]; + + let value_array = $values.slice(entry_idx, 1); + single_hash = [*hash]; + $recursive_hash_method(&[value_array], &mut single_hash)?; + *hash = single_hash[0]; + } + } + } else { + for (row_idx, hash) in $hashes_buffer.iter_mut().enumerate() { + if !$map_array.is_null(row_idx) { + let start = $offsets[row_idx] as usize; + let end = $offsets[row_idx + 1] as usize; + for entry_idx in start..end { + let key_array = $keys.slice(entry_idx, 1); + let mut single_hash = [*hash]; + $recursive_hash_method(&[key_array], &mut single_hash)?; + *hash = single_hash[0]; + + let value_array = $values.slice(entry_idx, 1); + single_hash = [*hash]; + $recursive_hash_method(&[value_array], &mut single_hash)?; + *hash = single_hash[0]; + } + } + } + } + }; +} + +/// Dispatch over a `MapArray`'s key/value types, calling [`hash_map_array!`] +/// with the right specialization (or the recursive fallback). +#[macro_export] +#[doc(hidden)] +macro_rules! hash_map_with_typed_entries { + ($col:ident, $field:expr, $hashes_buffer:ident, $hash_method:ident, $recursive_hash_method:ident) => { + let map_array = $col.as_any().downcast_ref::().unwrap(); + let keys = map_array.keys(); + let values = map_array.values(); + let offsets = map_array.offsets(); + + if let DataType::Struct(fields) = $field.data_type() { + let key_type = &fields[0].data_type(); + let value_type = &fields[1].data_type(); + + match (key_type, value_type) { + (DataType::Utf8, DataType::Int32) => { + let key_array = keys.as_any().downcast_ref::().unwrap(); + let value_array = values.as_any().downcast_ref::().unwrap(); + $crate::hash_map_array!( + map_array, key_array, value_array, offsets, $hashes_buffer, + $hash_method, |v| v, |v: i32| v.to_le_bytes() + ); + } + (DataType::Int32, DataType::Utf8) => { + let key_array = keys.as_any().downcast_ref::().unwrap(); + let value_array = values.as_any().downcast_ref::().unwrap(); + $crate::hash_map_array!( + map_array, key_array, value_array, offsets, $hashes_buffer, + $hash_method, |v: i32| v.to_le_bytes(), |v| v + ); + } + (DataType::Utf8, DataType::Utf8) => { + let key_array = keys.as_any().downcast_ref::().unwrap(); + let value_array = values.as_any().downcast_ref::().unwrap(); + $crate::hash_map_array!( + map_array, key_array, value_array, offsets, $hashes_buffer, + $hash_method, |v| v, |v| v + ); + } + (DataType::Int32, DataType::Int32) => { + let key_array = keys.as_any().downcast_ref::().unwrap(); + let value_array = values.as_any().downcast_ref::().unwrap(); + $crate::hash_map_array!( + map_array, key_array, value_array, offsets, $hashes_buffer, + $hash_method, |v: i32| v.to_le_bytes(), |v: i32| v.to_le_bytes() + ); + } + _ => { + $crate::hash_map_array!( + recursive: map_array, keys, values, offsets, + $hashes_buffer, $recursive_hash_method + ); + } + } + } else { + return Err(DataFusionError::Internal(format!( + "Map field type must be a struct, got: {}", + $field.data_type() + ))); + } + }; +} + +#[macro_export] +#[doc(hidden)] +macro_rules! hash_list_array { + ($array_type:ident, $offset_type:ty, $column: ident, $hashes: ident, $recursive_hash_method: ident) => { + let list_array = $column + .as_any() + .downcast_ref::<$array_type>() + .unwrap_or_else(|| { + panic!( + "Failed to downcast column to {}. Actual data type: {:?}.", + stringify!($array_type), + $column.data_type() + ) + }); + + let values = list_array.values(); + let offsets = list_array.offsets(); + + if list_array.null_count() == 0 { + // Fast path: no nulls, skip null checks + for (row_idx, hash) in $hashes.iter_mut().enumerate() { + let start = offsets[row_idx] as usize; + let end = offsets[row_idx + 1] as usize; + let len = end - start; + // Hash each element in sequence, chaining the hash values + for elem_idx in 0..len { + let elem_array = values.slice(start + elem_idx, 1); + let mut single_hash = [*hash]; + $recursive_hash_method(&[elem_array], &mut single_hash)?; + *hash = single_hash[0]; + } + } + } else { + // Slow path: array has nulls, check each row + for (row_idx, hash) in $hashes.iter_mut().enumerate() { + if !list_array.is_null(row_idx) { + let start = offsets[row_idx] as usize; + let end = offsets[row_idx + 1] as usize; + let len = end - start; + // Hash each element in sequence, chaining the hash values + for elem_idx in 0..len { + let elem_array = values.slice(start + elem_idx, 1); + let mut single_hash = [*hash]; + $recursive_hash_method(&[elem_array], &mut single_hash)?; + *hash = single_hash[0]; + } + } + } + } + }; +} + +/// Creates hash values for every row, based on the values in the +/// columns. +/// +/// The number of rows to hash is determined by `hashes_buffer.len()`. +/// `hashes_buffer` should be pre-sized appropriately +/// +/// `hash_method` is the hash function to use. +/// `create_dictionary_hash_method` is the function to create hashes for dictionary arrays input. +/// `recursive_hash_method` is the function to call for recursive hashing of complex types. +#[macro_export] +#[doc(hidden)] +macro_rules! create_hashes_internal { + ($arrays: ident, $hashes_buffer: ident, $hash_method: ident, $create_dictionary_hash_method: ident, $recursive_hash_method: ident) => { + use arrow::array::{types::*, *}; + use arrow::datatypes::{DataType, TimeUnit}; + use datafusion_common::DataFusionError; + + for (i, col) in $arrays.iter().enumerate() { + let first_col = i == 0; + match col.data_type() { + DataType::Boolean => { + $crate::hash_array_boolean!( + BooleanArray, + col, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Int8 => { + $crate::hash_array_primitive!( + Int8Array, + col, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Int16 => { + $crate::hash_array_primitive!( + Int16Array, + col, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Int32 => { + $crate::hash_array_primitive!( + Int32Array, + col, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Int64 => { + $crate::hash_array_primitive!( + Int64Array, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Float32 => { + $crate::hash_array_primitive_float!( + Float32Array, + col, + f32, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Float64 => { + $crate::hash_array_primitive_float!( + Float64Array, + col, + f64, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Second, _) => { + $crate::hash_array_primitive!( + TimestampSecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + $crate::hash_array_primitive!( + TimestampMillisecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + $crate::hash_array_primitive!( + TimestampMicrosecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + $crate::hash_array_primitive!( + TimestampNanosecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Date32 => { + $crate::hash_array_primitive!( + Date32Array, + col, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Date64 => { + $crate::hash_array_primitive!( + Date64Array, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Utf8 => { + $crate::hash_array!(StringArray, col, $hashes_buffer, $hash_method); + } + DataType::LargeUtf8 => { + $crate::hash_array!(LargeStringArray, col, $hashes_buffer, $hash_method); + } + DataType::Utf8View => { + $crate::hash_array!(StringViewArray, col, $hashes_buffer, $hash_method); + } + DataType::Binary => { + $crate::hash_array!(BinaryArray, col, $hashes_buffer, $hash_method); + } + DataType::LargeBinary => { + $crate::hash_array!(LargeBinaryArray, col, $hashes_buffer, $hash_method); + } + DataType::BinaryView => { + $crate::hash_array!(BinaryViewArray, col, $hashes_buffer, $hash_method); + } + DataType::FixedSizeBinary(_) => { + $crate::hash_array!(FixedSizeBinaryArray, col, $hashes_buffer, $hash_method); + } + DataType::Null => { + // Nulls don't update the hash + } + // Apache Spark: if it's a small decimal, i.e. precision <= 18, turn it into long and hash it. + // Else, turn it into bytes and hash it. + DataType::Decimal128(precision, _) if *precision <= 18 => { + $crate::hash_array_small_decimal!(Decimal128Array, col, $hashes_buffer, $hash_method); + } + DataType::Decimal128(_, _) => { + $crate::hash_array_decimal!(Decimal128Array, col, $hashes_buffer, $hash_method); + } + DataType::Dictionary(index_type, _) => match **index_type { + DataType::Int8 => { + $create_dictionary_hash_method::(col, $hashes_buffer, first_col)?; + } + DataType::Int16 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::Int32 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::Int64 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt8 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt16 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt32 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt64 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported dictionary type in hasher hashing: {}", + col.data_type(), + ))) + } + }, + DataType::List(field) => { + let list_array = col.as_any().downcast_ref::().unwrap(); + let values = list_array.values(); + let offsets = list_array.offsets(); + + $crate::hash_list_with_primitive_elements!(offsets: ListArray, list_array, values, offsets, field, $hashes_buffer, $hash_method, $recursive_hash_method, i32, col); + } + DataType::LargeList(field) => { + let list_array = col.as_any().downcast_ref::().unwrap(); + let values = list_array.values(); + let offsets = list_array.offsets(); + + $crate::hash_list_with_primitive_elements!(offsets: LargeListArray, list_array, values, offsets, field, $hashes_buffer, $hash_method, $recursive_hash_method, i64, col); + } + DataType::FixedSizeList(field, size) => { + let list_array = col.as_any().downcast_ref::().unwrap(); + let values = list_array.values(); + let list_size = *size as usize; + + $crate::hash_list_with_primitive_elements!(fixed_size: list_array, values, list_size, field, $hashes_buffer, $hash_method, $recursive_hash_method); + } + DataType::Struct(_) => { + let struct_array = col.as_any().downcast_ref::().unwrap(); + // Hash each field of the struct - Spark hashes all fields recursively + let columns: Vec = struct_array.columns().to_vec(); + if !columns.is_empty() { + $recursive_hash_method(&columns, $hashes_buffer)?; + } + } + DataType::Map(field, _) => { + $crate::hash_map_with_typed_entries!( + col, field, $hashes_buffer, $hash_method, $recursive_hash_method + ); + } + _ => { + // This is internal because we should have caught this before. + return Err(DataFusionError::Internal(format!( + "Unsupported data type in hasher: {}", + col.data_type() + ))); + } + } + } + }; +} diff --git a/datafusion/spark/src/function/hash/xxhash64.rs b/datafusion/spark/src/function/hash/xxhash64.rs new file mode 100644 index 0000000000000..5dca47bcb8984 --- /dev/null +++ b/datafusion/spark/src/function/hash/xxhash64.rs @@ -0,0 +1,445 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, DictionaryArray, Int64Array, types::ArrowDictionaryKeyType, +}; +use arrow::buffer::{Buffer, ScalarBuffer}; +use arrow::compute::take; +use arrow::datatypes::{ArrowNativeType, DataType, Field, FieldRef}; +use datafusion_common::{Result, ScalarValue, internal_err}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use twox_hash::XxHash64; + +use crate::create_hashes_internal; + +const DEFAULT_SEED: u64 = 42; + +/// Spark-compatible xxhash64 function. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkXxhash64 { + signature: Signature, +} + +impl Default for SparkXxhash64 { + fn default() -> Self { + Self::new() + } +} + +impl SparkXxhash64 { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkXxhash64 { + fn name(&self) -> &str { + "xxhash64" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + // Spark's HashExpression overrides nullable to false: NULL inputs are + // skipped and the seed is used, so the result is never null. + Ok(Arc::new(Field::new(self.name(), DataType::Int64, false))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let num_rows = args.number_rows; + let mut hashes: Vec = vec![DEFAULT_SEED; num_rows]; + + let arrays = ColumnarValue::values_to_arrays(&args.args)?; + create_xxhash64_hashes(&arrays, &mut hashes)?; + + if num_rows == 1 { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some( + hashes[0] as i64, + )))) + } else { + // Reinterpret Vec as ScalarBuffer without copying — both + // types have identical layout, and `as i64` is a bitcast. + let buffer = ScalarBuffer::::from(Buffer::from_vec(hashes)); + Ok(ColumnarValue::Array(Arc::new(Int64Array::new( + buffer, None, + )))) + } + } +} + +#[inline] +fn spark_compatible_xxhash64>(data: T, seed: u64) -> u64 { + XxHash64::oneshot(seed, data.as_ref()) +} + +/// Hash the values in a dictionary array using xxhash64. +fn create_xxhash64_hashes_dictionary( + array: &ArrayRef, + hashes_buffer: &mut [u64], + first_col: bool, +) -> Result<()> { + let dict_array = array.as_any().downcast_ref::>().unwrap(); + if !first_col { + let unpacked = take(dict_array.values().as_ref(), dict_array.keys(), None)?; + create_xxhash64_hashes(&[unpacked], hashes_buffer)?; + } else { + // Hash each dictionary value once, then look up by key. This avoids + // redundant hashing of large dictionary entries (e.g. long strings). + let dict_values = Arc::clone(dict_array.values()); + let mut dict_hashes = vec![DEFAULT_SEED; dict_values.len()]; + create_xxhash64_hashes(&[dict_values], &mut dict_hashes)?; + + for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { + if let Some(key) = key { + *hash = dict_hashes[key.as_usize()] + } + // No update for Null keys, consistent with other types. + } + } + Ok(()) +} + +/// Create xxhash64 hash values for every row, based on the values in the columns. +/// +/// The number of rows to hash is determined by `hashes_buffer.len()`. +/// `hashes_buffer` should be pre-sized appropriately and seeded with the +/// initial hash value (Spark uses `42`). +fn create_xxhash64_hashes(arrays: &[ArrayRef], hashes_buffer: &mut [u64]) -> Result<()> { + create_hashes_internal!( + arrays, + hashes_buffer, + spark_compatible_xxhash64, + create_xxhash64_hashes_dictionary, + create_xxhash64_hashes + ); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{FixedSizeBinaryArray, Int32Array, StringArray}; + + #[test] + fn test_xxhash64_nullability() -> Result<()> { + let func = SparkXxhash64::new(); + + // Spark's xxhash64 is never null (NULL args are skipped, seed is returned), + // so the output field is non-nullable regardless of input nullability. + let nullable: FieldRef = Arc::new(Field::new("a", DataType::Int32, true)); + let non_nullable: FieldRef = Arc::new(Field::new("b", DataType::Int32, false)); + + let out = func.return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&nullable), Arc::clone(&non_nullable)], + scalar_arguments: &[None, None], + })?; + assert!(!out.is_nullable()); + assert_eq!(out.data_type(), &DataType::Int64); + + Ok(()) + } + + #[test] + fn test_xxhash64_i32() { + let seed = 42u64; + assert_eq!( + spark_compatible_xxhash64(1i32.to_le_bytes(), seed), + 0xa309b38455455929 + ); + assert_eq!( + spark_compatible_xxhash64(0i32.to_le_bytes(), seed), + 0x3229fbc4681e48f3 + ); + assert_eq!( + spark_compatible_xxhash64((-1i32).to_le_bytes(), seed), + 0x1bfdda8861c06e45 + ); + } + + #[test] + fn test_xxhash64_i32_boundary() { + let seed = 42u64; + let h = spark_compatible_xxhash64(i32::MAX.to_le_bytes(), seed); + assert_ne!(h, seed); + let h = spark_compatible_xxhash64(i32::MIN.to_le_bytes(), seed); + assert_ne!(h, seed); + } + + #[test] + fn test_xxhash64_i8() { + let seed = 42u64; + // i8 is widened to i32 before hashing + assert_eq!( + spark_compatible_xxhash64((1i8 as i32).to_le_bytes(), seed), + spark_compatible_xxhash64(1i32.to_le_bytes(), seed), + ); + } + + #[test] + fn test_xxhash64_i64() { + let seed = 42u64; + assert_eq!( + spark_compatible_xxhash64(1i64.to_le_bytes(), seed), + 0x9ed50fd59358d232 + ); + assert_eq!( + spark_compatible_xxhash64(0i64.to_le_bytes(), seed), + 0xb71b47ebda15746c + ); + assert_eq!( + spark_compatible_xxhash64((-1i64).to_le_bytes(), seed), + 0x358ae035bfb46fd2 + ); + } + + #[test] + fn test_xxhash64_i64_boundary() { + let seed = 42u64; + let h = spark_compatible_xxhash64(i64::MAX.to_le_bytes(), seed); + assert_ne!(h, seed); + let h = spark_compatible_xxhash64(i64::MIN.to_le_bytes(), seed); + assert_ne!(h, seed); + } + + /// Spark normalizes `-0.0` to `0.0` before hashing, so both produce + /// the same hash. Exercise the dispatch through `create_xxhash64_hashes` + /// to cover the `hash_array_primitive_float!` normalization path. + #[test] + fn test_xxhash64_negative_zero_f32() { + use arrow::array::Float32Array; + let array: ArrayRef = Arc::new(Float32Array::from(vec![0.0f32, -0.0f32])); + let mut hashes = vec![DEFAULT_SEED; 2]; + create_xxhash64_hashes(&[array], &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); + assert_eq!(hashes[0], spark_compatible_xxhash64(0i32.to_le_bytes(), 42)); + } + + #[test] + fn test_xxhash64_negative_zero_f64() { + use arrow::array::Float64Array; + let array: ArrayRef = Arc::new(Float64Array::from(vec![0.0f64, -0.0f64])); + let mut hashes = vec![DEFAULT_SEED; 2]; + create_xxhash64_hashes(&[array], &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); + assert_eq!(hashes[0], spark_compatible_xxhash64(0i64.to_le_bytes(), 42)); + } + + #[test] + fn test_xxhash64_string() { + let seed = 42u64; + assert_eq!(spark_compatible_xxhash64("hello", seed), 0xc3629e6318d53932); + assert_eq!(spark_compatible_xxhash64("", seed), 0x98b1582b0977e704); + assert_eq!(spark_compatible_xxhash64("abc", seed), 0x13c1d910702770e6); + } + + #[test] + fn test_xxhash64_string_emoji_cjk() { + let seed = 42u64; + let h1 = spark_compatible_xxhash64("😁", seed); + assert_ne!(h1, seed); + let h2 = spark_compatible_xxhash64("天地", seed); + assert_ne!(h2, seed); + assert_ne!(h1, h2); + } + + #[test] + fn test_xxhash64_dictionary_string() { + use arrow::array::DictionaryArray; + use arrow::datatypes::Int32Type; + + let dict_array: DictionaryArray = + vec!["hello", "world", "abc", "hello", "world"] + .into_iter() + .collect(); + let array_ref: ArrayRef = Arc::new(dict_array); + + let mut hashes = vec![DEFAULT_SEED; 5]; + create_xxhash64_hashes(&[array_ref], &mut hashes).unwrap(); + + assert_eq!(hashes[0], spark_compatible_xxhash64("hello", 42)); + assert_eq!(hashes[1], spark_compatible_xxhash64("world", 42)); + assert_eq!(hashes[2], spark_compatible_xxhash64("abc", 42)); + assert_eq!(hashes[3], hashes[0]); + assert_eq!(hashes[4], hashes[1]); + } + + #[test] + fn test_xxhash64_dictionary_int() { + use arrow::array::DictionaryArray; + use arrow::datatypes::Int32Type; + + let keys = Int32Array::from(vec![0, 1, 2, 0, 1]); + let values = Int32Array::from(vec![100, 200, 300]); + let dict_array = + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + let array_ref: ArrayRef = Arc::new(dict_array); + + let mut hashes = vec![DEFAULT_SEED; 5]; + create_xxhash64_hashes(&[array_ref], &mut hashes).unwrap(); + + assert_eq!( + hashes[0], + spark_compatible_xxhash64(100i32.to_le_bytes(), 42) + ); + assert_eq!( + hashes[1], + spark_compatible_xxhash64(200i32.to_le_bytes(), 42) + ); + assert_eq!( + hashes[2], + spark_compatible_xxhash64(300i32.to_le_bytes(), 42) + ); + assert_eq!(hashes[3], hashes[0]); + assert_eq!(hashes[4], hashes[1]); + } + + #[test] + fn test_xxhash64_dictionary_with_nulls() { + use arrow::array::DictionaryArray; + use arrow::datatypes::Int32Type; + + let keys = Int32Array::from(vec![Some(0), None, Some(1), Some(0), None]); + let values = StringArray::from(vec!["hello", "world"]); + let dict_array = + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + let array_ref: ArrayRef = Arc::new(dict_array); + + let mut hashes = vec![DEFAULT_SEED; 5]; + create_xxhash64_hashes(&[array_ref], &mut hashes).unwrap(); + + assert_eq!(hashes[0], spark_compatible_xxhash64("hello", 42)); + assert_eq!(hashes[2], spark_compatible_xxhash64("world", 42)); + assert_eq!(hashes[3], spark_compatible_xxhash64("hello", 42)); + assert_eq!(hashes[1], DEFAULT_SEED); + assert_eq!(hashes[4], DEFAULT_SEED); + } + + #[test] + fn test_xxhash64_dictionary_non_first_column() { + use arrow::array::DictionaryArray; + use arrow::datatypes::Int32Type; + + let dict_array: DictionaryArray = + vec!["hello", "world", "abc"].into_iter().collect(); + let array_ref: ArrayRef = Arc::new(dict_array); + + let mut hashes = vec![123u64, 456u64, 789u64]; + create_xxhash64_hashes_dictionary::(&array_ref, &mut hashes, false) + .unwrap(); + + assert_eq!(hashes[0], spark_compatible_xxhash64("hello", 123)); + assert_eq!(hashes[1], spark_compatible_xxhash64("world", 456)); + assert_eq!(hashes[2], spark_compatible_xxhash64("abc", 789)); + } + + #[test] + fn test_xxhash64_fixed_size_binary() { + let array = FixedSizeBinaryArray::from(vec![ + Some(&[0x01, 0x02, 0x03, 0x04][..]), + Some(&[0x05, 0x06, 0x07, 0x08][..]), + None, + Some(&[0x00, 0x00, 0x00, 0x00][..]), + ]); + let array_ref: ArrayRef = Arc::new(array); + + let mut hashes = vec![DEFAULT_SEED; 4]; + create_xxhash64_hashes(&[array_ref], &mut hashes).unwrap(); + + assert_eq!( + hashes[0], + spark_compatible_xxhash64([0x01, 0x02, 0x03, 0x04], 42) + ); + assert_eq!( + hashes[1], + spark_compatible_xxhash64([0x05, 0x06, 0x07, 0x08], 42) + ); + assert_eq!(hashes[2], DEFAULT_SEED); + assert_eq!( + hashes[3], + spark_compatible_xxhash64([0x00, 0x00, 0x00, 0x00], 42) + ); + } + + #[test] + fn test_xxhash64_struct() { + use arrow::array::StructArray; + use arrow::datatypes::Field; + + let int_array = Int32Array::from(vec![1, 2, 3]); + let str_array = StringArray::from(vec!["a", "b", "c"]); + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(int_array) as ArrayRef, + ), + ( + Arc::new(Field::new("b", DataType::Utf8, false)), + Arc::new(str_array) as ArrayRef, + ), + ]); + let array_ref: ArrayRef = Arc::new(struct_array); + + let mut hashes = vec![DEFAULT_SEED; 3]; + create_xxhash64_hashes(&[array_ref], &mut hashes).unwrap(); + + for hash in &hashes { + assert_ne!(*hash, DEFAULT_SEED); + } + assert_ne!(hashes[0], hashes[1]); + assert_ne!(hashes[1], hashes[2]); + } + + #[test] + fn test_xxhash64_list() { + use arrow::array::ListArray; + use arrow::buffer::OffsetBuffer; + use arrow::datatypes::Field; + + let values = Int32Array::from(vec![1, 2, 3, 4, 5, 6]); + let offsets = OffsetBuffer::new(vec![0i32, 2, 3, 6].into()); + let list_array = ListArray::new( + Arc::new(Field::new_list_field(DataType::Int32, false)), + offsets, + Arc::new(values), + None, + ); + let array_ref: ArrayRef = Arc::new(list_array); + + let mut hashes = vec![DEFAULT_SEED; 3]; + create_xxhash64_hashes(&[array_ref], &mut hashes).unwrap(); + + for hash in &hashes { + assert_ne!(*hash, DEFAULT_SEED); + } + assert_ne!(hashes[0], hashes[1]); + } +} diff --git a/datafusion/spark/src/function/json/json_tuple.rs b/datafusion/spark/src/function/json/json_tuple.rs new file mode 100644 index 0000000000000..3496f979ffe06 --- /dev/null +++ b/datafusion/spark/src/function/json/json_tuple.rs @@ -0,0 +1,238 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, NullBufferBuilder, StringBuilder, StructArray}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields}; +use datafusion_common::cast::as_string_array; +use datafusion_common::{Result, exec_err, internal_err}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; + +/// Spark-compatible `json_tuple` expression +/// +/// +/// +/// Extracts top-level fields from a JSON string and returns them as a struct. +/// +/// `json_tuple(json_string, field1, field2, ...) -> Struct` +/// +/// Note: In Spark, `json_tuple` is a Generator that produces multiple columns directly. +/// In DataFusion, a ScalarUDF can only return one value per row, so the result is wrapped +/// in a Struct. The caller (e.g. Comet) is expected to destructure the struct fields. +/// +/// - Returns NULL for each field that is missing from the JSON object +/// - Returns NULL for all fields if the input is NULL or not valid JSON +/// - Non-string JSON values are converted to their JSON string representation +/// - JSON `null` values are returned as NULL (not the string "null") +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct JsonTuple { + signature: Signature, +} + +impl Default for JsonTuple { + fn default() -> Self { + Self::new() + } +} + +impl JsonTuple { + pub fn new() -> Self { + Self { + signature: Signature::variadic(vec![DataType::Utf8], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for JsonTuple { + fn name(&self) -> &str { + "json_tuple" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + if args.arg_fields.len() < 2 { + return exec_err!( + "json_tuple requires at least 2 arguments (json_string, field1), got {}", + args.arg_fields.len() + ); + } + + let num_fields = args.arg_fields.len() - 1; + let fields: Fields = (0..num_fields) + .map(|i| Field::new(format!("c{i}"), DataType::Utf8, true)) + .collect::>() + .into(); + + Ok(Arc::new(Field::new( + self.name(), + DataType::Struct(fields), + true, + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { + args: arg_values, + return_field, + .. + } = args; + let arrays = ColumnarValue::values_to_arrays(&arg_values)?; + let result = json_tuple_inner(&arrays, return_field.data_type())?; + + Ok(ColumnarValue::Array(result)) + } +} + +fn json_tuple_inner(args: &[ArrayRef], return_type: &DataType) -> Result { + let num_rows = args[0].len(); + let num_fields = args.len() - 1; + + let json_array = as_string_array(&args[0])?; + + let field_arrays = args[1..] + .iter() + .map(|arg| as_string_array(arg)) + .collect::>>()?; + + let mut builders: Vec = + (0..num_fields).map(|_| StringBuilder::new()).collect(); + + let mut null_buffer = NullBufferBuilder::new(num_rows); + + for row_idx in 0..num_rows { + if json_array.is_null(row_idx) { + for builder in &mut builders { + builder.append_null(); + } + null_buffer.append_null(); + continue; + } + + let json_str = json_array.value(row_idx); + match serde_json::from_str::(json_str) { + Ok(serde_json::Value::Object(map)) => { + null_buffer.append_non_null(); + for (field_idx, builder) in builders.iter_mut().enumerate() { + if field_arrays[field_idx].is_null(row_idx) { + builder.append_null(); + continue; + } + let field_name = field_arrays[field_idx].value(row_idx); + match map.get(field_name) { + Some(serde_json::Value::Null) => { + builder.append_null(); + } + Some(serde_json::Value::String(s)) => { + builder.append_value(s); + } + Some(other) => { + builder.append_value(other.to_string()); + } + None => { + builder.append_null(); + } + } + } + } + _ => { + for builder in &mut builders { + builder.append_null(); + } + null_buffer.append_null(); + } + } + } + + let struct_fields = match return_type { + DataType::Struct(fields) => fields.clone(), + _ => { + return internal_err!( + "json_tuple requires a Struct return type, got {:?}", + return_type + ); + } + }; + + let arrays: Vec = builders + .into_iter() + .map(|mut builder| Arc::new(builder.finish()) as ArrayRef) + .collect(); + + let struct_array = StructArray::try_new(struct_fields, arrays, null_buffer.finish())?; + + Ok(Arc::new(struct_array)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_return_field_shape() { + let func = JsonTuple::new(); + let fields = vec![ + Arc::new(Field::new("json", DataType::Utf8, false)), + Arc::new(Field::new("f1", DataType::Utf8, false)), + Arc::new(Field::new("f2", DataType::Utf8, false)), + ]; + let result = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &[None, None, None], + }) + .unwrap(); + + match result.data_type() { + DataType::Struct(inner) => { + assert_eq!(inner.len(), 2); + assert_eq!(inner[0].name(), "c0"); + assert_eq!(inner[1].name(), "c1"); + assert_eq!(inner[0].data_type(), &DataType::Utf8); + assert!(inner[0].is_nullable()); + } + other => panic!("Expected Struct, got {other:?}"), + } + } + + #[test] + fn test_too_few_args() { + let func = JsonTuple::new(); + let fields = vec![Arc::new(Field::new("json", DataType::Utf8, false))]; + let result = func.return_field_from_args(ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &[None], + }); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("at least 2 arguments") + ); + } +} diff --git a/datafusion/spark/src/function/json/mod.rs b/datafusion/spark/src/function/json/mod.rs index a87df9a2c87a0..01378235d7c64 100644 --- a/datafusion/spark/src/function/json/mod.rs +++ b/datafusion/spark/src/function/json/mod.rs @@ -15,11 +15,24 @@ // specific language governing permissions and limitations // under the License. +pub mod json_tuple; + use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; use std::sync::Arc; -pub mod expr_fn {} +make_udf_function!(json_tuple::JsonTuple, json_tuple); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!(( + json_tuple, + "Extracts top-level fields from a JSON string and returns them as a struct.", + args, + )); +} pub fn functions() -> Vec> { - vec![] + vec![json_tuple()] } diff --git a/datafusion/spark/src/function/map/map_from_arrays.rs b/datafusion/spark/src/function/map/map_from_arrays.rs index 987548e353e44..92dea2720fbfc 100644 --- a/datafusion/spark/src/function/map/map_from_arrays.rs +++ b/datafusion/spark/src/function/map/map_from_arrays.rs @@ -15,19 +15,22 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use crate::function::map::utils::{ get_element_type, get_list_offsets, get_list_values, map_from_keys_values_offsets_nulls, map_type_from_key_value_types, }; use arrow::array::{Array, ArrayRef, NullArray}; use arrow::compute::kernels::cast; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::config::MapKeyDedupPolicy; use datafusion_common::utils::take_function_args; -use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; use datafusion_functions::utils::make_scalar_function; +use std::sync::Arc; /// Spark-compatible `map_from_arrays` expression /// @@ -51,10 +54,6 @@ impl MapFromArrays { } impl ScalarUDFImpl for MapFromArrays { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "map_from_arrays" } @@ -63,28 +62,39 @@ impl ScalarUDFImpl for MapFromArrays { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - let [key_type, value_type] = take_function_args("map_from_arrays", arg_types)?; - Ok(map_type_from_key_value_types( - get_element_type(key_type)?, - get_element_type(value_type)?, - )) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - make_scalar_function(map_from_arrays_inner, vec![])(&args.args) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let [keys_field, values_field] = args.arg_fields else { + return internal_err!("map_from_arrays expects exactly 2 arguments"); + }; + + let map_type = map_type_from_key_value_types( + get_element_type(keys_field.data_type())?, + get_element_type(values_field.data_type())?, + ); + // Spark marks map_from_arrays as null intolerant, so the output is + // nullable if either input is nullable. + let nullable = keys_field.is_nullable() || values_field.is_nullable(); + Ok(Arc::new(Field::new(self.name(), map_type, nullable))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let last_value_wins = + args.config_options.spark.map_key_dedup_policy == MapKeyDedupPolicy::LastWin; + make_scalar_function( + move |args: &[ArrayRef]| map_from_arrays_inner(args, last_value_wins), + vec![], + )(&args.args) } } -fn map_from_arrays_inner(args: &[ArrayRef]) -> Result { +fn map_from_arrays_inner(args: &[ArrayRef], last_value_wins: bool) -> Result { let [keys, values] = take_function_args("map_from_arrays", args)?; - if matches!(keys.data_type(), DataType::Null) - || matches!(values.data_type(), DataType::Null) - { + if *keys.data_type() == DataType::Null || *values.data_type() == DataType::Null { return Ok(cast( &NullArray::new(keys.len()), &map_type_from_key_value_types( @@ -101,5 +111,60 @@ fn map_from_arrays_inner(args: &[ArrayRef]) -> Result { &get_list_offsets(values)?, keys.nulls(), values.nulls(), + last_value_wins, ) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_map_from_arrays_nullability_and_type() { + let func = MapFromArrays::new(); + + let keys_field: FieldRef = Arc::new(Field::new( + "keys", + DataType::List(Arc::new(Field::new("item", DataType::Int32, false))), + false, + )); + let values_field: FieldRef = Arc::new(Field::new( + "values", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + false, + )); + + let out = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&keys_field), Arc::clone(&values_field)], + scalar_arguments: &[None, None], + }) + .expect("return_field_from_args should succeed"); + + let expected_type = + map_type_from_key_value_types(&DataType::Int32, &DataType::Utf8); + assert_eq!(out.data_type(), &expected_type); + assert!( + !out.is_nullable(), + "map_from_arrays should be non-nullable when both inputs are non-nullable" + ); + + let nullable_keys: FieldRef = Arc::new(Field::new( + "keys", + DataType::List(Arc::new(Field::new("item", DataType::Int32, false))), + true, + )); + + let out_nullable = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[nullable_keys, values_field], + scalar_arguments: &[None, None], + }) + .expect("return_field_from_args should succeed"); + + assert!( + out_nullable.is_nullable(), + "map_from_arrays should be nullable when any input is nullable" + ); + } +} diff --git a/datafusion/spark/src/function/map/map_from_entries.rs b/datafusion/spark/src/function/map/map_from_entries.rs index 6648979c5dd23..69ce352694bd1 100644 --- a/datafusion/spark/src/function/map/map_from_entries.rs +++ b/datafusion/spark/src/function/map/map_from_entries.rs @@ -15,18 +15,22 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; +use std::sync::Arc; use crate::function::map::utils::{ - get_element_type, get_list_offsets, get_list_values, - map_from_keys_values_offsets_nulls, map_type_from_key_value_types, + get_list_offsets, get_list_values, map_from_keys_values_offsets_nulls, + map_type_from_key_value_types, }; use arrow::array::{Array, ArrayRef, NullBufferBuilder, StructArray}; use arrow::buffer::NullBuffer; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::config::MapKeyDedupPolicy; use datafusion_common::utils::take_function_args; -use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{Result, exec_err, internal_err}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; use datafusion_functions::utils::make_scalar_function; /// Spark-compatible `map_from_entries` expression @@ -51,10 +55,6 @@ impl MapFromEntries { } impl ScalarUDFImpl for MapFromEntries { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "map_from_entries" } @@ -63,9 +63,28 @@ impl ScalarUDFImpl for MapFromEntries { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - let [entries_type] = take_function_args("map_from_entries", arg_types)?; - let entries_element_type = get_element_type(entries_type)?; + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let [entries_field] = args.arg_fields else { + return exec_err!("map_from_entries: expected one argument"); + }; + + let (entries_element_field, entries_element_type) = + match entries_field.data_type() { + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) => { + Ok((field.as_ref(), field.data_type())) + } + wrong_type => exec_err!( + "map_from_entries: expected array>, got {:?}", + wrong_type + ), + }?; + let (keys_type, values_type) = match entries_element_type { DataType::Struct(fields) if fields.len() == 2 => { Ok((fields[0].data_type(), fields[1].data_type())) @@ -75,18 +94,24 @@ impl ScalarUDFImpl for MapFromEntries { wrong_type ), }?; - Ok(map_type_from_key_value_types(keys_type, values_type)) + + let map_type = map_type_from_key_value_types(keys_type, values_type); + let nullable = entries_field.is_nullable() || entries_element_field.is_nullable(); + + Ok(Arc::new(Field::new(self.name(), map_type, nullable))) } - fn invoke_with_args( - &self, - args: datafusion_expr::ScalarFunctionArgs, - ) -> Result { - make_scalar_function(map_from_entries_inner, vec![])(&args.args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let last_value_wins = + args.config_options.spark.map_key_dedup_policy == MapKeyDedupPolicy::LastWin; + make_scalar_function( + move |args: &[ArrayRef]| map_from_entries_inner(args, last_value_wins), + vec![], + )(&args.args) } } -fn map_from_entries_inner(args: &[ArrayRef]) -> Result { +fn map_from_entries_inner(args: &[ArrayRef], last_value_wins: bool) -> Result { let [entries] = take_function_args("map_from_entries", args)?; let entries_offsets = get_list_offsets(entries)?; let entries_values = get_list_values(entries)?; @@ -129,5 +154,64 @@ fn map_from_entries_inner(args: &[ArrayRef]) -> Result { &entries_offsets, None, res_nulls.as_ref(), + last_value_wins, ) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::Fields; + + fn make_entries_field(array_nullable: bool, element_nullable: bool) -> FieldRef { + let struct_type = DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Int32, false), + Field::new("value", DataType::Utf8, true), + ])); + Arc::new(Field::new( + "entries", + DataType::List(Arc::new(Field::new("item", struct_type, element_nullable))), + array_nullable, + )) + } + + #[test] + fn test_map_from_entries_nullability_matches_input() { + let func = MapFromEntries::new(); + let expected_type = + map_type_from_key_value_types(&DataType::Int32, &DataType::Utf8); + + // Non-nullable array and elements => non-nullable result + let non_nullable_field = make_entries_field(false, false); + let result = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&non_nullable_field)], + scalar_arguments: &[None], + }) + .expect("should infer field"); + assert!(!result.is_nullable()); + assert_eq!(result.data_type(), &expected_type); + + // Nullable elements should make result nullable even if array is non-nullable + let element_nullable_field = make_entries_field(false, true); + let result = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&element_nullable_field)], + scalar_arguments: &[None], + }) + .expect("should infer field"); + assert!(result.is_nullable()); + assert_eq!(result.data_type(), &expected_type); + + // Nullable array should also yield nullable result + let array_nullable_field = make_entries_field(true, false); + let result = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&array_nullable_field)], + scalar_arguments: &[None], + }) + .expect("should infer field"); + assert!(result.is_nullable()); + assert_eq!(result.data_type(), &expected_type); + } +} diff --git a/datafusion/spark/src/function/map/mod.rs b/datafusion/spark/src/function/map/mod.rs index 2f596b19b422f..c9ebed6f612e1 100644 --- a/datafusion/spark/src/function/map/mod.rs +++ b/datafusion/spark/src/function/map/mod.rs @@ -17,6 +17,7 @@ pub mod map_from_arrays; pub mod map_from_entries; +pub mod str_to_map; mod utils; use datafusion_expr::ScalarUDF; @@ -25,6 +26,7 @@ use std::sync::Arc; make_udf_function!(map_from_arrays::MapFromArrays, map_from_arrays); make_udf_function!(map_from_entries::MapFromEntries, map_from_entries); +make_udf_function!(str_to_map::SparkStrToMap, str_to_map); pub mod expr_fn { use datafusion_functions::export_functions; @@ -40,8 +42,14 @@ pub mod expr_fn { "Creates a map from array>.", arg1 )); + + export_functions!(( + str_to_map, + "Creates a map after splitting the text into key/value pairs using delimiters.", + text pair_delim key_value_delim + )); } pub fn functions() -> Vec> { - vec![map_from_arrays(), map_from_entries()] + vec![map_from_arrays(), map_from_entries(), str_to_map()] } diff --git a/datafusion/spark/src/function/map/str_to_map.rs b/datafusion/spark/src/function/map/str_to_map.rs new file mode 100644 index 0000000000000..abb4bd04762a3 --- /dev/null +++ b/datafusion/spark/src/function/map/str_to_map.rs @@ -0,0 +1,306 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, MapBuilder, MapFieldNames, StringArrayType, StringBuilder, +}; +use arrow::buffer::NullBuffer; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; +use datafusion_common::{Result, exec_err, internal_err}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, +}; + +use crate::function::map::utils::map_type_from_key_value_types; +use datafusion_common::config::MapKeyDedupPolicy; + +const DEFAULT_PAIR_DELIM: &str = ","; +const DEFAULT_KV_DELIM: &str = ":"; + +/// Spark-compatible `str_to_map` expression +/// +/// +/// Creates a map from a string by splitting on delimiters. +/// str_to_map(text[, pairDelim[, keyValueDelim]]) -> Map +/// +/// - text: The input string +/// - pairDelim: Delimiter between key-value pairs (default: ',') +/// - keyValueDelim: Delimiter between key and value (default: ':') +/// +/// # Duplicate Key Handling +/// Mirrors Spark's [`spark.sql.mapKeyDedupPolicy`](https://github.com/apache/spark/blob/v4.0.0/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4502-L4511), +/// wired through DataFusion's `datafusion.spark.map_key_dedup_policy`: +/// - `EXCEPTION` (default): error on duplicate keys. +/// - `LAST_WIN`: keep the last occurrence of each duplicate key. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkStrToMap { + signature: Signature, +} + +impl Default for SparkStrToMap { + fn default() -> Self { + Self::new() + } +} + +impl SparkStrToMap { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + // str_to_map(text) + TypeSignature::String(1), + // str_to_map(text, pairDelim) + TypeSignature::String(2), + // str_to_map(text, pairDelim, keyValueDelim) + TypeSignature::String(3), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkStrToMap { + fn name(&self) -> &str { + "str_to_map" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + let map_type = map_type_from_key_value_types(&DataType::Utf8, &DataType::Utf8); + Ok(Arc::new(Field::new(self.name(), map_type, nullable))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let last_value_wins = + args.config_options.spark.map_key_dedup_policy == MapKeyDedupPolicy::LastWin; + let arrays: Vec = ColumnarValue::values_to_arrays(&args.args)?; + let result = str_to_map_inner(&arrays, last_value_wins)?; + Ok(ColumnarValue::Array(result)) + } +} + +fn str_to_map_inner(args: &[ArrayRef], last_value_wins: bool) -> Result { + match args.len() { + 1 => match args[0].data_type() { + DataType::Utf8 => { + str_to_map_impl(as_string_array(&args[0])?, None, None, last_value_wins) + } + DataType::LargeUtf8 => str_to_map_impl( + as_large_string_array(&args[0])?, + None, + None, + last_value_wins, + ), + DataType::Utf8View => str_to_map_impl( + as_string_view_array(&args[0])?, + None, + None, + last_value_wins, + ), + other => exec_err!( + "Unsupported data type {other:?} for str_to_map, \ + expected Utf8, LargeUtf8, or Utf8View" + ), + }, + 2 => match (args[0].data_type(), args[1].data_type()) { + (DataType::Utf8, DataType::Utf8) => str_to_map_impl( + as_string_array(&args[0])?, + Some(as_string_array(&args[1])?), + None, + last_value_wins, + ), + (DataType::LargeUtf8, DataType::LargeUtf8) => str_to_map_impl( + as_large_string_array(&args[0])?, + Some(as_large_string_array(&args[1])?), + None, + last_value_wins, + ), + (DataType::Utf8View, DataType::Utf8View) => str_to_map_impl( + as_string_view_array(&args[0])?, + Some(as_string_view_array(&args[1])?), + None, + last_value_wins, + ), + (t1, t2) => exec_err!( + "Unsupported data types ({t1:?}, {t2:?}) for str_to_map, \ + expected matching Utf8, LargeUtf8, or Utf8View" + ), + }, + 3 => match ( + args[0].data_type(), + args[1].data_type(), + args[2].data_type(), + ) { + (DataType::Utf8, DataType::Utf8, DataType::Utf8) => str_to_map_impl( + as_string_array(&args[0])?, + Some(as_string_array(&args[1])?), + Some(as_string_array(&args[2])?), + last_value_wins, + ), + (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => { + str_to_map_impl( + as_large_string_array(&args[0])?, + Some(as_large_string_array(&args[1])?), + Some(as_large_string_array(&args[2])?), + last_value_wins, + ) + } + (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => { + str_to_map_impl( + as_string_view_array(&args[0])?, + Some(as_string_view_array(&args[1])?), + Some(as_string_view_array(&args[2])?), + last_value_wins, + ) + } + (t1, t2, t3) => exec_err!( + "Unsupported data types ({t1:?}, {t2:?}, {t3:?}) for str_to_map, \ + expected matching Utf8, LargeUtf8, or Utf8View" + ), + }, + n => exec_err!("str_to_map expects 1-3 arguments, got {n}"), + } +} + +fn str_to_map_impl<'a, V: StringArrayType<'a> + Copy>( + text_array: V, + pair_delim_array: Option, + kv_delim_array: Option, + last_value_wins: bool, +) -> Result { + let num_rows = text_array.len(); + + // Precompute combined null buffer from all input arrays. + // NullBuffer::union_many performs a bitmap-level AND, which is more + // efficient than checking per-row nullability inline. + let combined_nulls = NullBuffer::union_many([ + text_array.nulls(), + pair_delim_array.as_ref().and_then(|a| a.nulls()), + kv_delim_array.as_ref().and_then(|a| a.nulls()), + ]); + + // Use field names matching map_type_from_key_value_types: "key" and "value" + let field_names = MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }; + let mut map_builder = MapBuilder::new( + Some(field_names), + StringBuilder::new(), + StringBuilder::new(), + ); + + let mut seen_keys = HashSet::new(); + // LAST_WIN buffers pairs to support in-place value overwrite at the key's + // first-seen position — matches Spark's `ArrayBasedMapBuilder`. + let mut pairs: Vec<(&str, Option<&str>)> = Vec::new(); + let mut key_positions: HashMap<&str, usize> = HashMap::new(); + for row_idx in 0..num_rows { + if combined_nulls.as_ref().is_some_and(|n| n.is_null(row_idx)) { + map_builder.append(false)?; + continue; + } + + // Per-row delimiter extraction + let pair_delim = + pair_delim_array.map_or(DEFAULT_PAIR_DELIM, |a| a.value(row_idx)); + let kv_delim = kv_delim_array.map_or(DEFAULT_KV_DELIM, |a| a.value(row_idx)); + + let text = text_array.value(row_idx); + if text.is_empty() { + // Empty string -> map with empty key and NULL value (Spark behavior) + map_builder.keys().append_value(""); + map_builder.values().append_null(); + map_builder.append(true)?; + continue; + } + + if last_value_wins { + pairs.clear(); + key_positions.clear(); + for pair in text.split(pair_delim) { + if pair.is_empty() { + continue; + } + let mut kv_iter = pair.splitn(2, kv_delim); + let key = kv_iter.next().unwrap_or(""); + let value = kv_iter.next(); + match key_positions.get(key) { + Some(&idx) => pairs[idx].1 = value, + None => { + key_positions.insert(key, pairs.len()); + pairs.push((key, value)); + } + } + } + for (key, value) in &pairs { + map_builder.keys().append_value(key); + match value { + Some(v) => map_builder.values().append_value(v), + None => map_builder.values().append_null(), + } + } + } else { + seen_keys.clear(); + for pair in text.split(pair_delim) { + if pair.is_empty() { + continue; + } + + let mut kv_iter = pair.splitn(2, kv_delim); + let key = kv_iter.next().unwrap_or(""); + let value = kv_iter.next(); + + if !seen_keys.insert(key) { + return exec_err!( + "[DUPLICATED_MAP_KEY] Duplicate map key '{key}' was found, \ + please check the input data. To allow duplicate keys with \ + last-value-wins semantics, set \ + `datafusion.spark.map_key_dedup_policy` to `LAST_WIN`." + ); + } + + map_builder.keys().append_value(key); + match value { + Some(v) => map_builder.values().append_value(v), + None => map_builder.values().append_null(), + } + } + } + map_builder.append(true)?; + } + + Ok(Arc::new(map_builder.finish())) +} diff --git a/datafusion/spark/src/function/map/utils.rs b/datafusion/spark/src/function/map/utils.rs index b568f45403c30..fa6b2a960dabb 100644 --- a/datafusion/spark/src/function/map/utils.rs +++ b/datafusion/spark/src/function/map/utils.rs @@ -16,14 +16,16 @@ // under the License. use std::borrow::Cow; -use std::collections::HashSet; +use std::collections::HashMap; use std::sync::Arc; -use arrow::array::{Array, ArrayRef, AsArray, BooleanBuilder, MapArray, StructArray}; +use arrow::array::{ + Array, ArrayRef, AsArray, BooleanBuilder, Int32Array, MapArray, StructArray, +}; use arrow::buffer::{NullBuffer, OffsetBuffer}; -use arrow::compute::filter; +use arrow::compute::{filter, take}; use arrow::datatypes::{DataType, Field, Fields}; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, exec_err}; /// Helper function to get element [`DataType`] /// from [`List`](DataType::List)/[`LargeList`](DataType::LargeList)/[`FixedSizeList`](DataType::FixedSizeList)
@@ -64,14 +66,15 @@ pub fn get_list_offsets(array: &ArrayRef) -> Result> { match array.data_type() { DataType::List(_) => Ok(Cow::Borrowed(array.as_list::().offsets().as_ref())), DataType::LargeList(_) => Ok(Cow::Owned( - array.as_list::() + array + .as_list::() .offsets() .iter() .map(|i| *i as i32) .collect::>(), )), DataType::FixedSizeList(_, size) => Ok(Cow::Owned( - (0..=array.len() as i32).map(|i| size * i).collect() + (0..=array.len() as i32).map(|i| size * i).collect(), )), wrong_type => exec_err!( "get_list_offsets expects List/LargeList/FixedSizeList as argument, got {wrong_type:?}" @@ -110,13 +113,13 @@ pub fn map_type_from_key_value_types( /// So the inputs can be [`ListArray`](`arrow::array::ListArray`)/[`LargeListArray`](`arrow::array::LargeListArray`)/[`FixedSizeListArray`](`arrow::array::FixedSizeListArray`)
/// To preserve the row info, [`offsets`](arrow::array::ListArray::offsets) and [`nulls`](arrow::array::ListArray::nulls) for both keys and values need to be provided
/// [`FixedSizeListArray`](`arrow::array::FixedSizeListArray`) has no `offsets`, so they can be generated as a cumulative sum of it's `Size` -/// 2. Spark provides [spark.sql.mapKeyDedupPolicy](https://github.com/apache/spark/blob/cf3a34e19dfcf70e2d679217ff1ba21302212472/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4961) -/// to handle duplicate keys
-/// For now, configurable functions are not supported by Datafusion
-/// So more permissive `LAST_WIN` option is used in this implementation (instead of `EXCEPTION`)
-/// `EXCEPTION` behaviour can still be achieved externally in cost of performance:
-/// `when(array_length(array_distinct(keys)) == array_length(keys), constructed_map)`
-/// `.otherwise(raise_error("duplicate keys occurred during map construction"))` +/// 2. Duplicate-key handling mirrors Spark's +/// [spark.sql.mapKeyDedupPolicy](https://github.com/apache/spark/blob/cf3a34e19dfcf70e2d679217ff1ba21302212472/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4961) +/// and is driven by `last_value_wins`: +/// - `false` (Spark's default `EXCEPTION`): raise `[DUPLICATED_MAP_KEY]` on any duplicate. +/// - `true` (`LAST_WIN`): keep the last occurrence of each duplicate key. +/// +/// Callers wire this from `datafusion.spark.map_key_dedup_policy`. pub fn map_from_keys_values_offsets_nulls( flat_keys: &ArrayRef, flat_values: &ArrayRef, @@ -124,6 +127,7 @@ pub fn map_from_keys_values_offsets_nulls( values_offsets: &[i32], keys_nulls: Option<&NullBuffer>, values_nulls: Option<&NullBuffer>, + last_value_wins: bool, ) -> Result { let (keys, values, offsets) = map_deduplicate_keys( flat_keys, @@ -132,6 +136,7 @@ pub fn map_from_keys_values_offsets_nulls( values_offsets, keys_nulls, values_nulls, + last_value_wins, )?; let nulls = NullBuffer::union(keys_nulls, values_nulls); @@ -146,6 +151,7 @@ pub fn map_from_keys_values_offsets_nulls( )?)) } +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // ScalarValue has interior mutability but is intentionally used as hash key fn map_deduplicate_keys( flat_keys: &ArrayRef, flat_values: &ArrayRef, @@ -153,6 +159,7 @@ fn map_deduplicate_keys( values_offsets: &[i32], keys_nulls: Option<&NullBuffer>, values_nulls: Option<&NullBuffer>, + last_value_wins: bool, ) -> Result<(ArrayRef, ArrayRef, OffsetBuffer)> { let offsets_len = keys_offsets.len(); let mut new_offsets = Vec::with_capacity(offsets_len); @@ -169,8 +176,14 @@ fn map_deduplicate_keys( let mut new_last_offset = 0; new_offsets.push(new_last_offset); + // Mirror Spark's `ArrayBasedMapBuilder`: the first occurrence of a key + // fixes its position in the output; under LAST_WIN a later duplicate + // overwrites that slot's value. `keys_mask` selects the first-seen keys, + // `value_indices` records the source index in `flat_values` to materialize + // for each output slot (updated in place on overwrite). let mut keys_mask_builder = BooleanBuilder::new(); - let mut values_mask_builder = BooleanBuilder::new(); + let mut value_indices: Vec = Vec::new(); + let mut key_to_output_idx: HashMap = HashMap::new(); for (row_idx, (next_keys_offset, next_values_offset)) in keys_offsets .iter() .zip(values_offsets.iter()) @@ -180,52 +193,183 @@ fn map_deduplicate_keys( let num_keys_entries = *next_keys_offset as usize - cur_keys_offset; let num_values_entries = *next_values_offset as usize - cur_values_offset; - let mut keys_mask_one = [false].repeat(num_keys_entries); - let mut values_mask_one = [false].repeat(num_values_entries); - let key_is_valid = keys_nulls.is_none_or(|buf| buf.is_valid(row_idx)); let value_is_valid = values_nulls.is_none_or(|buf| buf.is_valid(row_idx)); if key_is_valid && value_is_valid { if num_keys_entries != num_values_entries { - return exec_err!("map_deduplicate_keys: keys and values lists in the same row must have equal lengths"); - } else if num_keys_entries != 0 { - let mut seen_keys = HashSet::new(); - - for cur_entry_idx in (0..num_keys_entries).rev() { - let key = ScalarValue::try_from_array( - &flat_keys, - cur_keys_offset + cur_entry_idx, - )? - .compacted(); - if seen_keys.contains(&key) { - // TODO: implement configuration and logic for spark.sql.mapKeyDedupPolicy=EXCEPTION (this is default spark-config) - // exec_err!("invalid argument: duplicate keys in map") - // https://github.com/apache/spark/blob/cf3a34e19dfcf70e2d679217ff1ba21302212472/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala#L4961 - } else { - // This code implements deduplication logic for spark.sql.mapKeyDedupPolicy=LAST_WIN (this is NOT default spark-config) - keys_mask_one[cur_entry_idx] = true; - values_mask_one[cur_entry_idx] = true; - seen_keys.insert(key); - new_last_offset += 1; + return exec_err!( + "map_deduplicate_keys: keys and values lists in the same row must have equal lengths" + ); + } + key_to_output_idx.clear(); + for cur_entry_idx in 0..num_keys_entries { + let key = ScalarValue::try_from_array( + &flat_keys, + cur_keys_offset + cur_entry_idx, + )? + .compacted(); + let abs_value_idx = (cur_values_offset + cur_entry_idx) as i32; + + if let Some(&output_idx) = key_to_output_idx.get(&key) { + if last_value_wins { + value_indices[output_idx] = abs_value_idx; + keys_mask_builder.append_value(false); + continue; } + return exec_err!( + "[DUPLICATED_MAP_KEY] Duplicate map key {key} was found, \ + please check the input data. To allow duplicate keys with \ + last-value-wins semantics, set \ + `datafusion.spark.map_key_dedup_policy` to `LAST_WIN`." + ); } + keys_mask_builder.append_value(true); + key_to_output_idx.insert(key, value_indices.len()); + value_indices.push(abs_value_idx); + new_last_offset += 1; } } else { - // the result entry is NULL - // both current row offsets are skipped - // keys or values in the current row are marked false in the masks + // The result entry is NULL — no keys/values emitted. Still pad the + // mask so it stays aligned with `flat_keys`. + keys_mask_builder.append_n(num_keys_entries, false); } - keys_mask_builder.append_array(&keys_mask_one.into()); - values_mask_builder.append_array(&values_mask_one.into()); new_offsets.push(new_last_offset); cur_keys_offset += num_keys_entries; cur_values_offset += num_values_entries; } let keys_mask = keys_mask_builder.finish(); - let values_mask = values_mask_builder.finish(); let needed_keys = filter(&flat_keys, &keys_mask)?; - let needed_values = filter(&flat_values, &values_mask)?; + let value_indices_array = Int32Array::from(value_indices); + let needed_values = take(&flat_values, &value_indices_array, None)?; let offsets = OffsetBuffer::new(new_offsets.into()); Ok((needed_keys, needed_values, offsets)) } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, StringArray}; + + fn int32_utf8_inputs( + keys: Vec, + values: Vec>, + ) -> (ArrayRef, ArrayRef) { + let keys: ArrayRef = Arc::new(Int32Array::from(keys)); + let values: ArrayRef = Arc::new(StringArray::from(values)); + (keys, values) + } + + #[test] + fn happy_path_two_rows_no_duplicates() { + let (keys, values) = + int32_utf8_inputs(vec![1, 2, 3], vec![Some("a"), Some("b"), Some("c")]); + let offsets = [0i32, 2, 3]; + + let result = map_from_keys_values_offsets_nulls( + &keys, &values, &offsets, &offsets, None, None, false, + ) + .unwrap(); + + let map = result.as_map(); + assert_eq!(map.len(), 2); + assert_eq!(map.value_offsets(), &[0, 2, 3]); + } + + #[test] + fn single_row_duplicate_errors_under_exception() { + let (keys, values) = + int32_utf8_inputs(vec![1, 2, 1], vec![Some("a"), Some("b"), Some("c")]); + let offsets = [0i32, 3]; + + let err = map_from_keys_values_offsets_nulls( + &keys, &values, &offsets, &offsets, None, None, false, + ) + .unwrap_err() + .to_string(); + + assert!(err.contains("[DUPLICATED_MAP_KEY]"), "{err}"); + assert!(err.contains("map_key_dedup_policy"), "{err}"); + } + + #[test] + fn last_win_keeps_final_occurrence() { + let (keys, values) = int32_utf8_inputs( + vec![1, 2, 1, 3, 2], + vec![Some("a"), Some("b"), Some("c"), Some("d"), Some("e")], + ); + let offsets = [0i32, 5]; + + let result = map_from_keys_values_offsets_nulls( + &keys, &values, &offsets, &offsets, None, None, true, + ) + .unwrap(); + + let map = result.as_map(); + assert_eq!(map.len(), 1); + // 5 entries in, 3 unique keys -> offsets [0, 3] + assert_eq!(map.value_offsets(), &[0, 3]); + } + + #[test] + fn duplicate_in_later_row_still_errors() { + let (keys, values) = int32_utf8_inputs( + vec![1, 2, 1, 1], + vec![Some("a"), Some("b"), Some("x"), Some("y")], + ); + let offsets = [0i32, 2, 4]; + + let err = map_from_keys_values_offsets_nulls( + &keys, &values, &offsets, &offsets, None, None, false, + ) + .unwrap_err() + .to_string(); + + assert!(err.contains("[DUPLICATED_MAP_KEY]"), "{err}"); + } + + #[test] + fn empty_row_does_not_trigger_dedup() { + let (keys, values) = int32_utf8_inputs(vec![], vec![]); + let offsets = [0i32, 0]; + + let result = map_from_keys_values_offsets_nulls( + &keys, &values, &offsets, &offsets, None, None, false, + ) + .unwrap(); + + let map = result.as_map(); + assert_eq!(map.len(), 1); + assert_eq!(map.value_offsets(), &[0, 0]); + } + + #[test] + fn null_row_is_skipped_and_not_checked() { + // Row 0 is NULL (keys null). Its duplicate keys should be ignored; + // row 1 is a clean row. + let (keys, values) = int32_utf8_inputs( + vec![1, 1, 2, 3], + vec![Some("dup-a"), Some("dup-b"), Some("x"), Some("y")], + ); + let offsets = [0i32, 2, 4]; + let keys_nulls = NullBuffer::from(vec![false, true]); + + let result = map_from_keys_values_offsets_nulls( + &keys, + &values, + &offsets, + &offsets, + Some(&keys_nulls), + None, + false, + ) + .unwrap(); + + let map = result.as_map(); + assert_eq!(map.len(), 2); + // First row is NULL (no entries emitted), second row keeps both entries. + assert_eq!(map.value_offsets(), &[0, 0, 2]); + assert!(map.is_null(0)); + assert!(!map.is_null(1)); + } +} diff --git a/datafusion/spark/src/function/math/abs.rs b/datafusion/spark/src/function/math/abs.rs index f48f8964c28c9..0d6c7f3285a18 100644 --- a/datafusion/spark/src/function/math/abs.rs +++ b/datafusion/spark/src/function/math/abs.rs @@ -16,15 +16,17 @@ // under the License. use arrow::array::*; -use arrow::datatypes::DataType; -use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use arrow::error::ArrowError; +use datafusion_common::{DataFusionError, Result, ScalarValue, internal_err}; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_functions::{ - downcast_named_arg, make_abs_function, make_wrapping_abs_function, + downcast_named_arg, make_abs_function, make_try_abs_function, + make_wrapping_abs_function, }; -use std::any::Any; use std::sync::Arc; /// Spark-compatible `abs` expression @@ -33,8 +35,10 @@ use std::sync::Arc; /// Returns the absolute value of input /// Returns NULL if input is NULL, returns NaN if input is NaN. /// -/// TODOs: +/// Differences with DataFusion abs: /// - Spark's ANSI-compliant dialect, when off (i.e. `spark.sql.ansi.enabled=false`), taking absolute value on the minimal value of a signed integer returns the value as is. DataFusion's abs throws "DataFusion error: Arrow error: Compute error" on arithmetic overflow +/// +/// TODOs: /// - Spark's abs also supports ANSI interval types: YearMonthIntervalType and DayTimeIntervalType. DataFusion's abs doesn't. /// #[derive(Debug, PartialEq, Eq, Hash)] @@ -57,10 +61,6 @@ impl SparkAbs { } impl ScalarUDFImpl for SparkAbs { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "abs" } @@ -69,24 +69,54 @@ impl ScalarUDFImpl for SparkAbs { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!( + "SparkAbs: return_type() is not used; return_field_from_args() is implemented" + ) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let input_field = &args.arg_fields[0]; + let out_dt = input_field.data_type().clone(); + let out_nullable = input_field.is_nullable(); + + Ok(Arc::new(Field::new(self.name(), out_dt, out_nullable))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - spark_abs(&args.args) + spark_abs(&args.args, args.config_options.execution.enable_ansi_mode) } } macro_rules! scalar_compute_op { - ($INPUT:ident, $SCALAR_TYPE:ident) => {{ - let result = $INPUT.wrapping_abs(); + ($ENABLE_ANSI_MODE:expr, $INPUT:ident, $SCALAR_TYPE:ident) => {{ + let result = if $ENABLE_ANSI_MODE { + $INPUT.checked_abs().ok_or_else(|| { + ArrowError::ComputeError(format!( + "{} overflow on abs({:?})", + stringify!($SCALAR_TYPE), + $INPUT + )) + })? + } else { + $INPUT.wrapping_abs() + }; Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE(Some( result, )))) }}; - ($INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{ - let result = $INPUT.wrapping_abs(); + ($ENABLE_ANSI_MODE:expr, $INPUT:ident, $PRECISION:expr, $SCALE:expr, $SCALAR_TYPE:ident) => {{ + let result = if $ENABLE_ANSI_MODE { + $INPUT.checked_abs().ok_or_else(|| { + ArrowError::ComputeError(format!( + "{} overflow on abs({:?})", + stringify!($SCALAR_TYPE), + $INPUT + )) + })? + } else { + $INPUT.wrapping_abs() + }; Ok(ColumnarValue::Scalar(ScalarValue::$SCALAR_TYPE( Some(result), $PRECISION, @@ -95,7 +125,10 @@ macro_rules! scalar_compute_op { }}; } -pub fn spark_abs(args: &[ColumnarValue]) -> Result { +pub fn spark_abs( + args: &[ColumnarValue], + enable_ansi_mode: bool, +) -> Result { if args.len() != 1 { return internal_err!("abs takes exactly 1 argument, but got: {}", args.len()); } @@ -108,19 +141,35 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Ok(args[0].clone()), DataType::Int8 => { - let abs_fun = make_wrapping_abs_function!(Int8Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Int8Array) + } else { + make_wrapping_abs_function!(Int8Array) + }; abs_fun(array).map(ColumnarValue::Array) } DataType::Int16 => { - let abs_fun = make_wrapping_abs_function!(Int16Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Int16Array) + } else { + make_wrapping_abs_function!(Int16Array) + }; abs_fun(array).map(ColumnarValue::Array) } DataType::Int32 => { - let abs_fun = make_wrapping_abs_function!(Int32Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Int32Array) + } else { + make_wrapping_abs_function!(Int32Array) + }; abs_fun(array).map(ColumnarValue::Array) } DataType::Int64 => { - let abs_fun = make_wrapping_abs_function!(Int64Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Int64Array) + } else { + make_wrapping_abs_function!(Int64Array) + }; abs_fun(array).map(ColumnarValue::Array) } DataType::Float32 => { @@ -132,11 +181,19 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { - let abs_fun = make_wrapping_abs_function!(Decimal128Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Decimal128Array) + } else { + make_wrapping_abs_function!(Decimal128Array) + }; abs_fun(array).map(ColumnarValue::Array) } DataType::Decimal256(_, _) => { - let abs_fun = make_wrapping_abs_function!(Decimal256Array); + let abs_fun = if enable_ansi_mode { + make_try_abs_function!(Decimal256Array) + } else { + make_wrapping_abs_function!(Decimal256Array) + }; abs_fun(array).map(ColumnarValue::Array) } dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), @@ -148,10 +205,10 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result Ok(args[0].clone()), sv if sv.is_null() => Ok(args[0].clone()), - ScalarValue::Int8(Some(v)) => scalar_compute_op!(v, Int8), - ScalarValue::Int16(Some(v)) => scalar_compute_op!(v, Int16), - ScalarValue::Int32(Some(v)) => scalar_compute_op!(v, Int32), - ScalarValue::Int64(Some(v)) => scalar_compute_op!(v, Int64), + ScalarValue::Int8(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int8), + ScalarValue::Int16(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int16), + ScalarValue::Int32(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int32), + ScalarValue::Int64(Some(v)) => scalar_compute_op!(enable_ansi_mode, v, Int64), ScalarValue::Float32(Some(v)) => { Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(v.abs())))) } @@ -159,10 +216,10 @@ pub fn spark_abs(args: &[ColumnarValue]) -> Result { - scalar_compute_op!(v, *precision, *scale, Decimal128) + scalar_compute_op!(enable_ansi_mode, v, *precision, *scale, Decimal128) } ScalarValue::Decimal256(Some(v), precision, scale) => { - scalar_compute_op!(v, *precision, *scale, Decimal256) + scalar_compute_op!(enable_ansi_mode, v, *precision, *scale, Decimal256) } dt => internal_err!("Not supported datatype for Spark ABS: {dt}"), }, @@ -174,100 +231,12 @@ mod tests { use super::*; use arrow::datatypes::i256; - macro_rules! eval_legacy_mode { - ($TYPE:ident, $VAL:expr) => {{ - let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL))); - match spark_abs(&[args]) { - Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(Some(result)))) => { - assert_eq!(result, $VAL); - } - _ => unreachable!(), - } - }}; - ($TYPE:ident, $VAL:expr, $RESULT:expr) => {{ - let args = ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL))); - match spark_abs(&[args]) { - Ok(ColumnarValue::Scalar(ScalarValue::$TYPE(Some(result)))) => { - assert_eq!(result, $RESULT); - } - _ => unreachable!(), - } - }}; - ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr) => {{ - let args = - ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE)); - match spark_abs(&[args]) { - Ok(ColumnarValue::Scalar(ScalarValue::$TYPE( - Some(result), - precision, - scale, - ))) => { - assert_eq!(result, $VAL); - assert_eq!(precision, $PRECISION); - assert_eq!(scale, $SCALE); - } - _ => unreachable!(), - } - }}; - ($TYPE:ident, $VAL:expr, $PRECISION:expr, $SCALE:expr, $RESULT:expr) => {{ - let args = - ColumnarValue::Scalar(ScalarValue::$TYPE(Some($VAL), $PRECISION, $SCALE)); - match spark_abs(&[args]) { - Ok(ColumnarValue::Scalar(ScalarValue::$TYPE( - Some(result), - precision, - scale, - ))) => { - assert_eq!(result, $RESULT); - assert_eq!(precision, $PRECISION); - assert_eq!(scale, $SCALE); - } - _ => unreachable!(), - } - }}; - } - - #[test] - fn test_abs_scalar_legacy_mode() { - // NumericType MIN - eval_legacy_mode!(UInt8, u8::MIN); - eval_legacy_mode!(UInt16, u16::MIN); - eval_legacy_mode!(UInt32, u32::MIN); - eval_legacy_mode!(UInt64, u64::MIN); - eval_legacy_mode!(Int8, i8::MIN); - eval_legacy_mode!(Int16, i16::MIN); - eval_legacy_mode!(Int32, i32::MIN); - eval_legacy_mode!(Int64, i64::MIN); - eval_legacy_mode!(Float32, f32::MIN, f32::MAX); - eval_legacy_mode!(Float64, f64::MIN, f64::MAX); - eval_legacy_mode!(Decimal128, i128::MIN, 18, 10); - eval_legacy_mode!(Decimal256, i256::MIN, 10, 2); - - // NumericType not MIN - eval_legacy_mode!(Int8, -1i8, 1i8); - eval_legacy_mode!(Int16, -1i16, 1i16); - eval_legacy_mode!(Int32, -1i32, 1i32); - eval_legacy_mode!(Int64, -1i64, 1i64); - eval_legacy_mode!(Decimal128, -1i128, 18, 10, 1i128); - eval_legacy_mode!(Decimal256, i256::from(-1i8), 10, 2, i256::from(1i8)); - - // Float32, Float64 - eval_legacy_mode!(Float32, f32::NEG_INFINITY, f32::INFINITY); - eval_legacy_mode!(Float32, f32::INFINITY, f32::INFINITY); - eval_legacy_mode!(Float32, 0.0f32, 0.0f32); - eval_legacy_mode!(Float32, -0.0f32, 0.0f32); - eval_legacy_mode!(Float64, f64::NEG_INFINITY, f64::INFINITY); - eval_legacy_mode!(Float64, f64::INFINITY, f64::INFINITY); - eval_legacy_mode!(Float64, 0.0f64, 0.0f64); - eval_legacy_mode!(Float64, -0.0f64, 0.0f64); - } - macro_rules! eval_array_legacy_mode { ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{ let input = $INPUT; let args = ColumnarValue::Array(Arc::new(input)); let expected = $OUTPUT; - match spark_abs(&[args]) { + match spark_abs(&[args], false) { Ok(ColumnarValue::Array(result)) => { let actual = datafusion_common::cast::$FUNC(&result).unwrap(); assert_eq!(actual, &expected); @@ -356,23 +325,241 @@ mod tests { ); eval_array_legacy_mode!( - Decimal128Array::from(vec![Some(i128::MIN), None]) + Decimal128Array::from(vec![Some(i128::MIN), Some(i128::MIN + 1), None]) .with_precision_and_scale(38, 37) .unwrap(), - Decimal128Array::from(vec![Some(i128::MIN), None]) + Decimal128Array::from(vec![Some(i128::MIN), Some(i128::MAX), None]) .with_precision_and_scale(38, 37) .unwrap(), as_decimal128_array ); eval_array_legacy_mode!( - Decimal256Array::from(vec![Some(i256::MIN), None]) - .with_precision_and_scale(5, 2) + Decimal256Array::from(vec![ + Some(i256::MIN), + Some(i256::MINUS_ONE), + Some(i256::MIN + i256::from(1)), + None + ]) + .with_precision_and_scale(5, 2) + .unwrap(), + Decimal256Array::from(vec![ + Some(i256::MIN), + Some(i256::ONE), + Some(i256::MAX), + None + ]) + .with_precision_and_scale(5, 2) + .unwrap(), + as_decimal256_array + ); + } + + macro_rules! eval_array_ansi_mode { + ($INPUT:expr) => {{ + let input = $INPUT; + let args = ColumnarValue::Array(Arc::new(input)); + match spark_abs(&[args], true) { + Err(e) => { + assert!( + e.to_string().contains("overflow on abs"), + "Error message did not match. Actual message: {e}" + ); + } + _ => unreachable!(), + } + }}; + ($INPUT:expr, $OUTPUT:expr, $FUNC:ident) => {{ + let input = $INPUT; + let args = ColumnarValue::Array(Arc::new(input)); + let expected = $OUTPUT; + match spark_abs(&[args], true) { + Ok(ColumnarValue::Array(result)) => { + let actual = datafusion_common::cast::$FUNC(&result).unwrap(); + assert_eq!(actual, &expected); + } + _ => unreachable!(), + } + }}; + } + #[test] + fn test_abs_array_ansi_mode() { + eval_array_ansi_mode!( + UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]), + UInt64Array::from(vec![Some(u64::MIN), Some(u64::MAX), None]), + as_uint64_array + ); + + eval_array_ansi_mode!(Int8Array::from(vec![ + Some(-1), + Some(i8::MIN), + Some(i8::MAX), + None + ])); + eval_array_ansi_mode!(Int16Array::from(vec![ + Some(-1), + Some(i16::MIN), + Some(i16::MAX), + None + ])); + eval_array_ansi_mode!(Int32Array::from(vec![ + Some(-1), + Some(i32::MIN), + Some(i32::MAX), + None + ])); + eval_array_ansi_mode!(Int64Array::from(vec![ + Some(-1), + Some(i64::MIN), + Some(i64::MAX), + None + ])); + eval_array_ansi_mode!( + Float32Array::from(vec![ + Some(-1f32), + Some(f32::MIN), + Some(f32::MAX), + None, + Some(f32::NAN), + Some(f32::INFINITY), + Some(f32::NEG_INFINITY), + Some(0.0), + Some(-0.0), + ]), + Float32Array::from(vec![ + Some(1f32), + Some(f32::MAX), + Some(f32::MAX), + None, + Some(f32::NAN), + Some(f32::INFINITY), + Some(f32::INFINITY), + Some(0.0), + Some(0.0), + ]), + as_float32_array + ); + + eval_array_ansi_mode!( + Float64Array::from(vec![ + Some(-1f64), + Some(f64::MIN), + Some(f64::MAX), + None, + Some(f64::NAN), + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + Some(0.0), + Some(-0.0), + ]), + Float64Array::from(vec![ + Some(1f64), + Some(f64::MAX), + Some(f64::MAX), + None, + Some(f64::NAN), + Some(f64::INFINITY), + Some(f64::INFINITY), + Some(0.0), + Some(0.0), + ]), + as_float64_array + ); + + // decimal: no arithmetic overflow + eval_array_ansi_mode!( + Decimal128Array::from(vec![Some(-1), Some(-2), Some(i128::MIN + 1)]) + .with_precision_and_scale(38, 37) .unwrap(), - Decimal256Array::from(vec![Some(i256::MIN), None]) - .with_precision_and_scale(5, 2) + Decimal128Array::from(vec![Some(1), Some(2), Some(i128::MAX)]) + .with_precision_and_scale(38, 37) .unwrap(), + as_decimal128_array + ); + + eval_array_ansi_mode!( + Decimal256Array::from(vec![ + Some(i256::MINUS_ONE), + Some(i256::from(-2)), + Some(i256::MIN + i256::from(1)) + ]) + .with_precision_and_scale(18, 7) + .unwrap(), + Decimal256Array::from(vec![ + Some(i256::ONE), + Some(i256::from(2)), + Some(i256::MAX) + ]) + .with_precision_and_scale(18, 7) + .unwrap(), as_decimal256_array ); + + // decimal: arithmetic overflow + eval_array_ansi_mode!( + Decimal128Array::from(vec![Some(i128::MIN), None]) + .with_precision_and_scale(38, 37) + .unwrap() + ); + eval_array_ansi_mode!( + Decimal256Array::from(vec![Some(i256::MIN), None]) + .with_precision_and_scale(5, 2) + .unwrap() + ); + } + + #[test] + fn test_abs_nullability() { + let abs = SparkAbs::new(); + + // --- non-nullable Int32 input --- + let non_nullable_i32 = Arc::new(Field::new("c", DataType::Int32, false)); + let out_non_null = abs + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&non_nullable_i32)], + scalar_arguments: &[None], + }) + .unwrap(); + + // result should be non-nullable and the same DataType as input + assert!(!out_non_null.is_nullable()); + assert_eq!(out_non_null.data_type(), &DataType::Int32); + + // --- nullable Int32 input --- + let nullable_i32 = Arc::new(Field::new("c", DataType::Int32, true)); + let out_nullable = abs + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&nullable_i32)], + scalar_arguments: &[None], + }) + .unwrap(); + + // result should be nullable and the same DataType as input + assert!(out_nullable.is_nullable()); + assert_eq!(out_nullable.data_type(), &DataType::Int32); + + // --- non-nullable Float64 input --- + let non_nullable_f64 = Arc::new(Field::new("c", DataType::Float64, false)); + let out_f64 = abs + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&non_nullable_f64)], + scalar_arguments: &[None], + }) + .unwrap(); + + assert!(!out_f64.is_nullable()); + assert_eq!(out_f64.data_type(), &DataType::Float64); + + // --- nullable Float64 input --- + let nullable_f64 = Arc::new(Field::new("c", DataType::Float64, true)); + let out_f64_null = abs + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&nullable_f64)], + scalar_arguments: &[None], + }) + .unwrap(); + + assert!(out_f64_null.is_nullable()); + assert_eq!(out_f64_null.data_type(), &DataType::Float64); } } diff --git a/datafusion/spark/src/function/math/bin.rs b/datafusion/spark/src/function/math/bin.rs new file mode 100644 index 0000000000000..82afd48e8dc9f --- /dev/null +++ b/datafusion/spark/src/function/math/bin.rs @@ -0,0 +1,106 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, AsArray, StringArray}; +use arrow::datatypes::{DataType, Field, FieldRef, Int64Type}; +use datafusion_common::types::{NativeType, logical_int64}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use std::sync::Arc; + +/// Spark-compatible `bin` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkBin { + signature: Signature, +} + +impl Default for SparkBin { + fn default() -> Self { + Self::new() + } +} + +impl SparkBin { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Numeric], + NativeType::Int64, + )])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkBin { + fn name(&self) -> &str { + "bin" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + Ok(Arc::new(Field::new( + self.name(), + DataType::Utf8, + args.arg_fields[0].is_nullable(), + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_bin_inner, vec![])(&args.args) + } +} + +fn spark_bin_inner(arg: &[ArrayRef]) -> Result { + let [array] = take_function_args("bin", arg)?; + match &array.data_type() { + DataType::Int64 => { + let result: StringArray = array + .as_primitive::() + .iter() + .map(|opt| opt.map(spark_bin)) + .collect(); + Ok(Arc::new(result)) + } + data_type => { + internal_err!("bin does not support: {data_type}") + } + } +} + +fn spark_bin(value: i64) -> String { + format!("{value:b}") +} diff --git a/datafusion/spark/src/function/math/ceil.rs b/datafusion/spark/src/function/math/ceil.rs new file mode 100644 index 0000000000000..5096914a1eba8 --- /dev/null +++ b/datafusion/spark/src/function/math/ceil.rs @@ -0,0 +1,304 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrowNativeTypeOp, AsArray, Decimal128Array}; +use arrow::datatypes::{DataType, Decimal128Type, Float32Type, Float64Type, Int64Type}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +/// Spark-compatible `ceil` expression +/// +/// +/// Differences with DataFusion ceil: +/// - Spark's ceil returns Int64 for float inputs; DataFusion preserves +/// the input type (Float32→Float32, Float64→Float64) +/// - Spark's ceil on Decimal128(p, s) returns Decimal128(p−s+1, 0), reducing scale +/// to 0; DataFusion preserves the original precision and scale +/// - Spark only supports Decimal128; DataFusion also supports Decimal32/64/256 +/// - Spark does not check for decimal overflow; DataFusion errors on overflow +/// +/// 2-argument ceil(value, scale) is not yet implemented +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkCeil { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkCeil { + fn default() -> Self { + Self::new() + } +} + +impl SparkCeil { + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + aliases: vec!["ceiling".to_string()], + } + } +} + +impl ScalarUDFImpl for SparkCeil { + fn name(&self) -> &str { + "ceil" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::Decimal128(p, s) => { + if *s > 0 { + Ok(DataType::Decimal128(decimal128_ceil_precision(*p, *s), 0)) + } else { + // scale <= 0 means the value is already a whole number + // (or represents multiples of 10^(-scale)), so ceil is a no-op + Ok(DataType::Decimal128(*p, *s)) + } + } + dt if matches!(dt, DataType::Float32 | DataType::Float64) + || dt.is_integer() => + { + Ok(DataType::Int64) + } + other => exec_err!("Unsupported data type {other:?} for function ceil"), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_ceil(&args.args) + } +} + +fn spark_ceil(args: &[ColumnarValue]) -> Result { + let [input] = take_function_args("ceil", args)?; + + match input { + ColumnarValue::Scalar(value) => spark_ceil_scalar(value), + ColumnarValue::Array(input) => spark_ceil_array(input), + } +} + +/// Compute ceil for a single decimal128 value with the given scale. +#[inline] +fn decimal128_ceil(value: i128, scale: u32) -> i128 { + let div = 10_i128.pow_wrapping(scale); + let d = value / div; + let r = value % div; + if r > 0 { d + 1 } else { d } +} + +/// Compute the return precision for a decimal128 ceil result. +#[inline] +fn decimal128_ceil_precision(precision: u8, scale: i8) -> u8 { + ((precision as i64) - (scale as i64) + 1).clamp(1, 38) as u8 +} + +fn spark_ceil_scalar(value: &ScalarValue) -> Result { + let result = match value { + ScalarValue::Float32(v) => ScalarValue::Int64(v.map(|x| x.ceil() as i64)), + ScalarValue::Float64(v) => ScalarValue::Int64(v.map(|x| x.ceil() as i64)), + v if v.data_type().is_integer() => v.cast_to(&DataType::Int64)?, + ScalarValue::Decimal128(v, p, s) if *s > 0 => { + let new_p = decimal128_ceil_precision(*p, *s); + ScalarValue::Decimal128(v.map(|x| decimal128_ceil(x, *s as u32)), new_p, 0) + } + ScalarValue::Decimal128(_, _, _) => value.clone(), + other => { + return exec_err!( + "Unsupported data type {:?} for function ceil", + other.data_type() + ); + } + }; + Ok(ColumnarValue::Scalar(result)) +} + +fn spark_ceil_array(input: &Arc) -> Result { + let result = match input.data_type() { + DataType::Float32 => Arc::new( + input + .as_primitive::() + .unary::<_, Int64Type>(|x| x.ceil() as i64), + ) as _, + DataType::Float64 => Arc::new( + input + .as_primitive::() + .unary::<_, Int64Type>(|x| x.ceil() as i64), + ) as _, + dt if dt.is_integer() => arrow::compute::cast(input, &DataType::Int64)?, + DataType::Decimal128(p, s) if *s > 0 => { + let new_p = decimal128_ceil_precision(*p, *s); + let result: Decimal128Array = input + .as_primitive::() + .unary(|x| decimal128_ceil(x, *s as u32)); + Arc::new(result.with_data_type(DataType::Decimal128(new_p, 0))) + } + DataType::Decimal128(_, _) => Arc::clone(input), + other => return exec_err!("Unsupported data type {other:?} for function ceil"), + }; + + Ok(ColumnarValue::Array(result)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Decimal128Array, Float32Array, Float64Array, Int64Array}; + use datafusion_common::ScalarValue; + + #[test] + fn test_ceil_float64() { + let input = Float64Array::from(vec![ + Some(125.2345), + Some(15.0001), + Some(0.1), + Some(-0.9), + Some(-1.1), + Some(123.0), + None, + ]); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_ceil(&args).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!( + result, + &Int64Array::from(vec![ + Some(126), + Some(16), + Some(1), + Some(0), + Some(-1), + Some(123), + None, + ]) + ); + } + + #[test] + fn test_ceil_float32() { + let input = Float32Array::from(vec![ + Some(125.2345f32), + Some(15.0001f32), + Some(0.1f32), + Some(-0.9f32), + Some(-1.1f32), + Some(123.0f32), + None, + ]); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_ceil(&args).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!( + result, + &Int64Array::from(vec![ + Some(126), + Some(16), + Some(1), + Some(0), + Some(-1), + Some(123), + None, + ]) + ); + } + + #[test] + fn test_ceil_int64() { + let input = Int64Array::from(vec![Some(1), Some(-1), None]); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_ceil(&args).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + assert_eq!(result, &Int64Array::from(vec![Some(1), Some(-1), None])); + } + + #[test] + fn test_ceil_decimal128() { + // Decimal128(10, 2): 150 = 1.50, -150 = -1.50, 100 = 1.00 + let return_type = DataType::Decimal128(9, 0); + let input = Decimal128Array::from(vec![Some(150), Some(-150), Some(100), None]) + .with_data_type(DataType::Decimal128(10, 2)); + let args = vec![ColumnarValue::Array(Arc::new(input))]; + let result = spark_ceil(&args).unwrap(); + let result = match result { + ColumnarValue::Array(arr) => arr, + _ => panic!("Expected array"), + }; + let result = result.as_primitive::(); + let expected = Decimal128Array::from(vec![Some(2), Some(-1), Some(1), None]) + .with_data_type(return_type); + assert_eq!(result, &expected); + } + + #[test] + fn test_ceil_float64_scalar() { + let input = ScalarValue::Float64(Some(-1.1)); + let args = vec![ColumnarValue::Scalar(input)]; + let result = match spark_ceil(&args).unwrap() { + ColumnarValue::Scalar(v) => v, + _ => panic!("Expected scalar"), + }; + assert_eq!(result, ScalarValue::Int64(Some(-1))); + } + + #[test] + fn test_ceil_float32_scalar() { + let input = ScalarValue::Float32(Some(125.2345f32)); + let args = vec![ColumnarValue::Scalar(input)]; + let result = match spark_ceil(&args).unwrap() { + ColumnarValue::Scalar(v) => v, + _ => panic!("Expected scalar"), + }; + assert_eq!(result, ScalarValue::Int64(Some(126))); + } + + #[test] + fn test_ceil_int64_scalar() { + let input = ScalarValue::Int64(Some(48)); + let args = vec![ColumnarValue::Scalar(input)]; + let result = match spark_ceil(&args).unwrap() { + ColumnarValue::Scalar(v) => v, + _ => panic!("Expected scalar"), + }; + assert_eq!(result, ScalarValue::Int64(Some(48))); + } +} diff --git a/datafusion/spark/src/function/math/expm1.rs b/datafusion/spark/src/function/math/expm1.rs index b0b2b1a0865cd..a1090072f4909 100644 --- a/datafusion/spark/src/function/math/expm1.rs +++ b/datafusion/spark/src/function/math/expm1.rs @@ -23,7 +23,6 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; -use std::any::Any; use std::sync::Arc; /// @@ -47,10 +46,6 @@ impl SparkExpm1 { } impl ScalarUDFImpl for SparkExpm1 { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "expm1" } diff --git a/datafusion/spark/src/function/math/factorial.rs b/datafusion/spark/src/function/math/factorial.rs index 5cf33d6073e53..c9405273e823b 100644 --- a/datafusion/spark/src/function/math/factorial.rs +++ b/datafusion/spark/src/function/math/factorial.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::{Array, Int64Array}; @@ -23,7 +22,7 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Int32, Int64}; use datafusion_common::cast::as_int32_array; use datafusion_common::{ - exec_err, utils::take_function_args, DataFusionError, Result, ScalarValue, + DataFusionError, Result, ScalarValue, exec_err, utils::take_function_args, }; use datafusion_expr::Signature; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; @@ -51,10 +50,6 @@ impl SparkFactorial { } impl ScalarUDFImpl for SparkFactorial { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "factorial" } @@ -136,8 +131,8 @@ fn compute_factorial(num: Option) -> Option { mod test { use crate::function::math::factorial::spark_factorial; use arrow::array::{Int32Array, Int64Array}; - use datafusion_common::cast::as_int64_array; use datafusion_common::ScalarValue; + use datafusion_common::cast::as_int64_array; use datafusion_expr::ColumnarValue; use std::sync::Arc; diff --git a/datafusion/spark/src/function/math/floor.rs b/datafusion/spark/src/function/math/floor.rs new file mode 100644 index 0000000000000..703f81a2c2065 --- /dev/null +++ b/datafusion/spark/src/function/math/floor.rs @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::cast::AsArray; +use arrow::array::types::Decimal128Type; +use arrow::array::{ArrowNativeTypeOp, Decimal128Array, Int64Array}; +use arrow::compute::kernels::arity::unary; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{DataFusionError, ScalarValue, exec_err, internal_err}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use std::sync::Arc; + +/// Spark-compatible `floor` function. +/// +/// Differences from DataFusion's floor: +/// - Returns Int64 for float and integer inputs (while DataFusion preserves input type) +/// - For Decimal128(p, s), returns Decimal128(p-s+1, 0) with scale 0 +/// (DataFusion preserves original precision and scale) +/// +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkFloor { + signature: Signature, +} + +impl Default for SparkFloor { + fn default() -> Self { + Self::new() + } +} + +impl SparkFloor { + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkFloor { + fn name(&self) -> &str { + "floor" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type( + &self, + _arg_types: &[DataType], + ) -> datafusion_common::Result { + internal_err!("return_field_from_args should be called instead") + } + + fn return_field_from_args( + &self, + args: ReturnFieldArgs, + ) -> datafusion_common::Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + let return_type = match args.arg_fields[0].data_type() { + DataType::Decimal128(p, s) if *s > 0 => { + let new_p = (*p - *s as u8 + 1).clamp(1, 38); + DataType::Decimal128(new_p, 0) + } + DataType::Decimal128(p, s) => DataType::Decimal128(*p, *s), + DataType::Float32 + | DataType::Float64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 => DataType::Int64, + _ => exec_err!( + "found unsupported return type {:?}", + args.arg_fields[0].data_type() + )?, + }; + Ok(Arc::new(Field::new(self.name(), return_type, nullable))) + } + + fn invoke_with_args( + &self, + args: ScalarFunctionArgs, + ) -> datafusion_common::Result { + spark_floor(&args.args, args.return_field.data_type()) + } +} + +macro_rules! apply_int64 { + ($value:expr, $arr_type:ty, $scalar_variant:path, $f:expr) => { + match $value { + ColumnarValue::Array(array) => { + let result: Int64Array = unary(array.as_primitive::<$arr_type>(), $f); + Ok(ColumnarValue::Array(Arc::new(result))) + } + ColumnarValue::Scalar($scalar_variant(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v.map($f)))) + } + other => internal_err!( + "floor: data type mismatch — expected scalar of type {} but got {:?}", + stringify!($scalar_variant), + other.data_type() + ), + } + }; +} + +fn spark_floor( + args: &[ColumnarValue], + return_type: &DataType, +) -> Result { + let value = &args[0]; + match value.data_type() { + DataType::Float32 => apply_int64!( + value, + arrow::datatypes::Float32Type, + ScalarValue::Float32, + |x| x.floor() as i64 + ), + DataType::Float64 => apply_int64!( + value, + arrow::datatypes::Float64Type, + ScalarValue::Float64, + |x| x.floor() as i64 + ), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + value.cast_to(&DataType::Int64, None) + } + DataType::Decimal128(_, scale) if scale > 0 => { + let divisor = 10_i128.pow_wrapping(scale as u32); + let floor_decimal = |x: i128| { + let (d, r) = (x / divisor, x % divisor); + if r < 0 { d - 1 } else { d } + }; + match value { + ColumnarValue::Array(array) => { + let result: Decimal128Array = + unary(array.as_primitive::(), floor_decimal); + Ok(ColumnarValue::Array(Arc::new( + result.with_data_type(return_type.clone()), + ))) + } + ColumnarValue::Scalar(ScalarValue::Decimal128(v, _, _)) => { + let DataType::Decimal128(new_p, new_s) = return_type else { + return internal_err!( + "floor: data type mismatch — expected Decimal128 return type but got {:?}", + return_type + ); + }; + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + v.map(floor_decimal), + *new_p, + *new_s, + ))) + } + other => internal_err!( + "floor: data type mismatch — expected Decimal128 scalar but got {:?}", + other.data_type() + ), + } + } + DataType::Decimal128(_, _) => Ok(value.clone()), + other => exec_err!("Unsupported data type {other:?} for function floor"), + } +} diff --git a/datafusion/spark/src/function/math/hex.rs b/datafusion/spark/src/function/math/hex.rs index 7029b5e434909..55c9cda63c888 100644 --- a/datafusion/spark/src/function/math/hex.rs +++ b/datafusion/spark/src/function/math/hex.rs @@ -15,28 +15,28 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; +use std::str::from_utf8_unchecked; use std::sync::Arc; -use crate::function::error_utils::{ - invalid_arg_count_exec_err, unsupported_data_type_exec_err, -}; -use arrow::array::{Array, StringArray}; +use arrow::array::{Array, ArrayRef, StringBuilder}; use arrow::datatypes::DataType; use arrow::{ array::{as_dictionary_array, as_largestring_array, as_string_array}, datatypes::Int32Type, }; +use datafusion_common::cast::as_large_binary_array; use datafusion_common::cast::as_string_view_array; +use datafusion_common::types::{NativeType, logical_int64, logical_string}; use datafusion_common::utils::take_function_args; use datafusion_common::{ + DataFusionError, cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array}, - exec_err, DataFusionError, + exec_datafusion_err, exec_err, +}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, }; -use datafusion_expr::Signature; -use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Volatility}; -use std::fmt::Write; - /// #[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkHex { @@ -52,18 +52,33 @@ impl Default for SparkHex { impl SparkHex { pub fn new() -> Self { + let int64 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Numeric], + NativeType::Int64, + ); + + let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string())); + + let binary = Coercion::new_exact(TypeSignatureClass::Binary); + + let variants = vec![ + // accepts numeric types + TypeSignature::Coercible(vec![int64]), + // accepts string types (Utf8, Utf8View, LargeUtf8) + TypeSignature::Coercible(vec![string]), + // accepts binary types (Binary, FixedSizeBinary, LargeBinary) + TypeSignature::Coercible(vec![binary]), + ]; + Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::one_of(variants, Volatility::Immutable), aliases: vec![], } } } impl ScalarUDFImpl for SparkHex { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "hex" } @@ -72,11 +87,13 @@ impl ScalarUDFImpl for SparkHex { &self.signature } - fn return_type( - &self, - _arg_types: &[DataType], - ) -> datafusion_common::Result { - Ok(DataType::Utf8) + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + Ok(match &arg_types[0] { + DataType::Dictionary(key_type, _) => { + DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8)) + } + _ => DataType::Utf8, + }) } fn invoke_with_args( @@ -89,86 +106,123 @@ impl ScalarUDFImpl for SparkHex { fn aliases(&self) -> &[String] { &self.aliases } +} - fn coerce_types( - &self, - arg_types: &[DataType], - ) -> datafusion_common::Result> { - if arg_types.len() != 1 { - return Err(invalid_arg_count_exec_err("hex", (1, 1), arg_types.len())); - } - match &arg_types[0] { - DataType::Int64 - | DataType::Utf8 - | DataType::Utf8View - | DataType::LargeUtf8 - | DataType::Binary - | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]), - DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { - DataType::Int64 - | DataType::Utf8 - | DataType::Utf8View - | DataType::LargeUtf8 - | DataType::Binary - | DataType::LargeBinary => Ok(vec![arg_types[0].clone()]), - other => { - if other.is_numeric() { - Ok(vec![DataType::Dictionary( - key_type.clone(), - Box::new(DataType::Int64), - )]) - } else { - Err(unsupported_data_type_exec_err( - "hex", - "Numeric, String, or Binary", - &arg_types[0], - )) - } - } - }, - other => { - if other.is_numeric() { - Ok(vec![DataType::Int64]) - } else { - Err(unsupported_data_type_exec_err( - "hex", - "Numeric, String, or Binary", - &arg_types[0], - )) - } - } - } +/// Hex encoding lookup tables for fast byte-to-hex conversion. +/// +/// Each entry maps a full byte to its two-character hex encoding so the +/// hot loop becomes one load + one two-byte extend per input byte instead +/// of two nibble lookups and two pushes. +const HEX_CHARS_UPPER_NIBBLES: &[u8; 16] = b"0123456789ABCDEF"; +const HEX_CHARS_LOWER_NIBBLES: &[u8; 16] = b"0123456789abcdef"; + +const HEX_LOOKUP_UPPER: [[u8; 2]; 256] = build_hex_lookup(HEX_CHARS_UPPER_NIBBLES); +const HEX_LOOKUP_LOWER: [[u8; 2]; 256] = build_hex_lookup(HEX_CHARS_LOWER_NIBBLES); + +const fn build_hex_lookup(nibbles: &[u8; 16]) -> [[u8; 2]; 256] { + let mut table = [[0u8; 2]; 256]; + let mut i = 0; + while i < 256 { + table[i][0] = nibbles[(i >> 4) & 0xF]; + table[i][1] = nibbles[i & 0xF]; + i += 1; } + table } -fn hex_int64(num: i64) -> String { - format!("{num:X}") +#[inline] +fn hex_int64(num: i64, buffer: &mut [u8; 16]) -> &[u8] { + if num == 0 { + return b"0"; + } + + // Walk the value two nibbles (one full byte) at a time. The buffer is + // filled from the right so the high-order nibbles end up first; the + // returned slice trims leading zeros automatically. + let mut n = num as u64; + let mut i = 16; + while n >= 0x10 { + i -= 2; + let pair = HEX_LOOKUP_UPPER[(n & 0xFF) as usize]; + buffer[i] = pair[0]; + buffer[i + 1] = pair[1]; + n >>= 8; + } + if n > 0 { + // Single remaining high nibble (value 0x1..=0xF). + i -= 1; + buffer[i] = HEX_CHARS_UPPER_NIBBLES[n as usize]; + } + &buffer[i..] } -#[inline(always)] -fn hex_encode>(data: T, lower_case: bool) -> String { - let mut s = String::with_capacity(data.as_ref().len() * 2); - if lower_case { - for b in data.as_ref() { - // Writing to a string never errors, so we can unwrap here. - write!(&mut s, "{b:02x}").unwrap(); - } +/// Generic hex encoding for byte array types +fn hex_encode_bytes<'a, I, T>( + iter: I, + lowercase: bool, + len: usize, +) -> Result +where + I: Iterator>, + T: AsRef<[u8]> + 'a, +{ + let mut builder = StringBuilder::with_capacity(len, len * 64); + let mut buffer = Vec::with_capacity(64); + let lookup = if lowercase { + &HEX_LOOKUP_LOWER } else { - for b in data.as_ref() { - // Writing to a string never errors, so we can unwrap here. - write!(&mut s, "{b:02X}").unwrap(); + &HEX_LOOKUP_UPPER + }; + + for v in iter { + if let Some(b) = v { + let bytes = b.as_ref(); + buffer.clear(); + let additional = bytes + .len() + .checked_mul(2) + .ok_or_else(|| exec_datafusion_err!("hex output size overflow"))?; + buffer.try_reserve(additional).map_err(|e| { + exec_datafusion_err!( + "failed to reserve {additional} bytes for hex output: {e}" + ) + })?; + for &byte in bytes { + buffer.extend_from_slice(&lookup[byte as usize]); + } + // SAFETY: buffer contains only ASCII hex digits, which are valid UTF-8. + unsafe { + builder.append_value(from_utf8_unchecked(&buffer)); + } + } else { + builder.append_null(); } } - s + + Ok(Arc::new(builder.finish())) } -#[inline(always)] -fn hex_bytes>( - bytes: T, - lowercase: bool, -) -> Result { - let hex_string = hex_encode(bytes, lowercase); - Ok(hex_string) +/// Generic hex encoding for int64 type +fn hex_encode_int64( + iter: impl Iterator>, + len: usize, +) -> Result { + let mut builder = StringBuilder::with_capacity(len, len * 16); + + for v in iter { + if let Some(num) = v { + let mut temp = [0u8; 16]; + let slice = hex_int64(num, &mut temp); + // SAFETY: slice contains only ASCII hex digests, which are valid UTF-8 + unsafe { + builder.append_value(from_utf8_unchecked(slice)); + } + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) } /// Spark-compatible `hex` function @@ -194,93 +248,109 @@ pub fn compute_hex( ColumnarValue::Array(array) => match array.data_type() { DataType::Int64 => { let array = as_int64_array(array)?; - - let hexed_array: StringArray = - array.iter().map(|v| v.map(hex_int64)).collect(); - - Ok(ColumnarValue::Array(Arc::new(hexed_array))) + Ok(ColumnarValue::Array(hex_encode_int64( + array.iter(), + array.len(), + )?)) } DataType::Utf8 => { let array = as_string_array(array); - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::Utf8View => { let array = as_string_view_array(array)?; - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::LargeUtf8 => { let array = as_largestring_array(array); - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::Binary => { let array = as_binary_array(array)?; - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) + } + DataType::LargeBinary => { + let array = as_large_binary_array(array)?; + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::FixedSizeBinary(_) => { let array = as_fixed_size_binary_array(array)?; - - let hexed: StringArray = array - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?; - - Ok(ColumnarValue::Array(Arc::new(hexed))) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } - DataType::Dictionary(_, value_type) => { + DataType::Dictionary(key_type, _) => { + if **key_type != DataType::Int32 { + return exec_err!( + "hex only supports Int32 dictionary keys, get: {}", + key_type + ); + } + let dict = as_dictionary_array::(&array); + let dict_values = dict.values(); - let values = match **value_type { - DataType::Int64 => as_int64_array(dict.values())? - .iter() - .map(|v| v.map(hex_int64)) - .collect::>(), - DataType::Utf8 => as_string_array(dict.values()) - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?, - DataType::Binary => as_binary_array(dict.values())? - .iter() - .map(|v| v.map(|b| hex_bytes(b, lowercase)).transpose()) - .collect::>()?, - _ => exec_err!( - "hex got an unexpected argument type: {}", - array.data_type() - )?, + let encoded_values = match dict_values.data_type() { + DataType::Int64 => { + let arr = as_int64_array(dict_values)?; + hex_encode_int64(arr.iter(), arr.len())? + } + DataType::Utf8 => { + let arr = as_string_array(dict_values); + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::LargeUtf8 => { + let arr = as_largestring_array(dict_values); + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::Utf8View => { + let arr = as_string_view_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::Binary => { + let arr = as_binary_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::LargeBinary => { + let arr = as_large_binary_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::FixedSizeBinary(_) => { + let arr = as_fixed_size_binary_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + _ => { + return exec_err!( + "hex got an unexpected argument type: {}", + dict_values.data_type() + ); + } }; - let new_values: Vec> = dict - .keys() - .iter() - .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None)) - .collect(); - - let string_array_values = StringArray::from(new_values); - - Ok(ColumnarValue::Array(Arc::new(string_array_values))) + let new_dict = dict.with_values(encoded_values); + Ok(ColumnarValue::Array(Arc::new(new_dict))) } _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()), }, @@ -290,16 +360,20 @@ pub fn compute_hex( #[cfg(test)] mod test { + use std::str::from_utf8_unchecked; use std::sync::Arc; - use arrow::array::{Int64Array, StringArray}; + use arrow::array::{ + BinaryArray, DictionaryArray, Int32Array, Int64Array, StringArray, + }; use arrow::{ array::{ - as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, - StringBuilder, StringDictionaryBuilder, + BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringDictionaryBuilder, + as_string_array, }, datatypes::{Int32Type, Int64Type}, }; + use datafusion_common::cast::as_dictionary_array; use datafusion_expr::ColumnarValue; #[test] @@ -311,12 +385,12 @@ mod test { input_builder.append_value("rust"); let input = input_builder.finish(); - let mut string_builder = StringBuilder::new(); - string_builder.append_value("6869"); - string_builder.append_value("627965"); - string_builder.append_null(); - string_builder.append_value("72757374"); - let expected = string_builder.finish(); + let mut expected_builder = StringDictionaryBuilder::::new(); + expected_builder.append_value("6869"); + expected_builder.append_value("627965"); + expected_builder.append_null(); + expected_builder.append_value("72757374"); + let expected = expected_builder.finish(); let columnar_value = ColumnarValue::Array(Arc::new(input)); let result = super::spark_hex(&[columnar_value]).unwrap(); @@ -326,7 +400,7 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } @@ -340,12 +414,12 @@ mod test { input_builder.append_value(3); let input = input_builder.finish(); - let mut string_builder = StringBuilder::new(); - string_builder.append_value("1"); - string_builder.append_value("2"); - string_builder.append_null(); - string_builder.append_value("3"); - let expected = string_builder.finish(); + let mut expected_builder = StringDictionaryBuilder::::new(); + expected_builder.append_value("1"); + expected_builder.append_value("2"); + expected_builder.append_null(); + expected_builder.append_value("3"); + let expected = expected_builder.finish(); let columnar_value = ColumnarValue::Array(Arc::new(input)); let result = super::spark_hex(&[columnar_value]).unwrap(); @@ -355,7 +429,7 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } @@ -369,7 +443,7 @@ mod test { input_builder.append_value("3"); let input = input_builder.finish(); - let mut expected_builder = StringBuilder::new(); + let mut expected_builder = StringDictionaryBuilder::::new(); expected_builder.append_value("31"); expected_builder.append_value("6A"); expected_builder.append_null(); @@ -384,20 +458,79 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } #[test] fn test_hex_int64() { - let num = 1234; - let hexed = super::hex_int64(num); - assert_eq!(hexed, "4D2".to_string()); + let test_cases = vec![ + (0_i64, "0"), + (1, "1"), + (15, "F"), + (16, "10"), + (255, "FF"), + (256, "100"), + (1234, "4D2"), + (i64::MAX, "7FFFFFFFFFFFFFFF"), + (i64::MIN, "8000000000000000"), + (-1, "FFFFFFFFFFFFFFFF"), + ]; + + for (num, expected) in test_cases { + let mut cache = [0u8; 16]; + let slice = super::hex_int64(num, &mut cache); + + unsafe { + let result = from_utf8_unchecked(slice); + assert_eq!(expected, result, "hex_int64({num}) mismatch"); + } + } + } - let num = -1; - let hexed = super::hex_int64(num); - assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string()); + #[test] + fn test_hex_lookup_table_covers_all_bytes() { + // Cross-check the precomputed table against an independent encoder + // for every possible byte value and both casings. + for byte in 0u8..=255 { + let upper = format!("{byte:02X}"); + let lower = format!("{byte:02x}"); + let upper_pair = super::HEX_LOOKUP_UPPER[byte as usize]; + let lower_pair = super::HEX_LOOKUP_LOWER[byte as usize]; + assert_eq!( + upper.as_bytes(), + &upper_pair, + "upper encoding mismatch for byte 0x{byte:02X}" + ); + assert_eq!( + lower.as_bytes(), + &lower_pair, + "lower encoding mismatch for byte 0x{byte:02X}" + ); + } + } + + #[test] + fn test_spark_hex_binary_round_trip_all_bytes() { + // Single-row binary input containing every byte value, encoded in + // a single column. Catches per-byte regressions in the bytes path. + let payload: Vec = (0u8..=255).collect(); + let bin_array = BinaryArray::from(vec![Some(payload.as_slice())]); + + let result = + super::spark_hex(&[ColumnarValue::Array(Arc::new(bin_array))]).unwrap(); + let array = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + let strings = as_string_array(&array); + let mut expected = String::with_capacity(512); + for byte in 0u8..=255 { + use std::fmt::Write; + write!(expected, "{byte:02X}").unwrap(); + } + assert_eq!(strings.value(0), expected); } #[test] @@ -421,4 +554,28 @@ mod test { assert_eq!(string_array, &expected_array); } + + #[test] + fn test_dict_values_null() { + let keys = Int32Array::from(vec![Some(0), None, Some(1)]); + let vals = Int64Array::from(vec![Some(32), None]); + // [32, null, null] + let dict = DictionaryArray::new(keys, Arc::new(vals)); + + let columnar_value = ColumnarValue::Array(Arc::new(dict)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_dictionary_array(&result).unwrap(); + + let keys = Int32Array::from(vec![Some(0), None, Some(1)]); + let vals = StringArray::from(vec![Some("20"), None]); + let expected = DictionaryArray::new(keys, Arc::new(vals)); + + assert_eq!(&expected, result); + } } diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index 74fa4cf37ca55..0079ef0fc97cd 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -16,12 +16,19 @@ // under the License. pub mod abs; +pub mod bin; +pub mod ceil; pub mod expm1; pub mod factorial; +pub mod floor; pub mod hex; pub mod modulus; +pub mod negative; +pub mod pow; pub mod rint; +pub mod round; pub mod trigonometry; +pub mod unhex; pub mod width_bucket; use datafusion_expr::ScalarUDF; @@ -29,46 +36,87 @@ use datafusion_functions::make_udf_function; use std::sync::Arc; make_udf_function!(abs::SparkAbs, abs); +make_udf_function!(ceil::SparkCeil, ceil); make_udf_function!(expm1::SparkExpm1, expm1); make_udf_function!(factorial::SparkFactorial, factorial); +make_udf_function!(floor::SparkFloor, floor); make_udf_function!(hex::SparkHex, hex); make_udf_function!(modulus::SparkMod, modulus); make_udf_function!(modulus::SparkPmod, pmod); +make_udf_function!(pow::SparkPow, pow); make_udf_function!(rint::SparkRint, rint); +make_udf_function!(round::SparkRound, round); +make_udf_function!(unhex::SparkUnhex, unhex); make_udf_function!(width_bucket::SparkWidthBucket, width_bucket); make_udf_function!(trigonometry::SparkCsc, csc); make_udf_function!(trigonometry::SparkSec, sec); +make_udf_function!(negative::SparkNegative, negative); +make_udf_function!(bin::SparkBin, bin); pub mod expr_fn { use datafusion_functions::export_functions; export_functions!((abs, "Returns abs(expr)", arg1)); + export_functions!((ceil, "Returns the ceiling of expr.", arg1)); export_functions!((expm1, "Returns exp(expr) - 1 as a Float64.", arg1)); export_functions!(( factorial, "Returns the factorial of expr. expr is [0..20]. Otherwise, null.", arg1 )); + export_functions!((floor, "Returns floor of expr.", arg1)); export_functions!((hex, "Computes hex value of the given column.", arg1)); export_functions!((modulus, "Returns the remainder of division of the first argument by the second argument.", arg1 arg2)); export_functions!((pmod, "Returns the positive remainder of division of the first argument by the second argument.", arg1 arg2)); - export_functions!((rint, "Returns the double value that is closest in value to the argument and is equal to a mathematical integer.", arg1)); + export_functions!(( + pow, + "Returns base raised to the power of exponent. Returns Infinity for pow(0, negative).", + arg1 arg2 + )); + export_functions!(( + rint, + "Returns the double value that is closest in value to the argument and is equal to a mathematical integer.", + arg1 + )); + export_functions!(( + round, + "Rounds the value of expr to scale decimal places using HALF_UP rounding mode.", + arg1 arg2 + )); + export_functions!((unhex, "Converts hexadecimal string to binary.", arg1)); export_functions!((width_bucket, "Returns the bucket number into which the value of this expression would fall after being evaluated.", arg1 arg2 arg3 arg4)); export_functions!((csc, "Returns the cosecant of expr.", arg1)); export_functions!((sec, "Returns the secant of expr.", arg1)); + export_functions!(( + negative, + "Returns the negation of expr (unary minus).", + arg1 + )); + export_functions!(( + bin, + "Returns the string representation of the long value represented in binary.", + arg1 + )); } pub fn functions() -> Vec> { vec![ abs(), + ceil(), expm1(), factorial(), + floor(), hex(), modulus(), pmod(), + pow(), rint(), + round(), + unhex(), width_bucket(), csc(), sec(), + negative(), + bin(), ] } diff --git a/datafusion/spark/src/function/math/modulus.rs b/datafusion/spark/src/function/math/modulus.rs index 60d45baa7f380..97f59c2cbb0cb 100644 --- a/datafusion/spark/src/function/math/modulus.rs +++ b/datafusion/spark/src/function/math/modulus.rs @@ -15,37 +15,73 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::{Scalar, new_null_array}; use arrow::compute::kernels::numeric::add; -use arrow::compute::kernels::{cmp::lt, numeric::rem, zip::zip}; +use arrow::compute::kernels::{ + cmp::{eq, lt}, + numeric::rem, + zip::zip, +}; use arrow::datatypes::DataType; -use datafusion_common::{assert_eq_or_internal_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, assert_eq_or_internal_err}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; -use std::any::Any; + +/// Computes `rem(left, right)` with divide-by-zero handling. +/// In ANSI mode, any zero divisor causes an error. +/// In legacy mode (ANSI off), zero divisors are replaced with NULL before +/// computing the remainder, so those positions return NULL while others +/// compute normally. +fn try_rem( + left: &arrow::array::ArrayRef, + right: &arrow::array::ArrayRef, + enable_ansi_mode: bool, +) -> Result { + if enable_ansi_mode { + Ok(rem(left, right)?) + } else { + // In legacy mode, null out zero divisors so that division by zero + // returns NULL instead of erroring (integers) or returning NaN (floats). + let zero = ScalarValue::new_zero(right.data_type())?.to_array()?; + let zero = Scalar::new(zero); + let null = Scalar::new(new_null_array(right.data_type(), 1)); + let is_zero = eq(right, &zero)?; + let safe_right = zip(&is_zero, &null, right)?; + Ok(rem(left, &safe_right)?) + } +} /// Spark-compatible `mod` function -/// This function directly uses Arrow's arithmetic_op function for modulo operations -pub fn spark_mod(args: &[ColumnarValue]) -> Result { +/// In ANSI mode, division by zero throws an error. +/// In legacy mode, division by zero returns NULL (Spark behavior). +pub fn spark_mod( + args: &[ColumnarValue], + enable_ansi_mode: bool, +) -> Result { assert_eq_or_internal_err!(args.len(), 2, "mod expects exactly two arguments"); let args = ColumnarValue::values_to_arrays(args)?; - let result = rem(&args[0], &args[1])?; + let result = try_rem(&args[0], &args[1], enable_ansi_mode)?; Ok(ColumnarValue::Array(result)) } /// Spark-compatible `pmod` function -/// This function directly uses Arrow's arithmetic_op function for modulo operations -pub fn spark_pmod(args: &[ColumnarValue]) -> Result { +/// In ANSI mode, division by zero throws an error. +/// In legacy mode, division by zero returns NULL (Spark behavior). +pub fn spark_pmod( + args: &[ColumnarValue], + enable_ansi_mode: bool, +) -> Result { assert_eq_or_internal_err!(args.len(), 2, "pmod expects exactly two arguments"); let args = ColumnarValue::values_to_arrays(args)?; let left = &args[0]; let right = &args[1]; let zero = ScalarValue::new_zero(left.data_type())?.to_array_of_size(left.len())?; - let result = rem(left, right)?; + let result = try_rem(left, right, enable_ansi_mode)?; let neg = lt(&result, &zero)?; let plus = zip(&neg, right, &zero)?; let result = add(&plus, &result)?; - let result = rem(&result, right)?; + let result = try_rem(&result, right, enable_ansi_mode)?; Ok(ColumnarValue::Array(result)) } @@ -70,10 +106,6 @@ impl SparkMod { } impl ScalarUDFImpl for SparkMod { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "mod" } @@ -95,7 +127,7 @@ impl ScalarUDFImpl for SparkMod { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - spark_mod(&args.args) + spark_mod(&args.args, args.config_options.execution.enable_ansi_mode) } } @@ -120,10 +152,6 @@ impl SparkPmod { } impl ScalarUDFImpl for SparkPmod { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "pmod" } @@ -145,7 +173,7 @@ impl ScalarUDFImpl for SparkPmod { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - spark_pmod(&args.args) + spark_pmod(&args.args, args.config_options.execution.enable_ansi_mode) } } @@ -165,7 +193,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_mod(&[left_value, right_value]).unwrap(); + let result = spark_mod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_int32 = @@ -187,7 +215,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_mod(&[left_value, right_value]).unwrap(); + let result = spark_mod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_int64 = @@ -212,6 +240,8 @@ mod test { Some(5.0), Some(f64::NAN), Some(f64::INFINITY), + Some(10.5), + Some(15.8), ]); let right = Float64Array::from(vec![ Some(3.0), @@ -223,12 +253,14 @@ mod test { Some(f64::INFINITY), Some(f64::INFINITY), Some(f64::NAN), + Some(0.0), + Some(0.0), ]); let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_mod(&[left_value, right_value]).unwrap(); + let result = spark_mod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_float64 = result_array @@ -239,7 +271,7 @@ mod test { assert!((result_float64.value(0) - 1.5).abs() < f64::EPSILON); // 10.5 % 3.0 = 1.5 assert!((result_float64.value(1) - 2.2).abs() < f64::EPSILON); // 7.2 % 2.5 = 2.2 assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON); // 15.8 % 4.2 = 3.2 - // nan % 2.0 = nan + // nan % 2.0 = nan assert!(result_float64.value(3).is_nan()); // inf % 2.0 = nan (IEEE 754) assert!(result_float64.value(4).is_nan()); @@ -251,6 +283,9 @@ mod test { assert!(result_float64.value(7).is_nan()); // inf % nan = nan assert!(result_float64.value(8).is_nan()); + // Division by zero returns NULL + assert!(result_float64.is_null(9)); // 10.5 % 0.0 = NULL + assert!(result_float64.is_null(10)); // 15.8 % 0.0 = NULL } else { panic!("Expected array result"); } @@ -268,6 +303,8 @@ mod test { Some(5.0), Some(f32::NAN), Some(f32::INFINITY), + Some(10.5), + Some(15.8), ]); let right = Float32Array::from(vec![ Some(3.0), @@ -279,12 +316,14 @@ mod test { Some(f32::INFINITY), Some(f32::INFINITY), Some(f32::NAN), + Some(0.0), + Some(0.0), ]); let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_mod(&[left_value, right_value]).unwrap(); + let result = spark_mod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_float32 = result_array @@ -295,7 +334,7 @@ mod test { assert!((result_float32.value(0) - 1.5).abs() < f32::EPSILON); // 10.5 % 3.0 = 1.5 assert!((result_float32.value(1) - 2.2).abs() < f32::EPSILON * 3.0); // 7.2 % 2.5 = 2.2 assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); // 15.8 % 4.2 = 3.2 - // nan % 2.0 = nan + // nan % 2.0 = nan assert!(result_float32.value(3).is_nan()); // inf % 2.0 = nan (IEEE 754) assert!(result_float32.value(4).is_nan()); @@ -307,6 +346,9 @@ mod test { assert!(result_float32.value(7).is_nan()); // inf % nan = nan assert!(result_float32.value(8).is_nan()); + // Division by zero returns NULL + assert!(result_float32.is_null(9)); // 10.5 % 0.0 = NULL + assert!(result_float32.is_null(10)); // 15.8 % 0.0 = NULL } else { panic!("Expected array result"); } @@ -319,7 +361,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); - let result = spark_mod(&[left_value, right_value]).unwrap(); + let result = spark_mod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_int32 = @@ -337,20 +379,43 @@ mod test { let left = Int32Array::from(vec![Some(10)]); let left_value = ColumnarValue::Array(Arc::new(left)); - let result = spark_mod(&[left_value]); + let result = spark_mod(&[left_value], false); assert!(result.is_err()); } #[test] - fn test_mod_zero_division() { + fn test_mod_zero_division_legacy() { + // In legacy mode (ANSI off), division by zero returns NULL per-element let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]); let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]); let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_mod(&[left_value, right_value]); - assert!(result.is_err()); // Division by zero should error + let result = spark_mod(&[left_value, right_value], false).unwrap(); + + if let ColumnarValue::Array(result_array) = result { + let result_int32 = + result_array.as_any().downcast_ref::().unwrap(); + assert!(result_int32.is_null(0)); // 10 % 0 = NULL + assert_eq!(result_int32.value(1), 1); // 7 % 2 = 1 + assert_eq!(result_int32.value(2), 3); // 15 % 4 = 3 + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_mod_zero_division_ansi() { + // In ANSI mode, division by zero should error + let left = Int32Array::from(vec![Some(10), Some(7), Some(15)]); + let right = Int32Array::from(vec![Some(0), Some(2), Some(4)]); + + let left_value = ColumnarValue::Array(Arc::new(left)); + let right_value = ColumnarValue::Array(Arc::new(right)); + + let result = spark_mod(&[left_value, right_value], true); + assert!(result.is_err()); } // PMOD tests @@ -362,7 +427,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_pmod(&[left_value, right_value]).unwrap(); + let result = spark_pmod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_int32 = @@ -385,7 +450,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_pmod(&[left_value, right_value]).unwrap(); + let result = spark_pmod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_int64 = @@ -410,6 +475,8 @@ mod test { Some(f64::INFINITY), Some(5.0), Some(-5.0), + Some(10.5), + Some(-7.2), ]); let right = Float64Array::from(vec![ Some(3.0), @@ -420,12 +487,14 @@ mod test { Some(2.0), Some(f64::INFINITY), Some(f64::INFINITY), + Some(0.0), + Some(0.0), ]); let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_pmod(&[left_value, right_value]).unwrap(); + let result = spark_pmod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_float64 = result_array @@ -437,7 +506,7 @@ mod test { assert!((result_float64.value(1) - 1.8).abs() < f64::EPSILON * 3.0); // -7.2 pmod 3.0 = 1.8 (positive) assert!((result_float64.value(2) - 3.2).abs() < f64::EPSILON * 3.0); // 15.8 pmod 4.2 = 3.2 assert!((result_float64.value(3) - 1.0).abs() < f64::EPSILON * 3.0); // -15.8 pmod 4.2 = 1.0 (positive) - // nan pmod 2.0 = nan + // nan pmod 2.0 = nan assert!(result_float64.value(4).is_nan()); // inf pmod 2.0 = nan (IEEE 754) assert!(result_float64.value(5).is_nan()); @@ -445,6 +514,9 @@ mod test { assert!((result_float64.value(6) - 5.0).abs() < f64::EPSILON); // -5.0 pmod inf = NaN assert!(result_float64.value(7).is_nan()); + // Division by zero returns NULL + assert!(result_float64.is_null(8)); // 10.5 pmod 0.0 = NULL + assert!(result_float64.is_null(9)); // -7.2 pmod 0.0 = NULL } else { panic!("Expected array result"); } @@ -461,6 +533,8 @@ mod test { Some(f32::INFINITY), Some(5.0), Some(-5.0), + Some(10.5), + Some(-7.2), ]); let right = Float32Array::from(vec![ Some(3.0), @@ -471,12 +545,14 @@ mod test { Some(2.0), Some(f32::INFINITY), Some(f32::INFINITY), + Some(0.0), + Some(0.0), ]); let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_pmod(&[left_value, right_value]).unwrap(); + let result = spark_pmod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_float32 = result_array @@ -488,7 +564,7 @@ mod test { assert!((result_float32.value(1) - 1.8).abs() < f32::EPSILON * 3.0); // -7.2 pmod 3.0 = 1.8 (positive) assert!((result_float32.value(2) - 3.2).abs() < f32::EPSILON * 10.0); // 15.8 pmod 4.2 = 3.2 assert!((result_float32.value(3) - 1.0).abs() < f32::EPSILON * 10.0); // -15.8 pmod 4.2 = 1.0 (positive) - // nan pmod 2.0 = nan + // nan pmod 2.0 = nan assert!(result_float32.value(4).is_nan()); // inf pmod 2.0 = nan (IEEE 754) assert!(result_float32.value(5).is_nan()); @@ -496,6 +572,9 @@ mod test { assert!((result_float32.value(6) - 5.0).abs() < f32::EPSILON * 10.0); // -5.0 pmod inf = NaN assert!(result_float32.value(7).is_nan()); + // Division by zero returns NULL + assert!(result_float32.is_null(8)); // 10.5 pmod 0.0 = NULL + assert!(result_float32.is_null(9)); // -7.2 pmod 0.0 = NULL } else { panic!("Expected array result"); } @@ -508,7 +587,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); - let result = spark_pmod(&[left_value, right_value]).unwrap(); + let result = spark_pmod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_int32 = @@ -527,20 +606,43 @@ mod test { let left = Int32Array::from(vec![Some(10)]); let left_value = ColumnarValue::Array(Arc::new(left)); - let result = spark_pmod(&[left_value]); + let result = spark_pmod(&[left_value], false); assert!(result.is_err()); } #[test] - fn test_pmod_zero_division() { + fn test_pmod_zero_division_legacy() { + // In legacy mode (ANSI off), division by zero returns NULL per-element + let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]); + let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]); + + let left_value = ColumnarValue::Array(Arc::new(left)); + let right_value = ColumnarValue::Array(Arc::new(right)); + + let result = spark_pmod(&[left_value, right_value], false).unwrap(); + + if let ColumnarValue::Array(result_array) = result { + let result_int32 = + result_array.as_any().downcast_ref::().unwrap(); + assert!(result_int32.is_null(0)); // 10 pmod 0 = NULL + assert!(result_int32.is_null(1)); // -7 pmod 0 = NULL + assert_eq!(result_int32.value(2), 3); // 15 pmod 4 = 3 + } else { + panic!("Expected array result"); + } + } + + #[test] + fn test_pmod_zero_division_ansi() { + // In ANSI mode, division by zero should error let left = Int32Array::from(vec![Some(10), Some(-7), Some(15)]); let right = Int32Array::from(vec![Some(0), Some(0), Some(4)]); let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_pmod(&[left_value, right_value]); - assert!(result.is_err()); // Division by zero should error + let result = spark_pmod(&[left_value, right_value], true); + assert!(result.is_err()); } #[test] @@ -552,7 +654,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_pmod(&[left_value, right_value]).unwrap(); + let result = spark_pmod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_int32 = @@ -590,7 +692,7 @@ mod test { let left_value = ColumnarValue::Array(Arc::new(left)); let right_value = ColumnarValue::Array(Arc::new(right)); - let result = spark_pmod(&[left_value, right_value]).unwrap(); + let result = spark_pmod(&[left_value, right_value], false).unwrap(); if let ColumnarValue::Array(result_array) = result { let result_int32 = diff --git a/datafusion/spark/src/function/math/negative.rs b/datafusion/spark/src/function/math/negative.rs new file mode 100644 index 0000000000000..51e2418b85167 --- /dev/null +++ b/datafusion/spark/src/function/math/negative.rs @@ -0,0 +1,472 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::types::*; +use arrow::array::*; +use arrow::datatypes::{DataType, IntervalDayTime, IntervalMonthDayNano, IntervalUnit}; +use bigdecimal::num_traits::WrappingNeg; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use std::sync::Arc; + +/// Spark-compatible `negative` expression +/// +/// +/// Returns the negation of input (equivalent to unary minus) +/// Returns NULL if input is NULL, returns NaN if input is NaN. +/// +/// ANSI mode support: +/// - When ANSI mode is disabled (`spark.sql.ansi.enabled=false`), negating the minimal +/// value of a signed integer wraps around. For example: negative(i32::MIN) returns +/// i32::MIN (wraps instead of error). +/// - When ANSI mode is enabled (`spark.sql.ansi.enabled=true`), overflow conditions +/// throw an ARITHMETIC_OVERFLOW error instead of wrapping. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkNegative { + signature: Signature, +} + +impl Default for SparkNegative { + fn default() -> Self { + Self::new() + } +} + +impl SparkNegative { + pub fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::OneOf(vec![ + // Numeric types: signed integers, float, decimals + TypeSignature::Numeric(1), + // Interval types: YearMonth, DayTime, MonthDayNano + TypeSignature::Uniform( + 1, + vec![ + DataType::Interval(IntervalUnit::YearMonth), + DataType::Interval(IntervalUnit::DayTime), + DataType::Interval(IntervalUnit::MonthDayNano), + ], + ), + ]), + volatility: Volatility::Immutable, + parameter_names: None, + }, + } + } +} + +impl ScalarUDFImpl for SparkNegative { + fn name(&self) -> &str { + "negative" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_negative(&args.args, args.config_options.execution.enable_ansi_mode) + } +} + +/// Macro to implement negation for integer array types +macro_rules! impl_integer_array_negative { + ($array:expr, $type:ty, $type_name:expr, $enable_ansi_mode:expr) => {{ + let array = $array.as_primitive::<$type>(); + let result: PrimitiveArray<$type> = if $enable_ansi_mode { + array.try_unary(|x| { + x.checked_neg().ok_or_else(|| { + (exec_err!("{} overflow on negative({x})", $type_name) + as Result<(), _>) + .unwrap_err() + }) + })? + } else { + array.unary(|x| x.wrapping_neg()) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +/// Macro to implement negation for float array types +macro_rules! impl_float_array_negative { + ($array:expr, $type:ty) => {{ + let array = $array.as_primitive::<$type>(); + let result: PrimitiveArray<$type> = array.unary(|x| -x); + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +/// Macro to implement negation for decimal array types +macro_rules! impl_decimal_array_negative { + ($array:expr, $type:ty, $type_name:expr, $enable_ansi_mode:expr) => {{ + let array = $array.as_primitive::<$type>(); + let result: PrimitiveArray<$type> = if $enable_ansi_mode { + array + .try_unary(|x| { + x.checked_neg().ok_or_else(|| { + (exec_err!("{} overflow on negative({x})", $type_name) + as Result<(), _>) + .unwrap_err() + }) + })? + .with_data_type(array.data_type().clone()) + } else { + array.unary(|x| x.wrapping_neg()) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +/// Macro to implement negation for integer scalar types +macro_rules! impl_integer_scalar_negative { + ($v:expr, $type_name:expr, $variant:ident, $enable_ansi_mode:expr) => {{ + let result = if $enable_ansi_mode { + $v.checked_neg().ok_or_else(|| { + (exec_err!("{} overflow on negative({})", $type_name, $v) + as Result<(), _>) + .unwrap_err() + })? + } else { + $v.wrapping_neg() + }; + Ok(ColumnarValue::Scalar(ScalarValue::$variant(Some(result)))) + }}; +} + +/// Macro to implement negation for decimal scalar types +macro_rules! impl_decimal_scalar_negative { + ($v:expr, $precision:expr, $scale:expr, $type_name:expr, $variant:ident, $enable_ansi_mode:expr) => {{ + let result = if $enable_ansi_mode { + $v.checked_neg().ok_or_else(|| { + (exec_err!("{} overflow on negative({})", $type_name, $v) + as Result<(), _>) + .unwrap_err() + })? + } else { + $v.wrapping_neg() + }; + Ok(ColumnarValue::Scalar(ScalarValue::$variant( + Some(result), + *$precision, + *$scale, + ))) + }}; +} + +/// Core implementation of Spark's negative function +fn spark_negative( + args: &[ColumnarValue], + enable_ansi_mode: bool, +) -> Result { + let [arg] = take_function_args("negative", args)?; + + match arg { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Null => Ok(arg.clone()), + + // Signed integers - use checked negation in ANSI mode, wrapping in legacy mode + DataType::Int8 => { + impl_integer_array_negative!(array, Int8Type, "Int8", enable_ansi_mode) + } + DataType::Int16 => { + impl_integer_array_negative!(array, Int16Type, "Int16", enable_ansi_mode) + } + DataType::Int32 => { + impl_integer_array_negative!(array, Int32Type, "Int32", enable_ansi_mode) + } + DataType::Int64 => { + impl_integer_array_negative!(array, Int64Type, "Int64", enable_ansi_mode) + } + + // Floating point - simple negation (no overflow possible) + DataType::Float16 => impl_float_array_negative!(array, Float16Type), + DataType::Float32 => impl_float_array_negative!(array, Float32Type), + DataType::Float64 => impl_float_array_negative!(array, Float64Type), + + // Decimal types - use checked negation in ANSI mode, wrapping in legacy mode + DataType::Decimal32(_, _) => impl_decimal_array_negative!( + array, + Decimal32Type, + "Decimal32", + enable_ansi_mode + ), + DataType::Decimal64(_, _) => impl_decimal_array_negative!( + array, + Decimal64Type, + "Decimal64", + enable_ansi_mode + ), + DataType::Decimal128(_, _) => impl_decimal_array_negative!( + array, + Decimal128Type, + "Decimal128", + enable_ansi_mode + ), + DataType::Decimal256(_, _) => impl_decimal_array_negative!( + array, + Decimal256Type, + "Decimal256", + enable_ansi_mode + ), + + // interval type - use checked negation in ANSI mode, wrapping in legacy mode + DataType::Interval(IntervalUnit::YearMonth) => { + impl_integer_array_negative!( + array, + IntervalYearMonthType, + "IntervalYearMonth", + enable_ansi_mode + ) + } + DataType::Interval(IntervalUnit::DayTime) => { + let array = array.as_primitive::(); + let result: PrimitiveArray = if enable_ansi_mode { + array.try_unary(|x| { + let days = x.days.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalDayTime overflow on negative (days: {})", + x.days + ) as Result<(), _>) + .unwrap_err() + })?; + let milliseconds = + x.milliseconds.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalDayTime overflow on negative (milliseconds: {})", + x.milliseconds + ) as Result<(), _>) + .unwrap_err() + })?; + Ok::<_, arrow::error::ArrowError>(IntervalDayTime { + days, + milliseconds, + }) + })? + } else { + array.unary(|x| IntervalDayTime { + days: x.days.wrapping_neg(), + milliseconds: x.milliseconds.wrapping_neg(), + }) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let array = array.as_primitive::(); + let result: PrimitiveArray = if enable_ansi_mode + { + array.try_unary(|x| { + let months = x.months.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (months: {})", + x.months + ) as Result<(), _>) + .unwrap_err() + })?; + let days = x.days.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (days: {})", + x.days + ) as Result<(), _>) + .unwrap_err() + })?; + let nanoseconds = x.nanoseconds.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (nanoseconds: {})", + x.nanoseconds + ) as Result<(), _>) + .unwrap_err() + })?; + Ok::<_, arrow::error::ArrowError>(IntervalMonthDayNano { + months, + days, + nanoseconds, + }) + })? + } else { + array.unary(|x| IntervalMonthDayNano { + months: x.months.wrapping_neg(), + days: x.days.wrapping_neg(), + nanoseconds: x.nanoseconds.wrapping_neg(), + }) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + } + + dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"), + }, + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Null => Ok(arg.clone()), + _ if sv.is_null() => Ok(arg.clone()), + + // Signed integers - use checked negation in ANSI mode, wrapping in legacy mode + ScalarValue::Int8(Some(v)) => { + impl_integer_scalar_negative!(v, "Int8", Int8, enable_ansi_mode) + } + ScalarValue::Int16(Some(v)) => { + impl_integer_scalar_negative!(v, "Int16", Int16, enable_ansi_mode) + } + ScalarValue::Int32(Some(v)) => { + impl_integer_scalar_negative!(v, "Int32", Int32, enable_ansi_mode) + } + ScalarValue::Int64(Some(v)) => { + impl_integer_scalar_negative!(v, "Int64", Int64, enable_ansi_mode) + } + + // Floating point - simple negation + ScalarValue::Float16(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Float16(Some(-v)))) + } + ScalarValue::Float32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(-v)))) + } + ScalarValue::Float64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(-v)))) + } + + // Decimal types - use checked negation in ANSI mode, wrapping in legacy mode + ScalarValue::Decimal32(Some(v), precision, scale) => { + impl_decimal_scalar_negative!( + v, + precision, + scale, + "Decimal32", + Decimal32, + enable_ansi_mode + ) + } + ScalarValue::Decimal64(Some(v), precision, scale) => { + impl_decimal_scalar_negative!( + v, + precision, + scale, + "Decimal64", + Decimal64, + enable_ansi_mode + ) + } + ScalarValue::Decimal128(Some(v), precision, scale) => { + impl_decimal_scalar_negative!( + v, + precision, + scale, + "Decimal128", + Decimal128, + enable_ansi_mode + ) + } + ScalarValue::Decimal256(Some(v), precision, scale) => { + impl_decimal_scalar_negative!( + v, + precision, + scale, + "Decimal256", + Decimal256, + enable_ansi_mode + ) + } + + //interval type - use checked negation in ANSI mode, wrapping in legacy mode + ScalarValue::IntervalYearMonth(Some(v)) => { + impl_integer_scalar_negative!( + v, + "IntervalYearMonth", + IntervalYearMonth, + enable_ansi_mode + ) + } + ScalarValue::IntervalDayTime(Some(v)) => { + let result = if enable_ansi_mode { + let days = v.days.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalDayTime overflow on negative (days: {})", + v.days + ) as Result<(), _>) + .unwrap_err() + })?; + let milliseconds = v.milliseconds.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalDayTime overflow on negative (milliseconds: {})", + v.milliseconds + ) as Result<(), _>) + .unwrap_err() + })?; + IntervalDayTime { days, milliseconds } + } else { + IntervalDayTime { + days: v.days.wrapping_neg(), + milliseconds: v.milliseconds.wrapping_neg(), + } + }; + Ok(ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + result, + )))) + } + ScalarValue::IntervalMonthDayNano(Some(v)) => { + let result = if enable_ansi_mode { + let months = v.months.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (months: {})", + v.months + ) as Result<(), _>) + .unwrap_err() + })?; + let days = v.days.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (days: {})", + v.days + ) as Result<(), _>) + .unwrap_err() + })?; + let nanoseconds = v.nanoseconds.checked_neg().ok_or_else(|| { + (exec_err!( + "IntervalMonthDayNano overflow on negative (nanoseconds: {})", + v.nanoseconds + ) as Result<(), _>) + .unwrap_err() + })?; + IntervalMonthDayNano { + months, + days, + nanoseconds, + } + } else { + IntervalMonthDayNano { + months: v.months.wrapping_neg(), + days: v.days.wrapping_neg(), + nanoseconds: v.nanoseconds.wrapping_neg(), + } + }; + Ok(ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano( + Some(result), + ))) + } + + dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"), + }, + } +} diff --git a/datafusion/spark/src/function/math/pow.rs b/datafusion/spark/src/function/math/pow.rs new file mode 100644 index 0000000000000..8655d71e42c9a --- /dev/null +++ b/datafusion/spark/src/function/math/pow.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Spark-compatible `pow` / `power` function. +//! +//! Unlike the default DataFusion (PostgreSQL) implementation, Spark returns +//! `Infinity` for `pow(0, )` rather than raising an error. + +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, Float64Array}; +use arrow::datatypes::DataType; + +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, +}; +use datafusion_functions::math::power::PowerFunc; + +/// Spark-compatible implementation of `pow` / `power`. +/// +/// Behavioural difference from the DataFusion default: +/// - `pow(0, )` → `Infinity` (IEEE 754 / Spark semantics) +/// The default raises `"zero raised to a negative power is undefined"` to +/// match PostgreSQL. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkPow { + inner: PowerFunc, + aliases: Vec, +} + +impl Default for SparkPow { + fn default() -> Self { + Self::new() + } +} + +impl SparkPow { + pub fn new() -> Self { + Self { + inner: PowerFunc::new(), + // SparkPow is named "pow"; expose "power" as an alias so that + // both names resolve to Spark semantics when this crate is active. + aliases: vec!["power".to_string()], + } + } +} + +impl ScalarUDFImpl for SparkPow { + fn name(&self) -> &str { + "pow" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner.return_type(arg_types) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // Only Float64 × Float64 needs the Spark override. + // Decimal / integer / mixed-type paths are delegated to the standard + // PowerFunc which already handles them correctly (decimal can't + // represent Infinity anyway). + match args.args.as_slice() { + [base, exponent] + if matches!(base.data_type(), DataType::Float64) + && matches!(exponent.data_type(), DataType::Float64) => {} + _ => return self.inner.invoke_with_args(args), + } + + let num_rows = args.number_rows; + + // ── Scalar × Scalar fast path ──────────────────────────────────────── + // Pattern-match on the slice to avoid any ownership issues. + if let [ + ColumnarValue::Scalar(ScalarValue::Float64(base)), + ColumnarValue::Scalar(ScalarValue::Float64(exp)), + ] = args.args.as_slice() + { + // base and exp are &Option; Option is Copy. + let result = (*base).zip(*exp).map(|(base, exp)| { + if base == 0.0 && exp < 0.0 { + f64::INFINITY + } else { + base.powf(exp) + } + }); + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(result))); + } + + // ── Array path ─────────────────────────────────────────────────────── + let [base, exponent] = take_function_args(self.name(), &args.args)?; + + let base_arr: ArrayRef = base.to_array(num_rows)?; + let exp_arr: ArrayRef = exponent.to_array(num_rows)?; + + let base_f64 = base_arr + .as_any() + .downcast_ref::() + .expect("base must be Float64Array"); + let exp_f64 = exp_arr + .as_any() + .downcast_ref::() + .expect("exponent must be Float64Array"); + + // Spark: 0^negative = +Infinity (covers both 0.0 and -0.0) + // IEEE 754: 0.0^-1.0 = +Infinity, -0.0^-1.0 = -Infinity + // Thus we need an explicit guard for base == 0.0 to ensure +Infinity. + let result: Float64Array = base_f64 + .iter() + .zip(exp_f64.iter()) + .map(|(base, exp)| match (base, exp) { + (Some(base), Some(exp)) => { + if base == 0.0 && exp < 0.0 { + Some(f64::INFINITY) + } else { + Some(base.powf(exp)) + } + } + _ => None, + }) + .collect(); + + Ok(ColumnarValue::Array(Arc::new(result))) + } + + fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() + } +} diff --git a/datafusion/spark/src/function/math/rint.rs b/datafusion/spark/src/function/math/rint.rs index 3271be38f8338..3bca93b13241b 100644 --- a/datafusion/spark/src/function/math/rint.rs +++ b/datafusion/spark/src/function/math/rint.rs @@ -15,16 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::{Array, ArrayRef, AsArray}; use arrow::compute::cast; use arrow::datatypes::DataType::{ - Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, + Float32, Float64, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, }; use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{assert_eq_or_internal_err, exec_err, Result}; +use datafusion_common::{Result, assert_eq_or_internal_err, exec_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -51,10 +50,6 @@ impl SparkRint { } impl ScalarUDFImpl for SparkRint { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "rint" } diff --git a/datafusion/spark/src/function/math/round.rs b/datafusion/spark/src/function/math/round.rs new file mode 100644 index 0000000000000..05745666183d3 --- /dev/null +++ b/datafusion/spark/src/function/math/round.rs @@ -0,0 +1,654 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::*; +use arrow::datatypes::{ + ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type, + Decimal256Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, + Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type, +}; +use datafusion_common::types::{ + NativeType, logical_float32, logical_float64, logical_int32, +}; +use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, +}; + +/// Spark-compatible `round` expression +/// +/// +/// Rounds the value of `expr` to `scale` decimal places using HALF_UP rounding mode. +/// Returns the same type as the input expression. +/// +/// - `round(expr)` rounds to 0 decimal places (default scale = 0) +/// - `round(expr, scale)` rounds to `scale` decimal places +/// - For integer types with negative scale: `round(25, -1)` → `30` +/// - Uses HALF_UP rounding: 2.5 → 3, -2.5 → -3 (away from zero) +/// +/// Supported types: Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, +/// Float16, Float32, Float64, Decimal32, Decimal64, Decimal128, Decimal256 +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkRound { + signature: Signature, +} + +impl Default for SparkRound { + fn default() -> Self { + Self::new() + } +} + +impl SparkRound { + pub fn new() -> Self { + let decimal = Coercion::new_exact(TypeSignatureClass::Decimal); + let integer = Coercion::new_exact(TypeSignatureClass::Integer); + let decimal_places = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int32()), + vec![TypeSignatureClass::Integer], + NativeType::Int32, + ); + let float32 = Coercion::new_exact(TypeSignatureClass::Native(logical_float32())); + let float64 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ); + Self { + signature: Signature::one_of( + vec![ + // round(decimal, scale) + TypeSignature::Coercible(vec![ + decimal.clone(), + decimal_places.clone(), + ]), + // round(decimal) + TypeSignature::Coercible(vec![decimal]), + // round(integer, scale) + TypeSignature::Coercible(vec![ + integer.clone(), + decimal_places.clone(), + ]), + // round(integer) + TypeSignature::Coercible(vec![integer]), + // round(float32, scale) + TypeSignature::Coercible(vec![ + float32.clone(), + decimal_places.clone(), + ]), + // round(float32) + TypeSignature::Coercible(vec![float32]), + // round(float64, scale) + TypeSignature::Coercible(vec![float64.clone(), decimal_places]), + // round(float64) + TypeSignature::Coercible(vec![float64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkRound { + fn name(&self) -> &str { + "round" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_round(&args.args, args.config_options.execution.enable_ansi_mode) + } +} + +/// Extract the scale (decimal places) from the second argument. +/// Returns `Some(0)` if no second argument is provided. +/// Returns `None` if the scale argument is NULL (Spark returns NULL for `round(expr, NULL)`). +fn get_scale(args: &[ColumnarValue]) -> Result> { + if args.len() < 2 { + return Ok(Some(0)); + } + + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Int8(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::Int16(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(v))) => Ok(Some(*v)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) => { + i32::try_from(*v).map(Some).map_err(|_| { + (exec_err!("round scale {v} is out of supported i32 range") + as Result<(), _>) + .unwrap_err() + }) + } + ColumnarValue::Scalar(ScalarValue::UInt8(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::UInt16(Some(v))) => Ok(Some(i32::from(*v))), + ColumnarValue::Scalar(ScalarValue::UInt32(Some(v))) => { + i32::try_from(*v).map(Some).map_err(|_| { + (exec_err!("round scale {v} is out of supported i32 range") + as Result<(), _>) + .unwrap_err() + }) + } + ColumnarValue::Scalar(ScalarValue::UInt64(Some(v))) => { + i32::try_from(*v).map(Some).map_err(|_| { + (exec_err!("round scale {v} is out of supported i32 range") + as Result<(), _>) + .unwrap_err() + }) + } + ColumnarValue::Scalar(sv) if sv.is_null() => Ok(None), + other => exec_err!("Unsupported type for round scale: {}", other.data_type()), + } +} + +/// Round a floating-point value to the given number of decimal places using +/// HALF_UP rounding mode (ties round away from zero). +/// +/// This matches Spark's `RoundBase` behaviour for `FloatType` / `DoubleType`, +/// which internally converts the value to `BigDecimal` and rounds with +/// `RoundingMode.HALF_UP`. +/// +/// # Arguments +/// * `value` – the floating-point number to round +/// * `scale` – number of decimal places to keep. +/// - `scale >= 0`: rounds to that many fractional digits +/// (e.g. `round_float(2.345, 2) == 2.35`) +/// - `scale < 0`: rounds to the left of the decimal point +/// (e.g. `round_float(125.0, -1) == 130.0`) +/// +/// # Examples +/// ```text +/// round_float(2.5, 0) → 3.0 // half rounds up +/// round_float(-2.5, 0) → -3.0 // half rounds away from zero +/// round_float(1.4, 0) → 1.0 +/// round_float(125.0, -1) → 130.0 +/// ``` +fn round_float(value: T, scale: i32) -> T { + if scale >= 0 { + let factor = T::from(10.0f64.powi(scale)).unwrap_or_else(T::infinity); + if factor.is_infinite() { + // Very large positive scale — value is already precise enough, return as-is + return value; + } + (value * factor).round() / factor + } else { + let factor = T::from(10.0f64.powi(-scale)).unwrap_or_else(T::infinity); + if factor.is_infinite() { + // Very large negative scale — any finite value rounds to 0 + return T::zero(); + } + (value / factor).round() * factor + } +} + +/// Round an integer value to the given scale using HALF_UP rounding mode. +/// +/// Only meaningful when `scale` is negative — a non-negative scale leaves +/// the integer unchanged because integers have no fractional part. +/// +/// This matches Spark's `RoundBase` behaviour for `ByteType`, `ShortType`, +/// `IntegerType`, and `LongType`, which round to the nearest power-of-ten +/// boundary and return the same integer type. +/// +/// In ANSI mode, overflow conditions return an error instead of wrapping. +/// +/// # Arguments +/// * `value` – the integer to round (widened to `i64` by callers) +/// * `scale` – rounding position relative to the ones digit. +/// - `scale >= 0`: returns `value` as-is +/// - `scale == -1`: rounds to the nearest 10 +/// - `scale == -2`: rounds to the nearest 100 +/// - If `10^|scale|` overflows `i64`, returns `0` +/// * `enable_ansi_mode` – when true, overflow returns an error +/// +/// # Examples +/// ```text +/// round_integer(25, -1, false) → Ok(30) +/// round_integer(-25, -1, false) → Ok(-30) +/// round_integer(123, -1, false) → Ok(120) +/// round_integer(150, -2, false) → Ok(200) +/// round_integer(42, 2, false) → Ok(42) // no-op for positive scale +/// round_integer(42, -10, false) → Ok(0) // factor overflows → 0 +/// ``` +fn round_integer(value: i64, scale: i32, enable_ansi_mode: bool) -> Result { + if scale >= 0 { + return Ok(value); + } + let abs_scale = (-scale) as u32; + let Some(factor) = 10_i64.checked_pow(abs_scale) else { + return Ok(0); + }; + let remainder = value % factor; + let threshold = factor / 2; + let result = if remainder >= threshold { + if enable_ansi_mode { + value + .checked_sub(remainder) + .and_then(|v| v.checked_add(factor)) + .ok_or_else(|| { + (exec_err!("Int64 overflow on round({value}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + value.wrapping_sub(remainder).wrapping_add(factor) + } + } else if remainder <= -threshold { + if enable_ansi_mode { + value + .checked_sub(remainder) + .and_then(|v| v.checked_sub(factor)) + .ok_or_else(|| { + (exec_err!("Int64 overflow on round({value}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + value.wrapping_sub(remainder).wrapping_sub(factor) + } + } else { + value - remainder + }; + Ok(result) +} + +// --------------------------------------------------------------------------- +// Decimal rounding using ArrowNativeTypeOp (HALF_UP) +// --------------------------------------------------------------------------- + +/// Round a decimal value represented as its unscaled integer using HALF_UP +/// rounding mode (ties round away from zero). +/// +/// This matches Spark's `RoundBase` behaviour for `DecimalType`, which calls +/// `BigDecimal.setScale(scale, RoundingMode.HALF_UP)`. +/// +/// Decimals are stored as `(unscaled_value, precision, scale)` where the real +/// value equals `unscaled_value * 10^(-scale)`. This function operates on the +/// unscaled integer directly: +/// +/// 1. Compute `diff = input_scale - decimal_places`. +/// If `diff <= 0` the requested precision is finer than (or equal to) the +/// stored scale, so nothing needs to be rounded — return as-is. +/// 2. Divide by `10^diff` to shift the rounding boundary into the ones digit. +/// 3. Inspect the remainder to decide whether to round up or down (HALF_UP). +/// 4. Multiply back by `10^diff` so the result is expressed at the original +/// `input_scale`. +/// +/// # Arguments +/// * `value` – unscaled decimal value +/// * `input_scale` – scale of the incoming decimal +/// * `decimal_places` – number of fractional digits to keep (may be negative) +/// +/// # Returns +/// The rounded unscaled value at the same `input_scale`, or an error +/// on overflow. +/// +/// # Examples +/// ```text +/// // 2.5 (unscaled 25, scale 1) rounded to 0 places → 3.0 (unscaled 30) +/// round_decimal(25_i128, 1, 0) → Ok(30) +/// +/// // 2.345 (unscaled 2345, scale 3) rounded to 2 places → 2.350 (unscaled 2350) +/// round_decimal(2345_i128, 3, 2) → Ok(2350) +/// ``` +fn round_decimal( + value: V, + input_scale: i8, + decimal_places: i32, +) -> Result { + let diff = i64::from(input_scale) - i64::from(decimal_places); + if diff <= 0 { + // Nothing to round – the requested precision is finer than (or equal to) the + // stored scale. + return Ok(value); + } + + let diff = diff as u32; + + let one = V::ONE; + let two = V::from_usize(2).ok_or_else(|| { + (exec_err!("Internal error: could not create constant 2") as Result<(), _>) + .unwrap_err() + })?; + let ten = V::from_usize(10).ok_or_else(|| { + (exec_err!("Internal error: could not create constant 10") as Result<(), _>) + .unwrap_err() + })?; + + let Ok(factor) = ten.pow_checked(diff) else { + // 10^diff overflows the decimal type — the rounding position is beyond + // the representable range, so any value rounds to 0. + // This matches Spark's BigDecimal.setScale behavior where rounding to a + // scale far beyond the number's magnitude yields 0. + return Ok(V::ZERO); + }; + + let mut quotient = value.div_wrapping(factor); + let remainder = value.mod_wrapping(factor); + + // HALF_UP: round away from zero when remainder is exactly half + let threshold = factor.div_wrapping(two); + if remainder >= threshold { + quotient = quotient.add_checked(one).map_err(|_| { + (exec_err!("Overflow while rounding decimal") as Result<(), _>).unwrap_err() + })?; + } else if remainder <= threshold.neg_wrapping() { + quotient = quotient.sub_checked(one).map_err(|_| { + (exec_err!("Overflow while rounding decimal") as Result<(), _>).unwrap_err() + })?; + } + + // Re-scale the quotient back to `input_scale` so the returned unscaled integer is + // at the original scale. `factor` is already `10^diff` which is exactly the shift + // we need. + quotient.mul_checked(factor).map_err(|_| { + (exec_err!("Overflow while rounding decimal") as Result<(), _>).unwrap_err() + }) +} + +// --------------------------------------------------------------------------- +// Macros for array dispatch +// --------------------------------------------------------------------------- + +macro_rules! impl_integer_array_round { + ($array:expr, $arrow_type:ty, $scale:expr, $enable_ansi_mode:expr) => {{ + let array = $array.as_primitive::<$arrow_type>(); + type Native = <$arrow_type as arrow::datatypes::ArrowPrimitiveType>::Native; + let result: PrimitiveArray<$arrow_type> = if $enable_ansi_mode { + array.try_unary(|x| { + let v = round_integer(x as i64, $scale, true)?; + Native::try_from(v).map_err(|_| { + (exec_err!( + "{} overflow on round({x}, {})", + stringify!($arrow_type), + $scale + ) as Result<(), _>) + .unwrap_err() + }) + })? + } else { + array.unary(|x| round_integer(x as i64, $scale, false).unwrap() as Native) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +macro_rules! impl_float_array_round { + ($array:expr, $arrow_type:ty, $scale:expr) => {{ + let array = $array.as_primitive::<$arrow_type>(); + let result: PrimitiveArray<$arrow_type> = array.unary(|x| round_float(x, $scale)); + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +macro_rules! impl_decimal_array_round { + ($array:expr, $arrow_type:ty, $input_scale:expr, $scale:expr) => {{ + let array = $array.as_primitive::<$arrow_type>(); + let result: PrimitiveArray<$arrow_type> = array + .try_unary(|x| round_decimal(x, $input_scale, $scale))? + .with_data_type($array.data_type().clone()); + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +// --------------------------------------------------------------------------- +// Core dispatch +// --------------------------------------------------------------------------- + +fn spark_round(args: &[ColumnarValue], enable_ansi_mode: bool) -> Result { + if args.is_empty() || args.len() > 2 { + return exec_err!("round requires 1 or 2 arguments, got {}", args.len()); + } + + let scale = match get_scale(args)? { + Some(s) => s, + None => { + // NULL scale → return NULL with the same data type as the first argument + return Ok(ColumnarValue::Scalar(ScalarValue::try_from( + args[0].data_type(), + )?)); + } + }; + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Null => Ok(args[0].clone()), + + // Integer types + DataType::Int8 => { + impl_integer_array_round!(array, Int8Type, scale, enable_ansi_mode) + } + DataType::Int16 => { + impl_integer_array_round!(array, Int16Type, scale, enable_ansi_mode) + } + DataType::Int32 => { + impl_integer_array_round!(array, Int32Type, scale, enable_ansi_mode) + } + DataType::Int64 => { + impl_integer_array_round!(array, Int64Type, scale, enable_ansi_mode) + } + + // Unsigned integer types + DataType::UInt8 => { + impl_integer_array_round!(array, UInt8Type, scale, enable_ansi_mode) + } + DataType::UInt16 => { + impl_integer_array_round!(array, UInt16Type, scale, enable_ansi_mode) + } + DataType::UInt32 => { + impl_integer_array_round!(array, UInt32Type, scale, enable_ansi_mode) + } + DataType::UInt64 => { + let array = array.as_primitive::(); + let result: PrimitiveArray = array.try_unary(|x| { + let v_i64 = i64::try_from(x).map_err(|_| { + (exec_err!( + "round: UInt64 value {x} exceeds i64::MAX and cannot be rounded" + ) as Result<(), _>) + .unwrap_err() + })?; + round_integer(v_i64, scale, enable_ansi_mode) + .map(|v| v as u64) + })?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + + // Float types + DataType::Float16 => impl_float_array_round!(array, Float16Type, scale), + DataType::Float32 => impl_float_array_round!(array, Float32Type, scale), + DataType::Float64 => impl_float_array_round!(array, Float64Type, scale), + + // Decimal types + DataType::Decimal32(_, input_scale) => { + impl_decimal_array_round!(array, Decimal32Type, *input_scale, scale) + } + DataType::Decimal64(_, input_scale) => { + impl_decimal_array_round!(array, Decimal64Type, *input_scale, scale) + } + DataType::Decimal128(_, input_scale) => { + impl_decimal_array_round!(array, Decimal128Type, *input_scale, scale) + } + DataType::Decimal256(_, input_scale) => { + impl_decimal_array_round!(array, Decimal256Type, *input_scale, scale) + } + + dt => not_impl_err!("Unsupported data type for Spark round(): {dt}"), + }, + + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Null => Ok(args[0].clone()), + _ if sv.is_null() => Ok(args[0].clone()), + + // Integer scalars + ScalarValue::Int8(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + i8::try_from(r).map_err(|_| { + (exec_err!("Int8 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as i8 + }; + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(result)))) + } + ScalarValue::Int16(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + i16::try_from(r).map_err(|_| { + (exec_err!("Int16 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as i16 + }; + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(result)))) + } + ScalarValue::Int32(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + i32::try_from(r).map_err(|_| { + (exec_err!("Int32 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as i32 + }; + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result)))) + } + ScalarValue::Int64(Some(v)) => { + let result = round_integer(*v, scale, enable_ansi_mode)?; + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result)))) + } + + // Unsigned integer scalars + ScalarValue::UInt8(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + u8::try_from(r).map_err(|_| { + (exec_err!("UInt8 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as u8 + }; + Ok(ColumnarValue::Scalar(ScalarValue::UInt8(Some(result)))) + } + ScalarValue::UInt16(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + u16::try_from(r).map_err(|_| { + (exec_err!("UInt16 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as u16 + }; + Ok(ColumnarValue::Scalar(ScalarValue::UInt16(Some(result)))) + } + ScalarValue::UInt32(Some(v)) => { + let r = round_integer(i64::from(*v), scale, enable_ansi_mode)?; + let result = if enable_ansi_mode { + u32::try_from(r).map_err(|_| { + (exec_err!("UInt32 overflow on round({v}, {scale})") + as Result<(), _>) + .unwrap_err() + })? + } else { + r as u32 + }; + Ok(ColumnarValue::Scalar(ScalarValue::UInt32(Some(result)))) + } + ScalarValue::UInt64(Some(v)) => { + let v_i64 = i64::try_from(*v).map_err(|_| { + (exec_err!( + "round: UInt64 value {v} exceeds i64::MAX and cannot be rounded" + ) as Result<(), _>) + .unwrap_err() + })?; + let result = round_integer(v_i64, scale, enable_ansi_mode)?; + Ok(ColumnarValue::Scalar(ScalarValue::UInt64(Some( + result as u64, + )))) + } + + // Float scalars + ScalarValue::Float16(Some(v)) => { + let result = round_float(*v, scale); + Ok(ColumnarValue::Scalar(ScalarValue::Float16(Some(result)))) + } + ScalarValue::Float32(Some(v)) => { + let result = round_float(*v, scale); + Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(result)))) + } + ScalarValue::Float64(Some(v)) => { + let result = round_float(*v, scale); + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(result)))) + } + + // Decimal scalars + ScalarValue::Decimal32(Some(v), precision, input_scale) => { + let rounded = round_decimal(*v, *input_scale, scale)?; + Ok(ColumnarValue::Scalar(ScalarValue::Decimal32( + Some(rounded), + *precision, + *input_scale, + ))) + } + ScalarValue::Decimal64(Some(v), precision, input_scale) => { + let rounded = round_decimal(*v, *input_scale, scale)?; + Ok(ColumnarValue::Scalar(ScalarValue::Decimal64( + Some(rounded), + *precision, + *input_scale, + ))) + } + ScalarValue::Decimal128(Some(v), precision, input_scale) => { + let rounded = round_decimal(*v, *input_scale, scale)?; + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(rounded), + *precision, + *input_scale, + ))) + } + ScalarValue::Decimal256(Some(v), precision, input_scale) => { + let rounded = round_decimal(*v, *input_scale, scale)?; + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + Some(rounded), + *precision, + *input_scale, + ))) + } + + dt => not_impl_err!("Unsupported data type for Spark round(): {dt}"), + }, + } +} diff --git a/datafusion/spark/src/function/math/trigonometry.rs b/datafusion/spark/src/function/math/trigonometry.rs index 85b10f5b998c6..b3853d66d9be1 100644 --- a/datafusion/spark/src/function/math/trigonometry.rs +++ b/datafusion/spark/src/function/math/trigonometry.rs @@ -23,7 +23,6 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; -use std::any::Any; use std::sync::Arc; static CSC_FUNCTION_NAME: &str = "csc"; @@ -49,10 +48,6 @@ impl SparkCsc { } impl ScalarUDFImpl for SparkCsc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { CSC_FUNCTION_NAME } @@ -119,10 +114,6 @@ impl SparkSec { } impl ScalarUDFImpl for SparkSec { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { SEC_FUNCTION_NAME } diff --git a/datafusion/spark/src/function/math/unhex.rs b/datafusion/spark/src/function/math/unhex.rs new file mode 100644 index 0000000000000..6739e6a15c582 --- /dev/null +++ b/datafusion/spark/src/function/math/unhex.rs @@ -0,0 +1,216 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, BinaryBuilder}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; +use datafusion_common::types::logical_string; +use datafusion_common::utils::take_function_args; +use datafusion_common::{ + DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err, +}; +use datafusion_expr::{ + Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignatureClass, Volatility, +}; +use std::sync::Arc; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkUnhex { + signature: Signature, +} + +impl Default for SparkUnhex { + fn default() -> Self { + Self::new() + } +} + +impl SparkUnhex { + pub fn new() -> Self { + let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string())); + + Self { + signature: Signature::coercible(vec![string], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkUnhex { + fn name(&self) -> &str { + "unhex" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Binary) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_unhex(&args.args) + } +} + +#[inline] +fn hex_nibble(c: u8) -> Option { + match c { + b'0'..=b'9' => Some(c - b'0'), + b'a'..=b'f' => Some(c - b'a' + 10), + b'A'..=b'F' => Some(c - b'A' + 10), + _ => None, + } +} + +/// Decodes a hex-encoded byte slice into binary data. +/// Returns `true` if decoding succeeded, `false` if the input contains invalid hex characters. +fn unhex_common(bytes: &[u8], out: &mut Vec) -> bool { + if bytes.is_empty() { + return true; + } + + let mut i = 0usize; + + // If the hex string length is odd, implicitly left-pad with '0'. + if (bytes.len() & 1) == 1 { + match hex_nibble(bytes[0]) { + // Equivalent to (0 << 4) | lo + Some(lo) => out.push(lo), + None => return false, + } + i = 1; + } + + while i + 1 < bytes.len() { + match (hex_nibble(bytes[i]), hex_nibble(bytes[i + 1])) { + (Some(hi), Some(lo)) => out.push((hi << 4) | lo), + _ => return false, + } + i += 2; + } + + true +} + +/// Converts an iterator of hex strings to a binary array. +fn unhex_array( + iter: I, + len: usize, + capacity: usize, +) -> Result +where + I: Iterator>, + T: AsRef, +{ + let mut builder = BinaryBuilder::with_capacity(len, capacity); + let mut buffer = Vec::new(); + + for v in iter { + if let Some(s) = v { + buffer.clear(); + let additional = s.as_ref().len().div_ceil(2); + buffer.try_reserve(additional).map_err(|e| { + exec_datafusion_err!( + "failed to reserve {additional} bytes for unhex output: {e}" + ) + })?; + if unhex_common(s.as_ref().as_bytes(), &mut buffer) { + builder.append_value(&buffer); + } else { + builder.append_null(); + } + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) +} + +/// Convert a single hex string to binary +fn unhex_scalar(s: &str) -> Option> { + let mut buffer = Vec::with_capacity(s.len().div_ceil(2)); + if unhex_common(s.as_bytes(), &mut buffer) { + Some(buffer) + } else { + None + } +} + +fn spark_unhex(args: &[ColumnarValue]) -> Result { + let [args] = take_function_args("unhex", args)?; + + match args { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8 => { + let array = as_string_array(array)?; + let capacity = array.values().len().div_ceil(2); + Ok(ColumnarValue::Array(unhex_array( + array.iter(), + array.len(), + capacity, + )?)) + } + DataType::Utf8View => { + let array = as_string_view_array(array)?; + // Estimate capacity since StringViewArray data can be scattered or inlined. + let capacity = array.len() * 32; + Ok(ColumnarValue::Array(unhex_array( + array.iter(), + array.len(), + capacity, + )?)) + } + DataType::LargeUtf8 => { + let array = as_large_string_array(array)?; + let capacity = array.values().len().div_ceil(2); + Ok(ColumnarValue::Array(unhex_array( + array.iter(), + array.len(), + capacity, + )?)) + } + _ => exec_err!( + "unhex only supports string argument, but got: {}", + array.data_type() + ), + }, + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Utf8(None) + | ScalarValue::Utf8View(None) + | ScalarValue::LargeUtf8(None) => { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) + } + ScalarValue::Utf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(unhex_scalar(s)))) + } + _ => { + exec_err!( + "unhex only supports string argument, but got: {}", + sv.data_type() + ) + } + }, + } +} diff --git a/datafusion/spark/src/function/math/width_bucket.rs b/datafusion/spark/src/function/math/width_bucket.rs index 45a0d843b7ed7..79da924116d2e 100644 --- a/datafusion/spark/src/function/math/width_bucket.rs +++ b/datafusion/spark/src/function/math/width_bucket.rs @@ -15,10 +15,8 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; -use crate::function::error_utils::unsupported_data_types_exec_err; use arrow::array::{ Array, ArrayRef, DurationMicrosecondArray, Float64Array, IntervalMonthDayNanoArray, IntervalYearMonthArray, @@ -27,17 +25,24 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Duration, Float64, Int32, Interval}; use arrow::datatypes::IntervalUnit::{MonthDayNano, YearMonth}; use datafusion_common::cast::{ - as_duration_microsecond_array, as_float64_array, as_int32_array, + as_duration_microsecond_array, as_float64_array, as_int64_array, as_interval_mdn_array, as_interval_ym_array, }; -use datafusion_common::{exec_err, Result}; +use datafusion_common::types::{ + NativeType, logical_duration_microsecond, logical_float64, logical_int64, + logical_interval_mdn, logical_interval_year_month, +}; +use datafusion_common::{Result, exec_err, internal_err}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::type_coercion::is_signed_numeric; -use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, +}; use datafusion_functions::utils::make_scalar_function; -use arrow::array::{Int32Array, Int32Builder}; +use arrow::array::{Int32Array, Int32Builder, Int64Array}; use arrow::datatypes::TimeUnit::Microsecond; +use datafusion_expr::Coercion; use datafusion_expr::Volatility::Immutable; #[derive(Debug, PartialEq, Eq, Hash)] @@ -53,17 +58,64 @@ impl Default for SparkWidthBucket { impl SparkWidthBucket { pub fn new() -> Self { + let numeric = Coercion::new_implicit( + TypeSignatureClass::Native(logical_float64()), + vec![TypeSignatureClass::Numeric], + NativeType::Float64, + ); + let duration = Coercion::new_implicit( + TypeSignatureClass::Native(logical_duration_microsecond()), + vec![TypeSignatureClass::Duration], + NativeType::Duration(Microsecond), + ); + let interval_ym = Coercion::new_exact(TypeSignatureClass::Native( + logical_interval_year_month(), + )); + let interval_mdn = + Coercion::new_exact(TypeSignatureClass::Native(logical_interval_mdn())); + let bucket = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Integer], + NativeType::Int64, + ); + let type_signature = Signature::one_of( + vec![ + TypeSignature::Coercible(vec![ + numeric.clone(), + numeric.clone(), + numeric.clone(), + bucket.clone(), + ]), + TypeSignature::Coercible(vec![ + duration.clone(), + duration.clone(), + duration.clone(), + bucket.clone(), + ]), + TypeSignature::Coercible(vec![ + interval_ym.clone(), + interval_ym.clone(), + interval_ym.clone(), + bucket.clone(), + ]), + TypeSignature::Coercible(vec![ + interval_mdn.clone(), + interval_mdn.clone(), + interval_mdn.clone(), + bucket.clone(), + ]), + ], + Immutable, + ) + .with_parameter_names(vec!["expr", "min", "max", "num_buckets"]) + .expect("valid parameter names"); Self { - signature: Signature::user_defined(Immutable), + signature: type_signature, } } } impl ScalarUDFImpl for SparkWidthBucket { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "width_bucket" } @@ -88,63 +140,6 @@ impl ScalarUDFImpl for SparkWidthBucket { Ok(SortProperties::default()) } } - - fn coerce_types(&self, types: &[DataType]) -> Result> { - use DataType::*; - - let (v, lo, hi, n) = (&types[0], &types[1], &types[2], &types[3]); - - match (v, lo, hi, n) { - (a, b, c, &(Int8 | Int16 | Int32 | Int64)) - if is_signed_numeric(a) - && is_signed_numeric(b) - && is_signed_numeric(c) => - { - Ok(vec![Float64, Float64, Float64, Int32]) - } - ( - &Duration(_), - &Duration(_), - &Duration(_), - &(Int8 | Int16 | Int32 | Int64), - ) => Ok(vec![ - Duration(Microsecond), - Duration(Microsecond), - Duration(Microsecond), - Int32, - ]), - ( - &Interval(MonthDayNano), - &Interval(MonthDayNano), - &Interval(MonthDayNano), - &(Int8 | Int16 | Int32 | Int64), - ) => Ok(vec![ - Interval(MonthDayNano), - Interval(MonthDayNano), - Interval(MonthDayNano), - Int32, - ]), - ( - &Interval(YearMonth), - &Interval(YearMonth), - &Interval(YearMonth), - &(Int8 | Int16 | Int32 | Int64), - ) => Ok(vec![ - Interval(YearMonth), - Interval(YearMonth), - Interval(YearMonth), - Int32, - ]), - - _ => exec_err!( - "width_bucket expects a numeric argument, got {} {} {} {}", - types[0], - types[1], - types[2], - types[3] - ), - } - } } fn width_bucket_kern(args: &[ArrayRef]) -> Result { @@ -160,42 +155,40 @@ fn width_bucket_kern(args: &[ArrayRef]) -> Result { let v = as_float64_array(v)?; let min = as_float64_array(minv)?; let max = as_float64_array(maxv)?; - let n_bucket = as_int32_array(nb)?; + let n_bucket = as_int64_array(nb)?; Ok(Arc::new(width_bucket_float64(v, min, max, n_bucket))) } Duration(Microsecond) => { let v = as_duration_microsecond_array(v)?; let min = as_duration_microsecond_array(minv)?; let max = as_duration_microsecond_array(maxv)?; - let n_bucket = as_int32_array(nb)?; + let n_bucket = as_int64_array(nb)?; Ok(Arc::new(width_bucket_i64_as_float(v, min, max, n_bucket))) } Interval(YearMonth) => { let v = as_interval_ym_array(v)?; let min = as_interval_ym_array(minv)?; let max = as_interval_ym_array(maxv)?; - let n_bucket = as_int32_array(nb)?; + let n_bucket = as_int64_array(nb)?; Ok(Arc::new(width_bucket_i32_as_float(v, min, max, n_bucket))) } Interval(MonthDayNano) => { let v = as_interval_mdn_array(v)?; let min = as_interval_mdn_array(minv)?; let max = as_interval_mdn_array(maxv)?; - let n_bucket = as_int32_array(nb)?; - Ok(Arc::new(width_bucket_interval_mdn_exact(v, min, max, n_bucket))) + let n_bucket = as_int64_array(nb)?; + Ok(Arc::new(width_bucket_interval_mdn_exact( + v, min, max, n_bucket, + ))) } - - other => Err(unsupported_data_types_exec_err( - "width_bucket", - "Float/Decimal OR Duration OR Interval(YearMonth) for first 3 args; Int for 4th", - &[ - other.clone(), - minv.data_type().clone(), - maxv.data_type().clone(), - nb.data_type().clone(), - ], - )), + other => internal_err!( + "width_bucket received unexpected data types: {:?}, {:?}, {:?}, {:?}", + other, + minv.data_type(), + maxv.data_type(), + nb.data_type() + ), } } @@ -205,7 +198,7 @@ macro_rules! width_bucket_kernel_impl { v: &$arr_ty, min: &$arr_ty, max: &$arr_ty, - n_bucket: &Int32Array, + n_bucket: &Int64Array, ) -> Int32Array { let len = v.len(); let mut b = Int32Builder::with_capacity(len); @@ -225,6 +218,7 @@ macro_rules! width_bucket_kernel_impl { b.append_null(); continue; } + let next_bucket = (buckets + 1) as i32; if $check_nan { if !x.is_finite() || !l.is_finite() || !h.is_finite() { b.append_null(); @@ -239,11 +233,11 @@ macro_rules! width_bucket_kernel_impl { continue; } }; - if matches!(ord, std::cmp::Ordering::Equal) { + if ord == std::cmp::Ordering::Equal { b.append_null(); continue; } - let asc = matches!(ord, std::cmp::Ordering::Less); + let asc = ord == std::cmp::Ordering::Less; if asc { if x < l { @@ -251,7 +245,7 @@ macro_rules! width_bucket_kernel_impl { continue; } if x >= h { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } else { @@ -260,7 +254,7 @@ macro_rules! width_bucket_kernel_impl { continue; } if x <= h { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } @@ -274,8 +268,8 @@ macro_rules! width_bucket_kernel_impl { if bucket < 1 { bucket = 1; } - if bucket > buckets + 1 { - bucket = buckets + 1; + if bucket > next_bucket { + bucket = next_bucket; } b.append_value(bucket); @@ -311,7 +305,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( v: &IntervalMonthDayNanoArray, lo: &IntervalMonthDayNanoArray, hi: &IntervalMonthDayNanoArray, - n: &Int32Array, + n: &Int64Array, ) -> Int32Array { let len = v.len(); let mut b = Int32Builder::with_capacity(len); @@ -326,6 +320,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( b.append_null(); continue; } + let next_bucket = (buckets + 1) as i32; let x = v.value(i); let l = lo.value(i); @@ -351,7 +346,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( continue; } if x_m >= h_m { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } else { @@ -360,7 +355,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( continue; } if x_m <= h_m { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } @@ -375,8 +370,8 @@ pub(crate) fn width_bucket_interval_mdn_exact( if bucket < 1 { bucket = 1; } - if bucket > buckets + 1 { - bucket = buckets + 1; + if bucket > next_bucket { + bucket = next_bucket; } b.append_value(bucket); continue; @@ -402,7 +397,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( continue; } if x_f >= h_f { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } else { @@ -411,7 +406,7 @@ pub(crate) fn width_bucket_interval_mdn_exact( continue; } if x_f <= h_f { - b.append_value(buckets + 1); + b.append_value(next_bucket); continue; } } @@ -426,8 +421,8 @@ pub(crate) fn width_bucket_interval_mdn_exact( if bucket < 1 { bucket = 1; } - if bucket > buckets + 1 { - bucket = buckets + 1; + if bucket > next_bucket { + bucket = next_bucket; } b.append_value(bucket); continue; @@ -442,18 +437,17 @@ pub(crate) fn width_bucket_interval_mdn_exact( #[cfg(test)] mod tests { use super::*; - use std::sync::Arc; use arrow::array::{ - ArrayRef, DurationMicrosecondArray, Float64Array, Int32Array, + ArrayRef, DurationMicrosecondArray, Float64Array, Int32Array, Int64Array, IntervalYearMonthArray, }; use arrow::datatypes::IntervalMonthDayNano; // --- Helpers ------------------------------------------------------------- - fn i32_array_all(len: usize, val: i32) -> Arc { - Arc::new(Int32Array::from(vec![val; len])) + fn i64_array_all(len: usize, val: i64) -> Arc { + Arc::new(Int64Array::from(vec![val; len])) } fn f64_array(vals: &[f64]) -> Arc { @@ -491,7 +485,7 @@ mod tests { let v = f64_array(&[0.5, 1.0, 9.9, -1.0, 10.0]); let lo = f64_array(&[0.0, 0.0, 0.0, 0.0, 0.0]); let hi = f64_array(&[10.0, 10.0, 10.0, 10.0, 10.0]); - let n = i32_array_all(5, 10); + let n = i64_array_all(5, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -503,7 +497,7 @@ mod tests { let v = f64_array(&[9.9, 10.0, 0.0, -0.1, 10.1]); let lo = f64_array(&[10.0; 5]); let hi = f64_array(&[0.0; 5]); - let n = i32_array_all(5, 10); + let n = i64_array_all(5, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -515,7 +509,7 @@ mod tests { let v = f64_array(&[0.0, 9.999999999, 10.0]); let lo = f64_array(&[0.0; 3]); let hi = f64_array(&[10.0; 3]); - let n = i32_array_all(3, 10); + let n = i64_array_all(3, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -527,7 +521,7 @@ mod tests { let v = f64_array(&[10.0, 0.0, -0.000001]); let lo = f64_array(&[10.0; 3]); let hi = f64_array(&[0.0; 3]); - let n = i32_array_all(3, 10); + let n = i64_array_all(3, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -539,7 +533,7 @@ mod tests { let v = f64_array(&[1.0, 5.0, 9.0]); let lo = f64_array(&[0.0, 0.0, 0.0]); let hi = f64_array(&[10.0, 10.0, 10.0]); - let n = Arc::new(Int32Array::from(vec![0, -1, 10])); + let n = Arc::new(Int64Array::from(vec![0, -1, 10])); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); assert!(out.is_null(0)); @@ -549,7 +543,7 @@ mod tests { let v = f64_array(&[1.0]); let lo = f64_array(&[5.0]); let hi = f64_array(&[5.0]); - let n = i32_array_all(1, 10); + let n = i64_array_all(1, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); assert!(out.is_null(0)); @@ -557,7 +551,7 @@ mod tests { let v = f64_array_opt(&[Some(f64::NAN)]); let lo = f64_array(&[0.0]); let hi = f64_array(&[10.0]); - let n = i32_array_all(1, 10); + let n = i64_array_all(1, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); assert!(out.is_null(0)); @@ -568,7 +562,7 @@ mod tests { let v = f64_array_opt(&[None, Some(1.0), Some(2.0), Some(3.0)]); let lo = f64_array(&[0.0; 4]); let hi = f64_array(&[10.0; 4]); - let n = i32_array_all(4, 10); + let n = i64_array_all(4, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -580,7 +574,7 @@ mod tests { let v = f64_array(&[1.0]); let lo = f64_array_opt(&[None]); let hi = f64_array(&[10.0]); - let n = i32_array_all(1, 10); + let n = i64_array_all(1, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); assert!(out.is_null(0)); @@ -593,7 +587,7 @@ mod tests { let v = dur_us_array(&[1_000_000, 0, -1]); let lo = dur_us_array(&[0, 0, 0]); let hi = dur_us_array(&[2_000_000, 2_000_000, 2_000_000]); - let n = i32_array_all(3, 2); + let n = i64_array_all(3, 2); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -605,7 +599,7 @@ mod tests { let v = dur_us_array(&[0]); let lo = dur_us_array(&[1]); let hi = dur_us_array(&[1]); - let n = i32_array_all(1, 10); + let n = i64_array_all(1, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); assert!(downcast_i32(&out).is_null(0)); } @@ -617,7 +611,7 @@ mod tests { let v = ym_array(&[0, 5, 11, 12, 13]); let lo = ym_array(&[0; 5]); let hi = ym_array(&[12; 5]); - let n = i32_array_all(5, 12); + let n = i64_array_all(5, 12); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -629,7 +623,7 @@ mod tests { let v = ym_array(&[11, 12, 0, -1, 13]); let lo = ym_array(&[12; 5]); let hi = ym_array(&[0; 5]); - let n = i32_array_all(5, 12); + let n = i64_array_all(5, 12); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -643,7 +637,7 @@ mod tests { let v = mdn_array(&[(0, 0, 0), (5, 0, 0), (11, 0, 0), (12, 0, 0), (13, 0, 0)]); let lo = mdn_array(&[(0, 0, 0); 5]); let hi = mdn_array(&[(12, 0, 0); 5]); - let n = i32_array_all(5, 12); + let n = i64_array_all(5, 12); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -655,7 +649,7 @@ mod tests { let v = mdn_array(&[(11, 0, 0), (12, 0, 0), (0, 0, 0), (-1, 0, 0), (13, 0, 0)]); let lo = mdn_array(&[(12, 0, 0); 5]); let hi = mdn_array(&[(0, 0, 0); 5]); - let n = i32_array_all(5, 12); + let n = i64_array_all(5, 12); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -675,7 +669,7 @@ mod tests { ]); let lo = mdn_array(&[(0, 0, 0); 6]); let hi = mdn_array(&[(0, 10, 0); 6]); - let n = i32_array_all(6, 10); + let n = i64_array_all(6, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -688,7 +682,7 @@ mod tests { let v = mdn_array(&[(0, 9, 0), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]); let lo = mdn_array(&[(0, 10, 0); 5]); let hi = mdn_array(&[(0, 0, 0); 5]); - let n = i32_array_all(5, 10); + let n = i64_array_all(5, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -700,7 +694,7 @@ mod tests { let v = mdn_array(&[(0, 9, 1), (0, 10, 0), (0, 0, 0), (0, -1, 0), (0, 11, 0)]); let lo = mdn_array(&[(0, 10, 0); 5]); let hi = mdn_array(&[(0, 0, 0); 5]); - let n = i32_array_all(5, 10); + let n = i64_array_all(5, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -713,7 +707,7 @@ mod tests { let v = mdn_array(&[(0, 1, 0)]); let lo = mdn_array(&[(0, 0, 0)]); let hi = mdn_array(&[(1, 1, 0)]); - let n = i32_array_all(1, 4); + let n = i64_array_all(1, 4); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -725,7 +719,7 @@ mod tests { let v = mdn_array(&[(0, 0, 0)]); let lo = mdn_array(&[(1, 2, 3)]); let hi = mdn_array(&[(1, 2, 3)]); // lo == hi - let n = i32_array_all(1, 10); + let n = i64_array_all(1, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); assert!(downcast_i32(&out).is_null(0)); @@ -736,7 +730,7 @@ mod tests { let v = mdn_array(&[(0, 0, 0)]); let lo = mdn_array(&[(0, 0, 0)]); let hi = mdn_array(&[(0, 10, 0)]); - let n = Arc::new(Int32Array::from(vec![0])); // n <= 0 + let n = Arc::new(Int64Array::from(vec![0])); // n <= 0 let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); assert!(downcast_i32(&out).is_null(0)); @@ -750,7 +744,7 @@ mod tests { ])); let lo = mdn_array(&[(0, 0, 0), (0, 0, 0)]); let hi = mdn_array(&[(0, 10, 0), (0, 10, 0)]); - let n = i32_array_all(2, 10); + let n = i64_array_all(2, 10); let out = width_bucket_kern(&[v, lo, hi, n]).unwrap(); let out = downcast_i32(&out); @@ -775,13 +769,12 @@ mod tests { let v: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); let lo = f64_array(&[0.0, 0.0, 0.0]); let hi = f64_array(&[10.0, 10.0, 10.0]); - let n = i32_array_all(3, 10); + let n = i64_array_all(3, 10); let err = width_bucket_kern(&[v, lo, hi, n]).unwrap_err(); let msg = format!("{err}"); assert!( - msg.contains("unsupported data types") - || msg.contains("Float/Decimal OR Duration OR Interval(YearMonth)"), + msg.contains("width_bucket received unexpected data types"), "unexpected error: {msg}" ); } diff --git a/datafusion/spark/src/function/mod.rs b/datafusion/spark/src/function/mod.rs index 3f4f94cfaaf8c..d5dd60c3545a5 100644 --- a/datafusion/spark/src/function/mod.rs +++ b/datafusion/spark/src/function/mod.rs @@ -33,6 +33,7 @@ pub mod lambda; pub mod map; pub mod math; pub mod misc; +mod null_utils; pub mod predicate; pub mod string; pub mod r#struct; diff --git a/datafusion/spark/src/function/null_utils.rs b/datafusion/spark/src/function/null_utils.rs new file mode 100644 index 0000000000000..886b45d746510 --- /dev/null +++ b/datafusion/spark/src/function/null_utils.rs @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Array; +use arrow::buffer::NullBuffer; +use arrow::datatypes::DataType; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::ColumnarValue; +use std::sync::Arc; + +pub(crate) enum NullMaskResolution { + /// All inputs are scalars and at least one is NULL -> return NULL + ReturnNull, + /// All inputs are non-NULL -> no null mask needed + NoMask, + /// Null mask to apply for arrays + Apply(NullBuffer), +} + +pub(crate) fn compute_null_mask(args: &[ColumnarValue]) -> NullMaskResolution { + let mut array_len = None; + let mut has_null_scalar = false; + + for arg in args { + match arg { + ColumnarValue::Array(array) => { + array_len.get_or_insert_with(|| array.len()); + } + ColumnarValue::Scalar(scalar) => { + has_null_scalar |= scalar.is_null(); + } + } + } + + let Some(array_len) = array_len else { + // All arguments are scalars + return if has_null_scalar { + NullMaskResolution::ReturnNull + } else { + NullMaskResolution::NoMask + }; + }; + + if has_null_scalar { + return NullMaskResolution::Apply(NullBuffer::new_null(array_len)); + } + + let combined_nulls = + NullBuffer::union_many(args.iter().filter_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.nulls()), + ColumnarValue::Scalar(_) => None, + })); + + match combined_nulls { + Some(nulls) => NullMaskResolution::Apply(nulls), + None => NullMaskResolution::NoMask, + } +} + +/// Apply NULL mask to the result using NullBuffer::union +pub(crate) fn apply_null_mask( + result: ColumnarValue, + null_mask: NullMaskResolution, + return_type: &DataType, +) -> Result { + match (result, null_mask) { + // Scalar with ReturnNull mask means return NULL of the correct type + (ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => { + Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?)) + } + // Scalar without mask, return as-is + (scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar), + // Array with NULL mask - use NullBuffer::union to combine nulls + (ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => { + // Combine the result's existing nulls with our computed null mask + let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask)); + + // Create new array with combined nulls + let new_array = array + .into_data() + .into_builder() + .nulls(combined_nulls) + .build()?; + + Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array( + new_array, + )))) + } + // Array without NULL mask, return as-is + (array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array), + // Edge cases that shouldn't happen in practice + (scalar, _) => Ok(scalar), + } +} diff --git a/datafusion/spark/src/function/string/ascii.rs b/datafusion/spark/src/function/string/ascii.rs index f14a66d4e484d..7846d3c681c71 100644 --- a/datafusion/spark/src/function/string/ascii.rs +++ b/datafusion/spark/src/function/string/ascii.rs @@ -15,13 +15,17 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::DataType; -use datafusion_common::Result; -use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::types::{NativeType, logical_string}; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::{ + Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignatureClass, Volatility, +}; use datafusion_functions::string::ascii::ascii; use datafusion_functions::utils::make_scalar_function; -use std::any::Any; /// Spark compatible version of the [ascii] function. Differs from the [default ascii function] /// in that it is more permissive of input types, for example casting numeric input to string @@ -42,17 +46,22 @@ impl Default for SparkAscii { impl SparkAscii { pub fn new() -> Self { + // Spark's ascii uses ImplicitCastInputTypes with StringType, + // which allows numeric types to be implicitly cast to String. + // See: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala + let string_coercion = Coercion::new_implicit( + TypeSignatureClass::Native(logical_string()), + vec![TypeSignatureClass::Numeric], + NativeType::String, + ); + Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::coercible(vec![string_coercion], Volatility::Immutable), } } } impl ScalarUDFImpl for SparkAscii { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "ascii" } @@ -62,14 +71,60 @@ impl ScalarUDFImpl for SparkAscii { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Int32) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + // ascii returns an Int32 value + // The result is nullable only if any of the input arguments is nullable + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new("ascii", DataType::Int32, nullable))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { make_scalar_function(ascii, vec![])(&args.args) } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_return_field_nullable_input() { + let ascii_func = SparkAscii::new(); + let nullable_field = Arc::new(Field::new("input", DataType::Utf8, true)); + + let result = ascii_func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[nullable_field], + scalar_arguments: &[], + }) + .unwrap(); + + assert_eq!(result.data_type(), &DataType::Int32); + assert!( + result.is_nullable(), + "Output should be nullable when input is nullable" + ); + } + + #[test] + fn test_return_field_non_nullable_input() { + let ascii_func = SparkAscii::new(); + let non_nullable_field = Arc::new(Field::new("input", DataType::Utf8, false)); + + let result = ascii_func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[non_nullable_field], + scalar_arguments: &[], + }) + .unwrap(); - fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { - Ok(vec![DataType::Utf8]) + assert_eq!(result.data_type(), &DataType::Int32); + assert!( + !result.is_nullable(), + "Output should not be nullable when input is not nullable" + ); } } diff --git a/datafusion/spark/src/function/string/base64.rs b/datafusion/spark/src/function/string/base64.rs new file mode 100644 index 0000000000000..95607f374b32f --- /dev/null +++ b/datafusion/spark/src/function/string/base64.rs @@ -0,0 +1,174 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::{Field, FieldRef}; +use datafusion_common::types::{NativeType, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, exec_err, internal_err}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::{Coercion, Expr, ReturnFieldArgs, TypeSignatureClass, lit}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::expr_fn::{decode, encode}; + +/// Apache Spark base64 uses padded base64 encoding. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkBase64 { + signature: Signature, +} + +impl Default for SparkBase64 { + fn default() -> Self { + Self::new() + } +} + +impl SparkBase64 { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![Coercion::new_implicit( + TypeSignatureClass::Binary, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Binary, + )], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkBase64 { + fn name(&self) -> &str { + "base64" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_type should not be called for {}", self.name()) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result { + let [bin] = take_function_args(self.name(), args.arg_fields)?; + let return_type = match bin.data_type() { + DataType::LargeBinary => DataType::LargeUtf8, + _ => DataType::Utf8, + }; + Ok(Arc::new(Field::new( + self.name(), + return_type, + bin.is_nullable(), + ))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + exec_err!( + "invoke should not be called on a simplified {} function", + self.name() + ) + } + + fn simplify( + &self, + args: Vec, + _info: &SimplifyContext, + ) -> Result { + let [bin] = take_function_args(self.name(), args)?; + Ok(ExprSimplifyResult::Simplified(encode( + bin, + lit("base64pad"), + ))) + } +} + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkUnBase64 { + signature: Signature, +} + +impl Default for SparkUnBase64 { + fn default() -> Self { + Self::new() + } +} + +impl SparkUnBase64 { + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![Coercion::new_implicit( + TypeSignatureClass::Binary, + vec![TypeSignatureClass::Native(logical_string())], + NativeType::Binary, + )], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkUnBase64 { + fn name(&self) -> &str { + "unbase64" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_type should not be called for {}", self.name()) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result { + let [str] = take_function_args(self.name(), args.arg_fields)?; + let return_type = match str.data_type() { + DataType::LargeBinary => DataType::LargeBinary, + _ => DataType::Binary, + }; + Ok(Arc::new(Field::new( + self.name(), + return_type, + str.is_nullable(), + ))) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + exec_err!("{} should have been simplified", self.name()) + } + + fn simplify( + &self, + args: Vec, + _info: &SimplifyContext, + ) -> Result { + let [bin] = take_function_args(self.name(), args)?; + Ok(ExprSimplifyResult::Simplified(decode( + bin, + lit("base64pad"), + ))) + } +} diff --git a/datafusion/spark/src/function/string/char.rs b/datafusion/spark/src/function/string/char.rs index a1813373c65ff..15b00ee98f5c7 100644 --- a/datafusion/spark/src/function/string/char.rs +++ b/datafusion/spark/src/function/string/char.rs @@ -17,14 +17,15 @@ use arrow::array::ArrayRef; use arrow::array::GenericStringBuilder; -use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use arrow::datatypes::DataType::Utf8; -use std::{any::Any, sync::Arc}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use std::sync::Arc; -use datafusion_common::{cast::as_int64_array, exec_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, cast::as_int64_array, exec_err}; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; /// Spark-compatible `char` expression @@ -49,10 +50,6 @@ impl CharFunc { } impl ScalarUDFImpl for CharFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "char" } @@ -62,12 +59,19 @@ impl ScalarUDFImpl for CharFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Utf8) + datafusion_common::internal_err!( + "return_type should not be called, use return_field_from_args instead" + ) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { spark_chr(&args.args) } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new(self.name(), Utf8, nullable))) + } } /// Returns the ASCII character having the binary equivalent to the input expression. @@ -119,7 +123,7 @@ fn chr(args: &[ArrayRef]) -> Result { None => { return exec_err!( "requested character not compatible for encoding." - ) + ); } } } @@ -130,3 +134,48 @@ fn chr(args: &[ArrayRef]) -> Result { Ok(Arc::new(builder.finish()) as ArrayRef) } + +#[test] +fn test_char_nullability() -> Result<()> { + use arrow::datatypes::{DataType::Utf8, Field, FieldRef}; + use datafusion_expr::ReturnFieldArgs; + use std::sync::Arc; + + let func = CharFunc::new(); + + let nullable_field: FieldRef = Arc::new(Field::new("col", Int64, true)); + + let out_nullable = func.return_field_from_args(ReturnFieldArgs { + arg_fields: &[nullable_field], + scalar_arguments: &[None], + })?; + + assert!( + out_nullable.is_nullable(), + "char(col) should be nullable when input column is nullable" + ); + assert_eq!( + out_nullable.data_type(), + &Utf8, + "char always returns Utf8 regardless of input type" + ); + + let non_nullable_field: FieldRef = Arc::new(Field::new("col", Int64, false)); + + let out_non_nullable = func.return_field_from_args(ReturnFieldArgs { + arg_fields: &[non_nullable_field], + scalar_arguments: &[None], + })?; + + assert!( + !out_non_nullable.is_nullable(), + "char(col) should NOT be nullable when input column is NOT nullable" + ); + assert_eq!( + out_non_nullable.data_type(), + &Utf8, + "char always returns Utf8 regardless of input type" + ); + + Ok(()) +} diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index 0dcc58d5bb8ed..57fd6cadd9dde 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -15,18 +15,20 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::Array; -use arrow::buffer::NullBuffer; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::arrow::datatypes::FieldRef; use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::ReturnFieldArgs; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, - Volatility, + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions::string::concat::ConcatFunc; -use std::any::Any; use std::sync::Arc; +use crate::function::null_utils::{ + NullMaskResolution, apply_null_mask, compute_null_mask, +}; + /// Spark-compatible `concat` expression /// /// @@ -50,19 +52,12 @@ impl Default for SparkConcat { impl SparkConcat { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![TypeSignature::UserDefined, TypeSignature::Nullary], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), } } } impl ScalarUDFImpl for SparkConcat { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "concat" } @@ -71,10 +66,6 @@ impl ScalarUDFImpl for SparkConcat { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Utf8) - } - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { spark_concat(args) } @@ -83,16 +74,28 @@ impl ScalarUDFImpl for SparkConcat { // Accept any string types, including zero arguments Ok(arg_types.to_vec()) } -} + fn return_type(&self, _arg_types: &[DataType]) -> Result { + datafusion_common::internal_err!( + "return_type should not be called for Spark concat" + ) + } + fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result { + use DataType::*; + + // Spark semantics: concat returns NULL if ANY input is NULL + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + // Determine return type: Utf8View > LargeUtf8 > Utf8 + let mut dt = &Utf8; + for field in args.arg_fields { + let data_type = field.data_type(); + if data_type == &Utf8View || (data_type == &LargeUtf8 && dt != &Utf8View) { + dt = data_type; + } + } -/// Represents the null state for Spark concat -enum NullMaskResolution { - /// Return NULL as the result (e.g., scalar inputs with at least one NULL) - ReturnNull, - /// No null mask needed (e.g., all scalar inputs are non-NULL) - NoMask, - /// Null mask to apply for arrays - Apply(NullBuffer), + Ok(Arc::new(Field::new("concat", dt.clone(), nullable))) + } } /// Concatenates strings, returning NULL if any input is NULL @@ -109,21 +112,38 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { // Handle zero-argument case: return empty string if arg_values.is_empty() { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8( - Some(String::new()), - ))); + let return_type = return_field.data_type(); + return match return_type { + DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::new(), + )))), + DataType::LargeUtf8 => Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8( + Some(String::new()), + ))), + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8( + Some(String::new()), + ))), + }; } // Step 1: Check for NULL mask in incoming args - let null_mask = compute_null_mask(&arg_values, number_rows)?; + let null_mask = compute_null_mask(&arg_values); // If all scalars and any is NULL, return NULL immediately if matches!(null_mask, NullMaskResolution::ReturnNull) { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + let return_type = return_field.data_type(); + return match return_type { + DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))), + DataType::LargeUtf8 => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) + } + _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + }; } // Step 2: Delegate to DataFusion's concat let concat_func = ConcatFunc::new(); + let return_type = return_field.data_type().clone(); let func_args = ScalarFunctionArgs { args: arg_values, arg_fields, @@ -134,105 +154,14 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { let result = concat_func.invoke_with_args(func_args)?; // Step 3: Apply NULL mask to result - apply_null_mask(result, null_mask) -} - -/// Compute NULL mask for the arguments using NullBuffer::union -fn compute_null_mask( - args: &[ColumnarValue], - number_rows: usize, -) -> Result { - // Check if all arguments are scalars - let all_scalars = args - .iter() - .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); - - if all_scalars { - // For scalars, check if any is NULL - for arg in args { - if let ColumnarValue::Scalar(scalar) = arg { - if scalar.is_null() { - return Ok(NullMaskResolution::ReturnNull); - } - } - } - // No NULLs in scalars - Ok(NullMaskResolution::NoMask) - } else { - // For arrays, compute NULL mask for each row using NullBuffer::union - let array_len = args - .iter() - .find_map(|arg| match arg { - ColumnarValue::Array(array) => Some(array.len()), - _ => None, - }) - .unwrap_or(number_rows); - - // Convert all scalars to arrays for uniform processing - let arrays: Result> = args - .iter() - .map(|arg| match arg { - ColumnarValue::Array(array) => Ok(Arc::clone(array)), - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len), - }) - .collect(); - let arrays = arrays?; - - // Use NullBuffer::union to combine all null buffers - let combined_nulls = arrays - .iter() - .map(|arr| arr.nulls()) - .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls)); - - match combined_nulls { - Some(nulls) => Ok(NullMaskResolution::Apply(nulls)), - None => Ok(NullMaskResolution::NoMask), - } - } -} - -/// Apply NULL mask to the result using NullBuffer::union -fn apply_null_mask( - result: ColumnarValue, - null_mask: NullMaskResolution, -) -> Result { - match (result, null_mask) { - // Scalar with ReturnNull mask means return NULL - (ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) - } - // Scalar without mask, return as-is - (scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar), - // Array with NULL mask - use NullBuffer::union to combine nulls - (ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => { - // Combine the result's existing nulls with our computed null mask - let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask)); - - // Create new array with combined nulls - let new_array = array - .into_data() - .into_builder() - .nulls(combined_nulls) - .build()?; - - Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array( - new_array, - )))) - } - // Array without NULL mask, return as-is - (array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array), - // Edge cases that shouldn't happen in practice - (scalar, _) => Ok(scalar), - } + apply_null_mask(result, null_mask, &return_type) } #[cfg(test)] mod tests { use super::*; use crate::function::utils::test::test_scalar_function; - use arrow::array::StringArray; - use arrow::datatypes::DataType; - use datafusion_common::Result; + use arrow::array::{Array, StringArray}; #[test] fn test_concat_basic() -> Result<()> { @@ -266,4 +195,51 @@ mod tests { ); Ok(()) } + + #[test] + fn test_spark_concat_return_field_non_nullable() -> Result<()> { + let func = SparkConcat::new(); + + let fields = vec![ + Arc::new(Field::new("a", DataType::Utf8, false)), + Arc::new(Field::new("b", DataType::Utf8, false)), + ]; + + let args = ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &[], + }; + + let field = func.return_field_from_args(args)?; + + assert!( + !field.is_nullable(), + "Expected concat result to be non-nullable when all inputs are non-nullable" + ); + + Ok(()) + } + #[test] + fn test_spark_concat_return_field_nullable() -> Result<()> { + let func = SparkConcat::new(); + + let fields = vec![ + Arc::new(Field::new("a", DataType::Utf8, false)), + Arc::new(Field::new("b", DataType::Utf8, true)), + ]; + + let args = ReturnFieldArgs { + arg_fields: &fields, + scalar_arguments: &[], + }; + + let field = func.return_field_from_args(args)?; + + assert!( + field.is_nullable(), + "Expected concat result to be nullable when any input is nullable" + ); + + Ok(()) + } } diff --git a/datafusion/spark/src/function/string/elt.rs b/datafusion/spark/src/function/string/elt.rs index 35a22fe5edb6f..c37ecd1d3fc39 100644 --- a/datafusion/spark/src/function/string/elt.rs +++ b/datafusion/spark/src/function/string/elt.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::{ @@ -25,7 +24,7 @@ use arrow::compute::{can_cast_types, cast}; use arrow::datatypes::DataType::{Int64, Utf8}; use arrow::datatypes::{DataType, Int64Type}; use datafusion_common::cast::as_string_array; -use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result, plan_datafusion_err}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; @@ -51,10 +50,6 @@ impl SparkElt { } impl ScalarUDFImpl for SparkElt { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "elt" } @@ -150,11 +145,6 @@ fn elt(args: &[ArrayRef]) -> Result { mod tests { use super::*; use arrow::array::Int64Array; - use datafusion_common::Result; - - use arrow::array::{ArrayRef, StringArray}; - use datafusion_common::DataFusionError; - use std::sync::Arc; fn run_elt_arrays(arrs: Vec) -> Result> { let arr = elt(&arrs)?; diff --git a/datafusion/spark/src/function/string/format_string.rs b/datafusion/spark/src/function/string/format_string.rs index adb0eb2f09951..68b8fe52338d4 100644 --- a/datafusion/spark/src/function/string/format_string.rs +++ b/datafusion/spark/src/function/string/format_string.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::fmt::Write; use std::sync::Arc; @@ -23,19 +22,19 @@ use core::num::FpCategory; use arrow::{ array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray}, - datatypes::DataType, + datatypes::{DataType, Field, FieldRef}, }; use bigdecimal::{ - num_bigint::{BigInt, Sign}, BigDecimal, ToPrimitive, + num_bigint::{BigInt, Sign}, }; use chrono::{DateTime, Datelike, Timelike, Utc}; use datafusion_common::{ - exec_datafusion_err, exec_err, plan_err, DataFusionError, Result, ScalarValue, + DataFusionError, Result, ScalarValue, exec_datafusion_err, exec_err, plan_err, }; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, - Volatility, + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; /// Spark-compatible `format_string` expression @@ -62,10 +61,6 @@ impl FormatStringFunc { } impl ScalarUDFImpl for FormatStringFunc { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "format_string" } @@ -78,11 +73,24 @@ impl ScalarUDFImpl for FormatStringFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - match arg_types[0] { - DataType::Null => Ok(DataType::Utf8), - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Ok(arg_types[0].clone()), - _ => plan_err!("The format_string function expects the first argument to be Utf8, LargeUtf8 or Utf8View") + fn return_type(&self, _arg_types: &[DataType]) -> Result { + datafusion_common::internal_err!( + "return_type should not be called, use return_field_from_args instead" + ) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + match args.arg_fields[0].data_type() { + DataType::Null => { + Ok(Arc::new(Field::new("format_string", DataType::Utf8, true))) + } + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { + Ok(Arc::clone(&args.arg_fields[0])) + } + _ => exec_err!( + "format_string expects the first argument to be Utf8, LargeUtf8 or Utf8View, got {} instead", + args.arg_fields[0].data_type() + ), } } @@ -317,7 +325,7 @@ impl<'a> Formatter<'a> { (index as usize, &rest2[1..]) } (NumericParam::FromArgument, true) => { - return exec_err!("Invalid numeric parameter") + return exec_err!("Invalid numeric parameter"); } (_, false) => { argument_index += 1; @@ -584,15 +592,15 @@ impl TryFrom for TimeFormat { impl ConversionType { pub fn validate(&self, arg_type: &DataType) -> Result<()> { match self { - ConversionType::BooleanLower | ConversionType::BooleanUpper => { - if !matches!(arg_type, DataType::Boolean) { - return exec_err!( - "Invalid argument type for boolean conversion: {:?}", - arg_type - ); - } + ConversionType::BooleanLower | ConversionType::BooleanUpper + if *arg_type != DataType::Boolean => + { + return exec_err!( + "Invalid argument type for boolean conversion: {:?}", + arg_type + ); } - ConversionType::CharLower | ConversionType::CharUpper => { + ConversionType::CharLower | ConversionType::CharUpper if !matches!( arg_type, DataType::Int8 @@ -603,23 +611,23 @@ impl ConversionType { | DataType::UInt32 | DataType::Int64 | DataType::UInt64 - ) { - return exec_err!( - "Invalid argument type for char conversion: {:?}", - arg_type - ); - } + ) => + { + return exec_err!( + "Invalid argument type for char conversion: {:?}", + arg_type + ); } ConversionType::DecInt | ConversionType::OctInt | ConversionType::HexIntLower - | ConversionType::HexIntUpper => { - if !arg_type.is_integer() { - return exec_err!( - "Invalid argument type for integer conversion: {:?}", - arg_type - ); - } + | ConversionType::HexIntUpper + if !arg_type.is_integer() => + { + return exec_err!( + "Invalid argument type for integer conversion: {:?}", + arg_type + ); } ConversionType::SciFloatLower | ConversionType::SciFloatUpper @@ -627,21 +635,21 @@ impl ConversionType { | ConversionType::CompactFloatLower | ConversionType::CompactFloatUpper | ConversionType::HexFloatLower - | ConversionType::HexFloatUpper => { - if !arg_type.is_numeric() { - return exec_err!( - "Invalid argument type for float conversion: {:?}", - arg_type - ); - } + | ConversionType::HexFloatUpper + if !arg_type.is_numeric() => + { + return exec_err!( + "Invalid argument type for float conversion: {:?}", + arg_type + ); } - ConversionType::TimeLower(_) | ConversionType::TimeUpper(_) => { - if !arg_type.is_temporal() { - return exec_err!( - "Invalid argument type for time conversion: {:?}", - arg_type - ); - } + ConversionType::TimeLower(_) | ConversionType::TimeUpper(_) + if !arg_type.is_temporal() => + { + return exec_err!( + "Invalid argument type for time conversion: {:?}", + arg_type + ); } _ => {} } @@ -853,7 +861,138 @@ fn take_numeric_param(s: &str, zero: bool) -> (NumericParam, &str) { } } +/// Convert a `u32` to a [`char`] for the `%c` conversion. Returns an error if +/// the value is not a valid Unicode scalar value (i.e. is in the surrogate +/// range `0xD800..=0xDFFF` or above `0x10FFFF`). +fn codepoint_to_char(value: u32) -> Result { + char::from_u32(value).ok_or_else(|| { + exec_datafusion_err!("invalid Unicode scalar value for %c: {value:#x}") + }) +} + +/// Convert a signed integer to a [`char`] for the `%c` conversion. Returns an +/// error if the value is negative or is not a valid Unicode scalar value (i.e. +/// is in the surrogate range `0xD800..=0xDFFF` or above `0x10FFFF`). +fn signed_to_char(value: i64) -> Result { + let codepoint = u32::try_from(value).map_err(|_| { + exec_datafusion_err!("invalid Unicode scalar value for %c: {value}") + })?; + codepoint_to_char(codepoint) +} + +/// Convert an unsigned integer to a [`char`] for the `%c` conversion. Returns +/// an error if the value does not fit in a `u32` or is not a valid Unicode +/// scalar value (i.e. is in the surrogate range `0xD800..=0xDFFF` or above +/// `0x10FFFF`). +fn unsigned_to_char(value: u64) -> Result { + let codepoint = u32::try_from(value).map_err(|_| { + exec_datafusion_err!("invalid Unicode scalar value for %c: {value:#x}") + })?; + codepoint_to_char(codepoint) +} + +/// Formatting operations that differ between signed and unsigned integer +/// primitives. Signed values format as decimal for `%d` / `%s` / `%c`, but use +/// their original bit width for `%x` / `%o` via `unsigned_bits`. +trait IntegerFormatValue { + fn unsigned_bits(self) -> u64; + + fn to_char(self) -> Result; + + fn format_decimal( + self, + spec: &ConversionSpecifier, + writer: &mut String, + ) -> Result<()>; + + fn decimal_string(self) -> String; +} + +macro_rules! signed_integer_value { + ($source:ty, $unsigned:ty) => { + impl IntegerFormatValue for $source { + fn unsigned_bits(self) -> u64 { + (self as $unsigned) as u64 + } + + fn to_char(self) -> Result { + signed_to_char(self as i64) + } + + fn format_decimal( + self, + spec: &ConversionSpecifier, + writer: &mut String, + ) -> Result<()> { + spec.format_signed(writer, self as i64) + } + + fn decimal_string(self) -> String { + self.to_string() + } + } + }; +} + +signed_integer_value!(i8, u8); +signed_integer_value!(i16, u16); +signed_integer_value!(i32, u32); +signed_integer_value!(i64, u64); + +macro_rules! unsigned_integer_value { + ($source:ty) => { + impl IntegerFormatValue for $source { + fn unsigned_bits(self) -> u64 { + self as u64 + } + + fn to_char(self) -> Result { + unsigned_to_char(self as u64) + } + + fn format_decimal( + self, + spec: &ConversionSpecifier, + writer: &mut String, + ) -> Result<()> { + spec.format_unsigned(writer, self as u64) + } + + fn decimal_string(self) -> String { + self.to_string() + } + } + }; +} + +unsigned_integer_value!(u8); +unsigned_integer_value!(u16); +unsigned_integer_value!(u32); +unsigned_integer_value!(u64); + impl ConversionSpecifier { + /// Validates that the grouping separator flag is not used with scientific + /// notation conversions, matching Java/Spark behavior which throws + /// `FormatFlagsConversionMismatchException` for `%,e` / `%,E`. + fn validate_grouping_separator(&self) -> Result<()> { + if self.grouping_separator + && matches!( + self.conversion_type, + ConversionType::SciFloatLower | ConversionType::SciFloatUpper + ) + { + return exec_err!( + "Grouping separator ',' flag is not compatible with scientific notation conversion '{}'", + if self.conversion_type == ConversionType::SciFloatUpper { + 'E' + } else { + 'e' + } + ); + } + Ok(()) + } + pub fn format(&self, string: &mut String, value: &ScalarValue) -> Result<()> { match value { ScalarValue::Boolean(value) => match self.conversion_type { @@ -863,204 +1002,14 @@ impl ConversionSpecifier { _ => self.format_boolean(string, value), }, - ScalarValue::Int8(value) => match (self.conversion_type, value) { - (ConversionType::DecInt, Some(value)) => { - self.format_signed(string, *value as i64) - } - ( - ConversionType::HexIntLower - | ConversionType::HexIntUpper - | ConversionType::OctInt, - Some(value), - ) => self.format_unsigned(string, (*value as u8) as u64), - (ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => { - self.format_char(string, *value as u8 as char) - } - ( - ConversionType::StringLower | ConversionType::StringUpper, - Some(value), - ) => self.format_string(string, &value.to_string()), - (t, None) if t.supports_integer() => self.format_string(string, "null"), - _ => { - exec_err!( - "Invalid conversion type: {:?} for Int8", - self.conversion_type - ) - } - }, - ScalarValue::Int16(value) => match (self.conversion_type, value) { - (ConversionType::DecInt, Some(value)) => { - self.format_signed(string, *value as i64) - } - (ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => { - self.format_char( - string, - char::from_u32((*value as u16) as u32).unwrap(), - ) - } - ( - ConversionType::HexIntLower - | ConversionType::HexIntUpper - | ConversionType::OctInt, - Some(value), - ) => self.format_unsigned(string, (*value as u16) as u64), - ( - ConversionType::StringLower | ConversionType::StringUpper, - Some(value), - ) => self.format_string(string, &value.to_string()), - (t, None) if t.supports_integer() => self.format_string(string, "null"), - _ => { - exec_err!( - "Invalid conversion type: {:?} for Int16", - self.conversion_type - ) - } - }, - ScalarValue::Int32(value) => match (self.conversion_type, value) { - (ConversionType::DecInt, Some(value)) => { - self.format_signed(string, *value as i64) - } - ( - ConversionType::HexIntLower - | ConversionType::HexIntUpper - | ConversionType::OctInt, - Some(value), - ) => self.format_unsigned(string, (*value as u32) as u64), - (ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => { - self.format_char(string, char::from_u32(*value as u32).unwrap()) - } - ( - ConversionType::StringLower | ConversionType::StringUpper, - Some(value), - ) => self.format_string(string, &value.to_string()), - (t, None) if t.supports_integer() => self.format_string(string, "null"), - _ => { - exec_err!( - "Invalid conversion type: {:?} for Int32", - self.conversion_type - ) - } - }, - ScalarValue::Int64(value) => match (self.conversion_type, value) { - (ConversionType::DecInt, Some(value)) => { - self.format_signed(string, *value) - } - ( - ConversionType::HexIntLower - | ConversionType::HexIntUpper - | ConversionType::OctInt, - Some(value), - ) => self.format_unsigned(string, *value as u64), - (ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => { - self.format_char( - string, - char::from_u32((*value as u64) as u32).unwrap(), - ) - } - ( - ConversionType::StringLower | ConversionType::StringUpper, - Some(value), - ) => self.format_string(string, &value.to_string()), - (t, None) if t.supports_integer() => self.format_string(string, "null"), - _ => { - exec_err!( - "Invalid conversion type: {:?} for Int64", - self.conversion_type - ) - } - }, - ScalarValue::UInt8(value) => match (self.conversion_type, value) { - ( - ConversionType::DecInt - | ConversionType::HexIntLower - | ConversionType::HexIntUpper - | ConversionType::OctInt, - Some(value), - ) => self.format_unsigned(string, *value as u64), - (ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => { - self.format_char(string, *value as char) - } - ( - ConversionType::StringLower | ConversionType::StringUpper, - Some(value), - ) => self.format_string(string, &value.to_string()), - (t, None) if t.supports_integer() => self.format_string(string, "null"), - _ => { - exec_err!( - "Invalid conversion type: {:?} for UInt8", - self.conversion_type - ) - } - }, - ScalarValue::UInt16(value) => match (self.conversion_type, value) { - ( - ConversionType::DecInt - | ConversionType::HexIntLower - | ConversionType::HexIntUpper - | ConversionType::OctInt, - Some(value), - ) => self.format_unsigned(string, *value as u64), - (ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => { - self.format_char(string, char::from_u32(*value as u32).unwrap()) - } - ( - ConversionType::StringLower | ConversionType::StringUpper, - Some(value), - ) => self.format_string(string, &value.to_string()), - (t, None) if t.supports_integer() => self.format_string(string, "null"), - _ => { - exec_err!( - "Invalid conversion type: {:?} for UInt16", - self.conversion_type - ) - } - }, - ScalarValue::UInt32(value) => match (self.conversion_type, value) { - ( - ConversionType::DecInt - | ConversionType::HexIntLower - | ConversionType::HexIntUpper - | ConversionType::OctInt, - Some(value), - ) => self.format_unsigned(string, *value as u64), - (ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => { - self.format_char(string, char::from_u32(*value).unwrap()) - } - ( - ConversionType::StringLower | ConversionType::StringUpper, - Some(value), - ) => self.format_string(string, &value.to_string()), - (t, None) if t.supports_integer() => self.format_string(string, "null"), - _ => { - exec_err!( - "Invalid conversion type: {:?} for UInt32", - self.conversion_type - ) - } - }, - ScalarValue::UInt64(value) => match (self.conversion_type, value) { - ( - ConversionType::DecInt - | ConversionType::HexIntLower - | ConversionType::HexIntUpper - | ConversionType::OctInt, - Some(value), - ) => self.format_unsigned(string, *value), - (ConversionType::CharLower | ConversionType::CharUpper, Some(value)) => { - self.format_char(string, char::from_u32(*value as u32).unwrap()) - } - ( - ConversionType::StringLower | ConversionType::StringUpper, - Some(value), - ) => self.format_string(string, &value.to_string()), - (t, None) if t.supports_integer() => self.format_string(string, "null"), - _ => { - exec_err!( - "Invalid conversion type: {:?} for UInt64", - self.conversion_type - ) - } - }, + ScalarValue::Int8(value) => self.format_integer(string, value, "Int8"), + ScalarValue::Int16(value) => self.format_integer(string, value, "Int16"), + ScalarValue::Int32(value) => self.format_integer(string, value, "Int32"), + ScalarValue::Int64(value) => self.format_integer(string, value, "Int64"), + ScalarValue::UInt8(value) => self.format_integer(string, value, "UInt8"), + ScalarValue::UInt16(value) => self.format_integer(string, value, "UInt16"), + ScalarValue::UInt32(value) => self.format_integer(string, value, "UInt32"), + ScalarValue::UInt64(value) => self.format_integer(string, value, "UInt64"), ScalarValue::Float16(value) => match (self.conversion_type, value) { ( ConversionType::DecFloatLower @@ -1418,10 +1367,52 @@ impl ConversionSpecifier { let value = "null".to_string(); self.format_string(string, &value) } - _ => exec_err!("Invalid scalar value: {:?}", value), + _ => exec_err!("Invalid scalar value: {value}"), + } + } + + fn format_integer( + &self, + writer: &mut String, + value: &Option, + type_name: &str, + ) -> Result<()> + where + T: Copy + IntegerFormatValue, + { + let Some(value) = *value else { + return if self.conversion_type.supports_integer() { + self.format_string(writer, "null") + } else { + self.invalid_integer_conversion(type_name) + }; + }; + + match self.conversion_type { + ConversionType::DecInt => value.format_decimal(self, writer), + ConversionType::HexIntLower + | ConversionType::HexIntUpper + | ConversionType::OctInt => { + self.format_unsigned(writer, value.unsigned_bits()) + } + ConversionType::CharLower | ConversionType::CharUpper => { + self.format_char(writer, value.to_char()?) + } + ConversionType::StringLower | ConversionType::StringUpper => { + self.format_string(writer, &value.decimal_string()) + } + _ => self.invalid_integer_conversion(type_name), } } + fn invalid_integer_conversion(&self, type_name: &str) -> Result { + exec_err!( + "Invalid conversion type: {:?} for {}", + self.conversion_type, + type_name + ) + } + fn format_hex_float(&self, writer: &mut String, value: f64) -> Result<()> { // Handle special cases first let (sign, raw_exponent, mantissa) = value.to_parts(); @@ -1675,13 +1666,15 @@ impl ConversionSpecifier { return exec_err!( "Invalid conversion type: {:?} for boolean array", self.conversion_type - ) + ); } }; self.format_str(writer, formatted) } fn format_float(&self, writer: &mut String, value: f64) -> Result<()> { + self.validate_grouping_separator()?; + let mut prefix = String::new(); let mut suffix = String::new(); let mut number = String::new(); @@ -1744,7 +1737,7 @@ impl ConversionSpecifier { return exec_err!( "Invalid conversion type: {:?} for float", self.conversion_type - ) + ); } } @@ -1762,6 +1755,9 @@ impl ConversionSpecifier { if strip_trailing_0s { number = trim_trailing_0s(&number).to_owned(); } + if self.grouping_separator { + number = insert_thousands_separator(&number); + } } if self.alt_form && !number.contains('.') { number += "."; @@ -1789,7 +1785,7 @@ impl ConversionSpecifier { return exec_err!( "Invalid conversion type: {:?} for float", self.conversion_type - ) + ); } } } @@ -1874,20 +1870,11 @@ impl ConversionSpecifier { match self.conversion_type { ConversionType::DecInt => { let num_str = format!("{value}"); - if self.grouping_separator { - // Add thousands separators - let mut result = String::new(); - let chars: Vec = num_str.chars().collect(); - for (i, c) in chars.iter().enumerate() { - if i > 0 && (chars.len() - i).is_multiple_of(3) { - result.push(','); - } - result.push(*c); - } - s = result; + s = if self.grouping_separator { + insert_thousands_separator(&num_str) } else { - s = num_str; - } + num_str + }; } ConversionType::HexIntLower => { alt_prefix = "0x"; @@ -1908,7 +1895,7 @@ impl ConversionSpecifier { return exec_err!( "Invalid conversion type: {:?} for u64", self.conversion_type - ) + ); } } let mut prefix = if self.alt_form { @@ -1992,6 +1979,8 @@ impl ConversionSpecifier { } fn format_decimal(&self, writer: &mut String, value: &str, scale: i64) -> Result<()> { + self.validate_grouping_separator()?; + let mut prefix = String::new(); let upper = self.conversion_type.is_upper(); @@ -2002,6 +1991,10 @@ impl ConversionSpecifier { let decimal = BigDecimal::from_bigint(decimal, scale); // Handle sign + // TODO: `negative_in_parentheses` (the `(` flag) is not implemented here. + // Java/Spark wrap negative values in parentheses when this flag is set + // (e.g. `%(,.2f` with -1234.5 → "(1,234.50)"), but this path always + // uses a minus sign. See `format_float` for the correct implementation. let is_negative = decimal.sign() == Sign::Minus; let abs_decimal = decimal.abs(); @@ -2025,7 +2018,15 @@ impl ConversionSpecifier { let number = match self.conversion_type { ConversionType::DecFloatLower => { // Format as fixed-point decimal - self.format_decimal_fixed(&abs_decimal, precision, strip_trailing_0s)? + let mut n = self.format_decimal_fixed( + &abs_decimal, + precision, + strip_trailing_0s, + )?; + if self.grouping_separator { + n = insert_thousands_separator(&n); + } + n } ConversionType::SciFloatLower => self.format_decimal_scientific( &abs_decimal, @@ -2054,18 +2055,22 @@ impl ConversionSpecifier { strip_trailing_0s, )? } else { - self.format_decimal_fixed( + let mut n = self.format_decimal_fixed( &abs_decimal, precision - 1 - log10_val.floor() as i32, strip_trailing_0s, - )? + )?; + if self.grouping_separator { + n = insert_thousands_separator(&n); + } + n } } _ => { return exec_err!( "Invalid conversion type: {:?} for decimal", self.conversion_type - ) + ); } }; @@ -2324,6 +2329,24 @@ impl FloatBits for f64 { } } +/// Inserts thousands separators (`,`) into the integer part of a numeric string. +/// For example, `"1234567.89"` becomes `"1,234,567.89"`. +fn insert_thousands_separator(number: &str) -> String { + let (int_part, frac_part) = match number.find('.') { + Some(pos) => (&number[..pos], &number[pos..]), + None => (number, ""), + }; + let mut result = String::with_capacity(number.len() + number.len() / 3); + for (i, c) in int_part.char_indices() { + if i > 0 && (int_part.len() - i) % 3 == 0 { + result.push(','); + } + result.push(c); + } + result.push_str(frac_part); + result +} + fn trim_trailing_0s(number: &str) -> &str { if number.contains('.') { for (i, c) in number.chars().rev().enumerate() { @@ -2343,3 +2366,551 @@ fn trim_trailing_0s_hex(number: &str) -> &str { } number } + +#[cfg(test)] +mod tests { + use super::*; + use crate::function::utils::test::test_scalar_function; + use arrow::array::StringArray; + use arrow::datatypes::DataType::Utf8; + + #[test] + fn test_format_string_nullability() -> Result<()> { + let func = FormatStringFunc::new(); + let nullable_format: FieldRef = Arc::new(Field::new("fmt", Utf8, true)); + + let out_nullable = func.return_field_from_args(ReturnFieldArgs { + arg_fields: &[nullable_format], + scalar_arguments: &[None], + })?; + + assert!( + out_nullable.is_nullable(), + "format_string(fmt, ...) should be nullable when fmt is nullable" + ); + let non_nullable_format: FieldRef = Arc::new(Field::new("fmt", Utf8, false)); + + let out_non_nullable = func.return_field_from_args(ReturnFieldArgs { + arg_fields: &[non_nullable_format], + scalar_arguments: &[None], + })?; + + assert!( + !out_non_nullable.is_nullable(), + "format_string(fmt, ...) should NOT be nullable when fmt is NOT nullable" + ); + + Ok(()) + } + + #[test] + fn test_format_char_invalid_codepoint_errors() { + use arrow::datatypes::Field; + use datafusion_common::config::ConfigOptions; + + let func = FormatStringFunc::new(); + // Spark/Java reject any negative integer or any value outside + // `0..=0x10FFFF` (and the surrogate range) regardless of integer + // width, so all of these inputs must surface a SQL error rather than + // panicking or silently reinterpreting the bits as unsigned. + let cases: Vec<(&str, ScalarValue)> = vec![ + ("Int8(-1)", ScalarValue::Int8(Some(-1))), + ("Int16(-1)", ScalarValue::Int16(Some(-1))), + ("Int16(-10000)", ScalarValue::Int16(Some(-10000))), + ("Int32(-1)", ScalarValue::Int32(Some(-1))), + ("Int32(0x110000)", ScalarValue::Int32(Some(0x110000))), + ("Int64(0x1FFFFFFFF)", ScalarValue::Int64(Some(0x1FFFFFFFF))), + ("Int64(-1)", ScalarValue::Int64(Some(-1))), + ("UInt16(0xD800)", ScalarValue::UInt16(Some(0xD800))), + ("UInt32(0x110000)", ScalarValue::UInt32(Some(0x110000))), + ( + "UInt64(0x1_0000_0000)", + ScalarValue::UInt64(Some(0x1_0000_0000)), + ), + ]; + + for (label, value) in cases { + let fmt = ColumnarValue::Scalar(ScalarValue::Utf8(Some("[%c]".to_string()))); + let arg_data_type = value.data_type(); + let arg = ColumnarValue::Scalar(value); + let arg_fields = vec![ + Arc::new(Field::new("fmt", Utf8, false)), + Arc::new(Field::new("v", arg_data_type, false)), + ]; + let res = func.invoke_with_args(ScalarFunctionArgs { + args: vec![fmt, arg], + number_rows: 1, + arg_fields, + return_field: Arc::new(Field::new("o", Utf8, false)), + config_options: Arc::new(ConfigOptions::default()), + }); + assert!( + res.is_err(), + "format_string('[%c]', {label}) should error, got Ok" + ); + let err = res.unwrap_err().to_string(); + assert!( + err.contains("invalid Unicode scalar value for %c"), + "unexpected error for {label}: {err}" + ); + } + } + + #[test] + fn test_format_char_valid_codepoint_succeeds() { + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("[%c]".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(0x1F680))), + ], + Ok(Some("[\u{1F680}]")), + &str, + Utf8, + StringArray + ); + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("[%c]".to_string()))), + ColumnarValue::Scalar(ScalarValue::UInt32(Some(0x10FFFF))), + ], + Ok(Some("[\u{10FFFF}]")), + &str, + Utf8, + StringArray + ); + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("[%c]".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int16(Some(65))), + ], + Ok(Some("[A]")), + &str, + Utf8, + StringArray + ); + // Int8 / UInt8 can never produce an invalid codepoint for non-negative + // values, but they must still flow through the validating helper. + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("[%c]".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int8(Some(97))), + ], + Ok(Some("[a]")), + &str, + Utf8, + StringArray + ); + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("[%c]".to_string()))), + ColumnarValue::Scalar(ScalarValue::UInt8(Some(255))), + ], + Ok(Some("[\u{00FF}]")), + &str, + Utf8, + StringArray + ); + } + + #[test] + fn test_integer_formatting_across_widths() -> Result<()> { + let cases = [ + ( + ScalarValue::Int8(Some(-1)), + "%d|%x|%o|%s", + 4, + "-1|ff|377|-1", + ), + ( + ScalarValue::Int16(Some(-1)), + "%d|%x|%o|%s", + 4, + "-1|ffff|177777|-1", + ), + ( + ScalarValue::Int32(Some(-1)), + "%d|%x|%o|%s", + 4, + "-1|ffffffff|37777777777|-1", + ), + ( + ScalarValue::Int64(Some(-1)), + "%d|%x|%o|%s", + 4, + "-1|ffffffffffffffff|1777777777777777777777|-1", + ), + ( + ScalarValue::UInt8(Some(255)), + "%d|%x|%o|%s|%c", + 5, + "255|ff|377|255|ÿ", + ), + ( + ScalarValue::UInt16(Some(65535)), + "%d|%x|%o|%s", + 4, + "65535|ffff|177777|65535", + ), + ( + ScalarValue::UInt32(Some(u32::MAX)), + "%d|%x|%o|%s", + 4, + "4294967295|ffffffff|37777777777|4294967295", + ), + ( + ScalarValue::UInt64(Some(u64::MAX)), + "%d|%x|%o|%s", + 4, + "18446744073709551615|ffffffffffffffff|1777777777777777777777|18446744073709551615", + ), + ( + ScalarValue::Int32(None), + "%d|%x|%o|%s|%c", + 5, + "null|null|null|null|null", + ), + ]; + + for (value, fmt, arg_count, expected) in cases { + let data_types = vec![value.data_type(); arg_count]; + let formatter = Formatter::parse(fmt, &data_types)?; + let args = vec![value; arg_count]; + assert_eq!(formatter.format(&args)?, expected, "{fmt}"); + } + Ok(()) + } + + #[test] + fn test_insert_thousands_separator() { + assert_eq!(insert_thousands_separator("1234567.89"), "1,234,567.89"); + assert_eq!(insert_thousands_separator("123.45"), "123.45"); + assert_eq!(insert_thousands_separator("1234"), "1,234"); + assert_eq!(insert_thousands_separator("12"), "12"); + assert_eq!(insert_thousands_separator("0.5"), "0.5"); + assert_eq!( + insert_thousands_separator("1234567890.1234"), + "1,234,567,890.1234" + ); + assert_eq!(insert_thousands_separator("1000"), "1,000"); + assert_eq!(insert_thousands_separator("100"), "100"); + } + + #[test] + fn test_grouping_separator_float() -> Result<()> { + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%,.2f".to_string()))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(1234567.89))), + ], + Ok(Some("1,234,567.89")), + &str, + Utf8, + StringArray + ); + Ok(()) + } + + #[test] + fn test_grouping_separator_decimal() -> Result<()> { + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%,.2f".to_string()))), + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(123456789), 10, 2)), + ], + Ok(Some("1,234,567.89")), + &str, + Utf8, + StringArray + ); + Ok(()) + } + + #[test] + fn test_grouping_separator_scientific_float() -> Result<()> { + // %,e — Java/Spark reject grouping separator with scientific notation + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%,e".to_string()))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(1234567.89))), + ], + Err(DataFusionError::Execution( + "Grouping separator ',' flag is not compatible with scientific notation conversion 'e'".to_string(), + )), + &str, + Utf8, + StringArray + ); + // %,E — uppercase scientific also rejected + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%,E".to_string()))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(1234567.89))), + ], + Err(DataFusionError::Execution( + "Grouping separator ',' flag is not compatible with scientific notation conversion 'E'".to_string(), + )), + &str, + Utf8, + StringArray + ); + // %,.0e — precision 0 scientific with grouping also rejected + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%,.0e".to_string()))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(1234567.89))), + ], + Err(DataFusionError::Execution( + "Grouping separator ',' flag is not compatible with scientific notation conversion 'e'".to_string(), + )), + &str, + Utf8, + StringArray + ); + Ok(()) + } + + #[test] + fn test_grouping_separator_compact_float() -> Result<()> { + // %,g with large number — triggers scientific, no commas + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%,g".to_string()))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(1234567.89))), + ], + Ok(Some("1.23457e+06")), + &str, + Utf8, + StringArray + ); + // %,g with small number — triggers fixed-point, commas in integer part + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%,g".to_string()))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(12345.6))), + ], + Ok(Some("12,345.6")), + &str, + Utf8, + StringArray + ); + // %,.0g — precision 0 compact with grouping (large number, scientific) + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%,.0g".to_string()))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(1234567.89))), + ], + Ok(Some("1e+06")), + &str, + Utf8, + StringArray + ); + // %,G — uppercase compact + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%,G".to_string()))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(1234567.89))), + ], + Ok(Some("1.23457E+06")), + &str, + Utf8, + StringArray + ); + Ok(()) + } + + #[test] + fn test_grouping_separator_scientific_decimal() -> Result<()> { + // %,e on decimal — Java/Spark reject grouping separator with scientific notation + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%,e".to_string()))), + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(123456789), 10, 2)), + ], + Err(DataFusionError::Execution( + "Grouping separator ',' flag is not compatible with scientific notation conversion 'e'".to_string(), + )), + &str, + Utf8, + StringArray + ); + // %,.0e on decimal — also rejected + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%,.0e".to_string()))), + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(123456789), 10, 2)), + ], + Err(DataFusionError::Execution( + "Grouping separator ',' flag is not compatible with scientific notation conversion 'e'".to_string(), + )), + &str, + Utf8, + StringArray + ); + Ok(()) + } + + #[test] + fn test_grouping_separator_compact_decimal() -> Result<()> { + // %,g on decimal — large number triggers scientific, no commas + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%,g".to_string()))), + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(123456789), 10, 2)), + ], + Ok(Some("1.23457e+06")), + &str, + Utf8, + StringArray + ); + // %,g on decimal — small number triggers fixed-point, commas expected + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%,g".to_string()))), + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(1234560), 10, 2)), + ], + Ok(Some("12,345.6")), + &str, + Utf8, + StringArray + ); + // %,.0g on decimal — precision 0 compact with grouping (scientific) + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%,.0g".to_string()))), + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(123456789), 10, 2)), + ], + Ok(Some("1e+06")), + &str, + Utf8, + StringArray + ); + Ok(()) + } + + #[test] + fn test_grouping_separator_width_sign_float() -> Result<()> { + // %0,15.2f — zero-pad + grouping + width + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%0,15.2f".to_string()))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(1234567.89))), + ], + Ok(Some("0001,234,567.89")), + &str, + Utf8, + StringArray + ); + // %+,15.2f — force-sign + grouping + width (space-padded) + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%+,15.2f".to_string()))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(1234567.89))), + ], + Ok(Some(" +1,234,567.89")), + &str, + Utf8, + StringArray + ); + // %-,15.2f — left-adjust + grouping + width + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%-,15.2f".to_string()))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(1234567.89))), + ], + Ok(Some("1,234,567.89 ")), + &str, + Utf8, + StringArray + ); + Ok(()) + } + + #[test] + fn test_grouping_separator_width_sign_decimal() -> Result<()> { + // %0,15.2f — zero-pad + grouping + width on decimal + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%0,15.2f".to_string()))), + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(123456789), 10, 2)), + ], + Ok(Some("0001,234,567.89")), + &str, + Utf8, + StringArray + ); + // %+,15.2f — force-sign + grouping + width on decimal + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%+,15.2f".to_string()))), + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(123456789), 10, 2)), + ], + Ok(Some(" +1,234,567.89")), + &str, + Utf8, + StringArray + ); + Ok(()) + } + + #[test] + fn test_grouping_separator_parentheses_float() -> Result<()> { + // %(,15.2f with negative — parentheses + grouping + width + // Java: String.format("%(,15.2f", -1234.5) → " (1,234.50)" + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%(,15.2f".to_string()))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(-1234.5))), + ], + Ok(Some(" (1,234.50)")), + &str, + Utf8, + StringArray + ); + Ok(()) + } + + #[test] + fn test_grouping_separator_parentheses_decimal() -> Result<()> { + // %(,15.2f on negative decimal — format_decimal ignores negative_in_parentheses, + // always uses '-'. Check TODO in fn format_decimal + // Java: String.format("%(,15.2f", -1234.5) → " (1,234.50)" + // Ours: " -1,234.50" (minus sign, no parens) + test_scalar_function!( + FormatStringFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("%(,15.2f".to_string()))), + ColumnarValue::Scalar(ScalarValue::Decimal128(Some(-123450), 10, 2)), + ], + Ok(Some(" -1,234.50")), + &str, + Utf8, + StringArray + ); + Ok(()) + } +} diff --git a/datafusion/spark/src/function/string/ilike.rs b/datafusion/spark/src/function/string/ilike.rs index a160749523f1e..3be63955c0447 100644 --- a/datafusion/spark/src/function/string/ilike.rs +++ b/datafusion/spark/src/function/string/ilike.rs @@ -17,12 +17,13 @@ use arrow::array::ArrayRef; use arrow::compute::ilike; -use arrow::datatypes::DataType; -use datafusion_common::{exec_err, Result}; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::{Result, exec_err, internal_err}; use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; use datafusion_functions::utils::make_scalar_function; -use std::any::Any; use std::sync::Arc; /// ILIKE function for case-insensitive pattern matching @@ -47,10 +48,6 @@ impl SparkILike { } impl ScalarUDFImpl for SparkILike { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "ilike" } @@ -60,7 +57,14 @@ impl ScalarUDFImpl for SparkILike { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Boolean) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result> { + // ILIKE returns a boolean value + // The result is nullable if any of the input arguments is nullable + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new("ilike", DataType::Boolean, nullable))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -84,8 +88,7 @@ mod tests { use crate::function::utils::test::test_scalar_function; use arrow::array::{Array, BooleanArray}; use arrow::datatypes::DataType::Boolean; - use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_common::ScalarValue; macro_rules! test_ilike_string_invoke { ($INPUT1:expr, $INPUT2:expr, $EXPECTED:expr) => { @@ -170,4 +173,73 @@ mod tests { Ok(()) } + + #[test] + fn test_ilike_nullability() { + let ilike = SparkILike::new(); + + // Test with non-nullable arguments + let non_nullable_field1 = Arc::new(Field::new("str", DataType::Utf8, false)); + let non_nullable_field2 = Arc::new(Field::new("pattern", DataType::Utf8, false)); + + let result = ilike + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[ + Arc::clone(&non_nullable_field1), + Arc::clone(&non_nullable_field2), + ], + scalar_arguments: &[None, None], + }) + .unwrap(); + + // The result should not be nullable when both inputs are non-nullable + assert!(!result.is_nullable()); + assert_eq!(result.data_type(), &Boolean); + + // Test with first argument nullable + let nullable_field1 = Arc::new(Field::new("str", DataType::Utf8, true)); + + let result = ilike + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[ + Arc::clone(&nullable_field1), + Arc::clone(&non_nullable_field2), + ], + scalar_arguments: &[None, None], + }) + .unwrap(); + + // The result should be nullable when first input is nullable + assert!(result.is_nullable()); + assert_eq!(result.data_type(), &Boolean); + + // Test with second argument nullable + let nullable_field2 = Arc::new(Field::new("pattern", DataType::Utf8, true)); + + let result = ilike + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[ + Arc::clone(&non_nullable_field1), + Arc::clone(&nullable_field2), + ], + scalar_arguments: &[None, None], + }) + .unwrap(); + + // The result should be nullable when second input is nullable + assert!(result.is_nullable()); + assert_eq!(result.data_type(), &Boolean); + + // Test with both arguments nullable + let result = ilike + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&nullable_field1), Arc::clone(&nullable_field2)], + scalar_arguments: &[None, None], + }) + .unwrap(); + + // The result should be nullable when both inputs are nullable + assert!(result.is_nullable()); + assert_eq!(result.data_type(), &Boolean); + } } diff --git a/datafusion/spark/src/function/string/is_valid_utf8.rs b/datafusion/spark/src/function/string/is_valid_utf8.rs new file mode 100644 index 0000000000000..591801cef290c --- /dev/null +++ b/datafusion/spark/src/function/string/is_valid_utf8.rs @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; + +use arrow::array::{Array, ArrayRef, BooleanArray}; +use arrow::buffer::BooleanBuffer; +use datafusion_common::cast::{ + as_binary_array, as_binary_view_array, as_large_binary_array, +}; +use datafusion_common::utils::take_function_args; +use datafusion_functions::utils::make_scalar_function; + +use std::sync::Arc; + +/// Spark-compatible `is_valid_utf8` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkIsValidUtf8 { + signature: Signature, +} + +impl Default for SparkIsValidUtf8 { + fn default() -> Self { + Self::new() + } +} + +impl SparkIsValidUtf8 { + pub fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![ + DataType::Utf8, + DataType::LargeUtf8, + DataType::Utf8View, + DataType::Binary, + DataType::BinaryView, + DataType::LargeBinary, + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkIsValidUtf8 { + fn name(&self) -> &str { + "is_valid_utf8" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Arc::new(Field::new(self.name(), DataType::Boolean, true))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_is_valid_utf8_inner, vec![])(&args.args) + } +} + +fn spark_is_valid_utf8_inner(args: &[ArrayRef]) -> Result { + let [array] = take_function_args("is_valid_utf8", args)?; + match array.data_type() { + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => { + Ok(Arc::new(BooleanArray::new( + BooleanBuffer::new_set(array.len()), + array.nulls().cloned(), + ))) + } + DataType::Binary => Ok(Arc::new( + as_binary_array(array)? + .iter() + .map(|x| x.map(|y| str::from_utf8(y).is_ok())) + .collect::(), + )), + DataType::LargeBinary => Ok(Arc::new( + as_large_binary_array(array)? + .iter() + .map(|x| x.map(|y| str::from_utf8(y).is_ok())) + .collect::(), + )), + DataType::BinaryView => Ok(Arc::new( + as_binary_view_array(array)? + .iter() + .map(|x| x.map(|y| str::from_utf8(y).is_ok())) + .collect::(), + )), + data_type => { + internal_err!("is_valid_utf8 does not support: {data_type}") + } + } +} diff --git a/datafusion/spark/src/function/string/length.rs b/datafusion/spark/src/function/string/length.rs index ac6030770fe27..8c5539a0577d8 100644 --- a/datafusion/spark/src/function/string/length.rs +++ b/datafusion/spark/src/function/string/length.rs @@ -18,10 +18,11 @@ use arrow::array::{ Array, ArrayRef, AsArray, BinaryArrayType, PrimitiveArray, StringArrayType, }; -use arrow::datatypes::{DataType, Int32Type}; +use arrow::datatypes::{DataType, Field, FieldRef, Int32Type}; use datafusion_common::exec_err; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_functions::utils::make_scalar_function; use std::sync::Arc; @@ -65,10 +66,6 @@ impl SparkLengthFunc { } impl ScalarUDFImpl for SparkLengthFunc { - fn as_any(&self) -> &dyn std::any::Any { - self - } - fn name(&self) -> &str { "length" } @@ -78,8 +75,9 @@ impl ScalarUDFImpl for SparkLengthFunc { } fn return_type(&self, _args: &[DataType]) -> datafusion_common::Result { - // spark length always returns Int32 - Ok(DataType::Int32) + datafusion_common::internal_err!( + "return_type should not be called, use return_field_from_args instead" + ) } fn invoke_with_args( @@ -92,6 +90,15 @@ impl ScalarUDFImpl for SparkLengthFunc { fn aliases(&self) -> &[String] { &self.aliases } + + fn return_field_from_args( + &self, + args: ReturnFieldArgs, + ) -> datafusion_common::Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + // spark length always returns Int32 + Ok(Arc::new(Field::new(self.name(), DataType::Int32, nullable))) + } } fn spark_length(args: &[ArrayRef]) -> datafusion_common::Result { @@ -191,10 +198,9 @@ where mod tests { use super::*; use crate::function::utils::test::test_scalar_function; - use arrow::array::{Array, Int32Array}; + use arrow::array::Int32Array; use arrow::datatypes::DataType::Int32; use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; macro_rules! test_spark_length_string { ($INPUT:expr, $EXPECTED:expr) => { @@ -279,4 +285,36 @@ mod tests { Ok(()) } + + #[test] + fn test_spark_length_nullability() -> Result<()> { + let func = SparkLengthFunc::new(); + + let nullable_field: FieldRef = Arc::new(Field::new("col", DataType::Utf8, true)); + + let out_nullable = func.return_field_from_args(ReturnFieldArgs { + arg_fields: &[nullable_field], + scalar_arguments: &[None], + })?; + + assert!( + out_nullable.is_nullable(), + "length(col) should be nullable when child is nullable" + ); + + let non_nullable_field: FieldRef = + Arc::new(Field::new("col", DataType::Utf8, false)); + + let out_non_nullable = func.return_field_from_args(ReturnFieldArgs { + arg_fields: &[non_nullable_field], + scalar_arguments: &[None], + })?; + + assert!( + !out_non_nullable.is_nullable(), + "length(col) should NOT be nullable when child is NOT nullable" + ); + + Ok(()) + } } diff --git a/datafusion/spark/src/function/string/like.rs b/datafusion/spark/src/function/string/like.rs index df8eaef7cecbc..50f0822fcbf7c 100644 --- a/datafusion/spark/src/function/string/like.rs +++ b/datafusion/spark/src/function/string/like.rs @@ -17,12 +17,13 @@ use arrow::array::ArrayRef; use arrow::compute::like; -use arrow::datatypes::DataType; -use datafusion_common::{exec_err, Result}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::{Result, exec_err, internal_err}; use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; use datafusion_functions::utils::make_scalar_function; -use std::any::Any; use std::sync::Arc; /// LIKE function for case-sensitive pattern matching @@ -47,10 +48,6 @@ impl SparkLike { } impl ScalarUDFImpl for SparkLike { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "like" } @@ -60,7 +57,16 @@ impl ScalarUDFImpl for SparkLike { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Boolean) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + Ok(Arc::new(Field::new( + self.name(), + DataType::Boolean, + nullable, + ))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -84,8 +90,7 @@ mod tests { use crate::function::utils::test::test_scalar_function; use arrow::array::{Array, BooleanArray}; use arrow::datatypes::DataType::Boolean; - use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_common::ScalarValue; macro_rules! test_like_string_invoke { ($INPUT1:expr, $INPUT2:expr, $EXPECTED:expr) => { @@ -175,4 +180,73 @@ mod tests { Ok(()) } + + #[test] + fn test_like_nullability() { + let like = SparkLike::new(); + + // Test with non-nullable arguments + let non_nullable_field1 = Arc::new(Field::new("str", DataType::Utf8, false)); + let non_nullable_field2 = Arc::new(Field::new("pattern", DataType::Utf8, false)); + + let both_non_nullable = like + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[ + Arc::clone(&non_nullable_field1), + Arc::clone(&non_nullable_field2), + ], + scalar_arguments: &[None, None], + }) + .unwrap(); + + // The result should not be nullable when both inputs are non-nullable + assert!(!both_non_nullable.is_nullable()); + assert_eq!(both_non_nullable.data_type(), &Boolean); + + // Test with first argument nullable + let nullable_field1 = Arc::new(Field::new("str", DataType::Utf8, true)); + + let first_nullable = like + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[ + Arc::clone(&nullable_field1), + Arc::clone(&non_nullable_field2), + ], + scalar_arguments: &[None, None], + }) + .unwrap(); + + // The result should be nullable when first input is nullable + assert!(first_nullable.is_nullable()); + assert_eq!(first_nullable.data_type(), &Boolean); + + // Test with second argument nullable + let nullable_field2 = Arc::new(Field::new("pattern", DataType::Utf8, true)); + + let second_nullable = like + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[ + Arc::clone(&non_nullable_field1), + Arc::clone(&nullable_field2), + ], + scalar_arguments: &[None, None], + }) + .unwrap(); + + // The result should be nullable when second input is nullable + assert!(second_nullable.is_nullable()); + assert_eq!(second_nullable.data_type(), &Boolean); + + // Test with both arguments nullable + let first_second_nullable = like + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[Arc::clone(&nullable_field1), Arc::clone(&nullable_field2)], + scalar_arguments: &[None, None], + }) + .unwrap(); + + // The result should be nullable when both inputs are nullable + assert!(first_second_nullable.is_nullable()); + assert_eq!(first_second_nullable.data_type(), &Boolean); + } } diff --git a/datafusion/spark/src/function/string/luhn_check.rs b/datafusion/spark/src/function/string/luhn_check.rs index 090b16e34b8f1..9241f5e70d085 100644 --- a/datafusion/spark/src/function/string/luhn_check.rs +++ b/datafusion/spark/src/function/string/luhn_check.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. -use std::{any::Any, sync::Arc}; +use std::sync::Arc; use arrow::array::{Array, AsArray, BooleanArray}; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Boolean; use datafusion_common::utils::take_function_args; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -56,10 +56,6 @@ impl SparkLuhnCheck { } impl ScalarUDFImpl for SparkLuhnCheck { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "luhn_check" } diff --git a/datafusion/spark/src/function/string/make_valid_utf8.rs b/datafusion/spark/src/function/string/make_valid_utf8.rs new file mode 100644 index 0000000000000..d2c2ae8b00051 --- /dev/null +++ b/datafusion/spark/src/function/string/make_valid_utf8.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, LargeStringArray, StringArray}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::{ + as_binary_array, as_binary_view_array, as_large_binary_array, +}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, internal_err}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use std::sync::Arc; + +/// Spark-compatible `make_valid_utf8` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkMakeValidUtf8 { + signature: Signature, +} + +impl Default for SparkMakeValidUtf8 { + fn default() -> Self { + Self::new() + } +} + +impl SparkMakeValidUtf8 { + pub fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![ + DataType::Utf8, + DataType::LargeUtf8, + DataType::Utf8View, + DataType::Binary, + DataType::BinaryView, + DataType::LargeBinary, + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkMakeValidUtf8 { + fn name(&self) -> &str { + "make_valid_utf8" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let [make_valid_utf8] = take_function_args(self.name(), args.arg_fields)?; + let return_type = match make_valid_utf8.data_type() { + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { + Ok(make_valid_utf8.data_type().clone()) + } + DataType::Binary | DataType::BinaryView => Ok(DataType::Utf8), + DataType::LargeBinary => Ok(DataType::LargeUtf8), + data_type => internal_err!("make_valid_utf8 does not support: {data_type}"), + }?; + Ok(Arc::new(Field::new( + self.name(), + return_type, + make_valid_utf8.is_nullable(), + ))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_make_valid_utf8_inner, vec![])(&args.args) + } +} + +fn spark_make_valid_utf8_inner(args: &[ArrayRef]) -> Result { + let array = &args[0]; + match &array.data_type() { + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Ok(array.to_owned()), + DataType::Binary => Ok(Arc::new( + as_binary_array(&array)? + .iter() + .map(|x| x.map(String::from_utf8_lossy)) + .collect::(), + )), + DataType::BinaryView => Ok(Arc::new( + as_binary_view_array(&array)? + .iter() + .map(|x| x.map(String::from_utf8_lossy)) + .collect::(), + )), + DataType::LargeBinary => Ok(Arc::new( + as_large_binary_array(&array)? + .iter() + .map(|x| x.map(String::from_utf8_lossy)) + .collect::(), + )), + data_type => { + internal_err!("make_valid_utf8 does not support: {data_type}") + } + } +} diff --git a/datafusion/spark/src/function/string/mod.rs b/datafusion/spark/src/function/string/mod.rs index 480984f02159b..9c90ded5f7e1b 100644 --- a/datafusion/spark/src/function/string/mod.rs +++ b/datafusion/spark/src/function/string/mod.rs @@ -16,20 +16,28 @@ // under the License. pub mod ascii; +pub mod base64; pub mod char; pub mod concat; pub mod elt; pub mod format_string; pub mod ilike; +pub mod is_valid_utf8; pub mod length; pub mod like; pub mod luhn_check; +pub mod make_valid_utf8; +pub mod quote; +pub mod soundex; +pub mod space; +pub mod substring; use datafusion_expr::ScalarUDF; use datafusion_functions::make_udf_function; use std::sync::Arc; make_udf_function!(ascii::SparkAscii, ascii); +make_udf_function!(base64::SparkBase64, base64); make_udf_function!(char::CharFunc, char); make_udf_function!(concat::SparkConcat, concat); make_udf_function!(ilike::SparkILike, ilike); @@ -38,6 +46,13 @@ make_udf_function!(elt::SparkElt, elt); make_udf_function!(like::SparkLike, like); make_udf_function!(luhn_check::SparkLuhnCheck, luhn_check); make_udf_function!(format_string::FormatStringFunc, format_string); +make_udf_function!(space::SparkSpace, space); +make_udf_function!(substring::SparkSubstring, substring); +make_udf_function!(base64::SparkUnBase64, unbase64); +make_udf_function!(soundex::SparkSoundex, soundex); +make_udf_function!(make_valid_utf8::SparkMakeValidUtf8, make_valid_utf8); +make_udf_function!(is_valid_utf8::SparkIsValidUtf8, is_valid_utf8); +make_udf_function!(quote::SparkQuote, quote); pub mod expr_fn { use datafusion_functions::export_functions; @@ -47,6 +62,11 @@ pub mod expr_fn { "Returns the ASCII code point of the first character of string.", arg1 )); + export_functions!(( + base64, + "Encodes the input binary `bin` into a base64 string.", + bin + )); export_functions!(( char, "Returns the ASCII character having the binary equivalent to col. If col is larger than 256 the result is equivalent to char(col % 256).", @@ -87,11 +107,39 @@ pub mod expr_fn { "Returns a formatted string from printf-style format strings.", strfmt args )); + export_functions!((space, "Returns a string consisting of n spaces.", arg1)); + export_functions!(( + substring, + "Returns the substring from string `str` starting at position `pos` with length `length.", + str pos length + )); + export_functions!(( + unbase64, + "Decodes the input string `str` from a base64 string into binary data.", + str + )); + export_functions!((soundex, "Returns Soundex code of the string.", str)); + export_functions!(( + is_valid_utf8, + "Returns true if str is a valid UTF-8 string, otherwise returns false", + str + )); + export_functions!(( + make_valid_utf8, + "Returns the original string if str is a valid UTF-8 string, otherwise returns a new string whose invalid UTF8 byte sequences are replaced using the UNICODE replacement character U+FFFD.", + str + )); + export_functions!(( + quote, + "Returns str enclosed by single quotes and each instance of single quote in it is preceded by a backslash", + str + )); } pub fn functions() -> Vec> { vec![ ascii(), + base64(), char(), concat(), elt(), @@ -100,5 +148,12 @@ pub fn functions() -> Vec> { like(), luhn_check(), format_string(), + space(), + substring(), + unbase64(), + soundex(), + make_valid_utf8(), + is_valid_utf8(), + quote(), ] } diff --git a/datafusion/spark/src/function/string/quote.rs b/datafusion/spark/src/function/string/quote.rs new file mode 100644 index 0000000000000..39ad8bf841764 --- /dev/null +++ b/datafusion/spark/src/function/string/quote.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, OffsetSizeTrait, StringArray}; +use arrow::datatypes::DataType; +use datafusion::logical_expr::{Coercion, ColumnarValue, Signature, TypeSignatureClass}; +use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; +use datafusion_common::types::{NativeType, logical_string}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, exec_err}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Volatility}; +use datafusion_functions::utils::make_scalar_function; + +use std::sync::Arc; + +/// Spark-compatible `quote` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkQuote { + signature: Signature, +} + +impl Default for SparkQuote { + fn default() -> Self { + Self::new() + } +} + +impl SparkQuote { + pub fn new() -> Self { + let str_coercion = Coercion::new_implicit( + TypeSignatureClass::Native(logical_string()), + vec![TypeSignatureClass::Any], + NativeType::String, + ); + Self { + signature: Signature::coercible(vec![str_coercion], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkQuote { + fn name(&self) -> &str { + "quote" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::LargeUtf8 => Ok(DataType::LargeUtf8), + _ => Ok(DataType::Utf8), + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_quote_inner, vec![])(&args.args) + } +} + +fn spark_quote_inner(arg: &[ArrayRef]) -> Result { + let [array] = take_function_args("quote", arg)?; + match &array.data_type() { + DataType::Utf8 => quote_array::(array), + DataType::LargeUtf8 => quote_array::(array), + DataType::Utf8View => quote_view(array), + other => { + exec_err!("unsupported data type {other:?} for function `quote`") + } + } +} + +fn quote_array(array: &ArrayRef) -> Result { + let str_array = as_generic_string_array::(array)?; + let result = str_array + .iter() + .map(|s| s.map(compute_quote)) + .collect::(); + Ok(Arc::new(result)) +} + +fn quote_view(str_view: &ArrayRef) -> Result { + let str_array = as_string_view_array(str_view)?; + let result = str_array + .iter() + .map(|opt_str| opt_str.map(compute_quote)) + .collect::(); + Ok(Arc::new(result) as ArrayRef) +} + +const QUOTE_CHAR: char = '\''; +const ESCAPE_CHAR: char = '\\'; + +fn compute_quote(s: &str) -> String { + let mut quoted = String::with_capacity(s.len() + 2); + quoted.push(QUOTE_CHAR); + for c in s.chars() { + if c == QUOTE_CHAR { + quoted.push(ESCAPE_CHAR); + } + quoted.push(c); + } + quoted.push(QUOTE_CHAR); + quoted +} diff --git a/datafusion/spark/src/function/string/soundex.rs b/datafusion/spark/src/function/string/soundex.rs new file mode 100644 index 0000000000000..1fef0d5384821 --- /dev/null +++ b/datafusion/spark/src/function/string/soundex.rs @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, OffsetSizeTrait, StringArray}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, exec_err}; +use datafusion_expr::{ColumnarValue, Signature, Volatility}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_functions::utils::make_scalar_function; +use std::sync::Arc; + +/// Spark-compatible `soundex` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkSoundex { + signature: Signature, +} + +impl Default for SparkSoundex { + fn default() -> Self { + Self::new() + } +} + +impl SparkSoundex { + pub fn new() -> Self { + Self { + signature: Signature::string(1, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkSoundex { + fn name(&self) -> &str { + "soundex" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::LargeUtf8 => Ok(DataType::LargeUtf8), + _ => Ok(DataType::Utf8), + } + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_soundex_inner, vec![])(&args.args) + } +} + +fn spark_soundex_inner(arg: &[ArrayRef]) -> Result { + let [array] = take_function_args("soundex", arg)?; + match &array.data_type() { + DataType::Utf8 => soundex_array::(array), + DataType::LargeUtf8 => soundex_array::(array), + DataType::Utf8View => soundex_view(array), + other => { + exec_err!("unsupported data type {other:?} for function `soundex`") + } + } +} + +fn soundex_array(array: &ArrayRef) -> Result { + let str_array = as_generic_string_array::(array)?; + let result = str_array + .iter() + .map(|s| s.map(compute_soundex)) + .collect::(); + Ok(Arc::new(result)) +} + +fn soundex_view(str_view: &ArrayRef) -> Result { + let str_array = as_string_view_array(str_view)?; + let result = str_array + .iter() + .map(|opt_str| opt_str.map(compute_soundex)) + .collect::(); + Ok(Arc::new(result) as ArrayRef) +} + +fn classify_char(c: char) -> Option { + match c.to_ascii_uppercase() { + 'B' | 'F' | 'P' | 'V' => Some('1'), + 'C' | 'G' | 'J' | 'K' | 'Q' | 'S' | 'X' | 'Z' => Some('2'), + 'D' | 'T' => Some('3'), + 'L' => Some('4'), + 'M' | 'N' => Some('5'), + 'R' => Some('6'), + _ => None, + } +} + +fn is_ignored(c: char) -> bool { + matches!(c.to_ascii_uppercase(), 'H' | 'W') +} + +fn compute_soundex(s: &str) -> String { + let mut chars = s.chars(); + + let first_char = match chars.next() { + Some(c) if c.is_ascii_alphabetic() => c.to_ascii_uppercase(), + _ => return s.to_string(), + }; + + let mut soundex_code = String::with_capacity(4); + soundex_code.push(first_char); + let mut last_code = classify_char(first_char); + + for c in chars { + if soundex_code.len() >= 4 { + break; + } + + if is_ignored(c) { + continue; + } + + match classify_char(c) { + Some(code) => { + if last_code != Some(code) { + soundex_code.push(code); + } + last_code = Some(code); + } + None => { + last_code = None; + } + } + } + format!("{soundex_code:0<4}") +} diff --git a/datafusion/spark/src/function/string/space.rs b/datafusion/spark/src/function/string/space.rs new file mode 100644 index 0000000000000..a231401f3eef4 --- /dev/null +++ b/datafusion/spark/src/function/string/space.rs @@ -0,0 +1,227 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, DictionaryArray, Int32Array, StringArray, StringBuilder, + as_dictionary_array, +}; +use arrow::datatypes::{DataType, Int32Type}; +use datafusion_common::cast::as_int32_array; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::sync::Arc; + +/// Spark-compatible `space` expression +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkSpace { + signature: Signature, +} + +impl Default for SparkSpace { + fn default() -> Self { + Self::new() + } +} + +impl SparkSpace { + pub fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![ + DataType::Int32, + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Int32), + ), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkSpace { + fn name(&self) -> &str { + "space" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, args: &[DataType]) -> Result { + let return_type = match &args[0] { + DataType::Dictionary(key_type, _) => { + DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8)) + } + _ => DataType::Utf8, + }; + Ok(return_type) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_space(&args.args) + } +} + +pub fn spark_space(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!("space function takes exactly one argument"); + } + match &args[0] { + ColumnarValue::Array(array) => { + let result = spark_space_array(array)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar) => { + let result = spark_space_scalar(scalar)?; + Ok(ColumnarValue::Scalar(result)) + } + } +} + +fn spark_space_array(array: &ArrayRef) -> Result { + match array.data_type() { + DataType::Int32 => { + let array = as_int32_array(array)?; + Ok(Arc::new(spark_space_array_inner(array))) + } + DataType::Dictionary(_, _) => { + let dict = as_dictionary_array::(array); + let values = spark_space_array(dict.values())?; + let result = DictionaryArray::try_new(dict.keys().clone(), values)?; + Ok(Arc::new(result)) + } + other => { + exec_err!("Unsupported data type {other:?} for function `space`") + } + } +} + +fn spark_space_scalar(scalar: &ScalarValue) -> Result { + match scalar { + ScalarValue::Int32(value) => { + let result = value.map(|v| { + if v <= 0 { + String::new() + } else { + " ".repeat(v as usize) + } + }); + Ok(ScalarValue::Utf8(result)) + } + other => { + exec_err!("Unsupported data type {other:?} for function `space`") + } + } +} + +fn spark_space_array_inner(array: &Int32Array) -> StringArray { + let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16); + let mut space_buf = String::new(); + for value in array.iter() { + match value { + None => builder.append_null(), + Some(l) if l > 0 => { + let l = l as usize; + if space_buf.len() < l { + space_buf = " ".repeat(l); + } + builder.append_value(&space_buf[..l]); + } + Some(_) => builder.append_value(""), + } + } + builder.finish() +} + +#[cfg(test)] +mod tests { + use crate::function::string::space::spark_space; + use arrow::array::{Array, Int32Array, Int32DictionaryArray}; + use arrow::datatypes::Int32Type; + use datafusion_common::cast::{as_dictionary_array, as_string_array}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::ColumnarValue; + use std::sync::Arc; + + #[test] + fn test_spark_space_int32_array() -> Result<()> { + let int32_array = ColumnarValue::Array(Arc::new(Int32Array::from(vec![ + Some(1), + Some(-3), + Some(0), + Some(5), + None, + ]))); + let ColumnarValue::Array(result) = spark_space(&[int32_array])? else { + unreachable!() + }; + let result = as_string_array(&result)?; + + assert_eq!(result.value(0), " "); + assert_eq!(result.value(1), ""); + assert_eq!(result.value(2), ""); + assert_eq!(result.value(3), " "); + assert!(result.is_null(4)); + Ok(()) + } + + #[test] + fn test_spark_space_dictionary() -> Result<()> { + let dictionary = ColumnarValue::Array(Arc::new(Int32DictionaryArray::new( + Int32Array::from(vec![0, 1, 2, 3, 4]), + Arc::new(Int32Array::from(vec![ + Some(1), + Some(-3), + Some(0), + Some(5), + None, + ])), + ))); + let ColumnarValue::Array(result) = spark_space(&[dictionary])? else { + unreachable!() + }; + let result = + as_string_array(as_dictionary_array::(&result)?.values())?; + assert_eq!(result.value(0), " "); + assert_eq!(result.value(1), ""); + assert_eq!(result.value(2), ""); + assert_eq!(result.value(3), " "); + assert!(result.is_null(4)); + Ok(()) + } + + #[test] + fn test_spark_space_scalar() -> Result<()> { + let scalar = ColumnarValue::Scalar(ScalarValue::Int32(Some(-5))); + let ColumnarValue::Scalar(result) = spark_space(&[scalar])? else { + unreachable!() + }; + match result { + ScalarValue::Utf8(Some(result)) => { + assert_eq!(result, ""); + } + _ => unreachable!(), + } + Ok(()) + } +} diff --git a/datafusion/spark/src/function/string/substring.rs b/datafusion/spark/src/function/string/substring.rs new file mode 100644 index 0000000000000..1c26564e03993 --- /dev/null +++ b/datafusion/spark/src/function/string/substring.rs @@ -0,0 +1,404 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayAccessor, ArrayBuilder, ArrayRef, AsArray, BinaryViewBuilder, + GenericBinaryBuilder, GenericStringBuilder, Int64Array, OffsetSizeTrait, + StringViewBuilder, +}; +use arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::{Field, FieldRef}; +use datafusion_common::cast::as_int64_array; +use datafusion_common::types::{ + NativeType, logical_int32, logical_int64, logical_string, +}; +use datafusion_common::{Result, exec_err}; +use datafusion_expr::{Coercion, ReturnFieldArgs, TypeSignatureClass}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_functions::unicode::substr::{enable_ascii_fast_path, get_true_start_end}; +use datafusion_functions::utils::make_scalar_function; +use std::sync::Arc; + +/// Spark-compatible `substring` expression +/// +/// +/// Returns the substring from string starting at position pos with length len. +/// Position is 1-indexed. If pos is negative, it counts from the end of the string. +/// Returns NULL if any input is NULL. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkSubstring { + signature: Signature, + aliases: Vec, +} + +impl Default for SparkSubstring { + fn default() -> Self { + Self::new() + } +} + +impl SparkSubstring { + pub fn new() -> Self { + let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string())); + let binary = Coercion::new_exact(TypeSignatureClass::Binary); + let int64 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Native(logical_int32())], + NativeType::Int64, + ); + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![string.clone(), int64.clone()]), + TypeSignature::Coercible(vec![ + string.clone(), + int64.clone(), + int64.clone(), + ]), + TypeSignature::Coercible(vec![binary.clone(), int64.clone()]), + TypeSignature::Coercible(vec![ + binary.clone(), + int64.clone(), + int64.clone(), + ]), + ], + Volatility::Immutable, + ) + .with_parameter_names(vec![ + "str".to_string(), + "pos".to_string(), + "length".to_string(), + ]) + .expect("valid parameter names"), + aliases: vec![String::from("substr")], + } + } +} + +impl ScalarUDFImpl for SparkSubstring { + fn name(&self) -> &str { + "substring" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_substring, vec![])(&args.args) + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + datafusion_common::internal_err!( + "return_type should not be called for Spark substring" + ) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result { + // Spark semantics: substring returns NULL if ANY input is NULL + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + + Ok(Arc::new(Field::new( + "substring", + args.arg_fields[0].data_type().clone(), + nullable, + ))) + } +} + +fn spark_substring(args: &[ArrayRef]) -> Result { + let start_array = as_int64_array(&args[1])?; + let length_array = if args.len() > 2 { + Some(as_int64_array(&args[2])?) + } else { + None + }; + + match args[0].data_type() { + DataType::Utf8 => { + let array = args[0].as_string::(); + let is_ascii = enable_ascii_fast_path(&array, start_array, length_array); + spark_substring_generic( + &array, + start_array, + length_array, + GenericStringBuilder::::new(), + is_ascii, + ) + } + DataType::LargeUtf8 => { + let array = args[0].as_string::(); + let is_ascii = enable_ascii_fast_path(&array, start_array, length_array); + spark_substring_generic( + &array, + start_array, + length_array, + GenericStringBuilder::::new(), + is_ascii, + ) + } + DataType::Utf8View => { + let array = args[0].as_string_view(); + let is_ascii = enable_ascii_fast_path(&array, start_array, length_array); + spark_substring_generic( + &array, + start_array, + length_array, + StringViewBuilder::new(), + is_ascii, + ) + } + // Binary paths always use byte-level indexing, so `is_ascii` is irrelevant + // and set to `true` (its value is ignored by the `[u8]` impl of + // `SubstringItem`). + DataType::Binary => spark_substring_generic( + &args[0].as_binary::(), + start_array, + length_array, + GenericBinaryBuilder::::new(), + true, + ), + DataType::LargeBinary => spark_substring_generic( + &args[0].as_binary::(), + start_array, + length_array, + GenericBinaryBuilder::::new(), + true, + ), + DataType::BinaryView => spark_substring_generic( + &args[0].as_binary_view(), + start_array, + length_array, + BinaryViewBuilder::new(), + true, + ), + other => exec_err!( + "Unsupported data type {other:?} for function spark_substring, expected Utf8View, Utf8, LargeUtf8, Binary, LargeBinary or BinaryView." + ), + } +} + +/// Convert Spark's start position to DataFusion's 1-based start position. +/// +/// Spark semantics: +/// - Positive start: 1-based index from beginning +/// - Zero start: treated as 1 +/// - Negative start: counts from end of string +/// +/// The result may be `<= 0` when a negative start lands before the string +/// (e.g. `start=-10` on a 3-char string gives `-6`). Such values are passed +/// through to `get_true_start_end`, which clamps them and yields an empty +/// slice — matching Spark's behavior for out-of-range negative positions. +#[inline] +fn spark_start_to_datafusion_start(start: i64, len: usize) -> i64 { + if start >= 0 { + start.max(1) + } else { + let len_i64 = i64::try_from(len).unwrap_or(i64::MAX); + start + len_i64 + 1 + } +} + +trait SubstringItem { + /// Length used for Spark's negative-position adjustment. + /// For `str` this is characters (or bytes in ASCII mode); for `[u8]` it is + /// always byte count. + fn positional_len(&self, is_ascii: bool) -> usize; + + /// Converts Spark's 1-indexed adjusted start + optional length into a + /// byte range clamped to `[0, byte_len]`. + fn byte_range( + &self, + adjusted_start: i64, + len: Option, + is_ascii: bool, + ) -> Result<(usize, usize)>; + + fn byte_slice(&self, start: usize, end: usize) -> &Self; +} + +impl SubstringItem for str { + fn positional_len(&self, is_ascii: bool) -> usize { + if is_ascii { + self.len() + } else { + self.chars().count() + } + } + + fn byte_range( + &self, + adjusted_start: i64, + len: Option, + is_ascii: bool, + ) -> Result<(usize, usize)> { + get_true_start_end(self, adjusted_start, len, is_ascii) + } + + fn byte_slice(&self, start: usize, end: usize) -> &Self { + &self[start..end] + } +} + +impl SubstringItem for [u8] { + fn positional_len(&self, _is_ascii: bool) -> usize { + self.len() + } + + fn byte_range( + &self, + adjusted_start: i64, + len: Option, + _is_ascii: bool, + ) -> Result<(usize, usize)> { + let byte_len = self.len(); + let start0 = adjusted_start.saturating_sub(1); + let end0 = match len { + Some(l) => start0.saturating_add(l), + None => byte_len as i64, + }; + let byte_len_i64 = byte_len as i64; + Ok(( + start0.clamp(0, byte_len_i64) as usize, + end0.clamp(0, byte_len_i64) as usize, + )) + } + + fn byte_slice(&self, start: usize, end: usize) -> &Self { + &self[start..end] + } +} + +trait SubstringBuilder: ArrayBuilder { + type Item: SubstringItem + ?Sized; + fn append_value(&mut self, val: &Self::Item); + fn append_null(&mut self); + /// Spark's semantic "empty" for this builder's item type, used for the + /// negative-length short-circuit. + fn append_empty(&mut self); +} + +impl SubstringBuilder for GenericStringBuilder { + type Item = str; + fn append_value(&mut self, val: &str) { + GenericStringBuilder::append_value(self, val); + } + fn append_null(&mut self) { + GenericStringBuilder::append_null(self); + } + fn append_empty(&mut self) { + GenericStringBuilder::append_value(self, ""); + } +} + +impl SubstringBuilder for StringViewBuilder { + type Item = str; + fn append_value(&mut self, val: &str) { + StringViewBuilder::append_value(self, val); + } + fn append_null(&mut self) { + StringViewBuilder::append_null(self); + } + fn append_empty(&mut self) { + StringViewBuilder::append_value(self, ""); + } +} + +impl SubstringBuilder for GenericBinaryBuilder { + type Item = [u8]; + fn append_value(&mut self, val: &[u8]) { + GenericBinaryBuilder::append_value(self, val); + } + fn append_null(&mut self) { + GenericBinaryBuilder::append_null(self); + } + fn append_empty(&mut self) { + GenericBinaryBuilder::append_value(self, &[]); + } +} + +impl SubstringBuilder for BinaryViewBuilder { + type Item = [u8]; + fn append_value(&mut self, val: &[u8]) { + BinaryViewBuilder::append_value(self, val); + } + fn append_null(&mut self) { + BinaryViewBuilder::append_null(self); + } + fn append_empty(&mut self) { + BinaryViewBuilder::append_value(self, []); + } +} + +/// Unified implementation of Spark's `substring`, generic over the source +/// array (`StringArrayType`/`BinaryArrayType` via `ArrayAccessor`) and its +/// corresponding builder. Per-row indexing semantics are delegated to +/// [`SubstringItem`], which differs between `str` (char-aware when +/// `is_ascii` is false) and `[u8]` (always byte-level). +fn spark_substring_generic<'a, Source, Item, Builder>( + array: &Source, + start_array: &Int64Array, + length_array: Option<&Int64Array>, + mut builder: Builder, + is_ascii: bool, +) -> Result +where + Source: ArrayAccessor, + Item: SubstringItem + ?Sized + 'a, + Builder: SubstringBuilder, +{ + for i in 0..array.len() { + if array.is_null(i) || start_array.is_null(i) { + builder.append_null(); + continue; + } + + if let Some(len_arr) = length_array + && len_arr.is_null(i) + { + builder.append_null(); + continue; + } + + let value = array.value(i); + let start = start_array.value(i); + let len_opt = length_array.map(|arr| arr.value(i)); + + // Spark: negative length yields an empty value + if let Some(len) = len_opt + && len < 0 + { + builder.append_empty(); + continue; + } + + let positional_len = value.positional_len(is_ascii); + let adjusted_start = spark_start_to_datafusion_start(start, positional_len); + let (byte_start, byte_end) = + value.byte_range(adjusted_start, len_opt, is_ascii)?; + builder.append_value(value.byte_slice(byte_start, byte_end)); + } + + Ok(builder.finish()) +} diff --git a/datafusion/spark/src/function/url/mod.rs b/datafusion/spark/src/function/url/mod.rs index 82bf8a9e09616..1313edaed5347 100644 --- a/datafusion/spark/src/function/url/mod.rs +++ b/datafusion/spark/src/function/url/mod.rs @@ -21,9 +21,15 @@ use std::sync::Arc; pub mod parse_url; pub mod try_parse_url; +pub mod try_url_decode; +pub mod url_decode; +pub mod url_encode; make_udf_function!(parse_url::ParseUrl, parse_url); make_udf_function!(try_parse_url::TryParseUrl, try_parse_url); +make_udf_function!(try_url_decode::TryUrlDecode, try_url_decode); +make_udf_function!(url_decode::UrlDecode, url_decode); +make_udf_function!(url_encode::UrlEncode, url_encode); pub mod expr_fn { use datafusion_functions::export_functions; @@ -38,8 +44,29 @@ pub mod expr_fn { "Same as parse_url but returns NULL if an invalid URL is provided.", args )); + export_functions!(( + url_decode, + "Decodes a URL-encoded string in ‘application/x-www-form-urlencoded’ format to its original format.", + args + )); + export_functions!(( + try_url_decode, + "Same as url_decode but returns NULL if an invalid URL-encoded string is provided", + args + )); + export_functions!(( + url_encode, + "Encodes a string into a URL-encoded string in ‘application/x-www-form-urlencoded’ format.", + args + )); } pub fn functions() -> Vec> { - vec![parse_url(), try_parse_url()] + vec![ + parse_url(), + try_parse_url(), + try_url_decode(), + url_decode(), + url_encode(), + ] } diff --git a/datafusion/spark/src/function/url/parse_url.rs b/datafusion/spark/src/function/url/parse_url.rs index a8afa1d9639f5..18f0bb1e0d78b 100644 --- a/datafusion/spark/src/function/url/parse_url.rs +++ b/datafusion/spark/src/function/url/parse_url.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; use std::sync::Arc; use arrow::array::{ @@ -26,7 +25,7 @@ use arrow::datatypes::DataType; use datafusion_common::cast::{ as_large_string_array, as_string_array, as_string_view_array, }; -use datafusion_common::{exec_datafusion_err, exec_err, Result}; +use datafusion_common::{Result, exec_datafusion_err, exec_err}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, @@ -78,36 +77,57 @@ impl ParseUrl { /// # Returns /// /// * `Ok(Some(String))` - The extracted URL component as a string - /// * `Ok(None)` - If the requested component doesn't exist or is empty + /// * `Ok(None)` - If the requested component doesn't exist /// * `Err(DataFusionError)` - If the URL is malformed and cannot be parsed fn parse(value: &str, part: &str, key: Option<&str>) -> Result> { let url: std::result::Result = Url::parse(value); if let Err(ParseError::RelativeUrlWithoutBase) = url { return if !value.contains("://") { - Ok(None) + // Schemeless URLs are treated as relative URIs (like java.net.URI). + // Manually parse path, query, and fragment components. + let (without_fragment, fragment) = match value.split_once('#') { + Some((before, frag)) => (before, Some(frag)), + None => (value, None), + }; + let (path, query) = match without_fragment.split_once('?') { + Some((p, q)) => (p, Some(q)), + None => (without_fragment, None), + }; + Ok(match part { + "PATH" => Some(path.to_string()), + "QUERY" => match key { + None => query.map(String::from), + Some(key) => Self::query_value(query, key).map(String::from), + }, + "REF" => fragment.map(String::from), + "FILE" => { + // FILE = path + query (without fragment) + Some(without_fragment.to_string()) + } + // HOST, PROTOCOL, AUTHORITY, USERINFO → NULL + _ => None, + }) } else { - Err(exec_datafusion_err!("The url is invalid: {value}. Use `try_parse_url` to tolerate invalid URL and return NULL instead. SQLSTATE: 22P02")) + Err(exec_datafusion_err!( + "The url is invalid: {value}. Use `try_parse_url` to tolerate invalid URL and return NULL instead. SQLSTATE: 22P02" + )) }; }; url.map_err(|e| exec_datafusion_err!("{e:?}")) .map(|url| match part { "HOST" => url.host_str().map(String::from), "PATH" => { - let path: String = url.path().to_string(); - let path: String = if path == "/" { "".to_string() } else { path }; - Some(path) + let path = Self::path(value, &url); + Some(path.to_string()) } "QUERY" => match key { None => url.query().map(String::from), - Some(key) => url - .query_pairs() - .find(|(k, _)| k == key) - .map(|(_, v)| v.into_owned()), + Some(key) => Self::query_value(url.query(), key).map(String::from), }, "REF" => url.fragment().map(String::from), "PROTOCOL" => Some(url.scheme().to_string()), "FILE" => { - let path = url.path(); + let path = Self::path(value, &url); match url.query() { Some(query) => Some(format!("{path}?{query}")), None => Some(path.to_string()), @@ -127,13 +147,39 @@ impl ParseUrl { _ => None, }) } -} -impl ScalarUDFImpl for ParseUrl { - fn as_any(&self) -> &dyn Any { - self + fn path<'a>(value: &str, url: &'a Url) -> &'a str { + let path = url.path(); + if path == "/" && Self::absolute_url_has_empty_path(value) { + "" + } else { + path + } + } + + fn absolute_url_has_empty_path(value: &str) -> bool { + let Some(authority_start) = value.find("://").map(|index| index + 3) else { + return false; + }; + let after_authority = &value[authority_start..]; + match after_authority.find(['/', '?', '#']) { + None => true, + Some(index) => matches!(after_authority.as_bytes()[index], b'?' | b'#'), + } } + fn query_value<'a>(query: Option<&'a str>, key: &str) -> Option<&'a str> { + query.and_then(|query| { + query + .split('&') + .filter_map(|pair| pair.split_once('=')) + .find(|(query_key, _)| *query_key == key) + .map(|(_, value)| value) + }) + } +} + +impl ScalarUDFImpl for ParseUrl { fn name(&self) -> &str { "parse_url" } @@ -186,7 +232,7 @@ pub fn spark_handled_parse_url( let url = &args[0]; let part = &args[1]; - let result = if args.len() == 3 { + if args.len() == 3 { // In this case, the 'key' argument is passed let key = &args[2]; @@ -197,6 +243,7 @@ pub fn spark_handled_parse_url( as_string_array(part)?, as_string_array(key)?, handler_err, + true, ) } (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => { @@ -205,6 +252,7 @@ pub fn spark_handled_parse_url( as_string_view_array(part)?, as_string_view_array(key)?, handler_err, + true, ) } (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => { @@ -213,9 +261,15 @@ pub fn spark_handled_parse_url( as_large_string_array(part)?, as_large_string_array(key)?, handler_err, + true, ) } - _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args), + _ => exec_err!( + "`parse_url` expects STRING arguments, got ({}, {}, {})", + url.data_type(), + part.data_type(), + key.data_type() + ), } } else { // The 'key' argument is omitted, assume all values are null @@ -233,6 +287,7 @@ pub fn spark_handled_parse_url( as_string_array(part)?, &key, handler_err, + false, ) } (DataType::Utf8View, DataType::Utf8View) => { @@ -241,6 +296,7 @@ pub fn spark_handled_parse_url( as_string_view_array(part)?, &key, handler_err, + false, ) } (DataType::LargeUtf8, DataType::LargeUtf8) => { @@ -249,12 +305,16 @@ pub fn spark_handled_parse_url( as_large_string_array(part)?, &key, handler_err, + false, ) } - _ => exec_err!("{} expects STRING arguments, got {:?}", "`parse_url`", args), + _ => exec_err!( + "`parse_url` expects STRING arguments, got ({}, {})", + url.data_type(), + part.data_type() + ), } - }; - result + } } fn process_parse_url<'a, A, B, C, T>( @@ -262,6 +322,7 @@ fn process_parse_url<'a, A, B, C, T>( part_array: &'a B, key_array: &'a C, handle: impl Fn(Result>) -> Result>, + has_key_arg: bool, ) -> Result where &'a A: StringArrayType<'a>, @@ -274,7 +335,11 @@ where .zip(part_array.iter()) .zip(key_array.iter()) .map(|((url, part), key)| { - if let (Some(url), Some(part), key) = (url, part, key) { + // Spark returns NULL when the third argument is explicitly NULL + if has_key_arg && key.is_none() { + return Ok(None); + } + if let (Some(url), Some(part)) = (url, part) { handle(ParseUrl::parse(url, part, key)) } else { Ok(None) @@ -287,10 +352,8 @@ where #[cfg(test)] mod tests { use super::*; - use arrow::array::{ArrayRef, Int32Array, StringArray}; - use datafusion_common::Result; + use arrow::array::Int32Array; use std::array::from_ref; - use std::sync::Arc; fn sa(vals: &[Option<&str>]) -> ArrayRef { Arc::new(StringArray::from(vals.to_vec())) as ArrayRef @@ -340,16 +403,163 @@ mod tests { } #[test] - fn test_parse_path_root_is_empty_string() -> Result<()> { - let got = ParseUrl::parse("https://example.com/", "PATH", None)?; - assert_eq!(got, Some("".to_string())); + fn test_parse_path_empty_vs_root() -> Result<()> { + assert_eq!( + ParseUrl::parse("https://example.com", "PATH", None)?, + Some("".to_string()) + ); + assert_eq!( + ParseUrl::parse("https://example.com/", "PATH", None)?, + Some("/".to_string()) + ); + assert_eq!( + ParseUrl::parse("https://ex.com/dir%20/pa%20th.HTML", "PATH", None)?, + Some("/dir%20/pa%20th.HTML".to_string()) + ); + Ok(()) + } + + #[test] + fn test_parse_query_key_is_raw() -> Result<()> { + let url = "https://use%20r:pas%20s@example.com/dir%20/pa%20th.HTML?query=x%20y&q2=2#Ref%20two"; + assert_eq!( + ParseUrl::parse(url, "QUERY", None)?, + Some("query=x%20y&q2=2".to_string()) + ); + assert_eq!( + ParseUrl::parse(url, "QUERY", Some("query"))?, + Some("x%20y".to_string()) + ); + assert_eq!( + ParseUrl::parse("http://ex.com?key=", "QUERY", Some("key"))?, + Some("".to_string()) + ); + assert_eq!( + ParseUrl::parse("http://ex.com?keyonly", "QUERY", Some("keyonly"))?, + None + ); + assert_eq!( + ParseUrl::parse("http://ex.com?a=1&a=2", "QUERY", Some("a"))?, + Some("1".to_string()) + ); + assert_eq!( + ParseUrl::parse("http://ex.com?a%20b=1", "QUERY", Some("a b"))?, + None + ); + Ok(()) + } + + #[test] + fn test_parse_empty_path_file() -> Result<()> { + assert_eq!(ParseUrl::parse("", "PATH", None)?, Some("".to_string())); + assert_eq!( + ParseUrl::parse("http://example.com", "FILE", None)?, + Some("".to_string()) + ); + assert_eq!( + ParseUrl::parse("http://example.com?foo=bar", "FILE", None)?, + Some("?foo=bar".to_string()) + ); + assert_eq!( + ParseUrl::parse("http://example.com#fragment", "FILE", None)?, + Some("".to_string()) + ); + assert_eq!( + ParseUrl::parse("http://example.com/?foo=bar", "FILE", None)?, + Some("/?foo=bar".to_string()) + ); + assert_eq!( + ParseUrl::parse("http://ex.com/?", "FILE", None)?, + Some("/?".to_string()) + ); + assert_eq!( + ParseUrl::parse("http://ex.com?", "FILE", None)?, + Some("?".to_string()) + ); Ok(()) } #[test] - fn test_parse_malformed_url_returns_error() -> Result<()> { - let got = ParseUrl::parse("notaurl", "HOST", None)?; - assert_eq!(got, None); + fn test_parse_schemeless_url() -> Result<()> { + // Spark's java.net.URI treats schemeless strings as relative URIs. + // Simple schemeless string: no query, no fragment. + assert_eq!( + ParseUrl::parse("notaurl", "PATH", None)?, + Some("notaurl".to_string()) + ); + assert_eq!( + ParseUrl::parse("notaurl", "FILE", None)?, + Some("notaurl".to_string()) + ); + assert_eq!(ParseUrl::parse("notaurl", "HOST", None)?, None); + assert_eq!(ParseUrl::parse("notaurl", "PROTOCOL", None)?, None); + assert_eq!(ParseUrl::parse("notaurl", "QUERY", None)?, None); + assert_eq!(ParseUrl::parse("notaurl", "REF", None)?, None); + assert_eq!(ParseUrl::parse("notaurl", "AUTHORITY", None)?, None); + assert_eq!(ParseUrl::parse("notaurl", "USERINFO", None)?, None); + + // Schemeless URL with query string + assert_eq!( + ParseUrl::parse("notaurl?key=value", "PATH", None)?, + Some("notaurl".to_string()) + ); + assert_eq!( + ParseUrl::parse("notaurl?key=value", "FILE", None)?, + Some("notaurl?key=value".to_string()) + ); + assert_eq!( + ParseUrl::parse("notaurl?key=value", "QUERY", None)?, + Some("key=value".to_string()) + ); + assert_eq!( + ParseUrl::parse("notaurl?key=value", "QUERY", Some("key"))?, + Some("value".to_string()) + ); + assert_eq!( + ParseUrl::parse("notaurl?key=value", "QUERY", Some("missing"))?, + None + ); + assert_eq!(ParseUrl::parse("notaurl?key=value", "HOST", None)?, None); + assert_eq!( + ParseUrl::parse("notaurl?key=value", "PROTOCOL", None)?, + None + ); + + // Schemeless URL with fragment + assert_eq!( + ParseUrl::parse("notaurl#reference", "REF", None)?, + Some("reference".to_string()) + ); + assert_eq!( + ParseUrl::parse("notaurl#reference", "PATH", None)?, + Some("notaurl".to_string()) + ); + assert_eq!( + ParseUrl::parse("notaurl#reference", "FILE", None)?, + Some("notaurl".to_string()) + ); + + // Schemeless URL with both query and fragment + assert_eq!( + ParseUrl::parse("notaurl?a=1&b=2#frag", "PATH", None)?, + Some("notaurl".to_string()) + ); + assert_eq!( + ParseUrl::parse("notaurl?a=1&b=2#frag", "QUERY", None)?, + Some("a=1&b=2".to_string()) + ); + assert_eq!( + ParseUrl::parse("notaurl?a=1&b=2#frag", "QUERY", Some("b"))?, + Some("2".to_string()) + ); + assert_eq!( + ParseUrl::parse("notaurl?a=1&b=2#frag", "REF", None)?, + Some("frag".to_string()) + ); + assert_eq!( + ParseUrl::parse("notaurl?a=1&b=2#frag", "FILE", None)?, + Some("notaurl?a=1&b=2".to_string()) + ); Ok(()) } @@ -363,7 +573,7 @@ mod tests { assert_eq!(out_sa.len(), 2); assert_eq!(out_sa.value(0), "example.com"); - assert_eq!(out_sa.value(1), ""); + assert_eq!(out_sa.value(1), "/"); Ok(()) } diff --git a/datafusion/spark/src/function/url/try_parse_url.rs b/datafusion/spark/src/function/url/try_parse_url.rs index c04850f3a6bf0..c9cafef97ba9f 100644 --- a/datafusion/spark/src/function/url/try_parse_url.rs +++ b/datafusion/spark/src/function/url/try_parse_url.rs @@ -15,9 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - -use crate::function::url::parse_url::{spark_handled_parse_url, ParseUrl}; +use crate::function::url::parse_url::{ParseUrl, spark_handled_parse_url}; use arrow::array::ArrayRef; use arrow::datatypes::DataType; use datafusion_common::Result; @@ -52,10 +50,6 @@ impl TryParseUrl { } impl ScalarUDFImpl for TryParseUrl { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { "try_parse_url" } diff --git a/datafusion/spark/src/function/url/try_url_decode.rs b/datafusion/spark/src/function/url/try_url_decode.rs new file mode 100644 index 0000000000000..78968288fc2f5 --- /dev/null +++ b/datafusion/spark/src/function/url/try_url_decode.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::datatypes::DataType; + +use datafusion_common::Result; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; + +use crate::function::url::url_decode::{UrlDecode, spark_handled_url_decode}; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct TryUrlDecode { + signature: Signature, + url_decoder: UrlDecode, +} + +impl Default for TryUrlDecode { + fn default() -> Self { + Self::new() + } +} + +impl TryUrlDecode { + pub fn new() -> Self { + Self { + signature: Signature::string(1, Volatility::Immutable), + url_decoder: UrlDecode::new(), + } + } +} + +impl ScalarUDFImpl for TryUrlDecode { + fn name(&self) -> &str { + "try_url_decode" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.url_decoder.return_type(arg_types) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + make_scalar_function(spark_try_url_decode, vec![])(&args) + } +} + +fn spark_try_url_decode(args: &[ArrayRef]) -> Result { + spark_handled_url_decode(args, |x| match x { + Err(_) => Ok(None), + result => result, + }) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::StringArray; + use datafusion_common::cast::as_string_array; + + use super::*; + + #[test] + fn test_try_decode_error_handled() -> Result<()> { + let input = Arc::new(StringArray::from(vec![ + Some("http%3A%2F%2spark.apache.org"), // '%2s' is not a valid percent encoded character + // Valid cases + Some("https%3A%2F%2Fspark.apache.org"), + None, + ])); + + let expected = + StringArray::from(vec![None, Some("https://spark.apache.org"), None]); + + let result = spark_try_url_decode(&[input as ArrayRef])?; + let result = as_string_array(&result)?; + + assert_eq!(&expected, result); + Ok(()) + } +} diff --git a/datafusion/spark/src/function/url/url_decode.rs b/datafusion/spark/src/function/url/url_decode.rs new file mode 100644 index 0000000000000..0966cc380e497 --- /dev/null +++ b/datafusion/spark/src/function/url/url_decode.rs @@ -0,0 +1,254 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::borrow::Cow; +use std::sync::Arc; + +use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; +use datafusion_common::{Result, exec_datafusion_err, exec_err, plan_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use percent_encoding::percent_decode; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct UrlDecode { + signature: Signature, +} + +impl Default for UrlDecode { + fn default() -> Self { + Self::new() + } +} + +impl UrlDecode { + pub fn new() -> Self { + Self { + signature: Signature::string(1, Volatility::Immutable), + } + } + + /// Decodes a URL-encoded string from application/x-www-form-urlencoded format. + /// Although the `url::form_urlencoded` support decoding, it does not return error when the string is malformed + /// For example: "%2s" is not a valid percent-encoding, the `decode` function from `url::form_urlencoded` + /// will ignore this instead of return error + /// This function reproduce the same decoding process, plus an extra validation step + /// See + /// + /// # Arguments + /// + /// * `value` - The URL-encoded string to decode + /// + /// # Returns + /// + /// * `Ok(String)` - The decoded string + /// * `Err(DataFusionError)` - If the input is malformed or contains invalid UTF-8 + /// + fn decode(value: &str) -> Result { + // Check if the string has valid percent encoding + Self::validate_percent_encoding(value)?; + + let replaced = Self::replace_plus(value.as_bytes()); + percent_decode(&replaced) + .decode_utf8() + .map_err(|e| exec_datafusion_err!("Invalid UTF-8 sequence: {e}")) + .map(|parsed| parsed.into_owned()) + } + + /// Replace b'+' with b' ' + /// See: + fn replace_plus(input: &[u8]) -> Cow<'_, [u8]> { + match input.iter().position(|&b| b == b'+') { + None => Cow::Borrowed(input), + Some(first_position) => { + let mut replaced = input.to_owned(); + replaced[first_position] = b' '; + for byte in &mut replaced[first_position + 1..] { + if *byte == b'+' { + *byte = b' '; + } + } + Cow::Owned(replaced) + } + } + } + + /// Validate percent-encoding of the string + fn validate_percent_encoding(value: &str) -> Result<()> { + let bytes = value.as_bytes(); + let mut i = 0; + + while i < bytes.len() { + if bytes[i] == b'%' { + // Check if we have at least 2 more characters + if i + 2 >= bytes.len() { + return exec_err!( + "Invalid percent-encoding: incomplete sequence at position {}", + i + ); + } + + let hex1 = bytes[i + 1]; + let hex2 = bytes[i + 2]; + + if !hex1.is_ascii_hexdigit() || !hex2.is_ascii_hexdigit() { + return exec_err!( + "Invalid percent-encoding: invalid hex sequence '%{}{}' at position {}", + hex1 as char, + hex2 as char, + i + ); + } + i += 3; + } else { + i += 1; + } + } + Ok(()) + } +} + +impl ScalarUDFImpl for UrlDecode { + fn name(&self) -> &str { + "url_decode" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return plan_err!( + "{} expects 1 argument, but got {}", + self.name(), + arg_types.len() + ); + } + // As the type signature is already checked, we can safely return the type of the first argument + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + make_scalar_function(spark_url_decode, vec![])(&args) + } +} + +/// Core implementation of URL decoding function. +/// +/// # Arguments +/// +/// * `args` - A slice containing exactly one ArrayRef with the URL-encoded strings to decode +/// +/// # Returns +/// +/// * `Ok(ArrayRef)` - A new array of the same type containing decoded strings +/// * `Err(DataFusionError)` - If validation fails or invalid arguments are provided +/// +fn spark_url_decode(args: &[ArrayRef]) -> Result { + spark_handled_url_decode(args, |x| x) +} + +pub fn spark_handled_url_decode( + args: &[ArrayRef], + err_handle_fn: impl Fn(Result>) -> Result>, +) -> Result { + if args.len() != 1 { + return exec_err!("`url_decode` expects 1 argument"); + } + + match &args[0].data_type() { + DataType::Utf8 => as_string_array(&args[0])? + .iter() + .map(|x| x.map(UrlDecode::decode).transpose()) + .map(&err_handle_fn) + .collect::>() + .map(|array| Arc::new(array) as ArrayRef), + DataType::LargeUtf8 => as_large_string_array(&args[0])? + .iter() + .map(|x| x.map(UrlDecode::decode).transpose()) + .map(&err_handle_fn) + .collect::>() + .map(|array| Arc::new(array) as ArrayRef), + DataType::Utf8View => as_string_view_array(&args[0])? + .iter() + .map(|x| x.map(UrlDecode::decode).transpose()) + .map(&err_handle_fn) + .collect::>() + .map(|array| Arc::new(array) as ArrayRef), + other => exec_err!("`url_decode`: Expr must be STRING, got {other:?}"), + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_decode() -> Result<()> { + let input = Arc::new(StringArray::from(vec![ + Some("https%3A%2F%2Fspark.apache.org"), + Some("inva+lid://user:pass@host/file\\;param?query\\;p2"), + Some("inva lid://user:pass@host/file\\;param?query\\;p2"), + Some("%7E%21%40%23%24%25%5E%26%2A%28%29%5F%2B"), + Some("%E4%BD%A0%E5%A5%BD"), + Some(""), + None, + ])); + let expected = StringArray::from(vec![ + Some("https://spark.apache.org"), + Some("inva lid://user:pass@host/file\\;param?query\\;p2"), + Some("inva lid://user:pass@host/file\\;param?query\\;p2"), + Some("~!@#$%^&*()_+"), + Some("你好"), + Some(""), + None, + ]); + + let result = spark_url_decode(&[input as ArrayRef])?; + let result = as_string_array(&result)?; + + assert_eq!(&expected, result); + + Ok(()) + } + + #[test] + fn test_decode_error() -> Result<()> { + let input = Arc::new(StringArray::from(vec![ + Some("http%3A%2F%2spark.apache.org"), // '%2s' is not a valid percent encoded character + // Valid cases + Some("https%3A%2F%2Fspark.apache.org"), + None, + ])); + + let result = spark_url_decode(&[input]); + assert!( + result.is_err_and(|e| e.to_string().contains("Invalid percent-encoding")) + ); + + Ok(()) + } +} diff --git a/datafusion/spark/src/function/url/url_encode.rs b/datafusion/spark/src/function/url/url_encode.rs new file mode 100644 index 0000000000000..1ad2a111851ee --- /dev/null +++ b/datafusion/spark/src/function/url/url_encode.rs @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; +use datafusion_common::{Result, exec_err, plan_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions::utils::make_scalar_function; +use url::form_urlencoded::byte_serialize; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct UrlEncode { + signature: Signature, +} + +impl Default for UrlEncode { + fn default() -> Self { + Self::new() + } +} + +impl UrlEncode { + pub fn new() -> Self { + Self { + signature: Signature::string(1, Volatility::Immutable), + } + } + + /// Encode a string to application/x-www-form-urlencoded format. + /// + /// # Arguments + /// + /// * `value` - The string to encode + /// + /// # Returns + /// + /// * `Ok(String)` - The encoded string + /// + fn encode(value: &str) -> Result { + Ok(byte_serialize(value.as_bytes()).collect::()) + } +} + +impl ScalarUDFImpl for UrlEncode { + fn name(&self) -> &str { + "url_encode" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return plan_err!( + "{} expects 1 argument, but got {}", + self.name(), + arg_types.len() + ); + } + // As the type signature is already checked, we can safely return the type of the first argument + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + make_scalar_function(spark_url_encode, vec![])(&args) + } +} + +/// Core implementation of URL encoding function. +/// +/// # Arguments +/// +/// * `args` - A slice containing exactly one ArrayRef with the strings to encode +/// +/// # Returns +/// +/// * `Ok(ArrayRef)` - A new array of the same type containing encoded strings +/// * `Err(DataFusionError)` - If invalid arguments are provided +/// +fn spark_url_encode(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("`url_encode` expects 1 argument"); + } + + match &args[0].data_type() { + DataType::Utf8 => as_string_array(&args[0])? + .iter() + .map(|x| x.map(UrlEncode::encode).transpose()) + .collect::>() + .map(|array| Arc::new(array) as ArrayRef), + DataType::LargeUtf8 => as_large_string_array(&args[0])? + .iter() + .map(|x| x.map(UrlEncode::encode).transpose()) + .collect::>() + .map(|array| Arc::new(array) as ArrayRef), + DataType::Utf8View => as_string_view_array(&args[0])? + .iter() + .map(|x| x.map(UrlEncode::encode).transpose()) + .collect::>() + .map(|array| Arc::new(array) as ArrayRef), + other => exec_err!("`url_encode`: Expr must be STRING, got {other:?}"), + } +} diff --git a/datafusion/spark/src/lib.rs b/datafusion/spark/src/lib.rs index 4f3f795add5ec..6cd4678da7560 100644 --- a/datafusion/spark/src/lib.rs +++ b/datafusion/spark/src/lib.rs @@ -22,8 +22,6 @@ #![cfg_attr(docsrs, feature(doc_cfg))] // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -// https://github.com/apache/datafusion/issues/18503 -#![deny(clippy::needless_pass_by_value)] #![cfg_attr(test, allow(clippy::needless_pass_by_value))] //! Spark Expression packages for [DataFusion]. @@ -45,7 +43,7 @@ //! //! ``` //! # use datafusion_execution::FunctionRegistry; -//! # use datafusion_expr::{ScalarUDF, AggregateUDF, WindowUDF}; +//! # use datafusion_expr::{ScalarUDF, AggregateUDF, WindowUDF, HigherOrderUDF}; //! # use datafusion_expr::planner::ExprPlanner; //! # use datafusion_common::Result; //! # use std::collections::HashSet; @@ -57,9 +55,11 @@ //! # impl FunctionRegistry for SessionContext { //! # fn register_udf(&mut self, _udf: Arc) -> Result>> { Ok (None) } //! # fn udfs(&self) -> HashSet { unimplemented!() } +//! # fn higher_order_function_names(&self) -> HashSet { unimplemented!() } //! # fn udafs(&self) -> HashSet { unimplemented!() } //! # fn udwfs(&self) -> HashSet { unimplemented!() } //! # fn udf(&self, _name: &str) -> Result> { unimplemented!() } +//! # fn higher_order_function(&self, name: &str) -> Result> { unimplemented!() } //! # fn udaf(&self, name: &str) -> Result> {unimplemented!() } //! # fn udwf(&self, name: &str) -> Result> { unimplemented!() } //! # fn expr_planners(&self) -> Vec> { unimplemented!() } @@ -93,9 +93,49 @@ //! let expr = sha2(col("my_data"), lit(256)); //! ``` //! +//! # Example: using the Spark expression planner +//! +//! The [`planner::SparkFunctionPlanner`] provides Spark-compatible expression +//! planning, such as mapping SQL `EXTRACT` expressions to Spark's `date_part` +//! function. To use it, register it with your session context: +//! +//! ```ignore +//! use std::sync::Arc; +//! use datafusion::prelude::SessionContext; +//! use datafusion_spark::planner::SparkFunctionPlanner; +//! +//! let mut ctx = SessionContext::new(); +//! // Register the Spark expression planner +//! ctx.register_expr_planner(Arc::new(SparkFunctionPlanner))?; +//! // Now EXTRACT expressions will use Spark semantics +//! let df = ctx.sql("SELECT EXTRACT(YEAR FROM timestamp_col) FROM my_table").await?; +//! ``` +//! //![`Expr`]: datafusion_expr::Expr +//! +//! # Example: enabling Apache Spark features with SessionStateBuilder +//! +//! The recommended way to enable Apache Spark compatibility is to use the +//! `SessionStateBuilderSpark` extension trait. This registers all +//! Apache Spark functions (scalar, aggregate, window, and table) as well as the Apache Spark +//! expression planner. +//! +//! Enable the `core` feature in your `Cargo.toml`: +//! ```toml +//! datafusion-spark = { version = "X", features = ["core"] } +//! ``` +//! +//! Then use the extension trait - see [`SessionStateBuilderSpark::with_spark_features`] +//! for an example. pub mod function; +pub mod planner; + +#[cfg(feature = "core")] +mod session_state; + +#[cfg(feature = "core")] +pub use session_state::SessionStateBuilderSpark; use datafusion_catalog::TableFunction; use datafusion_common::Result; @@ -105,7 +145,7 @@ use log::debug; use std::sync::Arc; /// Fluent-style API for creating `Expr`s -#[allow(unused)] +#[expect(unused_imports)] pub mod expr_fn { pub use super::function::aggregate::expr_fn::*; pub use super::function::array::expr_fn::*; @@ -124,8 +164,8 @@ pub mod expr_fn { pub use super::function::math::expr_fn::*; pub use super::function::misc::expr_fn::*; pub use super::function::predicate::expr_fn::*; - pub use super::function::r#struct::expr_fn::*; pub use super::function::string::expr_fn::*; + pub use super::function::r#struct::expr_fn::*; pub use super::function::table::expr_fn::*; pub use super::function::url::expr_fn::*; pub use super::function::window::expr_fn::*; diff --git a/datafusion/spark/src/planner.rs b/datafusion/spark/src/planner.rs new file mode 100644 index 0000000000000..2dafbb1f9a570 --- /dev/null +++ b/datafusion/spark/src/planner.rs @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_expr::Expr; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::planner::{ExprPlanner, PlannerResult}; + +#[derive(Default, Debug)] +pub struct SparkFunctionPlanner; + +impl ExprPlanner for SparkFunctionPlanner { + fn plan_extract( + &self, + args: Vec, + ) -> datafusion_common::Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::function::datetime::date_part(), args), + ))) + } + + fn plan_substring( + &self, + args: Vec, + ) -> datafusion_common::Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::function::string::substring(), args), + ))) + } +} diff --git a/datafusion/spark/src/session_state.rs b/datafusion/spark/src/session_state.rs new file mode 100644 index 0000000000000..839487772a9b2 --- /dev/null +++ b/datafusion/spark/src/session_state.rs @@ -0,0 +1,147 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion::execution::SessionStateBuilder; + +use crate::planner::SparkFunctionPlanner; +use crate::{ + all_default_aggregate_functions, all_default_scalar_functions, + all_default_table_functions, all_default_window_functions, +}; + +/// Extension trait for adding Apache Spark features to [`SessionStateBuilder`]. +/// +/// This trait provides a convenient way to register all Apache Spark-compatible +/// functions and planners with a DataFusion session. +/// +/// # Example +/// +/// ```rust +/// use datafusion::execution::SessionStateBuilder; +/// use datafusion_spark::SessionStateBuilderSpark; +/// +/// // Create a SessionState with Apache Spark features enabled +/// // note: the order matters here, `with_spark_features` should be +/// // called after `with_default_features` to overwrite any existing functions +/// let state = SessionStateBuilder::new() +/// .with_default_features() +/// .with_spark_features() +/// .build(); +/// ``` +pub trait SessionStateBuilderSpark { + /// Adds all expr_planners, scalar, aggregate, window and table functions + /// compatible with Apache Spark. + /// + /// Note: This overwrites any previously registered items with the same name. + fn with_spark_features(self) -> Self; +} + +impl SessionStateBuilderSpark for SessionStateBuilder { + fn with_spark_features(mut self) -> Self { + self.expr_planners() + .get_or_insert_with(Vec::new) + // planners are evaluated in order of insertion. Push Apache Spark function planner to the front + // to take precedence over others + .insert(0, Arc::new(SparkFunctionPlanner)); + + self.scalar_functions() + .get_or_insert_with(Vec::new) + .extend(all_default_scalar_functions()); + + self.aggregate_functions() + .get_or_insert_with(Vec::new) + .extend(all_default_aggregate_functions()); + + self.window_functions() + .get_or_insert_with(Vec::new) + .extend(all_default_window_functions()); + + self.table_functions() + .get_or_insert_with(HashMap::new) + .extend( + all_default_table_functions() + .into_iter() + .map(|f| (f.name().to_string(), f)), + ); + + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion::common::config::Dialect; + use datafusion::prelude::SessionConfig; + use datafusion::prelude::SessionContext; + + #[test] + fn test_session_state_with_spark_features() { + let state = SessionStateBuilder::new().with_spark_features().build(); + + assert!( + state.scalar_functions().contains_key("sha2"), + "Apache Spark scalar function 'sha2' should be registered" + ); + + assert!( + state.aggregate_functions().contains_key("try_sum"), + "Apache Spark aggregate function 'try_sum' should be registered" + ); + + assert!( + !state.expr_planners().is_empty(), + "Apache Spark expr planners should be registered" + ); + } + + #[tokio::test] + async fn test_spark_dialect_with_spark_functions() { + let query = "SELECT sha2('abc', 256), CAST(1 AS LONG)"; + + let mut config = SessionConfig::new(); + config.options_mut().sql_parser.dialect = Dialect::Spark; + let state = SessionStateBuilder::new() + .with_config(config) + .with_default_features() + .with_spark_features() + .build(); + let ctx = SessionContext::new_with_state(state); + + let result = ctx.sql(query).await.unwrap().collect().await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].num_rows(), 1); + + let mut config = SessionConfig::new(); + config.options_mut().sql_parser.dialect = Dialect::Generic; + let state = SessionStateBuilder::new() + .with_config(config) + .with_default_features() + .with_spark_features() + .build(); + let ctx = SessionContext::new_with_state(state); + + let err = ctx.sql(query).await.unwrap_err().to_string(); + assert!( + err.contains("Unsupported SQL type LONG"), + "unexpected error: {err}" + ); + } +} From 58625983b81d9a60766f64ffc8bdf4379e0d5674 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 10 Jun 2026 13:57:43 +0300 Subject: [PATCH 5/8] add tests --- datafusion/common/src/error.rs | 303 +++++++++++++++++++++++++++++++-- 1 file changed, 288 insertions(+), 15 deletions(-) diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index e1d4843039b6d..922c93775df60 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -1339,23 +1339,47 @@ mod test { // To pass the test the environment variable RUST_BACKTRACE should be set to 1 to enforce backtrace #[cfg(feature = "backtrace")] - #[test] - fn test_enabled_backtrace() { + fn ensure_rust_backtrace_enabled() { match std::env::var("RUST_BACKTRACE") { Ok(val) if val == "1" => {} _ => panic!("Environment variable RUST_BACKTRACE must be set to 1"), }; + } + + // To pass the test the environment variable RUST_BACKTRACE should be set to 1 to enforce backtrace + #[cfg(feature = "backtrace")] + #[test] + fn test_enabled_backtrace() { + ensure_rust_backtrace_enabled(); let res: Result<(), DataFusionError> = plan_err!("Err"); - let err = res.unwrap_err().to_string(); - assert!(err.contains(DataFusionError::BACK_TRACE_SEP)); - assert_eq!( - err.split(DataFusionError::BACK_TRACE_SEP) - .collect::>() - .first() - .unwrap(), - &"Error during planning: Err" + assert_error_have_message_and_backtrace(&res.unwrap_err(), "Error during planning: Err"); + } + + #[cfg(not(feature = "backtrace"))] + #[test] + fn test_disabled_backtrace() { + let res: Result<(), DataFusionError> = plan_err!("Err"); + assert_err_without_backtrace_and_equal( + &res.unwrap_err(), + "Error during planning: Err", ); + } + + #[cfg(not(feature = "backtrace"))] + fn assert_err_without_backtrace_and_equal( + err: &DataFusionError, + expected_message: &str, + ) { + let err = err.to_string(); + assert!(!err.contains(DataFusionError::BACK_TRACE_SEP)); + assert_eq!(err, expected_message); + } + + #[cfg(feature = "backtrace")] + fn assert_error_have_message_and_backtrace(err: &DataFusionError, message_before_backtrace: &str) { + let err = err.to_string(); + assert!(err.contains(DataFusionError::BACK_TRACE_SEP)); assert!( !err.split(DataFusionError::BACK_TRACE_SEP) .collect::>() @@ -1363,15 +1387,264 @@ mod test { .unwrap() .is_empty() ); + assert_eq!( + err.split(DataFusionError::BACK_TRACE_SEP) + .collect::>() + .first() + .copied() + .unwrap(), + message_before_backtrace + ); } + #[cfg(feature = "backtrace")] + #[test] + fn test_enabled_backtrace_for_unwrap_or_internal_err() { + ensure_rust_backtrace_enabled(); + + fn get_error() -> Result<(), DataFusionError> { + let item = None::<()>; + unwrap_or_internal_err!(item); + + unreachable!("should return error"); + } + + let res: Result<(), DataFusionError> = get_error(); + assert_error_have_message_and_backtrace(&res.unwrap_err(), + "Internal error: item should not be None" + ); + } + + // To pass the test the environment variable RUST_BACKTRACE should be set to 1 to enforce backtrace #[cfg(not(feature = "backtrace"))] #[test] - fn test_disabled_backtrace() { - let res: Result<(), DataFusionError> = plan_err!("Err"); - let res = res.unwrap_err().to_string(); - assert!(!res.contains(DataFusionError::BACK_TRACE_SEP)); - assert_eq!(res, "Error during planning: Err"); + fn test_disabled_backtrace_for_unwrap_or_internal_err() { + fn get_error() -> Result<(), DataFusionError> { + let item = None::<()>; + unwrap_or_internal_err!(item); + + unreachable!("should return error"); + } + + let res: Result<(), DataFusionError> = get_error(); + assert_err_without_backtrace_and_equal( + &res.unwrap_err(), + "Internal error: item should not be None", + ); + } + + #[cfg(feature = "backtrace")] + #[test] + fn test_enabled_backtrace_for_assert_or_internal_err_without_args() { + ensure_rust_backtrace_enabled(); + + fn get_error() -> Result<(), DataFusionError> { + assert_or_internal_err!(false); + + unreachable!("should return error"); + } + + let res: Result<(), DataFusionError> = get_error(); + assert_error_have_message_and_backtrace(&res.unwrap_err(), + "Internal error: Assertion failed: false" + ); + } + + #[cfg(feature = "backtrace")] + #[test] + fn test_enabled_backtrace_for_assert_or_internal_err_with_args() { + ensure_rust_backtrace_enabled(); + + fn get_error() -> Result<(), DataFusionError> { + assert_or_internal_err!(false, "my cool context"); + + unreachable!("should return error"); + } + + let res: Result<(), DataFusionError> = get_error(); + assert_error_have_message_and_backtrace(&res.unwrap_err(), + "Internal error: Assertion failed: false: my cool context" + ); + } + + #[cfg(not(feature = "backtrace"))] + #[test] + fn test_disabled_backtrace_for_assert_or_internal_err_without_args() { + fn get_error() -> Result<(), DataFusionError> { + assert_or_internal_err!(false); + + unreachable!("should return error"); + } + + let res: Result<(), DataFusionError> = get_error(); + assert_err_without_backtrace_and_equal( + &res.unwrap_err(), + "Internal error: Assertion failed: false", + ); + } + + #[cfg(not(feature = "backtrace"))] + #[test] + fn test_disabled_backtrace_for_assert_or_internal_err_with_args() { + fn get_error() -> Result<(), DataFusionError> { + assert_or_internal_err!(false, "my cool context"); + + unreachable!("should return error"); + } + + let res: Result<(), DataFusionError> = get_error(); + assert_err_without_backtrace_and_equal( + &res.unwrap_err(), + "Internal error: Assertion failed: false: my cool context", + ); + } + + #[cfg(feature = "backtrace")] + #[test] + fn test_enabled_backtrace_for_assert_eq_or_internal_err_without_args() { + ensure_rust_backtrace_enabled(); + + fn get_error() -> Result<(), DataFusionError> { + let arg1 = 1; + let arg2 = 2; + assert_eq_or_internal_err!(arg1, arg2); + + unreachable!("should return error"); + } + + let res: Result<(), DataFusionError> = get_error(); + assert_error_have_message_and_backtrace(&res.unwrap_err(), + "Internal error: Assertion failed: arg1 == arg2 (left: 1, right: 2)" + ); + } + + #[cfg(feature = "backtrace")] + #[test] + fn test_enabled_backtrace_for_assert_eq_or_internal_err_with_args() { + ensure_rust_backtrace_enabled(); + + fn get_error() -> Result<(), DataFusionError> { + let arg1 = 1; + let arg2 = 2; + assert_eq_or_internal_err!(arg1, arg2, "my cool context"); + + unreachable!("should return error"); + } + + let res: Result<(), DataFusionError> = get_error(); + assert_error_have_message_and_backtrace(&res.unwrap_err(), + "Internal error: Assertion failed: arg1 == arg2 (left: 1, right: 2): my cool context" + ); + } + + #[cfg(not(feature = "backtrace"))] + #[test] + fn test_disabled_backtrace_for_assert_eq_or_internal_err_without_args() { + fn get_error() -> Result<(), DataFusionError> { + let arg1 = 1; + let arg2 = 2; + assert_eq_or_internal_err!(arg1, arg2); + + unreachable!("should return error"); + } + + let res: Result<(), DataFusionError> = get_error(); + assert_err_without_backtrace_and_equal( + &res.unwrap_err(), + "Internal error: Assertion failed: arg1 == arg2 (left: 1, right: 2)", + ); + } + + #[cfg(not(feature = "backtrace"))] + #[test] + fn test_disabled_backtrace_for_assert_eq_or_internal_err_with_args() { + fn get_error() -> Result<(), DataFusionError> { + let arg1 = 1; + let arg2 = 2; + assert_eq_or_internal_err!(arg1, arg2, "my cool context"); + + unreachable!("should return error"); + } + + let res: Result<(), DataFusionError> = get_error(); + assert_err_without_backtrace_and_equal( + &res.unwrap_err(), + "Internal error: Assertion failed: arg1 == arg2 (left: 1, right: 2): my cool context", + ); + } + + #[cfg(feature = "backtrace")] + #[test] + fn test_enabled_backtrace_for_assert_ne_or_internal_err_without_args() { + ensure_rust_backtrace_enabled(); + + fn get_error() -> Result<(), DataFusionError> { + let arg1 = 1; + let arg2 = 1; + assert_ne_or_internal_err!(arg1, arg2); + + unreachable!("should return error"); + } + + let res: Result<(), DataFusionError> = get_error(); + assert_error_have_message_and_backtrace(&res.unwrap_err(), + "Internal error: Assertion failed: arg1 != arg2 (left: 1, right: 1)" + ); + } + + #[cfg(feature = "backtrace")] + #[test] + fn test_enabled_backtrace_for_assert_ne_or_internal_err_with_args() { + ensure_rust_backtrace_enabled(); + + fn get_error() -> Result<(), DataFusionError> { + let arg1 = 1; + let arg2 = 1; + assert_ne_or_internal_err!(arg1, arg2, "my cool context"); + + unreachable!("should return error"); + } + + let res: Result<(), DataFusionError> = get_error(); + assert_error_have_message_and_backtrace(&res.unwrap_err(), + "Internal error: Assertion failed: arg1 != arg2 (left: 1, right: 1): my cool context" + ); + } + + #[cfg(not(feature = "backtrace"))] + #[test] + fn test_disabled_backtrace_for_assert_ne_or_internal_err_without_args() { + fn get_error() -> Result<(), DataFusionError> { + let arg1 = 1; + let arg2 = 1; + assert_ne_or_internal_err!(arg1, arg2); + + unreachable!("should return error"); + } + + let res: Result<(), DataFusionError> = get_error(); + assert_err_without_backtrace_and_equal( + &res.unwrap_err(), + "Internal error: Assertion failed: arg1 != arg2 (left: 1, right: 1)", + ); + } + + #[cfg(not(feature = "backtrace"))] + #[test] + fn test_disabled_backtrace_for_assert_ne_or_internal_err_with_args() { + fn get_error() -> Result<(), DataFusionError> { + let arg1 = 1; + let arg2 = 1; + assert_ne_or_internal_err!(arg1, arg2, "my cool context"); + + unreachable!("should return error"); + } + + let res: Result<(), DataFusionError> = get_error(); + assert_err_without_backtrace_and_equal( + &res.unwrap_err(), + "Internal error: Assertion failed: arg1 != arg2 (left: 1, right: 1): my cool context", + ); } #[test] From ae1c623223ce38a8e950d56e246bc00b536e7eae Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 10 Jun 2026 14:01:32 +0300 Subject: [PATCH 6/8] format --- datafusion/common/src/error.rs | 45 ++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 922c93775df60..6ed94a4cab376 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -1353,7 +1353,10 @@ mod test { ensure_rust_backtrace_enabled(); let res: Result<(), DataFusionError> = plan_err!("Err"); - assert_error_have_message_and_backtrace(&res.unwrap_err(), "Error during planning: Err"); + assert_error_have_message_and_backtrace( + &res.unwrap_err(), + "Error during planning: Err", + ); } #[cfg(not(feature = "backtrace"))] @@ -1377,7 +1380,10 @@ mod test { } #[cfg(feature = "backtrace")] - fn assert_error_have_message_and_backtrace(err: &DataFusionError, message_before_backtrace: &str) { + fn assert_error_have_message_and_backtrace( + err: &DataFusionError, + message_before_backtrace: &str, + ) { let err = err.to_string(); assert!(err.contains(DataFusionError::BACK_TRACE_SEP)); assert!( @@ -1410,8 +1416,9 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_error_have_message_and_backtrace(&res.unwrap_err(), - "Internal error: item should not be None" + assert_error_have_message_and_backtrace( + &res.unwrap_err(), + "Internal error: item should not be None", ); } @@ -1445,8 +1452,9 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_error_have_message_and_backtrace(&res.unwrap_err(), - "Internal error: Assertion failed: false" + assert_error_have_message_and_backtrace( + &res.unwrap_err(), + "Internal error: Assertion failed: false", ); } @@ -1462,8 +1470,9 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_error_have_message_and_backtrace(&res.unwrap_err(), - "Internal error: Assertion failed: false: my cool context" + assert_error_have_message_and_backtrace( + &res.unwrap_err(), + "Internal error: Assertion failed: false: my cool context", ); } @@ -1513,8 +1522,9 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_error_have_message_and_backtrace(&res.unwrap_err(), - "Internal error: Assertion failed: arg1 == arg2 (left: 1, right: 2)" + assert_error_have_message_and_backtrace( + &res.unwrap_err(), + "Internal error: Assertion failed: arg1 == arg2 (left: 1, right: 2)", ); } @@ -1532,8 +1542,9 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_error_have_message_and_backtrace(&res.unwrap_err(), - "Internal error: Assertion failed: arg1 == arg2 (left: 1, right: 2): my cool context" + assert_error_have_message_and_backtrace( + &res.unwrap_err(), + "Internal error: Assertion failed: arg1 == arg2 (left: 1, right: 2): my cool context", ); } @@ -1587,8 +1598,9 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_error_have_message_and_backtrace(&res.unwrap_err(), - "Internal error: Assertion failed: arg1 != arg2 (left: 1, right: 1)" + assert_error_have_message_and_backtrace( + &res.unwrap_err(), + "Internal error: Assertion failed: arg1 != arg2 (left: 1, right: 1)", ); } @@ -1606,8 +1618,9 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_error_have_message_and_backtrace(&res.unwrap_err(), - "Internal error: Assertion failed: arg1 != arg2 (left: 1, right: 1): my cool context" + assert_error_have_message_and_backtrace( + &res.unwrap_err(), + "Internal error: Assertion failed: arg1 != arg2 (left: 1, right: 1): my cool context", ); } From 662c54f4c3d12136974b190efacb4b7115b8b7ef Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 10 Jun 2026 14:16:28 +0300 Subject: [PATCH 7/8] fix assertion message --- datafusion/common/src/error.rs | 60 ++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 6ed94a4cab376..5ed952a2265f6 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -1379,6 +1379,22 @@ mod test { assert_eq!(err, expected_message); } + #[cfg(not(feature = "backtrace"))] + fn assert_internal_err_without_backtrace_and_equal( + err: &DataFusionError, + expected_message: &str, + ) { + let expected_message_before_backtrace = format!( + "{expected_message}.\nThis issue was likely caused by a bug in DataFusion's code. \ + Please help us to resolve this by filing a bug report in our issue tracker: \ + https://github.com/apache/datafusion/issues" + ); + assert_err_without_backtrace_and_equal( + err, + expected_message_before_backtrace.as_str(), + ); + } + #[cfg(feature = "backtrace")] fn assert_error_have_message_and_backtrace( err: &DataFusionError, @@ -1403,6 +1419,22 @@ mod test { ); } + #[cfg(feature = "backtrace")] + fn assert_internal_error_have_message_and_backtrace( + err: &DataFusionError, + message_before_backtrace: &str, + ) { + let expected_message_before_backtrace = format!( + "{message_before_backtrace}.\nThis issue was likely caused by a bug in DataFusion's code. \ + Please help us to resolve this by filing a bug report in our issue tracker: \ + https://github.com/apache/datafusion/issues" + ); + assert_error_have_message_and_backtrace( + err, + expected_message_before_backtrace.as_str(), + ); + } + #[cfg(feature = "backtrace")] #[test] fn test_enabled_backtrace_for_unwrap_or_internal_err() { @@ -1416,7 +1448,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_error_have_message_and_backtrace( + assert_internal_error_have_message_and_backtrace( &res.unwrap_err(), "Internal error: item should not be None", ); @@ -1434,7 +1466,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_err_without_backtrace_and_equal( + assert_internal_err_without_backtrace_and_equal( &res.unwrap_err(), "Internal error: item should not be None", ); @@ -1452,7 +1484,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_error_have_message_and_backtrace( + assert_internal_error_have_message_and_backtrace( &res.unwrap_err(), "Internal error: Assertion failed: false", ); @@ -1470,7 +1502,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_error_have_message_and_backtrace( + assert_internal_error_have_message_and_backtrace( &res.unwrap_err(), "Internal error: Assertion failed: false: my cool context", ); @@ -1486,7 +1518,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_err_without_backtrace_and_equal( + assert_internal_err_without_backtrace_and_equal( &res.unwrap_err(), "Internal error: Assertion failed: false", ); @@ -1502,7 +1534,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_err_without_backtrace_and_equal( + assert_internal_err_without_backtrace_and_equal( &res.unwrap_err(), "Internal error: Assertion failed: false: my cool context", ); @@ -1522,7 +1554,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_error_have_message_and_backtrace( + assert_internal_error_have_message_and_backtrace( &res.unwrap_err(), "Internal error: Assertion failed: arg1 == arg2 (left: 1, right: 2)", ); @@ -1542,7 +1574,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_error_have_message_and_backtrace( + assert_internal_error_have_message_and_backtrace( &res.unwrap_err(), "Internal error: Assertion failed: arg1 == arg2 (left: 1, right: 2): my cool context", ); @@ -1560,7 +1592,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_err_without_backtrace_and_equal( + assert_internal_err_without_backtrace_and_equal( &res.unwrap_err(), "Internal error: Assertion failed: arg1 == arg2 (left: 1, right: 2)", ); @@ -1578,7 +1610,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_err_without_backtrace_and_equal( + assert_internal_err_without_backtrace_and_equal( &res.unwrap_err(), "Internal error: Assertion failed: arg1 == arg2 (left: 1, right: 2): my cool context", ); @@ -1598,7 +1630,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_error_have_message_and_backtrace( + assert_internal_error_have_message_and_backtrace( &res.unwrap_err(), "Internal error: Assertion failed: arg1 != arg2 (left: 1, right: 1)", ); @@ -1618,7 +1650,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_error_have_message_and_backtrace( + assert_internal_error_have_message_and_backtrace( &res.unwrap_err(), "Internal error: Assertion failed: arg1 != arg2 (left: 1, right: 1): my cool context", ); @@ -1636,7 +1668,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_err_without_backtrace_and_equal( + assert_internal_err_without_backtrace_and_equal( &res.unwrap_err(), "Internal error: Assertion failed: arg1 != arg2 (left: 1, right: 1)", ); @@ -1654,7 +1686,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_err_without_backtrace_and_equal( + assert_internal_err_without_backtrace_and_equal( &res.unwrap_err(), "Internal error: Assertion failed: arg1 != arg2 (left: 1, right: 1): my cool context", ); From 8346c7574371ff8e6732da7051341a354ad23ee3 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 10 Jun 2026 16:25:32 +0300 Subject: [PATCH 8/8] fix test to use starts_with --- datafusion/common/src/error.rs | 74 +++++++--------------------------- 1 file changed, 15 insertions(+), 59 deletions(-) diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 5ed952a2265f6..ce6f8e68aee43 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -1204,7 +1204,6 @@ mod test { use std::sync::Arc; use arrow::error::ArrowError; - use insta::assert_snapshot; fn ok_result() -> Result<()> { Ok(()) @@ -1224,13 +1223,7 @@ mod test { } let err = check().unwrap_err().strip_backtrace(); - assert_snapshot!( - err.to_string(), - @r" - Internal error: Assertion failed: 1 == 2 (left: 1, right: 2): expected equality. - This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues - " - ); + assert!(err.starts_with("Internal error: Assertion failed: 1 == 2 (left: 1, right: 2): expected equality")); } #[test] @@ -1247,13 +1240,7 @@ mod test { } let err = check().unwrap_err().strip_backtrace(); - assert_snapshot!( - err.to_string(), - @r" - Internal error: Assertion failed: 3 != 3 (left: 3, right: 3): values must differ. - This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues - " - ); + assert!(err.starts_with("Internal error: Assertion failed: 3 != 3 (left: 3, right: 3): values must differ")); } #[test] @@ -1271,13 +1258,7 @@ mod test { } let err = check().unwrap_err().strip_backtrace(); - assert_snapshot!( - err.to_string(), - @r" - Internal error: Assertion failed: false. - This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues - " - ); + assert!(err.starts_with("Internal error: Assertion failed: false")); } #[test] @@ -1288,12 +1269,8 @@ mod test { } let err = check().unwrap_err().strip_backtrace(); - assert_snapshot!( - err.to_string(), - @r" - Internal error: Assertion failed: false: custom message. - This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues - " + assert!( + err.starts_with("Internal error: Assertion failed: false: custom message") ); } @@ -1305,13 +1282,7 @@ mod test { } let err = check().unwrap_err().strip_backtrace(); - assert_snapshot!( - err.to_string(), - @r" - Internal error: Assertion failed: false: custom 42. - This issue was likely caused by a bug in DataFusion's code. Please help us to resolve this by filing a bug report in our issue tracker: https://github.com/apache/datafusion/issues - " - ); + assert!(err.starts_with("Internal error: Assertion failed: false: custom 42")); } #[test] @@ -1415,23 +1386,8 @@ mod test { .first() .copied() .unwrap(), - message_before_backtrace - ); - } - - #[cfg(feature = "backtrace")] - fn assert_internal_error_have_message_and_backtrace( - err: &DataFusionError, - message_before_backtrace: &str, - ) { - let expected_message_before_backtrace = format!( - "{message_before_backtrace}.\nThis issue was likely caused by a bug in DataFusion's code. \ - Please help us to resolve this by filing a bug report in our issue tracker: \ - https://github.com/apache/datafusion/issues" - ); - assert_error_have_message_and_backtrace( - err, - expected_message_before_backtrace.as_str(), + message_before_backtrace, + "full error is: {err}" ); } @@ -1448,7 +1404,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_internal_error_have_message_and_backtrace( + assert_error_have_message_and_backtrace( &res.unwrap_err(), "Internal error: item should not be None", ); @@ -1484,7 +1440,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_internal_error_have_message_and_backtrace( + assert_error_have_message_and_backtrace( &res.unwrap_err(), "Internal error: Assertion failed: false", ); @@ -1502,7 +1458,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_internal_error_have_message_and_backtrace( + assert_error_have_message_and_backtrace( &res.unwrap_err(), "Internal error: Assertion failed: false: my cool context", ); @@ -1554,7 +1510,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_internal_error_have_message_and_backtrace( + assert_error_have_message_and_backtrace( &res.unwrap_err(), "Internal error: Assertion failed: arg1 == arg2 (left: 1, right: 2)", ); @@ -1574,7 +1530,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_internal_error_have_message_and_backtrace( + assert_error_have_message_and_backtrace( &res.unwrap_err(), "Internal error: Assertion failed: arg1 == arg2 (left: 1, right: 2): my cool context", ); @@ -1630,7 +1586,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_internal_error_have_message_and_backtrace( + assert_error_have_message_and_backtrace( &res.unwrap_err(), "Internal error: Assertion failed: arg1 != arg2 (left: 1, right: 1)", ); @@ -1650,7 +1606,7 @@ mod test { } let res: Result<(), DataFusionError> = get_error(); - assert_internal_error_have_message_and_backtrace( + assert_error_have_message_and_backtrace( &res.unwrap_err(), "Internal error: Assertion failed: arg1 != arg2 (left: 1, right: 1): my cool context", );